|
|
import collections.abc |
|
|
from itertools import repeat |
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
from torch import nn as nn |
|
|
from torch import _assert |
|
|
from torchvision.ops.misc import FrozenBatchNorm2d |
|
|
|
|
|
|
|
|
def freeze_batch_norm_2d(module, module_match={}, name=''): |
|
|
""" |
|
|
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is |
|
|
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and |
|
|
returned. Otherwise, the module is walked recursively and submodules are converted in place. |
|
|
|
|
|
Args: |
|
|
module (torch.nn.Module): Any PyTorch module. |
|
|
module_match (dict): Dictionary of full module names to freeze (all if empty) |
|
|
name (str): Full module name (prefix) |
|
|
|
|
|
Returns: |
|
|
torch.nn.Module: Resulting module |
|
|
|
|
|
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 |
|
|
""" |
|
|
res = module |
|
|
is_match = True |
|
|
if module_match: |
|
|
is_match = name in module_match |
|
|
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): |
|
|
res = FrozenBatchNorm2d(module.num_features) |
|
|
res.num_features = module.num_features |
|
|
res.affine = module.affine |
|
|
if module.affine: |
|
|
res.weight.data = module.weight.data.clone().detach() |
|
|
res.bias.data = module.bias.data.clone().detach() |
|
|
res.running_mean.data = module.running_mean.data |
|
|
res.running_var.data = module.running_var.data |
|
|
res.eps = module.eps |
|
|
else: |
|
|
for child_name, child in module.named_children(): |
|
|
full_child_name = '.'.join([name, child_name]) if name else child_name |
|
|
new_child = freeze_batch_norm_2d(child, module_match, full_child_name) |
|
|
if new_child is not child: |
|
|
res.add_module(child_name, new_child) |
|
|
return res |
|
|
|
|
|
|
|
|
|
|
|
def _ntuple(n): |
|
|
def parse(x): |
|
|
if isinstance(x, collections.abc.Iterable): |
|
|
return x |
|
|
return tuple(repeat(x, n)) |
|
|
return parse |
|
|
|
|
|
|
|
|
to_1tuple = _ntuple(1) |
|
|
to_2tuple = _ntuple(2) |
|
|
to_3tuple = _ntuple(3) |
|
|
to_4tuple = _ntuple(4) |
|
|
to_ntuple = lambda n, x: _ntuple(n)(x) |
|
|
|
|
|
|
|
|
|
|
|
def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): |
|
|
for name, module in model.named_children(): |
|
|
if len(list(module.children())) > 0: |
|
|
replace_linear(module, linear_replacement, include_modules, copy_weights) |
|
|
|
|
|
if isinstance(module, torch.nn.Linear) and name in include_modules: |
|
|
old_module = model._modules[name] |
|
|
model._modules[name] = linear_replacement( |
|
|
module.in_features, |
|
|
module.out_features, |
|
|
module.bias is not None, |
|
|
) |
|
|
if copy_weights: |
|
|
model._modules[name].weight.data.copy_(old_module.weight.data) |
|
|
if model._modules[name].bias is not None: |
|
|
model._modules[name].bias.data.copy_(old_module.bias) |
|
|
|
|
|
return model |
|
|
|
|
|
def convert_int8_model_to_inference_mode(model): |
|
|
for m in model.modules(): |
|
|
if hasattr(m, 'prepare_for_eval'): |
|
|
int8_original_dtype = m.weight.dtype |
|
|
m.prepare_for_eval() |
|
|
m.int8_original_dtype = int8_original_dtype |
|
|
|
|
|
|
|
|
def feature_take_indices( |
|
|
num_features: int, |
|
|
indices: Optional[Union[int, List[int]]] = None, |
|
|
as_set: bool = False, |
|
|
) -> Tuple[List[int], int]: |
|
|
""" Determine the absolute feature indices to 'take' from. |
|
|
|
|
|
Note: This function can be called in forward() so must be torchscript compatible, |
|
|
which requires some incomplete typing and workaround hacks. |
|
|
|
|
|
Args: |
|
|
num_features: total number of features to select from |
|
|
indices: indices to select, |
|
|
None -> select all |
|
|
int -> select last n |
|
|
list/tuple of int -> return specified (-ve indices specify from end) |
|
|
as_set: return as a set |
|
|
|
|
|
Returns: |
|
|
List (or set) of absolute (from beginning) indices, Maximum index |
|
|
""" |
|
|
if indices is None: |
|
|
indices = num_features |
|
|
|
|
|
if isinstance(indices, int): |
|
|
|
|
|
_assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})') |
|
|
take_indices = [num_features - indices + i for i in range(indices)] |
|
|
else: |
|
|
take_indices: List[int] = [] |
|
|
for i in indices: |
|
|
idx = num_features + i if i < 0 else i |
|
|
_assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})') |
|
|
take_indices.append(idx) |
|
|
|
|
|
if not torch.jit.is_scripting() and as_set: |
|
|
return set(take_indices), max(take_indices) |
|
|
|
|
|
return take_indices, max(take_indices) |
|
|
|
|
|
|
|
|
def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: |
|
|
if isinstance(x, int): |
|
|
|
|
|
return tuple(range(-x, 0)) |
|
|
return tuple(x) |
|
|
|