mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Avoid unncessary caching source codegen (#155376)
We only need to cache a source (e.g., `x.y.z`) into a temporary local if it's used multiple times in the codegen, otherwise we'd just be creating redundant `DUP` and `STORE_FAST tmp_...` instructions, which might degrade perf and definitely makes generated bytecode harder to read. Example: ```python import torch @torch.compile(backend="eager") def fn(x, y): return x + y fn(torch.ones(2), torch.ones(1)) ``` Original bytecode: ```verbatim [0/0] [__bytecode] 3 0 RESUME 0 [0/0] [__bytecode] [0/0] [__bytecode] 5 2 LOAD_FAST 0 (x) [0/0] [__bytecode] 4 LOAD_FAST 1 (y) [0/0] [__bytecode] 6 BINARY_OP 0 (+) [0/0] [__bytecode] 10 RETURN_VALUE ``` Modified bytecode (before this patch): ```verbatim [__bytecode] 3 0 RESUME 0 [__bytecode] 2 LOAD_GLOBAL 1 (NULL + __compiled_fn_1_578c8d9a_2a9b_4d15_bac7_267591cdee32) [__bytecode] 14 LOAD_FAST 0 (x) [__bytecode] 16 COPY 1 [__bytecode] 18 STORE_FAST 3 (tmp_1) [__bytecode] 20 LOAD_FAST 1 (y) [__bytecode] 22 COPY 1 [__bytecode] 24 STORE_FAST 4 (tmp_2) [__bytecode] 26 PRECALL 2 [__bytecode] 30 CALL 2 [__bytecode] 40 STORE_FAST 2 (graph_out_0) [__bytecode] 42 LOAD_FAST 2 (graph_out_0) [__bytecode] 44 LOAD_CONST 1 (0) [__bytecode] 46 BINARY_SUBSCR [__bytecode] 56 DELETE_FAST 2 (graph_out_0) [__bytecode] 58 RETURN_VALUE ``` Modified bytecode (after this patch): ```verbatim [__bytecode] 3 0 RESUME 0 [__bytecode] 2 LOAD_GLOBAL 1 (NULL + __compiled_fn_1_2c498af2_ce5c_49cb_abba_a0c7489b09ce) [__bytecode] 14 LOAD_FAST 0 (x) [__bytecode] 16 LOAD_FAST 1 (y) [__bytecode] 18 PRECALL 2 [__bytecode] 22 CALL 2 [__bytecode] 32 STORE_FAST 2 (graph_out_0) [__bytecode] 34 LOAD_FAST 2 (graph_out_0) [__bytecode] 36 LOAD_CONST 1 (0) [__bytecode] 38 BINARY_SUBSCR [__bytecode] 48 DELETE_FAST 2 (graph_out_0) [__bytecode] 50 RETURN_VALUE ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/155376 Approved by: https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
91ee9ee82d
commit
07eb374e7e
@ -79,7 +79,7 @@ class PyCodegen:
|
||||
) -> None:
|
||||
self.root = root
|
||||
self.top_of_stack: Optional[Union[VariableTracker, Source]] = None
|
||||
self.uses: Counter[VariableTracker] = collections.Counter()
|
||||
self.uses: Counter[Union[VariableTracker, Source]] = collections.Counter()
|
||||
self.graph_outputs: dict[int, GraphOutputEntry] = {}
|
||||
self._output: list[Instruction] = []
|
||||
# This determines which VariableTracker/Source should be stored as
|
||||
@ -181,9 +181,9 @@ class PyCodegen:
|
||||
Notable effects:
|
||||
1. `self.top_of_stack` will be set to `value`, if we don't codegen
|
||||
`value` based on source.
|
||||
2. `self.uses[value]` will increment, if we don't codegen `value` based
|
||||
on source or cache/top-of-stack reuse; in other words, if we codegen
|
||||
as if `value` is modelling some brand new python value.
|
||||
2. `self.uses[value]` will increment, unless (a). we codegen via
|
||||
`top_of_stack` or cached `tempvars`, or (b). `value` has special VT
|
||||
types like `NNModuleVariable`, etc.
|
||||
"""
|
||||
if isinstance(value, Source):
|
||||
# If the source needs to be overridden, use the new one.
|
||||
@ -198,6 +198,7 @@ class PyCodegen:
|
||||
self.top_of_stack = source
|
||||
return
|
||||
|
||||
self.uses[source] += 1
|
||||
try:
|
||||
self.call_reconstruct(source)
|
||||
except NotImplementedError:
|
||||
@ -207,9 +208,9 @@ class PyCodegen:
|
||||
explanation=f"Dynamo has no bytecode reconstruction implemented for {type(source)} variable {source}.",
|
||||
hints=[*graph_break_hints.DYNAMO_BUG],
|
||||
)
|
||||
|
||||
self._output.append(create_dup_top())
|
||||
self.add_cache(source)
|
||||
if source in self.tempvars:
|
||||
self._output.append(create_dup_top())
|
||||
self.add_cache(source)
|
||||
self.top_of_stack = source
|
||||
|
||||
return
|
||||
|
@ -125,6 +125,7 @@ from .utils import (
|
||||
get_unique_name_wrt,
|
||||
graph_break_reasons,
|
||||
increment_op_count,
|
||||
istype,
|
||||
lazy_format_graph_code,
|
||||
LazyString,
|
||||
nn_module_proxy,
|
||||
@ -1394,12 +1395,20 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
)
|
||||
self.codegen_suffix(tx, stack_values_flat, pass1)
|
||||
|
||||
# one more time now that we have established tempvars
|
||||
# Use `pass1.uses` to selectively cache multi-user variables into a
|
||||
# temporary local source. This (a). speeds up loading VTs with long
|
||||
# chained source, and (b). avoids redundantly saving single-user VT
|
||||
# into a temporary local.
|
||||
tempvars = {} # type: ignore[var-annotated]
|
||||
for val, count in pass1.uses.items():
|
||||
# If it's already a local source, no need to cache it
|
||||
if count > 1 and not istype(val, (SyntheticLocalSource, LocalSource)):
|
||||
tempvars[val] = None
|
||||
pass2 = PyCodegen(
|
||||
self.root_tx,
|
||||
root,
|
||||
graph_output_var,
|
||||
tempvars={val: None for val, count in pass1.uses.items() if count > 1},
|
||||
tempvars=tempvars,
|
||||
overridden_sources=overridden_sources,
|
||||
)
|
||||
self.codegen_suffix(tx, stack_values_flat, pass2)
|
||||
|
@ -1965,6 +1965,10 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
name = name_var.as_python_constant()
|
||||
|
||||
# See NOTE [Tensor "grad" and "_grad" attr]
|
||||
if isinstance(obj, TensorVariable) and name == "_grad":
|
||||
name = "grad"
|
||||
|
||||
if tx.output.side_effects.is_attribute_mutation(obj):
|
||||
if isinstance(obj, variables.UnspecializedNNModuleVariable):
|
||||
if (
|
||||
@ -2199,11 +2203,12 @@ class BuiltinVariable(VariableTracker):
|
||||
# Step 4 - replace all reference to the current object with the new one
|
||||
return out
|
||||
elif name in ("_grad", "grad"):
|
||||
# NOTE: [Tensor "grad" and "_grad" attr]
|
||||
# _grad and grad share the same setter/getter, see
|
||||
# THPVariable_properties, and here we make sure setting one
|
||||
# enables reading `val` from the other.
|
||||
tx.output.side_effects.store_attr(obj, "grad", val)
|
||||
tx.output.side_effects.store_attr(obj, "_grad", val)
|
||||
# enables reading `val` from the other, by routing all
|
||||
# read/write to `grad`.
|
||||
name = "grad"
|
||||
elif is_tensor_getset_descriptor(name):
|
||||
# Attribute like `torch.Tensor.real` has special setters we
|
||||
# don't yet support; it's not as simple adding an entry to
|
||||
|
Reference in New Issue
Block a user