Files
pytorch/torch/_dynamo/codegen.py
Aaron Orenstein 23695ea002 Fix dynamo use of list[int] in graph break (#145554)
This reintroduces the change backed out by #145393 and fixes the underlying problem.

Although using a BuiltinVariable was better than nothing when we saw a GenericAlias it had problems if there was a graph break and we had to reconstruct the original python code which BuiltinVariable did as a simple `list` instead of a `list[int]`.

This changes it to use a TypingVariable instead and then teaches TypingVariable how to reconstruct.

Original commit changeset: 77b9193acb23

python test/dynamo/test_repros.py ReproTests.test_graph_break_on_jit_isinstance

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145554
Approved by: https://github.com/anijain2305
ghstack dependencies: #145551, #145552, #145553
2025-01-30 22:21:40 +00:00

559 lines
21 KiB
Python

# mypy: allow-untyped-defs
import collections
import dataclasses
import re
import sys
import types
from collections import Counter
from typing import Optional
import torch.nn
from . import utils
from .bytecode_transformation import (
add_push_null,
add_push_null_call_function_ex,
create_call_function,
create_call_method,
create_dup_top,
create_instruction,
create_load_const,
create_load_method,
create_rot_n,
Instruction,
)
from .exc import IncorrectUsage, unimplemented
from .source import AttrSource, Source
from .utils import is_safe_constant, rot_n_helper
from .variables.base import ValueMutationExisting, VariableTracker
from .variables.functions import FunctionDecoratedByContextlibContextManagerVariable
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
NumpyNdarrayVariable,
SymNodeVariable,
TensorVariable,
UnspecializedPythonVariable,
)
from .variables.torch_function import TensorWithTFOverrideVariable
@dataclasses.dataclass
class GraphOutputEntry:
index: int
variable: VariableTracker
class PyCodegen:
"""
Helper class uses for constructing Python bytecode
"""
def __init__(
self,
tx=None,
root: Optional[torch.nn.Module] = None,
graph_output_var: Optional[str] = None,
tempvars=None,
overridden_sources=None,
) -> None:
self.root = root
self.top_of_stack: Optional[VariableTracker] = None
self.uses: Counter[VariableTracker] = collections.Counter()
self.graph_outputs: dict[int, GraphOutputEntry] = {}
self._output: list[Instruction] = []
# This determines which VariableTracker should be stored as locals, and
# maps the VariableTracker to the local variable name. Note that it
# could map to None initially, in which case we'll overwrite it to map
# to real temporary names via `add_cache`.
self.tempvars = tempvars or {}
self.tx = tx
self.graph_output_var = graph_output_var
self.code_options = self.tx.output.code_options
self.cell_and_freevars = self.tx.cell_and_freevars
self.new_var = self.tx.output.new_var
self.value_from_source: bool = True
# This serves as a way for codegen to use a different source; we need
# this because sometimes we can't easily modify the original source
# without affecting other components, e.g., guards.
self.overridden_sources: dict[Source, Source] = overridden_sources or {}
def restore_stack(self, stack_values, *, value_from_source=True):
prev = self.value_from_source
self.value_from_source &= value_from_source
try:
self.foreach(stack_values)
finally:
self.value_from_source = prev
def graph_output_vars(self):
return [x.variable for x in self.graph_outputs.values()]
def call_reconstruct(self, value):
res = value.reconstruct(self)
assert res is None, f"reconstruct!=None {value}"
def add_push_null(self, gen_fn, call_function_ex=False):
"""
`gen_fn` generates instructions via PyCodegen methods
that push a single callable to the stack.
`add_push_null` pushes a NULL to the stack before or after the
instructions generated by `gen_fn`, depending on Python version.
Will attempt to use the NULL push bit for instructions
with such bits (LOAD_GLOBAL 3.11+, LOAD_ATTR 3.12+, LOAD_SUPER_ATTR).
"""
old_len = len(self._output)
if sys.version_info < (3, 13):
# gen_fn may DUP_TOP instead if TOS is not cleared.
# Will cause problems since NULL will be pushed right
# before the generated instructions in <= 3.12
self.clear_tos()
gen_fn()
# inplace modify self._output
added_insts = self._output[old_len:]
del self._output[old_len:]
if call_function_ex:
self._output.extend(add_push_null_call_function_ex(added_insts))
else:
self._output.extend(add_push_null(added_insts))
if sys.version_info >= (3, 13):
# NULL will be at top of stack
self.clear_tos()
def __call__(self, value, allow_cache=True):
"""
Generate code such that top-of-stack (TOS) is set to value.
`allow_cache` is used to determine whether the following could happen,
when `value` is a `VariableTracker`:
1. if `value` was codegen-ed previously with `allow_cache=True` and
without using source, reuse the generated code by loading from top
of stack or tempvars.
2. emit code based on `value.source` to handle aliasing.
Notable effects:
1. `self.top_of_stack` will be set to `value`, if we don't codegen
`value` based on source.
2. `self.uses[value]` will increment, if we don't codegen `value` based
on source or cache/top-of-stack reuse; in other words, if we codegen
as if `value` is modelling some brand new python value.
"""
if isinstance(value, Source):
# If the source needs to be overridden, use the new one.
source = self.overridden_sources.get(value, value)
self.call_reconstruct(source)
# We don't support dup_top optimization for source yet.
self.clear_tos()
return
assert isinstance(value, VariableTracker)
output = self._output
graph_outputs = self.graph_outputs
if allow_cache:
if self.top_of_stack is value:
output.append(create_dup_top())
return
if self.tempvars.get(value) is not None:
output.append(self.create_load(self.tempvars[value]))
self.top_of_stack = value
return
if value.is_realized() and isinstance(
value, FunctionDecoratedByContextlibContextManagerVariable
):
raise IncorrectUsage(
"NYI: Returning a @contextmanager object from a torch.compile function"
)
# Dynamo normally prefers codegen from source to account for aliasing.
if value.source is not None and allow_cache:
# There's a corner case for export: for instance, if the computation
# graph is just identity on an input tensor, Dynamo would just emit
# a `LOAD_FAST` from the input source, rather than generating an
# identity FX graph.
#
# However, export wants to maximize graph capture; in the case
# above, export _wants to_ obtain an identity FX graph (despite it
# appears unnecessarily expensive for `torch.compile`), so we have
# the following option to override Dynamo's preference for codegen
# from source. Morever, this option applies recursively, for cases
# like input tensor being returned in a new dictionary.
#
# And why the `ValueMutationExisting` check? Not sure, so leaving it
# to keep the old behavior, as when `value_from_source` was
# introduced. TODO sort out the invariants among side effect,
# codegen and export.
if (
isinstance(value.mutation_type, ValueMutationExisting)
or self.value_from_source
):
return self(value.source)
if value.is_python_constant() and is_safe_constant(value.as_python_constant()):
output.append(self.create_load_const(value.as_python_constant()))
elif isinstance(value, TensorWithTFOverrideVariable):
graph_outputs_key = self.add_graph_output(value)
self.add_push_null(
lambda: self.load_import_from(utils.__name__, "to_subclass")
)
self.load_graph_output(graph_outputs[graph_outputs_key].index)
output.append(
self.create_load_global(
value.global_mangled_class_name(self.tx), add=True
)
)
output.extend(create_call_function(2, False))
elif (
isinstance(value, SymNodeVariable)
and value.python_type() == float
and not self.tx.export
):
# This is a little unusual; force the output convention to be a
# Tensor here. Don't do this for export because this is
# apparently load bearing for export tests (but I am a bit
# doubtful it actually works in the real world)
# NB: It works to add_graph_output on a computed expression
# as_tensor here, because we memoize as_tensor calls on
# SymNodeVariable!
graph_outputs_key = self.add_graph_output(
value.as_tensor(self.tx, torch.float64)
)
def gen_fn():
self.load_graph_output(graph_outputs[graph_outputs_key].index)
output.append(self.create_load_attr("item"))
self.add_push_null(gen_fn)
output.extend(create_call_function(0, False))
elif isinstance(
value,
(
TensorVariable,
SymNodeVariable,
UnspecializedPythonVariable,
NumpyNdarrayVariable,
),
):
graph_outputs_key = self.add_graph_output(value)
if isinstance(value, NumpyNdarrayVariable):
self.add_push_null(
lambda: self.load_import_from(utils.__name__, "to_numpy_helper")
)
self.load_graph_output(graph_outputs[graph_outputs_key].index)
output.extend(create_call_function(1, False))
elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap:
def gen_fn():
self.load_graph_output(graph_outputs[graph_outputs_key].index)
output.append(self.create_load_attr("item"))
self.add_push_null(gen_fn)
output.extend(create_call_function(0, False))
else:
self.load_graph_output(graph_outputs[graph_outputs_key].index)
elif isinstance(value, NNModuleVariable):
parts = value.module_key.split(".")
if parts[0] in self.code_options["co_varnames"]:
output.append(self.create_load(parts[0]))
parts = parts[1:]
else:
assert self.root is not None
output.append(self.create_load_const_unchecked(self.root))
for part in parts:
output.append(self.create_load_attr(part))
else:
self.uses[value] += 1
try:
self.call_reconstruct(value)
except NotImplementedError:
unimplemented(f"reconstruct: {value}")
if allow_cache and value in self.tempvars:
self._output.append(create_dup_top())
self.add_cache(value)
self.top_of_stack = value
def add_graph_output(self, value):
graph_outputs_key = id(value.as_proxy())
if graph_outputs_key not in self.graph_outputs:
self.graph_outputs[graph_outputs_key] = GraphOutputEntry(
len(self.graph_outputs), value
)
return graph_outputs_key
def load_graph_output(self, index):
output = self._output
output.append(self.create_load(self.graph_output_var))
output.append(self.create_load_const(index))
output.append(self.create_binary_subscr())
def add_cache(self, value):
var = self.new_var()
self.tempvars[value] = var
self._output.append(self.create_store(var))
def foreach(self, items):
for i in items:
self(i)
def create_binary_subscr(self) -> Instruction:
return create_instruction("BINARY_SUBSCR")
def setup_globally_cached(self, name, value):
"""Store value in a new global"""
name = re.sub(r"[^a-zA-Z0-9_]+", "_", name)
f_globals = self.tx.f_globals
if name in f_globals:
assert id(f_globals[name]) == id(value)
else:
f_globals[name] = value
return [self.create_load_global(name, add=True)]
def clear_tos(self):
self.top_of_stack = None
def append_output(self, inst):
assert isinstance(inst, Instruction)
self._output.append(inst)
self.clear_tos()
def extend_output(self, insts):
assert all(isinstance(x, Instruction) for x in insts)
self._output.extend(insts)
self.clear_tos()
def get_instructions(self) -> list[Instruction]:
return self._output
def create_load(self, name) -> Instruction:
assert name in self.code_options["co_varnames"], f"{name} missing"
return create_instruction("LOAD_FAST", argval=name)
def create_load_closure(self, name) -> Instruction:
assert name in self.cell_and_freevars()
inst_name = "LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE"
return create_instruction(inst_name, argval=name)
def create_load_deref(self, name) -> Instruction:
assert name in self.cell_and_freevars()
return create_instruction("LOAD_DEREF", argval=name)
def create_store(self, name) -> Instruction:
assert name in self.code_options["co_varnames"], f"{name} missing"
return create_instruction("STORE_FAST", argval=name)
def create_store_deref(self, name) -> Instruction:
assert name in self.cell_and_freevars()
return create_instruction("STORE_DEREF", argval=name)
def create_load_global(self, name, add=False) -> Instruction:
if add:
self.tx.output.update_co_names(name)
assert name in self.code_options["co_names"], f"{name} not in co_names"
return create_instruction("LOAD_GLOBAL", argval=name)
def create_load_const(self, value) -> Instruction:
return create_load_const(value)
def create_load_const_unchecked(self, value) -> Instruction:
return create_load_const(value, checked=False)
def load_method(self, name):
self.tx.output.update_co_names(name)
self.append_output(create_load_method(name))
def call_method(self, nargs):
self.extend_output(create_call_method(nargs))
def create_load_attr(self, name) -> Instruction:
if name not in self.code_options["co_names"]:
self.code_options["co_names"] += (name,)
return create_instruction("LOAD_ATTR", argval=name)
def load_attr(self, name):
self.append_output(self.create_load_attr(name))
def create_load_attrs(self, names):
return [self.create_load_attr(name) for name in names.split(".")]
def create_store_attr(self, name) -> Instruction:
if name not in self.code_options["co_names"]:
self.code_options["co_names"] += (name,)
return create_instruction("STORE_ATTR", argval=name)
def store_attr(self, name):
self.append_output(self.create_store_attr(name))
def load_function_name(self, fn_name, push_null, num_on_stack=0):
"""Load the global fn_name on the stack num_on_stack down"""
output = []
if push_null and sys.version_info >= (3, 11):
output.extend(add_push_null(self.create_load_global(fn_name, add=True)))
if num_on_stack > 0:
output.extend(
[
*self.rot_n(num_on_stack + 2),
*self.rot_n(num_on_stack + 2),
]
)
else:
output.extend(
[
self.create_load_global(fn_name, add=True),
*self.rot_n(num_on_stack + 1),
]
)
return output
def rot_n(self, n):
try:
return create_rot_n(n)
except AttributeError:
# desired rotate bytecode doesn't exist, generate equivalent bytecode
return [
create_instruction("BUILD_TUPLE", arg=n),
self.create_load_const_unchecked(rot_n_helper(n)),
*create_rot_n(2),
create_instruction("CALL_FUNCTION_EX", arg=0),
create_instruction("UNPACK_SEQUENCE", arg=n),
]
def pop_null(self):
# POP_TOP doesn't work for null, so we pop nulls by pushing in a
# nop function, calling it (which consumes the null), and popping the result.
assert sys.version_info >= (3, 11)
return [
self.create_load_const_unchecked(lambda: None),
# 3.13 swapped NULL and callable
*(
(create_instruction("SWAP", arg=2),)
if sys.version_info >= (3, 13)
else ()
),
*create_call_function(0, False),
create_instruction("POP_TOP"),
]
def pop_top(self):
self.append_output(create_instruction("POP_TOP"))
def call_function(self, nargs: int, push_null: bool):
self.extend_output(create_call_function(nargs, push_null=push_null))
def dup_top(self):
self.append_output(create_dup_top())
def store(self, varname):
self.append_output(self.create_store(varname))
def load_deref(self, varname):
self.append_output(self.create_load_deref(varname))
def make_function_with_closure(
self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack=0
):
freevars = code.co_freevars
assert freevars
output = self._output
def gen_fn():
# Emitting `LOAD_FAST/LOAD_CLOSURE` with names in `co_freevars`
# requires that in the generated bytecode, these cells would keep
# their original local names, which we ensure via
# `CellVariable.local_name`.
for var in freevars:
assert var in self.cell_and_freevars()
output.append(self.create_load_closure(var))
output.append(create_instruction("BUILD_TUPLE", arg=len(freevars)))
output.append(self.create_load_const(code))
if sys.version_info < (3, 11):
output.append(self.create_load_const(fn_name))
if sys.version_info >= (3, 13):
output.extend(
[
create_instruction("MAKE_FUNCTION"),
create_instruction("SET_FUNCTION_ATTRIBUTE", arg=0x08),
]
)
else:
output.append(create_instruction("MAKE_FUNCTION", arg=0x08))
if push_null and sys.version_info >= (3, 11):
self.add_push_null(gen_fn)
output.extend(self.rot_n(num_on_stack + 2))
output.extend(self.rot_n(num_on_stack + 2))
else:
gen_fn()
output.extend(self.rot_n(num_on_stack + 1))
self.clear_tos()
def create_load_python_module(self, mod) -> Instruction:
"""
Generate a LOAD_GLOBAL instruction to fetch a given python module.
"""
output = self.tx.output
global_scope = output.global_scope
name = re.sub(r"^.*[.]", "", mod.__name__)
if global_scope.get(name, None) is mod:
return self.create_load_global(name, add=True)
prefix = f"___module_{name}"
global_name = self.tx.output.install_global_by_id(prefix, mod)
return self.create_load_global(global_name, add=True)
def make_call_generated_code(self, fn_name: str) -> None:
"""Call the generated code function stored in fn_name"""
self.extend_output(self.load_function_name(fn_name, True))
graphargs = self.tx.output.graphargs
for arg in graphargs:
if arg.pass_arg_as_tensor:
self.add_push_null(
lambda: self.extend_output(
[
self.create_load_python_module(torch),
self.create_load_attr("_as_tensor_fullprec"),
]
)
)
self.call_reconstruct(arg)
self.extend_output(create_call_function(1, False))
else:
self.call_reconstruct(arg)
self.extend_output(create_call_function(len(graphargs), False))
def load_import_from(self, module_name, object_name) -> None:
self(AttrSource(self.tx.import_source(module_name), object_name))
def create_call_function_kw(self, nargs, kw_names, push_null) -> list[Instruction]:
if sys.version_info >= (3, 13):
output = create_call_function(nargs, push_null)
assert output[-1].opname == "CALL"
output.insert(-1, self.create_load_const(kw_names))
output[-1] = create_instruction("CALL_KW", arg=nargs)
return output
elif sys.version_info >= (3, 11):
output = create_call_function(nargs, push_null)
if sys.version_info >= (3, 12):
idx = -1
expected_inst = "CALL"
else:
idx = -2
expected_inst = "PRECALL"
assert output[idx].opname == expected_inst
kw_names_inst = create_instruction("KW_NAMES", argval=kw_names)
output.insert(idx, kw_names_inst)
return output
return [
self.create_load_const(kw_names),
create_instruction("CALL_FUNCTION_KW", arg=nargs),
]
def create_delete(self, value) -> Instruction:
return create_instruction("DELETE_FAST", argval=value)