mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[nn] lstm : no batch dim support (#71056)
Summary: Reference: https://github.com/pytorch/pytorch/issues/60585 TODO: * [x] Update docs Pull Request resolved: https://github.com/pytorch/pytorch/pull/71056 Reviewed By: samdow Differential Revision: D33638643 Pulled By: jbschlosser fbshipit-source-id: c0949829de8a8e6e7b2873f459a8d7da597a3be3 (cherry picked from commit f94d5849f66dd7da2ae4037b7c1d3e72817e926f)
This commit is contained in:
committed by
PyTorch MergeBot
parent
99d9883a22
commit
b372be4211
@ -363,8 +363,11 @@ class TestModule(TestCase):
|
||||
grad_output = default_output.clone().detach_().normal_()
|
||||
default_output.backward(grad_output, retain_graph=True)
|
||||
else:
|
||||
grad_output = tuple(o.clone().detach_().normal_() for o in default_output)
|
||||
for o, g_o in zip(default_output, grad_output):
|
||||
grad_output = tuple(self._traverse_obj(o, lambda o: o.clone().detach_().normal_())
|
||||
for o in default_output)
|
||||
flattened_default_output, _ = torch.utils._pytree.tree_flatten(default_output)
|
||||
flattened_grad_output, _ = torch.utils._pytree.tree_flatten(grad_output)
|
||||
for o, g_o in zip(flattened_default_output, flattened_grad_output):
|
||||
o.backward(g_o, retain_graph=True)
|
||||
|
||||
default_input_args_grad, default_input_kwargs_grad = deepcopy(self._get_grads((input_args, input_kwargs)))
|
||||
@ -388,7 +391,9 @@ class TestModule(TestCase):
|
||||
if isinstance(out, torch.Tensor):
|
||||
out.backward(g_out_copy, retain_graph=True)
|
||||
else:
|
||||
for o, g_o in zip(out, g_out_copy):
|
||||
flattened_out, _ = torch.utils._pytree.tree_flatten(out)
|
||||
flattened_g_out_copy, _ = torch.utils._pytree.tree_flatten(g_out_copy)
|
||||
for o, g_o in zip(flattened_out, flattened_g_out_copy):
|
||||
o.backward(g_o, retain_graph=True)
|
||||
|
||||
input_args_grad, input_kwargs_grad = self._get_grads((in_args, in_kwargs))
|
||||
@ -447,7 +452,9 @@ class TestModule(TestCase):
|
||||
new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)}
|
||||
|
||||
with freeze_rng_state():
|
||||
return m(*new_input_args, **new_kwargs, **other_kwargs)
|
||||
output = m(*new_input_args, **new_kwargs, **other_kwargs)
|
||||
output_flattened, _ = torch.utils._pytree.tree_flatten(output)
|
||||
return output_flattened
|
||||
|
||||
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
|
||||
|
||||
@ -531,7 +538,9 @@ class TestModule(TestCase):
|
||||
if isinstance(cpu_outputs, torch.Tensor):
|
||||
check_backward(cpu_outputs, gpu_outputs)
|
||||
else:
|
||||
for cpu_output, gpu_output in zip(cpu_outputs, gpu_outputs):
|
||||
flatten_cpu_outputs, _ = torch.utils._pytree.tree_flatten(cpu_outputs)
|
||||
flatten_gpu_outputs, _ = torch.utils._pytree.tree_flatten(gpu_outputs)
|
||||
for cpu_output, gpu_output in zip(flatten_cpu_outputs, flatten_gpu_outputs):
|
||||
check_backward(cpu_output, gpu_output)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user