expanded weights: group norm faster rule

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73696

Approved by: https://github.com/zou3519
This commit is contained in:
samdow
2022-03-31 15:25:07 +00:00
committed by PyTorch MergeBot
parent 8d7242a18b
commit 72f7193f4d
5 changed files with 112 additions and 24 deletions

View File

@ -6,6 +6,7 @@ import unittest
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads
from torch.testing._internal.common_cuda import TEST_CUDA
@ -247,6 +248,15 @@ class TestExpandedWeightFunctional(TestCase):
for (res, exp) in zip(result, expected):
self.assertEqual(res, exp, atol=1e-4, rtol=5e-5)
def test_group_norm_error(self, device):
# group norm has to call native_group_norm. This checks that it hits the same errors
# that normal group norm would
N = 3
C = 5
inp = torch.randn(N, C)
with self.assertRaisesRegex(RuntimeError, r"Expected number of channels in input to be divisible"):
F.group_norm(inp, 2) # 5 is not divisible by 2
class TestExpandedWeightModule(TestCase):
def _do_test(self, module, input):
@ -370,7 +380,7 @@ class ContextManagerTests(TestBase):
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
# These currently use the legacy nn tests
supported_modules = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'Embedding', 'LayerNorm']
supported_modules = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'Embedding', 'LayerNorm', 'GroupNorm']
supported_tests = [t for t in module_tests + new_module_tests if 'module_name' in t and t['module_name'] in supported_modules]
for test_param in supported_tests:
if 'constructor' not in test_param: