mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 22:25:03 +08:00
Update base for Update on "[Inductor XPU GEMM] Step 1/N: Refactor cutlass configuration."
This PR is the first step toward implementing RFC #160175. Currently, all Cutlass-related Torch Inductor configs are located in `torch._inductor.config.cuda`. This PR refactors the device-agnostic Cutlass configurations into `torch._inductor.config.cutlass`, so they can be shared and reused by XPU as well. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
This commit is contained in:
@ -7,13 +7,13 @@ ENV LC_ALL en_US.UTF-8
|
||||
ENV LANG en_US.UTF-8
|
||||
ENV LANGUAGE en_US.UTF-8
|
||||
|
||||
ARG DEVTOOLSET_VERSION=11
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
|
||||
RUN yum -y update
|
||||
RUN yum -y install epel-release
|
||||
# install glibc-langpack-en make sure en_US.UTF-8 locale is available
|
||||
RUN yum -y install glibc-langpack-en
|
||||
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-toolchain
|
||||
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
|
||||
# Just add everything as a safe.directory for git since these will be used in multiple places with git
|
||||
RUN git config --global --add safe.directory '*'
|
||||
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||
@ -41,6 +41,7 @@ RUN bash ./install_conda.sh && rm install_conda.sh
|
||||
# Install CUDA
|
||||
FROM base as cuda
|
||||
ARG CUDA_VERSION=12.6
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
RUN rm -rf /usr/local/cuda-*
|
||||
ADD ./common/install_cuda.sh install_cuda.sh
|
||||
COPY ./common/install_nccl.sh install_nccl.sh
|
||||
@ -50,7 +51,8 @@ ENV CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
|
||||
# Preserve CUDA_VERSION for the builds
|
||||
ENV CUDA_VERSION=${CUDA_VERSION}
|
||||
# Make things in our path by default
|
||||
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:$PATH
|
||||
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||
|
||||
|
||||
FROM cuda as cuda12.6
|
||||
RUN bash ./install_cuda.sh 12.6
|
||||
@ -68,8 +70,22 @@ FROM cuda as cuda13.0
|
||||
RUN bash ./install_cuda.sh 13.0
|
||||
ENV DESIRED_CUDA=13.0
|
||||
|
||||
FROM ${ROCM_IMAGE} as rocm
|
||||
FROM ${ROCM_IMAGE} as rocm_base
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
ENV LC_ALL en_US.UTF-8
|
||||
ENV LANG en_US.UTF-8
|
||||
ENV LANGUAGE en_US.UTF-8
|
||||
# Install devtoolset on ROCm base image
|
||||
RUN yum -y update && \
|
||||
yum -y install epel-release && \
|
||||
yum -y install glibc-langpack-en && \
|
||||
yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
|
||||
RUN git config --global --add safe.directory '*'
|
||||
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||
|
||||
FROM rocm_base as rocm
|
||||
ARG PYTORCH_ROCM_ARCH
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
|
||||
ADD ./common/install_mkl.sh install_mkl.sh
|
||||
RUN bash ./install_mkl.sh && rm install_mkl.sh
|
||||
@ -88,6 +104,7 @@ COPY --from=cuda13.0 /usr/local/cuda-13.0 /usr/local/cuda-13.0
|
||||
|
||||
# Final step
|
||||
FROM ${BASE_TARGET} as final
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
COPY --from=openssl /opt/openssl /opt/openssl
|
||||
COPY --from=patchelf /patchelf /usr/local/bin/patchelf
|
||||
COPY --from=conda /opt/conda /opt/conda
|
||||
|
||||
@ -63,7 +63,7 @@ docker build \
|
||||
--target final \
|
||||
--progress plain \
|
||||
--build-arg "BASE_TARGET=${BASE_TARGET}" \
|
||||
--build-arg "DEVTOOLSET_VERSION=11" \
|
||||
--build-arg "DEVTOOLSET_VERSION=13" \
|
||||
${EXTRA_BUILD_ARGS} \
|
||||
-t ${tmp_tag} \
|
||||
$@ \
|
||||
|
||||
@ -388,6 +388,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
#ifndef USE_ROCM
|
||||
at::Half halpha;
|
||||
at::Half hbeta;
|
||||
uint32_t mask = -1;
|
||||
#endif
|
||||
void * alpha_ptr = α
|
||||
void * beta_ptr = β
|
||||
@ -427,7 +428,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
|
||||
if (fp16_reduction !=
|
||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||
uint32_t mask =
|
||||
mask =
|
||||
fp16_reduction ==
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||
@ -444,7 +445,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
|
||||
if (bf16_reduction !=
|
||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||
uint32_t mask =
|
||||
mask =
|
||||
bf16_reduction ==
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||
@ -511,17 +512,41 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||
int returnedResult = 0;
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
Adesc.descriptor(),
|
||||
Bdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
preference.descriptor(),
|
||||
1,
|
||||
&heuristicResult,
|
||||
&returnedResult));
|
||||
// on Blackwell+, we fake a n > 1 matmul when querying heuristics
|
||||
// to prevent cuBLASLt from dispatching to a GEMV kernel for batch-invariance
|
||||
#ifndef USE_ROCM
|
||||
const bool lie_to_cublaslt = mask == CUBLASLT_REDUCTION_SCHEME_NONE && n == 1 && at::cuda::getCurrentDeviceProperties()->major >= 10;
|
||||
#else
|
||||
const bool lie_to_cublaslt = false;
|
||||
#endif
|
||||
if (lie_to_cublaslt) {
|
||||
CuBlasLtMatrixLayout FakeBdesc(abType, k, 2, ldb, opb == CUBLAS_OP_T);
|
||||
CuBlasLtMatrixLayout FakeCdesc(cType, m, 2, ldc);
|
||||
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
Adesc.descriptor(),
|
||||
FakeBdesc.descriptor(),
|
||||
FakeCdesc.descriptor(),
|
||||
FakeCdesc.descriptor(),
|
||||
preference.descriptor(),
|
||||
1,
|
||||
&heuristicResult,
|
||||
&returnedResult));
|
||||
} else {
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
Adesc.descriptor(),
|
||||
Bdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
preference.descriptor(),
|
||||
1,
|
||||
&heuristicResult,
|
||||
&returnedResult));
|
||||
}
|
||||
if (returnedResult == 0) {
|
||||
cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
@ -1710,11 +1711,37 @@ Tensor narrow_symint(
|
||||
"], but got ",
|
||||
start,
|
||||
")")
|
||||
if (start < 0) {
|
||||
start = start + cur_size;
|
||||
|
||||
auto cond1 = TORCH_GUARD_OR_FALSE(start.sym_lt(0));
|
||||
auto cond2 = TORCH_GUARD_OR_FALSE(start.sym_ge(0));
|
||||
|
||||
if (cond1 || cond2) {
|
||||
if (cond1) {
|
||||
start = start + cur_size;
|
||||
}
|
||||
|
||||
TORCH_SYM_CHECK(
|
||||
start.sym_le(cur_size - length),
|
||||
"start (",
|
||||
start,
|
||||
") + length (",
|
||||
length,
|
||||
") exceeds dimension size (",
|
||||
cur_size,
|
||||
").");
|
||||
return at::slice_symint(self, dim, start, start + length, 1);
|
||||
}
|
||||
|
||||
// Unbacked start handling!
|
||||
|
||||
// Bounds check without converting start:
|
||||
// - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start +
|
||||
// length <= 0
|
||||
// - If start >= 0: need start + length <= cur_size
|
||||
auto end = start + length;
|
||||
TORCH_SYM_CHECK(
|
||||
start.sym_le(cur_size - length),
|
||||
(start.sym_lt(0).sym_and((end).sym_le(0)))
|
||||
.sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))),
|
||||
"start (",
|
||||
start,
|
||||
") + length (",
|
||||
@ -1722,7 +1749,28 @@ Tensor narrow_symint(
|
||||
") exceeds dimension size (",
|
||||
cur_size,
|
||||
").");
|
||||
return at::slice_symint(self, dim, start, start + length, 1);
|
||||
|
||||
if (TORCH_GUARD_OR_FALSE(end.sym_ne(0))) {
|
||||
return at::slice_symint(self, dim, start, end, 1);
|
||||
} else {
|
||||
// Cannot statically determine the condition due to unbacked.
|
||||
// This is an interesting situation; when start is negative and
|
||||
// start + length == 0, slice and narrow do different things.
|
||||
// i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to
|
||||
// pass curr_size instead of 0. Otherwise, they would do the same thing.
|
||||
// This says at runtime: if start < 0 and end == 0, then pass curr_size
|
||||
// instead of 0.
|
||||
|
||||
auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt();
|
||||
auto result =
|
||||
at::slice_symint(self, dim, start, end + use_different * cur_size, 1);
|
||||
|
||||
// Ensure slice allocated unbacked size is specialized to length.
|
||||
SymInt new_size = result.sym_size(dim);
|
||||
TORCH_SYM_CHECK(new_size.sym_eq(length), "")
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// This overload exists purely for XLA, because they wanted to pass in
|
||||
|
||||
@ -212,17 +212,12 @@ static Tensor& bce_loss_out_impl(const Tensor& input,
|
||||
loss.resize_((reduction == Reduction::None || grad_output.defined()) ? target.sizes() : IntArrayRef({}));
|
||||
TORCH_CHECK(loss.is_mps());
|
||||
|
||||
Tensor loss_squeezed = loss.squeeze();
|
||||
Tensor input_squeezed = input.squeeze();
|
||||
Tensor target_squeezed = target.squeeze();
|
||||
|
||||
@autoreleasepool {
|
||||
std::string key =
|
||||
op_name + reductionToString(reduction) + getTensorsStringKey({input_squeezed, target_squeezed, weight});
|
||||
std::string key = op_name + reductionToString(reduction) + getTensorsStringKey({input, target, weight});
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_squeezed);
|
||||
newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target_squeezed);
|
||||
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
|
||||
|
||||
MPSGraphTensor* bceLossUnweighted = nil;
|
||||
// if grad_output is defined, then it's a backward pass
|
||||
@ -252,12 +247,12 @@ static Tensor& bce_loss_out_impl(const Tensor& input,
|
||||
newCachedGraph->gradInputTensor = bceLoss;
|
||||
}
|
||||
} else {
|
||||
newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input_squeezed.sizes().size());
|
||||
newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input.sizes().size());
|
||||
}
|
||||
});
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_squeezed);
|
||||
Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target_squeezed);
|
||||
Placeholder lossPlaceholder = Placeholder(cachedGraph->lossTensor, loss_squeezed);
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input);
|
||||
Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target);
|
||||
Placeholder lossPlaceholder = Placeholder(cachedGraph->lossTensor, loss);
|
||||
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
#include <c10/core/SymBool.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
|
||||
namespace c10 {
|
||||
@ -111,4 +112,17 @@ bool SymBool::has_hint() const {
|
||||
return toSymNodeImpl()->has_hint();
|
||||
}
|
||||
|
||||
SymInt SymBool::toSymInt() const {
|
||||
// If concrete bool, return concrete SymInt
|
||||
if (auto ma = maybe_as_bool()) {
|
||||
return SymInt(*ma ? 1 : 0);
|
||||
}
|
||||
|
||||
// Symbolic case: use sym_ite to convert bool to int (0 or 1)
|
||||
auto node = toSymNodeImpl();
|
||||
auto one_node = node->wrap_int(1);
|
||||
auto zero_node = node->wrap_int(0);
|
||||
return SymInt(node->sym_ite(one_node, zero_node));
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
@ -12,6 +12,8 @@
|
||||
|
||||
namespace c10 {
|
||||
|
||||
class SymInt;
|
||||
|
||||
class C10_API SymBool {
|
||||
public:
|
||||
/*implicit*/ SymBool(bool b) : data_(b) {}
|
||||
@ -80,6 +82,10 @@ class C10_API SymBool {
|
||||
return toSymNodeImplUnowned()->constant_bool();
|
||||
}
|
||||
|
||||
// Convert SymBool to SymInt (0 or 1)
|
||||
// This is the C++ equivalent of Python's cast_symbool_to_symint_guardless
|
||||
SymInt toSymInt() const;
|
||||
|
||||
bool is_heap_allocated() const {
|
||||
return ptr_;
|
||||
}
|
||||
|
||||
@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca
|
||||
torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None)
|
||||
torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node
|
||||
torch.fx.graph.Graph.print_tabular(self)
|
||||
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode
|
||||
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode
|
||||
torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule')
|
||||
torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool
|
||||
torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None
|
||||
|
||||
@ -6093,26 +6093,19 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
retry_export(
|
||||
cf_implicitsize(),
|
||||
(torch.tensor(2), torch.randn(10)),
|
||||
fixes=[
|
||||
# Could not guard on data-dependent expression u0 < 0
|
||||
"torch._check(i >= 0)",
|
||||
],
|
||||
fixes=[],
|
||||
)
|
||||
|
||||
class cf_stacklist(torch.nn.Module):
|
||||
def forward(self, xs, y, fixes):
|
||||
i = y.item()
|
||||
eval(fixes)
|
||||
# instead of xs[i]
|
||||
return torch.stack(xs, 0).narrow(0, i, 1).squeeze()
|
||||
|
||||
retry_export(
|
||||
cf_stacklist(),
|
||||
([torch.ones(5) * i for i in range(10)], torch.tensor(2)),
|
||||
fixes=[
|
||||
# Could not guard on data-dependent expression u0 < 0
|
||||
"torch._check(i >= 0)",
|
||||
],
|
||||
fixes=[],
|
||||
)
|
||||
|
||||
class cf_tensorsplit(torch.nn.Module):
|
||||
@ -6166,7 +6159,12 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
class cf_stacklist(torch.nn.Module):
|
||||
def forward(self, xs, y):
|
||||
# y.item() is not a local, so we can't suggest a fix
|
||||
return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
|
||||
if y.item() < 0:
|
||||
return (
|
||||
torch.stack(xs, 0).narrow(0, y.item() + xs.size(), 1).squeeze()
|
||||
)
|
||||
else:
|
||||
return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
error_type,
|
||||
@ -6196,7 +6194,18 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
def forward(self, xs, y):
|
||||
box = Box(y.item())
|
||||
# box.content is not a local, so we can't suggest a fix
|
||||
return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze()
|
||||
if box.content < 0:
|
||||
return (
|
||||
torch.stack(xs, 0)
|
||||
.narrow(0, box.content + xs.size(), 1)
|
||||
.squeeze()
|
||||
)
|
||||
else:
|
||||
return (
|
||||
torch.stack(xs, 0)
|
||||
.narrow(0, box.content + xs.size(), 1)
|
||||
.squeeze()
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
error_type,
|
||||
|
||||
@ -4401,6 +4401,57 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
|
||||
|
||||
self.assertEqual(compiled(a, b), func(a, b))
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_narrow_unbacked_start(self):
|
||||
def func(x, start, length):
|
||||
# unbacked start
|
||||
u0 = start.item()
|
||||
return torch.narrow(x, 0, u0, length)
|
||||
|
||||
compiled_func = torch.compile(func, fullgraph=True, backend="inductor")
|
||||
|
||||
x = torch.tensor([1, 2, 3, 4, 5, 6])
|
||||
|
||||
# Test cases: (start, length)
|
||||
test_cases = [
|
||||
# Negative starts
|
||||
(-2, 2), # Start from second-to-last element
|
||||
(-1, 1), # Start from last element
|
||||
(-3, 3), # Start from third-to-last element
|
||||
(-6, 2), # Start from beginning (negative)
|
||||
(-4, 1), # Start from fourth-to-last element
|
||||
# Positive starts
|
||||
(0, 2), # Start from beginning
|
||||
(1, 3), # Start from second element
|
||||
(2, 2), # Start from third element
|
||||
(4, 2), # Start near end
|
||||
# Edge cases
|
||||
(0, 6), # Full tensor
|
||||
(0, 1), # Single element from start
|
||||
(5, 1), # Single element from end
|
||||
]
|
||||
|
||||
for start_val, length in test_cases:
|
||||
with self.subTest(start=start_val, length=length):
|
||||
start = torch.tensor([start_val])
|
||||
|
||||
# Test with compiled function
|
||||
result_compiled = compiled_func(x, start, length)
|
||||
|
||||
# Test with eager function (expected behavior)
|
||||
result_eager = func(x, start, length)
|
||||
|
||||
# Compare results
|
||||
self.assertEqual(result_compiled, result_eager)
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
@torch._inductor.config.patch("cpp_wrapper", True)
|
||||
def test_narrow_unbacked_start_cpp_wrapper(self):
|
||||
"""Test narrow with unbacked start with cpp_wrapper"""
|
||||
self.test_narrow_unbacked_start()
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestUnbacked)
|
||||
|
||||
|
||||
184
test/test_fx.py
184
test/test_fx.py
@ -72,9 +72,16 @@ from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
skipIfRocm,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace
|
||||
from torch.autograd.profiler_util import _canonicalize_profiler_events
|
||||
|
||||
try:
|
||||
from torchvision import models as torchvision_models
|
||||
|
||||
@ -201,6 +208,36 @@ def side_effect_func(x: torch.Tensor):
|
||||
print(x)
|
||||
|
||||
|
||||
def _enrich_profiler_traces(prof):
|
||||
"""
|
||||
Helper function to extract and augment profiler events with stack traces.
|
||||
|
||||
Args:
|
||||
prof: A torch.profiler.profile object
|
||||
|
||||
Returns:
|
||||
A string representing enriched events
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f:
|
||||
trace_file = f.name
|
||||
prof.export_chrome_trace(trace_file)
|
||||
|
||||
with open(trace_file) as f:
|
||||
trace_data = json.load(f)
|
||||
|
||||
map_recorded_events_to_aten_ops_with_stack_trace(
|
||||
trace_data
|
||||
)
|
||||
|
||||
events = []
|
||||
for event in trace_data["traceEvents"]:
|
||||
if "args" in event and "stack_trace" in event["args"]:
|
||||
events.append(event)
|
||||
|
||||
actual_traces = _canonicalize_profiler_events(events)
|
||||
return actual_traces
|
||||
|
||||
|
||||
class TestFX(JitTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -4212,6 +4249,153 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
|
||||
# recorver mutable checking flag
|
||||
torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_stack_trace_augmentation(self):
|
||||
"""
|
||||
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
|
||||
augments profiler events with stack traces from FX metadata registry.
|
||||
"""
|
||||
|
||||
# Simple test model
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.linear2 = torch.nn.Linear(16, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.relu(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
model = TestModel().cuda()
|
||||
|
||||
# Compile the model
|
||||
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
|
||||
|
||||
# Warmup
|
||||
for _ in range(3):
|
||||
_ = compiled_model(torch.randn(10, 10, device="cuda"))
|
||||
|
||||
# Profile with the compiled model
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
) as prof:
|
||||
result = compiled_model(torch.randn(10, 10, device="cuda"))
|
||||
|
||||
actual_traces = _enrich_profiler_traces(prof)
|
||||
|
||||
self.assertExpectedInline(actual_traces, """\
|
||||
event=aten::t node=t stack_trace=x = self.linear1(x)
|
||||
event=aten::transpose node=t stack_trace=x = self.linear1(x)
|
||||
event=aten::as_strided node=t stack_trace=x = self.linear1(x)
|
||||
event=aten::addmm node=addmm stack_trace=x = self.linear1(x)
|
||||
event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x)
|
||||
event=aten::relu node=relu stack_trace=x = self.relu(x)
|
||||
event=aten::clamp_min node=relu stack_trace=x = self.relu(x)
|
||||
event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x)
|
||||
event=aten::t node=t_1 stack_trace=x = self.linear2(x)
|
||||
event=aten::transpose node=t_1 stack_trace=x = self.linear2(x)
|
||||
event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x)
|
||||
event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x)
|
||||
event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_multiple_modules(self):
|
||||
"""
|
||||
Test that multiple compiled modules under the same profiler session
|
||||
have their events correctly augmented with stack traces.
|
||||
"""
|
||||
|
||||
class ModelA(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 1
|
||||
|
||||
class ModelB(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x - 1
|
||||
|
||||
model_a = ModelA().cuda()
|
||||
model_b = ModelB().cuda()
|
||||
|
||||
# Compile both models
|
||||
compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True)
|
||||
compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True)
|
||||
|
||||
# Warmup
|
||||
for _ in range(3):
|
||||
_ = compiled_a(torch.randn(10, 10, device="cuda"))
|
||||
_ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
|
||||
|
||||
# Profile both models in the same session
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
) as prof:
|
||||
result_a = compiled_a(torch.randn(10, 10, device="cuda"))
|
||||
result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
|
||||
|
||||
actual_traces = _enrich_profiler_traces(prof)
|
||||
self.assertExpectedInline(actual_traces, """\
|
||||
event=aten::add node=add stack_trace=return x + 1
|
||||
event=cudaLaunchKernel node=add stack_trace=return x + 1
|
||||
event=aten::sub node=sub stack_trace=return x - 1
|
||||
event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_nested_graph_modules(self):
|
||||
"""
|
||||
Test that nested graph modules (e.g., graph modules calling subgraphs)
|
||||
have their events correctly augmented with stack traces.
|
||||
"""
|
||||
|
||||
# Model with nested structure
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c = 5
|
||||
|
||||
@torch.compiler.nested_compile_region
|
||||
def forward(self, x, y):
|
||||
m = torch.mul(x, y)
|
||||
s = m.sin()
|
||||
a = s + self.c
|
||||
return a
|
||||
|
||||
model = Mod().cuda()
|
||||
|
||||
# Compile the model (this may create nested graph modules)
|
||||
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
|
||||
|
||||
# Warmup
|
||||
for _ in range(3):
|
||||
_ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
|
||||
|
||||
# Profile
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
) as prof:
|
||||
result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
|
||||
|
||||
actual_traces = _enrich_profiler_traces(prof)
|
||||
self.assertExpectedInline(actual_traces, """\
|
||||
event=aten::mul node=mul stack_trace=m = torch.mul(x, y)
|
||||
event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y)
|
||||
event=aten::sin node=sin stack_trace=s = m.sin()
|
||||
event=cudaLaunchKernel node=sin stack_trace=s = m.sin()
|
||||
event=aten::add node=add stack_trace=a = s + self.c
|
||||
event=cudaLaunchKernel node=add stack_trace=a = s + self.c"""
|
||||
)
|
||||
|
||||
|
||||
def run_getitem_target():
|
||||
from torch.fx._symbolic_trace import _wrapped_methods_to_patch
|
||||
|
||||
@ -359,6 +359,29 @@ class TestMatmulCuda(InductorTestCase):
|
||||
self.assertEqual(agrad, a.grad)
|
||||
self.assertEqual(bgrad, b.grad)
|
||||
|
||||
@onlyCUDA
|
||||
@skipIfRocm
|
||||
@dtypes(torch.half, torch.bfloat16)
|
||||
@unittest.skipIf(not SM100OrLater, "cuBLAS integration for batch invariance is only on Blackwell")
|
||||
@serialTest()
|
||||
def test_cublas_batch_invariance_blackwell(self, device, dtype):
|
||||
orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
|
||||
orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (False, False)
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (False, False)
|
||||
with blas_library_context('cublaslt'):
|
||||
N = 2048
|
||||
K = 6144
|
||||
M_max = 32
|
||||
x = torch.randn(M_max, K, device="cuda", dtype=torch.bfloat16)
|
||||
w = torch.randn(N, K, device="cuda", dtype=torch.bfloat16).t()
|
||||
full = x @ w
|
||||
xx = x[:1]
|
||||
out = xx @ w
|
||||
self.assertEqual(full[:1], out, atol=0., rtol=0.)
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16
|
||||
|
||||
@unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater")
|
||||
@parametrize("strided", [False, True])
|
||||
@parametrize("a_row_major", [False, True])
|
||||
|
||||
@ -4470,6 +4470,14 @@ class TestMPS(TestCaseMPS):
|
||||
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
def test_bce_backward_with_no_reduction_and_one_in_shape(self):
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/166746
|
||||
output = torch.zeros(3, 2, 1, requires_grad=True, device='mps')
|
||||
target = torch.zeros(3, 2, 1, device='mps')
|
||||
torch.sum(nn.BCELoss(reduction='none')(output, target)).backward()
|
||||
expected_grad = torch.zeros(3, 2, 1, device='mps')
|
||||
self.assertEqual(output.grad, expected_grad)
|
||||
|
||||
def test_cross_entropy_loss(self):
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/116095
|
||||
loss = nn.CrossEntropyLoss()
|
||||
|
||||
@ -257,34 +257,6 @@ class TestFuzzerCompileIssues(TestCase):
|
||||
out_compiled.sum().backward()
|
||||
print("Compile Success! ✅")
|
||||
|
||||
@pytest.mark.xfail(reason="Issue #163971")
|
||||
def test_fuzzer_issue_163971(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
def foo(arg0):
|
||||
t0 = arg0 # size=(), stride=(), dtype=bfloat16, device=cuda
|
||||
t1 = torch.softmax(
|
||||
t0, dim=0
|
||||
) # size=(), stride=(), dtype=bfloat16, device=cuda
|
||||
t2 = torch.nn.functional.gelu(
|
||||
t1
|
||||
) # size=(), stride=(), dtype=bfloat16, device=cuda
|
||||
t3 = torch.softmax(
|
||||
t2, dim=0
|
||||
) # size=(), stride=(), dtype=bfloat16, device=cuda
|
||||
output = t3
|
||||
return output
|
||||
|
||||
arg0 = torch.rand([], dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
||||
|
||||
out_eager = foo(arg0)
|
||||
out_eager.sum().backward()
|
||||
print("Eager Success! ✅")
|
||||
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
|
||||
out_compiled = compiled_foo(arg0)
|
||||
out_compiled.sum().backward()
|
||||
print("Compile Success! ✅")
|
||||
|
||||
@pytest.mark.xfail(reason="Issue #164059")
|
||||
def test_fuzzer_issue_164059(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@ -1914,6 +1914,7 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
q, k, v, None, 0.0, is_causal=True))
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
|
||||
def test_mem_eff_attention_fail_with_batch_size_geq_65536(self):
|
||||
batch_size = 2**16
|
||||
query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
|
||||
@ -1935,6 +1936,7 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4)
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
|
||||
def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self):
|
||||
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
@ -1948,6 +1950,7 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
|
||||
@largeTensorTest("15GB", "cuda")
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
|
||||
def test_mem_eff_attention_large_seq_len_uniform_attention(self):
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
@ -2063,7 +2063,8 @@ class PythonWrapperCodegen(CodeGen):
|
||||
neg = self.codegen_sizevar(
|
||||
sympy.Max(0, sympy.Min(x + node.size, node.size))
|
||||
)
|
||||
return f"{pos} if {x} >= 0 else {neg}"
|
||||
x_cond = self.codegen_sizevar(x)
|
||||
return f"{pos} if {x_cond} >= 0 else {neg}"
|
||||
|
||||
def codegen_with_step(start_var, end_var, step):
|
||||
if step == 1:
|
||||
|
||||
@ -1224,3 +1224,43 @@ def _build_table(
|
||||
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
|
||||
)
|
||||
return "".join(result)
|
||||
|
||||
|
||||
# Collect all events with stack traces and format them canonically
|
||||
def _canonicalize_profiler_events(events):
|
||||
"""
|
||||
Extract and format all events with stack traces in a canonical way
|
||||
for deterministic testing.
|
||||
"""
|
||||
events_with_traces = []
|
||||
|
||||
for event in events:
|
||||
# Extract relevant fields
|
||||
event_name = event.get("name", "")
|
||||
node_name = event["args"].get("node_name", "")
|
||||
stack_trace = event["args"].get("stack_trace", "")
|
||||
|
||||
# Get the last non-empty line of the stack trace
|
||||
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
|
||||
stack_trace = lines[-1] if lines else ""
|
||||
|
||||
events_with_traces.append(
|
||||
{
|
||||
"event_name": event_name[:20],
|
||||
"node_name": node_name,
|
||||
"stack_trace": stack_trace,
|
||||
"start_time": event.get("ts", 0),
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by node_name for deterministic ordering
|
||||
events_with_traces.sort(key=lambda x: x["start_time"])
|
||||
|
||||
# Format as a string
|
||||
lines: list[str] = []
|
||||
for evt in events_with_traces:
|
||||
lines.append(
|
||||
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
@ -547,6 +547,7 @@ def rebind_unbacked(
|
||||
assert shape_env is not None
|
||||
for raw_u0, path in bindings.items():
|
||||
u1 = pytree.key_get(result, path)
|
||||
|
||||
# Sometimes, things were previously unbacked bindings become constants.
|
||||
# There are two situations this can happen.
|
||||
#
|
||||
@ -602,7 +603,23 @@ def rebind_unbacked(
|
||||
if u1.node.hint is not None:
|
||||
continue
|
||||
|
||||
raw_u1 = u1.node.expr
|
||||
# unbacked symbols bindings might be replaced to other backed or
|
||||
# unbacked replacements.
|
||||
#
|
||||
# Example:
|
||||
# u = x.item()
|
||||
# torch._check(u == 5)
|
||||
#
|
||||
# The safest approach is to retrieve raw_u1 from u1.node._expr
|
||||
# and perform the rebinding on the original unbacked symbol,
|
||||
# even if it’s no longer directly referenced.
|
||||
#
|
||||
# In other words, we should always rebind the original symbol
|
||||
# before any replacements are applied.
|
||||
# u0 -> u0 == s1
|
||||
raw_u1 = u1.node._expr
|
||||
|
||||
# TODO Do we still need this logic below?
|
||||
# Simplify SymBool binding
|
||||
if (
|
||||
isinstance(raw_u1, sympy.Piecewise)
|
||||
|
||||
@ -443,6 +443,7 @@ class CodeGen:
|
||||
colored: bool = False,
|
||||
# Render each argument on its own line
|
||||
expanded_def: bool = False,
|
||||
record_func: bool = False,
|
||||
) -> PythonCode:
|
||||
free_vars: list[str] = []
|
||||
body: list[str] = []
|
||||
@ -817,6 +818,10 @@ class CodeGen:
|
||||
return
|
||||
raise NotImplementedError(f"node: {node.op} {node.target}")
|
||||
|
||||
if record_func:
|
||||
body.append(
|
||||
"_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n"
|
||||
)
|
||||
for i, node in enumerate(nodes):
|
||||
# NOTE: emit_node does not emit a string with newline. It depends
|
||||
# on delete_unused_values to append one
|
||||
@ -826,8 +831,22 @@ class CodeGen:
|
||||
# node index, which will be deleted later
|
||||
# after going through _body_transformer
|
||||
body.append(f"# COUNTER: {i}\n")
|
||||
do_record = record_func and node.op in (
|
||||
"call_function",
|
||||
"call_method",
|
||||
"call_module",
|
||||
)
|
||||
if do_record:
|
||||
# The double hash ## convention is used by post-processing to find the fx markers
|
||||
body.append(
|
||||
f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n"
|
||||
)
|
||||
emit_node(node)
|
||||
delete_unused_values(node)
|
||||
if do_record:
|
||||
body.append(f"_rf_{node.name}.__exit__(None, None, None)\n")
|
||||
if record_func:
|
||||
body.append("_rf.__exit__(None, None, None)\n")
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
@ -1779,6 +1798,7 @@ class Graph:
|
||||
include_device: bool = False,
|
||||
colored: bool = False,
|
||||
expanded_def: bool = False,
|
||||
record_func: bool = False,
|
||||
) -> PythonCode:
|
||||
"""
|
||||
Turn this ``Graph`` into valid Python code.
|
||||
@ -1846,6 +1866,7 @@ class Graph:
|
||||
include_device=include_device,
|
||||
colored=colored,
|
||||
expanded_def=expanded_def,
|
||||
record_func=record_func,
|
||||
)
|
||||
|
||||
def _python_code(
|
||||
@ -1858,6 +1879,7 @@ class Graph:
|
||||
include_device: bool = False,
|
||||
colored: bool = False,
|
||||
expanded_def: bool = False,
|
||||
record_func: bool = False,
|
||||
) -> PythonCode:
|
||||
return self._codegen._gen_python_code(
|
||||
self.nodes,
|
||||
@ -1868,6 +1890,7 @@ class Graph:
|
||||
include_device=include_device,
|
||||
colored=colored,
|
||||
expanded_def=expanded_def,
|
||||
record_func=record_func,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
||||
@ -861,14 +861,18 @@ class {module_name}(torch.nn.Module):
|
||||
if isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||||
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||||
python_code = self._graph.python_code(root_module="self")
|
||||
|
||||
from torch._dynamo import config as dynamo_config
|
||||
|
||||
python_code = self._graph.python_code(
|
||||
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
|
||||
)
|
||||
self._code = python_code.src
|
||||
self._lineno_map = python_code._lineno_map
|
||||
self._prologue_start = python_code._prologue_start
|
||||
|
||||
cls = type(self)
|
||||
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
|
||||
from torch._dynamo import config as dynamo_config
|
||||
|
||||
if dynamo_config.enrich_profiler_metadata:
|
||||
# Generate metadata and register for profiler augmentation
|
||||
@ -885,7 +889,6 @@ class {module_name}(torch.nn.Module):
|
||||
# This ensures the same code+metadata always generates the same filename
|
||||
hash_value = _metadata_hash(self._code, node_metadata)
|
||||
file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}"
|
||||
|
||||
filename = f"{file_stem}.py"
|
||||
|
||||
# Only include co_filename to use it directly as the cache key
|
||||
@ -905,6 +908,13 @@ class {module_name}(torch.nn.Module):
|
||||
|
||||
_register_fx_metadata(filename, metadata)
|
||||
|
||||
# Replace the placeholder in generated code with actual filename
|
||||
# The double hash ## convention is used by post-processing to find the fx markers
|
||||
self._code = self._code.replace(
|
||||
"torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')",
|
||||
f"torch._C._profiler._RecordFunctionFast('## {filename} ##')",
|
||||
)
|
||||
|
||||
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
|
||||
|
||||
# Determine whether this class explicitly defines a __call__ implementation
|
||||
|
||||
@ -4,7 +4,7 @@ import operator
|
||||
import re
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any, Literal, Optional, TYPE_CHECKING
|
||||
|
||||
from torch.autograd.profiler import profile
|
||||
from torch.profiler import DeviceType
|
||||
@ -400,3 +400,170 @@ def _init_for_cuda_graphs() -> None:
|
||||
|
||||
with profile():
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimelineEvent:
|
||||
"""Represents an event in the profiler timeline."""
|
||||
|
||||
timestamp: int
|
||||
event_type: Literal["start", "end", "regular"]
|
||||
marker_type: Optional[Literal["filename", "node"]]
|
||||
identifier: Optional[str | int]
|
||||
event: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextStackEntry:
|
||||
"""Represents a context (filename or node) in the stack."""
|
||||
|
||||
context_type: Literal["filename", "node"]
|
||||
identifier: str | int
|
||||
metadata: Optional[dict]
|
||||
tid: Optional[int] = None # Thread ID associated with this context
|
||||
|
||||
|
||||
def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
"""
|
||||
Maps recorded profiler events to their corresponding fx nodes and adds stack traces.
|
||||
|
||||
Builds a timeline of all events (regular ops and FX markers for filenames/nodes),
|
||||
sorts by timestamp, then processes chronologically while maintaining a context stack of active
|
||||
filename/node scopes. Regular events are augmented with stack traces and node names from the
|
||||
innermost active context. Runtime is O(n log n) for n events.
|
||||
|
||||
Args:
|
||||
traced_data: Json of profiler events from Chrome trace
|
||||
|
||||
Returns:
|
||||
Dict mapping recorded event names to their aten operations with added stack traces
|
||||
"""
|
||||
from torch.fx.traceback import _FX_METADATA_REGISTRY
|
||||
|
||||
trace_events = traced_data.get("traceEvents", [])
|
||||
|
||||
# Create event timeline
|
||||
event_timeline: list[TimelineEvent] = []
|
||||
|
||||
def is_fx_marker_event(event):
|
||||
return (
|
||||
event.get("cat") == "cpu_op"
|
||||
and event.get("name", "").startswith("## ")
|
||||
and event.get("name", "").endswith(" ##")
|
||||
)
|
||||
|
||||
def append_fx_marker_event(event_type, identifier, event):
|
||||
start_ts = event["ts"]
|
||||
end_ts = start_ts + event["dur"]
|
||||
event_timeline.append(
|
||||
TimelineEvent(start_ts, "start", event_type, identifier, event)
|
||||
)
|
||||
event_timeline.append(
|
||||
TimelineEvent(end_ts, "end", event_type, identifier, event)
|
||||
)
|
||||
|
||||
for event in trace_events:
|
||||
if "ts" not in event or "dur" not in event:
|
||||
continue
|
||||
|
||||
if is_fx_marker_event(event):
|
||||
content = event["name"][3:-3]
|
||||
|
||||
if content.endswith(".py"):
|
||||
append_fx_marker_event("filename", content, event)
|
||||
else:
|
||||
try:
|
||||
node_index = int(content)
|
||||
except ValueError:
|
||||
pass
|
||||
append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined]
|
||||
|
||||
else:
|
||||
# Regular event that needs augmentation
|
||||
start_ts = event["ts"]
|
||||
event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event))
|
||||
|
||||
# Sort by timestamp
|
||||
event_timeline.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# Process events in chronological order with a stack
|
||||
context_stack: list[ContextStackEntry] = []
|
||||
|
||||
# Invariant: all start event has a corresponding end event
|
||||
for timeline_event in event_timeline:
|
||||
match timeline_event.event_type:
|
||||
case "start":
|
||||
assert timeline_event.identifier is not None
|
||||
|
||||
if timeline_event.marker_type == "filename":
|
||||
assert isinstance(timeline_event.identifier, str)
|
||||
# Push filename context - query metadata registry on-demand
|
||||
metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier)
|
||||
tid = timeline_event.event.get("tid")
|
||||
context_stack.append(
|
||||
ContextStackEntry(
|
||||
"filename", timeline_event.identifier, metadata, tid
|
||||
)
|
||||
)
|
||||
elif timeline_event.marker_type == "node":
|
||||
# Find the current filename from stack
|
||||
current_file_metadata = None
|
||||
tid = timeline_event.event.get("tid")
|
||||
for ctx_entry in reversed(context_stack):
|
||||
if (
|
||||
ctx_entry.context_type == "filename"
|
||||
and ctx_entry.tid == tid
|
||||
):
|
||||
current_file_metadata = ctx_entry.metadata
|
||||
break
|
||||
|
||||
if current_file_metadata:
|
||||
node_metadata = current_file_metadata.get("node_metadata", {})
|
||||
if timeline_event.identifier in node_metadata:
|
||||
node_meta: Optional[dict] = node_metadata[
|
||||
timeline_event.identifier
|
||||
]
|
||||
context_stack.append(
|
||||
ContextStackEntry(
|
||||
"node", timeline_event.identifier, node_meta, tid
|
||||
)
|
||||
)
|
||||
|
||||
case "end":
|
||||
# Pop from stack - search backwards to find matching context
|
||||
for i in range(len(context_stack) - 1, -1, -1):
|
||||
ctx_entry = context_stack[i]
|
||||
if (
|
||||
timeline_event.marker_type == ctx_entry.context_type
|
||||
and timeline_event.identifier == ctx_entry.identifier
|
||||
):
|
||||
context_stack.pop(i)
|
||||
break
|
||||
|
||||
case "regular":
|
||||
# Apply metadata from current context stack
|
||||
# Find the most specific context (node takes precedence over filename)
|
||||
# Only augment events with the same tid as the file/node event matched
|
||||
current_stack_trace = None
|
||||
current_node_name = None
|
||||
event_tid = timeline_event.event.get("tid")
|
||||
|
||||
for ctx_entry in reversed(context_stack):
|
||||
# Only apply metadata from contexts with matching tid
|
||||
if ctx_entry.tid == event_tid:
|
||||
if ctx_entry.context_type == "node" and ctx_entry.metadata:
|
||||
current_stack_trace = ctx_entry.metadata.get(
|
||||
"stack_trace", "No model stack trace available"
|
||||
)
|
||||
current_node_name = ctx_entry.metadata.get("name", "")
|
||||
# Do we want to only attach the stack trace of the lowest node or stack trace of all nodes
|
||||
# if nodes are nested, e.g. in nested graph modules
|
||||
break
|
||||
|
||||
# Augment the event
|
||||
if current_stack_trace or current_node_name:
|
||||
args = timeline_event.event.setdefault("args", {})
|
||||
if current_stack_trace:
|
||||
args["stack_trace"] = current_stack_trace
|
||||
if current_node_name:
|
||||
args["node_name"] = current_node_name
|
||||
|
||||
@ -306,6 +306,24 @@ class PythonPrinter(ExprPrinter):
|
||||
raise TypeError("ndigits must be an instance of sympy.Integer")
|
||||
return f"round({self._print(number)}, {ndigits})"
|
||||
|
||||
def _print_Piecewise(self, expr: sympy.Expr) -> str:
|
||||
# Convert Piecewise(expr_cond_pairs) to nested ternary expressions
|
||||
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
|
||||
# becomes: e1 if c1 else (e2 if c2 else (... else eN))
|
||||
result: Optional[str] = None
|
||||
for expr_i, cond_i in reversed(expr.args):
|
||||
expr_str = self._print(expr_i)
|
||||
if cond_i == True: # noqa: E712
|
||||
# This is the default case
|
||||
result = expr_str
|
||||
else:
|
||||
cond_str = self._print(cond_i)
|
||||
if result is None:
|
||||
result = expr_str
|
||||
else:
|
||||
result = f"({expr_str} if {cond_str} else {result})"
|
||||
return result if result else "0"
|
||||
|
||||
|
||||
class CppPrinter(ExprPrinter):
|
||||
def _print_Integer(self, expr: sympy.Expr) -> str:
|
||||
@ -327,6 +345,24 @@ class CppPrinter(ExprPrinter):
|
||||
)
|
||||
return f"{c} ? {p} : {q}"
|
||||
|
||||
def _print_Piecewise(self, expr: sympy.Expr) -> str:
|
||||
# Convert Piecewise(expr_cond_pairs) to nested ternary operators
|
||||
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
|
||||
# becomes: c1 ? e1 : (c2 ? e2 : (... : eN))
|
||||
result: Optional[str] = None
|
||||
for expr_i, cond_i in reversed(expr.args):
|
||||
expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5)
|
||||
if cond_i == True: # noqa: E712
|
||||
# This is the default case
|
||||
result = expr_str
|
||||
else:
|
||||
cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5)
|
||||
if result is None:
|
||||
result = expr_str
|
||||
else:
|
||||
result = f"{cond_str} ? {expr_str} : {result}"
|
||||
return f"({result})" if result else "0"
|
||||
|
||||
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
|
||||
x, div, mod = expr.args
|
||||
x = self.doprint(x)
|
||||
|
||||
Reference in New Issue
Block a user