mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
expanded weights: group norm faster rule
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73696 Approved by: https://github.com/zou3519
This commit is contained in:
@ -6,6 +6,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
@ -247,6 +248,15 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
for (res, exp) in zip(result, expected):
|
||||
self.assertEqual(res, exp, atol=1e-4, rtol=5e-5)
|
||||
|
||||
def test_group_norm_error(self, device):
|
||||
# group norm has to call native_group_norm. This checks that it hits the same errors
|
||||
# that normal group norm would
|
||||
|
||||
N = 3
|
||||
C = 5
|
||||
inp = torch.randn(N, C)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected number of channels in input to be divisible"):
|
||||
F.group_norm(inp, 2) # 5 is not divisible by 2
|
||||
|
||||
class TestExpandedWeightModule(TestCase):
|
||||
def _do_test(self, module, input):
|
||||
@ -370,7 +380,7 @@ class ContextManagerTests(TestBase):
|
||||
|
||||
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
|
||||
# These currently use the legacy nn tests
|
||||
supported_modules = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'Embedding', 'LayerNorm']
|
||||
supported_modules = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'Embedding', 'LayerNorm', 'GroupNorm']
|
||||
supported_tests = [t for t in module_tests + new_module_tests if 'module_name' in t and t['module_name'] in supported_modules]
|
||||
for test_param in supported_tests:
|
||||
if 'constructor' not in test_param:
|
||||
|
Reference in New Issue
Block a user