mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-12 14:54:55 +08:00
Compare commits
9 Commits
gh/laithsa
...
annotate_f
| Author | SHA1 | Date | |
|---|---|---|---|
| 2056d7fa22 | |||
| 4b9ba0fb26 | |||
| 106d34c80a | |||
| 0b06109412 | |||
| 2073af5790 | |||
| 9b4ac45d2f | |||
| a45a17f65e | |||
| c5593e75b3 | |||
| c90a976370 |
@ -73,6 +73,19 @@ void box_cox_zero_lambda(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
at::vec::Vectorized<T> box_cox_nonzero_lambda_impl(
|
||||
at::vec::Vectorized<T> data,
|
||||
at::vec::Vectorized<T> lambda1,
|
||||
at::vec::Vectorized<T> lambda2,
|
||||
at::vec::Vectorized<T> k_eps) {
|
||||
auto sum = data + lambda2;
|
||||
auto max = at::vec::max(sum, k_eps);
|
||||
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1);
|
||||
auto pow = max.pow(lambda1);
|
||||
return at::vec::fmsub(pow, lambda_over_1, lambda_over_1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void box_cox_nonzero_lambda(
|
||||
int64_t D,
|
||||
@ -88,21 +101,18 @@ void box_cox_nonzero_lambda(
|
||||
auto k_eps_vec = Vec(k_eps);
|
||||
for(; j + VLEN < D; j += VLEN) {
|
||||
auto data = Vec::loadu(data_ptr + j);
|
||||
auto lambda2 = Vec::loadu(lambda2_ptr + j);
|
||||
auto sum = data + lambda2;
|
||||
auto max = at::vec::max(sum, k_eps_vec);
|
||||
auto lambda1 = Vec::loadu(lambda1_ptr + j);
|
||||
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1);
|
||||
auto pow = max.pow(lambda1);
|
||||
auto res = at::vec::fmsub(pow, lambda_over_1, lambda_over_1);
|
||||
auto lambda2 = Vec::loadu(lambda2_ptr + j);
|
||||
auto res = box_cox_nonzero_lambda_impl(data, lambda1, lambda2, k_eps_vec);
|
||||
res.store(out + j);
|
||||
}
|
||||
for ( ;j < D; ++j) {
|
||||
auto sum = data_ptr[j] + lambda2_ptr[j];
|
||||
auto max = std::max(sum, k_eps);
|
||||
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1_ptr[j]);
|
||||
auto pow = std::pow(max, lambda1_ptr[j]);
|
||||
out[j] = pow * lambda_over_1 - lambda_over_1;
|
||||
if (j < D) {
|
||||
auto remaining = D - j;
|
||||
auto data = Vec::loadu(data_ptr + j, remaining);
|
||||
auto lambda1 = Vec::loadu(lambda1_ptr + j, remaining);
|
||||
auto lambda2 = Vec::loadu(lambda2_ptr + j, remaining);
|
||||
auto res = box_cox_nonzero_lambda_impl(data, lambda1, lambda2, k_eps_vec);
|
||||
res.store(out + j, remaining);
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
||||
@ -74,13 +74,13 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': None}
|
||||
# Annotation: {'stream': 0}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
# Annotation: {'stream': 1}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
# Annotation: {'stream': 1}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
|
||||
return (add_3,)
|
||||
@ -196,6 +196,7 @@ class <lambda>(torch.nn.Module):
|
||||
s_exp = fn(*inp)
|
||||
self.assertEqual(s_act, s_exp)
|
||||
|
||||
@requires_cuda
|
||||
def test_nested_stream_enter_exit(self):
|
||||
def fn(x, y, s0, s1, s2):
|
||||
with s1:
|
||||
@ -229,13 +230,13 @@ class <lambda>(torch.nn.Module):
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': None}
|
||||
# Annotation: {'stream': 1}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
# Annotation: {'stream': 2}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
# Annotation: {'stream': 1}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
|
||||
return (add_1, add_2)
|
||||
""",
|
||||
@ -249,6 +250,7 @@ class <lambda>(torch.nn.Module):
|
||||
def test_nested_stream_enter_exit_graph_break(self):
|
||||
pass
|
||||
|
||||
@requires_cuda
|
||||
def test_local_stream_enter_exit(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
@ -289,6 +291,7 @@ class <lambda>(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_local_stream_nested_enter_exit(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
@ -331,6 +334,7 @@ class <lambda>(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_with_mutation(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
@ -380,6 +384,7 @@ class <lambda>(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_backward(self) -> None:
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
|
||||
@ -500,13 +500,8 @@ class PaddingTest(TestCaseBase):
|
||||
forward_wrapper = wrapper_codes[0]
|
||||
|
||||
# make sure the load for softmax is aligned
|
||||
if bias:
|
||||
# addmm -> mm + bias and bias is fused with softmax
|
||||
softmax_load_str = "tl.load(in_out_ptr0 + (r0_1 + 30528*x0)"
|
||||
else:
|
||||
softmax_load_str = "tl.load(in_ptr0 + (r0_1 + 30528*x0)"
|
||||
self.assertTrue(
|
||||
softmax_load_str in forward_wrapper,
|
||||
"tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper,
|
||||
f"forward_wrapper: {forward_wrapper}",
|
||||
)
|
||||
|
||||
|
||||
@ -27,9 +27,9 @@ from torch._inductor.fx_passes.post_grad import post_grad_passes
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import run_and_get_code, run_and_get_cpp_code
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.testing._internal.common_utils import IS_MACOS
|
||||
from torch.testing._internal.common_utils import IS_MACOS, skipIfRocm
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
|
||||
try:
|
||||
from .test_aot_inductor_utils import AOTIRunnerUtil
|
||||
@ -941,5 +941,63 @@ copy_tests(
|
||||
)
|
||||
|
||||
|
||||
from torch.profiler._utils import _enrich_profiler_traces
|
||||
|
||||
|
||||
class TestProfilerStackTraceAugmentation(TestCase):
|
||||
"""
|
||||
Test that profiler events are correctly augmented with stack traces
|
||||
from both FX metadata and inductor kernel stack traces.
|
||||
"""
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
@config.patch("fallback_by_default", True) # TODO update the config patch to inductor lite mode
|
||||
@torch.compiler.config.patch("force_disable_caches", True)
|
||||
def test_profiler_inductor_stack_trace_augmentation(self):
|
||||
"""
|
||||
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
|
||||
augments profiler events with stack traces from inductor kernel metadata.
|
||||
"""
|
||||
|
||||
# Test model similar to test.py
|
||||
class TestModel(torch.nn.Module):
|
||||
def forward(self, c):
|
||||
d = c * 2
|
||||
d = d + 1
|
||||
return d
|
||||
|
||||
device = "cuda"
|
||||
model = TestModel().to(device)
|
||||
c = torch.randn((64, 32), device=device)
|
||||
|
||||
# Force disable caches to ensure fresh compilation
|
||||
torch.compiler.config.force_disable_caches = True
|
||||
|
||||
# Compile the model
|
||||
compiled_model = torch.compile(model, fullgraph=True)
|
||||
|
||||
# Warmup
|
||||
for _ in range(3):
|
||||
_ = compiled_model(c)
|
||||
|
||||
# Profile with the compiled model
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
) as prof:
|
||||
compiled_model(c)
|
||||
|
||||
actual_traces = _enrich_profiler_traces(prof)
|
||||
|
||||
self.assertExpectedInline(actual_traces, """\
|
||||
event=aten::mul node=torch.ops.aten.mul.Tensor:1 stack_trace=d = c * 2
|
||||
event=cudaLaunchKernel node=torch.ops.aten.mul.Tensor:1 stack_trace=d = c * 2
|
||||
event=aten::add node=torch.ops.aten.add.Tensor:2 stack_trace=d = d + 1
|
||||
event=cudaLaunchKernel node=torch.ops.aten.add.Tensor:2 stack_trace=d = d + 1""")
|
||||
|
||||
# TODO: add test that when enrich is not turned on there is no recordfast generated.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -15310,7 +15310,7 @@ if RUN_GPU:
|
||||
),
|
||||
(
|
||||
fn3,
|
||||
"triton_poi_fused_addmm_native_layer_norm",
|
||||
"triton_poi_fused_native_layer_norm_relu",
|
||||
(torch.randn(4, 4, device=GPU_TYPE),),
|
||||
),
|
||||
]
|
||||
@ -15323,7 +15323,7 @@ if RUN_GPU:
|
||||
),
|
||||
(
|
||||
fn3,
|
||||
"triton_poi_fused_LayerNorm_Linear_ReLU",
|
||||
"triton_poi_fused_LayerNorm_ReLU",
|
||||
(torch.randn(4, 4, device=GPU_TYPE),),
|
||||
),
|
||||
]
|
||||
|
||||
@ -7508,6 +7508,8 @@ class TestFXMemoryProfiler(TestCase):
|
||||
device = "cuda"
|
||||
mod = MLPModule(device)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# reset cache to start fresh
|
||||
torch.cuda.memory.empty_cache()
|
||||
torch.cuda.memory._record_memory_history()
|
||||
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
||||
result = compiled(torch.randn(10, 10, device=device))
|
||||
@ -7518,10 +7520,7 @@ class TestFXMemoryProfiler(TestCase):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
fx_frames = self.collect_frames(augmented_snapshot)
|
||||
if TEST_WITH_ROCM:
|
||||
self.assertGreater(len(fx_frames), 0)
|
||||
else:
|
||||
self.assertEqual(len(fx_frames), 12)
|
||||
self.assertGreater(len(fx_frames), 2)
|
||||
|
||||
for frame in fx_frames:
|
||||
# Every FX frame should have both node_op and node_name
|
||||
|
||||
@ -76,11 +76,8 @@ from torch.testing._internal.common_utils import (
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace
|
||||
from torch.autograd.profiler_util import _canonicalize_profiler_events
|
||||
from torch.profiler._utils import _enrich_profiler_traces
|
||||
|
||||
try:
|
||||
from torchvision import models as torchvision_models
|
||||
@ -208,36 +205,6 @@ def side_effect_func(x: torch.Tensor):
|
||||
print(x)
|
||||
|
||||
|
||||
def _enrich_profiler_traces(prof):
|
||||
"""
|
||||
Helper function to extract and augment profiler events with stack traces.
|
||||
|
||||
Args:
|
||||
prof: A torch.profiler.profile object
|
||||
|
||||
Returns:
|
||||
A string representing enriched events
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f:
|
||||
trace_file = f.name
|
||||
prof.export_chrome_trace(trace_file)
|
||||
|
||||
with open(trace_file) as f:
|
||||
trace_data = json.load(f)
|
||||
|
||||
map_recorded_events_to_aten_ops_with_stack_trace(
|
||||
trace_data
|
||||
)
|
||||
|
||||
events = []
|
||||
for event in trace_data["traceEvents"]:
|
||||
if "args" in event and "stack_trace" in event["args"]:
|
||||
events.append(event)
|
||||
|
||||
actual_traces = _canonicalize_profiler_events(events)
|
||||
return actual_traces
|
||||
|
||||
|
||||
class TestFX(JitTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
@ -1542,7 +1542,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
)
|
||||
)
|
||||
tmp_vars = []
|
||||
for constructor in reversed(index_to_bytecode_constructor.values()):
|
||||
for constructor in index_to_bytecode_constructor.values():
|
||||
constructor(codegen)
|
||||
var_name = (
|
||||
self.new_var()
|
||||
|
||||
@ -1061,9 +1061,7 @@ class VariableBuilder:
|
||||
)
|
||||
set_example_value(stream_proxy.node, value)
|
||||
var = StreamVariable(
|
||||
stream_proxy,
|
||||
value,
|
||||
source=self.source,
|
||||
stream_proxy, value, source=self.source, user_object_index=index
|
||||
)
|
||||
return self.tx.output.side_effects.track_object_existing(value, var)
|
||||
elif isinstance(value, (torch._C._SDPAParams)):
|
||||
@ -3006,14 +3004,16 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
||||
return SymNodeVariable(proxy, example_value, **options)
|
||||
elif (
|
||||
isinstance(example_value, torch.Stream)
|
||||
and proxy.node.target
|
||||
in (get_external_object_by_index, torch.accelerator.current_stream)
|
||||
and proxy.node.target == get_external_object_by_index
|
||||
) or proxy.node.target in [
|
||||
device_interface.current_stream
|
||||
for _, device_interface in get_registered_device_interfaces()
|
||||
]:
|
||||
set_example_value(proxy.node, example_value)
|
||||
return StreamVariable(proxy, example_value, **options)
|
||||
index = None
|
||||
if proxy.node.target == get_external_object_by_index:
|
||||
index = proxy.node.args[0]
|
||||
return StreamVariable(proxy, example_value, index, **options)
|
||||
elif (
|
||||
inspect.isclass(proxy.node.target)
|
||||
and issubclass(proxy.node.target, torch.Event)
|
||||
|
||||
@ -204,11 +204,11 @@ class StreamVariable(StreamContextVariable):
|
||||
self,
|
||||
proxy: Proxy,
|
||||
value: torch.Stream,
|
||||
user_object_index: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Index into the user object table
|
||||
# used to pass arbitrary objects to the graph
|
||||
user_object_index = kwargs.pop("user_obj_index", None)
|
||||
if proxy is not None and "example_value" in proxy.node.meta:
|
||||
assert proxy.node.meta["example_value"] == value
|
||||
|
||||
@ -300,7 +300,7 @@ class StreamVariable(StreamContextVariable):
|
||||
codegen.append_output(codegen.create_load_const(self.user_object_index))
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
else:
|
||||
# TODO mlazos: evaluate if we still need this
|
||||
# This will support the legacy behavior
|
||||
prefix = f"_stream_{self.device}"
|
||||
name = codegen.tx.output.install_global_by_id(prefix, self.value)
|
||||
codegen.append_output(codegen.create_load_global(name, add=True))
|
||||
|
||||
@ -838,7 +838,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function", get_external_object_by_index, (ind,), {}
|
||||
),
|
||||
user_obj_index=ind,
|
||||
)
|
||||
else:
|
||||
tensor_variable = wrap_fx_proxy(
|
||||
|
||||
@ -70,7 +70,7 @@ from .common import (
|
||||
)
|
||||
from .cpp_utils import cexpr
|
||||
from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta
|
||||
|
||||
from torch.fx.experimental import _config as fx_experimental_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
@ -1120,6 +1120,37 @@ class PythonWrapperCodegen(CodeGen):
|
||||
# Additional files that are dependent to the wrapper (ex. cubin files)
|
||||
self.additional_files = []
|
||||
|
||||
# This is used to emit RecordFunctionFast markers that can be matched
|
||||
# with profiler traces for provenance tracking.
|
||||
#
|
||||
# Stores the (kernel_name, debug_handle) tuple
|
||||
# for the currently being generated kernel.
|
||||
self.current_kernel_debug_handle: Optional[tuple[str, int]] = None
|
||||
|
||||
# set_current_kernel_debug_handle: Flag that controls whether
|
||||
# write_provenance_debug_handle() should update current_kernel_debug_handle.
|
||||
# This flag is automatically managed by kernel_debug_handle_context().
|
||||
self.set_current_kernel_debug_handle: bool = False
|
||||
|
||||
@contextlib.contextmanager
|
||||
def kernel_debug_handle_context(self):
|
||||
"""
|
||||
Context manager for kernel debug handle tracking.
|
||||
|
||||
self.current_kernel_debug_handle can be updated within the context
|
||||
with wrapper.write_provenance_debug_handle
|
||||
and it will be reset after the context
|
||||
"""
|
||||
old_flag_value = self.set_current_kernel_debug_handle
|
||||
old_handle_value = self.current_kernel_debug_handle
|
||||
self.set_current_kernel_debug_handle = True
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.set_current_kernel_debug_handle = old_flag_value
|
||||
self.current_kernel_debug_handle = old_handle_value
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
is_subgraph: bool,
|
||||
@ -1510,8 +1541,27 @@ class PythonWrapperCodegen(CodeGen):
|
||||
def generate_end(self, result: IndentedBuffer) -> None:
|
||||
return
|
||||
|
||||
def generate_record_function_start(self) -> Optional[str]:
|
||||
record_func = self.current_kernel_debug_handle and fx_experimental_config.enrich_profiler_metadata
|
||||
if record_func:
|
||||
assert self.current_kernel_debug_handle
|
||||
kernel_name, debug_handle = self.current_kernel_debug_handle
|
||||
kernel_debug_handle = f"{kernel_name}:{debug_handle}"
|
||||
self.writeline(
|
||||
f"_rf_enter = torch._C._profiler._RecordFunctionFast('## inductor_kernel:{kernel_debug_handle} ##'); _rf_enter.__enter__()"
|
||||
)
|
||||
return "_rf_enter"
|
||||
else:
|
||||
return None
|
||||
|
||||
def generate_record_function_end(self, record_func_var: Optional[str]):
|
||||
if record_func_var:
|
||||
self.writeline(f"{record_func_var}.__exit__(None, None, None)")
|
||||
|
||||
def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None:
|
||||
record_func_var = self.generate_record_function_start()
|
||||
self.writeline(ExternKernelAllocLine(self, node))
|
||||
self.generate_record_function_end(record_func_var)
|
||||
|
||||
def generate_extern_kernel_alloc(self, node: ir.ExternKernelAlloc):
|
||||
node.codegen_comment(self)
|
||||
@ -1671,7 +1721,9 @@ class PythonWrapperCodegen(CodeGen):
|
||||
raw_args: Sequence[Any],
|
||||
outputs: Sequence[ir.Buffer],
|
||||
) -> None:
|
||||
record_func_var = self.generate_record_function_start()
|
||||
self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(get_args())})")
|
||||
self.generate_record_function_end(record_func_var)
|
||||
|
||||
def generate(self, is_inference):
|
||||
with dynamo_timed("PythonWrapperCodegen.generate"):
|
||||
@ -3142,6 +3194,8 @@ class PythonWrapperCodegen(CodeGen):
|
||||
self.writeline(
|
||||
f"{self.comment} [Provenance debug handles] {kernel_name}:{debug_handle}"
|
||||
)
|
||||
if self.set_current_kernel_debug_handle:
|
||||
self.current_kernel_debug_handle = (kernel_name, debug_handle)
|
||||
|
||||
def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool):
|
||||
assert old.get_dtype() == new.get_dtype()
|
||||
|
||||
@ -98,6 +98,8 @@ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExpr
|
||||
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
|
||||
from torch.monitor import _WaitCounter
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.fx.experimental import _config as fx_experimental_config
|
||||
|
||||
|
||||
from .._dynamo.backends.common import aot_autograd
|
||||
from .._dynamo.exc import ShortenTraceback, SkipFrame
|
||||
@ -1530,7 +1532,10 @@ class _InProcessFxCompile(FxCompile):
|
||||
# Dump provenance artifacts for debugging trace
|
||||
inductor_provenance_tracking_node_mappings = None
|
||||
inductor_kernel_stack_trace_str = None
|
||||
if config.trace.provenance_tracking_level != 0:
|
||||
if (
|
||||
config.trace.provenance_tracking_level != 0
|
||||
or fx_experimental_config.enrich_profiler_metadata
|
||||
):
|
||||
inductor_provenance_tracking_node_mappings = json.dumps(
|
||||
torch._inductor.debug.dump_inductor_provenance_info()
|
||||
)
|
||||
|
||||
@ -1179,6 +1179,8 @@ torchinductor_worker_logpath: str = Config(
|
||||
default="",
|
||||
)
|
||||
|
||||
fallback_by_default: bool = False
|
||||
|
||||
|
||||
# config specific to codegen/cpp.py
|
||||
class cpp:
|
||||
|
||||
@ -1106,7 +1106,7 @@ def set_kernel_post_grad_provenance_tracing(
|
||||
Returns a unique int debug handler for each call to this function.
|
||||
"""
|
||||
|
||||
if config.trace.provenance_tracking_level == 0:
|
||||
if config.trace.provenance_tracking_level == 0 and not config.fallback_by_default:
|
||||
return None
|
||||
|
||||
try:
|
||||
|
||||
@ -52,8 +52,8 @@ from ..utils import (
|
||||
decode_device,
|
||||
get_all_devices,
|
||||
get_gpu_type,
|
||||
has_uses_tagged_as,
|
||||
is_gpu,
|
||||
is_pointwise_use,
|
||||
OPTIMUS_EXCLUDE_POST_GRAD,
|
||||
)
|
||||
from ..virtualized import V
|
||||
@ -1511,10 +1511,8 @@ def should_prefer_unfused_addmm(match):
|
||||
if not is_gpu(inp.meta["val"].device.type):
|
||||
return False
|
||||
|
||||
return has_uses_tagged_as(
|
||||
match.output_node(),
|
||||
(torch.Tag.pointwise, torch.Tag.reduction),
|
||||
)
|
||||
output = match.output_node()
|
||||
return all(is_pointwise_use(use) for use in output.users)
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
|
||||
@ -1628,6 +1628,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
"inductor", "lowerings", lambda: repr(n)
|
||||
)
|
||||
)
|
||||
or (n.op == "call_function" and config.fallback_by_default)
|
||||
):
|
||||
debug("fallback_handler")
|
||||
result = fallback_handler(n.target, add_to_fallback_set=False)(
|
||||
|
||||
@ -8079,27 +8079,28 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
for v, a in zip(args_iter, kernel._schema.arguments)
|
||||
)
|
||||
|
||||
self.codegen_comment(wrapper)
|
||||
if self.use_runtime_dispatch:
|
||||
exported_args = self.export_extern_kernel_node()
|
||||
assert self.python_kernel_name is not None
|
||||
assert self.op_overload is not None
|
||||
with wrapper.kernel_debug_handle_context():
|
||||
self.codegen_comment(wrapper)
|
||||
if self.use_runtime_dispatch:
|
||||
exported_args = self.export_extern_kernel_node()
|
||||
assert self.python_kernel_name is not None
|
||||
assert self.op_overload is not None
|
||||
|
||||
wrapper.generate_fallback_kernel_with_runtime_lookup(
|
||||
self.get_name(),
|
||||
self.python_kernel_name,
|
||||
lambda: [*self.codegen_args(), *self.codegen_kwargs()],
|
||||
self.op_overload,
|
||||
exported_args,
|
||||
# NOTE: [special handling of all_reduce_coalesced_'s return value]
|
||||
self.outputs if self.outputs else self.mutation_outputs,
|
||||
)
|
||||
else:
|
||||
wrapper.generate_fallback_kernel(self)
|
||||
if isinstance(self.layout, Layout):
|
||||
self.codegen_size_asserts(wrapper)
|
||||
self.codegen_alignment_asserts(wrapper)
|
||||
self.codegen_memory_tracking(wrapper)
|
||||
wrapper.generate_fallback_kernel_with_runtime_lookup(
|
||||
self.get_name(),
|
||||
self.python_kernel_name,
|
||||
lambda: [*self.codegen_args(), *self.codegen_kwargs()],
|
||||
self.op_overload,
|
||||
exported_args,
|
||||
# NOTE: [special handling of all_reduce_coalesced_'s return value]
|
||||
self.outputs if self.outputs else self.mutation_outputs,
|
||||
)
|
||||
else:
|
||||
wrapper.generate_fallback_kernel(self)
|
||||
if isinstance(self.layout, Layout):
|
||||
self.codegen_size_asserts(wrapper)
|
||||
self.codegen_alignment_asserts(wrapper)
|
||||
self.codegen_memory_tracking(wrapper)
|
||||
|
||||
self.codegen_unbacked_symbol_defs(wrapper)
|
||||
|
||||
|
||||
@ -7096,19 +7096,30 @@ def sym_constrain_range(a, min=None, max=None):
|
||||
@register_lowering(aten.sym_size.int)
|
||||
def sym_size(a, dim):
|
||||
val = V.graph.current_node.meta["val"]
|
||||
if isinstance(val, torch.SymInt):
|
||||
return val.node.expr
|
||||
else:
|
||||
return int(val)
|
||||
# Note [Can val be an int?]
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# In principle, someone could construct an FX graph where
|
||||
# a call to size/stride has a val that is a plain int (not
|
||||
# SymInt). However, we will maintain the invariant that
|
||||
# this is not possible: if you are constructing an FX graph
|
||||
# where there is a call to size/stride that returns an
|
||||
# int, but you KNOW that int must always be a constant,
|
||||
# then you do not need trace that call at all (and just
|
||||
# constant propagate the integer as is.)
|
||||
assert isinstance(val, torch.SymInt), (
|
||||
f"Expect val to be torch.SymInt but got val={val}"
|
||||
)
|
||||
return val.node.expr
|
||||
|
||||
|
||||
@register_lowering(aten.sym_stride.int)
|
||||
def sym_stride(a, dim):
|
||||
val = V.graph.current_node.meta["val"]
|
||||
if isinstance(val, torch.SymInt):
|
||||
return val.node.expr
|
||||
else:
|
||||
return int(val)
|
||||
# See Note [Can val be an int?]
|
||||
assert isinstance(val, torch.SymInt), (
|
||||
f"Expect val to be torch.SymInt but got val={val}"
|
||||
)
|
||||
return val.node.expr
|
||||
|
||||
|
||||
@register_lowering(aten.sym_numel)
|
||||
|
||||
@ -25,6 +25,8 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
from functools import partial
|
||||
from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union
|
||||
|
||||
@ -70,6 +72,10 @@ if TYPE_CHECKING:
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Used for profiler post-processing to match
|
||||
# for the same compiled run
|
||||
CALL_COMPILED_PREFIX = "Call CompiledFxGraph"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class OutputCode:
|
||||
@ -612,9 +618,18 @@ class CompiledFxGraph(OutputCode):
|
||||
try:
|
||||
# Checking the profiler directly is faster than nullcontext
|
||||
if torch.autograd.profiler._is_profiler_enabled:
|
||||
with record_function(
|
||||
f"## Call CompiledFxGraph {self._fx_graph_cache_key} ##"
|
||||
):
|
||||
# generate a random string to represent this unique run if no cache key
|
||||
run_key = (
|
||||
self._fx_graph_cache_key
|
||||
if self._fx_graph_cache_key
|
||||
else "".join(random.choices(string.ascii_lowercase, k=51))
|
||||
)
|
||||
run_name = f"{CALL_COMPILED_PREFIX} {run_key}"
|
||||
if self.inductor_provenance_stack_traces_str:
|
||||
torch.fx.traceback._register_fx_metadata(
|
||||
run_name, self.inductor_provenance_stack_traces_str
|
||||
)
|
||||
with record_function(f"## {run_name} ##"):
|
||||
return self.current_callable(inputs)
|
||||
else:
|
||||
return self.current_callable(inputs)
|
||||
|
||||
@ -549,70 +549,6 @@ def is_pointwise_use(
|
||||
return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
|
||||
|
||||
|
||||
class LogicalConnective(enum.Enum):
|
||||
OR = enum.auto()
|
||||
AND = enum.auto()
|
||||
|
||||
|
||||
def has_uses(
|
||||
target: Node,
|
||||
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
|
||||
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
|
||||
) -> bool:
|
||||
"""
|
||||
Given a target, explore the uses of `target` by applying `use_selector_fn`
|
||||
on them, and then aggregate these booleans with the `use_aggregate_type`
|
||||
logical connective.
|
||||
|
||||
Uses in view ops will follow the views uses.
|
||||
"""
|
||||
|
||||
def get_use_aggregate_fn(
|
||||
use_aggregate_type: LogicalConnective,
|
||||
) -> Callable[[Iterator[Any]], bool]:
|
||||
match use_aggregate_type:
|
||||
case LogicalConnective.AND:
|
||||
return all
|
||||
case LogicalConnective.OR:
|
||||
return any
|
||||
case _:
|
||||
return any
|
||||
|
||||
use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type)
|
||||
|
||||
def has_uses_impl(use: Node) -> bool:
|
||||
if use.op != "call_function":
|
||||
return False
|
||||
if not (
|
||||
isinstance(use.target, torch._ops.OpOverload)
|
||||
or use.target is operator.getitem
|
||||
):
|
||||
return False
|
||||
|
||||
target = cast(torch._ops.OpOverload, use.target)
|
||||
# Process getitem and view
|
||||
if target is operator.getitem or is_view(target):
|
||||
return use_aggregate_fn(has_uses_impl(user) for user in use.users)
|
||||
|
||||
return use_selector_fn(target)
|
||||
|
||||
return use_aggregate_fn(has_uses_impl(user) for user in target.users)
|
||||
|
||||
|
||||
def has_uses_tagged_as(
|
||||
target: Node,
|
||||
use_tags: Collection[torch.Tag],
|
||||
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
|
||||
) -> bool:
|
||||
"""
|
||||
Is there a use with given tags?
|
||||
"""
|
||||
|
||||
return has_uses(
|
||||
target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type
|
||||
)
|
||||
|
||||
|
||||
def gen_gm_and_inputs(
|
||||
target: Any, args: list[Any], kwargs: dict[str, Any]
|
||||
) -> tuple[GraphModule, list[torch.Tensor]]:
|
||||
|
||||
@ -1224,43 +1224,3 @@ def _build_table(
|
||||
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
|
||||
)
|
||||
return "".join(result)
|
||||
|
||||
|
||||
# Collect all events with stack traces and format them canonically
|
||||
def _canonicalize_profiler_events(events):
|
||||
"""
|
||||
Extract and format all events with stack traces in a canonical way
|
||||
for deterministic testing.
|
||||
"""
|
||||
events_with_traces = []
|
||||
|
||||
for event in events:
|
||||
# Extract relevant fields
|
||||
event_name = event.get("name", "")
|
||||
node_name = event["args"].get("node_name", "")
|
||||
stack_trace = event["args"].get("stack_trace", "")
|
||||
|
||||
# Get the last non-empty line of the stack trace
|
||||
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
|
||||
stack_trace = lines[-1] if lines else ""
|
||||
|
||||
events_with_traces.append(
|
||||
{
|
||||
"event_name": event_name[:20],
|
||||
"node_name": node_name,
|
||||
"stack_trace": stack_trace,
|
||||
"start_time": event.get("ts", 0),
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by node_name for deterministic ordering
|
||||
events_with_traces.sort(key=lambda x: x["start_time"])
|
||||
|
||||
# Format as a string
|
||||
lines: list[str] = []
|
||||
for evt in events_with_traces:
|
||||
lines.append(
|
||||
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
@ -33,14 +33,6 @@ static inline int PyCode_GetNFreevars(PyCodeObject* code) {
|
||||
#endif
|
||||
}
|
||||
|
||||
// Provided by CPython but getting the header for them is very hard
|
||||
#if IS_PYTHON_3_11_PLUS
|
||||
// NOLINTNEXTLINE(readability-redundant-declaration)
|
||||
PyAPI_FUNC(void) _PyWeakref_ClearRef(PyWeakReference* self);
|
||||
#else
|
||||
extern void _PyWeakref_ClearRef(PyWeakReference* self);
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -43,10 +43,10 @@ should_preserve_node_meta = False
|
||||
# =============================================================================
|
||||
# Global in-memory registry for FX metadata
|
||||
# Maps module_name -> metadata dict containing lineno_map and node_metadata
|
||||
_FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {}
|
||||
_FX_METADATA_REGISTRY: dict[str, str | dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
|
||||
def _register_fx_metadata(module_name: str, metadata: str | dict[str, Any]) -> None:
|
||||
"""
|
||||
Register FX metadata in the global in-memory registry.
|
||||
|
||||
@ -55,7 +55,7 @@ def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
|
||||
|
||||
Args:
|
||||
module_name: The module identifier (content-addressed filename)
|
||||
metadata: Metadata dict containing lineno_map, node_metadata, and source_code
|
||||
metadata: Metadata dict containing lineno_map, node_metadata, and source_code. If a str, it's a json dump that can be json loaded as a dict.
|
||||
"""
|
||||
# TODO: add logging to tlparse
|
||||
_FX_METADATA_REGISTRY[module_name] = metadata
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import json
|
||||
import operator
|
||||
import re
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Optional, TYPE_CHECKING
|
||||
|
||||
from torch.autograd.profiler import profile
|
||||
@ -402,13 +404,31 @@ def _init_for_cuda_graphs() -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ContextType(Enum):
|
||||
"""Types of contexts in the profiler stack."""
|
||||
|
||||
FX_GRAPH = "filename"
|
||||
FX_NODE = "node"
|
||||
COMPILED_GRAPH = "compiled_graph"
|
||||
INDUCTOR_NODE = "inductor_node"
|
||||
|
||||
|
||||
def get_parent_context_type(context_type: ContextType) -> Optional[ContextType]:
|
||||
if context_type == ContextType.FX_NODE:
|
||||
return ContextType.FX_GRAPH
|
||||
elif context_type == ContextType.INDUCTOR_NODE:
|
||||
return ContextType.COMPILED_GRAPH
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimelineEvent:
|
||||
"""Represents an event in the profiler timeline."""
|
||||
|
||||
timestamp: int
|
||||
event_type: Literal["start", "end", "regular"]
|
||||
marker_type: Optional[Literal["filename", "node"]]
|
||||
marker_type: Optional[ContextType]
|
||||
identifier: Optional[str | int]
|
||||
event: dict[str, Any]
|
||||
|
||||
@ -417,7 +437,7 @@ class TimelineEvent:
|
||||
class ContextStackEntry:
|
||||
"""Represents a context (filename or node) in the stack."""
|
||||
|
||||
context_type: Literal["filename", "node"]
|
||||
context_type: ContextType
|
||||
identifier: str | int
|
||||
metadata: Optional[dict]
|
||||
tid: Optional[int] = None # Thread ID associated with this context
|
||||
@ -438,6 +458,8 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
Returns:
|
||||
Dict mapping recorded event names to their aten operations with added stack traces
|
||||
"""
|
||||
from torch._inductor.output_code import CALL_COMPILED_PREFIX
|
||||
from torch.fx.graph_module import FX_GRAPH_MODULE_FILE_PREFIX
|
||||
from torch.fx.traceback import _FX_METADATA_REGISTRY
|
||||
|
||||
trace_events = traced_data.get("traceEvents", [])
|
||||
@ -447,7 +469,7 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
|
||||
def is_fx_marker_event(event):
|
||||
return (
|
||||
event.get("cat") == "cpu_op"
|
||||
event.get("cat") in ("cpu_op", "user_annotation")
|
||||
and event.get("name", "").startswith("## ")
|
||||
and event.get("name", "").endswith(" ##")
|
||||
)
|
||||
@ -469,14 +491,27 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
if is_fx_marker_event(event):
|
||||
content = event["name"][3:-3]
|
||||
|
||||
if content.endswith(".py"):
|
||||
append_fx_marker_event("filename", content, event)
|
||||
# Try different event types
|
||||
if content.startswith(FX_GRAPH_MODULE_FILE_PREFIX) and content.endswith(
|
||||
".py"
|
||||
):
|
||||
# FX graph event
|
||||
append_fx_marker_event(ContextType.FX_GRAPH, content, event)
|
||||
elif content.startswith(CALL_COMPILED_PREFIX):
|
||||
# Inductor compiled graph event
|
||||
append_fx_marker_event(ContextType.COMPILED_GRAPH, content, event)
|
||||
elif content.startswith("inductor_kernel:"):
|
||||
append_fx_marker_event(
|
||||
ContextType.INDUCTOR_NODE, content[len("inductor_kernel:") :], event
|
||||
)
|
||||
else:
|
||||
# Try to parse as node index for FX graph
|
||||
# TODO: change to start with fx_node
|
||||
try:
|
||||
node_index = int(content)
|
||||
append_fx_marker_event(ContextType.FX_NODE, node_index, event)
|
||||
except ValueError:
|
||||
pass
|
||||
append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined]
|
||||
|
||||
else:
|
||||
# Regular event that needs augmentation
|
||||
@ -495,23 +530,37 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
case "start":
|
||||
assert timeline_event.identifier is not None
|
||||
|
||||
if timeline_event.marker_type == "filename":
|
||||
if timeline_event.marker_type in (
|
||||
ContextType.FX_GRAPH,
|
||||
ContextType.COMPILED_GRAPH,
|
||||
):
|
||||
assert isinstance(timeline_event.identifier, str)
|
||||
# Push filename context - query metadata registry on-demand
|
||||
metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier)
|
||||
tid = timeline_event.event.get("tid")
|
||||
|
||||
# TODO: add get method in traceback to try - catch and get
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
context_stack.append(
|
||||
ContextStackEntry(
|
||||
"filename", timeline_event.identifier, metadata, tid
|
||||
timeline_event.marker_type,
|
||||
timeline_event.identifier,
|
||||
metadata,
|
||||
tid,
|
||||
)
|
||||
)
|
||||
elif timeline_event.marker_type == "node":
|
||||
elif timeline_event.marker_type in (
|
||||
ContextType.FX_NODE,
|
||||
ContextType.INDUCTOR_NODE,
|
||||
):
|
||||
# Find the current filename from stack
|
||||
current_file_metadata = None
|
||||
tid = timeline_event.event.get("tid")
|
||||
parent_type = get_parent_context_type(timeline_event.marker_type)
|
||||
for ctx_entry in reversed(context_stack):
|
||||
if (
|
||||
ctx_entry.context_type == "filename"
|
||||
ctx_entry.context_type == parent_type
|
||||
and ctx_entry.tid == tid
|
||||
):
|
||||
current_file_metadata = ctx_entry.metadata
|
||||
@ -520,14 +569,39 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
if current_file_metadata:
|
||||
node_metadata = current_file_metadata.get("node_metadata", {})
|
||||
if timeline_event.identifier in node_metadata:
|
||||
node_meta: Optional[dict] = node_metadata[
|
||||
timeline_event.identifier
|
||||
]
|
||||
context_stack.append(
|
||||
ContextStackEntry(
|
||||
"node", timeline_event.identifier, node_meta, tid
|
||||
if ctx_entry.context_type == ContextType.FX_NODE:
|
||||
node_meta: Optional[dict] = node_metadata[
|
||||
timeline_event.identifier
|
||||
]
|
||||
context_stack.append(
|
||||
ContextStackEntry(
|
||||
ContextType.FX_NODE,
|
||||
timeline_event.identifier,
|
||||
node_meta,
|
||||
tid,
|
||||
)
|
||||
)
|
||||
|
||||
if timeline_event.marker_type == ContextType.INDUCTOR_NODE:
|
||||
# Look up stack traces for this kernel
|
||||
# TODO: make a dictionary that maps from compiled key to stack traces dictionary
|
||||
stack_traces = current_file_metadata.get(
|
||||
timeline_event.identifier, []
|
||||
)
|
||||
if stack_traces:
|
||||
# Store all stack traces as metadata
|
||||
node_meta: Optional[dict] = {
|
||||
"stack_trace": stack_traces,
|
||||
"name": timeline_event.identifier,
|
||||
}
|
||||
context_stack.append(
|
||||
ContextStackEntry(
|
||||
ContextType.INDUCTOR_NODE,
|
||||
timeline_event.identifier,
|
||||
node_meta,
|
||||
tid,
|
||||
)
|
||||
)
|
||||
|
||||
case "end":
|
||||
# Pop from stack - search backwards to find matching context
|
||||
@ -551,7 +625,10 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
for ctx_entry in reversed(context_stack):
|
||||
# Only apply metadata from contexts with matching tid
|
||||
if ctx_entry.tid == event_tid:
|
||||
if ctx_entry.context_type == "node" and ctx_entry.metadata:
|
||||
if (
|
||||
ctx_entry.context_type == ContextType.FX_NODE
|
||||
and ctx_entry.metadata
|
||||
):
|
||||
current_stack_trace = ctx_entry.metadata.get(
|
||||
"stack_trace", "No model stack trace available"
|
||||
)
|
||||
@ -559,6 +636,19 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
# Do we want to only attach the stack trace of the lowest node or stack trace of all nodes
|
||||
# if nodes are nested, e.g. in nested graph modules
|
||||
break
|
||||
elif (
|
||||
ctx_entry.context_type == ContextType.INDUCTOR_NODE
|
||||
and ctx_entry.metadata
|
||||
):
|
||||
# For inductor nodes, stack_trace is a list of traces
|
||||
stack_traces_list = ctx_entry.metadata.get(
|
||||
"stack_trace", []
|
||||
)
|
||||
if stack_traces_list:
|
||||
# Store as a list - each trace gets its own entry
|
||||
current_stack_trace = stack_traces_list
|
||||
current_node_name = ctx_entry.metadata.get("name", "")
|
||||
break
|
||||
|
||||
# Augment the event
|
||||
if current_stack_trace or current_node_name:
|
||||
@ -567,3 +657,81 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
args["stack_trace"] = current_stack_trace
|
||||
if current_node_name:
|
||||
args["node_name"] = current_node_name
|
||||
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
# Collect all events with stack traces and format them canonically
|
||||
def _canonicalize_profiler_events(events):
|
||||
"""
|
||||
Extract and format all events with stack traces in a canonical way
|
||||
for deterministic testing.
|
||||
"""
|
||||
events_with_traces = []
|
||||
|
||||
for event in events:
|
||||
# Extract relevant fields
|
||||
event_name = event.get("name", "")
|
||||
node_name = event["args"].get("node_name", "")
|
||||
stack_trace = event["args"].get("stack_trace", "")
|
||||
|
||||
if isinstance(stack_trace, list):
|
||||
stack_trace = "\n".join(stack_trace)
|
||||
|
||||
# Get the last non-empty line of the stack trace
|
||||
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
|
||||
stack_trace = lines[-1] if lines else ""
|
||||
|
||||
events_with_traces.append(
|
||||
{
|
||||
"event_name": event_name[:20],
|
||||
"node_name": node_name,
|
||||
"stack_trace": stack_trace,
|
||||
"start_time": event.get("ts", 0),
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by node_name for deterministic ordering
|
||||
events_with_traces.sort(key=lambda x: x["start_time"])
|
||||
|
||||
# Format as a string
|
||||
lines: list[str] = []
|
||||
for evt in events_with_traces:
|
||||
lines.append(
|
||||
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _enrich_profiler_traces(prof):
|
||||
"""
|
||||
Helper function to extract and augment profiler events with stack traces.
|
||||
|
||||
Args:
|
||||
prof: A torch.profiler.profile object
|
||||
|
||||
Returns:
|
||||
A string representing enriched events
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
trace_file = f.name
|
||||
|
||||
try:
|
||||
prof.export_chrome_trace(trace_file)
|
||||
|
||||
with open(trace_file) as f:
|
||||
trace_data = json.load(f)
|
||||
|
||||
map_recorded_events_to_aten_ops_with_stack_trace(trace_data)
|
||||
|
||||
events = []
|
||||
for event in trace_data["traceEvents"]:
|
||||
if "args" in event and "stack_trace" in event["args"]:
|
||||
events.append(event)
|
||||
|
||||
actual_traces = _canonicalize_profiler_events(events)
|
||||
return actual_traces
|
||||
finally:
|
||||
if os.path.exists(trace_file):
|
||||
os.remove(trace_file)
|
||||
|
||||
Reference in New Issue
Block a user