diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/devices.py | 9 | ||||
-rw-r--r-- | modules/extras.py | 8 | ||||
-rw-r--r-- | modules/import_hook.py | 5 | ||||
-rw-r--r-- | modules/safe.py | 12 | ||||
-rw-r--r-- | modules/sd_hijack_optimizations.py | 10 | ||||
-rw-r--r-- | modules/sd_samplers.py | 29 | ||||
-rw-r--r-- | modules/shared.py | 1 | ||||
-rw-r--r-- | modules/ui_extensions.py | 15 |
8 files changed, 64 insertions, 25 deletions
diff --git a/modules/devices.py b/modules/devices.py index f8cffae1..800510b7 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -125,7 +125,16 @@ def layer_norm_fix(*args, **kwargs): return orig_layer_norm(*args, **kwargs) +# MPS workaround for https://github.com/pytorch/pytorch/issues/90532 +orig_tensor_numpy = torch.Tensor.numpy +def numpy_fix(self, *args, **kwargs): + if self.requires_grad: + self = self.detach() + return orig_tensor_numpy(self, *args, **kwargs) + + # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): torch.Tensor.to = tensor_to_fix torch.nn.functional.layer_norm = layer_norm_fix + torch.Tensor.numpy = numpy_fix diff --git a/modules/extras.py b/modules/extras.py index 0ad8deec..704e5165 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -193,8 +193,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ else:
basename = ''
+ # Add upscaler name as a suffix.
+ suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
+ # Add second upscaler if applicable.
+ if suffix and extras_upscaler_2 and extras_upscaler_2_visibility:
+ suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}"
+
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
- no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
+ no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
if opts.enable_pnginfo:
image.info = existing_pnginfo
diff --git a/modules/import_hook.py b/modules/import_hook.py new file mode 100644 index 00000000..28c67dfa --- /dev/null +++ b/modules/import_hook.py @@ -0,0 +1,5 @@ +import sys + +# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it +if "--xformers" not in "".join(sys.argv): + sys.modules["xformers"] = None diff --git a/modules/safe.py b/modules/safe.py index 10460ad0..7c89c4c2 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -37,16 +37,16 @@ class RestrictedUnpickler(pickle.Unpickler): if module == 'collections' and name == 'OrderedDict':
return getattr(collections, name)
- if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
+ if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
return getattr(torch._utils, name)
- if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']:
+ if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name)
- if module == 'numpy.core.multiarray' and name == 'scalar':
- return numpy.core.multiarray.scalar
- if module == 'numpy' and name == 'dtype':
- return numpy.dtype
+ if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
+ return getattr(numpy.core.multiarray, name)
+ if module == 'numpy' and name in ['dtype', 'ndarray']:
+ return getattr(numpy, name)
if module == '_codecs' and name == 'encode':
return encode
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 98123fbf..02c87f40 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -127,7 +127,7 @@ def check_for_psutil(): invokeAI_mps_available = check_for_psutil()
-# -- Taken from https://github.com/invoke-ai/InvokeAI --
+# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
if invokeAI_mps_available:
import psutil
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
@@ -152,14 +152,16 @@ def einsum_op_slice_1(q, k, v, slice_size): return r
def einsum_op_mps_v1(q, k, v):
- if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
+ if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
return einsum_op_compvis(q, k, v)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
+ if slice_size % 4096 == 0:
+ slice_size -= 1
return einsum_op_slice_1(q, k, v, slice_size)
def einsum_op_mps_v2(q, k, v):
- if mem_total_gb > 8 and q.shape[1] <= 4096:
+ if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
return einsum_op_compvis(q, k, v)
else:
return einsum_op_slice_0(q, k, v, 1)
@@ -188,7 +190,7 @@ def einsum_op(q, k, v): return einsum_op_cuda(q, k, v)
if q.device.type == 'mps':
- if mem_total_gb >= 32:
+ if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
return einsum_op_mps_v1(q, k, v)
return einsum_op_mps_v2(q, k, v)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 4c123d3b..d26e48dc 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -23,16 +23,16 @@ samplers_k_diffusion = [ ('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
('Heun', 'sample_heun', ['k_heun'], {}),
- ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
- ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
+ ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
- ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
- ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
+ ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
+ ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
@@ -444,9 +444,7 @@ class KDiffusionSampler: return extra_params_kwargs
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- steps, t_enc = setup_img2img_steps(p, steps)
-
+ def get_sigmas(self, p, steps):
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
@@ -454,6 +452,16 @@ class KDiffusionSampler: else:
sigmas = self.model_wrap.get_sigmas(steps)
+ if self.config is not None and self.config.options.get('discard_next_to_last_sigma', False):
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
+
+ return sigmas
+
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps, t_enc = setup_img2img_steps(p, steps)
+
+ sigmas = self.get_sigmas(p, steps)
+
sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise * sigma_sched[0]
@@ -485,12 +493,7 @@ class KDiffusionSampler: def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
steps = steps or p.steps
- if p.sampler_noise_scheduler_override:
- sigmas = p.sampler_noise_scheduler_override(steps)
- elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
- sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
- else:
- sigmas = self.model_wrap.get_sigmas(steps)
+ sigmas = self.get_sigmas(p, steps)
x = x * sigmas[0]
diff --git a/modules/shared.py b/modules/shared.py index 272267c1..215c1358 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -293,6 +293,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
+ "use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 1434f25f..eec9586f 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -9,6 +9,8 @@ import git import gradio as gr
import html
+import shutil
+import errno
from modules import extensions, shared, paths
@@ -138,7 +140,18 @@ def install_extension_from_url(dirname, url): repo = git.Repo.clone_from(url, tmpdir)
repo.remote().fetch()
- os.rename(tmpdir, target_dir)
+ try:
+ os.rename(tmpdir, target_dir)
+ except OSError as err:
+ # TODO what does this do on windows? I think it'll be a different error code but I don't have a system to check it
+ # Shouldn't cause any new issues at least but we probably want to handle it there too.
+ if err.errno == errno.EXDEV:
+ # Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
+ # Since we can't use a rename, do the slower but more versitile shutil.move()
+ shutil.move(tmpdir, target_dir)
+ else:
+ # Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
+ raise(err)
import launch
launch.run_extension_installer(target_dir)
|