diff options
author | zhaohu xing <920232796@qq.com> | 2022-11-30 06:56:12 +0000 |
---|---|---|
committer | zhaohu xing <920232796@qq.com> | 2022-11-30 06:56:12 +0000 |
commit | 52cc83d36b7663a77b79fd2258d2ca871af73e55 (patch) | |
tree | 5c31e75a3934327331d5636bd6ef1420c3ba32fe /ldm/modules/ema.py | |
parent | a39a57cb1f5964d9af2b541f7b352576adeeac0f (diff) | |
download | stable-diffusion-webui-gfx803-52cc83d36b7663a77b79fd2258d2ca871af73e55.tar.gz stable-diffusion-webui-gfx803-52cc83d36b7663a77b79fd2258d2ca871af73e55.tar.bz2 stable-diffusion-webui-gfx803-52cc83d36b7663a77b79fd2258d2ca871af73e55.zip |
fix bugs
Signed-off-by: zhaohu xing <920232796@qq.com>
Diffstat (limited to 'ldm/modules/ema.py')
-rw-r--r-- | ldm/modules/ema.py | 76 |
1 files changed, 0 insertions, 76 deletions
diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py deleted file mode 100644 index c8c75af4..00000000 --- a/ldm/modules/ema.py +++ /dev/null @@ -1,76 +0,0 @@ -import torch -from torch import nn - - -class LitEma(nn.Module): - def __init__(self, model, decay=0.9999, use_num_upates=True): - super().__init__() - if decay < 0.0 or decay > 1.0: - raise ValueError('Decay must be between 0 and 1') - - self.m_name2s_name = {} - self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) - self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates - else torch.tensor(-1,dtype=torch.int)) - - for name, p in model.named_parameters(): - if p.requires_grad: - #remove as '.'-character is not allowed in buffers - s_name = name.replace('.','') - self.m_name2s_name.update({name:s_name}) - self.register_buffer(s_name,p.clone().detach().data) - - self.collected_params = [] - - def forward(self,model): - decay = self.decay - - if self.num_updates >= 0: - self.num_updates += 1 - decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) - - one_minus_decay = 1.0 - decay - - with torch.no_grad(): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - - for key in m_param: - if m_param[key].requires_grad: - sname = self.m_name2s_name[key] - shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) - shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) - else: - assert not key in self.m_name2s_name - - def copy_to(self, model): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - for key in m_param: - if m_param[key].requires_grad: - m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) - else: - assert not key in self.m_name2s_name - - def store(self, parameters): - """ - Save the current parameters for restoring later. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. - """ - self.collected_params = [param.clone() for param in parameters] - - def restore(self, parameters): - """ - Restore the parameters stored with the `store` method. - Useful to validate the model with EMA parameters without affecting the - original optimization process. Store the parameters before the - `copy_to` method. After validation (or model saving), use this to - restore the former parameters. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. - """ - for c_param, param in zip(self.collected_params, parameters): - param.data.copy_(c_param.data) |