mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-17 16:46:31 +08:00
This reverts commit e773f28ee307e2a246a4b765f3a51117661b45ba. Reverted https://github.com/pytorch/pytorch/pull/106043 on behalf of https://github.com/DanilBaibak due to Break slow tests ([comment](https://github.com/pytorch/pytorch/pull/106043#issuecomment-1658642734))
This commit is contained in:
@ -959,59 +959,6 @@ class TestForeach(TestCase):
|
|||||||
sample.args = new_args
|
sample.args = new_args
|
||||||
_test(func, sample)
|
_test(func, sample)
|
||||||
|
|
||||||
@ops(
|
|
||||||
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_lerp_op_db,
|
|
||||||
dtypes=OpDTypes.supported,
|
|
||||||
allowed_dtypes=(torch.float64, torch.complex128),
|
|
||||||
)
|
|
||||||
def test_outplace_forward_mode_AD(self, device, dtype, op):
|
|
||||||
if not op.supports_forward_ad:
|
|
||||||
self.skipTest("forward AD not supported")
|
|
||||||
|
|
||||||
# note(crcrpar): without this, some unary functions fail, unlike inplace and/or complex.
|
|
||||||
if dtype == torch.float64 and op.name in (
|
|
||||||
"_foreach_acos", "_foreach_asin", "_foreach_log10", "_foreach_log1p", "_foreach_log2",
|
|
||||||
"_foreach_log", "_foreach_pow", "_foreach_sqrt",
|
|
||||||
):
|
|
||||||
value_range = {"low": 0.5, "high": 1.0}
|
|
||||||
else:
|
|
||||||
value_range = {}
|
|
||||||
for sample in op.sample_inputs(
|
|
||||||
device, dtype, requires_grad=True, num_input_tenosrs=[5], same_size=True, **value_range,
|
|
||||||
):
|
|
||||||
# Skip `_foreach_pow.ScalarAndTensor(Scalar, Tensor[])`
|
|
||||||
if op.name == "_foreach_pow" and isinstance(sample.input, Number):
|
|
||||||
continue
|
|
||||||
|
|
||||||
def func(*tensorlist):
|
|
||||||
kwargs = {"alpha": sample.kwargs["alpha"]} if "alpha" in sample.kwargs else {}
|
|
||||||
return op.method_variant(tensorlist, *sample.args, **kwargs)
|
|
||||||
|
|
||||||
working_sample, err_msg_pattern = check_forward_mode_AD_sample(op, sample, dtype, False)
|
|
||||||
if not working_sample:
|
|
||||||
if not err_msg_pattern:
|
|
||||||
# lhs of float64 and rhs of complex.
|
|
||||||
continue
|
|
||||||
with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
|
|
||||||
gradcheck(
|
|
||||||
func,
|
|
||||||
sample.input,
|
|
||||||
raise_exception=True,
|
|
||||||
check_forward_ad=True,
|
|
||||||
check_batched_forward_grad=False,
|
|
||||||
check_backward_ad=False,
|
|
||||||
check_batched_grad=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
gradcheck(
|
|
||||||
func,
|
|
||||||
sample.input,
|
|
||||||
raise_exception=True,
|
|
||||||
check_forward_ad=True,
|
|
||||||
check_backward_ad=False,
|
|
||||||
check_batched_grad=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
@ops(
|
@ops(
|
||||||
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_lerp_op_db,
|
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_lerp_op_db,
|
||||||
dtypes=OpDTypes.supported,
|
dtypes=OpDTypes.supported,
|
||||||
@ -1021,6 +968,36 @@ class TestForeach(TestCase):
|
|||||||
if not op.supports_forward_ad:
|
if not op.supports_forward_ad:
|
||||||
self.skipTest("forward AD not supported")
|
self.skipTest("forward AD not supported")
|
||||||
|
|
||||||
|
# note(crcrpar): The combinations below are failing in its forward path,
|
||||||
|
# which is before forward-mode AD happens. This function gates the combinations where
|
||||||
|
# - subtraction with Scalar/ScalarList of boolean value:
|
||||||
|
# - combinations where the in-place op in questions tries to write out complex result
|
||||||
|
# into float storage (= `self`)
|
||||||
|
def check_sample_eligibility(op, sample, dtype):
|
||||||
|
if (
|
||||||
|
op.name == "_foreach_sub"
|
||||||
|
and (
|
||||||
|
(isinstance(sample.args[0], list) and any(isinstance(a, bool) for a in sample.args[0]))
|
||||||
|
or isinstance(sample.args[0], bool)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
return False, _BOOL_SUB_ERR_MSG
|
||||||
|
rhs_arg_has_complex_number = sample.args and ((
|
||||||
|
isinstance(sample.args[0], list)
|
||||||
|
and any(isinstance(a, complex) for a in sample.args[0])
|
||||||
|
) or (
|
||||||
|
isinstance(sample.args[0], complex)
|
||||||
|
))
|
||||||
|
if dtype == torch.float64 and rhs_arg_has_complex_number:
|
||||||
|
if op.name in ("_foreach_add", "_foreach_sub", "_foreach_mul", "_foreach_div"):
|
||||||
|
return False, "result type ComplexDouble can't be cast to the desired output type Double"
|
||||||
|
if op.name in ("_foreach_clamp_max", "_foreach_clamp_min"):
|
||||||
|
return False, "clamp is not supported for complex types"
|
||||||
|
if op.name == "_foreach_pow":
|
||||||
|
return False, "Found dtype Double but expected ComplexDouble"
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
for sample in op.sample_inputs(
|
for sample in op.sample_inputs(
|
||||||
device, dtype, requires_grad=True, num_input_tensors=[5], same_size=True,
|
device, dtype, requires_grad=True, num_input_tensors=[5], same_size=True,
|
||||||
):
|
):
|
||||||
@ -1031,7 +1008,7 @@ class TestForeach(TestCase):
|
|||||||
op.inplace_variant(tuple(t.clone() for t in tensorlist), *sample.args, **kwargs)
|
op.inplace_variant(tuple(t.clone() for t in tensorlist), *sample.args, **kwargs)
|
||||||
return tensorlist
|
return tensorlist
|
||||||
|
|
||||||
working_sample, err_msg_pattern = check_forward_mode_AD_sample(op, sample, dtype, True)
|
working_sample, err_msg_pattern = check_sample_eligibility(op, sample, dtype)
|
||||||
if not working_sample:
|
if not working_sample:
|
||||||
with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
|
with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
|
||||||
gradcheck(
|
gradcheck(
|
||||||
@ -1083,38 +1060,6 @@ class TestForeach(TestCase):
|
|||||||
self.assertEqual(num_tensors_seen, 2 * num_tensors_per_list)
|
self.assertEqual(num_tensors_seen, 2 * num_tensors_per_list)
|
||||||
|
|
||||||
|
|
||||||
# TODO(crcrpar): Hide this inside torch/testing/_internal.
|
|
||||||
# would end up adding another layer to `foreach_inputs_sample_func.__call__`
|
|
||||||
# so that we can use this function as something like the first argument of `filter` function.
|
|
||||||
# Even after moving this function to testing, I personally think it'd be better to check the error message.
|
|
||||||
def check_forward_mode_AD_sample(op, sample, dtype, is_inplace):
|
|
||||||
if (
|
|
||||||
op.name == "_foreach_sub"
|
|
||||||
and (
|
|
||||||
(isinstance(sample.args[0], list) and any(isinstance(a, bool) for a in sample.args[0]))
|
|
||||||
or isinstance(sample.args[0], bool)
|
|
||||||
)
|
|
||||||
):
|
|
||||||
return False, _BOOL_SUB_ERR_MSG
|
|
||||||
rhs_arg_has_complex_number = sample.args and ((
|
|
||||||
isinstance(sample.args[0], list)
|
|
||||||
and any(isinstance(a, complex) for a in sample.args[0])
|
|
||||||
) or (
|
|
||||||
isinstance(sample.args[0], complex)
|
|
||||||
))
|
|
||||||
if rhs_arg_has_complex_number and dtype == torch.float64:
|
|
||||||
if op.name in ("_foreach_clamp_max", "_foreach_clamp_min"):
|
|
||||||
return False, "clamp is not supported for complex types"
|
|
||||||
if not is_inplace:
|
|
||||||
return False, ""
|
|
||||||
else:
|
|
||||||
if op.name == "_foreach_pow":
|
|
||||||
return False, "Found dtype Double but expected ComplexDouble"
|
|
||||||
if op.name in ("_foreach_add", "_foreach_sub", "_foreach_mul", "_foreach_div"):
|
|
||||||
return False, "result type ComplexDouble can't be cast to the desired output type Double"
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
|
|
||||||
instantiate_device_type_tests(TestForeach, globals())
|
instantiate_device_type_tests(TestForeach, globals())
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3040,16 +3040,13 @@
|
|||||||
- name: _foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[]
|
- name: _foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[]
|
||||||
self: div_tensor_self_backward(grads[i], other[i], self[i].scalar_type())
|
self: div_tensor_self_backward(grads[i], other[i], self[i].scalar_type())
|
||||||
other: div_tensor_other_backward(grads[i], self[i], other[i])
|
other: div_tensor_other_backward(grads[i], self[i], other[i])
|
||||||
result: (self_t - other_t * result[i]) / other_p
|
|
||||||
|
|
||||||
- name: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[]
|
- name: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[]
|
||||||
self: pow_backward_self(grads[i], self[i], exponent[i])
|
self: pow_backward_self(grads[i], self[i], exponent[i])
|
||||||
exponent: pow_backward_exponent(grads[i], self[i], exponent[i], result[i])
|
exponent: pow_backward_exponent(grads[i], self[i], exponent[i], result[i])
|
||||||
result: (pow_backward_self(self_t.conj(), self_p, exponent_p) + pow_backward_exponent(exponent_t.conj(), self_p, exponent_p, result[i])).conj()
|
|
||||||
|
|
||||||
- name: _foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[]
|
- name: _foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[]
|
||||||
self: pow_backward(grads[i], self[i], exponent[i])
|
self: pow_backward(grads[i], self[i], exponent[i])
|
||||||
result: pow_backward(self_t.conj(), self_p, exponent[i]).conj()
|
|
||||||
|
|
||||||
- name: _foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[]
|
- name: _foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[]
|
||||||
exponent: pow_backward_exponent(grads[i], self, exponent[i], result[i])
|
exponent: pow_backward_exponent(grads[i], self, exponent[i], result[i])
|
||||||
@ -3058,22 +3055,12 @@
|
|||||||
# of `maximum` and `minimum` don't have the overload def with Scalar as their second argument.
|
# of `maximum` and `minimum` don't have the overload def with Scalar as their second argument.
|
||||||
- name: _foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
|
- name: _foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
|
||||||
self: at::where(self[i] == scalar, grads[i] / 2, grads[i]).masked_fill_(self[i] > scalar, 0)
|
self: at::where(self[i] == scalar, grads[i] / 2, grads[i]).masked_fill_(self[i] > scalar, 0)
|
||||||
result: scalar + at::where(self_p == scalar, at::scalar_tensor(0.5, result[i].options()), (self_p < scalar).to(result[i].scalar_type())) * (self_t - scalar)
|
|
||||||
|
|
||||||
- name: _foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
|
- name: _foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
|
||||||
self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] > scalars[i], 0)
|
self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] > scalars[i], 0)
|
||||||
result: scalars[i] + at::where(self_p == scalars[i], at::scalar_tensor(0.5, result[i].options()), (self_p < scalars[i]).to(result[i].scalar_type())) * (self_t - scalars[i])
|
|
||||||
|
|
||||||
- name: _foreach_maximum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
|
- name: _foreach_maximum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
|
||||||
self: at::where(self[i] == scalar, grads[i] / 2, grads[i]).masked_fill_(self[i] < scalar, 0)
|
self: at::where(self[i] == scalar, grads[i] / 2, grads[i]).masked_fill_(self[i] < scalar, 0)
|
||||||
result: scalar + at::where(self_p == scalar, at::scalar_tensor(0.5, result[i].options()), (self_p > scalar).to(result[i].scalar_type())) * (self_t - scalar)
|
|
||||||
|
|
||||||
- name: _foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
|
- name: _foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
|
||||||
self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] < scalars[i], 0)
|
self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] < scalars[i], 0)
|
||||||
result: scalars[i] + at::where(self_p == scalars[i], at::scalar_tensor(0.5, result[i].options()), (self_p > scalars[i]).to(result[i].scalar_type())) * (self_t - scalars[i])
|
|
||||||
|
|
||||||
# note(crcrpar): forward-mode AD is tricky for a simple string replace to handle:
|
|
||||||
# formula.replace("p", "ord") produces `norm_jvord(self_ord, self_t, ord, result)`
|
|
||||||
- name: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2) -> Tensor[]
|
|
||||||
self: norm_backward(grads[i], self[i], ord, result[i])
|
|
||||||
result: norm_jvp(self_p, self_t, ord, result[i])
|
|
||||||
|
|||||||
@ -1769,7 +1769,7 @@ def emit_body(
|
|||||||
|
|
||||||
def emit_any_has_forward_grad() -> List[str]:
|
def emit_any_has_forward_grad() -> List[str]:
|
||||||
content: List[str] = []
|
content: List[str] = []
|
||||||
if not is_foreach:
|
if not is_inplace_foreach:
|
||||||
for derivative in fw_derivatives:
|
for derivative in fw_derivatives:
|
||||||
requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative)
|
requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative)
|
||||||
if info and info.output_differentiability_conditions:
|
if info and info.output_differentiability_conditions:
|
||||||
@ -1783,17 +1783,11 @@ def emit_body(
|
|||||||
bool_vector_name = get_any_has_forward_grad_name(derivative.var_names)
|
bool_vector_name = get_any_has_forward_grad_name(derivative.var_names)
|
||||||
cur_derivative_conditions = [
|
cur_derivative_conditions = [
|
||||||
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(
|
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(
|
||||||
req_inp=(
|
req_inp=refargname2inplace_foreacharg[inp.name].name
|
||||||
inp.name
|
|
||||||
if not inplace
|
|
||||||
else refargname2inplace_foreacharg[inp.name].name
|
|
||||||
)
|
|
||||||
+ (
|
+ (
|
||||||
"[i]"
|
"[i]"
|
||||||
if is_tensor_list_type(
|
if is_tensor_list_type(
|
||||||
inp.type
|
refargname2inplace_foreacharg[inp.name].type
|
||||||
if not inplace
|
|
||||||
else refargname2inplace_foreacharg[inp.name].type
|
|
||||||
)
|
)
|
||||||
else ""
|
else ""
|
||||||
),
|
),
|
||||||
@ -1835,10 +1829,8 @@ def emit_body(
|
|||||||
unpacked_arguments = ""
|
unpacked_arguments = ""
|
||||||
for inp in differentiable_inputs:
|
for inp in differentiable_inputs:
|
||||||
inp_name = inp.name
|
inp_name = inp.name
|
||||||
is_input_tensorlist = is_foreach and is_tensor_list_type(
|
is_input_tensorlist = is_inplace_foreach and is_tensor_list_type(
|
||||||
inp.type
|
refargname2inplace_foreacharg[inp.name].type
|
||||||
if not inplace
|
|
||||||
else refargname2inplace_foreacharg[inp.name].type
|
|
||||||
)
|
)
|
||||||
input_suffix = "[i]" if is_input_tensorlist else ""
|
input_suffix = "[i]" if is_input_tensorlist else ""
|
||||||
if is_inplace_foreach:
|
if is_inplace_foreach:
|
||||||
@ -1895,14 +1887,14 @@ def emit_body(
|
|||||||
# Is there a way to get from BaseType to BaseCType
|
# Is there a way to get from BaseType to BaseCType
|
||||||
if len(derivative.var_types) == 1:
|
if len(derivative.var_types) == 1:
|
||||||
opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type()
|
opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type()
|
||||||
if not is_foreach:
|
if not is_inplace_foreach:
|
||||||
fw_grad_setters.append(
|
fw_grad_setters.append(
|
||||||
FW_DERIVATIVE_SETTER_TENSOR.substitute(
|
FW_DERIVATIVE_SETTER_TENSOR.substitute(
|
||||||
out_arg=res[0], is_inplace=is_inplace_str
|
out_arg=res[0], is_inplace=is_inplace_str
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert res[0] == ("result" if not inplace else "self")
|
assert res[0] == "self"
|
||||||
fw_grad_setters.append(
|
fw_grad_setters.append(
|
||||||
FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute(
|
FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute(
|
||||||
out_arg=res[0], is_inplace=is_inplace_str
|
out_arg=res[0], is_inplace=is_inplace_str
|
||||||
@ -1927,7 +1919,6 @@ def emit_body(
|
|||||||
assert (
|
assert (
|
||||||
len(derivative.var_types) == 1
|
len(derivative.var_types) == 1
|
||||||
), "Expected number of outputs to be 1 if function returns ListType"
|
), "Expected number of outputs to be 1 if function returns ListType"
|
||||||
if not is_foreach:
|
|
||||||
opt_res_grad_type = OptionalCType(
|
opt_res_grad_type = OptionalCType(
|
||||||
VectorCType(BaseCType(tensorT))
|
VectorCType(BaseCType(tensorT))
|
||||||
).cpp_type()
|
).cpp_type()
|
||||||
@ -1936,20 +1927,10 @@ def emit_body(
|
|||||||
out_arg=res[0], is_inplace=is_inplace_str
|
out_arg=res[0], is_inplace=is_inplace_str
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# TODO(crcrpar): Should this (= the foreach specific logic) be refactored somehow?
|
|
||||||
# Only out-place foreach functions that have entries in `tools/autograd/derivatives.yaml`
|
|
||||||
# can reach here.
|
|
||||||
opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type()
|
|
||||||
fw_grad_setters.append(
|
|
||||||
FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute(
|
|
||||||
out_arg=res[0], is_inplace=is_inplace_str
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unsupported output type for forward derivative")
|
raise RuntimeError("Unsupported output type for forward derivative")
|
||||||
|
|
||||||
if not is_foreach:
|
if not is_inplace_foreach:
|
||||||
fw_grad_opt_definition = f"{opt_res_grad_type} {'_'.join(res)}_new_fw_grad_opt = c10::nullopt;"
|
fw_grad_opt_definition = f"{opt_res_grad_type} {'_'.join(res)}_new_fw_grad_opt = c10::nullopt;"
|
||||||
# View ops create fw_grad that already is a view of the base's fw_grad so just use that
|
# View ops create fw_grad that already is a view of the base's fw_grad so just use that
|
||||||
content.append(
|
content.append(
|
||||||
@ -1967,9 +1948,7 @@ def emit_body(
|
|||||||
f"std::vector<{opt_res_grad_type}> {'_'.join(res)}_new_fw_grad_opts"
|
f"std::vector<{opt_res_grad_type}> {'_'.join(res)}_new_fw_grad_opts"
|
||||||
"(self.size(), c10::nullopt);"
|
"(self.size(), c10::nullopt);"
|
||||||
)
|
)
|
||||||
foreach_forward_grad_formula = derivative.formula
|
inplace_foreach_forward_grad_formula = derivative.formula
|
||||||
_foreach_arg: Union[Argument, DifferentiableInput]
|
|
||||||
if inplace:
|
|
||||||
for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items():
|
for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items():
|
||||||
# note(crcrpar): Massage only Scalar and ArrayRef<Scalar> here.
|
# note(crcrpar): Massage only Scalar and ArrayRef<Scalar> here.
|
||||||
if not (
|
if not (
|
||||||
@ -1979,19 +1958,11 @@ def emit_body(
|
|||||||
pattern = _foreach_arg.name
|
pattern = _foreach_arg.name
|
||||||
if isinstance(_foreach_arg.type, ListType):
|
if isinstance(_foreach_arg.type, ListType):
|
||||||
pattern += "[i]"
|
pattern += "[i]"
|
||||||
foreach_forward_grad_formula = (
|
inplace_foreach_forward_grad_formula = (
|
||||||
foreach_forward_grad_formula.replace(
|
inplace_foreach_forward_grad_formula.replace(
|
||||||
_ref_arg.name, pattern
|
_ref_arg.name, pattern
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
if (
|
|
||||||
"result" in foreach_forward_grad_formula
|
|
||||||
and "result[i]" not in foreach_forward_grad_formula
|
|
||||||
):
|
|
||||||
foreach_forward_grad_formula = (
|
|
||||||
foreach_forward_grad_formula.replace("result", "result[i]")
|
|
||||||
)
|
|
||||||
|
|
||||||
content.append(
|
content.append(
|
||||||
FW_DERIVATIVE_FOREACH_TEMPLATE.substitute(
|
FW_DERIVATIVE_FOREACH_TEMPLATE.substitute(
|
||||||
@ -2001,7 +1972,7 @@ def emit_body(
|
|||||||
get_any_has_forward_grad_name(derivative.var_names) + "[i]"
|
get_any_has_forward_grad_name(derivative.var_names) + "[i]"
|
||||||
for derivative in fw_derivatives
|
for derivative in fw_derivatives
|
||||||
),
|
),
|
||||||
formula=foreach_forward_grad_formula,
|
formula=inplace_foreach_forward_grad_formula,
|
||||||
unpacked_arguments=unpacked_arguments,
|
unpacked_arguments=unpacked_arguments,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -2063,11 +2034,7 @@ def emit_body(
|
|||||||
else:
|
else:
|
||||||
any_has_fw_grad = " || ".join(
|
any_has_fw_grad = " || ".join(
|
||||||
[
|
[
|
||||||
(
|
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
|
||||||
FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE
|
|
||||||
if is_tensor_list_type(inp.type)
|
|
||||||
else FW_DERIVATIVE_CHECK_TEMPLATE
|
|
||||||
).substitute(req_inp=inp.name)
|
|
||||||
for inp in differentiable_inputs
|
for inp in differentiable_inputs
|
||||||
if inp.name in derivative.required_inputs_fw_grad
|
if inp.name in derivative.required_inputs_fw_grad
|
||||||
]
|
]
|
||||||
|
|||||||
@ -272,13 +272,9 @@ def postprocess_forward_derivatives(
|
|||||||
args_with_derivatives: Sequence[Binding],
|
args_with_derivatives: Sequence[Binding],
|
||||||
) -> List[ForwardDerivative]:
|
) -> List[ForwardDerivative]:
|
||||||
def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]:
|
def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]:
|
||||||
is_foreach = f.func.name.name.base.startswith("_foreach_")
|
|
||||||
required_inputs = set()
|
required_inputs = set()
|
||||||
for arg in args_with_derivatives:
|
for arg in args_with_derivatives:
|
||||||
if (
|
if arg.type in ("at::TensorList", "const at::ITensorListRef &"):
|
||||||
arg.type in ("at::TensorList", "const at::ITensorListRef &")
|
|
||||||
and not is_foreach
|
|
||||||
):
|
|
||||||
# The functions taking TensorList handle everything internally
|
# The functions taking TensorList handle everything internally
|
||||||
continue
|
continue
|
||||||
arg_name = arg.name
|
arg_name = arg.name
|
||||||
|
|||||||
@ -8779,15 +8779,6 @@ foreach_unary_op_db: List[OpInfo] = [
|
|||||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||||
supports_autograd=True,
|
supports_autograd=True,
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
skips=(
|
|
||||||
# note(crcrpar): excluding cdouble from dtypes above might be better.
|
|
||||||
# Guard for `error: "In-place abs is not supported for complex tensors."`
|
|
||||||
DecorateInfo(
|
|
||||||
unittest.skip("_foreach_zero is not implemented"),
|
|
||||||
'TestForeach',
|
|
||||||
'test_outplace_forward_mode_AD',
|
|
||||||
),
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -462,82 +462,13 @@ def gen_foreach_derivativeinfo(
|
|||||||
for arg in foreach_function.func.arguments.flat_non_out
|
for arg in foreach_function.func.arguments.flat_non_out
|
||||||
if arg.name in all_var_names
|
if arg.name in all_var_names
|
||||||
]
|
]
|
||||||
|
|
||||||
forward_derivatives: List[ForwardDerivative] = []
|
|
||||||
fw_derivative: ForwardDerivative
|
|
||||||
for fw_derivative in ref_diff_info.forward_derivatives:
|
|
||||||
var_names: List[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
|
|
||||||
var_types: List[Type] = list(fw_derivative.var_types)
|
|
||||||
required_inputs_fw_grad: List[str] = []
|
|
||||||
required_inputs_primal: List[str] = []
|
|
||||||
if fw_derivative.required_inputs_fw_grad is not None:
|
|
||||||
required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
|
|
||||||
if fw_derivative.required_inputs_primal:
|
|
||||||
required_inputs_primal = list(fw_derivative.required_inputs_primal)
|
|
||||||
modified_formula = fw_derivative.formula
|
|
||||||
|
|
||||||
# Foreach's result is TensorList
|
|
||||||
if "result" in modified_formula:
|
|
||||||
modified_formula = fw_derivative.formula.replace("result", "result[i]")
|
|
||||||
|
|
||||||
for foreach_arg, ref_arg in zip(
|
|
||||||
foreach_function.func.arguments.flat_non_out,
|
|
||||||
ref_diff_info.func.func.arguments.flat_non_out,
|
|
||||||
):
|
|
||||||
# Modify reference forward formula
|
|
||||||
if (
|
|
||||||
isinstance(foreach_arg.type, ListType)
|
|
||||||
and not foreach_arg.type.is_tensor_like()
|
|
||||||
):
|
|
||||||
# Assuming ScalarList
|
|
||||||
modified_formula = modified_formula.replace(
|
|
||||||
ref_arg.name, foreach_arg.name + "[i]"
|
|
||||||
)
|
|
||||||
elif foreach_arg.type.is_tensor_like():
|
|
||||||
# Assuming TensorList / Tensor
|
|
||||||
assert isinstance(foreach_arg.type, ListType)
|
|
||||||
for suffix in ("_p", "_t"):
|
|
||||||
curr_expr = ref_arg.name + suffix
|
|
||||||
if curr_expr in modified_formula:
|
|
||||||
new_expr = foreach_arg.name + suffix
|
|
||||||
modified_formula = modified_formula.replace(curr_expr, new_expr)
|
|
||||||
else:
|
|
||||||
# Assuming Scalar
|
|
||||||
if foreach_arg.name != ref_arg.name:
|
|
||||||
modified_formula = modified_formula.replace(
|
|
||||||
ref_arg.name, foreach_arg.name
|
|
||||||
)
|
|
||||||
|
|
||||||
# note(crcrpar): there should exist a cooler way...
|
|
||||||
for i, name in enumerate(var_names):
|
|
||||||
if name == ref_arg.name:
|
|
||||||
var_names[i] = foreach_arg.name
|
|
||||||
var_types[i] = foreach_arg.type
|
|
||||||
for i, name in enumerate(required_inputs_fw_grad):
|
|
||||||
if name == ref_arg.name:
|
|
||||||
required_inputs_fw_grad[i] = foreach_arg.name
|
|
||||||
for i, name in enumerate(required_inputs_primal):
|
|
||||||
if name == ref_arg.name:
|
|
||||||
required_inputs_primal[i] = foreach_arg.name
|
|
||||||
forward_derivatives.append(
|
|
||||||
ForwardDerivative(
|
|
||||||
formula=modified_formula,
|
|
||||||
var_names=tuple(var_names),
|
|
||||||
var_types=tuple(var_types),
|
|
||||||
required_inputs_fw_grad=tuple(required_inputs_fw_grad),
|
|
||||||
required_inputs_primal=tuple(required_inputs_primal),
|
|
||||||
required_original_self_value=fw_derivative.required_original_self_value,
|
|
||||||
is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
DifferentiabilityInfo(
|
DifferentiabilityInfo(
|
||||||
name=foreach_function.func.name.name.base,
|
name=foreach_function.func.name.name.base,
|
||||||
func=foreach_function,
|
func=foreach_function,
|
||||||
op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}",
|
op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}",
|
||||||
derivatives=modified_derivative_formulas,
|
derivatives=modified_derivative_formulas,
|
||||||
forward_derivatives=forward_derivatives,
|
forward_derivatives=[],
|
||||||
all_saved_inputs=tuple(set(all_saved_inputs)),
|
all_saved_inputs=tuple(set(all_saved_inputs)),
|
||||||
all_saved_outputs=tuple(set(all_saved_outputs)),
|
all_saved_outputs=tuple(set(all_saved_outputs)),
|
||||||
available_named_gradients=(),
|
available_named_gradients=(),
|
||||||
|
|||||||
Reference in New Issue
Block a user