mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-15 06:48:48 +08:00
This reverts commit e773f28ee307e2a246a4b765f3a51117661b45ba. Reverted https://github.com/pytorch/pytorch/pull/106043 on behalf of https://github.com/DanilBaibak due to Break slow tests ([comment](https://github.com/pytorch/pytorch/pull/106043#issuecomment-1658642734))
This commit is contained in:
@ -959,59 +959,6 @@ class TestForeach(TestCase):
|
||||
sample.args = new_args
|
||||
_test(func, sample)
|
||||
|
||||
@ops(
|
||||
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_lerp_op_db,
|
||||
dtypes=OpDTypes.supported,
|
||||
allowed_dtypes=(torch.float64, torch.complex128),
|
||||
)
|
||||
def test_outplace_forward_mode_AD(self, device, dtype, op):
|
||||
if not op.supports_forward_ad:
|
||||
self.skipTest("forward AD not supported")
|
||||
|
||||
# note(crcrpar): without this, some unary functions fail, unlike inplace and/or complex.
|
||||
if dtype == torch.float64 and op.name in (
|
||||
"_foreach_acos", "_foreach_asin", "_foreach_log10", "_foreach_log1p", "_foreach_log2",
|
||||
"_foreach_log", "_foreach_pow", "_foreach_sqrt",
|
||||
):
|
||||
value_range = {"low": 0.5, "high": 1.0}
|
||||
else:
|
||||
value_range = {}
|
||||
for sample in op.sample_inputs(
|
||||
device, dtype, requires_grad=True, num_input_tenosrs=[5], same_size=True, **value_range,
|
||||
):
|
||||
# Skip `_foreach_pow.ScalarAndTensor(Scalar, Tensor[])`
|
||||
if op.name == "_foreach_pow" and isinstance(sample.input, Number):
|
||||
continue
|
||||
|
||||
def func(*tensorlist):
|
||||
kwargs = {"alpha": sample.kwargs["alpha"]} if "alpha" in sample.kwargs else {}
|
||||
return op.method_variant(tensorlist, *sample.args, **kwargs)
|
||||
|
||||
working_sample, err_msg_pattern = check_forward_mode_AD_sample(op, sample, dtype, False)
|
||||
if not working_sample:
|
||||
if not err_msg_pattern:
|
||||
# lhs of float64 and rhs of complex.
|
||||
continue
|
||||
with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
|
||||
gradcheck(
|
||||
func,
|
||||
sample.input,
|
||||
raise_exception=True,
|
||||
check_forward_ad=True,
|
||||
check_batched_forward_grad=False,
|
||||
check_backward_ad=False,
|
||||
check_batched_grad=False,
|
||||
)
|
||||
else:
|
||||
gradcheck(
|
||||
func,
|
||||
sample.input,
|
||||
raise_exception=True,
|
||||
check_forward_ad=True,
|
||||
check_backward_ad=False,
|
||||
check_batched_grad=False,
|
||||
)
|
||||
|
||||
@ops(
|
||||
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_lerp_op_db,
|
||||
dtypes=OpDTypes.supported,
|
||||
@ -1021,6 +968,36 @@ class TestForeach(TestCase):
|
||||
if not op.supports_forward_ad:
|
||||
self.skipTest("forward AD not supported")
|
||||
|
||||
# note(crcrpar): The combinations below are failing in its forward path,
|
||||
# which is before forward-mode AD happens. This function gates the combinations where
|
||||
# - subtraction with Scalar/ScalarList of boolean value:
|
||||
# - combinations where the in-place op in questions tries to write out complex result
|
||||
# into float storage (= `self`)
|
||||
def check_sample_eligibility(op, sample, dtype):
|
||||
if (
|
||||
op.name == "_foreach_sub"
|
||||
and (
|
||||
(isinstance(sample.args[0], list) and any(isinstance(a, bool) for a in sample.args[0]))
|
||||
or isinstance(sample.args[0], bool)
|
||||
)
|
||||
):
|
||||
return False, _BOOL_SUB_ERR_MSG
|
||||
rhs_arg_has_complex_number = sample.args and ((
|
||||
isinstance(sample.args[0], list)
|
||||
and any(isinstance(a, complex) for a in sample.args[0])
|
||||
) or (
|
||||
isinstance(sample.args[0], complex)
|
||||
))
|
||||
if dtype == torch.float64 and rhs_arg_has_complex_number:
|
||||
if op.name in ("_foreach_add", "_foreach_sub", "_foreach_mul", "_foreach_div"):
|
||||
return False, "result type ComplexDouble can't be cast to the desired output type Double"
|
||||
if op.name in ("_foreach_clamp_max", "_foreach_clamp_min"):
|
||||
return False, "clamp is not supported for complex types"
|
||||
if op.name == "_foreach_pow":
|
||||
return False, "Found dtype Double but expected ComplexDouble"
|
||||
|
||||
return True, ""
|
||||
|
||||
for sample in op.sample_inputs(
|
||||
device, dtype, requires_grad=True, num_input_tensors=[5], same_size=True,
|
||||
):
|
||||
@ -1031,7 +1008,7 @@ class TestForeach(TestCase):
|
||||
op.inplace_variant(tuple(t.clone() for t in tensorlist), *sample.args, **kwargs)
|
||||
return tensorlist
|
||||
|
||||
working_sample, err_msg_pattern = check_forward_mode_AD_sample(op, sample, dtype, True)
|
||||
working_sample, err_msg_pattern = check_sample_eligibility(op, sample, dtype)
|
||||
if not working_sample:
|
||||
with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
|
||||
gradcheck(
|
||||
@ -1083,38 +1060,6 @@ class TestForeach(TestCase):
|
||||
self.assertEqual(num_tensors_seen, 2 * num_tensors_per_list)
|
||||
|
||||
|
||||
# TODO(crcrpar): Hide this inside torch/testing/_internal.
|
||||
# would end up adding another layer to `foreach_inputs_sample_func.__call__`
|
||||
# so that we can use this function as something like the first argument of `filter` function.
|
||||
# Even after moving this function to testing, I personally think it'd be better to check the error message.
|
||||
def check_forward_mode_AD_sample(op, sample, dtype, is_inplace):
|
||||
if (
|
||||
op.name == "_foreach_sub"
|
||||
and (
|
||||
(isinstance(sample.args[0], list) and any(isinstance(a, bool) for a in sample.args[0]))
|
||||
or isinstance(sample.args[0], bool)
|
||||
)
|
||||
):
|
||||
return False, _BOOL_SUB_ERR_MSG
|
||||
rhs_arg_has_complex_number = sample.args and ((
|
||||
isinstance(sample.args[0], list)
|
||||
and any(isinstance(a, complex) for a in sample.args[0])
|
||||
) or (
|
||||
isinstance(sample.args[0], complex)
|
||||
))
|
||||
if rhs_arg_has_complex_number and dtype == torch.float64:
|
||||
if op.name in ("_foreach_clamp_max", "_foreach_clamp_min"):
|
||||
return False, "clamp is not supported for complex types"
|
||||
if not is_inplace:
|
||||
return False, ""
|
||||
else:
|
||||
if op.name == "_foreach_pow":
|
||||
return False, "Found dtype Double but expected ComplexDouble"
|
||||
if op.name in ("_foreach_add", "_foreach_sub", "_foreach_mul", "_foreach_div"):
|
||||
return False, "result type ComplexDouble can't be cast to the desired output type Double"
|
||||
return True, ""
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestForeach, globals())
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user