aboutsummaryrefslogtreecommitdiffstats
path: root/extensions-builtin/Lora/network.py
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin/Lora/network.py')
-rw-r--r--extensions-builtin/Lora/network.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py
index fe42dbdd..8ecfa29a 100644
--- a/extensions-builtin/Lora/network.py
+++ b/extensions-builtin/Lora/network.py
@@ -1,5 +1,6 @@
import os
from collections import namedtuple
+import enum
from modules import sd_models, cache, errors, hashes, shared
@@ -8,6 +9,13 @@ NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
+class SdVersion(enum.Enum):
+ Unknown = 1
+ SD1 = 2
+ SD2 = 3
+ SDXL = 4
+
+
class NetworkOnDisk:
def __init__(self, name, filename):
self.name = name
@@ -44,6 +52,18 @@ class NetworkOnDisk:
''
)
+ self.sd_version = self.detect_version()
+
+ def detect_version(self):
+ if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
+ return SdVersion.SDXL
+ elif str(self.metadata.get('ss_v2', "")) == "True":
+ return SdVersion.SD2
+ elif len(self.metadata):
+ return SdVersion.SD1
+
+ return SdVersion.Unknown
+
def set_hash(self, v):
self.hash = v
self.shorthash = self.hash[0:12]