[nn] lstm : no batch dim support (#71056)

Summary:
Reference: https://github.com/pytorch/pytorch/issues/60585

TODO:
* [x] Update docs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71056

Reviewed By: samdow

Differential Revision: D33638643

Pulled By: jbschlosser

fbshipit-source-id: c0949829de8a8e6e7b2873f459a8d7da597a3be3
(cherry picked from commit f94d5849f66dd7da2ae4037b7c1d3e72817e926f)
This commit is contained in:
kshitij12345
2022-01-24 07:08:32 -08:00
committed by PyTorch MergeBot
parent 99d9883a22
commit b372be4211
3 changed files with 151 additions and 65 deletions

View File

@ -363,8 +363,11 @@ class TestModule(TestCase):
grad_output = default_output.clone().detach_().normal_()
default_output.backward(grad_output, retain_graph=True)
else:
grad_output = tuple(o.clone().detach_().normal_() for o in default_output)
for o, g_o in zip(default_output, grad_output):
grad_output = tuple(self._traverse_obj(o, lambda o: o.clone().detach_().normal_())
for o in default_output)
flattened_default_output, _ = torch.utils._pytree.tree_flatten(default_output)
flattened_grad_output, _ = torch.utils._pytree.tree_flatten(grad_output)
for o, g_o in zip(flattened_default_output, flattened_grad_output):
o.backward(g_o, retain_graph=True)
default_input_args_grad, default_input_kwargs_grad = deepcopy(self._get_grads((input_args, input_kwargs)))
@ -388,7 +391,9 @@ class TestModule(TestCase):
if isinstance(out, torch.Tensor):
out.backward(g_out_copy, retain_graph=True)
else:
for o, g_o in zip(out, g_out_copy):
flattened_out, _ = torch.utils._pytree.tree_flatten(out)
flattened_g_out_copy, _ = torch.utils._pytree.tree_flatten(g_out_copy)
for o, g_o in zip(flattened_out, flattened_g_out_copy):
o.backward(g_o, retain_graph=True)
input_args_grad, input_kwargs_grad = self._get_grads((in_args, in_kwargs))
@ -447,7 +452,9 @@ class TestModule(TestCase):
new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)}
with freeze_rng_state():
return m(*new_input_args, **new_kwargs, **other_kwargs)
output = m(*new_input_args, **new_kwargs, **other_kwargs)
output_flattened, _ = torch.utils._pytree.tree_flatten(output)
return output_flattened
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
@ -531,7 +538,9 @@ class TestModule(TestCase):
if isinstance(cpu_outputs, torch.Tensor):
check_backward(cpu_outputs, gpu_outputs)
else:
for cpu_output, gpu_output in zip(cpu_outputs, gpu_outputs):
flatten_cpu_outputs, _ = torch.utils._pytree.tree_flatten(cpu_outputs)
flatten_gpu_outputs, _ = torch.utils._pytree.tree_flatten(gpu_outputs)
for cpu_output, gpu_output in zip(flatten_cpu_outputs, flatten_gpu_outputs):
check_backward(cpu_output, gpu_output)