diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2022-11-06 08:27:54 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-11-06 08:27:54 +0000 |
commit | 07d1bd426722b4c53b38ff682c5aab53177d8530 (patch) | |
tree | 4fdad803a4536cec2bd3e622c5f4cfb980f04550 /modules/lowvram.py | |
parent | 3f3d14afd5abd07d3843370dc1c28be299dbdbab (diff) | |
parent | 6e4de5b4422dfc0d45063b2c8c78b19f00321615 (diff) | |
download | stable-diffusion-webui-gfx803-07d1bd426722b4c53b38ff682c5aab53177d8530.tar.gz stable-diffusion-webui-gfx803-07d1bd426722b4c53b38ff682c5aab53177d8530.tar.bz2 stable-diffusion-webui-gfx803-07d1bd426722b4c53b38ff682c5aab53177d8530.zip |
Merge branch 'master' into roy.add_simple_interrogate_api
Diffstat (limited to 'modules/lowvram.py')
-rw-r--r-- | modules/lowvram.py | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/modules/lowvram.py b/modules/lowvram.py index f327c3df..a4652cb1 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram): # see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods
- def first_stage_model_encode_wrap(self, encoder, x):
- send_me_to_gpu(self, None)
- return encoder(x)
- def first_stage_model_decode_wrap(self, decoder, z):
- send_me_to_gpu(self, None)
- return decoder(z)
+ first_stage_model = sd_model.first_stage_model
+ first_stage_model_encode = sd_model.first_stage_model.encode
+ first_stage_model_decode = sd_model.first_stage_model.decode
+
+ def first_stage_model_encode_wrap(x):
+ send_me_to_gpu(first_stage_model, None)
+ return first_stage_model_encode(x)
+
+ def first_stage_model_decode_wrap(z):
+ send_me_to_gpu(first_stage_model, None)
+ return first_stage_model_decode(z)
# remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram): # register hooks for those the first two models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
- sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
- sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
+ sd_model.first_stage_model.encode = first_stage_model_encode_wrap
+ sd_model.first_stage_model.decode = first_stage_model_decode_wrap
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram:
|