diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/memmon.py | 77 | ||||
-rw-r--r-- | modules/shared.py | 5 | ||||
-rw-r--r-- | modules/ui.py | 14 |
3 files changed, 95 insertions, 1 deletions
diff --git a/modules/memmon.py b/modules/memmon.py new file mode 100644 index 00000000..f2cac841 --- /dev/null +++ b/modules/memmon.py @@ -0,0 +1,77 @@ +import threading +import time +from collections import defaultdict + +import torch + + +class MemUsageMonitor(threading.Thread): + run_flag = None + device = None + disabled = False + opts = None + data = None + + def __init__(self, name, device, opts): + threading.Thread.__init__(self) + self.name = name + self.device = device + self.opts = opts + + self.daemon = True + self.run_flag = threading.Event() + self.data = defaultdict(int) + + def run(self): + if self.disabled: + return + + while True: + self.run_flag.wait() + + torch.cuda.reset_peak_memory_stats() + self.data.clear() + + if self.opts.memmon_poll_rate <= 0: + self.run_flag.clear() + continue + + self.data["min_free"] = torch.cuda.mem_get_info()[0] + + while self.run_flag.is_set(): + free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug? + self.data["min_free"] = min(self.data["min_free"], free) + + time.sleep(1 / self.opts.memmon_poll_rate) + + def dump_debug(self): + print(self, 'recorded data:') + for k, v in self.read().items(): + print(k, -(v // -(1024 ** 2))) + + print(self, 'raw torch memory stats:') + tm = torch.cuda.memory_stats(self.device) + for k, v in tm.items(): + if 'bytes' not in k: + continue + print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2))) + + print(torch.cuda.memory_summary()) + + def monitor(self): + self.run_flag.set() + + def read(self): + free, total = torch.cuda.mem_get_info() + self.data["total"] = total + + torch_stats = torch.cuda.memory_stats(self.device) + self.data["active_peak"] = torch_stats["active_bytes.all.peak"] + self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] + self.data["system_peak"] = total - self.data["min_free"] + + return self.data + + def stop(self): + self.run_flag.clear() + return self.read() diff --git a/modules/shared.py b/modules/shared.py index da56b6ae..4f877036 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -12,6 +12,7 @@ from modules.paths import script_path, sd_path from modules.devices import get_optimal_device
import modules.styles
import modules.interrogate
+import modules.memmon
sd_model_file = os.path.join(script_path, 'model.ckpt')
if not os.path.exists(sd_model_file):
@@ -138,6 +139,7 @@ class Options: "show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."),
+ "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step":1}),
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
@@ -217,3 +219,6 @@ class TotalTQDM: total_tqdm = TotalTQDM()
+
+mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
+mem_mon.start()
diff --git a/modules/ui.py b/modules/ui.py index 738ac945..01b2ba85 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -119,6 +119,7 @@ def save_files(js_data, images, index): def wrap_gradio_call(func):
def f(*args, **kwargs):
+ shared.mem_mon.monitor()
t = time.perf_counter()
try:
@@ -135,8 +136,19 @@ def wrap_gradio_call(func): elapsed = time.perf_counter() - t
+ mem_stats = {k:-(v//-(1024*1024)) for k,v in shared.mem_mon.stop().items()}
+ active_peak = mem_stats['active_peak']
+ reserved_peak = mem_stats['reserved_peak']
+ sys_peak = '?' if opts.memmon_poll_rate <= 0 else mem_stats['system_peak']
+ sys_total = mem_stats['total']
+ sys_pct = '?' if opts.memmon_poll_rate <= 0 else round(sys_peak/sys_total * 100, 2)
+ vram_tooltip = "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.
" \
+ "Torch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.
" \
+ "Sys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%)."
+
# last item is always HTML
- res[-1] = res[-1] + f"<p class='performance'>Time taken: {elapsed:.2f}s</p>"
+ res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>" \
+ f"<p class='vram' title='{vram_tooltip}'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p></div>"
shared.state.interrupted = False
|