Revert "Reland "Add forward mode AD to out-place foreach functions (#102409) (#106043)"

This reverts commit e773f28ee307e2a246a4b765f3a51117661b45ba.

Reverted https://github.com/pytorch/pytorch/pull/106043 on behalf of https://github.com/DanilBaibak due to Break slow tests ([comment](https://github.com/pytorch/pytorch/pull/106043#issuecomment-1658642734))
This commit is contained in:
PyTorch MergeBot
2023-07-31 15:50:34 +00:00
parent c5b9dc1f40
commit 2b427ae3a7
6 changed files with 63 additions and 246 deletions

View File

@ -959,59 +959,6 @@ class TestForeach(TestCase):
sample.args = new_args
_test(func, sample)
@ops(
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_lerp_op_db,
dtypes=OpDTypes.supported,
allowed_dtypes=(torch.float64, torch.complex128),
)
def test_outplace_forward_mode_AD(self, device, dtype, op):
if not op.supports_forward_ad:
self.skipTest("forward AD not supported")
# note(crcrpar): without this, some unary functions fail, unlike inplace and/or complex.
if dtype == torch.float64 and op.name in (
"_foreach_acos", "_foreach_asin", "_foreach_log10", "_foreach_log1p", "_foreach_log2",
"_foreach_log", "_foreach_pow", "_foreach_sqrt",
):
value_range = {"low": 0.5, "high": 1.0}
else:
value_range = {}
for sample in op.sample_inputs(
device, dtype, requires_grad=True, num_input_tenosrs=[5], same_size=True, **value_range,
):
# Skip `_foreach_pow.ScalarAndTensor(Scalar, Tensor[])`
if op.name == "_foreach_pow" and isinstance(sample.input, Number):
continue
def func(*tensorlist):
kwargs = {"alpha": sample.kwargs["alpha"]} if "alpha" in sample.kwargs else {}
return op.method_variant(tensorlist, *sample.args, **kwargs)
working_sample, err_msg_pattern = check_forward_mode_AD_sample(op, sample, dtype, False)
if not working_sample:
if not err_msg_pattern:
# lhs of float64 and rhs of complex.
continue
with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
gradcheck(
func,
sample.input,
raise_exception=True,
check_forward_ad=True,
check_batched_forward_grad=False,
check_backward_ad=False,
check_batched_grad=False,
)
else:
gradcheck(
func,
sample.input,
raise_exception=True,
check_forward_ad=True,
check_backward_ad=False,
check_batched_grad=False,
)
@ops(
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_lerp_op_db,
dtypes=OpDTypes.supported,
@ -1021,6 +968,36 @@ class TestForeach(TestCase):
if not op.supports_forward_ad:
self.skipTest("forward AD not supported")
# note(crcrpar): The combinations below are failing in its forward path,
# which is before forward-mode AD happens. This function gates the combinations where
# - subtraction with Scalar/ScalarList of boolean value:
# - combinations where the in-place op in questions tries to write out complex result
# into float storage (= `self`)
def check_sample_eligibility(op, sample, dtype):
if (
op.name == "_foreach_sub"
and (
(isinstance(sample.args[0], list) and any(isinstance(a, bool) for a in sample.args[0]))
or isinstance(sample.args[0], bool)
)
):
return False, _BOOL_SUB_ERR_MSG
rhs_arg_has_complex_number = sample.args and ((
isinstance(sample.args[0], list)
and any(isinstance(a, complex) for a in sample.args[0])
) or (
isinstance(sample.args[0], complex)
))
if dtype == torch.float64 and rhs_arg_has_complex_number:
if op.name in ("_foreach_add", "_foreach_sub", "_foreach_mul", "_foreach_div"):
return False, "result type ComplexDouble can't be cast to the desired output type Double"
if op.name in ("_foreach_clamp_max", "_foreach_clamp_min"):
return False, "clamp is not supported for complex types"
if op.name == "_foreach_pow":
return False, "Found dtype Double but expected ComplexDouble"
return True, ""
for sample in op.sample_inputs(
device, dtype, requires_grad=True, num_input_tensors=[5], same_size=True,
):
@ -1031,7 +1008,7 @@ class TestForeach(TestCase):
op.inplace_variant(tuple(t.clone() for t in tensorlist), *sample.args, **kwargs)
return tensorlist
working_sample, err_msg_pattern = check_forward_mode_AD_sample(op, sample, dtype, True)
working_sample, err_msg_pattern = check_sample_eligibility(op, sample, dtype)
if not working_sample:
with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
gradcheck(
@ -1083,38 +1060,6 @@ class TestForeach(TestCase):
self.assertEqual(num_tensors_seen, 2 * num_tensors_per_list)
# TODO(crcrpar): Hide this inside torch/testing/_internal.
# would end up adding another layer to `foreach_inputs_sample_func.__call__`
# so that we can use this function as something like the first argument of `filter` function.
# Even after moving this function to testing, I personally think it'd be better to check the error message.
def check_forward_mode_AD_sample(op, sample, dtype, is_inplace):
if (
op.name == "_foreach_sub"
and (
(isinstance(sample.args[0], list) and any(isinstance(a, bool) for a in sample.args[0]))
or isinstance(sample.args[0], bool)
)
):
return False, _BOOL_SUB_ERR_MSG
rhs_arg_has_complex_number = sample.args and ((
isinstance(sample.args[0], list)
and any(isinstance(a, complex) for a in sample.args[0])
) or (
isinstance(sample.args[0], complex)
))
if rhs_arg_has_complex_number and dtype == torch.float64:
if op.name in ("_foreach_clamp_max", "_foreach_clamp_min"):
return False, "clamp is not supported for complex types"
if not is_inplace:
return False, ""
else:
if op.name == "_foreach_pow":
return False, "Found dtype Double but expected ComplexDouble"
if op.name in ("_foreach_add", "_foreach_sub", "_foreach_mul", "_foreach_div"):
return False, "result type ComplexDouble can't be cast to the desired output type Double"
return True, ""
instantiate_device_type_tests(TestForeach, globals())