Revert "Add backward check for test_memory_format (#106104)"

This reverts commit 2e44adb06608d09a36b899ffdb375cb7d46a78d2.

Reverted https://github.com/pytorch/pytorch/pull/106104 on behalf of https://github.com/huydhn due to Sorry for reverting this but it is failing inductor job in trunk 2e44adb066.  I will add ciflow/inductor label to the PR make sure that the test runs there ([comment](https://github.com/pytorch/pytorch/pull/106104#issuecomment-1683119990))
This commit is contained in:
PyTorch MergeBot
2023-08-17 23:45:31 +00:00
parent d3f92ca9e9
commit 02bcaf45f6
2 changed files with 14 additions and 145 deletions

View File

@ -610,7 +610,7 @@ class TestModule(TestCase):
atol, rtol = (3e-3, 7e-3) if is_sm86or80 else (None, None)
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=True, training=training)
requires_grad=False, training=training)
module_memformat_affects_out = module_info.module_memformat_affects_out
def _get_mem_formats(channels_last=False, channels_last_3d=False):
@ -639,13 +639,8 @@ class TestModule(TestCase):
d = obj.dim()
if ((mem_format == torch.channels_last and d != 4)
or (mem_format == torch.channels_last_3d and d != 5)):
return obj.clone().detach().requires_grad_(obj.requires_grad)
return (
obj.clone()
.to(memory_format=mem_format)
.detach()
.requires_grad_(obj.requires_grad)
)
return obj
return obj.to(memory_format=mem_format)
return self._traverse_obj(obj, inner_to_mem_format)
@ -662,9 +657,6 @@ class TestModule(TestCase):
self.assertTrue(output.is_contiguous())
return self._traverse_obj(output, inner_check_out_mem_format)
def _req_grad(t):
return isinstance(t, torch.Tensor) and t.requires_grad
for module_input in module_inputs:
if module_input.forward_input is None:
continue
@ -684,24 +676,6 @@ class TestModule(TestCase):
# === Get output in (contiguous, contiguous) configuration. ===
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
desired_outputs = m(*args, **kwargs)
# === Do backward pass. ===
ref_diff_outputs = tuple(t for t in torch.utils._pytree.tree_flatten(desired_outputs)[0] if _req_grad(t))
if training and len(ref_diff_outputs) > 0:
params = tuple(p for p in m.parameters())
ref_diff_inputs = tuple(
t
for t in torch.utils._pytree.tree_flatten((args, kwargs, params))[0]
if _req_grad(t)
)
ref_grad_outputs = tuple(
torch.rand_like(t)
for t in ref_diff_outputs
)
ref_grad_inputs = torch.autograd.grad(
ref_diff_outputs,
ref_diff_inputs,
grad_outputs=ref_grad_outputs,
)
for input_mem_format in input_mem_formats:
# === Change memformat of input. ===
@ -719,43 +693,12 @@ class TestModule(TestCase):
outputs = m(*args, **kwargs)
# === Compare outputs to (contiguous, contiguous) output. ===
if input_mem_format != torch.contiguous_format or module_mem_format != torch.contiguous_format:
if input_mem_format != torch.contiguous_format or module_mem_formats != torch.contiguous_format:
self.assertEqual(outputs, desired_outputs, rtol=rtol, atol=atol)
# === Check mem format of output. ===
_check_out_mem_format(outputs, input_mem_format, module_mem_format)
# === Do backward pass. ===
diff_outputs = tuple(t for t in torch.utils._pytree.tree_flatten(outputs)[0] if _req_grad(t))
if training and len(diff_outputs) > 0:
params = tuple(p for p in m.parameters())
diff_inputs = tuple(
t
for t in torch.utils._pytree.tree_flatten((args, kwargs, params))[0]
if _req_grad(t)
)
grad_outputs = tuple(
torch.empty_like(t1).copy_(t2)
for (t1, t2) in zip(diff_outputs, ref_grad_outputs)
)
grad_inputs = torch.autograd.grad(
diff_outputs,
diff_inputs,
grad_outputs=grad_outputs,
)
if (
input_mem_format != torch.contiguous_format
or module_mem_format != torch.contiguous_format
):
self.assertEqual(
grad_inputs, ref_grad_inputs, rtol=rtol, atol=atol
)
# === Check mem format of grad_inputs. ===
_check_out_mem_format(grad_inputs, input_mem_format, module_mem_format)
# Test whether train and eval modes differ for each module. Use to verify
# that the ModuleInfo entry flag is correct.
@modules(module_db, train_eval_mode=TrainEvalMode.train_only)