mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
expanded weights: embedding faster rule
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73693 Approved by: https://github.com/zou3519
This commit is contained in:
@ -187,10 +187,16 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
def test_expanded_weight_forward(self, device, dtype, op):
|
||||
sample_inputs = op.sample_inputs(device, dtype)
|
||||
for sample_input in supported_inputs(op, sample_inputs):
|
||||
if op.name == "nn.functional.embedding": # embedding flips its argument order for autograd tests
|
||||
sample_input = SampleInput(sample_input.args[0].clone(),
|
||||
args=(sample_input.input.clone(),),
|
||||
kwargs=sample_input.kwargs)
|
||||
if "cuda" in device and "max_norm" in sample_input.kwargs and "padding_idx" in sample_input.kwargs:
|
||||
self.skipTest("embedding is non-determinstic in this case, see issue #74679")
|
||||
batch_size = sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1
|
||||
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size)
|
||||
expanded_weight_result = op(ew_input, *ew_args, **ew_kwargs)
|
||||
normal_result = op(sample_input.input, *sample_input.args, **sample_input.kwargs)
|
||||
expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs)
|
||||
normal_result = run_op(op, sample_input.input, *sample_input.args, **sample_input.kwargs)
|
||||
self.assertEqual(expanded_weight_result, normal_result)
|
||||
|
||||
def test_expanded_weight_error(self, device):
|
||||
@ -320,6 +326,8 @@ class ContextManagerTests(TestBase):
|
||||
def test_context_manager(self, test_case, device):
|
||||
kwargs = {'device': device, 'dtype': torch.double}
|
||||
module = self.constructor(*self.constructor_args).to(**kwargs)
|
||||
if 'Embedding' in self.get_name():
|
||||
kwargs['dtype'] = torch.long
|
||||
input = self._get_input().to(**kwargs)
|
||||
if len(input.shape) == 0 or input.shape[0] == 0:
|
||||
raise unittest.SkipTest("Can't get per sample gradients when no batch dim or batch dim is 0")
|
||||
@ -338,7 +346,7 @@ class ContextManagerTests(TestBase):
|
||||
|
||||
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
|
||||
# These currently use the legacy nn tests
|
||||
supported_modules = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d']
|
||||
supported_modules = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'Embedding']
|
||||
supported_tests = [t for t in module_tests + new_module_tests if 'module_name' in t and t['module_name'] in supported_modules]
|
||||
for test_param in supported_tests:
|
||||
if 'constructor' not in test_param:
|
||||
|
Reference in New Issue
Block a user