mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
[nn] mha : no-batch-dim support (python) (#67176)
Summary: Reference: https://github.com/pytorch/pytorch/issues/60585 * [x] Update docs * [x] Tests for shape checking Tests take roughly 20s on system that I use. Below is the timings for slowest 20 tests. ``` pytest test/test_modules.py -k _multih --durations=20 ============================================================================================== test session starts =============================================================================================== platform linux -- Python 3.10.0, pytest-6.2.5, py-1.10.0, pluggy-1.0.0 rootdir: /home/kshiteej/Pytorch/pytorch_no_batch_mha, configfile: pytest.ini plugins: hypothesis-6.23.2, repeat-0.9.1 collected 372 items / 336 deselected / 36 selected test/test_modules.py ..............ssssssss.............. [100%] ================================================================================================ warnings summary ================================================================================================ ../../.conda/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/backends/cudnn/__init__.py:73 test/test_modules.py::TestModuleCUDA::test_factory_kwargs_nn_MultiheadAttention_cuda_float32 /home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/backends/cudnn/__init__.py:73: UserWarning: PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild PyTorch making sure the library is visible to the build system. warnings.warn( -- Docs: https://docs.pytest.org/en/stable/warnings.html ============================================================================================== slowest 20 durations ============================================================================================== 8.66s call test/test_modules.py::TestModuleCUDA::test_gradgrad_nn_MultiheadAttention_cuda_float64 2.02s call test/test_modules.py::TestModuleCPU::test_gradgrad_nn_MultiheadAttention_cpu_float64 1.89s call test/test_modules.py::TestModuleCUDA::test_grad_nn_MultiheadAttention_cuda_float64 1.01s call test/test_modules.py::TestModuleCUDA::test_factory_kwargs_nn_MultiheadAttention_cuda_float32 0.51s call test/test_modules.py::TestModuleCPU::test_grad_nn_MultiheadAttention_cpu_float64 0.46s call test/test_modules.py::TestModuleCUDA::test_forward_nn_MultiheadAttention_cuda_float32 0.45s call test/test_modules.py::TestModuleCUDA::test_non_contiguous_tensors_nn_MultiheadAttention_cuda_float64 0.44s call test/test_modules.py::TestModuleCUDA::test_non_contiguous_tensors_nn_MultiheadAttention_cuda_float32 0.21s call test/test_modules.py::TestModuleCUDA::test_pickle_nn_MultiheadAttention_cuda_float64 0.21s call test/test_modules.py::TestModuleCUDA::test_pickle_nn_MultiheadAttention_cuda_float32 0.18s call test/test_modules.py::TestModuleCUDA::test_forward_nn_MultiheadAttention_cuda_float64 0.17s call test/test_modules.py::TestModuleCPU::test_non_contiguous_tensors_nn_MultiheadAttention_cpu_float32 0.16s call test/test_modules.py::TestModuleCPU::test_non_contiguous_tensors_nn_MultiheadAttention_cpu_float64 0.11s call test/test_modules.py::TestModuleCUDA::test_factory_kwargs_nn_MultiheadAttention_cuda_float64 0.08s call test/test_modules.py::TestModuleCPU::test_pickle_nn_MultiheadAttention_cpu_float32 0.08s call test/test_modules.py::TestModuleCPU::test_pickle_nn_MultiheadAttention_cpu_float64 0.06s call test/test_modules.py::TestModuleCUDA::test_repr_nn_MultiheadAttention_cuda_float64 0.06s call test/test_modules.py::TestModuleCUDA::test_repr_nn_MultiheadAttention_cuda_float32 0.06s call test/test_modules.py::TestModuleCPU::test_forward_nn_MultiheadAttention_cpu_float32 0.06s call test/test_modules.py::TestModuleCPU::test_forward_nn_MultiheadAttention_cpu_float64 ============================================================================================ short test summary info ============================================================================================= =========================================================================== 28 passed, 8 skipped, 336 deselected, 2 warnings in 19.71s =========================================================================== ``` cc albanD mruberry jbschlosser walterddr Pull Request resolved: https://github.com/pytorch/pytorch/pull/67176 Reviewed By: dagitses Differential Revision: D33094285 Pulled By: jbschlosser fbshipit-source-id: 0dd08261b8a457bf8bad5c7f3f6ded14b0beaf0d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
37ec99c0e4
commit
e8d5c7cf7f
@ -359,8 +359,13 @@ class TestModule(TestCase):
|
||||
# === Forward with default input
|
||||
with freeze_rng_state():
|
||||
default_output = m(*input_args, **input_kwargs)
|
||||
grad_output = default_output.clone().detach_().normal_()
|
||||
default_output.backward(grad_output, retain_graph=True)
|
||||
if isinstance(default_output, torch.Tensor):
|
||||
grad_output = default_output.clone().detach_().normal_()
|
||||
default_output.backward(grad_output, retain_graph=True)
|
||||
else:
|
||||
grad_output = tuple(o.clone().detach_().normal_() for o in default_output)
|
||||
for o, g_o in zip(default_output, grad_output):
|
||||
o.backward(g_o, retain_graph=True)
|
||||
|
||||
default_input_args_grad, default_input_kwargs_grad = deepcopy(self._get_grads((input_args, input_kwargs)))
|
||||
default_param_grad = deepcopy([p.grad for p in m.parameters()])
|
||||
@ -380,7 +385,11 @@ class TestModule(TestCase):
|
||||
|
||||
with freeze_rng_state():
|
||||
out = m(*in_args, **in_kwargs)
|
||||
out.backward(g_out_copy, retain_graph=True)
|
||||
if isinstance(out, torch.Tensor):
|
||||
out.backward(g_out_copy, retain_graph=True)
|
||||
else:
|
||||
for o, g_o in zip(out, g_out_copy):
|
||||
o.backward(g_o, retain_graph=True)
|
||||
|
||||
input_args_grad, input_kwargs_grad = self._get_grads((in_args, in_kwargs))
|
||||
self.assertEqual(out, default_output)
|
||||
@ -483,13 +492,13 @@ class TestModule(TestCase):
|
||||
gpu_p.data.copy_(cpu_p)
|
||||
|
||||
# === Compare forward output between cpu and gpu ===
|
||||
cpu_output = cpu_module(*cpu_forward_args, **cpu_forward_kwargs)
|
||||
gpu_output = gpu_module(*gpu_forward_args, **gpu_forward_kwargs)
|
||||
cpu_outputs = cpu_module(*cpu_forward_args, **cpu_forward_kwargs)
|
||||
gpu_outputs = gpu_module(*gpu_forward_args, **gpu_forward_kwargs)
|
||||
|
||||
self.assertEqual(cpu_output, gpu_output)
|
||||
self.assertEqual(cpu_outputs, gpu_outputs)
|
||||
|
||||
# === Run backwards on CPU and GPU and compare results ===
|
||||
for _ in range(5):
|
||||
def check_backward(cpu_output, gpu_output):
|
||||
cpu_grad_output = cpu_output.clone().normal_()
|
||||
gpu_grad_output = cpu_grad_output.type_as(gpu_output)
|
||||
|
||||
@ -507,6 +516,13 @@ class TestModule(TestCase):
|
||||
gpu_grad_kwarg_input = self._get_grads(gpu_forward_kwargs)
|
||||
self.assertEqual(cpu_grad_kwarg_input, gpu_grad_kwarg_input)
|
||||
|
||||
for _ in range(5):
|
||||
if isinstance(cpu_outputs, torch.Tensor):
|
||||
check_backward(cpu_outputs, gpu_outputs)
|
||||
else:
|
||||
for cpu_output, gpu_output in zip(cpu_outputs, gpu_outputs):
|
||||
check_backward(cpu_output, gpu_output)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestModule, globals())
|
||||
|
||||
|
Reference in New Issue
Block a user