diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2022-09-17 11:57:10 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-17 11:57:10 +0000 |
commit | 0d7fdb179104e48983d07e0175021f0e4bdc2d55 (patch) | |
tree | a183247f90049207a5af64b2882c0f92136ee6fe /modules/memmon.py | |
parent | ac61e4663c21ea0f51a4319162d3877e00554a2a (diff) | |
parent | 1ef79f926e6314b3ef9308b12ff7ad482afd790a (diff) | |
download | stable-diffusion-webui-gfx803-0d7fdb179104e48983d07e0175021f0e4bdc2d55.tar.gz stable-diffusion-webui-gfx803-0d7fdb179104e48983d07e0175021f0e4bdc2d55.tar.bz2 stable-diffusion-webui-gfx803-0d7fdb179104e48983d07e0175021f0e4bdc2d55.zip |
Merge branch 'master' into image_info_tab
Diffstat (limited to 'modules/memmon.py')
-rw-r--r-- | modules/memmon.py | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/modules/memmon.py b/modules/memmon.py new file mode 100644 index 00000000..f2cac841 --- /dev/null +++ b/modules/memmon.py @@ -0,0 +1,77 @@ +import threading +import time +from collections import defaultdict + +import torch + + +class MemUsageMonitor(threading.Thread): + run_flag = None + device = None + disabled = False + opts = None + data = None + + def __init__(self, name, device, opts): + threading.Thread.__init__(self) + self.name = name + self.device = device + self.opts = opts + + self.daemon = True + self.run_flag = threading.Event() + self.data = defaultdict(int) + + def run(self): + if self.disabled: + return + + while True: + self.run_flag.wait() + + torch.cuda.reset_peak_memory_stats() + self.data.clear() + + if self.opts.memmon_poll_rate <= 0: + self.run_flag.clear() + continue + + self.data["min_free"] = torch.cuda.mem_get_info()[0] + + while self.run_flag.is_set(): + free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug? + self.data["min_free"] = min(self.data["min_free"], free) + + time.sleep(1 / self.opts.memmon_poll_rate) + + def dump_debug(self): + print(self, 'recorded data:') + for k, v in self.read().items(): + print(k, -(v // -(1024 ** 2))) + + print(self, 'raw torch memory stats:') + tm = torch.cuda.memory_stats(self.device) + for k, v in tm.items(): + if 'bytes' not in k: + continue + print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2))) + + print(torch.cuda.memory_summary()) + + def monitor(self): + self.run_flag.set() + + def read(self): + free, total = torch.cuda.mem_get_info() + self.data["total"] = total + + torch_stats = torch.cuda.memory_stats(self.device) + self.data["active_peak"] = torch_stats["active_bytes.all.peak"] + self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] + self.data["system_peak"] = total - self.data["min_free"] + + return self.data + + def stop(self): + self.run_flag.clear() + return self.read() |