Update base for Update on "[Inductor XPU GEMM] Step 1/N: Refactor cutlass configuration."

This PR is the first step toward implementing RFC #160175.
Currently, all Cutlass-related Torch Inductor configs are located in `torch._inductor.config.cuda`. This PR refactors the device-agnostic Cutlass configurations into `torch._inductor.config.cutlass`, so they can be shared and reused by XPU as well.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
This commit is contained in:
xinan.lin
2025-11-06 15:13:45 +00:00
22 changed files with 729 additions and 80 deletions

View File

@ -7,13 +7,13 @@ ENV LC_ALL en_US.UTF-8
ENV LANG en_US.UTF-8
ENV LANGUAGE en_US.UTF-8
ARG DEVTOOLSET_VERSION=11
ARG DEVTOOLSET_VERSION=13
RUN yum -y update
RUN yum -y install epel-release
# install glibc-langpack-en make sure en_US.UTF-8 locale is available
RUN yum -y install glibc-langpack-en
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-toolchain
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
# Just add everything as a safe.directory for git since these will be used in multiple places with git
RUN git config --global --add safe.directory '*'
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
@ -41,6 +41,7 @@ RUN bash ./install_conda.sh && rm install_conda.sh
# Install CUDA
FROM base as cuda
ARG CUDA_VERSION=12.6
ARG DEVTOOLSET_VERSION=13
RUN rm -rf /usr/local/cuda-*
ADD ./common/install_cuda.sh install_cuda.sh
COPY ./common/install_nccl.sh install_nccl.sh
@ -50,7 +51,8 @@ ENV CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
# Preserve CUDA_VERSION for the builds
ENV CUDA_VERSION=${CUDA_VERSION}
# Make things in our path by default
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:$PATH
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
FROM cuda as cuda12.6
RUN bash ./install_cuda.sh 12.6
@ -68,8 +70,22 @@ FROM cuda as cuda13.0
RUN bash ./install_cuda.sh 13.0
ENV DESIRED_CUDA=13.0
FROM ${ROCM_IMAGE} as rocm
FROM ${ROCM_IMAGE} as rocm_base
ARG DEVTOOLSET_VERSION=13
ENV LC_ALL en_US.UTF-8
ENV LANG en_US.UTF-8
ENV LANGUAGE en_US.UTF-8
# Install devtoolset on ROCm base image
RUN yum -y update && \
yum -y install epel-release && \
yum -y install glibc-langpack-en && \
yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
RUN git config --global --add safe.directory '*'
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
FROM rocm_base as rocm
ARG PYTORCH_ROCM_ARCH
ARG DEVTOOLSET_VERSION=13
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
ADD ./common/install_mkl.sh install_mkl.sh
RUN bash ./install_mkl.sh && rm install_mkl.sh
@ -88,6 +104,7 @@ COPY --from=cuda13.0 /usr/local/cuda-13.0 /usr/local/cuda-13.0
# Final step
FROM ${BASE_TARGET} as final
ARG DEVTOOLSET_VERSION=13
COPY --from=openssl /opt/openssl /opt/openssl
COPY --from=patchelf /patchelf /usr/local/bin/patchelf
COPY --from=conda /opt/conda /opt/conda

View File

@ -63,7 +63,7 @@ docker build \
--target final \
--progress plain \
--build-arg "BASE_TARGET=${BASE_TARGET}" \
--build-arg "DEVTOOLSET_VERSION=11" \
--build-arg "DEVTOOLSET_VERSION=13" \
${EXTRA_BUILD_ARGS} \
-t ${tmp_tag} \
$@ \

View File

@ -388,6 +388,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
#ifndef USE_ROCM
at::Half halpha;
at::Half hbeta;
uint32_t mask = -1;
#endif
void * alpha_ptr = α
void * beta_ptr = β
@ -427,7 +428,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
if (fp16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
uint32_t mask =
mask =
fp16_reduction ==
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
@ -444,7 +445,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
if (bf16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
uint32_t mask =
mask =
bf16_reduction ==
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
@ -511,17 +512,41 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Cdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
// on Blackwell+, we fake a n > 1 matmul when querying heuristics
// to prevent cuBLASLt from dispatching to a GEMV kernel for batch-invariance
#ifndef USE_ROCM
const bool lie_to_cublaslt = mask == CUBLASLT_REDUCTION_SCHEME_NONE && n == 1 && at::cuda::getCurrentDeviceProperties()->major >= 10;
#else
const bool lie_to_cublaslt = false;
#endif
if (lie_to_cublaslt) {
CuBlasLtMatrixLayout FakeBdesc(abType, k, 2, ldb, opb == CUBLAS_OP_T);
CuBlasLtMatrixLayout FakeCdesc(cType, m, 2, ldc);
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
FakeBdesc.descriptor(),
FakeCdesc.descriptor(),
FakeCdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
} else {
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Cdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
}
if (returnedResult == 0) {
cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
}

View File

@ -1,5 +1,6 @@
#include <ATen/core/ATen_fwd.h>
#include <c10/core/ScalarType.h>
#include <c10/core/SymInt.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
@ -1710,11 +1711,37 @@ Tensor narrow_symint(
"], but got ",
start,
")")
if (start < 0) {
start = start + cur_size;
auto cond1 = TORCH_GUARD_OR_FALSE(start.sym_lt(0));
auto cond2 = TORCH_GUARD_OR_FALSE(start.sym_ge(0));
if (cond1 || cond2) {
if (cond1) {
start = start + cur_size;
}
TORCH_SYM_CHECK(
start.sym_le(cur_size - length),
"start (",
start,
") + length (",
length,
") exceeds dimension size (",
cur_size,
").");
return at::slice_symint(self, dim, start, start + length, 1);
}
// Unbacked start handling!
// Bounds check without converting start:
// - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start +
// length <= 0
// - If start >= 0: need start + length <= cur_size
auto end = start + length;
TORCH_SYM_CHECK(
start.sym_le(cur_size - length),
(start.sym_lt(0).sym_and((end).sym_le(0)))
.sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))),
"start (",
start,
") + length (",
@ -1722,7 +1749,28 @@ Tensor narrow_symint(
") exceeds dimension size (",
cur_size,
").");
return at::slice_symint(self, dim, start, start + length, 1);
if (TORCH_GUARD_OR_FALSE(end.sym_ne(0))) {
return at::slice_symint(self, dim, start, end, 1);
} else {
// Cannot statically determine the condition due to unbacked.
// This is an interesting situation; when start is negative and
// start + length == 0, slice and narrow do different things.
// i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to
// pass curr_size instead of 0. Otherwise, they would do the same thing.
// This says at runtime: if start < 0 and end == 0, then pass curr_size
// instead of 0.
auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt();
auto result =
at::slice_symint(self, dim, start, end + use_different * cur_size, 1);
// Ensure slice allocated unbacked size is specialized to length.
SymInt new_size = result.sym_size(dim);
TORCH_SYM_CHECK(new_size.sym_eq(length), "")
return result;
}
}
// This overload exists purely for XLA, because they wanted to pass in

View File

@ -212,17 +212,12 @@ static Tensor& bce_loss_out_impl(const Tensor& input,
loss.resize_((reduction == Reduction::None || grad_output.defined()) ? target.sizes() : IntArrayRef({}));
TORCH_CHECK(loss.is_mps());
Tensor loss_squeezed = loss.squeeze();
Tensor input_squeezed = input.squeeze();
Tensor target_squeezed = target.squeeze();
@autoreleasepool {
std::string key =
op_name + reductionToString(reduction) + getTensorsStringKey({input_squeezed, target_squeezed, weight});
std::string key = op_name + reductionToString(reduction) + getTensorsStringKey({input, target, weight});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_squeezed);
newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target_squeezed);
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
MPSGraphTensor* bceLossUnweighted = nil;
// if grad_output is defined, then it's a backward pass
@ -252,12 +247,12 @@ static Tensor& bce_loss_out_impl(const Tensor& input,
newCachedGraph->gradInputTensor = bceLoss;
}
} else {
newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input_squeezed.sizes().size());
newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input.sizes().size());
}
});
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_squeezed);
Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target_squeezed);
Placeholder lossPlaceholder = Placeholder(cachedGraph->lossTensor, loss_squeezed);
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input);
Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target);
Placeholder lossPlaceholder = Placeholder(cachedGraph->lossTensor, loss);
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];

View File

@ -1,4 +1,5 @@
#include <c10/core/SymBool.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymNodeImpl.h>
namespace c10 {
@ -111,4 +112,17 @@ bool SymBool::has_hint() const {
return toSymNodeImpl()->has_hint();
}
SymInt SymBool::toSymInt() const {
// If concrete bool, return concrete SymInt
if (auto ma = maybe_as_bool()) {
return SymInt(*ma ? 1 : 0);
}
// Symbolic case: use sym_ite to convert bool to int (0 or 1)
auto node = toSymNodeImpl();
auto one_node = node->wrap_int(1);
auto zero_node = node->wrap_int(0);
return SymInt(node->sym_ite(one_node, zero_node));
}
} // namespace c10

View File

@ -12,6 +12,8 @@
namespace c10 {
class SymInt;
class C10_API SymBool {
public:
/*implicit*/ SymBool(bool b) : data_(b) {}
@ -80,6 +82,10 @@ class C10_API SymBool {
return toSymNodeImplUnowned()->constant_bool();
}
// Convert SymBool to SymInt (0 or 1)
// This is the C++ equivalent of Python's cast_symbool_to_symint_guardless
SymInt toSymInt() const;
bool is_heap_allocated() const {
return ptr_;
}

View File

@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca
torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None)
torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node
torch.fx.graph.Graph.print_tabular(self)
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode
torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule')
torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool
torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None

View File

@ -6093,26 +6093,19 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
retry_export(
cf_implicitsize(),
(torch.tensor(2), torch.randn(10)),
fixes=[
# Could not guard on data-dependent expression u0 < 0
"torch._check(i >= 0)",
],
fixes=[],
)
class cf_stacklist(torch.nn.Module):
def forward(self, xs, y, fixes):
i = y.item()
eval(fixes)
# instead of xs[i]
return torch.stack(xs, 0).narrow(0, i, 1).squeeze()
retry_export(
cf_stacklist(),
([torch.ones(5) * i for i in range(10)], torch.tensor(2)),
fixes=[
# Could not guard on data-dependent expression u0 < 0
"torch._check(i >= 0)",
],
fixes=[],
)
class cf_tensorsplit(torch.nn.Module):
@ -6166,7 +6159,12 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
class cf_stacklist(torch.nn.Module):
def forward(self, xs, y):
# y.item() is not a local, so we can't suggest a fix
return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
if y.item() < 0:
return (
torch.stack(xs, 0).narrow(0, y.item() + xs.size(), 1).squeeze()
)
else:
return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
with self.assertRaisesRegex(
error_type,
@ -6196,7 +6194,18 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
def forward(self, xs, y):
box = Box(y.item())
# box.content is not a local, so we can't suggest a fix
return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze()
if box.content < 0:
return (
torch.stack(xs, 0)
.narrow(0, box.content + xs.size(), 1)
.squeeze()
)
else:
return (
torch.stack(xs, 0)
.narrow(0, box.content + xs.size(), 1)
.squeeze()
)
with self.assertRaisesRegex(
error_type,

View File

@ -4401,6 +4401,57 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
self.assertEqual(compiled(a, b), func(a, b))
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_narrow_unbacked_start(self):
def func(x, start, length):
# unbacked start
u0 = start.item()
return torch.narrow(x, 0, u0, length)
compiled_func = torch.compile(func, fullgraph=True, backend="inductor")
x = torch.tensor([1, 2, 3, 4, 5, 6])
# Test cases: (start, length)
test_cases = [
# Negative starts
(-2, 2), # Start from second-to-last element
(-1, 1), # Start from last element
(-3, 3), # Start from third-to-last element
(-6, 2), # Start from beginning (negative)
(-4, 1), # Start from fourth-to-last element
# Positive starts
(0, 2), # Start from beginning
(1, 3), # Start from second element
(2, 2), # Start from third element
(4, 2), # Start near end
# Edge cases
(0, 6), # Full tensor
(0, 1), # Single element from start
(5, 1), # Single element from end
]
for start_val, length in test_cases:
with self.subTest(start=start_val, length=length):
start = torch.tensor([start_val])
# Test with compiled function
result_compiled = compiled_func(x, start, length)
# Test with eager function (expected behavior)
result_eager = func(x, start, length)
# Compare results
self.assertEqual(result_compiled, result_eager)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_narrow_unbacked_start_cpp_wrapper(self):
"""Test narrow with unbacked start with cpp_wrapper"""
self.test_narrow_unbacked_start()
instantiate_parametrized_tests(TestUnbacked)

View File

@ -72,9 +72,16 @@ from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
skipIfTorchDynamo,
skipIfRocm,
)
from torch.testing._internal.jit_utils import JitTestCase
import json
import tempfile
from torch.profiler import profile, ProfilerActivity
from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace
from torch.autograd.profiler_util import _canonicalize_profiler_events
try:
from torchvision import models as torchvision_models
@ -201,6 +208,36 @@ def side_effect_func(x: torch.Tensor):
print(x)
def _enrich_profiler_traces(prof):
"""
Helper function to extract and augment profiler events with stack traces.
Args:
prof: A torch.profiler.profile object
Returns:
A string representing enriched events
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f:
trace_file = f.name
prof.export_chrome_trace(trace_file)
with open(trace_file) as f:
trace_data = json.load(f)
map_recorded_events_to_aten_ops_with_stack_trace(
trace_data
)
events = []
for event in trace_data["traceEvents"]:
if "args" in event and "stack_trace" in event["args"]:
events.append(event)
actual_traces = _canonicalize_profiler_events(events)
return actual_traces
class TestFX(JitTestCase):
def setUp(self):
super().setUp()
@ -4212,6 +4249,153 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
# recorver mutable checking flag
torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_stack_trace_augmentation(self):
"""
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
augments profiler events with stack traces from FX metadata registry.
"""
# Simple test model
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(16, 10)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
model = TestModel().cuda()
# Compile the model
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_model(torch.randn(10, 10, device="cuda"))
# Profile with the compiled model
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
result = compiled_model(torch.randn(10, 10, device="cuda"))
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::t node=t stack_trace=x = self.linear1(x)
event=aten::transpose node=t stack_trace=x = self.linear1(x)
event=aten::as_strided node=t stack_trace=x = self.linear1(x)
event=aten::addmm node=addmm stack_trace=x = self.linear1(x)
event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x)
event=aten::relu node=relu stack_trace=x = self.relu(x)
event=aten::clamp_min node=relu stack_trace=x = self.relu(x)
event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x)
event=aten::t node=t_1 stack_trace=x = self.linear2(x)
event=aten::transpose node=t_1 stack_trace=x = self.linear2(x)
event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x)
event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x)
event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_multiple_modules(self):
"""
Test that multiple compiled modules under the same profiler session
have their events correctly augmented with stack traces.
"""
class ModelA(torch.nn.Module):
def forward(self, x):
return x + 1
class ModelB(torch.nn.Module):
def forward(self, x):
return x - 1
model_a = ModelA().cuda()
model_b = ModelB().cuda()
# Compile both models
compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True)
compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_a(torch.randn(10, 10, device="cuda"))
_ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
# Profile both models in the same session
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
result_a = compiled_a(torch.randn(10, 10, device="cuda"))
result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::add node=add stack_trace=return x + 1
event=cudaLaunchKernel node=add stack_trace=return x + 1
event=aten::sub node=sub stack_trace=return x - 1
event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_nested_graph_modules(self):
"""
Test that nested graph modules (e.g., graph modules calling subgraphs)
have their events correctly augmented with stack traces.
"""
# Model with nested structure
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.c = 5
@torch.compiler.nested_compile_region
def forward(self, x, y):
m = torch.mul(x, y)
s = m.sin()
a = s + self.c
return a
model = Mod().cuda()
# Compile the model (this may create nested graph modules)
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
# Profile
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::mul node=mul stack_trace=m = torch.mul(x, y)
event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y)
event=aten::sin node=sin stack_trace=s = m.sin()
event=cudaLaunchKernel node=sin stack_trace=s = m.sin()
event=aten::add node=add stack_trace=a = s + self.c
event=cudaLaunchKernel node=add stack_trace=a = s + self.c"""
)
def run_getitem_target():
from torch.fx._symbolic_trace import _wrapped_methods_to_patch

View File

@ -359,6 +359,29 @@ class TestMatmulCuda(InductorTestCase):
self.assertEqual(agrad, a.grad)
self.assertEqual(bgrad, b.grad)
@onlyCUDA
@skipIfRocm
@dtypes(torch.half, torch.bfloat16)
@unittest.skipIf(not SM100OrLater, "cuBLAS integration for batch invariance is only on Blackwell")
@serialTest()
def test_cublas_batch_invariance_blackwell(self, device, dtype):
orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (False, False)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (False, False)
with blas_library_context('cublaslt'):
N = 2048
K = 6144
M_max = 32
x = torch.randn(M_max, K, device="cuda", dtype=torch.bfloat16)
w = torch.randn(N, K, device="cuda", dtype=torch.bfloat16).t()
full = x @ w
xx = x[:1]
out = xx @ w
self.assertEqual(full[:1], out, atol=0., rtol=0.)
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16
@unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater")
@parametrize("strided", [False, True])
@parametrize("a_row_major", [False, True])

View File

@ -4470,6 +4470,14 @@ class TestMPS(TestCaseMPS):
self.assertEqual(out1, out2)
def test_bce_backward_with_no_reduction_and_one_in_shape(self):
# Regression test for https://github.com/pytorch/pytorch/issues/166746
output = torch.zeros(3, 2, 1, requires_grad=True, device='mps')
target = torch.zeros(3, 2, 1, device='mps')
torch.sum(nn.BCELoss(reduction='none')(output, target)).backward()
expected_grad = torch.zeros(3, 2, 1, device='mps')
self.assertEqual(output.grad, expected_grad)
def test_cross_entropy_loss(self):
# Regression test for https://github.com/pytorch/pytorch/issues/116095
loss = nn.CrossEntropyLoss()

View File

@ -257,34 +257,6 @@ class TestFuzzerCompileIssues(TestCase):
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #163971")
def test_fuzzer_issue_163971(self):
torch.manual_seed(0)
def foo(arg0):
t0 = arg0 # size=(), stride=(), dtype=bfloat16, device=cuda
t1 = torch.softmax(
t0, dim=0
) # size=(), stride=(), dtype=bfloat16, device=cuda
t2 = torch.nn.functional.gelu(
t1
) # size=(), stride=(), dtype=bfloat16, device=cuda
t3 = torch.softmax(
t2, dim=0
) # size=(), stride=(), dtype=bfloat16, device=cuda
output = t3
return output
arg0 = torch.rand([], dtype=torch.bfloat16, device="cuda", requires_grad=True)
out_eager = foo(arg0)
out_eager.sum().backward()
print("Eager Success! ✅")
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
out_compiled = compiled_foo(arg0)
out_compiled.sum().backward()
print("Compile Success! ✅")
@pytest.mark.xfail(reason="Issue #164059")
def test_fuzzer_issue_164059(self):
torch.manual_seed(0)

View File

@ -1914,6 +1914,7 @@ class TestSDPAFailureModes(NNTestCase):
q, k, v, None, 0.0, is_causal=True))
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
def test_mem_eff_attention_fail_with_batch_size_geq_65536(self):
batch_size = 2**16
query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
@ -1935,6 +1936,7 @@ class TestSDPAFailureModes(NNTestCase):
self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4)
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self):
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
@ -1948,6 +1950,7 @@ class TestSDPAFailureModes(NNTestCase):
@largeTensorTest("15GB", "cuda")
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
def test_mem_eff_attention_large_seq_len_uniform_attention(self):
device = torch.device("cuda")
dtype = torch.bfloat16

View File

@ -2063,7 +2063,8 @@ class PythonWrapperCodegen(CodeGen):
neg = self.codegen_sizevar(
sympy.Max(0, sympy.Min(x + node.size, node.size))
)
return f"{pos} if {x} >= 0 else {neg}"
x_cond = self.codegen_sizevar(x)
return f"{pos} if {x_cond} >= 0 else {neg}"
def codegen_with_step(start_var, end_var, step):
if step == 1:

View File

@ -1224,3 +1224,43 @@ def _build_table(
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
)
return "".join(result)
# Collect all events with stack traces and format them canonically
def _canonicalize_profiler_events(events):
"""
Extract and format all events with stack traces in a canonical way
for deterministic testing.
"""
events_with_traces = []
for event in events:
# Extract relevant fields
event_name = event.get("name", "")
node_name = event["args"].get("node_name", "")
stack_trace = event["args"].get("stack_trace", "")
# Get the last non-empty line of the stack trace
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
stack_trace = lines[-1] if lines else ""
events_with_traces.append(
{
"event_name": event_name[:20],
"node_name": node_name,
"stack_trace": stack_trace,
"start_time": event.get("ts", 0),
}
)
# Sort by node_name for deterministic ordering
events_with_traces.sort(key=lambda x: x["start_time"])
# Format as a string
lines: list[str] = []
for evt in events_with_traces:
lines.append(
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
)
return "\n".join(lines)

View File

@ -547,6 +547,7 @@ def rebind_unbacked(
assert shape_env is not None
for raw_u0, path in bindings.items():
u1 = pytree.key_get(result, path)
# Sometimes, things were previously unbacked bindings become constants.
# There are two situations this can happen.
#
@ -602,7 +603,23 @@ def rebind_unbacked(
if u1.node.hint is not None:
continue
raw_u1 = u1.node.expr
# unbacked symbols bindings might be replaced to other backed or
# unbacked replacements.
#
# Example:
# u = x.item()
# torch._check(u == 5)
#
# The safest approach is to retrieve raw_u1 from u1.node._expr
# and perform the rebinding on the original unbacked symbol,
# even if its no longer directly referenced.
#
# In other words, we should always rebind the original symbol
# before any replacements are applied.
# u0 -> u0 == s1
raw_u1 = u1.node._expr
# TODO Do we still need this logic below?
# Simplify SymBool binding
if (
isinstance(raw_u1, sympy.Piecewise)

View File

@ -443,6 +443,7 @@ class CodeGen:
colored: bool = False,
# Render each argument on its own line
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
free_vars: list[str] = []
body: list[str] = []
@ -817,6 +818,10 @@ class CodeGen:
return
raise NotImplementedError(f"node: {node.op} {node.target}")
if record_func:
body.append(
"_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n"
)
for i, node in enumerate(nodes):
# NOTE: emit_node does not emit a string with newline. It depends
# on delete_unused_values to append one
@ -826,8 +831,22 @@ class CodeGen:
# node index, which will be deleted later
# after going through _body_transformer
body.append(f"# COUNTER: {i}\n")
do_record = record_func and node.op in (
"call_function",
"call_method",
"call_module",
)
if do_record:
# The double hash ## convention is used by post-processing to find the fx markers
body.append(
f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n"
)
emit_node(node)
delete_unused_values(node)
if do_record:
body.append(f"_rf_{node.name}.__exit__(None, None, None)\n")
if record_func:
body.append("_rf.__exit__(None, None, None)\n")
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
@ -1779,6 +1798,7 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
"""
Turn this ``Graph`` into valid Python code.
@ -1846,6 +1866,7 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def _python_code(
@ -1858,6 +1879,7 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
return self._codegen._gen_python_code(
self.nodes,
@ -1868,6 +1890,7 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def __str__(self) -> str:

View File

@ -861,14 +861,18 @@ class {module_name}(torch.nn.Module):
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
python_code = self._graph.python_code(root_module="self")
from torch._dynamo import config as dynamo_config
python_code = self._graph.python_code(
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
)
self._code = python_code.src
self._lineno_map = python_code._lineno_map
self._prologue_start = python_code._prologue_start
cls = type(self)
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
from torch._dynamo import config as dynamo_config
if dynamo_config.enrich_profiler_metadata:
# Generate metadata and register for profiler augmentation
@ -885,7 +889,6 @@ class {module_name}(torch.nn.Module):
# This ensures the same code+metadata always generates the same filename
hash_value = _metadata_hash(self._code, node_metadata)
file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}"
filename = f"{file_stem}.py"
# Only include co_filename to use it directly as the cache key
@ -905,6 +908,13 @@ class {module_name}(torch.nn.Module):
_register_fx_metadata(filename, metadata)
# Replace the placeholder in generated code with actual filename
# The double hash ## convention is used by post-processing to find the fx markers
self._code = self._code.replace(
"torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')",
f"torch._C._profiler._RecordFunctionFast('## {filename} ##')",
)
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
# Determine whether this class explicitly defines a __call__ implementation

View File

@ -4,7 +4,7 @@ import operator
import re
from collections import deque
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import Any, Literal, Optional, TYPE_CHECKING
from torch.autograd.profiler import profile
from torch.profiler import DeviceType
@ -400,3 +400,170 @@ def _init_for_cuda_graphs() -> None:
with profile():
pass
@dataclass
class TimelineEvent:
"""Represents an event in the profiler timeline."""
timestamp: int
event_type: Literal["start", "end", "regular"]
marker_type: Optional[Literal["filename", "node"]]
identifier: Optional[str | int]
event: dict[str, Any]
@dataclass
class ContextStackEntry:
"""Represents a context (filename or node) in the stack."""
context_type: Literal["filename", "node"]
identifier: str | int
metadata: Optional[dict]
tid: Optional[int] = None # Thread ID associated with this context
def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
"""
Maps recorded profiler events to their corresponding fx nodes and adds stack traces.
Builds a timeline of all events (regular ops and FX markers for filenames/nodes),
sorts by timestamp, then processes chronologically while maintaining a context stack of active
filename/node scopes. Regular events are augmented with stack traces and node names from the
innermost active context. Runtime is O(n log n) for n events.
Args:
traced_data: Json of profiler events from Chrome trace
Returns:
Dict mapping recorded event names to their aten operations with added stack traces
"""
from torch.fx.traceback import _FX_METADATA_REGISTRY
trace_events = traced_data.get("traceEvents", [])
# Create event timeline
event_timeline: list[TimelineEvent] = []
def is_fx_marker_event(event):
return (
event.get("cat") == "cpu_op"
and event.get("name", "").startswith("## ")
and event.get("name", "").endswith(" ##")
)
def append_fx_marker_event(event_type, identifier, event):
start_ts = event["ts"]
end_ts = start_ts + event["dur"]
event_timeline.append(
TimelineEvent(start_ts, "start", event_type, identifier, event)
)
event_timeline.append(
TimelineEvent(end_ts, "end", event_type, identifier, event)
)
for event in trace_events:
if "ts" not in event or "dur" not in event:
continue
if is_fx_marker_event(event):
content = event["name"][3:-3]
if content.endswith(".py"):
append_fx_marker_event("filename", content, event)
else:
try:
node_index = int(content)
except ValueError:
pass
append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined]
else:
# Regular event that needs augmentation
start_ts = event["ts"]
event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event))
# Sort by timestamp
event_timeline.sort(key=lambda x: x.timestamp)
# Process events in chronological order with a stack
context_stack: list[ContextStackEntry] = []
# Invariant: all start event has a corresponding end event
for timeline_event in event_timeline:
match timeline_event.event_type:
case "start":
assert timeline_event.identifier is not None
if timeline_event.marker_type == "filename":
assert isinstance(timeline_event.identifier, str)
# Push filename context - query metadata registry on-demand
metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier)
tid = timeline_event.event.get("tid")
context_stack.append(
ContextStackEntry(
"filename", timeline_event.identifier, metadata, tid
)
)
elif timeline_event.marker_type == "node":
# Find the current filename from stack
current_file_metadata = None
tid = timeline_event.event.get("tid")
for ctx_entry in reversed(context_stack):
if (
ctx_entry.context_type == "filename"
and ctx_entry.tid == tid
):
current_file_metadata = ctx_entry.metadata
break
if current_file_metadata:
node_metadata = current_file_metadata.get("node_metadata", {})
if timeline_event.identifier in node_metadata:
node_meta: Optional[dict] = node_metadata[
timeline_event.identifier
]
context_stack.append(
ContextStackEntry(
"node", timeline_event.identifier, node_meta, tid
)
)
case "end":
# Pop from stack - search backwards to find matching context
for i in range(len(context_stack) - 1, -1, -1):
ctx_entry = context_stack[i]
if (
timeline_event.marker_type == ctx_entry.context_type
and timeline_event.identifier == ctx_entry.identifier
):
context_stack.pop(i)
break
case "regular":
# Apply metadata from current context stack
# Find the most specific context (node takes precedence over filename)
# Only augment events with the same tid as the file/node event matched
current_stack_trace = None
current_node_name = None
event_tid = timeline_event.event.get("tid")
for ctx_entry in reversed(context_stack):
# Only apply metadata from contexts with matching tid
if ctx_entry.tid == event_tid:
if ctx_entry.context_type == "node" and ctx_entry.metadata:
current_stack_trace = ctx_entry.metadata.get(
"stack_trace", "No model stack trace available"
)
current_node_name = ctx_entry.metadata.get("name", "")
# Do we want to only attach the stack trace of the lowest node or stack trace of all nodes
# if nodes are nested, e.g. in nested graph modules
break
# Augment the event
if current_stack_trace or current_node_name:
args = timeline_event.event.setdefault("args", {})
if current_stack_trace:
args["stack_trace"] = current_stack_trace
if current_node_name:
args["node_name"] = current_node_name

View File

@ -306,6 +306,24 @@ class PythonPrinter(ExprPrinter):
raise TypeError("ndigits must be an instance of sympy.Integer")
return f"round({self._print(number)}, {ndigits})"
def _print_Piecewise(self, expr: sympy.Expr) -> str:
# Convert Piecewise(expr_cond_pairs) to nested ternary expressions
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
# becomes: e1 if c1 else (e2 if c2 else (... else eN))
result: Optional[str] = None
for expr_i, cond_i in reversed(expr.args):
expr_str = self._print(expr_i)
if cond_i == True: # noqa: E712
# This is the default case
result = expr_str
else:
cond_str = self._print(cond_i)
if result is None:
result = expr_str
else:
result = f"({expr_str} if {cond_str} else {result})"
return result if result else "0"
class CppPrinter(ExprPrinter):
def _print_Integer(self, expr: sympy.Expr) -> str:
@ -327,6 +345,24 @@ class CppPrinter(ExprPrinter):
)
return f"{c} ? {p} : {q}"
def _print_Piecewise(self, expr: sympy.Expr) -> str:
# Convert Piecewise(expr_cond_pairs) to nested ternary operators
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
# becomes: c1 ? e1 : (c2 ? e2 : (... : eN))
result: Optional[str] = None
for expr_i, cond_i in reversed(expr.args):
expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5)
if cond_i == True: # noqa: E712
# This is the default case
result = expr_str
else:
cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5)
if result is None:
result = expr_str
else:
result = f"{cond_str} ? {expr_str} : {result}"
return f"({result})" if result else "0"
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
x, div, mod = expr.args
x = self.doprint(x)