diff options
author | Karun <karun.ellango7@gmail.com> | 2023-03-25 09:12:55 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-25 09:12:55 +0000 |
commit | 63a2f8d8225228a52b3ca7f19d2ba1fd07a6234b (patch) | |
tree | 9a7c38070d83b409895704125525dfc70cc21215 /modules/memmon.py | |
parent | ca2b8faa83076a21dd14c974f03f88eb6da57485 (diff) | |
parent | 70615448b2ef3285dba9bb1992974cb1eaf10995 (diff) | |
download | stable-diffusion-webui-gfx803-63a2f8d8225228a52b3ca7f19d2ba1fd07a6234b.tar.gz stable-diffusion-webui-gfx803-63a2f8d8225228a52b3ca7f19d2ba1fd07a6234b.tar.bz2 stable-diffusion-webui-gfx803-63a2f8d8225228a52b3ca7f19d2ba1fd07a6234b.zip |
Merge branch 'master' into master
Diffstat (limited to 'modules/memmon.py')
-rw-r--r-- | modules/memmon.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/modules/memmon.py b/modules/memmon.py index a7060f58..4018edcc 100644 --- a/modules/memmon.py +++ b/modules/memmon.py @@ -23,12 +23,16 @@ class MemUsageMonitor(threading.Thread): self.data = defaultdict(int) try: - torch.cuda.mem_get_info() + self.cuda_mem_get_info() torch.cuda.memory_stats(self.device) except Exception as e: # AMD or whatever print(f"Warning: caught exception '{e}', memory monitor disabled") self.disabled = True + def cuda_mem_get_info(self): + index = self.device.index if self.device.index is not None else torch.cuda.current_device() + return torch.cuda.mem_get_info(index) + def run(self): if self.disabled: return @@ -43,10 +47,10 @@ class MemUsageMonitor(threading.Thread): self.run_flag.clear() continue - self.data["min_free"] = torch.cuda.mem_get_info()[0] + self.data["min_free"] = self.cuda_mem_get_info()[0] while self.run_flag.is_set(): - free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug? + free, total = self.cuda_mem_get_info() self.data["min_free"] = min(self.data["min_free"], free) time.sleep(1 / self.opts.memmon_poll_rate) @@ -70,7 +74,7 @@ class MemUsageMonitor(threading.Thread): def read(self): if not self.disabled: - free, total = torch.cuda.mem_get_info() + free, total = self.cuda_mem_get_info() self.data["free"] = free self.data["total"] = total |