Revert "Revert "Reland "Add forward mode AD to out-place foreach functions (#102409) (#106043)""" (#106320)

Fixed a typo specifying the number of tensors and elements in the test having failed in slow gradcheck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106320
Approved by: https://github.com/soulitzer
This commit is contained in:
Masaki Kozuki
2023-08-18 23:01:42 +00:00
committed by PyTorch MergeBot
parent bc662ffff9
commit 5814380e7b
6 changed files with 252 additions and 62 deletions

View File

@ -982,39 +982,62 @@ class TestForeach(TestCase):
dtypes=OpDTypes.supported,
allowed_dtypes=(torch.float64, torch.complex128),
)
def test_inplace_forward_mode_AD(self, device, dtype, op):
def test_outplace_forward_mode_AD(self, device, dtype, op):
if not op.supports_forward_ad:
self.skipTest("forward AD not supported")
# note(crcrpar): The combinations below are failing in its forward path,
# which is before forward-mode AD happens. This function gates the combinations where
# - subtraction with Scalar/ScalarList of boolean value:
# - combinations where the in-place op in questions tries to write out complex result
# into float storage (= `self`)
def check_sample_eligibility(op, sample, dtype):
if (
op.name == "_foreach_sub"
and (
(isinstance(sample.args[0], list) and any(isinstance(a, bool) for a in sample.args[0]))
or isinstance(sample.args[0], bool)
)
):
return False, _BOOL_SUB_ERR_MSG
rhs_arg_has_complex_number = sample.args and ((
isinstance(sample.args[0], list)
and any(isinstance(a, complex) for a in sample.args[0])
) or (
isinstance(sample.args[0], complex)
))
if dtype == torch.float64 and rhs_arg_has_complex_number:
if op.name in ("_foreach_add", "_foreach_sub", "_foreach_mul", "_foreach_div"):
return False, "result type ComplexDouble can't be cast to the desired output type Double"
if op.name in ("_foreach_clamp_max", "_foreach_clamp_min"):
return False, "clamp is not supported for complex types"
if op.name == "_foreach_pow":
return False, "Found dtype Double but expected ComplexDouble"
# note(crcrpar): without this, some unary functions fail, unlike inplace and/or complex.
if dtype == torch.float64 and op.name in (
"_foreach_acos", "_foreach_asin", "_foreach_log10", "_foreach_log1p", "_foreach_log2",
"_foreach_log", "_foreach_pow", "_foreach_sqrt",
):
value_range = {"low": 0.5, "high": 1.0}
else:
value_range = {}
for sample in op.sample_inputs(
device, dtype, requires_grad=True, num_input_tensors=[5], **value_range,
):
# Skip `_foreach_pow.ScalarAndTensor(Scalar, Tensor[])`
if op.name == "_foreach_pow" and isinstance(sample.input, Number):
continue
return True, ""
def func(*tensorlist):
kwargs = {"alpha": sample.kwargs["alpha"]} if "alpha" in sample.kwargs else {}
return op.method_variant(tensorlist, *sample.args, **kwargs)
working_sample, err_msg_pattern = check_forward_mode_AD_sample(op, sample, dtype, False)
if not working_sample:
if not err_msg_pattern:
# lhs of float64 and rhs of complex.
continue
with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
gradcheck(
func,
sample.input,
raise_exception=True,
check_forward_ad=True,
check_batched_forward_grad=False,
check_backward_ad=False,
check_batched_grad=False,
)
else:
gradcheck(
func,
sample.input,
raise_exception=True,
check_forward_ad=True,
check_backward_ad=False,
check_batched_grad=False,
)
@ops(
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_lerp_op_db,
dtypes=OpDTypes.supported,
allowed_dtypes=(torch.float64, torch.complex128),
)
def test_inplace_forward_mode_AD(self, device, dtype, op):
if not op.supports_forward_ad:
self.skipTest("forward AD not supported")
for sample in op.sample_inputs(
device, dtype, requires_grad=True, num_input_tensors=[5], same_size=True,
@ -1026,7 +1049,7 @@ class TestForeach(TestCase):
op.inplace_variant(tuple(t.clone() for t in tensorlist), *sample.args, **kwargs)
return tensorlist
working_sample, err_msg_pattern = check_sample_eligibility(op, sample, dtype)
working_sample, err_msg_pattern = check_forward_mode_AD_sample(op, sample, dtype, True)
if not working_sample:
with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
gradcheck(
@ -1113,6 +1136,38 @@ class TestForeach(TestCase):
self.assertEqual(ref_input, sample.input)
# TODO(crcrpar): Hide this inside torch/testing/_internal.
# would end up adding another layer to `foreach_inputs_sample_func.__call__`
# so that we can use this function as something like the first argument of `filter` function.
# Even after moving this function to testing, I personally think it'd be better to check the error message.
def check_forward_mode_AD_sample(op, sample, dtype, is_inplace):
if (
op.name == "_foreach_sub"
and (
(isinstance(sample.args[0], list) and any(isinstance(a, bool) for a in sample.args[0]))
or isinstance(sample.args[0], bool)
)
):
return False, _BOOL_SUB_ERR_MSG
rhs_arg_has_complex_number = sample.args and ((
isinstance(sample.args[0], list)
and any(isinstance(a, complex) for a in sample.args[0])
) or (
isinstance(sample.args[0], complex)
))
if rhs_arg_has_complex_number and dtype == torch.float64:
if op.name in ("_foreach_clamp_max", "_foreach_clamp_min"):
return False, "clamp is not supported for complex types"
if not is_inplace:
return False, ""
else:
if op.name == "_foreach_pow":
return False, "Found dtype Double but expected ComplexDouble"
if op.name in ("_foreach_add", "_foreach_sub", "_foreach_mul", "_foreach_div"):
return False, "result type ComplexDouble can't be cast to the desired output type Double"
return True, ""
instantiate_device_type_tests(TestForeach, globals())

View File

@ -3040,13 +3040,16 @@
- name: _foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[]
self: div_tensor_self_backward(grads[i], other[i], self[i].scalar_type())
other: div_tensor_other_backward(grads[i], self[i], other[i])
result: (self_t - other_t * result[i]) / other_p
- name: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[]
self: pow_backward_self(grads[i], self[i], exponent[i])
exponent: pow_backward_exponent(grads[i], self[i], exponent[i], result[i])
result: (pow_backward_self(self_t.conj(), self_p, exponent_p) + pow_backward_exponent(exponent_t.conj(), self_p, exponent_p, result[i])).conj()
- name: _foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[]
self: pow_backward(grads[i], self[i], exponent[i])
result: pow_backward(self_t.conj(), self_p, exponent[i]).conj()
- name: _foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[]
exponent: pow_backward_exponent(grads[i], self, exponent[i], result[i])
@ -3055,12 +3058,22 @@
# of `maximum` and `minimum` don't have the overload def with Scalar as their second argument.
- name: _foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
self: at::where(self[i] == scalar, grads[i] / 2, grads[i]).masked_fill_(self[i] > scalar, 0)
result: scalar + at::where(self_p == scalar, at::scalar_tensor(0.5, result[i].options()), (self_p < scalar).to(result[i].scalar_type())) * (self_t - scalar)
- name: _foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] > scalars[i], 0)
result: scalars[i] + at::where(self_p == scalars[i], at::scalar_tensor(0.5, result[i].options()), (self_p < scalars[i]).to(result[i].scalar_type())) * (self_t - scalars[i])
- name: _foreach_maximum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
self: at::where(self[i] == scalar, grads[i] / 2, grads[i]).masked_fill_(self[i] < scalar, 0)
result: scalar + at::where(self_p == scalar, at::scalar_tensor(0.5, result[i].options()), (self_p > scalar).to(result[i].scalar_type())) * (self_t - scalar)
- name: _foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] < scalars[i], 0)
result: scalars[i] + at::where(self_p == scalars[i], at::scalar_tensor(0.5, result[i].options()), (self_p > scalars[i]).to(result[i].scalar_type())) * (self_t - scalars[i])
# note(crcrpar): forward-mode AD is tricky for a simple string replace to handle:
# formula.replace("p", "ord") produces `norm_jvord(self_ord, self_t, ord, result)`
- name: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2) -> Tensor[]
self: norm_backward(grads[i], self[i], ord, result[i])
result: norm_jvp(self_p, self_t, ord, result[i])

View File

@ -1772,7 +1772,7 @@ def emit_body(
def emit_any_has_forward_grad() -> List[str]:
content: List[str] = []
if not is_inplace_foreach:
if not is_foreach:
for derivative in fw_derivatives:
requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative)
if info and info.output_differentiability_conditions:
@ -1786,11 +1786,17 @@ def emit_body(
bool_vector_name = get_any_has_forward_grad_name(derivative.var_names)
cur_derivative_conditions = [
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(
req_inp=refargname2inplace_foreacharg[inp.name].name
req_inp=(
inp.name
if not inplace
else refargname2inplace_foreacharg[inp.name].name
)
+ (
"[i]"
if is_tensor_list_type(
refargname2inplace_foreacharg[inp.name].type
inp.type
if not inplace
else refargname2inplace_foreacharg[inp.name].type
)
else ""
),
@ -1832,8 +1838,10 @@ def emit_body(
unpacked_arguments = ""
for inp in differentiable_inputs:
inp_name = inp.name
is_input_tensorlist = is_inplace_foreach and is_tensor_list_type(
refargname2inplace_foreacharg[inp.name].type
is_input_tensorlist = is_foreach and is_tensor_list_type(
inp.type
if not inplace
else refargname2inplace_foreacharg[inp.name].type
)
input_suffix = "[i]" if is_input_tensorlist else ""
if is_inplace_foreach:
@ -1890,14 +1898,14 @@ def emit_body(
# Is there a way to get from BaseType to BaseCType
if len(derivative.var_types) == 1:
opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type()
if not is_inplace_foreach:
if not is_foreach:
fw_grad_setters.append(
FW_DERIVATIVE_SETTER_TENSOR.substitute(
out_arg=res[0], is_inplace=is_inplace_str
)
)
else:
assert res[0] == "self"
assert res[0] == ("result" if not inplace else "self")
fw_grad_setters.append(
FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute(
out_arg=res[0], is_inplace=is_inplace_str
@ -1922,18 +1930,29 @@ def emit_body(
assert (
len(derivative.var_types) == 1
), "Expected number of outputs to be 1 if function returns ListType"
opt_res_grad_type = OptionalCType(
VectorCType(BaseCType(tensorT))
).cpp_type()
fw_grad_setters.append(
FW_DERIVATIVE_SETTER_TENSOR_LIST.substitute(
out_arg=res[0], is_inplace=is_inplace_str
if not is_foreach:
opt_res_grad_type = OptionalCType(
VectorCType(BaseCType(tensorT))
).cpp_type()
fw_grad_setters.append(
FW_DERIVATIVE_SETTER_TENSOR_LIST.substitute(
out_arg=res[0], is_inplace=is_inplace_str
)
)
else:
# TODO(crcrpar): Should this (= the foreach specific logic) be refactored somehow?
# Only out-place foreach functions that have entries in `tools/autograd/derivatives.yaml`
# can reach here.
opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type()
fw_grad_setters.append(
FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute(
out_arg=res[0], is_inplace=is_inplace_str
)
)
)
else:
raise RuntimeError("Unsupported output type for forward derivative")
if not is_inplace_foreach:
if not is_foreach:
fw_grad_opt_definition = f"{opt_res_grad_type} {'_'.join(res)}_new_fw_grad_opt = c10::nullopt;"
# View ops create fw_grad that already is a view of the base's fw_grad so just use that
content.append(
@ -1951,20 +1970,30 @@ def emit_body(
f"std::vector<{opt_res_grad_type}> {'_'.join(res)}_new_fw_grad_opts"
"(self.size(), c10::nullopt);"
)
inplace_foreach_forward_grad_formula = derivative.formula
for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items():
# note(crcrpar): Massage only Scalar and ArrayRef<Scalar> here.
if not (
is_tensor_type(_foreach_arg.type)
or is_tensor_list_type(_foreach_arg.type)
):
pattern = _foreach_arg.name
if isinstance(_foreach_arg.type, ListType):
pattern += "[i]"
inplace_foreach_forward_grad_formula = (
inplace_foreach_forward_grad_formula.replace(
_ref_arg.name, pattern
foreach_forward_grad_formula = derivative.formula
_foreach_arg: Union[Argument, DifferentiableInput]
if inplace:
for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items():
# note(crcrpar): Massage only Scalar and ArrayRef<Scalar> here.
if not (
is_tensor_type(_foreach_arg.type)
or is_tensor_list_type(_foreach_arg.type)
):
pattern = _foreach_arg.name
if isinstance(_foreach_arg.type, ListType):
pattern += "[i]"
foreach_forward_grad_formula = (
foreach_forward_grad_formula.replace(
_ref_arg.name, pattern
)
)
else:
if (
"result" in foreach_forward_grad_formula
and "result[i]" not in foreach_forward_grad_formula
):
foreach_forward_grad_formula = (
foreach_forward_grad_formula.replace("result", "result[i]")
)
content.append(
@ -1975,7 +2004,7 @@ def emit_body(
get_any_has_forward_grad_name(derivative.var_names) + "[i]"
for derivative in fw_derivatives
),
formula=inplace_foreach_forward_grad_formula,
formula=foreach_forward_grad_formula,
unpacked_arguments=unpacked_arguments,
)
)
@ -2037,7 +2066,11 @@ def emit_body(
else:
any_has_fw_grad = " || ".join(
[
FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
(
FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE
if is_tensor_list_type(inp.type)
else FW_DERIVATIVE_CHECK_TEMPLATE
).substitute(req_inp=inp.name)
for inp in differentiable_inputs
if inp.name in derivative.required_inputs_fw_grad
]

View File

@ -272,9 +272,13 @@ def postprocess_forward_derivatives(
args_with_derivatives: Sequence[Binding],
) -> List[ForwardDerivative]:
def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]:
is_foreach = f.func.name.name.base.startswith("_foreach_")
required_inputs = set()
for arg in args_with_derivatives:
if arg.type in ("at::TensorList", "const at::ITensorListRef &"):
if (
arg.type in ("at::TensorList", "const at::ITensorListRef &")
and not is_foreach
):
# The functions taking TensorList handle everything internally
continue
arg_name = arg.name

View File

@ -8833,6 +8833,15 @@ foreach_unary_op_db: List[OpInfo] = [
supports_autograd=True,
supports_forward_ad=True,
has_no_out_of_place=True,
skips=(
# note(crcrpar): excluding cdouble from dtypes above might be better.
# Guard for `error: "In-place abs is not supported for complex tensors."`
DecorateInfo(
unittest.skip("_foreach_zero is not implemented"),
'TestForeach',
'test_outplace_forward_mode_AD',
),
),
),
ForeachFuncInfo(

View File

@ -7,6 +7,8 @@ from torchgen import local
from torchgen.api import cpp
from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
from torchgen.model import (
BaseTy,
BaseType,
FunctionSchema,
ListType,
NativeFunction,
@ -314,6 +316,7 @@ def is_foreach_func(f: NativeFunction) -> bool:
# is functional for their backward derivatives (and forward derivatives in the future), i.e.,
# they would find such one in `functional_info_by_signature`. There however are some exceptions:
_foreach_with_inplace_ref = {"_foreach_zero_"}
_foreach_with_tensor_overload = {"_foreach_mul.Tensor"}
# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function
@ -462,13 +465,86 @@ def gen_foreach_derivativeinfo(
for arg in foreach_function.func.arguments.flat_non_out
if arg.name in all_var_names
]
forward_derivatives: List[ForwardDerivative] = []
fw_derivative: ForwardDerivative
for fw_derivative in ref_diff_info.forward_derivatives:
var_names: List[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
var_types: List[Type] = list(fw_derivative.var_types)
required_inputs_fw_grad: List[str] = []
required_inputs_primal: List[str] = []
if fw_derivative.required_inputs_fw_grad is not None:
required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
if fw_derivative.required_inputs_primal:
required_inputs_primal = list(fw_derivative.required_inputs_primal)
modified_formula = fw_derivative.formula
# Foreach's result is TensorList
if "result" in modified_formula:
modified_formula = fw_derivative.formula.replace("result", "result[i]")
for foreach_arg, ref_arg in zip(
foreach_function.func.arguments.flat_non_out,
ref_diff_info.func.func.arguments.flat_non_out,
):
# Modify reference forward formula
if (
isinstance(foreach_arg.type, ListType)
and not foreach_arg.type.is_tensor_like()
):
# Assuming ScalarList
modified_formula = modified_formula.replace(
ref_arg.name, foreach_arg.name + "[i]"
)
elif foreach_arg.type.is_tensor_like():
# Assuming TensorList / Tensor
# assert isinstance(foreach_arg.type, ListType), f"{foreach_function.func.name}, {foreach_arg.type}"
assert isinstance(foreach_arg.type, ListType) or (
foreach_arg.type == BaseType(BaseTy.Tensor)
and str(foreach_function.func.name) in _foreach_with_tensor_overload
), f"{foreach_function.func.name}, {foreach_arg.type}"
for suffix in ("_p", "_t"):
curr_expr = ref_arg.name + suffix
if curr_expr in modified_formula:
new_expr = foreach_arg.name + suffix
modified_formula = modified_formula.replace(curr_expr, new_expr)
else:
# Assuming Scalar
if foreach_arg.name != ref_arg.name:
modified_formula = modified_formula.replace(
ref_arg.name, foreach_arg.name
)
# note(crcrpar): there should exist a cooler way...
for i, name in enumerate(var_names):
if name == ref_arg.name:
var_names[i] = foreach_arg.name
var_types[i] = foreach_arg.type
for i, name in enumerate(required_inputs_fw_grad):
if name == ref_arg.name:
required_inputs_fw_grad[i] = foreach_arg.name
for i, name in enumerate(required_inputs_primal):
if name == ref_arg.name:
required_inputs_primal[i] = foreach_arg.name
forward_derivatives.append(
ForwardDerivative(
formula=modified_formula,
var_names=tuple(var_names),
var_types=tuple(var_types),
required_inputs_fw_grad=tuple(required_inputs_fw_grad),
required_inputs_primal=tuple(required_inputs_primal),
required_original_self_value=fw_derivative.required_original_self_value,
is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula,
)
)
return (
DifferentiabilityInfo(
name=foreach_function.func.name.name.base,
func=foreach_function,
op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}",
derivatives=modified_derivative_formulas,
forward_derivatives=[],
forward_derivatives=forward_derivatives,
all_saved_inputs=tuple(set(all_saved_inputs)),
all_saved_outputs=tuple(set(all_saved_outputs)),
available_named_gradients=(),