From bf5067f50ca32cd4764638702e3cc38bca8bfd8b Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 12:54:28 +0800 Subject: Fix alphas cumprod --- modules/sd_models_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models_xl.py') diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 01123321..11259a36 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -93,7 +93,7 @@ def extend_sdxl(model): model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() - model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32) model.conditioner.wrapped = torch.nn.Module() -- cgit v1.2.3 From 9feb034e343d6d7ef63395821658fb3774b30a24 Mon Sep 17 00:00:00 2001 From: wangqyqq Date: Thu, 21 Dec 2023 20:15:51 +0800 Subject: support for sdxl-inpaint model --- configs/sd_xl_inpaint.yaml | 98 +++++++++++++++++++++++++++++++++++++++++++++ modules/processing.py | 19 +++++++++ modules/sd_models_config.py | 6 ++- modules/sd_models_xl.py | 5 +++ 4 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 configs/sd_xl_inpaint.yaml (limited to 'modules/sd_models_xl.py') diff --git a/configs/sd_xl_inpaint.yaml b/configs/sd_xl_inpaint.yaml new file mode 100644 index 00000000..3bad3721 --- /dev/null +++ b/configs/sd_xl_inpaint.yaml @@ -0,0 +1,98 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.13025 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + adm_in_channels: 2816 + num_classes: sequential + use_checkpoint: True + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 + context_dim: 2048 + spatial_transformer_attn_type: softmax-xformers + legacy: False + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + # crossattn cond + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenCLIPEmbedder + params: + layer: hidden + layer_idx: 11 + # crossattn and vector cond + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 + params: + arch: ViT-bigG-14 + version: laion2b_s39b_b160k + freeze: True + layer: penultimate + always_return_pooled: True + legacy: False + # vector cond + - is_trainable: False + input_key: original_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: crop_coords_top_left + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: target_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity diff --git a/modules/processing.py b/modules/processing.py index 6f01c95f..159548db 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -106,6 +106,20 @@ def txt2img_image_conditioning(sd_model, x, width, height): return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) else: + sd = sd_model.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input.shape[1] == 9: + # The "masked-image" in this case will just be all 0.5 since the entire image is masked. + image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 + image_conditioning = images_tensor_to_samples(image_conditioning, + approximation_indexes.get(opts.sd_vae_encode_method)) + + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) + + return image_conditioning + # Dummy zero conditioning if we're not using inpainting or unclip models. # Still takes up a bit of memory, but no encoder call. # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. @@ -362,6 +376,11 @@ class StableDiffusionProcessing: if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) + sd = self.sampler.model_wrap.inner_model.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input.shape[1] == 9: + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index deab2f6e..b38137eb 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -15,6 +15,7 @@ config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") +config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") @@ -71,7 +72,10 @@ def guess_model_config_from_state_dict(sd, filename): sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: - return config_sdxl + if diffusion_model_input.shape[1] == 9: + return config_sdxl_inpainting + else: + return config_sdxl if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: return config_sdxl_refiner elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 01123321..d8a9a73b 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -34,6 +34,11 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): + sd = self.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input.shape[1] == 9: + x = torch.cat([x] + cond['c_concat'], dim=1) + return self.model(x, t, cond) -- cgit v1.2.3 From bfe418a58d39c69ca2672e7d8a1fd7ad2b34869b Mon Sep 17 00:00:00 2001 From: wangqyqq Date: Wed, 27 Dec 2023 10:20:56 +0800 Subject: add some codes for robust --- modules/processing.py | 24 +++++++++++++----------- modules/sd_models_xl.py | 5 +++-- 2 files changed, 16 insertions(+), 13 deletions(-) (limited to 'modules/sd_models_xl.py') diff --git a/modules/processing.py b/modules/processing.py index 159548db..c05e608a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -108,17 +108,18 @@ def txt2img_image_conditioning(sd_model, x, width, height): else: sd = sd_model.model.state_dict() diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input.shape[1] == 9: - # The "masked-image" in this case will just be all 0.5 since the entire image is masked. - image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 - image_conditioning = images_tensor_to_samples(image_conditioning, - approximation_indexes.get(opts.sd_vae_encode_method)) + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + # The "masked-image" in this case will just be all 0.5 since the entire image is masked. + image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 + image_conditioning = images_tensor_to_samples(image_conditioning, + approximation_indexes.get(opts.sd_vae_encode_method)) - # Add the fake full 1s mask to the first dimension. - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) - return image_conditioning + return image_conditioning # Dummy zero conditioning if we're not using inpainting or unclip models. # Still takes up a bit of memory, but no encoder call. @@ -378,8 +379,9 @@ class StableDiffusionProcessing: sd = self.sampler.model_wrap.inner_model.model.state_dict() diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input.shape[1] == 9: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index d8a9a73b..162d0fee 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -36,8 +36,9 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): sd = self.model.state_dict() diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input.shape[1] == 9: - x = torch.cat([x] + cond['c_concat'], dim=1) + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + x = torch.cat([x] + cond['c_concat'], dim=1) return self.model(x, t, cond) -- cgit v1.2.3 From 5768afc776a66bb94e77a9c1daebeea58fa731d5 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:20:30 +0200 Subject: Add utility to inspect a model's parameters (to get dtype/device) --- modules/devices.py | 3 ++- modules/interrogate.py | 3 ++- modules/sd_models_xl.py | 3 ++- modules/torch_utils.py | 17 +++++++++++++++++ modules/upscaler_utils.py | 5 +++-- modules/xlmr.py | 5 ++++- modules/xlmr_m18.py | 5 ++++- test/test_torch_utils.py | 19 +++++++++++++++++++ 8 files changed, 53 insertions(+), 7 deletions(-) create mode 100644 modules/torch_utils.py create mode 100644 test/test_torch_utils.py (limited to 'modules/sd_models_xl.py') diff --git a/modules/devices.py b/modules/devices.py index c956207f..bd6bd579 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -4,6 +4,7 @@ from functools import lru_cache import torch from modules import errors, shared +from modules.torch_utils import get_param if sys.platform == "darwin": from modules import mac_specific @@ -131,7 +132,7 @@ patch_module_list = [ def manual_cast_forward(self, *args, **kwargs): - org_dtype = next(self.parameters()).dtype + org_dtype = get_param(self).dtype self.to(dtype) args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} diff --git a/modules/interrogate.py b/modules/interrogate.py index 3045560d..5be5a10f 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -11,6 +11,7 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from modules import devices, paths, shared, lowvram, modelloader, errors +from modules.torch_utils import get_param blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -131,7 +132,7 @@ class InterrogateModels: self.clip_model = self.clip_model.to(devices.device_interrogate) - self.dtype = next(self.clip_model.parameters()).dtype + self.dtype = get_param(self.clip_model).dtype def send_clip_to_ram(self): if not shared.opts.interrogate_keep_models_in_memory: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 1de31b0d..c3602a7e 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -6,6 +6,7 @@ import sgm.models.diffusion import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.discretizer from modules import devices, shared, prompt_parser +from modules.torch_utils import get_param def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): @@ -90,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt def extend_sdxl(model): """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" - dtype = next(model.model.diffusion_model.parameters()).dtype + dtype = get_param(model.model.diffusion_model).dtype model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' model.cond_stage_key = 'txt' diff --git a/modules/torch_utils.py b/modules/torch_utils.py new file mode 100644 index 00000000..e5b52393 --- /dev/null +++ b/modules/torch_utils.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import torch.nn + + +def get_param(model) -> torch.nn.Parameter: + """ + Find the first parameter in a model or module. + """ + if hasattr(model, "model") and hasattr(model.model, "parameters"): + # Unpeel a model descriptor to get at the actual Torch module. + model = model.model + + for param in model.parameters(): + return param + + raise ValueError(f"No parameters found in model {model!r}") diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 8e413854..c60e3beb 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -7,6 +7,7 @@ import tqdm from PIL import Image from modules import images, shared +from modules.torch_utils import get_param logger = logging.getLogger(__name__) @@ -17,8 +18,8 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - model_weight = next(iter(model.model.parameters())) - img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) + param = get_param(model) + img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): output = model(img) diff --git a/modules/xlmr.py b/modules/xlmr.py index a407a3ca..6e000a56 100644 --- a/modules/xlmr.py +++ b/modules/xlmr.py @@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional +from modules.torch_utils import get_param + + class BertSeriesConfig(BertConfig): def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): @@ -62,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): self.post_init() def encode(self,c): - device = next(self.parameters()).device + device = get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py index a727e865..e3e81961 100644 --- a/modules/xlmr_m18.py +++ b/modules/xlmr_m18.py @@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional +from modules.torch_utils import get_param + + class BertSeriesConfig(BertConfig): def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): @@ -68,7 +71,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): self.post_init() def encode(self,c): - device = next(self.parameters()).device + device = get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/test/test_torch_utils.py b/test/test_torch_utils.py new file mode 100644 index 00000000..f1aec832 --- /dev/null +++ b/test/test_torch_utils.py @@ -0,0 +1,19 @@ +import types + +import pytest +import torch + +from modules.torch_utils import get_param + + +@pytest.mark.parametrize("wrapped", [True, False]) +def test_get_param(wrapped): + mod = torch.nn.Linear(1, 1) + cpu = torch.device("cpu") + mod.to(dtype=torch.float16, device=cpu) + if wrapped: + # more or less how spandrel wraps a thing + mod = types.SimpleNamespace(model=mod) + p = get_param(mod) + assert p.dtype == torch.float16 + assert p.device == cpu -- cgit v1.2.3 From a70dfb64a86b9b6d869deffdb0ffebe980365473 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 31 Dec 2023 22:38:30 +0300 Subject: change import statements for #14478 --- modules/devices.py | 4 ++-- modules/interrogate.py | 5 ++--- modules/sd_models_xl.py | 4 ++-- modules/upscaler_utils.py | 5 ++--- modules/xlmr.py | 4 ++-- modules/xlmr_m18.py | 5 ++--- test/test_torch_utils.py | 4 ++-- 7 files changed, 14 insertions(+), 17 deletions(-) (limited to 'modules/sd_models_xl.py') diff --git a/modules/devices.py b/modules/devices.py index bd6bd579..ff279ac5 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -4,7 +4,7 @@ from functools import lru_cache import torch from modules import errors, shared -from modules.torch_utils import get_param +from modules import torch_utils if sys.platform == "darwin": from modules import mac_specific @@ -132,7 +132,7 @@ patch_module_list = [ def manual_cast_forward(self, *args, **kwargs): - org_dtype = get_param(self).dtype + org_dtype = torch_utils.get_param(self).dtype self.to(dtype) args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} diff --git a/modules/interrogate.py b/modules/interrogate.py index 5be5a10f..35a627ca 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -10,8 +10,7 @@ import torch.hub from torchvision import transforms from torchvision.transforms.functional import InterpolationMode -from modules import devices, paths, shared, lowvram, modelloader, errors -from modules.torch_utils import get_param +from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -132,7 +131,7 @@ class InterrogateModels: self.clip_model = self.clip_model.to(devices.device_interrogate) - self.dtype = get_param(self.clip_model).dtype + self.dtype = torch_utils.get_param(self.clip_model).dtype def send_clip_to_ram(self): if not shared.opts.interrogate_keep_models_in_memory: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index c3602a7e..0de17af3 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -6,7 +6,7 @@ import sgm.models.diffusion import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.discretizer from modules import devices, shared, prompt_parser -from modules.torch_utils import get_param +from modules import torch_utils def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): @@ -91,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt def extend_sdxl(model): """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" - dtype = get_param(model.model.diffusion_model).dtype + dtype = torch_utils.get_param(model.model.diffusion_model).dtype model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' model.cond_stage_key = 'txt' diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index c60e3beb..f5cb92d5 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,8 +6,7 @@ import torch import tqdm from PIL import Image -from modules import images, shared -from modules.torch_utils import get_param +from modules import images, shared, torch_utils logger = logging.getLogger(__name__) @@ -18,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - param = get_param(model) + param = torch_utils.get_param(model) img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): diff --git a/modules/xlmr.py b/modules/xlmr.py index 6e000a56..319771b7 100644 --- a/modules/xlmr.py +++ b/modules/xlmr.py @@ -5,7 +5,7 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional -from modules.torch_utils import get_param +from modules import torch_utils class BertSeriesConfig(BertConfig): @@ -65,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): self.post_init() def encode(self,c): - device = get_param(self).device + device = torch_utils.get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py index e3e81961..f6055504 100644 --- a/modules/xlmr_m18.py +++ b/modules/xlmr_m18.py @@ -4,8 +4,7 @@ import torch from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional - -from modules.torch_utils import get_param +from modules import torch_utils class BertSeriesConfig(BertConfig): @@ -71,7 +70,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): self.post_init() def encode(self,c): - device = get_param(self).device + device = torch_utils.get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/test/test_torch_utils.py b/test/test_torch_utils.py index f1aec832..23ccb93a 100644 --- a/test/test_torch_utils.py +++ b/test/test_torch_utils.py @@ -3,7 +3,7 @@ import types import pytest import torch -from modules.torch_utils import get_param +from modules import torch_utils @pytest.mark.parametrize("wrapped", [True, False]) @@ -14,6 +14,6 @@ def test_get_param(wrapped): if wrapped: # more or less how spandrel wraps a thing mod = types.SimpleNamespace(model=mod) - p = get_param(mod) + p = torch_utils.get_param(mod) assert p.dtype == torch.float16 assert p.device == cpu -- cgit v1.2.3