diff options
author | Sihan Wang <31711261+shwang95@users.noreply.github.com> | 2022-11-02 06:09:33 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-11-02 06:09:33 +0000 |
commit | 5c864be010b42373e51eac4c6869d561adca4202 (patch) | |
tree | 1de5f558ab056bc766ac51d3612356f0de52b1f8 /modules/lowvram.py | |
parent | 7bd8581e461623932ffbd5762ee931ee51f798db (diff) | |
parent | 65522ff157e4be4095a99421da04ecb0749824ac (diff) | |
download | stable-diffusion-webui-gfx803-5c864be010b42373e51eac4c6869d561adca4202.tar.gz stable-diffusion-webui-gfx803-5c864be010b42373e51eac4c6869d561adca4202.tar.bz2 stable-diffusion-webui-gfx803-5c864be010b42373e51eac4c6869d561adca4202.zip |
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'modules/lowvram.py')
-rw-r--r-- | modules/lowvram.py | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/modules/lowvram.py b/modules/lowvram.py index f327c3df..a4652cb1 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram): # see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods
- def first_stage_model_encode_wrap(self, encoder, x):
- send_me_to_gpu(self, None)
- return encoder(x)
- def first_stage_model_decode_wrap(self, decoder, z):
- send_me_to_gpu(self, None)
- return decoder(z)
+ first_stage_model = sd_model.first_stage_model
+ first_stage_model_encode = sd_model.first_stage_model.encode
+ first_stage_model_decode = sd_model.first_stage_model.decode
+
+ def first_stage_model_encode_wrap(x):
+ send_me_to_gpu(first_stage_model, None)
+ return first_stage_model_encode(x)
+
+ def first_stage_model_decode_wrap(z):
+ send_me_to_gpu(first_stage_model, None)
+ return first_stage_model_decode(z)
# remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram): # register hooks for those the first two models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
- sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
- sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
+ sd_model.first_stage_model.encode = first_stage_model_encode_wrap
+ sd_model.first_stage_model.decode = first_stage_model_decode_wrap
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram:
|