Compare commits

..

5 Commits

Author SHA1 Message Date
dcbb8f9c12 [CI][3.14] Adjust test_skip_data_serialization_materialize_ for 3.14
By changing exception type and regex


ghstack-source-id: 5f91c18bc37e42997b594875d08a70980033e375
Pull-Request: https://github.com/pytorch/pytorch/pull/167355
2025-11-07 11:53:52 -08:00
e0cf9eb685 [CI] Fix some regex with Python-3.14
Not sure why, but running `test_weights_only_safe_globals_build` with `pytest` makes global name `test_serialization.ClassThatUsesBuildInstruction` instead of expected `__main__.ClassThatUsesBuildInstruction`

ghstack-source-id: 2f9b8c78490f1ea7f200790725f1686a5583ff89
Pull-Request: https://github.com/pytorch/pytorch/pull/167333
2025-11-07 11:53:48 -08:00
9a86ef7632 [BE][Typing][Dynamo] Type torch/_dynamo/variables/functions.py (#167103)
Provides type coverage to torch/_dynamo/variables/dicts.py

Coverage report:
`mypy torch/_dynamo/variables/functions.py --linecount-report /tmp/coverage_log`

Compare before to after - we go from 0 lines and 0 funcs covered to 2698 lines and 166 funcs covered

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167103
Approved by: https://github.com/mlazos, https://github.com/fxdawnn
2025-11-07 00:40:49 +00:00
f47cadf75d [BE][Typing][Dynamo] Type torch/_dynamo/variables/lists.py (#167156)
Provides type coverage to torch/_dynamo/variables/dicts.py

Coverage report:
`mypy torch/_dynamo/variables/lists.py --linecount-report /tmp/coverage_log`

Compare before to after - we go from 0 lines and 0 funcs covered to 1759 lines and 102 funcs covered

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167156
Approved by: https://github.com/Skylion007, https://github.com/rtimpe
2025-11-07 00:15:40 +00:00
2923b02c6e [DTensor] add explicit mode (ExplicitRedistributionContext) (#166593)
usage:

```
dx = distribute_tensor(x, device_mesh, [Shard(0)])
dA = distribute_tensor(A, device_mesh, [Shard(0)])
with ExplicitRedistributionContext():
    with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"):
        # Shard(0) @ Shard(0) requires a redistribution
        torch.matmul(dx, dA)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166593
Approved by: https://github.com/ezyang
2025-11-07 00:04:19 +00:00
23 changed files with 959 additions and 817 deletions

View File

@ -1,11 +1,18 @@
# Owner(s): ["oncall: distributed"]
import itertools
from contextlib import nullcontext
from typing import Any
import torch
import torch.distributed as dist
from torch.distributed._local_tensor import (
local_tensor_mode,
LocalTensor,
LocalTensorMode,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, DTensor
from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._utils import (
_compute_local_shape_and_global_offset,
@ -14,6 +21,7 @@ from torch.distributed.tensor._utils import (
compute_global_tensor_shape,
compute_local_shape_and_global_offset,
compute_local_tensor_info,
ExplicitRedistributionContext,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.placement_types import (
@ -851,5 +859,93 @@ class Test2DStridedLocalShard(DTensorTestBase):
self.assertEqual(global_tensor, dtensor_2d.full_tensor())
class LocalTensorTestBase(TestCase):
def assertEqual(self, lhs, rhs, **kwargs):
mode = local_tensor_mode()
with nullcontext() if mode is None else mode.disable():
if isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor):
assert isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor)
super().assertEqual(lhs._ranks, rhs._ranks)
for r in lhs._ranks:
super().assertEqual(
lhs._local_tensors[r],
rhs._local_tensors[r],
lambda m: f"rank {r}: {m}",
)
elif isinstance(lhs, LocalTensor) or isinstance(rhs, LocalTensor):
lhs, rhs = (lhs, rhs) if isinstance(lhs, LocalTensor) else (rhs, lhs)
for r in lhs._ranks:
super().assertEqual(
lhs._local_tensors[r], rhs, lambda m: f"rank {r}: {m}"
)
else:
return super().assertEqual(lhs, rhs, **kwargs)
@property
def world_size(self):
raise NotImplementedError("override world-size in your subclass")
def build_device_mesh(self) -> DeviceMesh:
return init_device_mesh("cpu", (self.world_size,))
def setUp(self):
super().setUp()
torch.distributed.init_process_group(
# TODO: test other ranks too
"fake",
rank=0,
world_size=self.world_size,
)
def tearDown(self):
super().tearDown()
try:
dist.destroy_process_group()
except AssertionError:
pass
class TestExplicitRedistribute(LocalTensorTestBase):
@property
def world_size(self):
return 4
def test_explicit_matmul(self):
with LocalTensorMode(self.world_size):
device_mesh = self.build_device_mesh()
dim = 128
x = torch.randn(8, dim, requires_grad=True)
A = torch.randn(dim, dim, requires_grad=True)
# Prepare DTensors
dx = distribute_tensor(x, device_mesh, [Shard(0)])
dA = distribute_tensor(A, device_mesh, [Shard(0)])
# implicit redistribute works as usual by default
with CommDebugMode() as comm_mode:
torch.matmul(dx, dA)
self.assertEqual(comm_mode.get_total_counts(), 1)
# explicit redistribute works too
with ExplicitRedistributionContext():
with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"):
torch.matmul(dx, dA)
# explicit redistribute allows manual redistribute
with ExplicitRedistributionContext():
dA_repl = dA.redistribute(device_mesh, [Replicate()])
torch.matmul(dx, dA_repl)
dx = distribute_tensor(x, device_mesh, [Shard(0)])
dA = distribute_tensor(A, device_mesh, [Replicate()])
with ExplicitRedistributionContext():
dY = torch.matmul(dx, dA_repl)
loss = dY.sum()
# we now see the error during backwards
with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"):
loss.backward()
if __name__ == "__main__":
run_tests()

View File

@ -27,9 +27,9 @@ from torch._inductor.fx_passes.post_grad import post_grad_passes
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code, run_and_get_cpp_code
from torch._inductor.virtualized import V
from torch.testing._internal.common_utils import IS_MACOS, skipIfRocm
from torch.testing._internal.common_utils import IS_MACOS
from torch.testing._internal.triton_utils import requires_cuda_and_triton
from torch.profiler import profile, ProfilerActivity
try:
from .test_aot_inductor_utils import AOTIRunnerUtil
@ -941,63 +941,5 @@ copy_tests(
)
from torch.profiler._utils import _enrich_profiler_traces
class TestProfilerStackTraceAugmentation(TestCase):
"""
Test that profiler events are correctly augmented with stack traces
from both FX metadata and inductor kernel stack traces.
"""
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
@config.patch("fallback_by_default", True) # TODO update the config patch to inductor lite mode
@torch.compiler.config.patch("force_disable_caches", True)
def test_profiler_inductor_stack_trace_augmentation(self):
"""
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
augments profiler events with stack traces from inductor kernel metadata.
"""
# Test model similar to test.py
class TestModel(torch.nn.Module):
def forward(self, c):
d = c * 2
d = d + 1
return d
device = "cuda"
model = TestModel().to(device)
c = torch.randn((64, 32), device=device)
# Force disable caches to ensure fresh compilation
torch.compiler.config.force_disable_caches = True
# Compile the model
compiled_model = torch.compile(model, fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_model(c)
# Profile with the compiled model
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
compiled_model(c)
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::mul node=torch.ops.aten.mul.Tensor:1 stack_trace=d = c * 2
event=cudaLaunchKernel node=torch.ops.aten.mul.Tensor:1 stack_trace=d = c * 2
event=aten::add node=torch.ops.aten.add.Tensor:2 stack_trace=d = d + 1
event=cudaLaunchKernel node=torch.ops.aten.add.Tensor:2 stack_trace=d = d + 1""")
# TODO: add test that when enrich is not turned on there is no recordfast generated.
if __name__ == "__main__":
run_tests()

View File

@ -76,8 +76,11 @@ from torch.testing._internal.common_utils import (
)
from torch.testing._internal.jit_utils import JitTestCase
import json
import tempfile
from torch.profiler import profile, ProfilerActivity
from torch.profiler._utils import _enrich_profiler_traces
from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace
from torch.autograd.profiler_util import _canonicalize_profiler_events
try:
from torchvision import models as torchvision_models
@ -205,6 +208,36 @@ def side_effect_func(x: torch.Tensor):
print(x)
def _enrich_profiler_traces(prof):
"""
Helper function to extract and augment profiler events with stack traces.
Args:
prof: A torch.profiler.profile object
Returns:
A string representing enriched events
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f:
trace_file = f.name
prof.export_chrome_trace(trace_file)
with open(trace_file) as f:
trace_data = json.load(f)
map_recorded_events_to_aten_ops_with_stack_trace(
trace_data
)
events = []
for event in trace_data["traceEvents"]:
if "args" in event and "stack_trace" in event["args"]:
events.append(event)
actual_traces = _canonicalize_profiler_events(events)
return actual_traces
class TestFX(JitTestCase):
def setUp(self):
super().setUp()

View File

@ -1281,7 +1281,7 @@ class TestSerialization(TestCase, SerializationMixin):
torch.save(p, f)
f.seek(0)
with self.assertRaisesRegex(pickle.UnpicklingError,
"GLOBAL __main__.Point was not an allowed global by default"):
f"GLOBAL {__name__}.Point was not an allowed global by default"):
torch.load(f, weights_only=True)
f.seek(0)
with torch.serialization.safe_globals([Point]):
@ -1300,7 +1300,7 @@ class TestSerialization(TestCase, SerializationMixin):
torch.save(c, f)
f.seek(0)
with self.assertRaisesRegex(pickle.UnpicklingError,
"GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default"):
f"GLOBAL {__name__}.ClassThatUsesBuildInstruction was not an allowed global by default"):
torch.load(f, weights_only=True)
try:
with torch.serialization.safe_globals([ClassThatUsesBuildInstruction]):
@ -1330,7 +1330,7 @@ class TestSerialization(TestCase, SerializationMixin):
torch.save(obj, f)
f.seek(0)
with self.assertRaisesRegex(pickle.UnpicklingError,
f"GLOBAL __main__.{obj_cls.__name__} was not an allowed global by default"):
f"GLOBAL {__name__}.{obj_cls.__name__} was not an allowed global by default"):
torch.load(f, weights_only=True)
f.seek(0)
@ -4501,9 +4501,10 @@ class TestSerialization(TestCase, SerializationMixin):
# Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx
if not materialize_fake:
ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device))
exc = pickle.PicklingError if sys.version_info >= (3, 14) else AttributeError
with self.assertRaisesRegex(
AttributeError,
"Can't (get|pickle) local object 'WeakValueDictionary.__init__.<locals>.remove'"
exc,
"Can't (get|pickle) local object (<function |')WeakValueDictionary.__init__.<locals>.remove"
):
with skip_data(), BytesIOContext() as f:
torch.save(ft, f)

View File

@ -3228,7 +3228,7 @@ class InstructionTranslatorBase(
def BUILD_SLICE(self, inst: Instruction) -> None:
items = self.popn(inst.argval)
self.push(SliceVariable(items, tx=self))
self.push(SliceVariable(items, tx=self)) # type: ignore[arg-type]
def BUILD_LIST(self, inst: Instruction) -> None:
items = self.popn(inst.argval)
@ -3607,7 +3607,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg]
assert isinstance(obj, ListVariable)
assert obj.is_mutable()
obj.call_method(self, "extend", [v], {})
obj.call_method(self, "extend", [v], {}) # type: ignore[arg-type]
def LIST_TO_TUPLE(self, inst: Instruction) -> None:
self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) # type: ignore[arg-type]
@ -3673,7 +3673,7 @@ class InstructionTranslatorBase(
def MATCH_KEYS(self, inst: Instruction) -> None:
tos = self.stack[-1]
assert isinstance(tos, TupleVariable)
keys = tos.unpack_var_sequence(self)
keys = tos.unpack_var_sequence(self) # type: ignore[arg-type]
tos1 = self.stack[-2]
assert isinstance(tos1, ConstDictVariable)

View File

@ -1991,7 +1991,7 @@ class BuiltinVariable(VariableTracker):
# If the object implements a __getitem__ method, iter(...) will call obj.__getitem__()
# with an integer argument starting at 0, until __getitem__ raises IndexError
ret = variables.UserFunctionVariable(
polyfills.builtins.iter_
polyfills.builtins.iter_ # type: ignore[arg-type]
).call_function(tx, [obj, *args], {})
if args:

View File

@ -1513,7 +1513,7 @@ class WithExitFunctionVariable(VariableTracker):
# Note here we reconstruct the context manager rather than the
# exit function. The handler generated by BlockStackEntry
# will re-enter the context in the resume function.
self.ctx.reconstruct_type(codegen) # type: ignore[attr-defined]
self.ctx.reconstruct_type(codegen) # type: ignore[union-attr]
if codegen.tx.output.partial_convert:
if sys.version_info >= (3, 11):
codegen.append_output(create_instruction("PUSH_NULL"))
@ -1522,10 +1522,10 @@ class WithExitFunctionVariable(VariableTracker):
# We rely on classes subtyping `GenericContextWrappingVariable`
# to implement these fns and have these attributes
codegen.extend_output(
[codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[arg-type]
[codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[union-attr]
)
codegen.extend_output(
create_call_function(len(self.ctx.target_values), False) # type: ignore[arg-type]
create_call_function(len(self.ctx.target_values), False) # type: ignore[union-attr]
)
codegen.append_output(create_setup_with(self.target))
codegen.append_output(create_instruction("POP_TOP"))

File diff suppressed because it is too large Load Diff

View File

@ -82,7 +82,8 @@ class ItertoolsVariable(VariableTracker):
for item in itertools.product(*seqs, repeat=r)
]
return variables.ListIteratorVariable(
items, mutation_type=ValueMutationNew()
items, # type: ignore[arg-type]
mutation_type=ValueMutationNew(),
)
elif (
self.value is itertools.combinations
@ -98,7 +99,8 @@ class ItertoolsVariable(VariableTracker):
for item in itertools.combinations(iterable, r):
items.append(variables.TupleVariable(list(item)))
return variables.ListIteratorVariable(
items, mutation_type=ValueMutationNew()
items, # type: ignore[arg-type]
mutation_type=ValueMutationNew(),
)
elif self.value is itertools.groupby:
if any(kw != "key" for kw in kwargs.keys()):
@ -181,7 +183,8 @@ class ItertoolsVariable(VariableTracker):
from_exc=e,
)
return variables.ListIteratorVariable(
result, mutation_type=ValueMutationNew()
result, # type: ignore[arg-type]
mutation_type=ValueMutationNew(),
)
elif self.value is itertools.repeat:
if len(args) < 2:
@ -212,7 +215,8 @@ class ItertoolsVariable(VariableTracker):
)
]
return variables.ListIteratorVariable(
items, mutation_type=ValueMutationNew()
items, # type: ignore[arg-type]
mutation_type=ValueMutationNew(),
)
else:
return super().call_function(tx, args, kwargs)
@ -586,7 +590,7 @@ class FilterVariable(IteratorVariable):
else:
res = self.fn.call_function(tx, [item], {})
pred_res = variables.UserFunctionVariable(
polyfills.predicate
polyfills.predicate # type: ignore[arg-type]
).call_function(tx, [res], {})
if pred_res.as_python_constant():
return item

View File

@ -1,5 +1,3 @@
# mypy: ignore-errors
"""
Variable tracking implementations for list-like data structures in Dynamo.
@ -20,7 +18,7 @@ import collections
import inspect
import operator
import sys
from typing import Optional, TYPE_CHECKING
from typing import Any, Optional, Sequence, TYPE_CHECKING
import torch
import torch.fx
@ -60,11 +58,11 @@ if TYPE_CHECKING:
class BaseListVariable(VariableTracker):
@staticmethod
def cls_for_instance(obj):
def cls_for_instance(obj: Any) -> type["BaseListVariable"]:
return BaseListVariable.cls_for(type(obj))
@staticmethod
def cls_for(obj):
def cls_for(obj: Any) -> type:
return {
iter: ListIteratorVariable,
list: ListVariable,
@ -80,34 +78,38 @@ class BaseListVariable(VariableTracker):
def __init__(
self,
items: list[VariableTracker],
**kwargs,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
assert isinstance(items, list)
assert all(isinstance(x, VariableTracker) for x in items)
self.items: list[VariableTracker] = items
def _as_proxy(self):
def _as_proxy(self) -> list[Any]:
return [x.as_proxy() for x in self.items]
def modified(self, items, **kwargs):
def modified(
self, items: list[VariableTracker], **kwargs: Any
) -> "BaseListVariable":
return type(self)(items, **kwargs)
@property
def value(self):
def value(self) -> Any:
return self.as_python_constant()
def debug_repr_helper(self, prefix, suffix):
def debug_repr_helper(self, prefix: str, suffix: str) -> str:
return prefix + ", ".join(i.debug_repr() for i in self.items) + suffix
def as_python_constant(self):
def as_python_constant(self) -> Any:
return self.python_type()([x.as_python_constant() for x in self.items])
def as_proxy(self):
def as_proxy(self) -> Any:
assert self.python_type() is not SizeVariable
return self.python_type()(self._as_proxy())
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
from .tensor import SymNodeVariable
if isinstance(arg, SymNodeVariable):
@ -134,16 +136,16 @@ class BaseListVariable(VariableTracker):
IndexError, tx, args=["list index out of range"]
)
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
return list(self.items)
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__getitem__":
from .tensor import TensorVariable
@ -224,15 +226,15 @@ class BaseListVariable(VariableTracker):
if type(self) is not type(args[0]):
tp_name = self.python_type_name()
other = args[0].python_type_name()
msg = ConstantVariable.create(
msg_vt = ConstantVariable.create(
f'can only concatenate {tp_name} (not "{other}") to {tp_name}'
)
raise_observed_exception(TypeError, tx, args=[msg])
raise_observed_exception(TypeError, tx, args=[msg_vt])
if name == "__add__":
return type(self)(self.items + args[0].items, source=self.source)
return type(self)(self.items + args[0].items, source=self.source) # type: ignore[attr-defined]
else:
self.items += args[0].items
self.items += args[0].items # type: ignore[attr-defined]
return self
elif name in ("__mul__", "__imul__"):
if kwargs or len(args) != 1:
@ -244,10 +246,10 @@ class BaseListVariable(VariableTracker):
)
if not (args[0].is_python_constant() and args[0].python_type() is int):
msg = ConstantVariable.create(
msg_vt = ConstantVariable.create(
f"can't multiply sequence by non-int type of '{args[0].python_type_name()}'"
)
raise_observed_exception(TypeError, tx, args=[msg])
raise_observed_exception(TypeError, tx, args=[msg_vt])
val = args[0].as_python_constant()
@ -301,7 +303,7 @@ class BaseListVariable(VariableTracker):
class RangeVariable(BaseListVariable):
def __init__(self, items, **kwargs) -> None:
def __init__(self, items: Sequence[VariableTracker], **kwargs: Any) -> None:
items_to_map = items
start = variables.ConstantVariable.create(0)
stop = None
@ -316,7 +318,7 @@ class RangeVariable(BaseListVariable):
else:
raise AssertionError
def maybe_as_int(x):
def maybe_as_int(x: VariableTracker) -> VariableTracker:
return (
ConstantVariable(int(x.value)) if isinstance(x, ConstantVariable) else x
)
@ -329,22 +331,22 @@ class RangeVariable(BaseListVariable):
assert stop is not None
super().__init__([start, stop, step], **kwargs)
def debug_repr(self):
def debug_repr(self) -> str:
return self.debug_repr_helper("range(", ")")
def python_type(self):
def python_type(self) -> type:
return range
def start(self):
def start(self) -> Any:
return self.items[0].as_python_constant()
def stop(self):
def stop(self) -> Any:
return self.items[1].as_python_constant()
def step(self):
def step(self) -> Any:
return self.items[2].as_python_constant()
def range_length(self):
def range_length(self) -> int:
lo = self.start()
hi = self.stop()
step = self.step()
@ -357,7 +359,7 @@ class RangeVariable(BaseListVariable):
else:
return 0
def _get_slice_indices(self, length, slice):
def _get_slice_indices(self, length: int, slice: slice) -> list[int]:
step_is_negative = 0
if slice.step is None:
@ -406,7 +408,7 @@ class RangeVariable(BaseListVariable):
return [start, stop, step]
def apply_index(self, index):
def apply_index(self, index: int) -> VariableTracker:
length = self.range_length()
if index < 0:
index = length + index
@ -421,12 +423,12 @@ class RangeVariable(BaseListVariable):
return variables.ConstantVariable.create(self.start() + (index * self.step()))
def apply_slice(self, slice):
def apply_slice(self, slice: slice) -> "RangeVariable":
(slice_start, slice_stop, slice_step) = self._get_slice_indices(
self.range_length(), slice
)
def compute_item(index):
def compute_item(index: int) -> int:
return self.start() + (index * self.step())
sub_step = self.step() * slice_step
@ -442,10 +444,12 @@ class RangeVariable(BaseListVariable):
)
return result
def as_python_constant(self):
def as_python_constant(self) -> range:
return range(*[x.as_python_constant() for x in self.items])
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
# implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c
index = arg.as_python_constant()
@ -457,28 +461,30 @@ class RangeVariable(BaseListVariable):
msg = ConstantVariable("range indices must be integers or slices")
raise_observed_exception(TypeError, tx, args=[msg])
def as_proxy(self):
def as_proxy(self) -> range:
return self.python_type()(*self._as_proxy())
def unpack_var_sequence(self, tx=None):
def unpack_var_sequence(
self, tx: Optional["InstructionTranslator"] = None
) -> list[VariableTracker]:
return [variables.ConstantVariable.create(x) for x in self.as_python_constant()]
def reconstruct(self, codegen: "PyCodegen") -> None:
assert "range" not in codegen.tx.f_globals
codegen.add_push_null(
lambda: codegen.append_output(codegen.create_load_python_module(range))
lambda: codegen.append_output(codegen.create_load_python_module(range)) # type: ignore[arg-type]
)
codegen.foreach(self.items)
codegen.extend_output(create_call_function(3, False))
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
if self.python_type() is range:
return variables.ConstantVariable.create(name in range.__dict__)
return super().call_obj_hasattr(tx, name)
def range_equals(self, other: "RangeVariable"):
def range_equals(self, other: "RangeVariable") -> bool:
r0, r1 = self, other
if (
self.range_length() != r1.range_length()
@ -487,12 +493,12 @@ class RangeVariable(BaseListVariable):
):
return False
if len(r0) == 1:
if self.range_length() == 1:
return True
return r0.step() == r1.step()
def range_count(self, x: VariableTracker):
def range_count(self, x: VariableTracker) -> int:
# Based on CPython
# https://github.com/guilhermeleobas/cpython/blob/baefaa6cba1d69efd2f930cdc56bca682c54b139/Objects/rangeobject.c#L442-L486
x = x.as_python_constant()
@ -511,7 +517,13 @@ class RangeVariable(BaseListVariable):
return int(re)
return 0
def call_method(self, tx, name, args, kwargs):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__iter__":
if not all(var.is_python_constant() for var in self.items):
# Can't represent a `range_iterator` without well defined bounds
@ -545,7 +557,10 @@ class RangeVariable(BaseListVariable):
if pt is not range:
return ConstantVariable.create(NotImplemented)
cmp = self.range_equals(other)
if isinstance(other, RangeVariable):
cmp = self.range_equals(other)
else:
cmp = False
# Two ranges are equal if they produce the same sequence of values
if name == "__eq__":
@ -554,7 +569,7 @@ class RangeVariable(BaseListVariable):
return ConstantVariable(not cmp)
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx: "InstructionTranslator", name):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
fields = ["start", "stop", "step"]
if name in fields:
return self.items[fields.index(name)]
@ -568,11 +583,11 @@ class CommonListMethodsVariable(BaseListVariable):
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
from .tensor import SymNodeVariable
if name == "append" and self.is_mutable():
@ -676,9 +691,9 @@ class CommonListMethodsVariable(BaseListVariable):
self.items[key.evaluate_expr()] = value
elif isinstance(key, SliceVariable):
if key.is_python_constant():
self.items[key.as_python_constant()] = list(value.items)
self.items[key.as_python_constant()] = list(value.items) # type: ignore[attr-defined]
else:
items = slice(
items_slice = slice(
*[
(
s.evaluate_expr()
@ -688,7 +703,7 @@ class CommonListMethodsVariable(BaseListVariable):
for s in key.items
]
)
self.items[items] = list(value.items)
self.items[items_slice] = list(value.items) # type: ignore[attr-defined]
else:
self.items[key.as_python_constant()] = value
return ConstantVariable.create(None)
@ -733,8 +748,8 @@ class CommonListMethodsVariable(BaseListVariable):
"0 args and 0 kwargs",
f"{len(args)} args and {len(kwargs)} kwargs",
)
items = list(self.items)
return self.modified(items, mutation_type=ValueMutationNew())
items_lst: list[VariableTracker] = list(self.items)
return self.modified(items_lst, mutation_type=ValueMutationNew())
elif name == "reverse" and self.is_mutable():
if args or kwargs:
raise_args_mismatch(
@ -763,13 +778,13 @@ class CommonListMethodsVariable(BaseListVariable):
class ListVariable(CommonListMethodsVariable):
def python_type(self):
def python_type(self) -> type:
return list
def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)})"
def debug_repr(self):
def debug_repr(self) -> str:
return self.debug_repr_helper("[", "]")
def reconstruct(self, codegen: "PyCodegen") -> None:
@ -778,11 +793,11 @@ class ListVariable(CommonListMethodsVariable):
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
from .tensor import SymNodeVariable
if name == "__setitem__" and self.is_mutable():
@ -805,14 +820,14 @@ class ListVariable(CommonListMethodsVariable):
msg = ConstantVariable.create("can only assign an iterable")
raise_observed_exception(TypeError, tx, args=[msg])
key = key.as_python_constant()
if key.step == 0:
key_as_const = key.as_python_constant()
if key_as_const.step == 0:
msg = ConstantVariable.create("slice step cannot be zero")
raise_observed_exception(ValueError, tx, args=[msg])
value = value.force_unpack_var_sequence(tx)
value_unpack = value.force_unpack_var_sequence(tx)
try:
self.items[key] = value
self.items[key_as_const] = value_unpack
except Exception as exc:
raise_observed_exception(
type(exc),
@ -859,7 +874,7 @@ class ListVariable(CommonListMethodsVariable):
assert first_non_constant_key is not None
try:
python_type = first_non_constant_key.python_type()
python_type = str(first_non_constant_key.python_type())
except NotImplementedError:
python_type = "unknown"
@ -904,7 +919,7 @@ class ListVariable(CommonListMethodsVariable):
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name == "__class__":
source = AttrSource(self.source, name) if self.source else None
class_type = self.python_type()
@ -916,14 +931,19 @@ class ListVariable(CommonListMethodsVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
if self.python_type() is not list:
return super().call_obj_hasattr(tx, name)
return variables.ConstantVariable.create(hasattr([], name))
class DequeVariable(CommonListMethodsVariable):
def __init__(self, items, maxlen=None, **kwargs) -> None:
def __init__(
self,
items: list[VariableTracker],
maxlen: Optional[VariableTracker] = None,
**kwargs: Any,
) -> None:
if maxlen is None:
maxlen = ConstantVariable.create(None)
assert maxlen.is_python_constant(), (
@ -935,17 +955,17 @@ class DequeVariable(CommonListMethodsVariable):
items = items[-maxlen.as_python_constant() :]
super().__init__(items, **kwargs)
def python_type(self):
def python_type(self) -> type:
return collections.deque
def debug_repr(self):
def debug_repr(self) -> str:
if self.maxlen.as_python_constant() is None:
return self.debug_repr_helper(
"deque([", "], maxlen=" + self.maxlen.debug_repr() + ")"
)
return self.debug_repr_helper("deque([", "])")
def as_python_constant(self):
def as_python_constant(self) -> collections.deque[Any]:
return self.python_type()(
[x.as_python_constant() for x in self.items],
maxlen=self.maxlen.as_python_constant(),
@ -954,7 +974,7 @@ class DequeVariable(CommonListMethodsVariable):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.append_output(
codegen.create_load_python_module(collections.deque)
codegen.create_load_python_module(collections.deque) # type: ignore[arg-type]
)
)
codegen.foreach(self.items)
@ -962,18 +982,18 @@ class DequeVariable(CommonListMethodsVariable):
codegen(self.maxlen)
codegen.extend_output(codegen.create_call_function_kw(2, ("maxlen",), False))
def var_getattr(self, tx: "InstructionTranslator", name):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name == "maxlen":
return self.maxlen
return super().var_getattr(tx, name)
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if (
name == "__setitem__"
and self.is_mutable()
@ -1068,20 +1088,20 @@ class DequeVariable(CommonListMethodsVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
if self.python_type() is collections.deque:
return variables.ConstantVariable.create(name in collections.deque.__dict__)
return super().call_obj_hasattr(tx, name)
class TupleVariable(BaseListVariable):
def python_type(self):
def python_type(self) -> type[tuple]: # type: ignore[type-arg]
return tuple
def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)})"
def debug_repr(self):
def debug_repr(self) -> str:
return self.debug_repr_helper("(", ")")
def reconstruct(self, codegen: "PyCodegen") -> None:
@ -1090,14 +1110,14 @@ class TupleVariable(BaseListVariable):
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name == "__class__":
source = AttrSource(self.source, name) if self.source else None
class_type = self.python_type()
@ -1109,7 +1129,7 @@ class TupleVariable(BaseListVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
if self.python_type() is not tuple:
return super().call_obj_hasattr(tx, name)
return variables.ConstantVariable.create(hasattr((), name))
@ -1127,18 +1147,18 @@ class SizeVariable(TupleVariable):
self,
items: list[VariableTracker],
proxy: Optional[torch.fx.Proxy] = None,
**kwargs,
**kwargs: Any,
) -> None:
self.proxy = proxy
super().__init__(items, **kwargs)
def debug_repr(self):
def debug_repr(self) -> str:
return self.debug_repr_helper("torch.Size([", "])")
def python_type(self):
def python_type(self) -> type:
return torch.Size
def as_proxy(self):
def as_proxy(self) -> Any:
if self.proxy is not None:
return self.proxy
@ -1193,10 +1213,10 @@ class SizeVariable(TupleVariable):
] + create_call_function(1, False)
codegen.extend_output(build_torch_size)
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
return list(self.items)
def numel(self, tx):
def numel(self, tx: "InstructionTranslator") -> VariableTracker:
from .builtin import BuiltinVariable
from .tensor import SymNodeVariable
@ -1226,11 +1246,11 @@ class SizeVariable(TupleVariable):
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__getitem__":
if kwargs or len(args) != 1:
raise_args_mismatch(
@ -1253,7 +1273,9 @@ class SizeVariable(TupleVariable):
return super().call_method(tx, name, args, kwargs)
def get_item_dyn(self, tx: "InstructionTranslator", arg: VariableTracker):
def get_item_dyn(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
from .tensor import SymNodeVariable
if isinstance(arg, SymNodeVariable):
@ -1269,7 +1291,7 @@ class SizeVariable(TupleVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
return variables.ConstantVariable.create(hasattr(torch.Size, name))
@ -1280,33 +1302,39 @@ class NamedTupleVariable(TupleVariable):
*TupleVariable._nonvar_fields,
}
def __init__(self, items, tuple_cls, dynamic_attributes=None, **kwargs) -> None:
def __init__(
self,
items: list[VariableTracker],
tuple_cls: type,
dynamic_attributes: Optional[dict[str, VariableTracker]] = None,
**kwargs: Any,
) -> None:
super().__init__(items, **kwargs)
self.tuple_cls = tuple_cls
self.dynamic_attributes = dynamic_attributes if dynamic_attributes else {}
def is_namedtuple(self):
def is_namedtuple(self) -> bool:
return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable(
getattr(self.tuple_cls, "_make", None)
)
def is_structseq(self):
def is_structseq(self) -> bool:
return not self.is_namedtuple()
def fields(self):
def fields(self) -> tuple[str, ...]:
return namedtuple_fields(self.tuple_cls)
def debug_repr(self):
def debug_repr(self) -> str:
if self.is_structseq():
# StructSequenceType(iterable)
return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items]))
# NamedTupleType(*iterable)
return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items)))
def python_type(self):
def python_type(self) -> type:
return self.tuple_cls
def as_python_constant(self):
def as_python_constant(self) -> Any:
if self.is_structseq():
# StructSequenceType(iterable)
result = self.python_type()([x.as_python_constant() for x in self.items])
@ -1328,7 +1356,7 @@ class NamedTupleVariable(TupleVariable):
return result
def as_proxy(self):
def as_proxy(self) -> Any:
assert self.python_type() is not SizeVariable
if self.is_structseq():
# StructSequenceType(iterable)
@ -1342,7 +1370,10 @@ class NamedTupleVariable(TupleVariable):
# StructSequenceType(iterable)
# NamedTupleType(*iterable)
# NamedTupleType._make(iterable)
create_fn = self.tuple_cls if self.is_structseq() else self.tuple_cls._make
if self.is_structseq():
create_fn = self.tuple_cls
else:
create_fn = self.tuple_cls._make # type: ignore[attr-defined]
codegen.add_push_null(
lambda: codegen.append_output(
codegen.create_load_const_unchecked(create_fn)
@ -1384,8 +1415,8 @@ class NamedTupleVariable(TupleVariable):
def call_method(
self,
tx,
name,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
@ -1446,7 +1477,9 @@ class NamedTupleVariable(TupleVariable):
return super().call_method(tx, name, args, kwargs)
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
if isinstance(arg, SliceVariable):
# slicing a namedtuple produces a tuple
return TupleVariable(
@ -1455,8 +1488,8 @@ class NamedTupleVariable(TupleVariable):
)
return super().getitem_const(tx, arg)
def var_getattr(self, tx: "InstructionTranslator", name):
def check_and_create_method():
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
def check_and_create_method() -> Optional[VariableTracker]:
method = inspect.getattr_static(self.tuple_cls, name, None)
if isinstance(method, classmethod):
# We need the unbounded cls method to avoid the inline __self__
@ -1489,8 +1522,8 @@ class NamedTupleVariable(TupleVariable):
return super().var_getattr(tx, name)
if name == "_fields":
source = NamedTupleFieldsSource(self.source) if self.source else None
return VariableTracker.build(tx, self.fields(), source=source)
result_source = NamedTupleFieldsSource(self.source) if self.source else None
return VariableTracker.build(tx, self.fields(), source=result_source)
if name in self.dynamic_attributes:
return self.dynamic_attributes[name]
@ -1505,14 +1538,19 @@ class NamedTupleVariable(TupleVariable):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> "VariableTracker":
) -> VariableTracker:
return variables.ConstantVariable.create(
name in self.dynamic_attributes or hasattr(self.tuple_cls, name)
)
class SliceVariable(VariableTracker):
def __init__(self, items, tx=None, **kwargs) -> None:
def __init__(
self,
items: Sequence[VariableTracker],
tx: Optional["InstructionTranslator"] = None,
**kwargs: Any,
) -> None:
items_to_map = items
start, stop, step = [variables.ConstantVariable.create(None)] * 3
@ -1547,23 +1585,23 @@ class SliceVariable(VariableTracker):
super().__init__(**kwargs)
def debug_repr(self):
return self.debug_repr_helper("slice(", ")")
def debug_repr(self) -> str:
return "slice(" + ", ".join(i.debug_repr() for i in self.items) + ")"
def as_proxy(self):
def as_proxy(self) -> slice:
return slice(*[x.as_proxy() for x in self.items])
def python_type(self):
def python_type(self) -> type:
return slice
def as_python_constant(self):
def as_python_constant(self) -> slice:
return slice(*[guard_if_dyn(x) for x in self.items])
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.foreach(self.items)
codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items)))
def var_getattr(self, tx: "InstructionTranslator", name):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name in cmp_name_to_op_mapping:
return variables.GetAttrVariable(self, name)
fields = ["start", "stop", "step"]
@ -1584,7 +1622,9 @@ class ListIteratorVariable(IteratorVariable):
*IteratorVariable._nonvar_fields,
}
def __init__(self, items, index: int = 0, **kwargs) -> None:
def __init__(
self, items: list[VariableTracker], index: int = 0, **kwargs: Any
) -> None:
super().__init__(**kwargs)
assert isinstance(items, list)
# Removing this check as it slows things down too much
@ -1598,7 +1638,7 @@ class ListIteratorVariable(IteratorVariable):
def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})"
def next_variable(self, tx):
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
assert self.is_mutable()
old_index = self.index
if old_index >= len(self.items) or self.is_exhausted:
@ -1609,27 +1649,31 @@ class ListIteratorVariable(IteratorVariable):
self.index += 1
return self.items[old_index]
def call_obj_hasattr(self, tx, name):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
return variables.ConstantVariable.create(hasattr(iter([]), name))
def python_type(self):
def python_type(self) -> type:
return type(iter([]))
def as_python_constant(self):
def as_python_constant(self) -> Any:
if self.index > 0:
raise NotImplementedError
return iter([x.as_python_constant() for x in self.items])
def has_unpack_var_sequence(self, tx):
def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool:
return True
def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
if self.is_exhausted:
return []
self.is_exhausted = True
return list(self.items[self.index :])
def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
def force_unpack_var_sequence(
self, tx: "InstructionTranslator"
) -> list[VariableTracker]:
return self.unpack_var_sequence(tx)
def reconstruct(self, codegen: "PyCodegen") -> None:
@ -1656,27 +1700,37 @@ class RangeIteratorVariable(IteratorVariable):
"iter_obj",
}
def __init__(self, start: int, stop: int, step: int, len_: int, **kwargs):
def __init__(
self, start: int, stop: int, step: int, len_: int, **kwargs: Any
) -> None:
super().__init__(**kwargs)
self.start = start
self.stop = stop
self.step = step
self.len = len_
def call_method(self, tx, name, args, kwargs):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if name == "__next__":
return self.next_variable(tx)
elif name == "__iter__":
return self
return super().call_method(tx, name, args, kwargs)
def call_obj_hasattr(self, tx, name):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
if self.python_type() is range_iterator:
ri = iter(range(0))
return ConstantVariable(hasattr(ri, name))
return super().call_obj_hasattr(tx, name)
def next_variable(self, tx):
def next_variable(self, tx: "InstructionTranslator") -> VariableTracker:
if self.len <= 0:
raise_observed_exception(StopIteration, tx)
@ -1685,12 +1739,12 @@ class RangeIteratorVariable(IteratorVariable):
self.start += self.step
return ConstantVariable.create(current)
def python_type(self):
def python_type(self) -> type:
return range_iterator
def reconstruct(self, codegen: "PyCodegen"):
def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(
lambda: codegen.append_output(codegen.create_load_python_module(range))
lambda: codegen.append_output(codegen.create_load_python_module(range)) # type: ignore[arg-type]
)
codegen.append_output(codegen.create_load_const(self.start))
codegen.append_output(codegen.create_load_const(self.stop))

View File

@ -472,7 +472,12 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
)
elif self.value is torch.nn.attention.sdpa_kernel.__wrapped__: # type: ignore[attr-defined]
name_to_arg_map = bind_args_cached(
self.value, tx, self.source, args, kwargs
# pyrefly: ignore[bad-argument-type]
self.value,
tx,
self.source,
args,
kwargs,
)
backends = name_to_arg_map["backends"].as_python_constant()
set_priority = name_to_arg_map["set_priority"].as_python_constant()
@ -1429,7 +1434,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
packed_input_vt = TupleVariable.build(
tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs))
)
out_vt = variables.UserFunctionVariable(tree_flatten).call_function(
out_vt = variables.UserFunctionVariable(tree_flatten).call_function( # type: ignore[arg-type]
tx, [packed_input_vt], {}
)
assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2

View File

@ -70,7 +70,7 @@ from .common import (
)
from .cpp_utils import cexpr
from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta
from torch.fx.experimental import _config as fx_experimental_config
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
@ -1120,37 +1120,6 @@ class PythonWrapperCodegen(CodeGen):
# Additional files that are dependent to the wrapper (ex. cubin files)
self.additional_files = []
# This is used to emit RecordFunctionFast markers that can be matched
# with profiler traces for provenance tracking.
#
# Stores the (kernel_name, debug_handle) tuple
# for the currently being generated kernel.
self.current_kernel_debug_handle: Optional[tuple[str, int]] = None
# set_current_kernel_debug_handle: Flag that controls whether
# write_provenance_debug_handle() should update current_kernel_debug_handle.
# This flag is automatically managed by kernel_debug_handle_context().
self.set_current_kernel_debug_handle: bool = False
@contextlib.contextmanager
def kernel_debug_handle_context(self):
"""
Context manager for kernel debug handle tracking.
self.current_kernel_debug_handle can be updated within the context
with wrapper.write_provenance_debug_handle
and it will be reset after the context
"""
old_flag_value = self.set_current_kernel_debug_handle
old_handle_value = self.current_kernel_debug_handle
self.set_current_kernel_debug_handle = True
try:
yield
finally:
self.set_current_kernel_debug_handle = old_flag_value
self.current_kernel_debug_handle = old_handle_value
@staticmethod
def create(
is_subgraph: bool,
@ -1541,27 +1510,8 @@ class PythonWrapperCodegen(CodeGen):
def generate_end(self, result: IndentedBuffer) -> None:
return
def generate_record_function_start(self) -> Optional[str]:
record_func = self.current_kernel_debug_handle and fx_experimental_config.enrich_profiler_metadata
if record_func:
assert self.current_kernel_debug_handle
kernel_name, debug_handle = self.current_kernel_debug_handle
kernel_debug_handle = f"{kernel_name}:{debug_handle}"
self.writeline(
f"_rf_enter = torch._C._profiler._RecordFunctionFast('## inductor_kernel:{kernel_debug_handle} ##'); _rf_enter.__enter__()"
)
return "_rf_enter"
else:
return None
def generate_record_function_end(self, record_func_var: Optional[str]):
if record_func_var:
self.writeline(f"{record_func_var}.__exit__(None, None, None)")
def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None:
record_func_var = self.generate_record_function_start()
self.writeline(ExternKernelAllocLine(self, node))
self.generate_record_function_end(record_func_var)
def generate_extern_kernel_alloc(self, node: ir.ExternKernelAlloc):
node.codegen_comment(self)
@ -1721,9 +1671,7 @@ class PythonWrapperCodegen(CodeGen):
raw_args: Sequence[Any],
outputs: Sequence[ir.Buffer],
) -> None:
record_func_var = self.generate_record_function_start()
self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(get_args())})")
self.generate_record_function_end(record_func_var)
def generate(self, is_inference):
with dynamo_timed("PythonWrapperCodegen.generate"):
@ -3194,8 +3142,6 @@ class PythonWrapperCodegen(CodeGen):
self.writeline(
f"{self.comment} [Provenance debug handles] {kernel_name}:{debug_handle}"
)
if self.set_current_kernel_debug_handle:
self.current_kernel_debug_handle = (kernel_name, debug_handle)
def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool):
assert old.get_dtype() == new.get_dtype()

View File

@ -98,8 +98,6 @@ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExpr
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.monitor import _WaitCounter
from torch.utils._ordered_set import OrderedSet
from torch.fx.experimental import _config as fx_experimental_config
from .._dynamo.backends.common import aot_autograd
from .._dynamo.exc import ShortenTraceback, SkipFrame
@ -1532,10 +1530,7 @@ class _InProcessFxCompile(FxCompile):
# Dump provenance artifacts for debugging trace
inductor_provenance_tracking_node_mappings = None
inductor_kernel_stack_trace_str = None
if (
config.trace.provenance_tracking_level != 0
or fx_experimental_config.enrich_profiler_metadata
):
if config.trace.provenance_tracking_level != 0:
inductor_provenance_tracking_node_mappings = json.dumps(
torch._inductor.debug.dump_inductor_provenance_info()
)

View File

@ -1179,8 +1179,6 @@ torchinductor_worker_logpath: str = Config(
default="",
)
fallback_by_default: bool = False
# config specific to codegen/cpp.py
class cpp:

View File

@ -1106,7 +1106,7 @@ def set_kernel_post_grad_provenance_tracing(
Returns a unique int debug handler for each call to this function.
"""
if config.trace.provenance_tracking_level == 0 and not config.fallback_by_default:
if config.trace.provenance_tracking_level == 0:
return None
try:

View File

@ -1628,7 +1628,6 @@ class GraphLowering(torch.fx.Interpreter):
"inductor", "lowerings", lambda: repr(n)
)
)
or (n.op == "call_function" and config.fallback_by_default)
):
debug("fallback_handler")
result = fallback_handler(n.target, add_to_fallback_set=False)(

View File

@ -8079,28 +8079,27 @@ class FallbackKernel(ExternKernelAlloc):
for v, a in zip(args_iter, kernel._schema.arguments)
)
with wrapper.kernel_debug_handle_context():
self.codegen_comment(wrapper)
if self.use_runtime_dispatch:
exported_args = self.export_extern_kernel_node()
assert self.python_kernel_name is not None
assert self.op_overload is not None
self.codegen_comment(wrapper)
if self.use_runtime_dispatch:
exported_args = self.export_extern_kernel_node()
assert self.python_kernel_name is not None
assert self.op_overload is not None
wrapper.generate_fallback_kernel_with_runtime_lookup(
self.get_name(),
self.python_kernel_name,
lambda: [*self.codegen_args(), *self.codegen_kwargs()],
self.op_overload,
exported_args,
# NOTE: [special handling of all_reduce_coalesced_'s return value]
self.outputs if self.outputs else self.mutation_outputs,
)
else:
wrapper.generate_fallback_kernel(self)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
self.codegen_alignment_asserts(wrapper)
self.codegen_memory_tracking(wrapper)
wrapper.generate_fallback_kernel_with_runtime_lookup(
self.get_name(),
self.python_kernel_name,
lambda: [*self.codegen_args(), *self.codegen_kwargs()],
self.op_overload,
exported_args,
# NOTE: [special handling of all_reduce_coalesced_'s return value]
self.outputs if self.outputs else self.mutation_outputs,
)
else:
wrapper.generate_fallback_kernel(self)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
self.codegen_alignment_asserts(wrapper)
self.codegen_memory_tracking(wrapper)
self.codegen_unbacked_symbol_defs(wrapper)

View File

@ -25,8 +25,6 @@ from __future__ import annotations
import dataclasses
import logging
import os
import random
import string
from functools import partial
from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union
@ -72,10 +70,6 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
# Used for profiler post-processing to match
# for the same compiled run
CALL_COMPILED_PREFIX = "Call CompiledFxGraph"
@dataclasses.dataclass
class OutputCode:
@ -618,18 +612,9 @@ class CompiledFxGraph(OutputCode):
try:
# Checking the profiler directly is faster than nullcontext
if torch.autograd.profiler._is_profiler_enabled:
# generate a random string to represent this unique run if no cache key
run_key = (
self._fx_graph_cache_key
if self._fx_graph_cache_key
else "".join(random.choices(string.ascii_lowercase, k=51))
)
run_name = f"{CALL_COMPILED_PREFIX} {run_key}"
if self.inductor_provenance_stack_traces_str:
torch.fx.traceback._register_fx_metadata(
run_name, self.inductor_provenance_stack_traces_str
)
with record_function(f"## {run_name} ##"):
with record_function(
f"## Call CompiledFxGraph {self._fx_graph_cache_key} ##"
):
return self.current_callable(inputs)
else:
return self.current_callable(inputs)

View File

@ -1224,3 +1224,43 @@ def _build_table(
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
)
return "".join(result)
# Collect all events with stack traces and format them canonically
def _canonicalize_profiler_events(events):
"""
Extract and format all events with stack traces in a canonical way
for deterministic testing.
"""
events_with_traces = []
for event in events:
# Extract relevant fields
event_name = event.get("name", "")
node_name = event["args"].get("node_name", "")
stack_trace = event["args"].get("stack_trace", "")
# Get the last non-empty line of the stack trace
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
stack_trace = lines[-1] if lines else ""
events_with_traces.append(
{
"event_name": event_name[:20],
"node_name": node_name,
"stack_trace": stack_trace,
"start_time": event.get("ts", 0),
}
)
# Sort by node_name for deterministic ordering
events_with_traces.sort(key=lambda x: x["start_time"])
# Format as a string
lines: list[str] = []
for evt in events_with_traces:
lines.append(
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
)
return "\n".join(lines)

View File

@ -20,7 +20,10 @@ from torch.distributed.tensor._tp_conv import (
convolution_backward_handler,
convolution_handler,
)
from torch.distributed.tensor._utils import try_find_mesh_from_args
from torch.distributed.tensor._utils import (
ExplicitRedistributionContext,
try_find_mesh_from_args,
)
from torch.distributed.tensor.placement_types import Partial, Placement, Replicate
from torch.utils._debug_mode import get_active_debug_mode
from torch.utils._python_dispatch import return_and_correct_aliasing
@ -199,6 +202,10 @@ class OpDispatcher:
if participating:
# computation that happens in the current rank of the mesh, normal case
if output_sharding.needs_redistribute:
if ExplicitRedistributionContext.is_active():
raise RuntimeError(
f"Implicit redistribution occurred while ExplicitRedistributionContext was active for {op_info.schema}"
)
# If sharding propagation decision needs redistribute, perform redistribute
# on args first, which could potentially modify args (i.e. allgather certain arg)
assert output_sharding.redistribute_schema is not None

View File

@ -18,6 +18,33 @@ from torch.distributed.tensor.placement_types import (
from torch.utils._typing_utils import not_none
class ExplicitRedistributionContext:
"""
Within this context manager, DTensor will refuse to perform implicit redistribution,
instead raising an error. Manual calls to ``redistribute()`` are required wherever a redistribution
must occur to avoid erroring. This can be used to ensure that the user is aware of all redistribution.
Note: it is easier to use this mode on just the forward pass of a typical DTensor program, as the backwards pass
may contain implicit redistribution calls that are not visible to the user and difficult to replace with manual
calls. Redistribution during backward can be made explicit by writing `autograd.Function`s that are no-op
during forward and perform a manual redistribution during backwards.
"""
_explicit_redistribute_mode = False
@classmethod
def is_active(cls) -> bool:
return cls._explicit_redistribute_mode
def __enter__(self):
self.prev = ExplicitRedistributionContext._explicit_redistribute_mode
ExplicitRedistributionContext._explicit_redistribute_mode = True
return self
def __exit__(self, exc_type, exc_val, exc_tb):
ExplicitRedistributionContext._explicit_redistribute_mode = self.prev
def _explicit_order_placements(
mesh_shape: ShapeType, placements: Sequence[Placement]
) -> Sequence[tuple[int, Placement]]:

View File

@ -43,10 +43,10 @@ should_preserve_node_meta = False
# =============================================================================
# Global in-memory registry for FX metadata
# Maps module_name -> metadata dict containing lineno_map and node_metadata
_FX_METADATA_REGISTRY: dict[str, str | dict[str, Any]] = {}
_FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {}
def _register_fx_metadata(module_name: str, metadata: str | dict[str, Any]) -> None:
def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
"""
Register FX metadata in the global in-memory registry.
@ -55,7 +55,7 @@ def _register_fx_metadata(module_name: str, metadata: str | dict[str, Any]) -> N
Args:
module_name: The module identifier (content-addressed filename)
metadata: Metadata dict containing lineno_map, node_metadata, and source_code. If a str, it's a json dump that can be json loaded as a dict.
metadata: Metadata dict containing lineno_map, node_metadata, and source_code
"""
# TODO: add logging to tlparse
_FX_METADATA_REGISTRY[module_name] = metadata

View File

@ -1,11 +1,9 @@
# mypy: allow-untyped-defs
import functools
import json
import operator
import re
from collections import deque
from dataclasses import dataclass
from enum import Enum
from typing import Any, Literal, Optional, TYPE_CHECKING
from torch.autograd.profiler import profile
@ -404,31 +402,13 @@ def _init_for_cuda_graphs() -> None:
pass
class ContextType(Enum):
"""Types of contexts in the profiler stack."""
FX_GRAPH = "filename"
FX_NODE = "node"
COMPILED_GRAPH = "compiled_graph"
INDUCTOR_NODE = "inductor_node"
def get_parent_context_type(context_type: ContextType) -> Optional[ContextType]:
if context_type == ContextType.FX_NODE:
return ContextType.FX_GRAPH
elif context_type == ContextType.INDUCTOR_NODE:
return ContextType.COMPILED_GRAPH
else:
return None
@dataclass
class TimelineEvent:
"""Represents an event in the profiler timeline."""
timestamp: int
event_type: Literal["start", "end", "regular"]
marker_type: Optional[ContextType]
marker_type: Optional[Literal["filename", "node"]]
identifier: Optional[str | int]
event: dict[str, Any]
@ -437,7 +417,7 @@ class TimelineEvent:
class ContextStackEntry:
"""Represents a context (filename or node) in the stack."""
context_type: ContextType
context_type: Literal["filename", "node"]
identifier: str | int
metadata: Optional[dict]
tid: Optional[int] = None # Thread ID associated with this context
@ -458,8 +438,6 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
Returns:
Dict mapping recorded event names to their aten operations with added stack traces
"""
from torch._inductor.output_code import CALL_COMPILED_PREFIX
from torch.fx.graph_module import FX_GRAPH_MODULE_FILE_PREFIX
from torch.fx.traceback import _FX_METADATA_REGISTRY
trace_events = traced_data.get("traceEvents", [])
@ -469,7 +447,7 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
def is_fx_marker_event(event):
return (
event.get("cat") in ("cpu_op", "user_annotation")
event.get("cat") == "cpu_op"
and event.get("name", "").startswith("## ")
and event.get("name", "").endswith(" ##")
)
@ -491,27 +469,14 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
if is_fx_marker_event(event):
content = event["name"][3:-3]
# Try different event types
if content.startswith(FX_GRAPH_MODULE_FILE_PREFIX) and content.endswith(
".py"
):
# FX graph event
append_fx_marker_event(ContextType.FX_GRAPH, content, event)
elif content.startswith(CALL_COMPILED_PREFIX):
# Inductor compiled graph event
append_fx_marker_event(ContextType.COMPILED_GRAPH, content, event)
elif content.startswith("inductor_kernel:"):
append_fx_marker_event(
ContextType.INDUCTOR_NODE, content[len("inductor_kernel:") :], event
)
if content.endswith(".py"):
append_fx_marker_event("filename", content, event)
else:
# Try to parse as node index for FX graph
# TODO: change to start with fx_node
try:
node_index = int(content)
append_fx_marker_event(ContextType.FX_NODE, node_index, event)
except ValueError:
pass
append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined]
else:
# Regular event that needs augmentation
@ -530,37 +495,23 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
case "start":
assert timeline_event.identifier is not None
if timeline_event.marker_type in (
ContextType.FX_GRAPH,
ContextType.COMPILED_GRAPH,
):
if timeline_event.marker_type == "filename":
assert isinstance(timeline_event.identifier, str)
# Push filename context - query metadata registry on-demand
metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier)
tid = timeline_event.event.get("tid")
# TODO: add get method in traceback to try - catch and get
if isinstance(metadata, str):
metadata = json.loads(metadata)
context_stack.append(
ContextStackEntry(
timeline_event.marker_type,
timeline_event.identifier,
metadata,
tid,
"filename", timeline_event.identifier, metadata, tid
)
)
elif timeline_event.marker_type in (
ContextType.FX_NODE,
ContextType.INDUCTOR_NODE,
):
elif timeline_event.marker_type == "node":
# Find the current filename from stack
current_file_metadata = None
tid = timeline_event.event.get("tid")
parent_type = get_parent_context_type(timeline_event.marker_type)
for ctx_entry in reversed(context_stack):
if (
ctx_entry.context_type == parent_type
ctx_entry.context_type == "filename"
and ctx_entry.tid == tid
):
current_file_metadata = ctx_entry.metadata
@ -569,39 +520,14 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
if current_file_metadata:
node_metadata = current_file_metadata.get("node_metadata", {})
if timeline_event.identifier in node_metadata:
if ctx_entry.context_type == ContextType.FX_NODE:
node_meta: Optional[dict] = node_metadata[
timeline_event.identifier
]
context_stack.append(
ContextStackEntry(
ContextType.FX_NODE,
timeline_event.identifier,
node_meta,
tid,
)
node_meta: Optional[dict] = node_metadata[
timeline_event.identifier
]
context_stack.append(
ContextStackEntry(
"node", timeline_event.identifier, node_meta, tid
)
if timeline_event.marker_type == ContextType.INDUCTOR_NODE:
# Look up stack traces for this kernel
# TODO: make a dictionary that maps from compiled key to stack traces dictionary
stack_traces = current_file_metadata.get(
timeline_event.identifier, []
)
if stack_traces:
# Store all stack traces as metadata
node_meta: Optional[dict] = {
"stack_trace": stack_traces,
"name": timeline_event.identifier,
}
context_stack.append(
ContextStackEntry(
ContextType.INDUCTOR_NODE,
timeline_event.identifier,
node_meta,
tid,
)
)
case "end":
# Pop from stack - search backwards to find matching context
@ -625,10 +551,7 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
for ctx_entry in reversed(context_stack):
# Only apply metadata from contexts with matching tid
if ctx_entry.tid == event_tid:
if (
ctx_entry.context_type == ContextType.FX_NODE
and ctx_entry.metadata
):
if ctx_entry.context_type == "node" and ctx_entry.metadata:
current_stack_trace = ctx_entry.metadata.get(
"stack_trace", "No model stack trace available"
)
@ -636,19 +559,6 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
# Do we want to only attach the stack trace of the lowest node or stack trace of all nodes
# if nodes are nested, e.g. in nested graph modules
break
elif (
ctx_entry.context_type == ContextType.INDUCTOR_NODE
and ctx_entry.metadata
):
# For inductor nodes, stack_trace is a list of traces
stack_traces_list = ctx_entry.metadata.get(
"stack_trace", []
)
if stack_traces_list:
# Store as a list - each trace gets its own entry
current_stack_trace = stack_traces_list
current_node_name = ctx_entry.metadata.get("name", "")
break
# Augment the event
if current_stack_trace or current_node_name:
@ -657,81 +567,3 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
args["stack_trace"] = current_stack_trace
if current_node_name:
args["node_name"] = current_node_name
import tempfile
import os
# Collect all events with stack traces and format them canonically
def _canonicalize_profiler_events(events):
"""
Extract and format all events with stack traces in a canonical way
for deterministic testing.
"""
events_with_traces = []
for event in events:
# Extract relevant fields
event_name = event.get("name", "")
node_name = event["args"].get("node_name", "")
stack_trace = event["args"].get("stack_trace", "")
if isinstance(stack_trace, list):
stack_trace = "\n".join(stack_trace)
# Get the last non-empty line of the stack trace
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
stack_trace = lines[-1] if lines else ""
events_with_traces.append(
{
"event_name": event_name[:20],
"node_name": node_name,
"stack_trace": stack_trace,
"start_time": event.get("ts", 0),
}
)
# Sort by node_name for deterministic ordering
events_with_traces.sort(key=lambda x: x["start_time"])
# Format as a string
lines: list[str] = []
for evt in events_with_traces:
lines.append(
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
)
return "\n".join(lines)
def _enrich_profiler_traces(prof):
"""
Helper function to extract and augment profiler events with stack traces.
Args:
prof: A torch.profiler.profile object
Returns:
A string representing enriched events
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
trace_file = f.name
try:
prof.export_chrome_trace(trace_file)
with open(trace_file) as f:
trace_data = json.load(f)
map_recorded_events_to_aten_ops_with_stack_trace(trace_data)
events = []
for event in trace_data["traceEvents"]:
if "args" in event and "stack_trace" in event["args"]:
events.append(event)
actual_traces = _canonicalize_profiler_events(events)
return actual_traces
finally:
if os.path.exists(trace_file):
os.remove(trace_file)