aboutsummaryrefslogtreecommitdiffstats
path: root/modules/hypernetworks/hypernetwork.py
diff options
context:
space:
mode:
authoraria1th <35677394+aria1th@users.noreply.github.com>2022-11-03 05:30:53 +0000
committeraria1th <35677394+aria1th@users.noreply.github.com>2022-11-03 05:30:53 +0000
commit0b143c1163a96b193a4e8512be9c5831c661a50d (patch)
tree59352d832798e6b65f334373c317ab7cb774205a /modules/hypernetworks/hypernetwork.py
parent7ea5956ad5fa925f92116e8a3bf78d7f6517b654 (diff)
downloadstable-diffusion-webui-gfx803-0b143c1163a96b193a4e8512be9c5831c661a50d.tar.gz
stable-diffusion-webui-gfx803-0b143c1163a96b193a4e8512be9c5831c661a50d.tar.bz2
stable-diffusion-webui-gfx803-0b143c1163a96b193a4e8512be9c5831c661a50d.zip
Separate .optim file from model
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r--modules/hypernetworks/hypernetwork.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 8f74cdea..63c25de8 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -161,6 +161,7 @@ class Hypernetwork:
def save(self, filename):
state_dict = {}
+ optimizer_saved_dict = {}
for k, v in self.layers.items():
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
@@ -175,9 +176,10 @@ class Hypernetwork:
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
if self.optimizer_name is not None:
- state_dict['optimizer_name'] = self.optimizer_name
+ optimizer_saved_dict['optimizer_name'] = self.optimizer_name
if self.optimizer_state_dict:
- state_dict['optimizer_state_dict'] = self.optimizer_state_dict
+ optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
+ torch.save(optimizer_saved_dict, filename + '.optim')
torch.save(state_dict, filename)
@@ -198,9 +200,11 @@ class Hypernetwork:
print(f"Layer norm is set to {self.add_layer_norm}")
self.use_dropout = state_dict.get('use_dropout', False)
print(f"Dropout usage is set to {self.use_dropout}")
- self.optimizer_name = state_dict.get('optimizer_name', 'AdamW')
+
+ optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
+ self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
print(f"Optimizer name is {self.optimizer_name}")
- self.optimizer_state_dict = state_dict.get('optimizer_state_dict', None)
+ self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
if self.optimizer_state_dict:
print("Loaded existing optimizer from checkpoint")
else: