mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
expanded weights: group norm faster rule
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73696 Approved by: https://github.com/zou3519
This commit is contained in:
@ -16,6 +16,36 @@
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
void check_group_norm_inputs(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const Tensor& bias,
|
||||
int64_t C,
|
||||
int64_t num_groups) {
|
||||
TORCH_CHECK(
|
||||
C % num_groups == 0,
|
||||
"Expected number of channels in input to be divisible by ",
|
||||
"num_groups, but got input of shape ",
|
||||
input.sizes(),
|
||||
" and "
|
||||
"num_groups=",
|
||||
num_groups);
|
||||
TORCH_CHECK(
|
||||
!weight.defined() || (weight.dim() == 1 && weight.numel() == C),
|
||||
"Expected weight to be a vector of size equal to the number of ",
|
||||
"channels in input, but got weight of shape ",
|
||||
weight.sizes(),
|
||||
" and input of shape ",
|
||||
input.sizes());
|
||||
TORCH_CHECK(
|
||||
!bias.defined() || (bias.dim() == 1 && bias.numel() == C),
|
||||
"Expected bias to be a vector of size equal to the number of ",
|
||||
"channels in input, but got bias of shape ",
|
||||
weight.sizes(),
|
||||
" and input of shape ",
|
||||
input.sizes());
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> native_group_norm(
|
||||
const Tensor& X,
|
||||
const c10::optional<Tensor>& gamma_opt /* optional */,
|
||||
@ -31,6 +61,9 @@ std::tuple<Tensor, Tensor, Tensor> native_group_norm(
|
||||
const Tensor& gamma = *gamma_maybe_owned;
|
||||
const Tensor& beta = c10::value_or_else(beta_opt, [] { return Tensor(); });
|
||||
|
||||
// repeated check so expanded weights can call native_group_norm directly but
|
||||
// save mean and variance from forward
|
||||
check_group_norm_inputs(X, gamma, beta, C, group);
|
||||
auto memory_format = X.device().is_cpu() ?
|
||||
X.suggest_memory_format() : at::MemoryFormat::Contiguous;
|
||||
|
||||
@ -128,28 +161,7 @@ Tensor group_norm(
|
||||
|
||||
const int64_t N = input.size(0);
|
||||
const int64_t C = input.size(1);
|
||||
TORCH_CHECK(
|
||||
C % num_groups == 0,
|
||||
"Expected number of channels in input to be divisible by ",
|
||||
"num_groups, but got input of shape ",
|
||||
input.sizes(),
|
||||
" and "
|
||||
"num_groups=",
|
||||
num_groups);
|
||||
TORCH_CHECK(
|
||||
!weight.defined() || (weight.dim() == 1 && weight.numel() == C),
|
||||
"Expected weight to be a vector of size equal to the number of ",
|
||||
"channels in input, but got weight of shape ",
|
||||
weight.sizes(),
|
||||
" and input of shape ",
|
||||
input.sizes());
|
||||
TORCH_CHECK(
|
||||
!bias.defined() || (bias.dim() == 1 && bias.numel() == C),
|
||||
"Expected bias to be a vector of size equal to the number of ",
|
||||
"channels in input, but got bias of shape ",
|
||||
weight.sizes(),
|
||||
" and input of shape ",
|
||||
input.sizes());
|
||||
check_group_norm_inputs(input, weight, bias, C, num_groups);
|
||||
|
||||
const auto input_shape = input.sizes();
|
||||
const int64_t HxW =
|
||||
|
||||
@ -6,6 +6,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
@ -247,6 +248,15 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
for (res, exp) in zip(result, expected):
|
||||
self.assertEqual(res, exp, atol=1e-4, rtol=5e-5)
|
||||
|
||||
def test_group_norm_error(self, device):
|
||||
# group norm has to call native_group_norm. This checks that it hits the same errors
|
||||
# that normal group norm would
|
||||
|
||||
N = 3
|
||||
C = 5
|
||||
inp = torch.randn(N, C)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected number of channels in input to be divisible"):
|
||||
F.group_norm(inp, 2) # 5 is not divisible by 2
|
||||
|
||||
class TestExpandedWeightModule(TestCase):
|
||||
def _do_test(self, module, input):
|
||||
@ -370,7 +380,7 @@ class ContextManagerTests(TestBase):
|
||||
|
||||
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
|
||||
# These currently use the legacy nn tests
|
||||
supported_modules = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'Embedding', 'LayerNorm']
|
||||
supported_modules = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'Embedding', 'LayerNorm', 'GroupNorm']
|
||||
supported_tests = [t for t in module_tests + new_module_tests if 'module_name' in t and t['module_name'] in supported_modules]
|
||||
for test_param in supported_tests:
|
||||
if 'constructor' not in test_param:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from .conv_expanded_weights import ConvPerSampleGrad
|
||||
from .embedding_expanded_weights import EmbeddingPerSampleGrad
|
||||
from .group_norm_expanded_weights import GroupNormPerSampleGrad
|
||||
from .layer_norm_expanded_weights import LayerNormPerSampleGrad
|
||||
from .linear_expanded_weights import LinearPerSampleGrad
|
||||
from .expanded_weights_impl import ExpandedWeight
|
||||
|
||||
@ -0,0 +1,64 @@
|
||||
from functools import reduce
|
||||
import operator
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads
|
||||
from .expanded_weights_utils import standard_kwargs, \
|
||||
forward_helper, set_grad_sample_if_exists, unpack_expanded_weight_or_tensor
|
||||
from typing import List, Optional
|
||||
|
||||
@implements_per_sample_grads(F.group_norm)
|
||||
class GroupNormPerSampleGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
|
||||
expanded_args, expanded_kwargs = standard_kwargs(kwarg_names, expanded_args_and_kwargs)
|
||||
input, num_groups = expanded_args
|
||||
N = input.shape[0]
|
||||
C = input.shape[1]
|
||||
HxW = reduce(operator.mul, input.shape[2:], 1)
|
||||
weight, bias, eps = expanded_kwargs['weight'], expanded_kwargs['bias'], expanded_kwargs['eps']
|
||||
output, mean, rstd = forward_helper(torch.native_group_norm, (input, weight, bias, N, C, HxW, num_groups, eps), {})
|
||||
ctx.input, ctx.num_groups = input, num_groups
|
||||
ctx.weight, ctx.eps = weight, eps
|
||||
ctx.mean, ctx.rstd = mean, rstd
|
||||
if isinstance(bias, ExpandedWeight):
|
||||
ctx.bias = bias
|
||||
if input.requires_grad and isinstance(weight, ExpandedWeight):
|
||||
ctx.weight = weight
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, num_groups = ctx.input, ctx.num_groups
|
||||
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
|
||||
mean, rstd = ctx.mean, ctx.rstd
|
||||
|
||||
results: List[Optional[torch.Tensor]] = []
|
||||
results.append(None) # for kwarg names
|
||||
results.append(None) # for op reference
|
||||
|
||||
if input.requires_grad:
|
||||
weight_c = unpack_expanded_weight_or_tensor(weight, lambda t: t.contiguous())
|
||||
input_c = input.contiguous()
|
||||
grad_output_c = grad_output.contiguous() if grad_output is not None else None
|
||||
N = input.shape[0]
|
||||
C = input.shape[1]
|
||||
HxW = 1
|
||||
for s in input.shape[2:]:
|
||||
HxW *= s
|
||||
bw_fn = torch.ops.aten.native_group_norm_backward
|
||||
results.append(bw_fn(grad_output_c, input_c,
|
||||
mean, rstd, weight_c, N, C, HxW, num_groups, (True, False, False))[0])
|
||||
else:
|
||||
results.append(None)
|
||||
|
||||
# weight and bias don't compute batched gradients; no other arguments are differentiable
|
||||
results = results + [None] * 4
|
||||
|
||||
# set grad_sample field for weight and bias with per sample gradients
|
||||
if hasattr(ctx, "weight"):
|
||||
set_grad_sample_if_exists(weight,
|
||||
lambda _: torch.einsum("ni...->ni", F.group_norm(input, num_groups, eps=eps) * grad_output))
|
||||
if hasattr(ctx, "bias"):
|
||||
set_grad_sample_if_exists(bias, lambda _: torch.einsum("ni...->ni", grad_output))
|
||||
return tuple(results)
|
||||
@ -11233,7 +11233,8 @@ op_db: List[OpInfo] = [
|
||||
# Consider making it a parameter or input, or detaching the gradient
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,))
|
||||
],
|
||||
sample_inputs_func=sample_inputs_group_norm,),
|
||||
sample_inputs_func=sample_inputs_group_norm,
|
||||
supports_expanded_weight=True,),
|
||||
OpInfo('nn.functional.instance_norm',
|
||||
# no ref because instance_norm will often have numerical instability (large numbers or nan)
|
||||
dtypes=floating_types(),
|
||||
|
||||
Reference in New Issue
Block a user