From af081211ee93622473ee575de30fed2fd8263c09 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 11 Jul 2023 21:16:43 +0300 Subject: getting SD2.1 to run on SDXL repo --- modules/sd_models_xl.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 modules/sd_models_xl.py (limited to 'modules/sd_models_xl.py') diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py new file mode 100644 index 00000000..d43b8868 --- /dev/null +++ b/modules/sd_models_xl.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import torch + +import sgm.models.diffusion +import sgm.modules.diffusionmodules.denoiser_scaling +import sgm.modules.diffusionmodules.discretizer +from modules import devices + + +def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: list[str]): + for embedder in self.conditioner.embedders: + embedder.ucg_rate = 0.0 + + c = self.conditioner({'txt': batch}) + + return c + + +def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): + return self.model(x, t, cond) + + +def extend_sdxl(model): + dtype = next(model.model.diffusion_model.parameters()).dtype + model.model.diffusion_model.dtype = dtype + model.model.conditioning_key = 'crossattn' + + model.cond_stage_model = [x for x in model.conditioner.embedders if type(x).__name__ == 'FrozenOpenCLIPEmbedder'][0] + model.cond_stage_key = model.cond_stage_model.input_key + + model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" + + discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() + model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + + +sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning +sgm.models.diffusion.DiffusionEngine.apply_model = apply_model + -- cgit v1.2.3