diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2022-09-18 11:35:04 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-18 11:35:04 +0000 |
commit | 9e892d90ceaa6493d8b9092c89017485bb658c5b (patch) | |
tree | 02df6e8d420d207783be425cf57d07fa46cbe58b /modules/memmon.py | |
parent | 83a65919bb2af35c0d47cbb47b8db2ac233e86ce (diff) | |
parent | 46db1405df64db4c543b7cc9db958479a3db200f (diff) | |
download | stable-diffusion-webui-gfx803-9e892d90ceaa6493d8b9092c89017485bb658c5b.tar.gz stable-diffusion-webui-gfx803-9e892d90ceaa6493d8b9092c89017485bb658c5b.tar.bz2 stable-diffusion-webui-gfx803-9e892d90ceaa6493d8b9092c89017485bb658c5b.zip |
Merge pull request #651 from EyeDeck/master
Add some error handling for VRAM monitor
Diffstat (limited to 'modules/memmon.py')
-rw-r--r-- | modules/memmon.py | 22 |
1 files changed, 15 insertions, 7 deletions
diff --git a/modules/memmon.py b/modules/memmon.py index f2cac841..9fb9b687 100644 --- a/modules/memmon.py +++ b/modules/memmon.py @@ -22,6 +22,13 @@ class MemUsageMonitor(threading.Thread): self.run_flag = threading.Event() self.data = defaultdict(int) + try: + torch.cuda.mem_get_info() + torch.cuda.memory_stats(self.device) + except Exception as e: # AMD or whatever + print(f"Warning: caught exception '{e}', memory monitor disabled") + self.disabled = True + def run(self): if self.disabled: return @@ -62,13 +69,14 @@ class MemUsageMonitor(threading.Thread): self.run_flag.set() def read(self): - free, total = torch.cuda.mem_get_info() - self.data["total"] = total - - torch_stats = torch.cuda.memory_stats(self.device) - self.data["active_peak"] = torch_stats["active_bytes.all.peak"] - self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] - self.data["system_peak"] = total - self.data["min_free"] + if not self.disabled: + free, total = torch.cuda.mem_get_info() + self.data["total"] = total + + torch_stats = torch.cuda.memory_stats(self.device) + self.data["active_peak"] = torch_stats["active_bytes.all.peak"] + self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] + self.data["system_peak"] = total - self.data["min_free"] return self.data |