diff options
author | EyeDeck <eyedeck@gmail.com> | 2022-09-17 04:49:31 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2022-09-17 06:15:16 +0000 |
commit | ed6787ca2fe950f633a925ccb0467eafd4ec0f43 (patch) | |
tree | 390b50f5f940efd255e5bf7f38ce6ca785cc4cf4 /modules/memmon.py | |
parent | 1fc1c537c7303be88e0da93c3a632c48acb101e9 (diff) | |
download | stable-diffusion-webui-gfx803-ed6787ca2fe950f633a925ccb0467eafd4ec0f43.tar.gz stable-diffusion-webui-gfx803-ed6787ca2fe950f633a925ccb0467eafd4ec0f43.tar.bz2 stable-diffusion-webui-gfx803-ed6787ca2fe950f633a925ccb0467eafd4ec0f43.zip |
Add VRAM monitoring
Diffstat (limited to 'modules/memmon.py')
-rw-r--r-- | modules/memmon.py | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/modules/memmon.py b/modules/memmon.py new file mode 100644 index 00000000..f2cac841 --- /dev/null +++ b/modules/memmon.py @@ -0,0 +1,77 @@ +import threading +import time +from collections import defaultdict + +import torch + + +class MemUsageMonitor(threading.Thread): + run_flag = None + device = None + disabled = False + opts = None + data = None + + def __init__(self, name, device, opts): + threading.Thread.__init__(self) + self.name = name + self.device = device + self.opts = opts + + self.daemon = True + self.run_flag = threading.Event() + self.data = defaultdict(int) + + def run(self): + if self.disabled: + return + + while True: + self.run_flag.wait() + + torch.cuda.reset_peak_memory_stats() + self.data.clear() + + if self.opts.memmon_poll_rate <= 0: + self.run_flag.clear() + continue + + self.data["min_free"] = torch.cuda.mem_get_info()[0] + + while self.run_flag.is_set(): + free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug? + self.data["min_free"] = min(self.data["min_free"], free) + + time.sleep(1 / self.opts.memmon_poll_rate) + + def dump_debug(self): + print(self, 'recorded data:') + for k, v in self.read().items(): + print(k, -(v // -(1024 ** 2))) + + print(self, 'raw torch memory stats:') + tm = torch.cuda.memory_stats(self.device) + for k, v in tm.items(): + if 'bytes' not in k: + continue + print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2))) + + print(torch.cuda.memory_summary()) + + def monitor(self): + self.run_flag.set() + + def read(self): + free, total = torch.cuda.mem_get_info() + self.data["total"] = total + + torch_stats = torch.cuda.memory_stats(self.device) + self.data["active_peak"] = torch_stats["active_bytes.all.peak"] + self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] + self.data["system_peak"] = total - self.data["min_free"] + + return self.data + + def stop(self): + self.run_flag.clear() + return self.read() |