From 97ccfad915fb291aecf0de4018d1c7558df8657f Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Fri, 19 Apr 2024 17:37:45 +0000 Subject: [PATCH] Fix test_decomp test for ops with py_impl(CompositeImplicitAutograd) (#116832) Pull Request resolved: https://github.com/pytorch/pytorch/pull/116832 Approved by: https://github.com/lezcano --- test/test_decomp.py | 43 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/test/test_decomp.py b/test/test_decomp.py index 39d0c2eef2b2..1063e0d33e01 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -144,7 +144,8 @@ def ref_vjp_no_create(f, *primals): def wrapped(cotangents): return _autograd_grad( - _as_tuple(result), primals, _as_tuple(cotangents), create_graph=False + _as_tuple(result), primals, _as_tuple(cotangents), create_graph=False, + retain_graph=True, ) return result, wrapped @@ -200,6 +201,12 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs) (torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1, (torch.float16, torch.ops.aten.nll_loss2d_forward.default): 1e-2, (torch.bfloat16, torch.ops.aten.nll_loss2d_forward.default): 2e-1, + (torch.float16, torch.ops.aten.hardswish.default): 2e-7, + (torch.bfloat16, torch.ops.aten.hardswish.default): 2e-7, + (torch.float16, torch.ops.aten.multi_margin_loss.default): 3e-2, + (torch.bfloat16, torch.ops.aten.multi_margin_loss.default): 3e-2, + (torch.float16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2, + (torch.bfloat16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2, # see https://github.com/pytorch/pytorch/pull/96264 (torch.float16, torch.ops.aten.mv.default): 1e-5, } @@ -488,6 +495,11 @@ if not TEST_WITH_SLOW: skip('unsafe_split'), # slow: takes 49 sec on A100 }) +comprehensive_failures = { + xfail("nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,)), # off by one error + xfail("nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,)), # off by one error + xfail("nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)), # off by one error +} @unMarkDynamoStrictTest class TestDecomp(TestCase): @@ -524,6 +536,7 @@ class TestDecomp(TestCase): @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @onlyNativeDeviceTypes @skipIfCrossRef + @skipOps('TestDecomp', 'test_comprehensive', comprehensive_failures) @suppress_warnings @ops(op_db) def test_comprehensive(self, device, dtype, op): @@ -810,6 +823,12 @@ def forward(self, x_1, start_1): aten_name = op.decomp_aten_name or op.aten_name func = op.get_op() + + def run_without_python_dispatcher(mode): + return any(isinstance(op, torch._ops.OpOverload) and + op.has_kernel_for_dispatch_key(DispatchKey.CompositeImplicitAutograd) + for op in mode.decomposed.union([func])) + for sample_input in samples: if requires_grad: fn, primals = normalize_op_input_output(func, sample_input) @@ -824,6 +843,12 @@ def forward(self, x_1, start_1): with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\ as mode, enable_python_dispatcher(): decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) + if run_without_python_dispatcher(mode): + # without this check, incorrect decomps at the python dispatcher level can still pass because + # they're checking aten decomps at the torch_dispatch level. + with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\ + as mode: + decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) if aten_name in decomposition_names: self.check_decomposed(aten_name, mode) @@ -833,15 +858,31 @@ def forward(self, x_1, start_1): with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\ as mode, enable_python_dispatcher(): decomp_vjp_fn(cotangents) + if run_without_python_dispatcher(mode): + # without this check, incorrect decomps at the python dispatcher level can still pass because + # they're checking aten decomps at the torch_dispatch level. + with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\ + as mode: + decomp_vjp_fn(cotangents) if not run_all: self.check_decomposed(op.aten_backward_name, mode) elif aten_name in decomposition_names or run_all: args = [sample_input.input] + list(sample_input.args) kwargs = sample_input.kwargs + # A failure here might be because the decomposition for the op is wrong or because a + # decomposition used by the particular op is wrong. with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\ as mode, enable_python_dispatcher(): func(*args, **kwargs) + + if run_without_python_dispatcher(mode): + # without this check, incorrect decomps at the python dispatcher level can still pass because + # they're checking aten decomps at the torch_dispatch level. + with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\ + as mode: + func(*args, **kwargs) + if not run_all: self.check_decomposed(aten_name, mode) else: