Compare commits

..

9 Commits

Author SHA1 Message Date
2056d7fa22 [profiler] Add stack trace to Fallback kernels for inductor lite mode 2025-11-06 16:09:21 -08:00
4b9ba0fb26 [user-streams] Add requires cuda to all test cases (#167195)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167195
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #167175, #167176, #167180
2025-11-06 23:13:47 +00:00
106d34c80a [user-streams] add requires cuda decorator (#167180)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167180
Approved by: https://github.com/donigian, https://github.com/Lucaskabela, https://github.com/Skylion007
ghstack dependencies: #167175, #167176
2025-11-06 23:13:47 +00:00
0b06109412 [user-streams] Fix bug in object bytecode construction (#167176)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167176
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #167175
2025-11-06 23:13:47 +00:00
2073af5790 [user-streams] Refactor user object index in streams (#167175)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167175
Approved by: https://github.com/Lucaskabela
2025-11-06 23:13:47 +00:00
9b4ac45d2f Revert "[Inductor] addmm with bias -> unfuse bias if there is a pointwise/reduction consumer (#166165)"
This reverts commit eefa16342c9f322b56c7c0cd6d309c3ed8f0b882.

Reverted https://github.com/pytorch/pytorch/pull/166165 on behalf of https://github.com/jeanschmidt due to Breaking internal tests D86216934 ([comment](https://github.com/pytorch/pytorch/pull/166165#issuecomment-3499645688))
2025-11-06 22:34:48 +00:00
a45a17f65e Fix boxcox to return same result for same input in one batch (#166986)
Summary:
The SIMD path is using SLEEF version of pow which is slightly different from std::pow. The fix is to use the same vectorized code (with partial load and store) for the trailing data as well to ensure consistency between results.

Deploy:
Need to make a hotfix in waas to monitor release signals, since this diff can cause testing failures in veloski and waas release correctness tests.

Test Plan: Sandcastle.

Differential Revision: D86218207

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166986
Approved by: https://github.com/swolchok
2025-11-06 22:33:26 +00:00
c5593e75b3 Fix flaky memory profiler test (#167168)
Fixes #167037

Do not check the exact number of frames.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167168
Approved by: https://github.com/angelayi
2025-11-06 21:39:44 +00:00
c90a976370 Update pythoncapi_compat.h (#167138)
Update to commit 44c8e14bbbb5d5135ae90957036a61397e4df577.

Should slightly simplify https://github.com/pytorch/pytorch/pull/166342
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167138
Approved by: https://github.com/albanD
2025-11-06 21:31:58 +00:00
26 changed files with 1831 additions and 259 deletions

View File

@ -73,6 +73,19 @@ void box_cox_zero_lambda(
}
}
template <typename T>
at::vec::Vectorized<T> box_cox_nonzero_lambda_impl(
at::vec::Vectorized<T> data,
at::vec::Vectorized<T> lambda1,
at::vec::Vectorized<T> lambda2,
at::vec::Vectorized<T> k_eps) {
auto sum = data + lambda2;
auto max = at::vec::max(sum, k_eps);
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1);
auto pow = max.pow(lambda1);
return at::vec::fmsub(pow, lambda_over_1, lambda_over_1);
}
template <typename T>
void box_cox_nonzero_lambda(
int64_t D,
@ -88,21 +101,18 @@ void box_cox_nonzero_lambda(
auto k_eps_vec = Vec(k_eps);
for(; j + VLEN < D; j += VLEN) {
auto data = Vec::loadu(data_ptr + j);
auto lambda2 = Vec::loadu(lambda2_ptr + j);
auto sum = data + lambda2;
auto max = at::vec::max(sum, k_eps_vec);
auto lambda1 = Vec::loadu(lambda1_ptr + j);
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1);
auto pow = max.pow(lambda1);
auto res = at::vec::fmsub(pow, lambda_over_1, lambda_over_1);
auto lambda2 = Vec::loadu(lambda2_ptr + j);
auto res = box_cox_nonzero_lambda_impl(data, lambda1, lambda2, k_eps_vec);
res.store(out + j);
}
for ( ;j < D; ++j) {
auto sum = data_ptr[j] + lambda2_ptr[j];
auto max = std::max(sum, k_eps);
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1_ptr[j]);
auto pow = std::pow(max, lambda1_ptr[j]);
out[j] = pow * lambda_over_1 - lambda_over_1;
if (j < D) {
auto remaining = D - j;
auto data = Vec::loadu(data_ptr + j, remaining);
auto lambda1 = Vec::loadu(lambda1_ptr + j, remaining);
auto lambda2 = Vec::loadu(lambda2_ptr + j, remaining);
auto res = box_cox_nonzero_lambda_impl(data, lambda1, lambda2, k_eps_vec);
res.store(out + j, remaining);
}
}
#else

View File

@ -74,13 +74,13 @@ class TestStreams(torch._dynamo.test_case.TestCase):
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
# Annotation: {'stream': 0}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
# Annotation: {'stream': 1}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
# Annotation: {'stream': 1}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
return (add_3,)
@ -196,6 +196,7 @@ class <lambda>(torch.nn.Module):
s_exp = fn(*inp)
self.assertEqual(s_act, s_exp)
@requires_cuda
def test_nested_stream_enter_exit(self):
def fn(x, y, s0, s1, s2):
with s1:
@ -229,13 +230,13 @@ class <lambda>(torch.nn.Module):
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
# Annotation: {'stream': 1}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
# Annotation: {'stream': 2}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
# Annotation: {'stream': 1}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
return (add_1, add_2)
""",
@ -249,6 +250,7 @@ class <lambda>(torch.nn.Module):
def test_nested_stream_enter_exit_graph_break(self):
pass
@requires_cuda
def test_local_stream_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
@ -289,6 +291,7 @@ class <lambda>(torch.nn.Module):
""",
)
@requires_cuda
def test_local_stream_nested_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
@ -331,6 +334,7 @@ class <lambda>(torch.nn.Module):
""",
)
@requires_cuda
def test_stream_with_mutation(self):
def fn(x, y):
s2 = torch.Stream()
@ -380,6 +384,7 @@ class <lambda>(torch.nn.Module):
""",
)
@requires_cuda
def test_stream_backward(self) -> None:
def fn(x, y):
s2 = torch.Stream()

View File

@ -500,13 +500,8 @@ class PaddingTest(TestCaseBase):
forward_wrapper = wrapper_codes[0]
# make sure the load for softmax is aligned
if bias:
# addmm -> mm + bias and bias is fused with softmax
softmax_load_str = "tl.load(in_out_ptr0 + (r0_1 + 30528*x0)"
else:
softmax_load_str = "tl.load(in_ptr0 + (r0_1 + 30528*x0)"
self.assertTrue(
softmax_load_str in forward_wrapper,
"tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper,
f"forward_wrapper: {forward_wrapper}",
)

View File

@ -27,9 +27,9 @@ from torch._inductor.fx_passes.post_grad import post_grad_passes
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code, run_and_get_cpp_code
from torch._inductor.virtualized import V
from torch.testing._internal.common_utils import IS_MACOS
from torch.testing._internal.common_utils import IS_MACOS, skipIfRocm
from torch.testing._internal.triton_utils import requires_cuda_and_triton
from torch.profiler import profile, ProfilerActivity
try:
from .test_aot_inductor_utils import AOTIRunnerUtil
@ -941,5 +941,63 @@ copy_tests(
)
from torch.profiler._utils import _enrich_profiler_traces
class TestProfilerStackTraceAugmentation(TestCase):
"""
Test that profiler events are correctly augmented with stack traces
from both FX metadata and inductor kernel stack traces.
"""
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
@config.patch("fallback_by_default", True) # TODO update the config patch to inductor lite mode
@torch.compiler.config.patch("force_disable_caches", True)
def test_profiler_inductor_stack_trace_augmentation(self):
"""
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
augments profiler events with stack traces from inductor kernel metadata.
"""
# Test model similar to test.py
class TestModel(torch.nn.Module):
def forward(self, c):
d = c * 2
d = d + 1
return d
device = "cuda"
model = TestModel().to(device)
c = torch.randn((64, 32), device=device)
# Force disable caches to ensure fresh compilation
torch.compiler.config.force_disable_caches = True
# Compile the model
compiled_model = torch.compile(model, fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_model(c)
# Profile with the compiled model
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
compiled_model(c)
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::mul node=torch.ops.aten.mul.Tensor:1 stack_trace=d = c * 2
event=cudaLaunchKernel node=torch.ops.aten.mul.Tensor:1 stack_trace=d = c * 2
event=aten::add node=torch.ops.aten.add.Tensor:2 stack_trace=d = d + 1
event=cudaLaunchKernel node=torch.ops.aten.add.Tensor:2 stack_trace=d = d + 1""")
# TODO: add test that when enrich is not turned on there is no recordfast generated.
if __name__ == "__main__":
run_tests()

View File

@ -15310,7 +15310,7 @@ if RUN_GPU:
),
(
fn3,
"triton_poi_fused_addmm_native_layer_norm",
"triton_poi_fused_native_layer_norm_relu",
(torch.randn(4, 4, device=GPU_TYPE),),
),
]
@ -15323,7 +15323,7 @@ if RUN_GPU:
),
(
fn3,
"triton_poi_fused_LayerNorm_Linear_ReLU",
"triton_poi_fused_LayerNorm_ReLU",
(torch.randn(4, 4, device=GPU_TYPE),),
),
]

View File

@ -7508,6 +7508,8 @@ class TestFXMemoryProfiler(TestCase):
device = "cuda"
mod = MLPModule(device)
with tempfile.TemporaryDirectory() as tmpdir:
# reset cache to start fresh
torch.cuda.memory.empty_cache()
torch.cuda.memory._record_memory_history()
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
result = compiled(torch.randn(10, 10, device=device))
@ -7518,10 +7520,7 @@ class TestFXMemoryProfiler(TestCase):
torch.cuda.empty_cache()
fx_frames = self.collect_frames(augmented_snapshot)
if TEST_WITH_ROCM:
self.assertGreater(len(fx_frames), 0)
else:
self.assertEqual(len(fx_frames), 12)
self.assertGreater(len(fx_frames), 2)
for frame in fx_frames:
# Every FX frame should have both node_op and node_name

View File

@ -76,11 +76,8 @@ from torch.testing._internal.common_utils import (
)
from torch.testing._internal.jit_utils import JitTestCase
import json
import tempfile
from torch.profiler import profile, ProfilerActivity
from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace
from torch.autograd.profiler_util import _canonicalize_profiler_events
from torch.profiler._utils import _enrich_profiler_traces
try:
from torchvision import models as torchvision_models
@ -208,36 +205,6 @@ def side_effect_func(x: torch.Tensor):
print(x)
def _enrich_profiler_traces(prof):
"""
Helper function to extract and augment profiler events with stack traces.
Args:
prof: A torch.profiler.profile object
Returns:
A string representing enriched events
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f:
trace_file = f.name
prof.export_chrome_trace(trace_file)
with open(trace_file) as f:
trace_data = json.load(f)
map_recorded_events_to_aten_ops_with_stack_trace(
trace_data
)
events = []
for event in trace_data["traceEvents"]:
if "args" in event and "stack_trace" in event["args"]:
events.append(event)
actual_traces = _canonicalize_profiler_events(events)
return actual_traces
class TestFX(JitTestCase):
def setUp(self):
super().setUp()

View File

@ -1542,7 +1542,7 @@ class OutputGraph(OutputGraphCommon):
)
)
tmp_vars = []
for constructor in reversed(index_to_bytecode_constructor.values()):
for constructor in index_to_bytecode_constructor.values():
constructor(codegen)
var_name = (
self.new_var()

View File

@ -1061,9 +1061,7 @@ class VariableBuilder:
)
set_example_value(stream_proxy.node, value)
var = StreamVariable(
stream_proxy,
value,
source=self.source,
stream_proxy, value, source=self.source, user_object_index=index
)
return self.tx.output.side_effects.track_object_existing(value, var)
elif isinstance(value, (torch._C._SDPAParams)):
@ -3006,14 +3004,16 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
return SymNodeVariable(proxy, example_value, **options)
elif (
isinstance(example_value, torch.Stream)
and proxy.node.target
in (get_external_object_by_index, torch.accelerator.current_stream)
and proxy.node.target == get_external_object_by_index
) or proxy.node.target in [
device_interface.current_stream
for _, device_interface in get_registered_device_interfaces()
]:
set_example_value(proxy.node, example_value)
return StreamVariable(proxy, example_value, **options)
index = None
if proxy.node.target == get_external_object_by_index:
index = proxy.node.args[0]
return StreamVariable(proxy, example_value, index, **options)
elif (
inspect.isclass(proxy.node.target)
and issubclass(proxy.node.target, torch.Event)

View File

@ -204,11 +204,11 @@ class StreamVariable(StreamContextVariable):
self,
proxy: Proxy,
value: torch.Stream,
user_object_index: Optional[int] = None,
**kwargs: Any,
) -> None:
# Index into the user object table
# used to pass arbitrary objects to the graph
user_object_index = kwargs.pop("user_obj_index", None)
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
@ -300,7 +300,7 @@ class StreamVariable(StreamContextVariable):
codegen.append_output(codegen.create_load_const(self.user_object_index))
codegen.extend_output(create_call_function(1, False))
else:
# TODO mlazos: evaluate if we still need this
# This will support the legacy behavior
prefix = f"_stream_{self.device}"
name = codegen.tx.output.install_global_by_id(prefix, self.value)
codegen.append_output(codegen.create_load_global(name, add=True))

View File

@ -838,7 +838,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
proxy=tx.output.create_proxy(
"call_function", get_external_object_by_index, (ind,), {}
),
user_obj_index=ind,
)
else:
tensor_variable = wrap_fx_proxy(

View File

@ -70,7 +70,7 @@ from .common import (
)
from .cpp_utils import cexpr
from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta
from torch.fx.experimental import _config as fx_experimental_config
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
@ -1120,6 +1120,37 @@ class PythonWrapperCodegen(CodeGen):
# Additional files that are dependent to the wrapper (ex. cubin files)
self.additional_files = []
# This is used to emit RecordFunctionFast markers that can be matched
# with profiler traces for provenance tracking.
#
# Stores the (kernel_name, debug_handle) tuple
# for the currently being generated kernel.
self.current_kernel_debug_handle: Optional[tuple[str, int]] = None
# set_current_kernel_debug_handle: Flag that controls whether
# write_provenance_debug_handle() should update current_kernel_debug_handle.
# This flag is automatically managed by kernel_debug_handle_context().
self.set_current_kernel_debug_handle: bool = False
@contextlib.contextmanager
def kernel_debug_handle_context(self):
"""
Context manager for kernel debug handle tracking.
self.current_kernel_debug_handle can be updated within the context
with wrapper.write_provenance_debug_handle
and it will be reset after the context
"""
old_flag_value = self.set_current_kernel_debug_handle
old_handle_value = self.current_kernel_debug_handle
self.set_current_kernel_debug_handle = True
try:
yield
finally:
self.set_current_kernel_debug_handle = old_flag_value
self.current_kernel_debug_handle = old_handle_value
@staticmethod
def create(
is_subgraph: bool,
@ -1510,8 +1541,27 @@ class PythonWrapperCodegen(CodeGen):
def generate_end(self, result: IndentedBuffer) -> None:
return
def generate_record_function_start(self) -> Optional[str]:
record_func = self.current_kernel_debug_handle and fx_experimental_config.enrich_profiler_metadata
if record_func:
assert self.current_kernel_debug_handle
kernel_name, debug_handle = self.current_kernel_debug_handle
kernel_debug_handle = f"{kernel_name}:{debug_handle}"
self.writeline(
f"_rf_enter = torch._C._profiler._RecordFunctionFast('## inductor_kernel:{kernel_debug_handle} ##'); _rf_enter.__enter__()"
)
return "_rf_enter"
else:
return None
def generate_record_function_end(self, record_func_var: Optional[str]):
if record_func_var:
self.writeline(f"{record_func_var}.__exit__(None, None, None)")
def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None:
record_func_var = self.generate_record_function_start()
self.writeline(ExternKernelAllocLine(self, node))
self.generate_record_function_end(record_func_var)
def generate_extern_kernel_alloc(self, node: ir.ExternKernelAlloc):
node.codegen_comment(self)
@ -1671,7 +1721,9 @@ class PythonWrapperCodegen(CodeGen):
raw_args: Sequence[Any],
outputs: Sequence[ir.Buffer],
) -> None:
record_func_var = self.generate_record_function_start()
self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(get_args())})")
self.generate_record_function_end(record_func_var)
def generate(self, is_inference):
with dynamo_timed("PythonWrapperCodegen.generate"):
@ -3142,6 +3194,8 @@ class PythonWrapperCodegen(CodeGen):
self.writeline(
f"{self.comment} [Provenance debug handles] {kernel_name}:{debug_handle}"
)
if self.set_current_kernel_debug_handle:
self.current_kernel_debug_handle = (kernel_name, debug_handle)
def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool):
assert old.get_dtype() == new.get_dtype()

View File

@ -98,6 +98,8 @@ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExpr
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.monitor import _WaitCounter
from torch.utils._ordered_set import OrderedSet
from torch.fx.experimental import _config as fx_experimental_config
from .._dynamo.backends.common import aot_autograd
from .._dynamo.exc import ShortenTraceback, SkipFrame
@ -1530,7 +1532,10 @@ class _InProcessFxCompile(FxCompile):
# Dump provenance artifacts for debugging trace
inductor_provenance_tracking_node_mappings = None
inductor_kernel_stack_trace_str = None
if config.trace.provenance_tracking_level != 0:
if (
config.trace.provenance_tracking_level != 0
or fx_experimental_config.enrich_profiler_metadata
):
inductor_provenance_tracking_node_mappings = json.dumps(
torch._inductor.debug.dump_inductor_provenance_info()
)

View File

@ -1179,6 +1179,8 @@ torchinductor_worker_logpath: str = Config(
default="",
)
fallback_by_default: bool = False
# config specific to codegen/cpp.py
class cpp:

View File

@ -1106,7 +1106,7 @@ def set_kernel_post_grad_provenance_tracing(
Returns a unique int debug handler for each call to this function.
"""
if config.trace.provenance_tracking_level == 0:
if config.trace.provenance_tracking_level == 0 and not config.fallback_by_default:
return None
try:

View File

@ -52,8 +52,8 @@ from ..utils import (
decode_device,
get_all_devices,
get_gpu_type,
has_uses_tagged_as,
is_gpu,
is_pointwise_use,
OPTIMUS_EXCLUDE_POST_GRAD,
)
from ..virtualized import V
@ -1511,10 +1511,8 @@ def should_prefer_unfused_addmm(match):
if not is_gpu(inp.meta["val"].device.type):
return False
return has_uses_tagged_as(
match.output_node(),
(torch.Tag.pointwise, torch.Tag.reduction),
)
output = match.output_node()
return all(is_pointwise_use(use) for use in output.users)
@register_graph_pattern(

View File

@ -1628,6 +1628,7 @@ class GraphLowering(torch.fx.Interpreter):
"inductor", "lowerings", lambda: repr(n)
)
)
or (n.op == "call_function" and config.fallback_by_default)
):
debug("fallback_handler")
result = fallback_handler(n.target, add_to_fallback_set=False)(

View File

@ -8079,27 +8079,28 @@ class FallbackKernel(ExternKernelAlloc):
for v, a in zip(args_iter, kernel._schema.arguments)
)
self.codegen_comment(wrapper)
if self.use_runtime_dispatch:
exported_args = self.export_extern_kernel_node()
assert self.python_kernel_name is not None
assert self.op_overload is not None
with wrapper.kernel_debug_handle_context():
self.codegen_comment(wrapper)
if self.use_runtime_dispatch:
exported_args = self.export_extern_kernel_node()
assert self.python_kernel_name is not None
assert self.op_overload is not None
wrapper.generate_fallback_kernel_with_runtime_lookup(
self.get_name(),
self.python_kernel_name,
lambda: [*self.codegen_args(), *self.codegen_kwargs()],
self.op_overload,
exported_args,
# NOTE: [special handling of all_reduce_coalesced_'s return value]
self.outputs if self.outputs else self.mutation_outputs,
)
else:
wrapper.generate_fallback_kernel(self)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
self.codegen_alignment_asserts(wrapper)
self.codegen_memory_tracking(wrapper)
wrapper.generate_fallback_kernel_with_runtime_lookup(
self.get_name(),
self.python_kernel_name,
lambda: [*self.codegen_args(), *self.codegen_kwargs()],
self.op_overload,
exported_args,
# NOTE: [special handling of all_reduce_coalesced_'s return value]
self.outputs if self.outputs else self.mutation_outputs,
)
else:
wrapper.generate_fallback_kernel(self)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
self.codegen_alignment_asserts(wrapper)
self.codegen_memory_tracking(wrapper)
self.codegen_unbacked_symbol_defs(wrapper)

View File

@ -7096,19 +7096,30 @@ def sym_constrain_range(a, min=None, max=None):
@register_lowering(aten.sym_size.int)
def sym_size(a, dim):
val = V.graph.current_node.meta["val"]
if isinstance(val, torch.SymInt):
return val.node.expr
else:
return int(val)
# Note [Can val be an int?]
# ~~~~~~~~~~~~~~~~~~~~~~~~~
# In principle, someone could construct an FX graph where
# a call to size/stride has a val that is a plain int (not
# SymInt). However, we will maintain the invariant that
# this is not possible: if you are constructing an FX graph
# where there is a call to size/stride that returns an
# int, but you KNOW that int must always be a constant,
# then you do not need trace that call at all (and just
# constant propagate the integer as is.)
assert isinstance(val, torch.SymInt), (
f"Expect val to be torch.SymInt but got val={val}"
)
return val.node.expr
@register_lowering(aten.sym_stride.int)
def sym_stride(a, dim):
val = V.graph.current_node.meta["val"]
if isinstance(val, torch.SymInt):
return val.node.expr
else:
return int(val)
# See Note [Can val be an int?]
assert isinstance(val, torch.SymInt), (
f"Expect val to be torch.SymInt but got val={val}"
)
return val.node.expr
@register_lowering(aten.sym_numel)

View File

@ -25,6 +25,8 @@ from __future__ import annotations
import dataclasses
import logging
import os
import random
import string
from functools import partial
from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union
@ -70,6 +72,10 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
# Used for profiler post-processing to match
# for the same compiled run
CALL_COMPILED_PREFIX = "Call CompiledFxGraph"
@dataclasses.dataclass
class OutputCode:
@ -612,9 +618,18 @@ class CompiledFxGraph(OutputCode):
try:
# Checking the profiler directly is faster than nullcontext
if torch.autograd.profiler._is_profiler_enabled:
with record_function(
f"## Call CompiledFxGraph {self._fx_graph_cache_key} ##"
):
# generate a random string to represent this unique run if no cache key
run_key = (
self._fx_graph_cache_key
if self._fx_graph_cache_key
else "".join(random.choices(string.ascii_lowercase, k=51))
)
run_name = f"{CALL_COMPILED_PREFIX} {run_key}"
if self.inductor_provenance_stack_traces_str:
torch.fx.traceback._register_fx_metadata(
run_name, self.inductor_provenance_stack_traces_str
)
with record_function(f"## {run_name} ##"):
return self.current_callable(inputs)
else:
return self.current_callable(inputs)

View File

@ -549,70 +549,6 @@ def is_pointwise_use(
return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
class LogicalConnective(enum.Enum):
OR = enum.auto()
AND = enum.auto()
def has_uses(
target: Node,
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
) -> bool:
"""
Given a target, explore the uses of `target` by applying `use_selector_fn`
on them, and then aggregate these booleans with the `use_aggregate_type`
logical connective.
Uses in view ops will follow the views uses.
"""
def get_use_aggregate_fn(
use_aggregate_type: LogicalConnective,
) -> Callable[[Iterator[Any]], bool]:
match use_aggregate_type:
case LogicalConnective.AND:
return all
case LogicalConnective.OR:
return any
case _:
return any
use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type)
def has_uses_impl(use: Node) -> bool:
if use.op != "call_function":
return False
if not (
isinstance(use.target, torch._ops.OpOverload)
or use.target is operator.getitem
):
return False
target = cast(torch._ops.OpOverload, use.target)
# Process getitem and view
if target is operator.getitem or is_view(target):
return use_aggregate_fn(has_uses_impl(user) for user in use.users)
return use_selector_fn(target)
return use_aggregate_fn(has_uses_impl(user) for user in target.users)
def has_uses_tagged_as(
target: Node,
use_tags: Collection[torch.Tag],
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
) -> bool:
"""
Is there a use with given tags?
"""
return has_uses(
target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type
)
def gen_gm_and_inputs(
target: Any, args: list[Any], kwargs: dict[str, Any]
) -> tuple[GraphModule, list[torch.Tensor]]:

View File

@ -1224,43 +1224,3 @@ def _build_table(
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
)
return "".join(result)
# Collect all events with stack traces and format them canonically
def _canonicalize_profiler_events(events):
"""
Extract and format all events with stack traces in a canonical way
for deterministic testing.
"""
events_with_traces = []
for event in events:
# Extract relevant fields
event_name = event.get("name", "")
node_name = event["args"].get("node_name", "")
stack_trace = event["args"].get("stack_trace", "")
# Get the last non-empty line of the stack trace
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
stack_trace = lines[-1] if lines else ""
events_with_traces.append(
{
"event_name": event_name[:20],
"node_name": node_name,
"stack_trace": stack_trace,
"start_time": event.get("ts", 0),
}
)
# Sort by node_name for deterministic ordering
events_with_traces.sort(key=lambda x: x["start_time"])
# Format as a string
lines: list[str] = []
for evt in events_with_traces:
lines.append(
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
)
return "\n".join(lines)

View File

@ -33,14 +33,6 @@ static inline int PyCode_GetNFreevars(PyCodeObject* code) {
#endif
}
// Provided by CPython but getting the header for them is very hard
#if IS_PYTHON_3_11_PLUS
// NOLINTNEXTLINE(readability-redundant-declaration)
PyAPI_FUNC(void) _PyWeakref_ClearRef(PyWeakReference* self);
#else
extern void _PyWeakref_ClearRef(PyWeakReference* self);
#endif
#ifdef __cplusplus
}
#endif

File diff suppressed because it is too large Load Diff

View File

@ -43,10 +43,10 @@ should_preserve_node_meta = False
# =============================================================================
# Global in-memory registry for FX metadata
# Maps module_name -> metadata dict containing lineno_map and node_metadata
_FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {}
_FX_METADATA_REGISTRY: dict[str, str | dict[str, Any]] = {}
def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
def _register_fx_metadata(module_name: str, metadata: str | dict[str, Any]) -> None:
"""
Register FX metadata in the global in-memory registry.
@ -55,7 +55,7 @@ def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
Args:
module_name: The module identifier (content-addressed filename)
metadata: Metadata dict containing lineno_map, node_metadata, and source_code
metadata: Metadata dict containing lineno_map, node_metadata, and source_code. If a str, it's a json dump that can be json loaded as a dict.
"""
# TODO: add logging to tlparse
_FX_METADATA_REGISTRY[module_name] = metadata

View File

@ -1,9 +1,11 @@
# mypy: allow-untyped-defs
import functools
import json
import operator
import re
from collections import deque
from dataclasses import dataclass
from enum import Enum
from typing import Any, Literal, Optional, TYPE_CHECKING
from torch.autograd.profiler import profile
@ -402,13 +404,31 @@ def _init_for_cuda_graphs() -> None:
pass
class ContextType(Enum):
"""Types of contexts in the profiler stack."""
FX_GRAPH = "filename"
FX_NODE = "node"
COMPILED_GRAPH = "compiled_graph"
INDUCTOR_NODE = "inductor_node"
def get_parent_context_type(context_type: ContextType) -> Optional[ContextType]:
if context_type == ContextType.FX_NODE:
return ContextType.FX_GRAPH
elif context_type == ContextType.INDUCTOR_NODE:
return ContextType.COMPILED_GRAPH
else:
return None
@dataclass
class TimelineEvent:
"""Represents an event in the profiler timeline."""
timestamp: int
event_type: Literal["start", "end", "regular"]
marker_type: Optional[Literal["filename", "node"]]
marker_type: Optional[ContextType]
identifier: Optional[str | int]
event: dict[str, Any]
@ -417,7 +437,7 @@ class TimelineEvent:
class ContextStackEntry:
"""Represents a context (filename or node) in the stack."""
context_type: Literal["filename", "node"]
context_type: ContextType
identifier: str | int
metadata: Optional[dict]
tid: Optional[int] = None # Thread ID associated with this context
@ -438,6 +458,8 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
Returns:
Dict mapping recorded event names to their aten operations with added stack traces
"""
from torch._inductor.output_code import CALL_COMPILED_PREFIX
from torch.fx.graph_module import FX_GRAPH_MODULE_FILE_PREFIX
from torch.fx.traceback import _FX_METADATA_REGISTRY
trace_events = traced_data.get("traceEvents", [])
@ -447,7 +469,7 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
def is_fx_marker_event(event):
return (
event.get("cat") == "cpu_op"
event.get("cat") in ("cpu_op", "user_annotation")
and event.get("name", "").startswith("## ")
and event.get("name", "").endswith(" ##")
)
@ -469,14 +491,27 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
if is_fx_marker_event(event):
content = event["name"][3:-3]
if content.endswith(".py"):
append_fx_marker_event("filename", content, event)
# Try different event types
if content.startswith(FX_GRAPH_MODULE_FILE_PREFIX) and content.endswith(
".py"
):
# FX graph event
append_fx_marker_event(ContextType.FX_GRAPH, content, event)
elif content.startswith(CALL_COMPILED_PREFIX):
# Inductor compiled graph event
append_fx_marker_event(ContextType.COMPILED_GRAPH, content, event)
elif content.startswith("inductor_kernel:"):
append_fx_marker_event(
ContextType.INDUCTOR_NODE, content[len("inductor_kernel:") :], event
)
else:
# Try to parse as node index for FX graph
# TODO: change to start with fx_node
try:
node_index = int(content)
append_fx_marker_event(ContextType.FX_NODE, node_index, event)
except ValueError:
pass
append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined]
else:
# Regular event that needs augmentation
@ -495,23 +530,37 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
case "start":
assert timeline_event.identifier is not None
if timeline_event.marker_type == "filename":
if timeline_event.marker_type in (
ContextType.FX_GRAPH,
ContextType.COMPILED_GRAPH,
):
assert isinstance(timeline_event.identifier, str)
# Push filename context - query metadata registry on-demand
metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier)
tid = timeline_event.event.get("tid")
# TODO: add get method in traceback to try - catch and get
if isinstance(metadata, str):
metadata = json.loads(metadata)
context_stack.append(
ContextStackEntry(
"filename", timeline_event.identifier, metadata, tid
timeline_event.marker_type,
timeline_event.identifier,
metadata,
tid,
)
)
elif timeline_event.marker_type == "node":
elif timeline_event.marker_type in (
ContextType.FX_NODE,
ContextType.INDUCTOR_NODE,
):
# Find the current filename from stack
current_file_metadata = None
tid = timeline_event.event.get("tid")
parent_type = get_parent_context_type(timeline_event.marker_type)
for ctx_entry in reversed(context_stack):
if (
ctx_entry.context_type == "filename"
ctx_entry.context_type == parent_type
and ctx_entry.tid == tid
):
current_file_metadata = ctx_entry.metadata
@ -520,14 +569,39 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
if current_file_metadata:
node_metadata = current_file_metadata.get("node_metadata", {})
if timeline_event.identifier in node_metadata:
node_meta: Optional[dict] = node_metadata[
timeline_event.identifier
]
context_stack.append(
ContextStackEntry(
"node", timeline_event.identifier, node_meta, tid
if ctx_entry.context_type == ContextType.FX_NODE:
node_meta: Optional[dict] = node_metadata[
timeline_event.identifier
]
context_stack.append(
ContextStackEntry(
ContextType.FX_NODE,
timeline_event.identifier,
node_meta,
tid,
)
)
if timeline_event.marker_type == ContextType.INDUCTOR_NODE:
# Look up stack traces for this kernel
# TODO: make a dictionary that maps from compiled key to stack traces dictionary
stack_traces = current_file_metadata.get(
timeline_event.identifier, []
)
if stack_traces:
# Store all stack traces as metadata
node_meta: Optional[dict] = {
"stack_trace": stack_traces,
"name": timeline_event.identifier,
}
context_stack.append(
ContextStackEntry(
ContextType.INDUCTOR_NODE,
timeline_event.identifier,
node_meta,
tid,
)
)
case "end":
# Pop from stack - search backwards to find matching context
@ -551,7 +625,10 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
for ctx_entry in reversed(context_stack):
# Only apply metadata from contexts with matching tid
if ctx_entry.tid == event_tid:
if ctx_entry.context_type == "node" and ctx_entry.metadata:
if (
ctx_entry.context_type == ContextType.FX_NODE
and ctx_entry.metadata
):
current_stack_trace = ctx_entry.metadata.get(
"stack_trace", "No model stack trace available"
)
@ -559,6 +636,19 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
# Do we want to only attach the stack trace of the lowest node or stack trace of all nodes
# if nodes are nested, e.g. in nested graph modules
break
elif (
ctx_entry.context_type == ContextType.INDUCTOR_NODE
and ctx_entry.metadata
):
# For inductor nodes, stack_trace is a list of traces
stack_traces_list = ctx_entry.metadata.get(
"stack_trace", []
)
if stack_traces_list:
# Store as a list - each trace gets its own entry
current_stack_trace = stack_traces_list
current_node_name = ctx_entry.metadata.get("name", "")
break
# Augment the event
if current_stack_trace or current_node_name:
@ -567,3 +657,81 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
args["stack_trace"] = current_stack_trace
if current_node_name:
args["node_name"] = current_node_name
import tempfile
import os
# Collect all events with stack traces and format them canonically
def _canonicalize_profiler_events(events):
"""
Extract and format all events with stack traces in a canonical way
for deterministic testing.
"""
events_with_traces = []
for event in events:
# Extract relevant fields
event_name = event.get("name", "")
node_name = event["args"].get("node_name", "")
stack_trace = event["args"].get("stack_trace", "")
if isinstance(stack_trace, list):
stack_trace = "\n".join(stack_trace)
# Get the last non-empty line of the stack trace
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
stack_trace = lines[-1] if lines else ""
events_with_traces.append(
{
"event_name": event_name[:20],
"node_name": node_name,
"stack_trace": stack_trace,
"start_time": event.get("ts", 0),
}
)
# Sort by node_name for deterministic ordering
events_with_traces.sort(key=lambda x: x["start_time"])
# Format as a string
lines: list[str] = []
for evt in events_with_traces:
lines.append(
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
)
return "\n".join(lines)
def _enrich_profiler_traces(prof):
"""
Helper function to extract and augment profiler events with stack traces.
Args:
prof: A torch.profiler.profile object
Returns:
A string representing enriched events
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
trace_file = f.name
try:
prof.export_chrome_trace(trace_file)
with open(trace_file) as f:
trace_data = json.load(f)
map_recorded_events_to_aten_ops_with_stack_trace(trace_data)
events = []
for event in trace_data["traceEvents"]:
if "args" in event and "stack_trace" in event["args"]:
events.append(event)
actual_traces = _canonicalize_profiler_events(events)
return actual_traces
finally:
if os.path.exists(trace_file):
os.remove(trace_file)