[gen_autograd_functions] rename some variables (#143166)

This is a follow-up from https://github.com/pytorch/pytorch/pull/141278.

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143166
Approved by: https://github.com/soulitzer
This commit is contained in:
rzou
2024-12-16 08:16:13 -08:00
committed by PyTorch MergeBot
parent 4c62275325
commit 557da8014d

View File

@ -99,7 +99,7 @@ FUNCTION_DEFINITION = CodeTemplate(
"""\
static variable_list ${op}_apply_functional(
variable_list&& grads,
std::array<bool,${num_vars}> needs_input_grad${,unpacked_saved_vars_signature})
std::array<bool,${num_inputs}> needs_input_grad${,apply_functional_args_signature})
{
IndexRangeGenerator gen;
${compute_index_ranges}
@ -113,7 +113,7 @@ variable_list ${op}::apply(variable_list&& grads) {
${asserts}
${unpacks}
${compute_needs_input_grad}
return ${op}_apply_functional(std::move(grads), needs_input_grad${,unpacked_saved_vars});
return ${op}_apply_functional(std::move(grads), needs_input_grad${,apply_functional_args});
}
void ${op}::compiled_args(CompiledNodeArgs& args) {
@ -587,24 +587,27 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
compiled_args: list[str] = []
apply_with_saved_before: list[str] = []
apply_with_saved_after: list[str] = []
unpacked_saved_vars: list[str] = []
unpacked_saved_vars_ref_type: list[str] = []
# Maps var_name to a unique index. The var_name is the
# name of an input to the operator that needs a gradient (like "self", "other").
# The index is the order in which they appear. We use this mapping
# to populate needs_input_grad in some order and then grab values from it.
var_name_map: dict[str, int] = {}
apply_functional_args: list[str] = []
apply_functional_args_ref_types: list[str] = []
# Maps the name of an input (to the original forward operator;
# examples are "self", "other") to the order in which they appear in the
# operator.
# For example; if the operator is foo(Tensor self, int64_t k, Tensor other),
# the mapping is: {"self": 0, "other": 1}.
# We use this mapping to populate needs_input_grad in some order and then grab
# values from it.
input_name_to_idx: dict[str, int] = {}
for idx, arg in enumerate(info.args_with_derivatives):
if arg.type in TENSOR_LIST_LIKE_CTYPES:
size = f"{arg.name}_size_"
saved_list_sizes.append(f"size_t {arg.name}_size_;")
unpacked_saved_vars.append(f"{arg.name}_size_")
unpacked_saved_vars_ref_type.append("size_t")
apply_functional_args.append(f"{arg.name}_size_")
apply_functional_args_ref_types.append("size_t")
else:
size = "1"
compute_index_ranges.append(f"auto {arg.name}_ix = gen.range({size});")
var_name_map[arg.name] = idx
input_name_to_idx[arg.name] = idx
def save_var(var: SavedAttribute, is_output: bool) -> None:
name = var.nctype.name
@ -856,8 +859,8 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
if unpacked_ref_type is None:
unpacked_ref_type = f"{saved_variables[-1].split(' ')[0]}&"
unpacked_saved_vars.append(str(name))
unpacked_saved_vars_ref_type.append(unpacked_ref_type)
apply_functional_args.append(str(name))
apply_functional_args_ref_types.append(unpacked_ref_type)
for var in sorted(info.all_saved_inputs, key=lambda sa: str(sa.nctype.name)):
save_var(var, is_output=False)
@ -872,8 +875,8 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
thread_lock = ""
if uses_retain_variables(info):
unpacked_saved_vars.append("retain_variables")
unpacked_saved_vars_ref_type.append("bool")
apply_functional_args.append("retain_variables")
apply_functional_args_ref_types.append("bool")
will_release_variables = WILL_RELEASE_VARIABLES.substitute()
else:
will_release_variables = ""
@ -919,14 +922,15 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
derivative_template.substitute(
name=var_names[0],
derivative=formula,
idx=var_name_map[var_names[0]],
idx=input_name_to_idx[var_names[0]],
),
)
else:
if "grad_input_mask" in formula:
masks = [
f"needs_input_grad[{var_name_map[name]}]," for name in var_names
f"needs_input_grad[{input_name_to_idx[name]}],"
for name in var_names
]
grad_input_mask = GRAD_INPUT_MASK.substitute(
n=len(var_names), masks=masks
@ -934,14 +938,14 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
else:
grad_input_mask = ""
needs_input_grad = [
f"needs_input_grad[{var_name_map[name]}]" for name in var_names
f"needs_input_grad[{input_name_to_idx[name]}]" for name in var_names
]
needs_input_grad = " || ".join(needs_input_grad)
copy_ranges: list[str] = []
for i, n in enumerate(var_names):
copy_ranges.append(
DERIVATIVE_MULTI_COPY_RANGE.substitute(
name=n, i=i, idx=var_name_map[n]
name=n, i=i, idx=input_name_to_idx[n]
)
)
return False, DERIVATIVE_MULTI.substitute(
@ -961,7 +965,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
body.append(derivative_text)
need_any_grad_defined_var |= checks_any_grad_defined
for name in var_name_map:
for name in input_name_to_idx:
masks.append(f"task_should_compute_output({{ {name}_ix }}),")
# Since single-output derivative formulas need to check if grads are
@ -985,17 +989,18 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
compute_needs_input_grad = COMPUTE_NEEDS_INPUT_GRAD.substitute(
n=len(masks), compute_index_ranges=compute_index_ranges, masks=masks
)
unpacked_saved_vars_signature = [
f"{T} {x}" for T, x in zip(unpacked_saved_vars_ref_type, unpacked_saved_vars)
apply_functional_args_signature = [
f"{T} {x}"
for T, x in zip(apply_functional_args_ref_types, apply_functional_args)
]
return template.substitute(
unpacks="\n".join(unpack),
op=info.op,
unpacked_saved_vars=unpacked_saved_vars,
unpacked_saved_vars_signature=unpacked_saved_vars_signature,
apply_functional_args=apply_functional_args,
apply_functional_args_signature=apply_functional_args_signature,
compute_needs_input_grad=compute_needs_input_grad,
num_vars=len(var_name_map),
num_inputs=len(input_name_to_idx),
compute_index_ranges=compute_index_ranges,
saved_variables=saved_variables,
release_variables=release_variables,