| 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
 | import threading
import time
from collections import defaultdict
import torch
class MemUsageMonitor(threading.Thread):
    run_flag = None
    device = None
    disabled = False
    opts = None
    data = None
    def __init__(self, name, device, opts):
        threading.Thread.__init__(self)
        self.name = name
        self.device = device
        self.opts = opts
        self.daemon = True
        self.run_flag = threading.Event()
        self.data = defaultdict(int)
        try:
            torch.cuda.mem_get_info()
            torch.cuda.memory_stats(self.device)
        except Exception as e:  # AMD or whatever
            print(f"Warning: caught exception '{e}', memory monitor disabled")
            self.disabled = True
    def run(self):
        if self.disabled:
            return
        while True:
            self.run_flag.wait()
            torch.cuda.reset_peak_memory_stats()
            self.data.clear()
            if self.opts.memmon_poll_rate <= 0:
                self.run_flag.clear()
                continue
            self.data["min_free"] = torch.cuda.mem_get_info()[0]
            while self.run_flag.is_set():
                free, total = torch.cuda.mem_get_info()  # calling with self.device errors, torch bug?
                self.data["min_free"] = min(self.data["min_free"], free)
                time.sleep(1 / self.opts.memmon_poll_rate)
    def dump_debug(self):
        print(self, 'recorded data:')
        for k, v in self.read().items():
            print(k, -(v // -(1024 ** 2)))
        print(self, 'raw torch memory stats:')
        tm = torch.cuda.memory_stats(self.device)
        for k, v in tm.items():
            if 'bytes' not in k:
                continue
            print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
        print(torch.cuda.memory_summary())
    def monitor(self):
        self.run_flag.set()
    def read(self):
        if not self.disabled:
            free, total = torch.cuda.mem_get_info()
            self.data["total"] = total
            torch_stats = torch.cuda.memory_stats(self.device)
            self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
            self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
            self.data["system_peak"] = total - self.data["min_free"]
        return self.data
    def stop(self):
        self.run_flag.clear()
        return self.read()
 |