mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Revert "Reland "Add forward mode AD to out-place foreach functions (#102409) (#106043)""" (#106320)
Fixed a typo specifying the number of tensors and elements in the test having failed in slow gradcheck Pull Request resolved: https://github.com/pytorch/pytorch/pull/106320 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
bc662ffff9
commit
5814380e7b
@ -982,39 +982,62 @@ class TestForeach(TestCase):
|
||||
dtypes=OpDTypes.supported,
|
||||
allowed_dtypes=(torch.float64, torch.complex128),
|
||||
)
|
||||
def test_inplace_forward_mode_AD(self, device, dtype, op):
|
||||
def test_outplace_forward_mode_AD(self, device, dtype, op):
|
||||
if not op.supports_forward_ad:
|
||||
self.skipTest("forward AD not supported")
|
||||
|
||||
# note(crcrpar): The combinations below are failing in its forward path,
|
||||
# which is before forward-mode AD happens. This function gates the combinations where
|
||||
# - subtraction with Scalar/ScalarList of boolean value:
|
||||
# - combinations where the in-place op in questions tries to write out complex result
|
||||
# into float storage (= `self`)
|
||||
def check_sample_eligibility(op, sample, dtype):
|
||||
if (
|
||||
op.name == "_foreach_sub"
|
||||
and (
|
||||
(isinstance(sample.args[0], list) and any(isinstance(a, bool) for a in sample.args[0]))
|
||||
or isinstance(sample.args[0], bool)
|
||||
)
|
||||
):
|
||||
return False, _BOOL_SUB_ERR_MSG
|
||||
rhs_arg_has_complex_number = sample.args and ((
|
||||
isinstance(sample.args[0], list)
|
||||
and any(isinstance(a, complex) for a in sample.args[0])
|
||||
) or (
|
||||
isinstance(sample.args[0], complex)
|
||||
))
|
||||
if dtype == torch.float64 and rhs_arg_has_complex_number:
|
||||
if op.name in ("_foreach_add", "_foreach_sub", "_foreach_mul", "_foreach_div"):
|
||||
return False, "result type ComplexDouble can't be cast to the desired output type Double"
|
||||
if op.name in ("_foreach_clamp_max", "_foreach_clamp_min"):
|
||||
return False, "clamp is not supported for complex types"
|
||||
if op.name == "_foreach_pow":
|
||||
return False, "Found dtype Double but expected ComplexDouble"
|
||||
# note(crcrpar): without this, some unary functions fail, unlike inplace and/or complex.
|
||||
if dtype == torch.float64 and op.name in (
|
||||
"_foreach_acos", "_foreach_asin", "_foreach_log10", "_foreach_log1p", "_foreach_log2",
|
||||
"_foreach_log", "_foreach_pow", "_foreach_sqrt",
|
||||
):
|
||||
value_range = {"low": 0.5, "high": 1.0}
|
||||
else:
|
||||
value_range = {}
|
||||
for sample in op.sample_inputs(
|
||||
device, dtype, requires_grad=True, num_input_tensors=[5], **value_range,
|
||||
):
|
||||
# Skip `_foreach_pow.ScalarAndTensor(Scalar, Tensor[])`
|
||||
if op.name == "_foreach_pow" and isinstance(sample.input, Number):
|
||||
continue
|
||||
|
||||
return True, ""
|
||||
def func(*tensorlist):
|
||||
kwargs = {"alpha": sample.kwargs["alpha"]} if "alpha" in sample.kwargs else {}
|
||||
return op.method_variant(tensorlist, *sample.args, **kwargs)
|
||||
|
||||
working_sample, err_msg_pattern = check_forward_mode_AD_sample(op, sample, dtype, False)
|
||||
if not working_sample:
|
||||
if not err_msg_pattern:
|
||||
# lhs of float64 and rhs of complex.
|
||||
continue
|
||||
with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
|
||||
gradcheck(
|
||||
func,
|
||||
sample.input,
|
||||
raise_exception=True,
|
||||
check_forward_ad=True,
|
||||
check_batched_forward_grad=False,
|
||||
check_backward_ad=False,
|
||||
check_batched_grad=False,
|
||||
)
|
||||
else:
|
||||
gradcheck(
|
||||
func,
|
||||
sample.input,
|
||||
raise_exception=True,
|
||||
check_forward_ad=True,
|
||||
check_backward_ad=False,
|
||||
check_batched_grad=False,
|
||||
)
|
||||
|
||||
@ops(
|
||||
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_lerp_op_db,
|
||||
dtypes=OpDTypes.supported,
|
||||
allowed_dtypes=(torch.float64, torch.complex128),
|
||||
)
|
||||
def test_inplace_forward_mode_AD(self, device, dtype, op):
|
||||
if not op.supports_forward_ad:
|
||||
self.skipTest("forward AD not supported")
|
||||
|
||||
for sample in op.sample_inputs(
|
||||
device, dtype, requires_grad=True, num_input_tensors=[5], same_size=True,
|
||||
@ -1026,7 +1049,7 @@ class TestForeach(TestCase):
|
||||
op.inplace_variant(tuple(t.clone() for t in tensorlist), *sample.args, **kwargs)
|
||||
return tensorlist
|
||||
|
||||
working_sample, err_msg_pattern = check_sample_eligibility(op, sample, dtype)
|
||||
working_sample, err_msg_pattern = check_forward_mode_AD_sample(op, sample, dtype, True)
|
||||
if not working_sample:
|
||||
with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
|
||||
gradcheck(
|
||||
@ -1113,6 +1136,38 @@ class TestForeach(TestCase):
|
||||
self.assertEqual(ref_input, sample.input)
|
||||
|
||||
|
||||
# TODO(crcrpar): Hide this inside torch/testing/_internal.
|
||||
# would end up adding another layer to `foreach_inputs_sample_func.__call__`
|
||||
# so that we can use this function as something like the first argument of `filter` function.
|
||||
# Even after moving this function to testing, I personally think it'd be better to check the error message.
|
||||
def check_forward_mode_AD_sample(op, sample, dtype, is_inplace):
|
||||
if (
|
||||
op.name == "_foreach_sub"
|
||||
and (
|
||||
(isinstance(sample.args[0], list) and any(isinstance(a, bool) for a in sample.args[0]))
|
||||
or isinstance(sample.args[0], bool)
|
||||
)
|
||||
):
|
||||
return False, _BOOL_SUB_ERR_MSG
|
||||
rhs_arg_has_complex_number = sample.args and ((
|
||||
isinstance(sample.args[0], list)
|
||||
and any(isinstance(a, complex) for a in sample.args[0])
|
||||
) or (
|
||||
isinstance(sample.args[0], complex)
|
||||
))
|
||||
if rhs_arg_has_complex_number and dtype == torch.float64:
|
||||
if op.name in ("_foreach_clamp_max", "_foreach_clamp_min"):
|
||||
return False, "clamp is not supported for complex types"
|
||||
if not is_inplace:
|
||||
return False, ""
|
||||
else:
|
||||
if op.name == "_foreach_pow":
|
||||
return False, "Found dtype Double but expected ComplexDouble"
|
||||
if op.name in ("_foreach_add", "_foreach_sub", "_foreach_mul", "_foreach_div"):
|
||||
return False, "result type ComplexDouble can't be cast to the desired output type Double"
|
||||
return True, ""
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestForeach, globals())
|
||||
|
||||
|
||||
|
||||
@ -3040,13 +3040,16 @@
|
||||
- name: _foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[]
|
||||
self: div_tensor_self_backward(grads[i], other[i], self[i].scalar_type())
|
||||
other: div_tensor_other_backward(grads[i], self[i], other[i])
|
||||
result: (self_t - other_t * result[i]) / other_p
|
||||
|
||||
- name: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[]
|
||||
self: pow_backward_self(grads[i], self[i], exponent[i])
|
||||
exponent: pow_backward_exponent(grads[i], self[i], exponent[i], result[i])
|
||||
result: (pow_backward_self(self_t.conj(), self_p, exponent_p) + pow_backward_exponent(exponent_t.conj(), self_p, exponent_p, result[i])).conj()
|
||||
|
||||
- name: _foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[]
|
||||
self: pow_backward(grads[i], self[i], exponent[i])
|
||||
result: pow_backward(self_t.conj(), self_p, exponent[i]).conj()
|
||||
|
||||
- name: _foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[]
|
||||
exponent: pow_backward_exponent(grads[i], self, exponent[i], result[i])
|
||||
@ -3055,12 +3058,22 @@
|
||||
# of `maximum` and `minimum` don't have the overload def with Scalar as their second argument.
|
||||
- name: _foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
|
||||
self: at::where(self[i] == scalar, grads[i] / 2, grads[i]).masked_fill_(self[i] > scalar, 0)
|
||||
result: scalar + at::where(self_p == scalar, at::scalar_tensor(0.5, result[i].options()), (self_p < scalar).to(result[i].scalar_type())) * (self_t - scalar)
|
||||
|
||||
- name: _foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
|
||||
self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] > scalars[i], 0)
|
||||
result: scalars[i] + at::where(self_p == scalars[i], at::scalar_tensor(0.5, result[i].options()), (self_p < scalars[i]).to(result[i].scalar_type())) * (self_t - scalars[i])
|
||||
|
||||
- name: _foreach_maximum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
|
||||
self: at::where(self[i] == scalar, grads[i] / 2, grads[i]).masked_fill_(self[i] < scalar, 0)
|
||||
result: scalar + at::where(self_p == scalar, at::scalar_tensor(0.5, result[i].options()), (self_p > scalar).to(result[i].scalar_type())) * (self_t - scalar)
|
||||
|
||||
- name: _foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
|
||||
self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] < scalars[i], 0)
|
||||
result: scalars[i] + at::where(self_p == scalars[i], at::scalar_tensor(0.5, result[i].options()), (self_p > scalars[i]).to(result[i].scalar_type())) * (self_t - scalars[i])
|
||||
|
||||
# note(crcrpar): forward-mode AD is tricky for a simple string replace to handle:
|
||||
# formula.replace("p", "ord") produces `norm_jvord(self_ord, self_t, ord, result)`
|
||||
- name: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2) -> Tensor[]
|
||||
self: norm_backward(grads[i], self[i], ord, result[i])
|
||||
result: norm_jvp(self_p, self_t, ord, result[i])
|
||||
|
||||
@ -1772,7 +1772,7 @@ def emit_body(
|
||||
|
||||
def emit_any_has_forward_grad() -> List[str]:
|
||||
content: List[str] = []
|
||||
if not is_inplace_foreach:
|
||||
if not is_foreach:
|
||||
for derivative in fw_derivatives:
|
||||
requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative)
|
||||
if info and info.output_differentiability_conditions:
|
||||
@ -1786,11 +1786,17 @@ def emit_body(
|
||||
bool_vector_name = get_any_has_forward_grad_name(derivative.var_names)
|
||||
cur_derivative_conditions = [
|
||||
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(
|
||||
req_inp=refargname2inplace_foreacharg[inp.name].name
|
||||
req_inp=(
|
||||
inp.name
|
||||
if not inplace
|
||||
else refargname2inplace_foreacharg[inp.name].name
|
||||
)
|
||||
+ (
|
||||
"[i]"
|
||||
if is_tensor_list_type(
|
||||
refargname2inplace_foreacharg[inp.name].type
|
||||
inp.type
|
||||
if not inplace
|
||||
else refargname2inplace_foreacharg[inp.name].type
|
||||
)
|
||||
else ""
|
||||
),
|
||||
@ -1832,8 +1838,10 @@ def emit_body(
|
||||
unpacked_arguments = ""
|
||||
for inp in differentiable_inputs:
|
||||
inp_name = inp.name
|
||||
is_input_tensorlist = is_inplace_foreach and is_tensor_list_type(
|
||||
refargname2inplace_foreacharg[inp.name].type
|
||||
is_input_tensorlist = is_foreach and is_tensor_list_type(
|
||||
inp.type
|
||||
if not inplace
|
||||
else refargname2inplace_foreacharg[inp.name].type
|
||||
)
|
||||
input_suffix = "[i]" if is_input_tensorlist else ""
|
||||
if is_inplace_foreach:
|
||||
@ -1890,14 +1898,14 @@ def emit_body(
|
||||
# Is there a way to get from BaseType to BaseCType
|
||||
if len(derivative.var_types) == 1:
|
||||
opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type()
|
||||
if not is_inplace_foreach:
|
||||
if not is_foreach:
|
||||
fw_grad_setters.append(
|
||||
FW_DERIVATIVE_SETTER_TENSOR.substitute(
|
||||
out_arg=res[0], is_inplace=is_inplace_str
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert res[0] == "self"
|
||||
assert res[0] == ("result" if not inplace else "self")
|
||||
fw_grad_setters.append(
|
||||
FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute(
|
||||
out_arg=res[0], is_inplace=is_inplace_str
|
||||
@ -1922,18 +1930,29 @@ def emit_body(
|
||||
assert (
|
||||
len(derivative.var_types) == 1
|
||||
), "Expected number of outputs to be 1 if function returns ListType"
|
||||
opt_res_grad_type = OptionalCType(
|
||||
VectorCType(BaseCType(tensorT))
|
||||
).cpp_type()
|
||||
fw_grad_setters.append(
|
||||
FW_DERIVATIVE_SETTER_TENSOR_LIST.substitute(
|
||||
out_arg=res[0], is_inplace=is_inplace_str
|
||||
if not is_foreach:
|
||||
opt_res_grad_type = OptionalCType(
|
||||
VectorCType(BaseCType(tensorT))
|
||||
).cpp_type()
|
||||
fw_grad_setters.append(
|
||||
FW_DERIVATIVE_SETTER_TENSOR_LIST.substitute(
|
||||
out_arg=res[0], is_inplace=is_inplace_str
|
||||
)
|
||||
)
|
||||
else:
|
||||
# TODO(crcrpar): Should this (= the foreach specific logic) be refactored somehow?
|
||||
# Only out-place foreach functions that have entries in `tools/autograd/derivatives.yaml`
|
||||
# can reach here.
|
||||
opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type()
|
||||
fw_grad_setters.append(
|
||||
FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute(
|
||||
out_arg=res[0], is_inplace=is_inplace_str
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Unsupported output type for forward derivative")
|
||||
|
||||
if not is_inplace_foreach:
|
||||
if not is_foreach:
|
||||
fw_grad_opt_definition = f"{opt_res_grad_type} {'_'.join(res)}_new_fw_grad_opt = c10::nullopt;"
|
||||
# View ops create fw_grad that already is a view of the base's fw_grad so just use that
|
||||
content.append(
|
||||
@ -1951,20 +1970,30 @@ def emit_body(
|
||||
f"std::vector<{opt_res_grad_type}> {'_'.join(res)}_new_fw_grad_opts"
|
||||
"(self.size(), c10::nullopt);"
|
||||
)
|
||||
inplace_foreach_forward_grad_formula = derivative.formula
|
||||
for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items():
|
||||
# note(crcrpar): Massage only Scalar and ArrayRef<Scalar> here.
|
||||
if not (
|
||||
is_tensor_type(_foreach_arg.type)
|
||||
or is_tensor_list_type(_foreach_arg.type)
|
||||
):
|
||||
pattern = _foreach_arg.name
|
||||
if isinstance(_foreach_arg.type, ListType):
|
||||
pattern += "[i]"
|
||||
inplace_foreach_forward_grad_formula = (
|
||||
inplace_foreach_forward_grad_formula.replace(
|
||||
_ref_arg.name, pattern
|
||||
foreach_forward_grad_formula = derivative.formula
|
||||
_foreach_arg: Union[Argument, DifferentiableInput]
|
||||
if inplace:
|
||||
for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items():
|
||||
# note(crcrpar): Massage only Scalar and ArrayRef<Scalar> here.
|
||||
if not (
|
||||
is_tensor_type(_foreach_arg.type)
|
||||
or is_tensor_list_type(_foreach_arg.type)
|
||||
):
|
||||
pattern = _foreach_arg.name
|
||||
if isinstance(_foreach_arg.type, ListType):
|
||||
pattern += "[i]"
|
||||
foreach_forward_grad_formula = (
|
||||
foreach_forward_grad_formula.replace(
|
||||
_ref_arg.name, pattern
|
||||
)
|
||||
)
|
||||
else:
|
||||
if (
|
||||
"result" in foreach_forward_grad_formula
|
||||
and "result[i]" not in foreach_forward_grad_formula
|
||||
):
|
||||
foreach_forward_grad_formula = (
|
||||
foreach_forward_grad_formula.replace("result", "result[i]")
|
||||
)
|
||||
|
||||
content.append(
|
||||
@ -1975,7 +2004,7 @@ def emit_body(
|
||||
get_any_has_forward_grad_name(derivative.var_names) + "[i]"
|
||||
for derivative in fw_derivatives
|
||||
),
|
||||
formula=inplace_foreach_forward_grad_formula,
|
||||
formula=foreach_forward_grad_formula,
|
||||
unpacked_arguments=unpacked_arguments,
|
||||
)
|
||||
)
|
||||
@ -2037,7 +2066,11 @@ def emit_body(
|
||||
else:
|
||||
any_has_fw_grad = " || ".join(
|
||||
[
|
||||
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
|
||||
(
|
||||
FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE
|
||||
if is_tensor_list_type(inp.type)
|
||||
else FW_DERIVATIVE_CHECK_TEMPLATE
|
||||
).substitute(req_inp=inp.name)
|
||||
for inp in differentiable_inputs
|
||||
if inp.name in derivative.required_inputs_fw_grad
|
||||
]
|
||||
|
||||
@ -272,9 +272,13 @@ def postprocess_forward_derivatives(
|
||||
args_with_derivatives: Sequence[Binding],
|
||||
) -> List[ForwardDerivative]:
|
||||
def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]:
|
||||
is_foreach = f.func.name.name.base.startswith("_foreach_")
|
||||
required_inputs = set()
|
||||
for arg in args_with_derivatives:
|
||||
if arg.type in ("at::TensorList", "const at::ITensorListRef &"):
|
||||
if (
|
||||
arg.type in ("at::TensorList", "const at::ITensorListRef &")
|
||||
and not is_foreach
|
||||
):
|
||||
# The functions taking TensorList handle everything internally
|
||||
continue
|
||||
arg_name = arg.name
|
||||
|
||||
@ -8833,6 +8833,15 @@ foreach_unary_op_db: List[OpInfo] = [
|
||||
supports_autograd=True,
|
||||
supports_forward_ad=True,
|
||||
has_no_out_of_place=True,
|
||||
skips=(
|
||||
# note(crcrpar): excluding cdouble from dtypes above might be better.
|
||||
# Guard for `error: "In-place abs is not supported for complex tensors."`
|
||||
DecorateInfo(
|
||||
unittest.skip("_foreach_zero is not implemented"),
|
||||
'TestForeach',
|
||||
'test_outplace_forward_mode_AD',
|
||||
),
|
||||
),
|
||||
),
|
||||
|
||||
ForeachFuncInfo(
|
||||
|
||||
@ -7,6 +7,8 @@ from torchgen import local
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
|
||||
from torchgen.model import (
|
||||
BaseTy,
|
||||
BaseType,
|
||||
FunctionSchema,
|
||||
ListType,
|
||||
NativeFunction,
|
||||
@ -314,6 +316,7 @@ def is_foreach_func(f: NativeFunction) -> bool:
|
||||
# is functional for their backward derivatives (and forward derivatives in the future), i.e.,
|
||||
# they would find such one in `functional_info_by_signature`. There however are some exceptions:
|
||||
_foreach_with_inplace_ref = {"_foreach_zero_"}
|
||||
_foreach_with_tensor_overload = {"_foreach_mul.Tensor"}
|
||||
|
||||
|
||||
# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function
|
||||
@ -462,13 +465,86 @@ def gen_foreach_derivativeinfo(
|
||||
for arg in foreach_function.func.arguments.flat_non_out
|
||||
if arg.name in all_var_names
|
||||
]
|
||||
|
||||
forward_derivatives: List[ForwardDerivative] = []
|
||||
fw_derivative: ForwardDerivative
|
||||
for fw_derivative in ref_diff_info.forward_derivatives:
|
||||
var_names: List[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
|
||||
var_types: List[Type] = list(fw_derivative.var_types)
|
||||
required_inputs_fw_grad: List[str] = []
|
||||
required_inputs_primal: List[str] = []
|
||||
if fw_derivative.required_inputs_fw_grad is not None:
|
||||
required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
|
||||
if fw_derivative.required_inputs_primal:
|
||||
required_inputs_primal = list(fw_derivative.required_inputs_primal)
|
||||
modified_formula = fw_derivative.formula
|
||||
|
||||
# Foreach's result is TensorList
|
||||
if "result" in modified_formula:
|
||||
modified_formula = fw_derivative.formula.replace("result", "result[i]")
|
||||
|
||||
for foreach_arg, ref_arg in zip(
|
||||
foreach_function.func.arguments.flat_non_out,
|
||||
ref_diff_info.func.func.arguments.flat_non_out,
|
||||
):
|
||||
# Modify reference forward formula
|
||||
if (
|
||||
isinstance(foreach_arg.type, ListType)
|
||||
and not foreach_arg.type.is_tensor_like()
|
||||
):
|
||||
# Assuming ScalarList
|
||||
modified_formula = modified_formula.replace(
|
||||
ref_arg.name, foreach_arg.name + "[i]"
|
||||
)
|
||||
elif foreach_arg.type.is_tensor_like():
|
||||
# Assuming TensorList / Tensor
|
||||
# assert isinstance(foreach_arg.type, ListType), f"{foreach_function.func.name}, {foreach_arg.type}"
|
||||
assert isinstance(foreach_arg.type, ListType) or (
|
||||
foreach_arg.type == BaseType(BaseTy.Tensor)
|
||||
and str(foreach_function.func.name) in _foreach_with_tensor_overload
|
||||
), f"{foreach_function.func.name}, {foreach_arg.type}"
|
||||
for suffix in ("_p", "_t"):
|
||||
curr_expr = ref_arg.name + suffix
|
||||
if curr_expr in modified_formula:
|
||||
new_expr = foreach_arg.name + suffix
|
||||
modified_formula = modified_formula.replace(curr_expr, new_expr)
|
||||
else:
|
||||
# Assuming Scalar
|
||||
if foreach_arg.name != ref_arg.name:
|
||||
modified_formula = modified_formula.replace(
|
||||
ref_arg.name, foreach_arg.name
|
||||
)
|
||||
|
||||
# note(crcrpar): there should exist a cooler way...
|
||||
for i, name in enumerate(var_names):
|
||||
if name == ref_arg.name:
|
||||
var_names[i] = foreach_arg.name
|
||||
var_types[i] = foreach_arg.type
|
||||
for i, name in enumerate(required_inputs_fw_grad):
|
||||
if name == ref_arg.name:
|
||||
required_inputs_fw_grad[i] = foreach_arg.name
|
||||
for i, name in enumerate(required_inputs_primal):
|
||||
if name == ref_arg.name:
|
||||
required_inputs_primal[i] = foreach_arg.name
|
||||
forward_derivatives.append(
|
||||
ForwardDerivative(
|
||||
formula=modified_formula,
|
||||
var_names=tuple(var_names),
|
||||
var_types=tuple(var_types),
|
||||
required_inputs_fw_grad=tuple(required_inputs_fw_grad),
|
||||
required_inputs_primal=tuple(required_inputs_primal),
|
||||
required_original_self_value=fw_derivative.required_original_self_value,
|
||||
is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula,
|
||||
)
|
||||
)
|
||||
|
||||
return (
|
||||
DifferentiabilityInfo(
|
||||
name=foreach_function.func.name.name.base,
|
||||
func=foreach_function,
|
||||
op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}",
|
||||
derivatives=modified_derivative_formulas,
|
||||
forward_derivatives=[],
|
||||
forward_derivatives=forward_derivatives,
|
||||
all_saved_inputs=tuple(set(all_saved_inputs)),
|
||||
all_saved_outputs=tuple(set(all_saved_outputs)),
|
||||
available_named_gradients=(),
|
||||
|
||||
Reference in New Issue
Block a user