mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
expanded weights: conv faster rule (#73692)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73692 Test Plan: Imported from OSS Reviewed By: bdhirsh Differential Revision: D34719302 Pulled By: samdow fbshipit-source-id: 2288320a5f5d6a442da78e9fbe722f300b844be9 (cherry picked from commit a4cf23383c16d3c61d53e9d21f426259d2dc2d37)
This commit is contained in:
committed by
PyTorch MergeBot
parent
6cf2cafe60
commit
78e17eaadc
@ -6,7 +6,9 @@ import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_device_type import OpDTypes, instantiate_device_type_tests, ops
|
||||
from torch.testing._internal.common_nn import TestBase, module_tests, new_module_tests
|
||||
from torch.testing._internal.common_utils import TestCase, freeze_rng_state, make_tensor, run_tests
|
||||
@ -159,7 +161,7 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
for (result_grad, expected_grad) in zip(expanded_weight_grad, per_sample_grad):
|
||||
if result_grad is None:
|
||||
result_grad = torch.zeros_like(expected_grad)
|
||||
assert torch.allclose(result_grad, expected_grad), f"Got {result_grad}, expected {expected_grad}"
|
||||
self.assertEqual(result_grad, expected_grad)
|
||||
|
||||
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
|
||||
def test_unsupported_expand_weights(self, device, dtype, op):
|
||||
@ -198,6 +200,47 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expanded Weights encountered but cannot handle function"):
|
||||
torch.add(sample_input, ExpandedWeight(sample_weight, batch_size))
|
||||
|
||||
def test_small_model(self, device):
|
||||
def convnet(num_classes):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AvgPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AvgPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AvgPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
nn.Flatten(start_dim=1, end_dim=-1),
|
||||
nn.Linear(128, num_classes, bias=True),
|
||||
)
|
||||
|
||||
batch_size = 32
|
||||
model = convnet(10).to(device)
|
||||
input = torch.randn([batch_size, 3, 28, 28], device=device)
|
||||
targets = torch.randint(0, 10, (batch_size,), device=device)
|
||||
criterion = CrossEntropyLoss(reduction='sum') # use a loss that doesn't average across the batch to test in a for loop
|
||||
result = call_for_per_sample_grads(model, batch_size, input)
|
||||
loss = criterion(result, targets)
|
||||
loss.backward()
|
||||
result = []
|
||||
for weight in model.parameters():
|
||||
result.append(weight.grad_sample)
|
||||
del weight.grad_sample
|
||||
|
||||
expected = []
|
||||
for i in range(batch_size):
|
||||
loss = criterion(model(input[i].unsqueeze(0)), targets[i].unsqueeze(0))
|
||||
expected.append(torch.autograd.grad(loss, model.parameters(), torch.ones_like(loss)))
|
||||
|
||||
expected = [torch.stack(grad) for grad in zip(*expected)]
|
||||
for (res, exp) in zip(result, expected):
|
||||
self.assertEqual(res, exp, atol=1e-4, rtol=5e-5)
|
||||
|
||||
|
||||
class TestExpandedWeightModule(TestCase):
|
||||
def _do_test(self, module, input):
|
||||
@ -212,7 +255,7 @@ class TestExpandedWeightModule(TestCase):
|
||||
del param.grad_sample
|
||||
|
||||
# get per sample grads with a for loop
|
||||
expected_res = torch.tensor(0.)
|
||||
expected_res = torch.tensor(0., device=input.device, dtype=torch.double)
|
||||
expected_grads = []
|
||||
for i in range(batch_size):
|
||||
res = module(input[i].unsqueeze(0)).sum()
|
||||
@ -220,7 +263,7 @@ class TestExpandedWeightModule(TestCase):
|
||||
expected_res += res
|
||||
expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads))
|
||||
self.assertEqual(actual_res, expected_res)
|
||||
assert [torch.allclose(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
[self.assertEqual(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
|
||||
def _do_test_multi_input(self, module, input):
|
||||
class TestModule(nn.Module):
|
||||
@ -248,7 +291,7 @@ class TestExpandedWeightModule(TestCase):
|
||||
res = module(input[i].unsqueeze(0)).sum()
|
||||
expected_grads.append(torch.autograd.grad(res, module.parameters(), torch.ones_like(res)))
|
||||
expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads))
|
||||
assert [torch.allclose(actual, 2 * expected) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
assert [self.assertEqual(actual, 2 * expected) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
|
||||
def test_per_sample_api_failing(self):
|
||||
module = nn.Linear(10, 10)
|
||||
@ -266,23 +309,26 @@ class TestExpandedWeightModule(TestCase):
|
||||
|
||||
class ContextManagerTests(TestBase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.test_cpu = kwargs.get('test_cpu', True)
|
||||
self.test_cuda = kwargs.get('test_cuda', True)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def constructor_args(self):
|
||||
return self._get_arg('constructor_args', False)
|
||||
|
||||
def test_context_manager(self, test_case):
|
||||
module = self.constructor(*self.constructor_args)
|
||||
input = self._get_input()
|
||||
def test_context_manager(self, test_case, device):
|
||||
kwargs = {'device': device, 'dtype': torch.double}
|
||||
module = self.constructor(*self.constructor_args).to(**kwargs)
|
||||
input = self._get_input().to(**kwargs)
|
||||
if len(input.shape) == 0 or input.shape[0] == 0:
|
||||
raise unittest.SkipTest("Can't get per sample gradients when no batch dim or batch dim is 0")
|
||||
if self.constructor == torch.nn.Linear and len(input.shape) == 1:
|
||||
raise unittest.SkipTest("Can't get per sample gradients for input of rank 1")
|
||||
test_case._do_test(module, input)
|
||||
|
||||
def test_context_manager_multiple_inputs(self, test_case):
|
||||
module = self.constructor(*self.constructor_args)
|
||||
def test_context_manager_multiple_inputs(self, test_case, device):
|
||||
module = self.constructor(*self.constructor_args).to(device)
|
||||
input = self._get_input()
|
||||
if len(input.shape) == 0 or input.shape[0] == 0:
|
||||
raise unittest.SkipTest("Can't get per sample gradients when no batch dim or batch dim is 0")
|
||||
@ -292,7 +338,7 @@ class ContextManagerTests(TestBase):
|
||||
|
||||
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
|
||||
# These currently use the legacy nn tests
|
||||
supported_modules = ['Linear']
|
||||
supported_modules = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d']
|
||||
supported_tests = [t for t in module_tests + new_module_tests if 'module_name' in t and t['module_name'] in supported_modules]
|
||||
for test_param in supported_tests:
|
||||
if 'constructor' not in test_param:
|
||||
@ -308,9 +354,14 @@ for test_param in supported_tests:
|
||||
raise RuntimeError('Found two tests with the same name: ' + test_name)
|
||||
if decorator is not None:
|
||||
fn = decorator(fn)
|
||||
setattr(TestExpandedWeightModule, test_name, lambda self, test=test: test.test_context_manager(self))
|
||||
setattr(TestExpandedWeightModule, test_name_multi_input,
|
||||
lambda self, test=test: test.test_context_manager_multiple_inputs(self))
|
||||
if test.test_cpu:
|
||||
setattr(TestExpandedWeightModule, test_name, lambda self, test=test: test.test_context_manager(self, 'cpu'))
|
||||
setattr(TestExpandedWeightModule, test_name_multi_input,
|
||||
lambda self, test=test: test.test_context_manager_multiple_inputs(self, 'cpu'))
|
||||
if TEST_CUDA and test.test_cuda:
|
||||
# since this checks derivatives, only use double for precision
|
||||
setattr(TestExpandedWeightModule, test_name + '_cuda_double',
|
||||
lambda self, test=test: test.test_context_manager(self, 'cuda'))
|
||||
|
||||
# ------------- HELPER FUNCTIONS -----------------
|
||||
|
||||
@ -340,12 +391,13 @@ def supported_inputs(op, sample_inputs, supported_inputs=True):
|
||||
operations that would cause inter-batch operations. Removes all of the cases it cannot deal with
|
||||
"""
|
||||
def filter_fn(input):
|
||||
convolutions = ["nn.functional.conv1d", "nn.functional.conv2d", "nn.functional.conv3d"]
|
||||
if op.name == "nn.functional.linear":
|
||||
is_supported_input = len(input.input.shape) > 1 # input of rank 1 means no batch dim
|
||||
elif op.name == "nn.functional.layer_norm":
|
||||
normalized_shape = input.args[0]
|
||||
is_supported_input = input.input.shape != normalized_shape # would cause inter-batch operations
|
||||
elif op.name == "nn.functional.conv2d":
|
||||
elif op.name in convolutions:
|
||||
# currently can't deal with padding computation on Python level
|
||||
is_supported_input = 'padding' not in input.kwargs or not isinstance(input.kwargs['padding'], str)
|
||||
elif op.name == "nn.functional.embedding":
|
||||
|
@ -1,3 +1,4 @@
|
||||
from .conv_expanded_weights import ConvPerSampleGrad
|
||||
from .linear_expanded_weights import LinearPerSampleGrad
|
||||
from .expanded_weights_impl import ExpandedWeight
|
||||
|
||||
|
37
torch/nn/utils/_expanded_weights/conv_expanded_weights.py
Normal file
37
torch/nn/utils/_expanded_weights/conv_expanded_weights.py
Normal file
@ -0,0 +1,37 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .conv_utils import conv_backward, conv_args_and_kwargs
|
||||
from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads
|
||||
from .expanded_weights_utils import forward_helper
|
||||
|
||||
@implements_per_sample_grads(F.conv1d)
|
||||
@implements_per_sample_grads(F.conv2d)
|
||||
@implements_per_sample_grads(F.conv3d)
|
||||
class ConvPerSampleGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, kwarg_names, conv_fn, *expanded_args_and_kwargs):
|
||||
if any([isinstance(i, str) for i in expanded_args_and_kwargs]):
|
||||
raise RuntimeError("Expanded Weights does not support convolution padding as a string. "
|
||||
"Please file an issue to prioritize support")
|
||||
expanded_args, expanded_kwargs = conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs)
|
||||
output = forward_helper(conv_fn, expanded_args, expanded_kwargs)
|
||||
input, weight = expanded_args
|
||||
|
||||
ctx.conv_fn = conv_fn
|
||||
|
||||
ctx.batch_size = input.shape[0]
|
||||
ctx.input_required_grad = input.requires_grad
|
||||
ctx.stride, ctx.padding = expanded_kwargs['stride'], expanded_kwargs['padding']
|
||||
ctx.dilation, ctx.groups = expanded_kwargs['dilation'], expanded_kwargs['groups']
|
||||
|
||||
if isinstance(weight, ExpandedWeight):
|
||||
ctx.input = input
|
||||
ctx.weight = weight
|
||||
ctx.bias = expanded_kwargs['bias']
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return conv_backward(ctx.conv_fn, ctx, grad_output)
|
184
torch/nn/utils/_expanded_weights/conv_utils.py
Normal file
184
torch/nn/utils/_expanded_weights/conv_utils.py
Normal file
@ -0,0 +1,184 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
from typing import List, Optional
|
||||
|
||||
from .expanded_weights_utils import \
|
||||
set_grad_sample_if_exists, unpack_expanded_weight_or_tensor
|
||||
|
||||
THRESHOLD = 32
|
||||
|
||||
def conv_picker(func, conv1dOpt, conv2dOpt, conv3dOpt):
|
||||
if func == F.conv1d:
|
||||
return conv1dOpt
|
||||
if func == F.conv2d:
|
||||
return conv2dOpt
|
||||
else:
|
||||
assert func == F.conv3d
|
||||
return conv3dOpt
|
||||
|
||||
def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs):
|
||||
args = expanded_args_and_kwargs[:len(expanded_args_and_kwargs) - len(kwarg_names)]
|
||||
kwargs = expanded_args_and_kwargs[len(expanded_args_and_kwargs) - len(kwarg_names):]
|
||||
kwargs = {name: arg for (name, arg) in zip(kwarg_names, kwargs)}
|
||||
|
||||
return conv_normalizer(*args, **kwargs)
|
||||
|
||||
def conv_normalizer(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
return (input, weight), {'bias': bias, 'stride': stride, 'padding': padding, 'dilation': dilation, 'groups': groups}
|
||||
|
||||
def conv_backward(func, ctx, grad_output):
|
||||
|
||||
def weight_grad_sample(weight):
|
||||
if (batch_size < THRESHOLD and groups == 1):
|
||||
return conv_group_weight_grad_sample(ctx.input, grad_output, weight_shape, stride, padding, dilation, batch_size, func)
|
||||
else:
|
||||
return conv_unfold_weight_grad_sample(ctx.input, grad_output, weight_shape, kernel_size,
|
||||
stride, padding, dilation, groups, func)
|
||||
|
||||
def expand(param):
|
||||
if isinstance(param, int):
|
||||
return conv_picker(func, (param,), (param, param), (param, param, param))
|
||||
else:
|
||||
return param
|
||||
|
||||
weight_shape = ctx.weight.shape
|
||||
stride, padding, dilation, groups = expand(ctx.stride), expand(ctx.padding), expand(ctx.dilation), ctx.groups
|
||||
|
||||
kernel_size = []
|
||||
for i in range(2, conv_picker(func, 3, 4, 5)):
|
||||
kernel_size.append(weight_shape[i])
|
||||
|
||||
batch_size = ctx.batch_size
|
||||
results: List[Optional[torch.Tensor]] = []
|
||||
results.append(None) # for kwarg names
|
||||
results.append(None) # for op reference
|
||||
|
||||
if ctx.input_required_grad:
|
||||
output_padding = []
|
||||
input_dims = conv_picker(func, 1, 2, 3)
|
||||
for i in range(input_dims):
|
||||
input_dim = ctx.input.shape[2 + i]
|
||||
output_padding.append((2 * padding[i] + input_dim - (kernel_size[i] * dilation[i] - dilation[i] + 1)) % stride[i])
|
||||
weight_ = unpack_expanded_weight_or_tensor(ctx.weight)
|
||||
transpose_func = conv_picker(func, F.conv_transpose1d, F.conv_transpose2d, F.conv_transpose3d)
|
||||
results.append(transpose_func(grad_output, weight_, None, stride, padding, tuple(output_padding), groups, dilation))
|
||||
else:
|
||||
results.append(None)
|
||||
# weight and bias don't compute batched gradients; no other arguments are differentiable
|
||||
results = results + [None] * 6
|
||||
|
||||
# set grad_sample field for weight and bias with per sample gradients
|
||||
set_grad_sample_if_exists(ctx.weight, weight_grad_sample)
|
||||
set_grad_sample_if_exists(ctx.bias, lambda _: grad_output.reshape(*grad_output.shape[:2], -1).sum(dim=2))
|
||||
return tuple(results)
|
||||
|
||||
def conv_unfold_weight_grad_sample(input, grad_output, weight_shape, kernel_size, stride, padding, dilation, groups, func):
|
||||
n = input.shape[0]
|
||||
in_channels = input.shape[1]
|
||||
|
||||
unfold_func = conv_picker(
|
||||
func,
|
||||
lambda: F.unfold(input.unsqueeze(-2),
|
||||
kernel_size=(1, kernel_size[0]),
|
||||
dilation=(1, dilation[0]),
|
||||
padding=(0, padding[0]),
|
||||
stride=(1, stride[0])),
|
||||
lambda: F.unfold(input, kernel_size, dilation=dilation, padding=padding, stride=stride),
|
||||
lambda: unfold3d(input, kernel_size, dilation, padding, stride)
|
||||
)
|
||||
|
||||
input = unfold_func()
|
||||
grad_output = grad_output.reshape(n, -1, input.shape[-1])
|
||||
|
||||
# n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz
|
||||
weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input)
|
||||
# rearrange the above tensor and extract diagonals.
|
||||
weight_grad_sample = weight_grad_sample.view(
|
||||
n,
|
||||
groups,
|
||||
-1,
|
||||
groups,
|
||||
int(in_channels / groups),
|
||||
np.prod(kernel_size),
|
||||
)
|
||||
weight_grad_sample = torch.einsum("ngrg...->ngr...", weight_grad_sample).contiguous()
|
||||
shape = [n] + list(weight_shape)
|
||||
weight_grad_sample = weight_grad_sample.view(shape)
|
||||
return weight_grad_sample
|
||||
|
||||
def conv_group_weight_grad_sample(input, grad_output, weight_shape, stride, padding, dilation, batch_size, func):
|
||||
I = input.shape[1]
|
||||
O = grad_output.shape[1]
|
||||
|
||||
input_ = input.transpose(0, 1)
|
||||
grad_output_ = grad_output.view(grad_output.shape[0] * grad_output.shape[1], 1, *grad_output.shape[2:])
|
||||
|
||||
weight_grad_sample = func(input_, grad_output_, None, stride=dilation, padding=padding, dilation=stride, groups=batch_size)
|
||||
input_dims = conv_picker(func, 3, 4, 5)
|
||||
for i in range(2, input_dims):
|
||||
weight_grad_sample = weight_grad_sample.narrow(i, 0, weight_shape[i])
|
||||
weight_grad_sample = weight_grad_sample.view(I, batch_size, O, *weight_grad_sample.shape[2:])
|
||||
weight_grad_sample = weight_grad_sample.movedim(0, 2)
|
||||
return weight_grad_sample
|
||||
|
||||
|
||||
def unfold3d(
|
||||
tensor,
|
||||
kernel_size,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
):
|
||||
r"""
|
||||
Extracts sliding local blocks from an batched input tensor.
|
||||
:class:`torch.nn.Unfold` only supports 4D inputs (batched image-like tensors).
|
||||
This method implements the same action for 5D inputs
|
||||
Args:
|
||||
tensor: An input tensor of shape ``(B, C, D, H, W)``.
|
||||
kernel_size: the size of the sliding blocks
|
||||
padding: implicit zero padding to be added on both sides of input
|
||||
stride: the stride of the sliding blocks in the input spatial dimensions
|
||||
dilation: the spacing between the kernel points.
|
||||
Returns:
|
||||
A tensor of shape ``(B, C * np.product(kernel_size), L)``, where L - output spatial dimensions.
|
||||
See :class:`torch.nn.Unfold` for more details
|
||||
Example:
|
||||
>>> B, C, D, H, W = 3, 4, 5, 6, 7
|
||||
>>> tensor = torch.arange(1, B*C*D*H*W + 1.).view(B, C, D, H, W)
|
||||
>>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape
|
||||
torch.Size([3, 32, 120])
|
||||
"""
|
||||
|
||||
if len(tensor.shape) != 5:
|
||||
raise ValueError(
|
||||
f"Input tensor must be of the shape [B, C, D, H, W]. Got{tensor.shape}"
|
||||
)
|
||||
|
||||
if dilation != (1, 1, 1):
|
||||
raise NotImplementedError(f"dilation={dilation} not supported.")
|
||||
|
||||
batch_size, channels, _, _, _ = tensor.shape
|
||||
|
||||
# Input shape: (B, C, D, H, W)
|
||||
tensor = F.pad(
|
||||
tensor, (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0])
|
||||
)
|
||||
# Output shape: (B, C, D+2*padding[2], H+2*padding[1], W+2*padding[0])
|
||||
|
||||
tensor = tensor.unfold(dimension=2, size=kernel_size[0], step=stride[0])
|
||||
tensor = tensor.unfold(dimension=3, size=kernel_size[1], step=stride[1])
|
||||
tensor = tensor.unfold(dimension=4, size=kernel_size[2], step=stride[2])
|
||||
# Output shape: (B, C, D_out, H_out, W_out, kernel_size[0], kernel_size[1], kernel_size[2])
|
||||
# For D_out, H_out, W_out definitions see :class:`torch.nn.Unfold`
|
||||
|
||||
tensor = tensor.permute(0, 2, 3, 4, 1, 5, 6, 7)
|
||||
# Output shape: (B, D_out, H_out, W_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
|
||||
|
||||
tensor = tensor.reshape(batch_size, -1, channels * np.prod(kernel_size)).transpose(
|
||||
1, 2
|
||||
)
|
||||
# Output shape: (B, D_out * H_out * W_out, C * kernel_size[0] * kernel_size[1] * kernel_size[2]
|
||||
|
||||
return tensor
|
@ -45,7 +45,7 @@ class ExpandedWeight(torch.Tensor):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func in cls.handled_functions:
|
||||
return cls.handled_functions[func].apply(tuple(kwargs.keys()), *(args + tuple(kwargs.values())))
|
||||
return cls.handled_functions[func].apply(tuple(kwargs.keys()), func, *(args + tuple(kwargs.values())))
|
||||
# We cannot use a fallback here because we do not know the batch dimension for any regular tensor inputs,
|
||||
# i.e. torch.add(torch.Tensor, ExpandedWeight)
|
||||
raise RuntimeError(f"Expanded Weights encountered but cannot handle function {func.__name__}")
|
||||
|
@ -75,6 +75,8 @@ def unpack_expanded_weight_or_tensor(maybe_expanded_weight, func=lambda x: x):
|
||||
raise RuntimeError("ExpandedWeights currently does not support a mixture of ExpandedWeight parameters "
|
||||
"and normal Parameters. Please file and issue with pytorch/pytorch")
|
||||
|
||||
|
||||
|
||||
def sum_over_all_but_batch_and_last_n(
|
||||
tensor: torch.Tensor, n_dims: int
|
||||
) -> torch.Tensor:
|
||||
|
@ -8,7 +8,7 @@ from typing import List, Optional
|
||||
@implements_per_sample_grads(F.linear)
|
||||
class LinearPerSampleGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, _, *expanded_args_and_kwargs):
|
||||
def forward(ctx, _, __, *expanded_args_and_kwargs):
|
||||
if len(expanded_args_and_kwargs[0].shape) <= 1:
|
||||
raise RuntimeError("Input does not have a batch dimension. Expanded Weights expected input "
|
||||
f"of at least rank 2, got of rank {len(expanded_args_and_kwargs[0].shape)}")
|
||||
@ -25,6 +25,7 @@ class LinearPerSampleGrad(torch.autograd.Function):
|
||||
bias = ctx.kwargs['bias']
|
||||
results: List[Optional[torch.Tensor]] = []
|
||||
results.append(None) # for kwarg_names
|
||||
results.append(None) # for op reference
|
||||
|
||||
if input.requires_grad:
|
||||
results.append(grad_output.matmul(unpack_expanded_weight_or_tensor(weight)))
|
||||
|
@ -10999,6 +10999,7 @@ op_db: List[OpInfo] = [
|
||||
# "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
),
|
||||
supports_expanded_weight=True,
|
||||
supports_out=False,),
|
||||
OpInfo('nn.functional.conv2d',
|
||||
aliases=('conv2d',),
|
||||
@ -11015,6 +11016,7 @@ op_db: List[OpInfo] = [
|
||||
# "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
),
|
||||
supports_expanded_weight=True,
|
||||
supports_out=False,),
|
||||
OpInfo('nn.functional.group_norm',
|
||||
aten_name='group_norm',
|
||||
|
Reference in New Issue
Block a user