diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-15 16:23:27 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-15 16:23:40 +0000 |
commit | f01682ee01e81e8ef84fd6fffe8f7aa17233285d (patch) | |
tree | 80f62099e6af5f77c7df8c092c37c71ed24750d9 /modules | |
parent | 7327be97aa9beeae881bf4649a56792bd284efd5 (diff) | |
download | stable-diffusion-webui-gfx803-f01682ee01e81e8ef84fd6fffe8f7aa17233285d.tar.gz stable-diffusion-webui-gfx803-f01682ee01e81e8ef84fd6fffe8f7aa17233285d.tar.bz2 stable-diffusion-webui-gfx803-f01682ee01e81e8ef84fd6fffe8f7aa17233285d.zip |
store patches for Lora in a specialized module
Diffstat (limited to 'modules')
-rw-r--r-- | modules/patches.py | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/modules/patches.py b/modules/patches.py new file mode 100644 index 00000000..348235e7 --- /dev/null +++ b/modules/patches.py @@ -0,0 +1,64 @@ +from collections import defaultdict
+
+
+def patch(key, obj, field, replacement):
+ """Replaces a function in a module or a class.
+
+ Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
+ If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
+
+ Arguments:
+ key: identifying information for who is doing the replacement. You can use __name__.
+ obj: the module or the class
+ field: name of the function as a string
+ replacement: the new function
+
+ Returns:
+ the original function
+ """
+
+ patch_key = (obj, field)
+ if patch_key in originals[key]:
+ raise RuntimeError(f"patch for {field} is already applied")
+
+ original_func = getattr(obj, field)
+ originals[key][patch_key] = original_func
+
+ setattr(obj, field, replacement)
+
+ return original_func
+
+
+def undo(key, obj, field):
+ """Undoes the peplacement by the patch().
+
+ If the function is not replaced, raises an exception.
+
+ Arguments:
+ key: identifying information for who is doing the replacement. You can use __name__.
+ obj: the module or the class
+ field: name of the function as a string
+
+ Returns:
+ Always None
+ """
+
+ patch_key = (obj, field)
+
+ if patch_key not in originals[key]:
+ raise RuntimeError(f"there is no patch for {field} to undo")
+
+ original_func = originals[key].pop(patch_key)
+ setattr(obj, field, original_func)
+
+ return None
+
+
+def original(key, obj, field):
+ """Returns the original function for the patch created by the patch() function"""
+ patch_key = (obj, field)
+
+ return originals[key].get(patch_key, None)
+
+
+originals = defaultdict(dict)
|