From f01682ee01e81e8ef84fd6fffe8f7aa17233285d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 15 Aug 2023 19:23:27 +0300 Subject: store patches for Lora in a specialized module --- modules/patches.py | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 modules/patches.py (limited to 'modules/patches.py') diff --git a/modules/patches.py b/modules/patches.py new file mode 100644 index 00000000..348235e7 --- /dev/null +++ b/modules/patches.py @@ -0,0 +1,64 @@ +from collections import defaultdict + + +def patch(key, obj, field, replacement): + """Replaces a function in a module or a class. + + Also stores the original function in this module, possible to be retrieved via original(key, obj, field). + If the function is already replaced by this caller (key), an exception is raised -- use undo() before that. + + Arguments: + key: identifying information for who is doing the replacement. You can use __name__. + obj: the module or the class + field: name of the function as a string + replacement: the new function + + Returns: + the original function + """ + + patch_key = (obj, field) + if patch_key in originals[key]: + raise RuntimeError(f"patch for {field} is already applied") + + original_func = getattr(obj, field) + originals[key][patch_key] = original_func + + setattr(obj, field, replacement) + + return original_func + + +def undo(key, obj, field): + """Undoes the peplacement by the patch(). + + If the function is not replaced, raises an exception. + + Arguments: + key: identifying information for who is doing the replacement. You can use __name__. + obj: the module or the class + field: name of the function as a string + + Returns: + Always None + """ + + patch_key = (obj, field) + + if patch_key not in originals[key]: + raise RuntimeError(f"there is no patch for {field} to undo") + + original_func = originals[key].pop(patch_key) + setattr(obj, field, original_func) + + return None + + +def original(key, obj, field): + """Returns the original function for the patch created by the patch() function""" + patch_key = (obj, field) + + return originals[key].get(patch_key, None) + + +originals = defaultdict(dict) -- cgit v1.2.3