diff options
Diffstat (limited to 'modules/memmon.py')
-rw-r--r-- | modules/memmon.py | 22 |
1 files changed, 15 insertions, 7 deletions
diff --git a/modules/memmon.py b/modules/memmon.py index f2cac841..9fb9b687 100644 --- a/modules/memmon.py +++ b/modules/memmon.py @@ -22,6 +22,13 @@ class MemUsageMonitor(threading.Thread): self.run_flag = threading.Event() self.data = defaultdict(int) + try: + torch.cuda.mem_get_info() + torch.cuda.memory_stats(self.device) + except Exception as e: # AMD or whatever + print(f"Warning: caught exception '{e}', memory monitor disabled") + self.disabled = True + def run(self): if self.disabled: return @@ -62,13 +69,14 @@ class MemUsageMonitor(threading.Thread): self.run_flag.set() def read(self): - free, total = torch.cuda.mem_get_info() - self.data["total"] = total - - torch_stats = torch.cuda.memory_stats(self.device) - self.data["active_peak"] = torch_stats["active_bytes.all.peak"] - self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] - self.data["system_peak"] = total - self.data["min_free"] + if not self.disabled: + free, total = torch.cuda.mem_get_info() + self.data["total"] = total + + torch_stats = torch.cuda.memory_stats(self.device) + self.data["active_peak"] = torch_stats["active_bytes.all.peak"] + self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] + self.data["system_peak"] = total - self.data["min_free"] return self.data |