[dynamo] Avoid unncessary caching source codegen (#155376)

We only need to cache a source (e.g., `x.y.z`) into a temporary local if
it's used multiple times in the codegen, otherwise we'd just be creating
redundant `DUP` and `STORE_FAST tmp_...` instructions, which might
degrade perf and definitely makes generated bytecode harder to read.

Example:
```python
import torch

@torch.compile(backend="eager")
def fn(x, y):
    return x + y

fn(torch.ones(2), torch.ones(1))
```

Original bytecode:
```verbatim
[0/0] [__bytecode]   3           0 RESUME                   0
[0/0] [__bytecode]
[0/0] [__bytecode]   5           2 LOAD_FAST                0 (x)
[0/0] [__bytecode]               4 LOAD_FAST                1 (y)
[0/0] [__bytecode]               6 BINARY_OP                0 (+)
[0/0] [__bytecode]              10 RETURN_VALUE
```

Modified bytecode (before this patch):
```verbatim
[__bytecode]   3           0 RESUME                   0
[__bytecode]               2 LOAD_GLOBAL              1 (NULL + __compiled_fn_1_578c8d9a_2a9b_4d15_bac7_267591cdee32)
[__bytecode]              14 LOAD_FAST                0 (x)
[__bytecode]              16 COPY                     1
[__bytecode]              18 STORE_FAST               3 (tmp_1)
[__bytecode]              20 LOAD_FAST                1 (y)
[__bytecode]              22 COPY                     1
[__bytecode]              24 STORE_FAST               4 (tmp_2)
[__bytecode]              26 PRECALL                  2
[__bytecode]              30 CALL                     2
[__bytecode]              40 STORE_FAST               2 (graph_out_0)
[__bytecode]              42 LOAD_FAST                2 (graph_out_0)
[__bytecode]              44 LOAD_CONST               1 (0)
[__bytecode]              46 BINARY_SUBSCR
[__bytecode]              56 DELETE_FAST              2 (graph_out_0)
[__bytecode]              58 RETURN_VALUE
```

Modified bytecode (after this patch):
```verbatim
[__bytecode]   3           0 RESUME                   0
[__bytecode]               2 LOAD_GLOBAL              1 (NULL + __compiled_fn_1_2c498af2_ce5c_49cb_abba_a0c7489b09ce)
[__bytecode]              14 LOAD_FAST                0 (x)
[__bytecode]              16 LOAD_FAST                1 (y)
[__bytecode]              18 PRECALL                  2
[__bytecode]              22 CALL                     2
[__bytecode]              32 STORE_FAST               2 (graph_out_0)
[__bytecode]              34 LOAD_FAST                2 (graph_out_0)
[__bytecode]              36 LOAD_CONST               1 (0)
[__bytecode]              38 BINARY_SUBSCR
[__bytecode]              48 DELETE_FAST              2 (graph_out_0)
[__bytecode]              50 RETURN_VALUE
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155376
Approved by: https://github.com/williamwen42
This commit is contained in:
Ryan Guo
2025-06-09 15:12:44 -07:00
committed by PyTorch MergeBot
parent 91ee9ee82d
commit 07eb374e7e
3 changed files with 27 additions and 12 deletions

View File

@ -79,7 +79,7 @@ class PyCodegen:
) -> None:
self.root = root
self.top_of_stack: Optional[Union[VariableTracker, Source]] = None
self.uses: Counter[VariableTracker] = collections.Counter()
self.uses: Counter[Union[VariableTracker, Source]] = collections.Counter()
self.graph_outputs: dict[int, GraphOutputEntry] = {}
self._output: list[Instruction] = []
# This determines which VariableTracker/Source should be stored as
@ -181,9 +181,9 @@ class PyCodegen:
Notable effects:
1. `self.top_of_stack` will be set to `value`, if we don't codegen
`value` based on source.
2. `self.uses[value]` will increment, if we don't codegen `value` based
on source or cache/top-of-stack reuse; in other words, if we codegen
as if `value` is modelling some brand new python value.
2. `self.uses[value]` will increment, unless (a). we codegen via
`top_of_stack` or cached `tempvars`, or (b). `value` has special VT
types like `NNModuleVariable`, etc.
"""
if isinstance(value, Source):
# If the source needs to be overridden, use the new one.
@ -198,6 +198,7 @@ class PyCodegen:
self.top_of_stack = source
return
self.uses[source] += 1
try:
self.call_reconstruct(source)
except NotImplementedError:
@ -207,9 +208,9 @@ class PyCodegen:
explanation=f"Dynamo has no bytecode reconstruction implemented for {type(source)} variable {source}.",
hints=[*graph_break_hints.DYNAMO_BUG],
)
self._output.append(create_dup_top())
self.add_cache(source)
if source in self.tempvars:
self._output.append(create_dup_top())
self.add_cache(source)
self.top_of_stack = source
return

View File

@ -125,6 +125,7 @@ from .utils import (
get_unique_name_wrt,
graph_break_reasons,
increment_op_count,
istype,
lazy_format_graph_code,
LazyString,
nn_module_proxy,
@ -1394,12 +1395,20 @@ class OutputGraph(OutputGraphGuardsState):
)
self.codegen_suffix(tx, stack_values_flat, pass1)
# one more time now that we have established tempvars
# Use `pass1.uses` to selectively cache multi-user variables into a
# temporary local source. This (a). speeds up loading VTs with long
# chained source, and (b). avoids redundantly saving single-user VT
# into a temporary local.
tempvars = {} # type: ignore[var-annotated]
for val, count in pass1.uses.items():
# If it's already a local source, no need to cache it
if count > 1 and not istype(val, (SyntheticLocalSource, LocalSource)):
tempvars[val] = None
pass2 = PyCodegen(
self.root_tx,
root,
graph_output_var,
tempvars={val: None for val, count in pass1.uses.items() if count > 1},
tempvars=tempvars,
overridden_sources=overridden_sources,
)
self.codegen_suffix(tx, stack_values_flat, pass2)

View File

@ -1965,6 +1965,10 @@ class BuiltinVariable(VariableTracker):
name = name_var.as_python_constant()
# See NOTE [Tensor "grad" and "_grad" attr]
if isinstance(obj, TensorVariable) and name == "_grad":
name = "grad"
if tx.output.side_effects.is_attribute_mutation(obj):
if isinstance(obj, variables.UnspecializedNNModuleVariable):
if (
@ -2199,11 +2203,12 @@ class BuiltinVariable(VariableTracker):
# Step 4 - replace all reference to the current object with the new one
return out
elif name in ("_grad", "grad"):
# NOTE: [Tensor "grad" and "_grad" attr]
# _grad and grad share the same setter/getter, see
# THPVariable_properties, and here we make sure setting one
# enables reading `val` from the other.
tx.output.side_effects.store_attr(obj, "grad", val)
tx.output.side_effects.store_attr(obj, "_grad", val)
# enables reading `val` from the other, by routing all
# read/write to `grad`.
name = "grad"
elif is_tensor_getset_descriptor(name):
# Attribute like `torch.Tensor.real` has special setters we
# don't yet support; it's not as simple adding an entry to