expanded weights: embedding faster rule

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73693

Approved by: https://github.com/zou3519
This commit is contained in:
samdow
2022-03-29 17:20:55 +00:00
committed by PyTorch MergeBot
parent c074a53002
commit fc47257b30
4 changed files with 67 additions and 3 deletions

View File

@ -187,10 +187,16 @@ class TestExpandedWeightFunctional(TestCase):
def test_expanded_weight_forward(self, device, dtype, op):
sample_inputs = op.sample_inputs(device, dtype)
for sample_input in supported_inputs(op, sample_inputs):
if op.name == "nn.functional.embedding": # embedding flips its argument order for autograd tests
sample_input = SampleInput(sample_input.args[0].clone(),
args=(sample_input.input.clone(),),
kwargs=sample_input.kwargs)
if "cuda" in device and "max_norm" in sample_input.kwargs and "padding_idx" in sample_input.kwargs:
self.skipTest("embedding is non-determinstic in this case, see issue #74679")
batch_size = sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size)
expanded_weight_result = op(ew_input, *ew_args, **ew_kwargs)
normal_result = op(sample_input.input, *sample_input.args, **sample_input.kwargs)
expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs)
normal_result = run_op(op, sample_input.input, *sample_input.args, **sample_input.kwargs)
self.assertEqual(expanded_weight_result, normal_result)
def test_expanded_weight_error(self, device):
@ -320,6 +326,8 @@ class ContextManagerTests(TestBase):
def test_context_manager(self, test_case, device):
kwargs = {'device': device, 'dtype': torch.double}
module = self.constructor(*self.constructor_args).to(**kwargs)
if 'Embedding' in self.get_name():
kwargs['dtype'] = torch.long
input = self._get_input().to(**kwargs)
if len(input.shape) == 0 or input.shape[0] == 0:
raise unittest.SkipTest("Can't get per sample gradients when no batch dim or batch dim is 0")
@ -338,7 +346,7 @@ class ContextManagerTests(TestBase):
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
# These currently use the legacy nn tests
supported_modules = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d']
supported_modules = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'Embedding']
supported_tests = [t for t in module_tests + new_module_tests if 'module_name' in t and t['module_name'] in supported_modules]
for test_param in supported_tests:
if 'constructor' not in test_param: