mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Closes gh-42968 Pull Request resolved: https://github.com/pytorch/pytorch/pull/43446 Reviewed By: albanD Differential Revision: D23280962 Pulled By: malfet fbshipit-source-id: de5386a95a20ecc814c39cbec3e4252112340b3a
137 lines
5.5 KiB
Python
137 lines
5.5 KiB
Python
import torch
|
|
import copy
|
|
from typing import Dict, Any
|
|
|
|
_supported_types = {torch.nn.Conv2d, torch.nn.Linear}
|
|
|
|
def max_over_ndim(input, axis_list, keepdim=False):
|
|
''' Applies 'torch.max' over the given axises
|
|
'''
|
|
axis_list.sort(reverse=True)
|
|
for axis in axis_list:
|
|
input, _ = input.max(axis, keepdim)
|
|
return input
|
|
|
|
def min_over_ndim(input, axis_list, keepdim=False):
|
|
''' Applies 'torch.min' over the given axises
|
|
'''
|
|
axis_list.sort(reverse=True)
|
|
for axis in axis_list:
|
|
input, _ = input.min(axis, keepdim)
|
|
return input
|
|
|
|
def channel_range(input, axis=0):
|
|
''' finds the range of weights associated with a specific channel
|
|
'''
|
|
size_of_tensor_dim = input.ndim
|
|
axis_list = list(range(size_of_tensor_dim))
|
|
axis_list.remove(axis)
|
|
|
|
mins = min_over_ndim(input, axis_list)
|
|
maxs = max_over_ndim(input, axis_list)
|
|
|
|
assert mins.size(0) == input.size(axis), "Dimensions of resultant channel range does not match size of requested axis"
|
|
return maxs - mins
|
|
|
|
def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
|
|
''' Given two adjacent tensors', the weights are scaled such that
|
|
the ranges of the first tensors' output channel are equal to the
|
|
ranges of the second tensors' input channel
|
|
'''
|
|
if type(module1) not in _supported_types or type(module2) not in _supported_types:
|
|
raise ValueError("module type not supported:", type(module1), " ", type(module2))
|
|
|
|
if module1.weight.size(output_axis) != module2.weight.size(input_axis):
|
|
raise TypeError("Number of output channels of first arg do not match \
|
|
number input channels of second arg")
|
|
|
|
weight1 = module1.weight
|
|
weight2 = module2.weight
|
|
bias = module1.bias
|
|
|
|
weight1_range = channel_range(weight1, output_axis)
|
|
weight2_range = channel_range(weight2, input_axis)
|
|
|
|
# producing scaling factors to applied
|
|
weight2_range += 1e-9
|
|
scaling_factors = torch.sqrt(weight1_range / weight2_range)
|
|
inverse_scaling_factors = torch.reciprocal(scaling_factors)
|
|
|
|
bias = bias * inverse_scaling_factors
|
|
|
|
# formatting the scaling (1D) tensors to be applied on the given argument tensors
|
|
# pads axis to (1D) tensors to then be broadcasted
|
|
size1 = [1] * weight1.ndim
|
|
size1[output_axis] = weight1.size(output_axis)
|
|
size2 = [1] * weight2.ndim
|
|
size2[input_axis] = weight2.size(input_axis)
|
|
|
|
scaling_factors = torch.reshape(scaling_factors, size2)
|
|
inverse_scaling_factors = torch.reshape(inverse_scaling_factors, size1)
|
|
|
|
weight1 = weight1 * inverse_scaling_factors
|
|
weight2 = weight2 * scaling_factors
|
|
|
|
module1.weight = torch.nn.Parameter(weight1)
|
|
module1.bias = torch.nn.Parameter(bias)
|
|
module2.weight = torch.nn.Parameter(weight2)
|
|
|
|
def equalize(model, paired_modules_list, threshold=1e-4, inplace=True):
|
|
''' Given a list of adjacent modules within a model, equalization will
|
|
be applied between each pair, this will repeated until convergence is achieved
|
|
|
|
Keeps a copy of the changing modules from the previous iteration, if the copies
|
|
are not that different than the current modules (determined by converged_test),
|
|
then the modules have converged enough that further equalizing is not necessary
|
|
|
|
Implementation of this referced section 4.1 of this paper https://arxiv.org/pdf/1906.04721.pdf
|
|
|
|
Args:
|
|
model: a model (nn.module) that equalization is to be applied on
|
|
paired_modules_list: a list of lists where each sublist is a pair of two
|
|
submodules found in the model, for each pair the two submodules generally
|
|
have to be adjacent in the model to get expected/reasonable results
|
|
threshold: a number used by the converged function to determine what degree
|
|
similarity between models is necessary for them to be called equivalent
|
|
inplace: determines if function is inplace or not
|
|
'''
|
|
if not inplace:
|
|
model = copy.deepcopy(model)
|
|
|
|
name_to_module : Dict[str, torch.nn.Module] = {}
|
|
previous_name_to_module: Dict[str, Any] = {}
|
|
name_set = {name for pair in paired_modules_list for name in pair}
|
|
|
|
for name, module in model.named_modules():
|
|
if name in name_set:
|
|
name_to_module[name] = module
|
|
previous_name_to_module[name] = None
|
|
while not converged(name_to_module, previous_name_to_module, threshold):
|
|
for pair in paired_modules_list:
|
|
previous_name_to_module[pair[0]] = copy.deepcopy(name_to_module[pair[0]])
|
|
previous_name_to_module[pair[1]] = copy.deepcopy(name_to_module[pair[1]])
|
|
|
|
cross_layer_equalization(name_to_module[pair[0]], name_to_module[pair[1]])
|
|
|
|
return model
|
|
|
|
def converged(curr_modules, prev_modules, threshold=1e-4):
|
|
''' Tests for the summed norm of the differences between each set of modules
|
|
being less than the given threshold
|
|
|
|
Takes two dictionaries mapping names to modules, the set of names for each dictionary
|
|
should be the same, looping over the set of names, for each name take the differnce
|
|
between the associated modules in each dictionary
|
|
|
|
'''
|
|
if curr_modules.keys() != prev_modules.keys():
|
|
raise ValueError("The keys to the given mappings must have the same set of names of modules")
|
|
|
|
summed_norms = torch.tensor(0.)
|
|
if None in prev_modules.values():
|
|
return False
|
|
for name in curr_modules.keys():
|
|
difference = curr_modules[name].weight.sub(prev_modules[name].weight)
|
|
summed_norms += torch.norm(difference)
|
|
return bool(summed_norms < threshold)
|