mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Builds on top of https://github.com/pytorch/pytorch/pull/163673 and https://github.com/pytorch/pytorch/pull/164174. This will be used in the followup PRs to apply regional inductor compilation. The existing implementation let Dynamo trace into the `torch.fx.traceback.annotate`, but thats not what we want. We want Dynamo to essentially run the torch.fx.traceback.annotate function in eager, so that every Fx node created in Dynamo Fx graph has the custom meta node. What does not work? * We still have to set the context manager `torch.fx.traceback.preserve_node_meta()` in the user code because CI was unhappy. This can be fixed but with some perseverance. * This does not work with graph breaks yet. But we can solve that problem, if needed, in a separate PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164678 Approved by: https://github.com/SherlockNoMad, https://github.com/jansel, https://github.com/xmfan
232 lines
6.7 KiB
Python
232 lines
6.7 KiB
Python
"""
|
|
This package implements variable tracking and symbolic execution capabilities for Dynamo,
|
|
which are essential for converting Python code into FX graphs. It provides a comprehensive
|
|
set of variable types that handle different Python constructs during tracing.
|
|
|
|
Each variable type (like BuiltinVariable, TensorVariable, NNModuleVariable, etc.) is responsible
|
|
for tracking and symbolically executing operations on specific Python objects. This enables
|
|
Dynamo to:
|
|
- Track the flow of values through Python code
|
|
- Maintain correct semantics during graph conversion
|
|
- Handle complex Python features like context managers, iterators, and custom objects
|
|
- Support both eager and symbolic execution modes
|
|
|
|
The VariableTracker base class provides the foundation for all variable types, with each
|
|
subclass implementing specific behavior for different Python constructs. This modular design
|
|
allows Dynamo to accurately trace and optimize Python code while preserving its semantics.
|
|
"""
|
|
|
|
from .base import VariableTracker
|
|
from .builtin import BuiltinVariable
|
|
from .constant import ConstantVariable, EnumVariable
|
|
from .ctx_manager import (
|
|
CatchWarningsCtxManagerVariable,
|
|
ContextWrappingVariable,
|
|
CUDADeviceVariable,
|
|
DeterministicAlgorithmsVariable,
|
|
DisabledSavedTensorsHooksVariable,
|
|
DualLevelContextManager,
|
|
DynamoConfigPatchVariable,
|
|
ErrorOnGraphBreakVariable,
|
|
FSDPParamGroupUseTrainingStateVariable,
|
|
FxTracebackAnnotateVariable,
|
|
GradIncrementNestingCtxManagerVariable,
|
|
GradInplaceRequiresGradCtxManagerVariable,
|
|
GradModeVariable,
|
|
InferenceModeVariable,
|
|
JvpIncrementNestingCtxManagerVariable,
|
|
SDPAKernelVariable,
|
|
SetFwdGradEnabledContextManager,
|
|
StreamContextVariable,
|
|
StreamVariable,
|
|
TemporarilyPopInterpreterStackCtxManagerVariable,
|
|
VmapIncrementNestingCtxManagerVariable,
|
|
WithEnterFunctionVariable,
|
|
WithExitFunctionVariable,
|
|
)
|
|
from .dicts import (
|
|
ConstDictVariable,
|
|
DefaultDictVariable,
|
|
DictKeySetVariable,
|
|
FrozensetVariable,
|
|
MappingProxyVariable,
|
|
NNModuleHooksDictVariable,
|
|
SetVariable,
|
|
)
|
|
from .distributed import BackwardHookVariable, DistributedVariable, PlacementVariable
|
|
from .functions import (
|
|
BuiltinMethodVariable,
|
|
CollectionsNamedTupleFunction,
|
|
CreateTMADescriptorExperimentalVariable,
|
|
CreateTMADescriptorStableVariable,
|
|
FunctionDecoratedByContextlibContextManagerVariable,
|
|
FunctoolsPartialVariable,
|
|
FunctoolsWrapsVariable,
|
|
LocalGeneratorFunctionVariable,
|
|
LocalGeneratorObjectVariable,
|
|
NestedUserFunctionVariable,
|
|
PolyfilledFunctionVariable,
|
|
SkipFunctionVariable,
|
|
TMADescriptorExperimentalVariable,
|
|
TMADescriptorStableVariable,
|
|
UserFunctionVariable,
|
|
UserMethodVariable,
|
|
WrapperUserFunctionVariable,
|
|
WrapperUserMethodVariable,
|
|
)
|
|
from .higher_order_ops import (
|
|
FunctionalCallVariable,
|
|
FunctorchHigherOrderVariable,
|
|
ReparametrizeModuleCallVariable,
|
|
TorchHigherOrderOperatorVariable,
|
|
)
|
|
from .iter import (
|
|
CountIteratorVariable,
|
|
FilterVariable,
|
|
IteratorVariable,
|
|
ItertoolsVariable,
|
|
MapVariable,
|
|
ObjectIteratorVariable,
|
|
RepeatIteratorVariable,
|
|
ZipVariable,
|
|
)
|
|
from .lazy import LazyVariableTracker
|
|
from .lists import (
|
|
BaseListVariable,
|
|
ListIteratorVariable,
|
|
ListVariable,
|
|
NamedTupleVariable,
|
|
RangeVariable,
|
|
SliceVariable,
|
|
TupleIteratorVariable,
|
|
TupleVariable,
|
|
)
|
|
from .misc import (
|
|
AutogradFunctionContextVariable,
|
|
AutogradFunctionVariable,
|
|
CellVariable,
|
|
DeletedVariable,
|
|
ExceptionVariable,
|
|
GetAttrVariable,
|
|
LambdaVariable,
|
|
MethodWrapperVariable,
|
|
NewGlobalVariable,
|
|
NumpyVariable,
|
|
PythonModuleVariable,
|
|
RandomClassVariable,
|
|
RandomVariable,
|
|
RegexPatternVariable,
|
|
StringFormatVariable,
|
|
SuperVariable,
|
|
TorchVersionVariable,
|
|
TypingVariable,
|
|
UnknownVariable,
|
|
WeakRefVariable,
|
|
)
|
|
from .nn_module import (
|
|
FSDPManagedNNModuleVariable,
|
|
NNModuleVariable,
|
|
UnspecializedBuiltinNNModuleVariable,
|
|
UnspecializedNNModuleVariable,
|
|
)
|
|
from .optimizer import OptimizerVariable
|
|
from .sdpa import SDPAParamsVariable
|
|
from .tensor import (
|
|
DataPtrVariable,
|
|
FakeItemVariable,
|
|
NumpyNdarrayVariable,
|
|
SymNodeVariable,
|
|
TensorVariable,
|
|
UnspecializedPythonVariable,
|
|
UntypedStorageVariable,
|
|
)
|
|
from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
|
|
from .user_defined import (
|
|
FrozenDataClassVariable,
|
|
MutableMappingVariable,
|
|
RemovableHandleVariable,
|
|
UserDefinedClassVariable,
|
|
UserDefinedDictVariable,
|
|
UserDefinedExceptionClassVariable,
|
|
UserDefinedExceptionObjectVariable,
|
|
UserDefinedListVariable,
|
|
UserDefinedObjectVariable,
|
|
UserDefinedSetVariable,
|
|
UserDefinedTupleVariable,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"AutogradFunctionContextVariable",
|
|
"AutogradFunctionVariable",
|
|
"BackwardHookVariable",
|
|
"BaseListVariable",
|
|
"BuiltinVariable",
|
|
"CatchWarningsCtxManagerVariable",
|
|
"ConstantVariable",
|
|
"ConstDictVariable",
|
|
"ContextWrappingVariable",
|
|
"CountIteratorVariable",
|
|
"CreateTMADescriptorExperimentalVariable",
|
|
"CreateTMADescriptorStableVariable",
|
|
"CUDADeviceVariable",
|
|
"DataPtrVariable",
|
|
"DefaultDictVariable",
|
|
"DeletedVariable",
|
|
"DeterministicAlgorithmsVariable",
|
|
"DictKeySetVariable",
|
|
"DynamoConfigPatchVariable",
|
|
"EnumVariable",
|
|
"FakeItemVariable",
|
|
"GetAttrVariable",
|
|
"GradModeVariable",
|
|
"IteratorVariable",
|
|
"ItertoolsVariable",
|
|
"LambdaVariable",
|
|
"LazyVariableTracker",
|
|
"ListIteratorVariable",
|
|
"ListVariable",
|
|
"NamedTupleVariable",
|
|
"NestedUserFunctionVariable",
|
|
"CellVariable",
|
|
"NewGlobalVariable",
|
|
"NNModuleVariable",
|
|
"NumpyNdarrayVariable",
|
|
"NumpyVariable",
|
|
"OptimizerVariable",
|
|
"PlacementVariable",
|
|
"PolyfilledFunctionVariable",
|
|
"PythonModuleVariable",
|
|
"RangeVariable",
|
|
"RegexPatternVariable",
|
|
"RemovableHandleVariable",
|
|
"RepeatIteratorVariable",
|
|
"SDPAParamsVariable",
|
|
"ErrorOnGraphBreakVariable",
|
|
"SkipFunctionVariable",
|
|
"SliceVariable",
|
|
"StringFormatVariable",
|
|
"SuperVariable",
|
|
"TemporarilyPopInterpreterStackCtxManagerVariable",
|
|
"TensorVariable",
|
|
"TMADescriptorExperimentalVariable",
|
|
"TMADescriptorStableVariable",
|
|
"TorchCtxManagerClassVariable",
|
|
"TorchInGraphFunctionVariable",
|
|
"TorchVersionVariable",
|
|
"TupleVariable",
|
|
"UnknownVariable",
|
|
"UnspecializedNNModuleVariable",
|
|
"UnspecializedPythonVariable",
|
|
"UntypedStorageVariable",
|
|
"UserDefinedClassVariable",
|
|
"UserDefinedTupleVariable",
|
|
"UserDefinedObjectVariable",
|
|
"UserFunctionVariable",
|
|
"UserMethodVariable",
|
|
"VariableTracker",
|
|
"WithEnterFunctionVariable",
|
|
"WithExitFunctionVariable",
|
|
"MappingProxyVariable",
|
|
]
|