aboutsummaryrefslogtreecommitdiffstats
path: root/modules/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/modules/devices.py b/modules/devices.py
index bd6bd579..ff279ac5 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -4,7 +4,7 @@ from functools import lru_cache
import torch
from modules import errors, shared
-from modules.torch_utils import get_param
+from modules import torch_utils
if sys.platform == "darwin":
from modules import mac_specific
@@ -132,7 +132,7 @@ patch_module_list = [
def manual_cast_forward(self, *args, **kwargs):
- org_dtype = get_param(self).dtype
+ org_dtype = torch_utils.get_param(self).dtype
self.to(dtype)
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}