mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix test_decomp test for ops with py_impl(CompositeImplicitAutograd) (#116832)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116832 Approved by: https://github.com/lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
a3e3693afc
commit
97ccfad915
@ -144,7 +144,8 @@ def ref_vjp_no_create(f, *primals):
|
||||
|
||||
def wrapped(cotangents):
|
||||
return _autograd_grad(
|
||||
_as_tuple(result), primals, _as_tuple(cotangents), create_graph=False
|
||||
_as_tuple(result), primals, _as_tuple(cotangents), create_graph=False,
|
||||
retain_graph=True,
|
||||
)
|
||||
|
||||
return result, wrapped
|
||||
@ -200,6 +201,12 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs)
|
||||
(torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1,
|
||||
(torch.float16, torch.ops.aten.nll_loss2d_forward.default): 1e-2,
|
||||
(torch.bfloat16, torch.ops.aten.nll_loss2d_forward.default): 2e-1,
|
||||
(torch.float16, torch.ops.aten.hardswish.default): 2e-7,
|
||||
(torch.bfloat16, torch.ops.aten.hardswish.default): 2e-7,
|
||||
(torch.float16, torch.ops.aten.multi_margin_loss.default): 3e-2,
|
||||
(torch.bfloat16, torch.ops.aten.multi_margin_loss.default): 3e-2,
|
||||
(torch.float16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2,
|
||||
(torch.bfloat16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2,
|
||||
# see https://github.com/pytorch/pytorch/pull/96264
|
||||
(torch.float16, torch.ops.aten.mv.default): 1e-5,
|
||||
}
|
||||
@ -488,6 +495,11 @@ if not TEST_WITH_SLOW:
|
||||
skip('unsafe_split'), # slow: takes 49 sec on A100
|
||||
})
|
||||
|
||||
comprehensive_failures = {
|
||||
xfail("nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,)), # off by one error
|
||||
xfail("nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,)), # off by one error
|
||||
xfail("nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)), # off by one error
|
||||
}
|
||||
|
||||
@unMarkDynamoStrictTest
|
||||
class TestDecomp(TestCase):
|
||||
@ -524,6 +536,7 @@ class TestDecomp(TestCase):
|
||||
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
|
||||
@onlyNativeDeviceTypes
|
||||
@skipIfCrossRef
|
||||
@skipOps('TestDecomp', 'test_comprehensive', comprehensive_failures)
|
||||
@suppress_warnings
|
||||
@ops(op_db)
|
||||
def test_comprehensive(self, device, dtype, op):
|
||||
@ -810,6 +823,12 @@ def forward(self, x_1, start_1):
|
||||
aten_name = op.decomp_aten_name or op.aten_name
|
||||
|
||||
func = op.get_op()
|
||||
|
||||
def run_without_python_dispatcher(mode):
|
||||
return any(isinstance(op, torch._ops.OpOverload) and
|
||||
op.has_kernel_for_dispatch_key(DispatchKey.CompositeImplicitAutograd)
|
||||
for op in mode.decomposed.union([func]))
|
||||
|
||||
for sample_input in samples:
|
||||
if requires_grad:
|
||||
fn, primals = normalize_op_input_output(func, sample_input)
|
||||
@ -824,6 +843,12 @@ def forward(self, x_1, start_1):
|
||||
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
|
||||
as mode, enable_python_dispatcher():
|
||||
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
|
||||
if run_without_python_dispatcher(mode):
|
||||
# without this check, incorrect decomps at the python dispatcher level can still pass because
|
||||
# they're checking aten decomps at the torch_dispatch level.
|
||||
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
|
||||
as mode:
|
||||
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
|
||||
if aten_name in decomposition_names:
|
||||
self.check_decomposed(aten_name, mode)
|
||||
|
||||
@ -833,15 +858,31 @@ def forward(self, x_1, start_1):
|
||||
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
|
||||
as mode, enable_python_dispatcher():
|
||||
decomp_vjp_fn(cotangents)
|
||||
if run_without_python_dispatcher(mode):
|
||||
# without this check, incorrect decomps at the python dispatcher level can still pass because
|
||||
# they're checking aten decomps at the torch_dispatch level.
|
||||
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
|
||||
as mode:
|
||||
decomp_vjp_fn(cotangents)
|
||||
if not run_all:
|
||||
self.check_decomposed(op.aten_backward_name, mode)
|
||||
|
||||
elif aten_name in decomposition_names or run_all:
|
||||
args = [sample_input.input] + list(sample_input.args)
|
||||
kwargs = sample_input.kwargs
|
||||
# A failure here might be because the decomposition for the op is wrong or because a
|
||||
# decomposition used by the particular op is wrong.
|
||||
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
|
||||
as mode, enable_python_dispatcher():
|
||||
func(*args, **kwargs)
|
||||
|
||||
if run_without_python_dispatcher(mode):
|
||||
# without this check, incorrect decomps at the python dispatcher level can still pass because
|
||||
# they're checking aten decomps at the torch_dispatch level.
|
||||
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
|
||||
as mode:
|
||||
func(*args, **kwargs)
|
||||
|
||||
if not run_all:
|
||||
self.check_decomposed(aten_name, mode)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user