Spaces:
Runtime error
Runtime error
| from __future__ import division | |
| from __future__ import unicode_literals | |
| import torch | |
| def get_param_buffer_for_ema(model, | |
| update_buffer=False, | |
| required_buffers=['running_mean', 'running_var']): | |
| params = model.parameters() | |
| all_param_buffer = [p for p in params if p.requires_grad] | |
| if update_buffer: | |
| named_buffers = model.named_buffers() | |
| for key, value in named_buffers: | |
| for buffer_name in required_buffers: | |
| if buffer_name in key: | |
| all_param_buffer.append(value) | |
| break | |
| return all_param_buffer | |
| class ExponentialMovingAverage: | |
| """ | |
| Maintains (exponential) moving average of a set of parameters. | |
| """ | |
| def __init__(self, parameters, decay, use_num_updates=True): | |
| """ | |
| Args: | |
| parameters: Iterable of `torch.nn.Parameter`; usually the result of | |
| `model.parameters()`. | |
| decay: The exponential decay. | |
| use_num_updates: Whether to use number of updates when computing | |
| averages. | |
| """ | |
| if decay < 0.0 or decay > 1.0: | |
| raise ValueError('Decay must be between 0 and 1') | |
| self.decay = decay | |
| self.num_updates = 0 if use_num_updates else None | |
| self.shadow_params = [p.clone().detach() for p in parameters] | |
| self.collected_params = [] | |
| def update(self, parameters): | |
| """ | |
| Update currently maintained parameters. | |
| Call this every time the parameters are updated, such as the result of | |
| the `optimizer.step()` call. | |
| Args: | |
| parameters: Iterable of `torch.nn.Parameter`; usually the same set of | |
| parameters used to initialize this object. | |
| """ | |
| decay = self.decay | |
| if self.num_updates is not None: | |
| self.num_updates += 1 | |
| decay = min(decay, | |
| (1 + self.num_updates) / (10 + self.num_updates)) | |
| one_minus_decay = 1.0 - decay | |
| with torch.no_grad(): | |
| for s_param, param in zip(self.shadow_params, parameters): | |
| s_param.sub_(one_minus_decay * (s_param - param)) | |
| def copy_to(self, parameters): | |
| """ | |
| Copy current parameters into given collection of parameters. | |
| Args: | |
| parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
| updated with the stored moving averages. | |
| """ | |
| for s_param, param in zip(self.shadow_params, parameters): | |
| param.data.copy_(s_param.data) | |
| def store(self, parameters): | |
| """ | |
| Save the current parameters for restoring later. | |
| Args: | |
| parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
| temporarily stored. | |
| """ | |
| self.collected_params = [param.clone() for param in parameters] | |
| def restore(self, parameters): | |
| """ | |
| Restore the parameters stored with the `store` method. | |
| Useful to validate the model with EMA parameters without affecting the | |
| original optimization process. Store the parameters before the | |
| `copy_to` method. After validation (or model saving), use this to | |
| restore the former parameters. | |
| Args: | |
| parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
| updated with the stored parameters. | |
| """ | |
| for c_param, param in zip(self.collected_params, parameters): | |
| param.data.copy_(c_param.data) | |
| del (self.collected_params) | |