aboutsummaryrefslogtreecommitdiffstats
path: root/modules/esrgan_model.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2022-11-02 11:09:38 +0000
committerGitHub <noreply@github.com>2022-11-02 11:09:38 +0000
commite359268be9936db1c16c78adf544d622d33d1bfb (patch)
tree15d41dc2832ea849192eb46f22a26cb2255f76f9 /modules/esrgan_model.py
parentbb21a4cb35986d95ec63ace48ce13b75a776f5a5 (diff)
parentc9bb33dd43dbb9479ff1b70351df14508c89ac60 (diff)
downloadstable-diffusion-webui-gfx803-e359268be9936db1c16c78adf544d622d33d1bfb.tar.gz
stable-diffusion-webui-gfx803-e359268be9936db1c16c78adf544d622d33d1bfb.tar.bz2
stable-diffusion-webui-gfx803-e359268be9936db1c16c78adf544d622d33d1bfb.zip
Merge pull request #3976 from victorca25/esrgan_fea
multiple trivial changes for "extras" models
Diffstat (limited to 'modules/esrgan_model.py')
-rw-r--r--modules/esrgan_model.py17
1 files changed, 13 insertions, 4 deletions
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index a13cf6ac..c61669b4 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -50,6 +50,7 @@ def mod2normal(state_dict):
def resrgan2normal(state_dict, nb=23):
# this code is copied from https://github.com/victorca25/iNNfer
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
+ re8x = 0
crt_net = {}
items = []
for k, v in state_dict.items():
@@ -75,10 +76,18 @@ def resrgan2normal(state_dict, nb=23):
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
- crt_net['model.8.weight'] = state_dict['conv_hr.weight']
- crt_net['model.8.bias'] = state_dict['conv_hr.bias']
- crt_net['model.10.weight'] = state_dict['conv_last.weight']
- crt_net['model.10.bias'] = state_dict['conv_last.bias']
+
+ if 'conv_up3.weight' in state_dict:
+ # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
+ re8x = 3
+ crt_net['model.9.weight'] = state_dict['conv_up3.weight']
+ crt_net['model.9.bias'] = state_dict['conv_up3.bias']
+
+ crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
+ crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
+ crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
+ crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
+
state_dict = crt_net
return state_dict