mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add backward check for test_memory_format (#106104)
Add backward check for test_memory_format. Pull Request resolved: https://github.com/pytorch/pytorch/pull/106104 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
@ -610,7 +610,7 @@ class TestModule(TestCase):
|
||||
atol, rtol = (3e-3, 7e-3) if is_sm86or80 else (None, None)
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=False, training=training)
|
||||
requires_grad=True, training=training)
|
||||
module_memformat_affects_out = module_info.module_memformat_affects_out
|
||||
|
||||
def _get_mem_formats(channels_last=False, channels_last_3d=False):
|
||||
@ -639,8 +639,13 @@ class TestModule(TestCase):
|
||||
d = obj.dim()
|
||||
if ((mem_format == torch.channels_last and d != 4)
|
||||
or (mem_format == torch.channels_last_3d and d != 5)):
|
||||
return obj
|
||||
return obj.to(memory_format=mem_format)
|
||||
return obj.clone().detach().requires_grad_(obj.requires_grad)
|
||||
return (
|
||||
obj.clone()
|
||||
.to(memory_format=mem_format)
|
||||
.detach()
|
||||
.requires_grad_(obj.requires_grad)
|
||||
)
|
||||
|
||||
return self._traverse_obj(obj, inner_to_mem_format)
|
||||
|
||||
@ -657,6 +662,9 @@ class TestModule(TestCase):
|
||||
self.assertTrue(output.is_contiguous())
|
||||
return self._traverse_obj(output, inner_check_out_mem_format)
|
||||
|
||||
def _req_grad(t):
|
||||
return isinstance(t, torch.Tensor) and t.requires_grad
|
||||
|
||||
for module_input in module_inputs:
|
||||
if module_input.forward_input is None:
|
||||
continue
|
||||
@ -676,6 +684,24 @@ class TestModule(TestCase):
|
||||
# === Get output in (contiguous, contiguous) configuration. ===
|
||||
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
desired_outputs = m(*args, **kwargs)
|
||||
# === Do backward pass. ===
|
||||
ref_diff_outputs = tuple(t for t in torch.utils._pytree.tree_flatten(desired_outputs)[0] if _req_grad(t))
|
||||
if training and len(ref_diff_outputs) > 0:
|
||||
params = tuple(p for p in m.parameters())
|
||||
ref_diff_inputs = tuple(
|
||||
t
|
||||
for t in torch.utils._pytree.tree_flatten((args, kwargs, params))[0]
|
||||
if _req_grad(t)
|
||||
)
|
||||
ref_grad_outputs = tuple(
|
||||
torch.rand_like(t)
|
||||
for t in ref_diff_outputs
|
||||
)
|
||||
ref_grad_inputs = torch.autograd.grad(
|
||||
ref_diff_outputs,
|
||||
ref_diff_inputs,
|
||||
grad_outputs=ref_grad_outputs,
|
||||
)
|
||||
|
||||
for input_mem_format in input_mem_formats:
|
||||
# === Change memformat of input. ===
|
||||
@ -693,12 +719,43 @@ class TestModule(TestCase):
|
||||
outputs = m(*args, **kwargs)
|
||||
|
||||
# === Compare outputs to (contiguous, contiguous) output. ===
|
||||
if input_mem_format != torch.contiguous_format or module_mem_formats != torch.contiguous_format:
|
||||
if input_mem_format != torch.contiguous_format or module_mem_format != torch.contiguous_format:
|
||||
self.assertEqual(outputs, desired_outputs, rtol=rtol, atol=atol)
|
||||
|
||||
# === Check mem format of output. ===
|
||||
_check_out_mem_format(outputs, input_mem_format, module_mem_format)
|
||||
|
||||
# === Do backward pass. ===
|
||||
diff_outputs = tuple(t for t in torch.utils._pytree.tree_flatten(outputs)[0] if _req_grad(t))
|
||||
if training and len(diff_outputs) > 0:
|
||||
params = tuple(p for p in m.parameters())
|
||||
diff_inputs = tuple(
|
||||
t
|
||||
for t in torch.utils._pytree.tree_flatten((args, kwargs, params))[0]
|
||||
if _req_grad(t)
|
||||
)
|
||||
grad_outputs = tuple(
|
||||
torch.empty_like(t1).copy_(t2)
|
||||
for (t1, t2) in zip(diff_outputs, ref_grad_outputs)
|
||||
)
|
||||
|
||||
grad_inputs = torch.autograd.grad(
|
||||
diff_outputs,
|
||||
diff_inputs,
|
||||
grad_outputs=grad_outputs,
|
||||
)
|
||||
|
||||
if (
|
||||
input_mem_format != torch.contiguous_format
|
||||
or module_mem_format != torch.contiguous_format
|
||||
):
|
||||
self.assertEqual(
|
||||
grad_inputs, ref_grad_inputs, rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
# === Check mem format of grad_inputs. ===
|
||||
_check_out_mem_format(grad_inputs, input_mem_format, module_mem_format)
|
||||
|
||||
# Test whether train and eval modes differ for each module. Use to verify
|
||||
# that the ModuleInfo entry flag is correct.
|
||||
@modules(module_db, train_eval_mode=TrainEvalMode.train_only)
|
||||
|
Reference in New Issue
Block a user