[pt2 event logging] send autotuning data for strides and hinted shapes (#158852)

Summary:
# Why

capture relevant data for offline lookup table generation

# What

report the hinted sizes not just the symbolic sizes

Test Plan:
```
buck2 run mode/opt scripts/coconutruben/torchmm:experiment 2>&1 | tee /tmp/epx040
```

This only validates that this change does not break anything, as the schema is not on scuba yet (not actualized)

Rollback Plan:

Reviewed By: stashuk-olek

Differential Revision: D77837548

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158852
Approved by: https://github.com/jingsh
This commit is contained in:
Ruben Rodriguez Buchillon
2025-07-23 06:44:27 +00:00
committed by PyTorch MergeBot
parent 1d302eaee8
commit 255a04baf1

View File

@ -2337,20 +2337,7 @@ class AlgorithmSelectorCache(PersistentCache):
f"{name}_template_autotuning",
log_pt2_compile_event=True,
dynamo_compile_column_us="compile_time_autotune_time_us",
metadata={
"autotune_strides": ", ".join(
[str(n.get_stride()) for n in input_nodes]
),
"autotune_dtypes": ", ".join(
[str(n.get_dtype()) for n in input_nodes]
),
"autotune_shape": ", ".join(
["x".join(map(str, n.get_size())) for n in input_nodes]
),
"autotune_offset": ", ".join(
[str(n.get_layout().offset) for n in input_nodes]
),
},
metadata=_autotune_metadata(input_nodes),
):
return benchmark(choices, hint_override=hint_override)
@ -3370,5 +3357,44 @@ class SymbolicGridFn:
return self.fn(*args, **kwargs, **self.kwargs_sym)
def _autotune_metadata(input_nodes):
"""Helper function to extract autotune metadata from input nodes."""
return {
"autotune_strides": ", ".join([str(n.get_stride()) for n in input_nodes]),
"autotune_dtypes": ", ".join([str(n.get_dtype()) for n in input_nodes]),
"autotune_shape": ", ".join(
["x".join(map(str, n.get_size())) for n in input_nodes]
),
"autotune_offset": ", ".join([str(n.get_layout().offset) for n in input_nodes]),
# TODO(coconutruben): replace this with taking KernelInputs as the
# argument, and extracting those out there directly
"autotune_strides_hinted": ", ".join(
[
str(
V.graph.sizevars.size_hints(
n.get_stride(),
fallback=config.unbacked_symint_fallback,
)
)
for n in input_nodes
]
),
"autotune_shape_hinted": ", ".join(
[
"x".join(
map(
str,
V.graph.sizevars.size_hints(
n.get_size(),
fallback=config.unbacked_symint_fallback,
),
)
)
for n in input_nodes
]
),
}
# ensure lowering is imported so that `extern_kernels.*` is populated
from . import lowering # noqa: F401