aboutsummaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/memmon.py77
-rw-r--r--modules/shared.py5
-rw-r--r--modules/ui.py14
3 files changed, 95 insertions, 1 deletions
diff --git a/modules/memmon.py b/modules/memmon.py
new file mode 100644
index 00000000..f2cac841
--- /dev/null
+++ b/modules/memmon.py
@@ -0,0 +1,77 @@
+import threading
+import time
+from collections import defaultdict
+
+import torch
+
+
+class MemUsageMonitor(threading.Thread):
+ run_flag = None
+ device = None
+ disabled = False
+ opts = None
+ data = None
+
+ def __init__(self, name, device, opts):
+ threading.Thread.__init__(self)
+ self.name = name
+ self.device = device
+ self.opts = opts
+
+ self.daemon = True
+ self.run_flag = threading.Event()
+ self.data = defaultdict(int)
+
+ def run(self):
+ if self.disabled:
+ return
+
+ while True:
+ self.run_flag.wait()
+
+ torch.cuda.reset_peak_memory_stats()
+ self.data.clear()
+
+ if self.opts.memmon_poll_rate <= 0:
+ self.run_flag.clear()
+ continue
+
+ self.data["min_free"] = torch.cuda.mem_get_info()[0]
+
+ while self.run_flag.is_set():
+ free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
+ self.data["min_free"] = min(self.data["min_free"], free)
+
+ time.sleep(1 / self.opts.memmon_poll_rate)
+
+ def dump_debug(self):
+ print(self, 'recorded data:')
+ for k, v in self.read().items():
+ print(k, -(v // -(1024 ** 2)))
+
+ print(self, 'raw torch memory stats:')
+ tm = torch.cuda.memory_stats(self.device)
+ for k, v in tm.items():
+ if 'bytes' not in k:
+ continue
+ print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
+
+ print(torch.cuda.memory_summary())
+
+ def monitor(self):
+ self.run_flag.set()
+
+ def read(self):
+ free, total = torch.cuda.mem_get_info()
+ self.data["total"] = total
+
+ torch_stats = torch.cuda.memory_stats(self.device)
+ self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
+ self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
+ self.data["system_peak"] = total - self.data["min_free"]
+
+ return self.data
+
+ def stop(self):
+ self.run_flag.clear()
+ return self.read()
diff --git a/modules/shared.py b/modules/shared.py
index da56b6ae..4f877036 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -12,6 +12,7 @@ from modules.paths import script_path, sd_path
from modules.devices import get_optimal_device
import modules.styles
import modules.interrogate
+import modules.memmon
sd_model_file = os.path.join(script_path, 'model.ckpt')
if not os.path.exists(sd_model_file):
@@ -138,6 +139,7 @@ class Options:
"show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."),
+ "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step":1}),
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
@@ -217,3 +219,6 @@ class TotalTQDM:
total_tqdm = TotalTQDM()
+
+mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
+mem_mon.start()
diff --git a/modules/ui.py b/modules/ui.py
index 738ac945..01b2ba85 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -119,6 +119,7 @@ def save_files(js_data, images, index):
def wrap_gradio_call(func):
def f(*args, **kwargs):
+ shared.mem_mon.monitor()
t = time.perf_counter()
try:
@@ -135,8 +136,19 @@ def wrap_gradio_call(func):
elapsed = time.perf_counter() - t
+ mem_stats = {k:-(v//-(1024*1024)) for k,v in shared.mem_mon.stop().items()}
+ active_peak = mem_stats['active_peak']
+ reserved_peak = mem_stats['reserved_peak']
+ sys_peak = '?' if opts.memmon_poll_rate <= 0 else mem_stats['system_peak']
+ sys_total = mem_stats['total']
+ sys_pct = '?' if opts.memmon_poll_rate <= 0 else round(sys_peak/sys_total * 100, 2)
+ vram_tooltip = "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.&#013;" \
+ "Torch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.&#013;" \
+ "Sys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%)."
+
# last item is always HTML
- res[-1] = res[-1] + f"<p class='performance'>Time taken: {elapsed:.2f}s</p>"
+ res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>" \
+ f"<p class='vram' title='{vram_tooltip}'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p></div>"
shared.state.interrupted = False