diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/esrgan_model.py | 13 | ||||
-rw-r--r-- | modules/realesrgan_model.py | 6 | ||||
-rw-r--r-- | modules/sd_hijack.py | 20 | ||||
-rw-r--r-- | modules/shared.py | 5 |
4 files changed, 30 insertions, 14 deletions
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 2ed1d273..e86ad775 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -14,17 +14,20 @@ import modules.images def load_model(filename):
# this code is adapted from https://github.com/xinntao/ESRGAN
- if torch.has_mps:
- map_l = 'cpu'
- else:
- map_l = None
- pretrained_net = torch.load(filename, map_location=map_l)
+ pretrained_net = torch.load(filename, map_location='cpu' if torch.has_mps else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
if 'conv_first.weight' in pretrained_net:
crt_model.load_state_dict(pretrained_net)
return crt_model
+ if 'model.0.weight' not in pretrained_net:
+ is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
+ if is_realesrgan:
+ raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
+ else:
+ raise Exception("The file is not a ESRGAN model.")
+
crt_net = crt_model.state_dict()
load_net_clean = {}
for k, v in pretrained_net.items():
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index e480887f..e2cef0c8 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -5,7 +5,7 @@ import numpy as np from PIL import Image
import modules.images
-from modules.shared import cmd_opts
+from modules.shared import cmd_opts, opts
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
@@ -76,7 +76,9 @@ def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index) scale=info.netscale,
model_path=info.location,
model=model,
- half=not cmd_opts.no_half
+ half=not cmd_opts.no_half,
+ tile=opts.ESRGAN_tile,
+ tile_pad=opts.ESRGAN_tile_overlap,
)
upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 1084e248..db9952a5 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -73,11 +73,21 @@ class StableDiffusionModelHijack: name = os.path.splitext(filename)[0]
data = torch.load(path)
- param_dict = data['string_to_param']
- if hasattr(param_dict, '_parameters'):
- param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
+
+ # textual inversion embeddings
+ if 'string_to_param' in data:
+ param_dict = data['string_to_param']
+ if hasattr(param_dict, '_parameters'):
+ param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+
self.word_embeddings[name] = emb.detach()
self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1))&0xffff:04x}'
diff --git a/modules/shared.py b/modules/shared.py index e529ec27..85318d7e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -34,6 +34,7 @@ parser.add_argument("--share", action='store_true', help="use share=True for gra parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN'))
parser.add_argument("--opt-split-attention", action='store_true', help="enable optimization that reduced vram usage by a lot for about 10%% decrease in performance")
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
+parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
cmd_opts = parser.parse_args()
if torch.has_cuda:
@@ -117,8 +118,8 @@ class Options: "font": OptionInfo(find_any_font(), "Font for image grids that have text"),
"enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text text and [text] to make it pay less attention"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
- "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscaling. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
- "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscaling. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
+ "ESRGAN_tile": OptionInfo(192, "Tile size for upscaling. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
+ "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for upscaling. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
"upscale_at_full_resolution_padding": OptionInfo(16, "Inpainting at full resolution: padding, in pixels, for the masked region.", gr.Slider, {"minimum": 0, "maximum": 128, "step": 4}),
"show_progressbar": OptionInfo(True, "Show progressbar"),
|