mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[gen_autograd_functions] rename some variables (#143166)
This is a follow-up from https://github.com/pytorch/pytorch/pull/141278. Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/143166 Approved by: https://github.com/soulitzer
This commit is contained in:
@ -99,7 +99,7 @@ FUNCTION_DEFINITION = CodeTemplate(
|
||||
"""\
|
||||
static variable_list ${op}_apply_functional(
|
||||
variable_list&& grads,
|
||||
std::array<bool,${num_vars}> needs_input_grad${,unpacked_saved_vars_signature})
|
||||
std::array<bool,${num_inputs}> needs_input_grad${,apply_functional_args_signature})
|
||||
{
|
||||
IndexRangeGenerator gen;
|
||||
${compute_index_ranges}
|
||||
@ -113,7 +113,7 @@ variable_list ${op}::apply(variable_list&& grads) {
|
||||
${asserts}
|
||||
${unpacks}
|
||||
${compute_needs_input_grad}
|
||||
return ${op}_apply_functional(std::move(grads), needs_input_grad${,unpacked_saved_vars});
|
||||
return ${op}_apply_functional(std::move(grads), needs_input_grad${,apply_functional_args});
|
||||
}
|
||||
|
||||
void ${op}::compiled_args(CompiledNodeArgs& args) {
|
||||
@ -587,24 +587,27 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
|
||||
compiled_args: list[str] = []
|
||||
apply_with_saved_before: list[str] = []
|
||||
apply_with_saved_after: list[str] = []
|
||||
unpacked_saved_vars: list[str] = []
|
||||
unpacked_saved_vars_ref_type: list[str] = []
|
||||
# Maps var_name to a unique index. The var_name is the
|
||||
# name of an input to the operator that needs a gradient (like "self", "other").
|
||||
# The index is the order in which they appear. We use this mapping
|
||||
# to populate needs_input_grad in some order and then grab values from it.
|
||||
var_name_map: dict[str, int] = {}
|
||||
apply_functional_args: list[str] = []
|
||||
apply_functional_args_ref_types: list[str] = []
|
||||
# Maps the name of an input (to the original forward operator;
|
||||
# examples are "self", "other") to the order in which they appear in the
|
||||
# operator.
|
||||
# For example; if the operator is foo(Tensor self, int64_t k, Tensor other),
|
||||
# the mapping is: {"self": 0, "other": 1}.
|
||||
# We use this mapping to populate needs_input_grad in some order and then grab
|
||||
# values from it.
|
||||
input_name_to_idx: dict[str, int] = {}
|
||||
|
||||
for idx, arg in enumerate(info.args_with_derivatives):
|
||||
if arg.type in TENSOR_LIST_LIKE_CTYPES:
|
||||
size = f"{arg.name}_size_"
|
||||
saved_list_sizes.append(f"size_t {arg.name}_size_;")
|
||||
unpacked_saved_vars.append(f"{arg.name}_size_")
|
||||
unpacked_saved_vars_ref_type.append("size_t")
|
||||
apply_functional_args.append(f"{arg.name}_size_")
|
||||
apply_functional_args_ref_types.append("size_t")
|
||||
else:
|
||||
size = "1"
|
||||
compute_index_ranges.append(f"auto {arg.name}_ix = gen.range({size});")
|
||||
var_name_map[arg.name] = idx
|
||||
input_name_to_idx[arg.name] = idx
|
||||
|
||||
def save_var(var: SavedAttribute, is_output: bool) -> None:
|
||||
name = var.nctype.name
|
||||
@ -856,8 +859,8 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
|
||||
if unpacked_ref_type is None:
|
||||
unpacked_ref_type = f"{saved_variables[-1].split(' ')[0]}&"
|
||||
unpacked_saved_vars.append(str(name))
|
||||
unpacked_saved_vars_ref_type.append(unpacked_ref_type)
|
||||
apply_functional_args.append(str(name))
|
||||
apply_functional_args_ref_types.append(unpacked_ref_type)
|
||||
|
||||
for var in sorted(info.all_saved_inputs, key=lambda sa: str(sa.nctype.name)):
|
||||
save_var(var, is_output=False)
|
||||
@ -872,8 +875,8 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
thread_lock = ""
|
||||
|
||||
if uses_retain_variables(info):
|
||||
unpacked_saved_vars.append("retain_variables")
|
||||
unpacked_saved_vars_ref_type.append("bool")
|
||||
apply_functional_args.append("retain_variables")
|
||||
apply_functional_args_ref_types.append("bool")
|
||||
will_release_variables = WILL_RELEASE_VARIABLES.substitute()
|
||||
else:
|
||||
will_release_variables = ""
|
||||
@ -919,14 +922,15 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
derivative_template.substitute(
|
||||
name=var_names[0],
|
||||
derivative=formula,
|
||||
idx=var_name_map[var_names[0]],
|
||||
idx=input_name_to_idx[var_names[0]],
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
if "grad_input_mask" in formula:
|
||||
masks = [
|
||||
f"needs_input_grad[{var_name_map[name]}]," for name in var_names
|
||||
f"needs_input_grad[{input_name_to_idx[name]}],"
|
||||
for name in var_names
|
||||
]
|
||||
grad_input_mask = GRAD_INPUT_MASK.substitute(
|
||||
n=len(var_names), masks=masks
|
||||
@ -934,14 +938,14 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
else:
|
||||
grad_input_mask = ""
|
||||
needs_input_grad = [
|
||||
f"needs_input_grad[{var_name_map[name]}]" for name in var_names
|
||||
f"needs_input_grad[{input_name_to_idx[name]}]" for name in var_names
|
||||
]
|
||||
needs_input_grad = " || ".join(needs_input_grad)
|
||||
copy_ranges: list[str] = []
|
||||
for i, n in enumerate(var_names):
|
||||
copy_ranges.append(
|
||||
DERIVATIVE_MULTI_COPY_RANGE.substitute(
|
||||
name=n, i=i, idx=var_name_map[n]
|
||||
name=n, i=i, idx=input_name_to_idx[n]
|
||||
)
|
||||
)
|
||||
return False, DERIVATIVE_MULTI.substitute(
|
||||
@ -961,7 +965,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
body.append(derivative_text)
|
||||
need_any_grad_defined_var |= checks_any_grad_defined
|
||||
|
||||
for name in var_name_map:
|
||||
for name in input_name_to_idx:
|
||||
masks.append(f"task_should_compute_output({{ {name}_ix }}),")
|
||||
|
||||
# Since single-output derivative formulas need to check if grads are
|
||||
@ -985,17 +989,18 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
compute_needs_input_grad = COMPUTE_NEEDS_INPUT_GRAD.substitute(
|
||||
n=len(masks), compute_index_ranges=compute_index_ranges, masks=masks
|
||||
)
|
||||
unpacked_saved_vars_signature = [
|
||||
f"{T} {x}" for T, x in zip(unpacked_saved_vars_ref_type, unpacked_saved_vars)
|
||||
apply_functional_args_signature = [
|
||||
f"{T} {x}"
|
||||
for T, x in zip(apply_functional_args_ref_types, apply_functional_args)
|
||||
]
|
||||
|
||||
return template.substitute(
|
||||
unpacks="\n".join(unpack),
|
||||
op=info.op,
|
||||
unpacked_saved_vars=unpacked_saved_vars,
|
||||
unpacked_saved_vars_signature=unpacked_saved_vars_signature,
|
||||
apply_functional_args=apply_functional_args,
|
||||
apply_functional_args_signature=apply_functional_args_signature,
|
||||
compute_needs_input_grad=compute_needs_input_grad,
|
||||
num_vars=len(var_name_map),
|
||||
num_inputs=len(input_name_to_idx),
|
||||
compute_index_ranges=compute_index_ranges,
|
||||
saved_variables=saved_variables,
|
||||
release_variables=release_variables,
|
||||
|
Reference in New Issue
Block a user