mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20691 Differential Revision: D15435960 Pulled By: ezyang fbshipit-source-id: 548b7cc42e52ad2c641ec7d9cf78028d9411d02e
352 lines
14 KiB
Python
352 lines
14 KiB
Python
from itertools import repeat
|
|
from collections import defaultdict
|
|
|
|
import torch
|
|
from torch._thnn.utils import parse_header, THNN_H_PATH
|
|
from torch.autograd.function import Function, InplaceFunction
|
|
from torch._thnn import type2backend
|
|
from .auto_double_backwards import double_backwards_fns
|
|
from .auto_symbolic import symbolic_fns
|
|
|
|
from . import _all_functions
|
|
|
|
|
|
def _make_function_class_criterion(class_name, update_output, update_grad_input, acc_grad_parameters,
|
|
double_backwards_fn, symbolic_fn):
|
|
weight_arg_idx = -1
|
|
for i, arg in enumerate(update_output.arguments):
|
|
if arg.name.startswith('weight'):
|
|
weight_arg_idx = i
|
|
break
|
|
|
|
reduce_arg_idx = -1
|
|
for i, arg in enumerate(update_output.arguments):
|
|
if arg.name == 'reduce':
|
|
reduce_arg_idx = i
|
|
break
|
|
|
|
buffers_idx = []
|
|
additional_arg_idx = 0
|
|
for arg in update_output.arguments[4:]:
|
|
if not arg.name.startswith('weight') and arg.type == 'THTensor*':
|
|
buffers_idx.append(additional_arg_idx)
|
|
additional_arg_idx += 1
|
|
|
|
@staticmethod
|
|
def symbolic(*args, **kwargs):
|
|
a = symbolic_fn(*args, **kwargs)
|
|
return a
|
|
|
|
@staticmethod
|
|
def forward(ctx, input, target, *args):
|
|
ctx._backend = type2backend[input.type()]
|
|
ctx.save_for_backward(input, target)
|
|
if weight_arg_idx >= 0:
|
|
ctx.weight = args[0]
|
|
args = args[1:]
|
|
ctx.additional_args = list(args)
|
|
insert_idx = weight_arg_idx - 4 # state, input, target, output
|
|
ctx.additional_args.insert(insert_idx, ctx.weight)
|
|
else:
|
|
ctx.additional_args = list(args)
|
|
|
|
ctx.forward_args_count = len(ctx.additional_args)
|
|
for idx in buffers_idx:
|
|
ctx.additional_args.insert(idx, input.new(1))
|
|
output = input.new(1)
|
|
getattr(ctx._backend, update_output.name)(ctx._backend.library_state, input, target,
|
|
output, *ctx.additional_args)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input, target = ctx.saved_tensors
|
|
# apply returns grad_input, so we need to return Nones for target (1) + 1 for each extra arg passed to forward.
|
|
return ((backward_cls.apply(input, target, grad_output, ctx.additional_args, ctx._backend),) +
|
|
(None,) * (ctx.forward_args_count + 1))
|
|
|
|
@staticmethod
|
|
def backward_cls_forward(ctx, input, target, grad_output, additional_args_ctx, backend_ctx):
|
|
ctx.additional_args = additional_args_ctx
|
|
ctx._backend = backend_ctx
|
|
ctx.save_for_backward(input, target, grad_output)
|
|
grad_input = grad_output.new().resize_as_(input).zero_()
|
|
|
|
if reduce_arg_idx >= 0:
|
|
getattr(ctx._backend, update_grad_input.name)(ctx._backend.library_state, input, target,
|
|
grad_output, grad_input, *ctx.additional_args)
|
|
return grad_input
|
|
|
|
getattr(ctx._backend, update_grad_input.name)(ctx._backend.library_state, input, target,
|
|
grad_input, *ctx.additional_args)
|
|
grad_output_expanded = grad_output.view(*repeat(1, grad_input.dim()))
|
|
grad_input.mul_(grad_output_expanded.expand_as(grad_input))
|
|
return grad_input
|
|
|
|
@staticmethod
|
|
def backward_cls_backward(ctx, *grad_params):
|
|
return double_backwards_fn(ctx, *grad_params)
|
|
|
|
backward_cls = type(class_name + "Backward", (Function,),
|
|
dict(forward=backward_cls_forward, backward=backward_cls_backward))
|
|
return type(class_name, (Function,), dict(forward=forward, backward=backward, symbolic=symbolic)), backward_cls
|
|
|
|
|
|
def _find_buffers(args, ignored_args):
|
|
additional_arg_idx = 0
|
|
buffers = []
|
|
for arg in args:
|
|
if arg.name in ignored_args:
|
|
continue
|
|
if arg.type == 'THTensor*':
|
|
buffers.append((additional_arg_idx, arg.name))
|
|
additional_arg_idx += 1
|
|
return buffers
|
|
|
|
|
|
def _make_function_class(class_name, update_output, update_grad_input, acc_grad_parameters,
|
|
double_backwards_fn, symbolic_fn):
|
|
def has_argument(fn, name):
|
|
for arg in fn.arguments:
|
|
if arg.name == name:
|
|
return True
|
|
return False
|
|
save_output = has_argument(update_grad_input, 'output')
|
|
|
|
param_args = {'weight', 'bias'}
|
|
ignored_args = {'weight', 'bias', 'gradWeight', 'gradBias', 'output'}
|
|
expected_params = [arg for arg in update_output.arguments[3:]
|
|
if arg.name in param_args]
|
|
buffers = {}
|
|
buffers['update_output'] = _find_buffers(update_output.arguments[3:],
|
|
ignored_args)
|
|
buffers['update_grad_input'] = _find_buffers(
|
|
update_grad_input.arguments[4:], ignored_args)
|
|
if acc_grad_parameters is not None:
|
|
buffers['acc_grad_parameters'] = _find_buffers(
|
|
acc_grad_parameters.arguments[3:], ignored_args)
|
|
|
|
# This assumes that only the last argument can be
|
|
# an inplace flag
|
|
is_inplace = update_output.arguments[-1].name == 'inplace'
|
|
|
|
def _initialize_buffers(ctx, fn_name):
|
|
additional_args = ctx.additional_args
|
|
for idx, name in buffers[fn_name]:
|
|
# TODO: some buffers are necessary only for update output and can be
|
|
# freed right afterwards
|
|
buffer = ctx.buffers[name]
|
|
additional_args = additional_args[:idx] + [buffer] + additional_args[idx:]
|
|
return tuple(additional_args)
|
|
|
|
@staticmethod
|
|
def symbolic(*args, **kwargs):
|
|
return symbolic_fn(*args, **kwargs)
|
|
|
|
@staticmethod
|
|
def forward(ctx, input, *params):
|
|
ctx._backend = type2backend[input.type()]
|
|
|
|
ctx.additional_args = []
|
|
tensor_param_list = []
|
|
for param in params:
|
|
if isinstance(param, torch.Tensor):
|
|
if type(param) != type(input):
|
|
raise RuntimeError("input type ({}) doesn't match the type of "
|
|
"a parameter tensor ({})".format(torch.typename(input),
|
|
torch.typename(param)))
|
|
tensor_param_list.append(param)
|
|
else:
|
|
ctx.additional_args.append(param)
|
|
|
|
tensor_params = tuple(tensor_param_list)
|
|
if is_inplace:
|
|
ctx.inplace = params[-1]
|
|
# Allocate temporary buffers and insert them into additional_args
|
|
ctx.buffers = defaultdict(type(input))
|
|
additional_args = _initialize_buffers(ctx, 'update_output')
|
|
|
|
# Fill in optional params with None
|
|
args = tensor_params
|
|
for i in range(len(params), len(expected_params)):
|
|
param = expected_params[i]
|
|
if param.is_optional:
|
|
args += (None,)
|
|
else:
|
|
raise ValueError("missing required argument '%s'" % param.name)
|
|
|
|
args += tuple(additional_args)
|
|
|
|
# If the module is working in-place its output will be set to the
|
|
# same storage as input, but its tensor won't be dirty.
|
|
if is_inplace and ctx.inplace:
|
|
ctx.mark_dirty(input)
|
|
output = input
|
|
else:
|
|
output = input.new()
|
|
|
|
if save_output:
|
|
ctx.save_for_backward(input, output, *tensor_params)
|
|
else:
|
|
ctx.save_for_backward(input, *tensor_params)
|
|
|
|
if not ctx.requires_grad:
|
|
del ctx.buffers
|
|
|
|
getattr(ctx._backend, update_output.name)(ctx._backend.library_state, input, output, *args)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
t = ctx.saved_tensors
|
|
input, tensor_params = t[0], t[1:]
|
|
# Some notes on this function call:
|
|
# 1) We need to pass params as *params so they are unwrapped correctly in backward_cls_forward.
|
|
# 2) apply returns the grad_input / grad_tensor_params, so we need to append Nones equal to the number
|
|
# of non tensor_params, i.e. the additional_args
|
|
# 3) it may be simpler to recalculate some of these parameters (e.g. ctx._backend) in backward_cls_forward?
|
|
|
|
return (backward_cls.apply(input, grad_output, ctx.additional_args, ctx._backend, ctx.buffers, *tensor_params) +
|
|
(None,) * len(ctx.additional_args))
|
|
|
|
@staticmethod
|
|
def backward_cls_forward(ctx, input, grad_output, additional_args_ctx, backend_ctx, buffers_ctx, *params):
|
|
ctx.additional_args = additional_args_ctx
|
|
ctx.buffers = buffers_ctx
|
|
ctx._backend = backend_ctx
|
|
ctx.save_for_backward(input, grad_output, *params)
|
|
if save_output:
|
|
output = params[0]
|
|
params = params[1:]
|
|
|
|
grad_params = tuple(None for p in params)
|
|
grad_input_tuple = (None,)
|
|
if is_inplace:
|
|
ctx.inplace = additional_args_ctx[-1]
|
|
|
|
if ctx.needs_input_grad[0]:
|
|
additional_args = _initialize_buffers(ctx, 'update_grad_input')
|
|
if save_output:
|
|
additional_args = (output,) + additional_args
|
|
|
|
if is_inplace and ctx.inplace:
|
|
assert additional_args[-1] is True
|
|
tmp_args = list(additional_args)
|
|
tmp_args[-1] = False
|
|
additional_args = tuple(tmp_args)
|
|
grad_input = input.new(input.size())
|
|
params_without_bias = params if len(params) < 2 else params[:1]
|
|
update_grad_input_fn = getattr(ctx._backend, update_grad_input.name)
|
|
gi_args = params_without_bias + additional_args
|
|
update_grad_input_fn(ctx._backend.library_state, input, grad_output, grad_input, *gi_args)
|
|
grad_input_tuple = (grad_input,)
|
|
|
|
if acc_grad_parameters and any(ctx.needs_input_grad[1:]):
|
|
additional_args = _initialize_buffers(ctx, 'acc_grad_parameters')
|
|
grad_params = tuple(p.new(p.size()).zero_() for p in params)
|
|
appended_grads = len(expected_params) - len(grad_params)
|
|
grad_params += (None,) * appended_grads
|
|
acc_grad_parameters_fn = getattr(ctx._backend, acc_grad_parameters.name)
|
|
param_args = grad_params + additional_args + (1,)
|
|
acc_grad_parameters_fn(ctx._backend.library_state, input, grad_output, *param_args)
|
|
if appended_grads:
|
|
grad_params = grad_params[:-appended_grads]
|
|
|
|
return grad_input_tuple + grad_params
|
|
|
|
@staticmethod
|
|
def backward_cls_backward(ctx, *grad_params):
|
|
return double_backwards_fn(ctx, *grad_params)
|
|
|
|
base_class = Function if not is_inplace else InplaceFunction
|
|
backward_cls = type(class_name + "Backward", (base_class,), dict(forward=backward_cls_forward,
|
|
backward=backward_cls_backward))
|
|
|
|
return type(class_name, (base_class,), dict(forward=forward, backward=backward, symbolic=symbolic)), backward_cls
|
|
|
|
|
|
def _generate_function_classes(scope_dict):
|
|
global function_list, function_by_name
|
|
function_list = parse_header(THNN_H_PATH)
|
|
function_by_name = {fn.name: fn for fn in function_list}
|
|
classes_to_generate = {fn.name.partition('_')[0] for fn in function_list}
|
|
exceptions = {
|
|
'Linear',
|
|
'IndexLinear',
|
|
'SpatialFullConvolution',
|
|
'SpatialConvolutionMM',
|
|
'TemporalConvolution',
|
|
'SpatialAveragePooling',
|
|
'SpatialMaxUnpooling',
|
|
'VolumetricAveragePooling',
|
|
'VolumetricMaxPooling',
|
|
'VolumetricMaxUnpooling',
|
|
'VolumetricConvolution',
|
|
'VolumetricFullConvolution',
|
|
'VolumetricConvolutionMM',
|
|
'TemporalMaxPooling',
|
|
'BatchNormalization',
|
|
'LookupTable',
|
|
'LookupTableBag',
|
|
'PReLU',
|
|
'RReLU',
|
|
'SoftMax',
|
|
'LogSoftMax',
|
|
'GRUFused',
|
|
'LSTMFused',
|
|
'unfolded',
|
|
}
|
|
name_remap = {
|
|
'TemporalConvolution': 'Conv1d',
|
|
'SpatialDilatedConvolution': 'DilatedConv2d',
|
|
'SpatialMaxUnpooling': 'MaxUnpool2d',
|
|
'VolumetricMaxUnpooling': 'MaxUnpool3d',
|
|
'HardTanh': 'Hardtanh',
|
|
'HardShrink': 'Hardshrink',
|
|
'SoftPlus': 'Softplus',
|
|
'SoftShrink': 'Softshrink',
|
|
'MSECriterion': 'MSELoss',
|
|
'AbsCriterion': 'L1Loss',
|
|
'BCECriterion': 'BCELoss',
|
|
'ClassNLLCriterion': 'NLLLoss',
|
|
'DistKLDivCriterion': 'KLDivLoss',
|
|
'SpatialClassNLLCriterion': 'NLLLoss2d',
|
|
'MultiLabelMarginCriterion': 'MultiLabelMarginLoss',
|
|
'MultiMarginCriterion': 'MultiMarginLoss',
|
|
'SmoothL1Criterion': 'SmoothL1Loss',
|
|
'SoftMarginCriterion': 'SoftMarginLoss',
|
|
}
|
|
|
|
classes_to_generate -= exceptions
|
|
for fn in classes_to_generate:
|
|
update_output = function_by_name[fn + '_updateOutput']
|
|
update_grad_input = function_by_name[fn + '_updateGradInput']
|
|
acc_grad_parameters = function_by_name.get(fn + '_accGradParameters')
|
|
class_name = name_remap.get(fn, fn)
|
|
double_backwards_fn = double_backwards_fns.get(class_name)
|
|
if double_backwards_fn is None:
|
|
def make_default_double_backwards_fn(class_name):
|
|
def default_double_backwards_fn(ctx, *grad_params):
|
|
raise ValueError(class_name + " can only be differentiated once.")
|
|
return default_double_backwards_fn
|
|
double_backwards_fn = make_default_double_backwards_fn(class_name)
|
|
symbolic_fn = symbolic_fns.get(class_name)
|
|
# This has to call a function to retain correct references to functions
|
|
is_criterion_fn = 'Criterion' in fn
|
|
if is_criterion_fn:
|
|
cls, backward_cls = _make_function_class_criterion(class_name, update_output,
|
|
update_grad_input, acc_grad_parameters,
|
|
double_backwards_fn, symbolic_fn)
|
|
else:
|
|
cls, backward_cls = _make_function_class(class_name, update_output,
|
|
update_grad_input, acc_grad_parameters,
|
|
double_backwards_fn, symbolic_fn)
|
|
scope_dict[class_name] = cls
|
|
scope_dict[backward_cls.__name__] = backward_cls
|
|
if not class_name.startswith('_'):
|
|
_all_functions.append(cls)
|
|
_all_functions.append(backward_cls)
|
|
|
|
|
|
_generate_function_classes(locals())
|