mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-01 04:54:55 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			708 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			708 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| This module provides utilities for generating Python bytecode in PyTorch's Dynamo system.
 | |
| It includes functionality for:
 | |
| - Constructing bytecode sequences for Python operations
 | |
| - Managing stack operations and variable tracking
 | |
| - Handling graph outputs and their conversions
 | |
| - Supporting different Python versions (3.11+, 3.12+, 3.13+)
 | |
| - Converting high-level operations to low-level bytecode instructions
 | |
| - Managing constant loading and attribute access
 | |
| - Supporting function creation and closure handling
 | |
| """
 | |
| 
 | |
| import collections
 | |
| import dataclasses
 | |
| import re
 | |
| import sys
 | |
| import types
 | |
| from collections import Counter
 | |
| from collections.abc import Callable, Iterable
 | |
| from typing import Any, Optional, TYPE_CHECKING, Union
 | |
| 
 | |
| import torch.nn
 | |
| from torch.utils._ordered_set import OrderedSet
 | |
| 
 | |
| from . import config, graph_break_hints, utils
 | |
| from .bytecode_transformation import (
 | |
|     add_push_null,
 | |
|     add_push_null_call_function_ex,
 | |
|     create_binary_subscr,
 | |
|     create_build_tuple,
 | |
|     create_call_function,
 | |
|     create_call_function_ex,
 | |
|     create_call_method,
 | |
|     create_dup_top,
 | |
|     create_instruction,
 | |
|     create_load_const,
 | |
|     create_load_method,
 | |
|     create_rot_n,
 | |
|     Instruction,
 | |
| )
 | |
| from .exc import IncorrectUsage, unimplemented_v2
 | |
| from .source import AttrSource, ChainedSource, DictGetItemSource, Source
 | |
| from .utils import is_safe_constant, rot_n_helper
 | |
| from .variables.base import ValueMutationExisting, VariableTracker
 | |
| from .variables.functions import (
 | |
|     ContextlibContextManagerLocalGeneratorObjectVariable,
 | |
|     LocalGeneratorObjectVariable,
 | |
| )
 | |
| from .variables.nn_module import NNModuleVariable
 | |
| from .variables.tensor import (
 | |
|     NumpyNdarrayVariable,
 | |
|     SymNodeVariable,
 | |
|     TensorVariable,
 | |
|     UnspecializedPythonVariable,
 | |
| )
 | |
| from .variables.torch_function import TensorWithTFOverrideVariable
 | |
| 
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from torch._dynamo.variables.builder import GraphArg
 | |
| 
 | |
|     from .symbolic_convert import InstructionTranslatorBase
 | |
| 
 | |
| 
 | |
| @dataclasses.dataclass
 | |
| class GraphOutputEntry:
 | |
|     index: int
 | |
|     variable: VariableTracker
 | |
| 
 | |
| 
 | |
| class PyCodegen:
 | |
|     """
 | |
|     Helper class uses for constructing Python bytecode
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         tx: "InstructionTranslatorBase",
 | |
|         root: Optional[torch.nn.Module] = None,
 | |
|         graph_output_var: Optional[str] = None,
 | |
|         tempvars: Optional[dict[Union[VariableTracker, Source], Any]] = None,
 | |
|         overridden_sources: Optional[dict[Source, Source]] = None,
 | |
|     ) -> None:
 | |
|         self.root = root
 | |
|         self.top_of_stack: Optional[Union[VariableTracker, Source]] = None
 | |
|         self.uses: Counter[Union[VariableTracker, Source]] = collections.Counter()
 | |
|         self.graph_outputs: dict[int, GraphOutputEntry] = {}
 | |
|         self._output: list[Instruction] = []
 | |
|         # This determines which VariableTracker/Source should be stored as
 | |
|         # locals, and maps the VariableTracker/Source to the local variable
 | |
|         # name. Note that it could map to None initially, in which case we'll
 | |
|         # overwrite it to map to real temporary names via `add_cache`.
 | |
|         self.tempvars: dict[Union[VariableTracker, Source], Any] = tempvars or {}
 | |
|         self.tx = tx
 | |
|         self.graph_output_var = graph_output_var
 | |
|         self.code_options = self.tx.output.code_options
 | |
|         self.cell_and_freevars = self.tx.cell_and_freevars
 | |
|         self.new_var = self.tx.output.new_var
 | |
|         self.value_from_source: bool = True
 | |
|         # This serves as a way for codegen to use a different source; we need
 | |
|         # this because sometimes we can't easily modify the original source
 | |
|         # without affecting other components, e.g., guards.
 | |
|         self.overridden_sources: dict[Source, Source] = overridden_sources or {}
 | |
| 
 | |
|     def restore_stack(
 | |
|         self, stack_values: list[Any], *, value_from_source: bool = True
 | |
|     ) -> None:
 | |
|         prev = self.value_from_source
 | |
|         self.value_from_source &= value_from_source
 | |
|         try:
 | |
|             self.foreach(stack_values)
 | |
|         finally:
 | |
|             self.value_from_source = prev
 | |
| 
 | |
|     def graph_output_vars(self) -> list[VariableTracker]:
 | |
|         return [x.variable for x in self.graph_outputs.values()]
 | |
| 
 | |
|     def call_reconstruct(
 | |
|         self, value: Union[VariableTracker, Source, "GraphArg"]
 | |
|     ) -> None:
 | |
|         res = value.reconstruct(self)
 | |
|         assert res is None, f"reconstruct!=None {value}"
 | |
| 
 | |
|     def add_push_null(
 | |
|         self, gen_fn: Callable[[], None], call_function_ex: bool = False
 | |
|     ) -> None:
 | |
|         """
 | |
|         `gen_fn` generates instructions via PyCodegen methods
 | |
|         that push a single callable to the stack.
 | |
| 
 | |
|         `add_push_null` pushes a NULL to the stack before or after the
 | |
|         instructions generated by `gen_fn`, depending on Python version.
 | |
| 
 | |
|         Will attempt to use the NULL push bit for instructions
 | |
|         with such bits (LOAD_GLOBAL 3.11+, LOAD_ATTR 3.12+, LOAD_SUPER_ATTR).
 | |
|         """
 | |
|         old_len = len(self._output)
 | |
|         if sys.version_info < (3, 13):
 | |
|             # gen_fn may DUP_TOP instead if TOS is not cleared.
 | |
|             # Will cause problems since NULL will be pushed right
 | |
|             # before the generated instructions in <= 3.12
 | |
|             self.clear_tos()
 | |
|         gen_fn()
 | |
|         # inplace modify self._output
 | |
|         added_insts = self._output[old_len:]
 | |
|         del self._output[old_len:]
 | |
|         if call_function_ex:
 | |
|             self._output.extend(add_push_null_call_function_ex(added_insts))
 | |
|         else:
 | |
|             self._output.extend(add_push_null(added_insts))
 | |
|         if sys.version_info >= (3, 13):
 | |
|             # NULL will be at top of stack
 | |
|             self.clear_tos()
 | |
| 
 | |
|     def __call__(
 | |
|         self, value: Union[VariableTracker, Source, None], allow_cache: bool = True
 | |
|     ) -> None:
 | |
|         """
 | |
|         Generate code such that top-of-stack (TOS) is set to value.
 | |
| 
 | |
|         `allow_cache` controls the behavior in the following manner. `value` can
 | |
|         either be a VariableTracker or a Source.
 | |
| 
 | |
|         If `value` is a `Source`, `allow_cache` must be True (invariant asserted
 | |
|         below). If the source was reconstructed earlier, we will reuse the
 | |
|         generated code by loading from top of stack or tempvars.
 | |
| 
 | |
|         If `value` is a `VariableTracker`, we have the following cases:
 | |
| 
 | |
|         1) `allow_cache=True`
 | |
|             a) If the value.source is not None, we will emit the code based on
 | |
|             `value.source` to handle aliasing.
 | |
|             b) If value.source is None (example reconstructing a local list
 | |
|             returned by the compiled function), we will reconstruct the variable
 | |
|             tracker (w/o any source) to emit bytecode that generates a new
 | |
|             python object.
 | |
| 
 | |
|             In both cases of value.source being None or not, if the value was
 | |
|             reconstructed earlier, we will reuse the generated code by loading from
 | |
|             top of stack or tempvars.
 | |
| 
 | |
|         2) `allow_cache=False` - This is a special case (allow_cache defaults to
 | |
|         True).
 | |
|             a) If the value.source is not None, we reconstruct the variable
 | |
|             tracker and emit a new python object. You might wonder what about
 | |
|             aliasing? The place where we use this config also has the followup
 | |
|             code where the original python object is assigned to this new python
 | |
|             value to handle aliasing (check side_effects.py and search for
 | |
|             allow_cache=False).
 | |
| 
 | |
|             b) If value.source is None, this is not allowed
 | |
| 
 | |
|         Notable effects:
 | |
|         1. `self.top_of_stack` will be set to `value`, if we don't codegen
 | |
|            `value` based on source.
 | |
|         2. `self.uses[value]` will increment, unless (a). we codegen via
 | |
|             `top_of_stack` or cached `tempvars`, or (b). `value` has special VT
 | |
|             types like `NNModuleVariable`, etc.
 | |
|         """
 | |
|         assert value is not None
 | |
|         if isinstance(value, Source):
 | |
|             # If the source needs to be overridden, use the new one.
 | |
|             source = self.overridden_sources.get(value, value)
 | |
|             assert allow_cache is True, "allow_cache must be True for Source"
 | |
|             if self.top_of_stack is value:
 | |
|                 self._output.append(create_dup_top())
 | |
|                 return
 | |
| 
 | |
|             if self.tempvars.get(source) is not None:
 | |
|                 self._output.append(self.create_load(self.tempvars[source]))
 | |
|                 self.top_of_stack = source
 | |
|                 return
 | |
| 
 | |
|             self.uses[source] += 1
 | |
|             try:
 | |
|                 self.call_reconstruct(source)
 | |
|             except NotImplementedError:
 | |
|                 unimplemented_v2(
 | |
|                     gb_type="Reconstruction failure: source.reconstruct not implemented",
 | |
|                     context=str(source),
 | |
|                     explanation=f"Dynamo has no bytecode reconstruction implemented for {type(source)} variable {source}.",
 | |
|                     hints=[*graph_break_hints.DYNAMO_BUG],
 | |
|                 )
 | |
|             if source in self.tempvars:
 | |
|                 self._output.append(create_dup_top())
 | |
|                 self.add_cache(source)
 | |
|             self.top_of_stack = source
 | |
| 
 | |
|             return
 | |
| 
 | |
|         assert isinstance(value, VariableTracker)
 | |
|         output = self._output
 | |
|         graph_outputs = self.graph_outputs
 | |
| 
 | |
|         if allow_cache:
 | |
|             if self.top_of_stack is value:
 | |
|                 output.append(create_dup_top())
 | |
|                 return
 | |
| 
 | |
|             if self.tempvars.get(value) is not None:
 | |
|                 output.append(self.create_load(self.tempvars[value]))
 | |
|                 self.top_of_stack = value
 | |
|                 return
 | |
| 
 | |
|         if value.is_realized() and isinstance(
 | |
|             value, ContextlibContextManagerLocalGeneratorObjectVariable
 | |
|         ):
 | |
|             raise IncorrectUsage(
 | |
|                 "NYI: Returning a @contextmanager object from a torch.compile function"
 | |
|             )
 | |
| 
 | |
|         # Dynamo normally prefers codegen from source to account for aliasing.
 | |
|         if (
 | |
|             value.source is not None
 | |
|             and allow_cache
 | |
|             and not (
 | |
|                 value.is_realized() and isinstance(value, LocalGeneratorObjectVariable)
 | |
|             )
 | |
|         ):
 | |
|             # There's a corner case for export: for instance, if the computation
 | |
|             # graph is just identity on an input tensor, Dynamo would just emit
 | |
|             # a `LOAD_FAST` from the input source, rather than generating an
 | |
|             # identity FX graph.
 | |
|             #
 | |
|             # However, export wants to maximize graph capture; in the case
 | |
|             # above, export _wants to_ obtain an identity FX graph (despite it
 | |
|             # appears unnecessarily expensive for `torch.compile`), so we have
 | |
|             # the following option to override Dynamo's preference for codegen
 | |
|             # from source. Moreover, this option applies recursively, for cases
 | |
|             # like input tensor being returned in a new dictionary.
 | |
|             #
 | |
|             # And why the `ValueMutationExisting` check? Not sure, so leaving it
 | |
|             # to keep the old behavior, as when `value_from_source` was
 | |
|             # introduced. TODO sort out the invariants among side effect,
 | |
|             # codegen and export.
 | |
|             if (
 | |
|                 isinstance(value.mutation_type, ValueMutationExisting)
 | |
|                 or self.value_from_source
 | |
|             ):
 | |
|                 return self(value.source)
 | |
| 
 | |
|         if value.is_python_constant() and is_safe_constant(value.as_python_constant()):
 | |
|             output.append(self.create_load_const(value.as_python_constant()))
 | |
|         elif isinstance(value, TensorWithTFOverrideVariable):
 | |
|             graph_outputs_key = self.add_graph_output(value)
 | |
| 
 | |
|             self.add_push_null(
 | |
|                 lambda: self.load_import_from(utils.__name__, "to_subclass")
 | |
|             )
 | |
|             self.load_graph_output(graph_outputs[graph_outputs_key].index)
 | |
|             output.append(
 | |
|                 self.create_load_global(
 | |
|                     value.global_mangled_class_name(self.tx),  # type: ignore[arg-type]
 | |
|                     add=True,
 | |
|                 )
 | |
|             )
 | |
|             output.extend(create_call_function(2, False))
 | |
|         elif (
 | |
|             isinstance(value, SymNodeVariable)
 | |
|             and value.python_type() is float
 | |
|             and not self.tx.export
 | |
|         ):
 | |
|             # This is a little unusual; force the output convention to be a
 | |
|             # Tensor here.  Don't do this for export because this is
 | |
|             # apparently load bearing for export tests (but I am a bit
 | |
|             # doubtful it actually works in the real world)
 | |
|             # NB: It works to add_graph_output on a computed expression
 | |
|             # as_tensor here, because we memoize as_tensor calls on
 | |
|             # SymNodeVariable!
 | |
|             graph_outputs_key = self.add_graph_output(
 | |
|                 value.as_tensor(self.tx, torch.float64)
 | |
|             )
 | |
| 
 | |
|             def gen_fn() -> None:
 | |
|                 self.load_graph_output(graph_outputs[graph_outputs_key].index)
 | |
|                 output.append(self.create_load_attr("item"))
 | |
| 
 | |
|             self.add_push_null(gen_fn)
 | |
|             output.extend(create_call_function(0, False))
 | |
|         elif isinstance(
 | |
|             value,
 | |
|             (
 | |
|                 TensorVariable,
 | |
|                 SymNodeVariable,
 | |
|                 UnspecializedPythonVariable,
 | |
|                 NumpyNdarrayVariable,
 | |
|             ),
 | |
|         ):
 | |
|             graph_outputs_key = self.add_graph_output(value)
 | |
| 
 | |
|             if isinstance(value, NumpyNdarrayVariable):
 | |
|                 self.add_push_null(
 | |
|                     lambda: self.load_import_from(utils.__name__, "to_numpy_helper")
 | |
|                 )
 | |
|                 self.load_graph_output(graph_outputs[graph_outputs_key].index)
 | |
|                 output.extend(create_call_function(1, False))
 | |
|             elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap:
 | |
| 
 | |
|                 def gen_fn() -> None:
 | |
|                     self.load_graph_output(graph_outputs[graph_outputs_key].index)
 | |
|                     output.append(self.create_load_attr("item"))
 | |
| 
 | |
|                 self.add_push_null(gen_fn)
 | |
|                 output.extend(create_call_function(0, False))
 | |
|             else:
 | |
|                 self.load_graph_output(graph_outputs[graph_outputs_key].index)
 | |
|         elif isinstance(value, NNModuleVariable):
 | |
|             parts = value.module_key.split(".")
 | |
|             if parts[0] in self.code_options["co_varnames"]:
 | |
|                 output.append(self.create_load(parts[0]))
 | |
|                 parts = parts[1:]
 | |
|             else:
 | |
|                 assert self.root is not None
 | |
|                 output.append(self.create_load_const_unchecked(self.root))
 | |
|             for part in parts:
 | |
|                 output.append(self.create_load_attr(part))
 | |
|         else:
 | |
|             self.uses[value] += 1
 | |
|             try:
 | |
|                 self.call_reconstruct(value)
 | |
|             except NotImplementedError:
 | |
|                 unimplemented_v2(
 | |
|                     gb_type="Reconstruction failure",
 | |
|                     context=str(value),
 | |
|                     explanation=f"Dynamo has no bytecode reconstruction implemented for sourceless variable {value}.",
 | |
|                     hints=[
 | |
|                         "If Dynamo is attempting to trace a return statement and your code is attempting to return a variable "
 | |
|                         "that Dynamo cannot reconstruct, then remove it from the return statement.",
 | |
|                         *graph_break_hints.CAUSED_BY_EARLIER_GRAPH_BREAK,
 | |
|                         "Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have "
 | |
|                         "reconstruction rules may be fundamentally unreconstructable.",
 | |
|                     ],
 | |
|                 )
 | |
|             if allow_cache and value in self.tempvars:
 | |
|                 self._output.append(create_dup_top())
 | |
|                 self.add_cache(value)
 | |
| 
 | |
|         self.top_of_stack = value
 | |
| 
 | |
|     def add_graph_output(self, value: VariableTracker) -> int:
 | |
|         graph_outputs_key = id(value.as_proxy())
 | |
|         if graph_outputs_key not in self.graph_outputs:
 | |
|             self.graph_outputs[graph_outputs_key] = GraphOutputEntry(
 | |
|                 len(self.graph_outputs), value
 | |
|             )
 | |
|         return graph_outputs_key
 | |
| 
 | |
|     def load_graph_output(self, index: int) -> None:
 | |
|         output = self._output
 | |
|         assert self.graph_output_var is not None
 | |
|         output.append(self.create_load(self.graph_output_var))
 | |
|         output.append(self.create_load_const(index))
 | |
|         output.append(self.create_binary_subscr())
 | |
| 
 | |
|     def add_cache(self, value: Union[VariableTracker, Source]) -> None:
 | |
|         var = self.new_var()
 | |
|         self.tempvars[value] = var
 | |
|         self._output.append(self.create_store(var))
 | |
| 
 | |
|     def foreach(self, items: Iterable[Union[VariableTracker, Source]]) -> None:
 | |
|         for i in items:
 | |
|             self(i)
 | |
| 
 | |
|     def create_binary_subscr(self) -> Instruction:
 | |
|         return create_binary_subscr()
 | |
| 
 | |
|     def setup_globally_cached(self, name: str, value: Any) -> list[Instruction]:
 | |
|         """Store value in a new global"""
 | |
|         name = re.sub(r"[^a-zA-Z0-9_]+", "_", name)
 | |
|         f_globals = self.tx.f_globals
 | |
|         if name in f_globals:
 | |
|             assert id(f_globals[name]) == id(value)
 | |
|         else:
 | |
|             f_globals[name] = value
 | |
|         return [self.create_load_global(name, add=True)]
 | |
| 
 | |
|     def clear_tos(self) -> None:
 | |
|         self.top_of_stack = None
 | |
| 
 | |
|     def append_output(self, inst: Instruction) -> None:
 | |
|         assert isinstance(inst, Instruction)
 | |
|         self._output.append(inst)
 | |
|         self.clear_tos()
 | |
| 
 | |
|     def extend_output(self, insts: list[Instruction]) -> None:
 | |
|         assert all(isinstance(x, Instruction) for x in insts)
 | |
|         self._output.extend(insts)
 | |
|         self.clear_tos()
 | |
| 
 | |
|     def get_instructions(self) -> list[Instruction]:
 | |
|         return self._output
 | |
| 
 | |
|     def create_load(self, name: str) -> Instruction:
 | |
|         assert name in self.code_options["co_varnames"], f"{name} missing"
 | |
|         return create_instruction("LOAD_FAST", argval=name)
 | |
| 
 | |
|     def create_load_closure(self, name: str) -> Instruction:
 | |
|         assert name in self.cell_and_freevars()
 | |
|         inst_name = "LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE"
 | |
|         return create_instruction(inst_name, argval=name)
 | |
| 
 | |
|     def create_load_deref(self, name: str) -> Instruction:
 | |
|         assert name in self.cell_and_freevars()
 | |
|         return create_instruction("LOAD_DEREF", argval=name)
 | |
| 
 | |
|     def create_store(self, name: str) -> Instruction:
 | |
|         assert name in self.code_options["co_varnames"], f"{name} missing"
 | |
|         return create_instruction("STORE_FAST", argval=name)
 | |
| 
 | |
|     def create_store_deref(self, name: str) -> Instruction:
 | |
|         assert name in self.cell_and_freevars()
 | |
|         return create_instruction("STORE_DEREF", argval=name)
 | |
| 
 | |
|     def create_load_global(self, name: str, add: bool = False) -> Instruction:
 | |
|         if add:
 | |
|             self.tx.output.update_co_names(name)
 | |
|         assert name in self.code_options["co_names"], f"{name} not in co_names"
 | |
|         return create_instruction("LOAD_GLOBAL", argval=name)
 | |
| 
 | |
|     def create_load_const(self, value: Any) -> Instruction:
 | |
|         return create_load_const(value)
 | |
| 
 | |
|     def create_load_const_unchecked(self, value: Any) -> Instruction:
 | |
|         return create_load_const(value, checked=False)
 | |
| 
 | |
|     def load_method(self, name: str) -> None:
 | |
|         self.tx.output.update_co_names(name)
 | |
|         self.append_output(create_load_method(name))
 | |
| 
 | |
|     def call_method(self, nargs: int) -> None:
 | |
|         self.extend_output(create_call_method(nargs))
 | |
| 
 | |
|     def create_load_attr(self, name: str) -> Instruction:
 | |
|         if name not in self.code_options["co_names"]:
 | |
|             self.code_options["co_names"] += (name,)
 | |
|         return create_instruction("LOAD_ATTR", argval=name)
 | |
| 
 | |
|     def load_attr(self, name: str) -> None:
 | |
|         self.append_output(self.create_load_attr(name))
 | |
| 
 | |
|     def create_load_attrs(self, names: str) -> list[Instruction]:
 | |
|         return [self.create_load_attr(name) for name in names.split(".")]
 | |
| 
 | |
|     def create_store_attr(self, name: str) -> Instruction:
 | |
|         if name not in self.code_options["co_names"]:
 | |
|             self.code_options["co_names"] += (name,)
 | |
|         return create_instruction("STORE_ATTR", argval=name)
 | |
| 
 | |
|     def store_attr(self, name: str) -> None:
 | |
|         self.append_output(self.create_store_attr(name))
 | |
| 
 | |
|     def load_function_name(
 | |
|         self, fn_name: str, push_null: bool, num_on_stack: int = 0
 | |
|     ) -> list[Instruction]:
 | |
|         """Load the global fn_name on the stack num_on_stack down"""
 | |
|         output = []
 | |
|         if push_null and sys.version_info >= (3, 11):
 | |
|             output.extend(add_push_null(self.create_load_global(fn_name, add=True)))
 | |
|             if num_on_stack > 0:
 | |
|                 output.extend(
 | |
|                     [
 | |
|                         *self.rot_n(num_on_stack + 2),
 | |
|                         *self.rot_n(num_on_stack + 2),
 | |
|                     ]
 | |
|                 )
 | |
|         else:
 | |
|             output.extend(
 | |
|                 [
 | |
|                     self.create_load_global(fn_name, add=True),
 | |
|                     *self.rot_n(num_on_stack + 1),
 | |
|                 ]
 | |
|             )
 | |
|         return output
 | |
| 
 | |
|     def rot_n(self, n: int) -> list[Instruction]:
 | |
|         try:
 | |
|             return create_rot_n(n)
 | |
|         except AttributeError:
 | |
|             # desired rotate bytecode doesn't exist, generate equivalent bytecode
 | |
|             return [
 | |
|                 create_build_tuple(n),
 | |
|                 self.create_load_const_unchecked(rot_n_helper(n)),
 | |
|                 *create_rot_n(2),
 | |
|                 *create_call_function_ex(False, False),
 | |
|                 create_instruction("UNPACK_SEQUENCE", arg=n),
 | |
|             ]
 | |
| 
 | |
|     def pop_top(self) -> None:
 | |
|         self.append_output(create_instruction("POP_TOP"))
 | |
| 
 | |
|     def call_function(self, nargs: int, push_null: bool) -> None:
 | |
|         self.extend_output(create_call_function(nargs, push_null=push_null))
 | |
| 
 | |
|     def dup_top(self) -> None:
 | |
|         self.append_output(create_dup_top())
 | |
| 
 | |
|     def store(self, varname: str) -> None:
 | |
|         self.append_output(self.create_store(varname))
 | |
| 
 | |
|     def load_deref(self, varname: str) -> None:
 | |
|         self.append_output(self.create_load_deref(varname))
 | |
| 
 | |
|     def make_function_with_closure(
 | |
|         self,
 | |
|         fn_name: str,
 | |
|         code: types.CodeType,
 | |
|     ) -> None:
 | |
|         """Creates a closure with code object `code`.
 | |
| 
 | |
|         Expects the TOS to be the tuple of cells to use for this closure.
 | |
|         TOS will be popped to create the closure.
 | |
|         Args:
 | |
|             - fn_name: name of the function
 | |
|             - code: code object of the function
 | |
|                 (does not include the tuple of cells on the TOS)
 | |
|         """
 | |
|         output = self._output
 | |
| 
 | |
|         output.append(self.create_load_const(code))
 | |
|         if sys.version_info < (3, 11):
 | |
|             output.append(self.create_load_const(fn_name))
 | |
|         if sys.version_info >= (3, 13):
 | |
|             output.extend(
 | |
|                 [
 | |
|                     create_instruction("MAKE_FUNCTION"),
 | |
|                     create_instruction("SET_FUNCTION_ATTRIBUTE", arg=0x08),
 | |
|                 ]
 | |
|             )
 | |
|         else:
 | |
|             output.append(create_instruction("MAKE_FUNCTION", arg=0x08))
 | |
| 
 | |
|         self.clear_tos()
 | |
| 
 | |
|     def create_load_python_module(self, mod: types.ModuleType) -> Instruction:
 | |
|         """
 | |
|         Generate a LOAD_GLOBAL instruction to fetch a given python module.
 | |
|         """
 | |
|         output = self.tx.output
 | |
|         global_scope = output.global_scope
 | |
|         name = re.sub(r"^.*[.]", "", mod.__name__)
 | |
|         if global_scope.get(name, None) is mod:
 | |
|             return self.create_load_global(name, add=True)
 | |
|         prefix = f"___module_{name}"
 | |
|         global_name = self.tx.output.install_global_by_id(prefix, mod)
 | |
|         return self.create_load_global(global_name, add=True)
 | |
| 
 | |
|     def mark_source_temp(self, source: Source) -> None:
 | |
|         """
 | |
|         Mark a source as a temp variable, so that it can be reused.
 | |
|         """
 | |
|         if source not in self.tempvars:
 | |
|             self.tempvars[source] = None
 | |
| 
 | |
|     def make_call_generated_code(self, fn_name: str) -> None:
 | |
|         """Call the generated code function stored in fn_name"""
 | |
|         self.extend_output(self.load_function_name(fn_name, True))
 | |
| 
 | |
|         graphargs = self.tx.output.graphargs
 | |
| 
 | |
|         seen_sources: OrderedSet[Source] = OrderedSet()
 | |
| 
 | |
|         def collect_temp_source(source: Source) -> None:
 | |
|             if source in seen_sources:
 | |
|                 # This source is used at least twice, so it can be reused
 | |
|                 self.mark_source_temp(source)
 | |
|                 # Dont trace source further. This prevents us from marking too
 | |
|                 # many nodes as temp sources.
 | |
|                 return
 | |
| 
 | |
|             seen_sources.add(source)
 | |
| 
 | |
|             if isinstance(source, ChainedSource):
 | |
|                 collect_temp_source(source.base)
 | |
| 
 | |
|             if isinstance(source, DictGetItemSource) and isinstance(
 | |
|                 source.index, Source
 | |
|             ):
 | |
|                 collect_temp_source(source.index)
 | |
| 
 | |
|         # Collect all the sources that are used more than once, so that we can
 | |
|         # generate tmp variables in the generated pre-graph bytecode. This
 | |
|         # essentially implements CSE.
 | |
|         for arg in graphargs:
 | |
|             if arg.source is not None:
 | |
|                 collect_temp_source(arg.source)
 | |
| 
 | |
|         cm_var = None
 | |
|         if config.record_runtime_overhead:
 | |
|             # Record the pregraph bytecode start
 | |
|             self.add_push_null(
 | |
|                 lambda: self.load_import_from(
 | |
|                     utils.__name__, "record_pregraph_bytecode_enter"
 | |
|                 )
 | |
|             )
 | |
|             self.extend_output(create_call_function(0, False))
 | |
|             cm_var = self.new_var()
 | |
|             self.store(cm_var)
 | |
| 
 | |
|         for arg in graphargs:
 | |
|             if arg.pass_arg_as_tensor:
 | |
|                 self.add_push_null(
 | |
|                     lambda: self.extend_output(
 | |
|                         [
 | |
|                             self.create_load_python_module(torch),
 | |
|                             self.create_load_attr("_as_tensor_fullprec"),
 | |
|                         ]
 | |
|                     )
 | |
|                 )
 | |
|                 self.call_reconstruct(arg)
 | |
|                 self.extend_output(create_call_function(1, False))
 | |
|             else:
 | |
|                 self.call_reconstruct(arg)
 | |
| 
 | |
|         if config.record_runtime_overhead:
 | |
|             # Record the pregraph bytecode end
 | |
|             self.add_push_null(
 | |
|                 lambda: self.load_import_from(
 | |
|                     utils.__name__, "record_pregraph_bytecode_exit"
 | |
|                 )
 | |
|             )
 | |
|             assert cm_var is not None
 | |
|             self.extend_output([self.create_load(cm_var)])
 | |
|             self.extend_output(create_call_function(1, False))
 | |
|             self.pop_top()
 | |
| 
 | |
|         self.extend_output(create_call_function(len(graphargs), False))
 | |
| 
 | |
|     def create_import_name(self, module_name: str) -> Instruction:
 | |
|         return create_instruction("IMPORT_NAME", argval=module_name)
 | |
| 
 | |
|     def load_import_from(self, module_name: str, object_name: str) -> None:
 | |
|         source = AttrSource(self.tx.import_source(module_name), object_name)
 | |
|         # Note: This approach is somewhat aggressive because typically, a source is marked
 | |
|         # as a tempvar only when it is used more than once. In this case, we're marking it
 | |
|         # as a tempvar without performing that analysis. However, this is a simple solution,
 | |
|         # and in many cases, load imports are reused multiple times.
 | |
|         self.mark_source_temp(source)
 | |
|         self(source)
 | |
| 
 | |
|     def create_call_function_kw(
 | |
|         self, nargs: int, kw_names: Iterable[str], push_null: bool
 | |
|     ) -> list[Instruction]:
 | |
|         if sys.version_info >= (3, 13):
 | |
|             output = create_call_function(nargs, push_null)
 | |
|             assert output[-1].opname == "CALL"
 | |
|             output.insert(-1, self.create_load_const(kw_names))
 | |
|             output[-1] = create_instruction("CALL_KW", arg=nargs)
 | |
|             return output
 | |
|         elif sys.version_info >= (3, 11):
 | |
|             output = create_call_function(nargs, push_null)
 | |
|             if sys.version_info >= (3, 12):
 | |
|                 idx = -1
 | |
|                 expected_inst = "CALL"
 | |
|             else:
 | |
|                 idx = -2
 | |
|                 expected_inst = "PRECALL"
 | |
|             assert output[idx].opname == expected_inst
 | |
|             kw_names_inst = create_instruction("KW_NAMES", argval=kw_names)
 | |
|             output.insert(idx, kw_names_inst)
 | |
|             return output
 | |
|         return [
 | |
|             self.create_load_const(kw_names),
 | |
|             create_instruction("CALL_FUNCTION_KW", arg=nargs),
 | |
|         ]
 | |
| 
 | |
|     def create_delete(self, value: object) -> Instruction:
 | |
|         return create_instruction("DELETE_FAST", argval=value)
 |