diff options
author | Aarni Koskela <akx@iki.fi> | 2023-05-11 15:28:15 +0000 |
---|---|---|
committer | Aarni Koskela <akx@iki.fi> | 2023-05-11 17:29:11 +0000 |
commit | 49a55b410b66b7dd9be9335d8a2e3a71e4f8b15c (patch) | |
tree | d79f004eae46bc1c49832f3c668a524107c30034 /modules/sd_hijack.py | |
parent | 431bc5a297ff7c17231b92b6c8f8152b2fab8553 (diff) | |
download | stable-diffusion-webui-gfx803-49a55b410b66b7dd9be9335d8a2e3a71e4f8b15c.tar.gz stable-diffusion-webui-gfx803-49a55b410b66b7dd9be9335d8a2e3a71e4f8b15c.tar.bz2 stable-diffusion-webui-gfx803-49a55b410b66b7dd9be9335d8a2e3a71e4f8b15c.zip |
Autofix Ruff W (not W605) (mostly whitespace)
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r-- | modules/sd_hijack.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index e374aeb8..7e50f1ab 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -34,7 +34,7 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
-
+
optimization_method = None
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
@@ -92,12 +92,12 @@ def fix_checkpoint(): def weighted_loss(sd_model, pred, target, mean=True):
#Calculate the weight normally, but ignore the mean
loss = sd_model._old_get_loss(pred, target, mean=False)
-
+
#Check if we have weights available
weight = getattr(sd_model, '_custom_loss_weight', None)
if weight is not None:
loss *= weight
-
+
#Return the loss, as mean if specified
return loss.mean() if mean else loss
@@ -105,7 +105,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs): try:
#Temporarily append weights to a place accessible during loss calc
sd_model._custom_loss_weight = w
-
+
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
if not hasattr(sd_model, '_old_get_loss'):
@@ -120,7 +120,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs): del sd_model._custom_loss_weight
except AttributeError:
pass
-
+
#If we have an old loss function, reset the loss function to the original one
if hasattr(sd_model, '_old_get_loss'):
sd_model.get_loss = sd_model._old_get_loss
@@ -184,7 +184,7 @@ class StableDiffusionModelHijack: def undo_hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
- m.cond_stage_model = m.cond_stage_model.wrapped
+ m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
|