Fix test_decomp test for ops with py_impl(CompositeImplicitAutograd) (#116832)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116832
Approved by: https://github.com/lezcano
This commit is contained in:
Isuru Fernando
2024-04-19 17:37:45 +00:00
committed by PyTorch MergeBot
parent a3e3693afc
commit 97ccfad915

View File

@ -144,7 +144,8 @@ def ref_vjp_no_create(f, *primals):
def wrapped(cotangents):
return _autograd_grad(
_as_tuple(result), primals, _as_tuple(cotangents), create_graph=False
_as_tuple(result), primals, _as_tuple(cotangents), create_graph=False,
retain_graph=True,
)
return result, wrapped
@ -200,6 +201,12 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs)
(torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1,
(torch.float16, torch.ops.aten.nll_loss2d_forward.default): 1e-2,
(torch.bfloat16, torch.ops.aten.nll_loss2d_forward.default): 2e-1,
(torch.float16, torch.ops.aten.hardswish.default): 2e-7,
(torch.bfloat16, torch.ops.aten.hardswish.default): 2e-7,
(torch.float16, torch.ops.aten.multi_margin_loss.default): 3e-2,
(torch.bfloat16, torch.ops.aten.multi_margin_loss.default): 3e-2,
(torch.float16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2,
(torch.bfloat16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2,
# see https://github.com/pytorch/pytorch/pull/96264
(torch.float16, torch.ops.aten.mv.default): 1e-5,
}
@ -488,6 +495,11 @@ if not TEST_WITH_SLOW:
skip('unsafe_split'), # slow: takes 49 sec on A100
})
comprehensive_failures = {
xfail("nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,)), # off by one error
xfail("nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,)), # off by one error
xfail("nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)), # off by one error
}
@unMarkDynamoStrictTest
class TestDecomp(TestCase):
@ -524,6 +536,7 @@ class TestDecomp(TestCase):
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyNativeDeviceTypes
@skipIfCrossRef
@skipOps('TestDecomp', 'test_comprehensive', comprehensive_failures)
@suppress_warnings
@ops(op_db)
def test_comprehensive(self, device, dtype, op):
@ -810,6 +823,12 @@ def forward(self, x_1, start_1):
aten_name = op.decomp_aten_name or op.aten_name
func = op.get_op()
def run_without_python_dispatcher(mode):
return any(isinstance(op, torch._ops.OpOverload) and
op.has_kernel_for_dispatch_key(DispatchKey.CompositeImplicitAutograd)
for op in mode.decomposed.union([func]))
for sample_input in samples:
if requires_grad:
fn, primals = normalize_op_input_output(func, sample_input)
@ -824,6 +843,12 @@ def forward(self, x_1, start_1):
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
as mode, enable_python_dispatcher():
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
if run_without_python_dispatcher(mode):
# without this check, incorrect decomps at the python dispatcher level can still pass because
# they're checking aten decomps at the torch_dispatch level.
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
as mode:
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
if aten_name in decomposition_names:
self.check_decomposed(aten_name, mode)
@ -833,15 +858,31 @@ def forward(self, x_1, start_1):
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
as mode, enable_python_dispatcher():
decomp_vjp_fn(cotangents)
if run_without_python_dispatcher(mode):
# without this check, incorrect decomps at the python dispatcher level can still pass because
# they're checking aten decomps at the torch_dispatch level.
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
as mode:
decomp_vjp_fn(cotangents)
if not run_all:
self.check_decomposed(op.aten_backward_name, mode)
elif aten_name in decomposition_names or run_all:
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
# A failure here might be because the decomposition for the op is wrong or because a
# decomposition used by the particular op is wrong.
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
as mode, enable_python_dispatcher():
func(*args, **kwargs)
if run_without_python_dispatcher(mode):
# without this check, incorrect decomps at the python dispatcher level can still pass because
# they're checking aten decomps at the torch_dispatch level.
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
as mode:
func(*args, **kwargs)
if not run_all:
self.check_decomposed(aten_name, mode)
else: