Files
pytorch/torch/_dynamo/variables/__init__.py
Animesh Jain 4308b8a28f [dynamo] Support torch.fx.traceback.annotate (#164678)
Builds on top of https://github.com/pytorch/pytorch/pull/163673 and https://github.com/pytorch/pytorch/pull/164174. This will be used in the followup PRs to apply regional inductor compilation.

The existing implementation let Dynamo trace into the `torch.fx.traceback.annotate`, but thats not what we want. We want Dynamo to essentially run the torch.fx.traceback.annotate function in eager, so that every Fx node created in Dynamo Fx graph has the custom meta node.

What does not work?
* We still have to set the context manager `torch.fx.traceback.preserve_node_meta()` in the user code because CI was unhappy. This can be fixed but with some perseverance.
* This does not work with graph breaks yet. But we can solve that problem, if needed, in a separate PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164678
Approved by: https://github.com/SherlockNoMad, https://github.com/jansel, https://github.com/xmfan
2025-10-08 22:41:00 +00:00

232 lines
6.7 KiB
Python

"""
This package implements variable tracking and symbolic execution capabilities for Dynamo,
which are essential for converting Python code into FX graphs. It provides a comprehensive
set of variable types that handle different Python constructs during tracing.
Each variable type (like BuiltinVariable, TensorVariable, NNModuleVariable, etc.) is responsible
for tracking and symbolically executing operations on specific Python objects. This enables
Dynamo to:
- Track the flow of values through Python code
- Maintain correct semantics during graph conversion
- Handle complex Python features like context managers, iterators, and custom objects
- Support both eager and symbolic execution modes
The VariableTracker base class provides the foundation for all variable types, with each
subclass implementing specific behavior for different Python constructs. This modular design
allows Dynamo to accurately trace and optimize Python code while preserving its semantics.
"""
from .base import VariableTracker
from .builtin import BuiltinVariable
from .constant import ConstantVariable, EnumVariable
from .ctx_manager import (
CatchWarningsCtxManagerVariable,
ContextWrappingVariable,
CUDADeviceVariable,
DeterministicAlgorithmsVariable,
DisabledSavedTensorsHooksVariable,
DualLevelContextManager,
DynamoConfigPatchVariable,
ErrorOnGraphBreakVariable,
FSDPParamGroupUseTrainingStateVariable,
FxTracebackAnnotateVariable,
GradIncrementNestingCtxManagerVariable,
GradInplaceRequiresGradCtxManagerVariable,
GradModeVariable,
InferenceModeVariable,
JvpIncrementNestingCtxManagerVariable,
SDPAKernelVariable,
SetFwdGradEnabledContextManager,
StreamContextVariable,
StreamVariable,
TemporarilyPopInterpreterStackCtxManagerVariable,
VmapIncrementNestingCtxManagerVariable,
WithEnterFunctionVariable,
WithExitFunctionVariable,
)
from .dicts import (
ConstDictVariable,
DefaultDictVariable,
DictKeySetVariable,
FrozensetVariable,
MappingProxyVariable,
NNModuleHooksDictVariable,
SetVariable,
)
from .distributed import BackwardHookVariable, DistributedVariable, PlacementVariable
from .functions import (
BuiltinMethodVariable,
CollectionsNamedTupleFunction,
CreateTMADescriptorExperimentalVariable,
CreateTMADescriptorStableVariable,
FunctionDecoratedByContextlibContextManagerVariable,
FunctoolsPartialVariable,
FunctoolsWrapsVariable,
LocalGeneratorFunctionVariable,
LocalGeneratorObjectVariable,
NestedUserFunctionVariable,
PolyfilledFunctionVariable,
SkipFunctionVariable,
TMADescriptorExperimentalVariable,
TMADescriptorStableVariable,
UserFunctionVariable,
UserMethodVariable,
WrapperUserFunctionVariable,
WrapperUserMethodVariable,
)
from .higher_order_ops import (
FunctionalCallVariable,
FunctorchHigherOrderVariable,
ReparametrizeModuleCallVariable,
TorchHigherOrderOperatorVariable,
)
from .iter import (
CountIteratorVariable,
FilterVariable,
IteratorVariable,
ItertoolsVariable,
MapVariable,
ObjectIteratorVariable,
RepeatIteratorVariable,
ZipVariable,
)
from .lazy import LazyVariableTracker
from .lists import (
BaseListVariable,
ListIteratorVariable,
ListVariable,
NamedTupleVariable,
RangeVariable,
SliceVariable,
TupleIteratorVariable,
TupleVariable,
)
from .misc import (
AutogradFunctionContextVariable,
AutogradFunctionVariable,
CellVariable,
DeletedVariable,
ExceptionVariable,
GetAttrVariable,
LambdaVariable,
MethodWrapperVariable,
NewGlobalVariable,
NumpyVariable,
PythonModuleVariable,
RandomClassVariable,
RandomVariable,
RegexPatternVariable,
StringFormatVariable,
SuperVariable,
TorchVersionVariable,
TypingVariable,
UnknownVariable,
WeakRefVariable,
)
from .nn_module import (
FSDPManagedNNModuleVariable,
NNModuleVariable,
UnspecializedBuiltinNNModuleVariable,
UnspecializedNNModuleVariable,
)
from .optimizer import OptimizerVariable
from .sdpa import SDPAParamsVariable
from .tensor import (
DataPtrVariable,
FakeItemVariable,
NumpyNdarrayVariable,
SymNodeVariable,
TensorVariable,
UnspecializedPythonVariable,
UntypedStorageVariable,
)
from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
from .user_defined import (
FrozenDataClassVariable,
MutableMappingVariable,
RemovableHandleVariable,
UserDefinedClassVariable,
UserDefinedDictVariable,
UserDefinedExceptionClassVariable,
UserDefinedExceptionObjectVariable,
UserDefinedListVariable,
UserDefinedObjectVariable,
UserDefinedSetVariable,
UserDefinedTupleVariable,
)
__all__ = [
"AutogradFunctionContextVariable",
"AutogradFunctionVariable",
"BackwardHookVariable",
"BaseListVariable",
"BuiltinVariable",
"CatchWarningsCtxManagerVariable",
"ConstantVariable",
"ConstDictVariable",
"ContextWrappingVariable",
"CountIteratorVariable",
"CreateTMADescriptorExperimentalVariable",
"CreateTMADescriptorStableVariable",
"CUDADeviceVariable",
"DataPtrVariable",
"DefaultDictVariable",
"DeletedVariable",
"DeterministicAlgorithmsVariable",
"DictKeySetVariable",
"DynamoConfigPatchVariable",
"EnumVariable",
"FakeItemVariable",
"GetAttrVariable",
"GradModeVariable",
"IteratorVariable",
"ItertoolsVariable",
"LambdaVariable",
"LazyVariableTracker",
"ListIteratorVariable",
"ListVariable",
"NamedTupleVariable",
"NestedUserFunctionVariable",
"CellVariable",
"NewGlobalVariable",
"NNModuleVariable",
"NumpyNdarrayVariable",
"NumpyVariable",
"OptimizerVariable",
"PlacementVariable",
"PolyfilledFunctionVariable",
"PythonModuleVariable",
"RangeVariable",
"RegexPatternVariable",
"RemovableHandleVariable",
"RepeatIteratorVariable",
"SDPAParamsVariable",
"ErrorOnGraphBreakVariable",
"SkipFunctionVariable",
"SliceVariable",
"StringFormatVariable",
"SuperVariable",
"TemporarilyPopInterpreterStackCtxManagerVariable",
"TensorVariable",
"TMADescriptorExperimentalVariable",
"TMADescriptorStableVariable",
"TorchCtxManagerClassVariable",
"TorchInGraphFunctionVariable",
"TorchVersionVariable",
"TupleVariable",
"UnknownVariable",
"UnspecializedNNModuleVariable",
"UnspecializedPythonVariable",
"UntypedStorageVariable",
"UserDefinedClassVariable",
"UserDefinedTupleVariable",
"UserDefinedObjectVariable",
"UserFunctionVariable",
"UserMethodVariable",
"VariableTracker",
"WithEnterFunctionVariable",
"WithExitFunctionVariable",
"MappingProxyVariable",
]