File size: 5,274 Bytes
e66e8cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import collections.abc
from itertools import repeat
from typing import List, Optional, Tuple, Union
import torch
from torch import nn as nn
from torch import _assert
from torchvision.ops.misc import FrozenBatchNorm2d
def freeze_batch_norm_2d(module, module_match={}, name=''):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
returned. Otherwise, the module is walked recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
module_match (dict): Dictionary of full module names to freeze (all if empty)
name (str): Full module name (prefix)
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
is_match = True
if module_match:
is_match = name in module_match
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for child_name, child in module.named_children():
full_child_name = '.'.join([name, child_name]) if name else child_name
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
if new_child is not child:
res.add_module(child_name, new_child)
return res
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = lambda n, x: _ntuple(n)(x)
# Replaces all linear layers with linear_replacement
# TODO: add int8 support for other linear layers including attn and convnets
def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True):
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_linear(module, linear_replacement, include_modules, copy_weights)
if isinstance(module, torch.nn.Linear) and name in include_modules:
old_module = model._modules[name]
model._modules[name] = linear_replacement(
module.in_features,
module.out_features,
module.bias is not None,
)
if copy_weights:
model._modules[name].weight.data.copy_(old_module.weight.data)
if model._modules[name].bias is not None:
model._modules[name].bias.data.copy_(old_module.bias)
return model
def convert_int8_model_to_inference_mode(model):
for m in model.modules():
if hasattr(m, 'prepare_for_eval'):
int8_original_dtype = m.weight.dtype
m.prepare_for_eval()
m.int8_original_dtype = int8_original_dtype
def feature_take_indices(
num_features: int,
indices: Optional[Union[int, List[int]]] = None,
as_set: bool = False,
) -> Tuple[List[int], int]:
""" Determine the absolute feature indices to 'take' from.
Note: This function can be called in forward() so must be torchscript compatible,
which requires some incomplete typing and workaround hacks.
Args:
num_features: total number of features to select from
indices: indices to select,
None -> select all
int -> select last n
list/tuple of int -> return specified (-ve indices specify from end)
as_set: return as a set
Returns:
List (or set) of absolute (from beginning) indices, Maximum index
"""
if indices is None:
indices = num_features # all features if None
if isinstance(indices, int):
# convert int -> last n indices
_assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})')
take_indices = [num_features - indices + i for i in range(indices)]
else:
take_indices: List[int] = []
for i in indices:
idx = num_features + i if i < 0 else i
_assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})')
take_indices.append(idx)
if not torch.jit.is_scripting() and as_set:
return set(take_indices), max(take_indices)
return take_indices, max(take_indices)
def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:
if isinstance(x, int):
# if indices is an int, take last N features
return tuple(range(-x, 0))
return tuple(x)
|