diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-08-28 22:58:15 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-08-28 22:58:15 +0000 |
commit | 9c9f048b5e3128f06ebdee12581f6c00692b2973 (patch) | |
tree | aca2b2fbb5e5d4cdbd6ed7a57ec8736cb0ba1e43 /webui.py | |
parent | 7a7a3a6b194578097b054ac0e58e15b2142ac106 (diff) | |
download | stable-diffusion-webui-gfx803-9c9f048b5e3128f06ebdee12581f6c00692b2973.tar.gz stable-diffusion-webui-gfx803-9c9f048b5e3128f06ebdee12581f6c00692b2973.tar.bz2 stable-diffusion-webui-gfx803-9c9f048b5e3128f06ebdee12581f6c00692b2973.zip |
support for generating images on video cards with 4GB
Diffstat (limited to 'webui.py')
-rw-r--r-- | webui.py | 90 |
1 files changed, 86 insertions, 4 deletions
@@ -2,6 +2,8 @@ import argparse import os
import sys
from collections import namedtuple
+from contextlib import nullcontext
+
import torch
import torch.nn as nn
import numpy as np
@@ -51,6 +53,7 @@ parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="embeddings dirtectory for textual inversion (default: embeddings)")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
+parser.add_argument("--lowvram", action='store_true', help="enamble optimizations for low vram")
cmd_opts = parser.parse_args()
@@ -185,11 +188,80 @@ def load_model_from_config(config, ckpt, verbose=False): print("unexpected keys:")
print(u)
- model.cuda()
model.eval()
return model
+module_in_gpu = None
+
+
+def setup_for_low_vram(sd_model):
+ parents = {}
+
+ def send_me_to_gpu(module, _):
+ """send this module to GPU; send whatever tracked module was previous in GPU to CPU;
+ we add this as forward_pre_hook to a lot of modules and this way all but one of them will
+ be in CPU
+ """
+ global module_in_gpu
+
+ module = parents.get(module, module)
+
+ if module_in_gpu == module:
+ return
+
+ if module_in_gpu is not None:
+ print('removing from gpu:', type(module_in_gpu))
+ module_in_gpu.to(cpu)
+
+ print('adding to gpu:', type(module))
+ module.to(gpu)
+
+ print('added to gpu:', type(module))
+ module_in_gpu = module
+
+ # see below for register_forward_pre_hook;
+ # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
+ # useless here, and we just replace those methods
+ def first_stage_model_encode_wrap(self, encoder, x):
+ send_me_to_gpu(self, None)
+ return encoder(x)
+
+ def first_stage_model_decode_wrap(self, decoder, z):
+ send_me_to_gpu(self, None)
+ return decoder(z)
+
+ # remove three big modules, cond, first_stage, and unet from the model and then
+ # send the model to GPU. Then put modules back. the modules will be in CPU.
+ stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
+ sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
+ sd_model.to(device)
+ sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
+
+ # register hooks for those the first two models
+ sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
+ sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
+ sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
+ sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
+ parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
+
+ # the third remaining model is still too big for 4GB, so we also do the same for its submodules
+ # so that only one of them is in GPU at a time
+ diff_model = sd_model.model.diffusion_model
+ stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
+ diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
+ sd_model.model.to(device)
+ diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
+
+ # install hooks for bits of third model
+ diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
+ for block in diff_model.input_blocks:
+ block.register_forward_pre_hook(send_me_to_gpu)
+ diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
+ for block in diff_model.output_blocks:
+ block.register_forward_pre_hook(send_me_to_gpu)
+
+
def create_random_tensors(shape, seeds):
xs = []
for seed in seeds:
@@ -838,7 +910,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model)
output_images = []
- with torch.no_grad(), autocast("cuda"), model.ema_scope():
+ ema_scope = (nullcontext if cmd_opts.lowvram else model.ema_scope)
+ with torch.no_grad(), autocast("cuda"), ema_scope():
p.init()
for n in range(p.n_iter):
@@ -1327,8 +1400,17 @@ interfaces = [ sd_config = OmegaConf.load(cmd_opts.config)
sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
-device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
-sd_model = (sd_model if cmd_opts.no_half else sd_model.half()).to(device)
+cpu = torch.device("cpu")
+gpu = torch.device("cuda")
+device = gpu if torch.cuda.is_available() else cpu
+
+sd_model = (sd_model if cmd_opts.no_half else sd_model.half())
+
+if not cmd_opts.lowvram:
+ sd_model = sd_model.to(device)
+
+else:
+ setup_for_low_vram(sd_model)
model_hijack = StableDiffusionModelHijack()
model_hijack.hijack(sd_model)
|