mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-07 10:01:39 +08:00
Compare commits
2 Commits
annotate_a
...
gh/mlazos/
| Author | SHA1 | Date | |
|---|---|---|---|
| 440a72e12f | |||
| 6badbfdc3e |
@ -7,13 +7,13 @@ ENV LC_ALL en_US.UTF-8
|
||||
ENV LANG en_US.UTF-8
|
||||
ENV LANGUAGE en_US.UTF-8
|
||||
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
ARG DEVTOOLSET_VERSION=11
|
||||
|
||||
RUN yum -y update
|
||||
RUN yum -y install epel-release
|
||||
# install glibc-langpack-en make sure en_US.UTF-8 locale is available
|
||||
RUN yum -y install glibc-langpack-en
|
||||
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
|
||||
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-toolchain
|
||||
# Just add everything as a safe.directory for git since these will be used in multiple places with git
|
||||
RUN git config --global --add safe.directory '*'
|
||||
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||
@ -41,7 +41,6 @@ RUN bash ./install_conda.sh && rm install_conda.sh
|
||||
# Install CUDA
|
||||
FROM base as cuda
|
||||
ARG CUDA_VERSION=12.6
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
RUN rm -rf /usr/local/cuda-*
|
||||
ADD ./common/install_cuda.sh install_cuda.sh
|
||||
COPY ./common/install_nccl.sh install_nccl.sh
|
||||
@ -51,8 +50,7 @@ ENV CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
|
||||
# Preserve CUDA_VERSION for the builds
|
||||
ENV CUDA_VERSION=${CUDA_VERSION}
|
||||
# Make things in our path by default
|
||||
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||
|
||||
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:$PATH
|
||||
|
||||
FROM cuda as cuda12.6
|
||||
RUN bash ./install_cuda.sh 12.6
|
||||
@ -70,22 +68,8 @@ FROM cuda as cuda13.0
|
||||
RUN bash ./install_cuda.sh 13.0
|
||||
ENV DESIRED_CUDA=13.0
|
||||
|
||||
FROM ${ROCM_IMAGE} as rocm_base
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
ENV LC_ALL en_US.UTF-8
|
||||
ENV LANG en_US.UTF-8
|
||||
ENV LANGUAGE en_US.UTF-8
|
||||
# Install devtoolset on ROCm base image
|
||||
RUN yum -y update && \
|
||||
yum -y install epel-release && \
|
||||
yum -y install glibc-langpack-en && \
|
||||
yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
|
||||
RUN git config --global --add safe.directory '*'
|
||||
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||
|
||||
FROM rocm_base as rocm
|
||||
FROM ${ROCM_IMAGE} as rocm
|
||||
ARG PYTORCH_ROCM_ARCH
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
|
||||
ADD ./common/install_mkl.sh install_mkl.sh
|
||||
RUN bash ./install_mkl.sh && rm install_mkl.sh
|
||||
@ -104,7 +88,6 @@ COPY --from=cuda13.0 /usr/local/cuda-13.0 /usr/local/cuda-13.0
|
||||
|
||||
# Final step
|
||||
FROM ${BASE_TARGET} as final
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
COPY --from=openssl /opt/openssl /opt/openssl
|
||||
COPY --from=patchelf /patchelf /usr/local/bin/patchelf
|
||||
COPY --from=conda /opt/conda /opt/conda
|
||||
|
||||
@ -63,7 +63,7 @@ docker build \
|
||||
--target final \
|
||||
--progress plain \
|
||||
--build-arg "BASE_TARGET=${BASE_TARGET}" \
|
||||
--build-arg "DEVTOOLSET_VERSION=13" \
|
||||
--build-arg "DEVTOOLSET_VERSION=11" \
|
||||
${EXTRA_BUILD_ARGS} \
|
||||
-t ${tmp_tag} \
|
||||
$@ \
|
||||
|
||||
@ -388,7 +388,6 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
#ifndef USE_ROCM
|
||||
at::Half halpha;
|
||||
at::Half hbeta;
|
||||
uint32_t mask = -1;
|
||||
#endif
|
||||
void * alpha_ptr = α
|
||||
void * beta_ptr = β
|
||||
@ -428,7 +427,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
|
||||
if (fp16_reduction !=
|
||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||
mask =
|
||||
uint32_t mask =
|
||||
fp16_reduction ==
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||
@ -445,7 +444,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
|
||||
if (bf16_reduction !=
|
||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||
mask =
|
||||
uint32_t mask =
|
||||
bf16_reduction ==
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||
@ -512,41 +511,17 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||
int returnedResult = 0;
|
||||
// on Blackwell+, we fake a n > 1 matmul when querying heuristics
|
||||
// to prevent cuBLASLt from dispatching to a GEMV kernel for batch-invariance
|
||||
#ifndef USE_ROCM
|
||||
const bool lie_to_cublaslt = mask == CUBLASLT_REDUCTION_SCHEME_NONE && n == 1 && at::cuda::getCurrentDeviceProperties()->major >= 10;
|
||||
#else
|
||||
const bool lie_to_cublaslt = false;
|
||||
#endif
|
||||
if (lie_to_cublaslt) {
|
||||
CuBlasLtMatrixLayout FakeBdesc(abType, k, 2, ldb, opb == CUBLAS_OP_T);
|
||||
CuBlasLtMatrixLayout FakeCdesc(cType, m, 2, ldc);
|
||||
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
Adesc.descriptor(),
|
||||
FakeBdesc.descriptor(),
|
||||
FakeCdesc.descriptor(),
|
||||
FakeCdesc.descriptor(),
|
||||
preference.descriptor(),
|
||||
1,
|
||||
&heuristicResult,
|
||||
&returnedResult));
|
||||
} else {
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
Adesc.descriptor(),
|
||||
Bdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
preference.descriptor(),
|
||||
1,
|
||||
&heuristicResult,
|
||||
&returnedResult));
|
||||
}
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
Adesc.descriptor(),
|
||||
Bdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
preference.descriptor(),
|
||||
1,
|
||||
&heuristicResult,
|
||||
&returnedResult));
|
||||
if (returnedResult == 0) {
|
||||
cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
|
||||
}
|
||||
|
||||
@ -24,13 +24,7 @@ namespace detail {
|
||||
// radix_sort_pairs doesn't interact with value_t other than to copy
|
||||
// the data, so we can save template instantiations by reinterpreting
|
||||
// it as an opaque type.
|
||||
// We use native integer types for 1/2/4/8-byte values to reduce
|
||||
// register usage in CUDA kernels. For sizes > 8 fall back to char array.
|
||||
template <int N> struct alignas(N) OpaqueType { char data[N]; };
|
||||
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
|
||||
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
|
||||
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
|
||||
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
|
||||
|
||||
template<typename key_t, int value_size>
|
||||
void radix_sort_pairs_impl(
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
@ -1711,37 +1710,11 @@ Tensor narrow_symint(
|
||||
"], but got ",
|
||||
start,
|
||||
")")
|
||||
|
||||
auto cond1 = TORCH_GUARD_OR_FALSE(start.sym_lt(0));
|
||||
auto cond2 = TORCH_GUARD_OR_FALSE(start.sym_ge(0));
|
||||
|
||||
if (cond1 || cond2) {
|
||||
if (cond1) {
|
||||
start = start + cur_size;
|
||||
}
|
||||
|
||||
TORCH_SYM_CHECK(
|
||||
start.sym_le(cur_size - length),
|
||||
"start (",
|
||||
start,
|
||||
") + length (",
|
||||
length,
|
||||
") exceeds dimension size (",
|
||||
cur_size,
|
||||
").");
|
||||
return at::slice_symint(self, dim, start, start + length, 1);
|
||||
if (start < 0) {
|
||||
start = start + cur_size;
|
||||
}
|
||||
|
||||
// Unbacked start handling!
|
||||
|
||||
// Bounds check without converting start:
|
||||
// - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start +
|
||||
// length <= 0
|
||||
// - If start >= 0: need start + length <= cur_size
|
||||
auto end = start + length;
|
||||
TORCH_SYM_CHECK(
|
||||
(start.sym_lt(0).sym_and((end).sym_le(0)))
|
||||
.sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))),
|
||||
start.sym_le(cur_size - length),
|
||||
"start (",
|
||||
start,
|
||||
") + length (",
|
||||
@ -1749,28 +1722,7 @@ Tensor narrow_symint(
|
||||
") exceeds dimension size (",
|
||||
cur_size,
|
||||
").");
|
||||
|
||||
if (TORCH_GUARD_OR_FALSE(end.sym_ne(0))) {
|
||||
return at::slice_symint(self, dim, start, end, 1);
|
||||
} else {
|
||||
// Cannot statically determine the condition due to unbacked.
|
||||
// This is an interesting situation; when start is negative and
|
||||
// start + length == 0, slice and narrow do different things.
|
||||
// i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to
|
||||
// pass curr_size instead of 0. Otherwise, they would do the same thing.
|
||||
// This says at runtime: if start < 0 and end == 0, then pass curr_size
|
||||
// instead of 0.
|
||||
|
||||
auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt();
|
||||
auto result =
|
||||
at::slice_symint(self, dim, start, end + use_different * cur_size, 1);
|
||||
|
||||
// Ensure slice allocated unbacked size is specialized to length.
|
||||
SymInt new_size = result.sym_size(dim);
|
||||
TORCH_SYM_CHECK(new_size.sym_eq(length), "")
|
||||
|
||||
return result;
|
||||
}
|
||||
return at::slice_symint(self, dim, start, start + length, 1);
|
||||
}
|
||||
|
||||
// This overload exists purely for XLA, because they wanted to pass in
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#include <c10/core/SymBool.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
|
||||
namespace c10 {
|
||||
@ -112,17 +111,4 @@ bool SymBool::has_hint() const {
|
||||
return toSymNodeImpl()->has_hint();
|
||||
}
|
||||
|
||||
SymInt SymBool::toSymInt() const {
|
||||
// If concrete bool, return concrete SymInt
|
||||
if (auto ma = maybe_as_bool()) {
|
||||
return SymInt(*ma ? 1 : 0);
|
||||
}
|
||||
|
||||
// Symbolic case: use sym_ite to convert bool to int (0 or 1)
|
||||
auto node = toSymNodeImpl();
|
||||
auto one_node = node->wrap_int(1);
|
||||
auto zero_node = node->wrap_int(0);
|
||||
return SymInt(node->sym_ite(one_node, zero_node));
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
@ -12,8 +12,6 @@
|
||||
|
||||
namespace c10 {
|
||||
|
||||
class SymInt;
|
||||
|
||||
class C10_API SymBool {
|
||||
public:
|
||||
/*implicit*/ SymBool(bool b) : data_(b) {}
|
||||
@ -82,10 +80,6 @@ class C10_API SymBool {
|
||||
return toSymNodeImplUnowned()->constant_bool();
|
||||
}
|
||||
|
||||
// Convert SymBool to SymInt (0 or 1)
|
||||
// This is the C++ equivalent of Python's cast_symbool_to_symint_guardless
|
||||
SymInt toSymInt() const;
|
||||
|
||||
bool is_heap_allocated() const {
|
||||
return ptr_;
|
||||
}
|
||||
|
||||
@ -5789,229 +5789,6 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
else:
|
||||
self.assertTrue("duration_ms" not in t["entries"][0])
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset_circular_buffer_full(self, timing_enabled):
|
||||
"""
|
||||
Test that when the circular buffer in entries_ is full and we call reset,
|
||||
then fill the buffer with new entries, dump_entries returns only the new
|
||||
entries and not the old ones.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# Fill the buffer completely with 10 entries
|
||||
for _ in range(10):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Verify buffer is full with 10 entries
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 10)
|
||||
|
||||
# Now reset the flight recorder
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Add new entries after reset - fill the buffer completely again
|
||||
for _ in range(10):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Verify we get exactly 10 new entries, not 20
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 10)
|
||||
|
||||
# Verify all entries have the expected properties (from after reset)
|
||||
# After reset, record IDs should start from 0 again
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("profiling_name", entry)
|
||||
self.assertEqual(entry["profiling_name"], "nccl:all_reduce")
|
||||
self.assertIn("record_id", entry)
|
||||
# Record IDs should be sequential starting from 0 after reset
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset_partial_overwrite(self, timing_enabled):
|
||||
"""
|
||||
Test that when the circular buffer is full, we reset, and then add fewer
|
||||
entries than the buffer size, we only get the new entries.
|
||||
This tests that old entries at the end of the circular buffer are properly
|
||||
filtered out based on reset_epoch.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# Fill the buffer completely
|
||||
for _ in range(10):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Reset the flight recorder
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Add only 3 new entries (much less than buffer size)
|
||||
for _ in range(3):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Verify we only get the 3 new entries, not 10
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 3)
|
||||
|
||||
# Verify record IDs start from 0 after reset
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("record_id", entry)
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset_wraparound(self, timing_enabled):
|
||||
"""
|
||||
Test that when we reset in the middle of the circular buffer and then
|
||||
wrap around, dump_entries correctly returns only entries from the current
|
||||
epoch in the correct order.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# Fill half the buffer
|
||||
for _ in range(5):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Reset at this point (reset happens at index 5)
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Now add 8 entries, which will wrap around
|
||||
# (5->9 fills rest of buffer, then 0->2 wraps around)
|
||||
for _ in range(8):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Should get exactly 8 entries, properly ordered
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 8)
|
||||
|
||||
# Entries should be in chronological order
|
||||
# The dump_entries() method returns entries from next_ to end, then 0 to next_
|
||||
# After filtering old entries, we should have 8 entries in order
|
||||
# Verify record IDs start from 0 after reset (id_ is reset in reset_all())
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("profiling_name", entry)
|
||||
self.assertIn("record_id", entry)
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_multiple_resets(self, timing_enabled):
|
||||
"""
|
||||
Test multiple consecutive resets to ensure each reset properly increments
|
||||
the epoch and filters out entries from previous epochs.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# First batch: 2 entries
|
||||
for _ in range(2):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# First reset
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Second batch: 3 entries
|
||||
for _ in range(3):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Second reset
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Third batch: 4 entries
|
||||
for _ in range(4):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Should only see the last 4 entries
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 4)
|
||||
|
||||
# Verify record IDs start from 0 after the last reset
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("record_id", entry)
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def check_if_test_is_skipped(fn):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
|
||||
@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca
|
||||
torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None)
|
||||
torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node
|
||||
torch.fx.graph.Graph.print_tabular(self)
|
||||
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode
|
||||
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode
|
||||
torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule')
|
||||
torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool
|
||||
torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None
|
||||
|
||||
@ -721,34 +721,6 @@ class TestExport(TestCase):
|
||||
)
|
||||
self.assertEqual(node.meta["from_node"][-1].graph_id, graph_id)
|
||||
|
||||
def test_annotate_on_assert(self):
|
||||
# nodes added in `apply_runtime_assertion_pass` will be annotated
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
with torch.fx.traceback.annotate({"moo": 0}):
|
||||
x = torch.cat([x, x])
|
||||
b = y.item()
|
||||
torch._check(b >= x.shape[0])
|
||||
return x * b
|
||||
|
||||
with torch.fx.traceback.preserve_node_meta():
|
||||
ep = torch.export.export(
|
||||
M(),
|
||||
(torch.randn(3), torch.tensor(6)),
|
||||
dynamic_shapes={"x": {0: Dim("b")}, "y": None},
|
||||
)
|
||||
|
||||
custom_metadata = torch.fx.traceback._get_custom_metadata(ep.module())
|
||||
self.assertExpectedInline(
|
||||
str(custom_metadata),
|
||||
"""\
|
||||
('call_function', 'cat', {'moo': 0})
|
||||
('call_function', 'item', {'moo': 0})
|
||||
('call_function', 'ge_1', {'moo': 0})
|
||||
('call_function', '_assert_scalar_default', {'moo': 0})
|
||||
('call_function', 'mul', {'moo': 0})""",
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
def test_flex_attention_export(self):
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
@ -6121,19 +6093,26 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
retry_export(
|
||||
cf_implicitsize(),
|
||||
(torch.tensor(2), torch.randn(10)),
|
||||
fixes=[],
|
||||
fixes=[
|
||||
# Could not guard on data-dependent expression u0 < 0
|
||||
"torch._check(i >= 0)",
|
||||
],
|
||||
)
|
||||
|
||||
class cf_stacklist(torch.nn.Module):
|
||||
def forward(self, xs, y, fixes):
|
||||
i = y.item()
|
||||
eval(fixes)
|
||||
# instead of xs[i]
|
||||
return torch.stack(xs, 0).narrow(0, i, 1).squeeze()
|
||||
|
||||
retry_export(
|
||||
cf_stacklist(),
|
||||
([torch.ones(5) * i for i in range(10)], torch.tensor(2)),
|
||||
fixes=[],
|
||||
fixes=[
|
||||
# Could not guard on data-dependent expression u0 < 0
|
||||
"torch._check(i >= 0)",
|
||||
],
|
||||
)
|
||||
|
||||
class cf_tensorsplit(torch.nn.Module):
|
||||
@ -6187,12 +6166,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
class cf_stacklist(torch.nn.Module):
|
||||
def forward(self, xs, y):
|
||||
# y.item() is not a local, so we can't suggest a fix
|
||||
if y.item() < 0:
|
||||
return (
|
||||
torch.stack(xs, 0).narrow(0, y.item() + xs.size(), 1).squeeze()
|
||||
)
|
||||
else:
|
||||
return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
|
||||
return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
error_type,
|
||||
@ -6222,18 +6196,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
def forward(self, xs, y):
|
||||
box = Box(y.item())
|
||||
# box.content is not a local, so we can't suggest a fix
|
||||
if box.content < 0:
|
||||
return (
|
||||
torch.stack(xs, 0)
|
||||
.narrow(0, box.content + xs.size(), 1)
|
||||
.squeeze()
|
||||
)
|
||||
else:
|
||||
return (
|
||||
torch.stack(xs, 0)
|
||||
.narrow(0, box.content + xs.size(), 1)
|
||||
.squeeze()
|
||||
)
|
||||
return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
error_type,
|
||||
|
||||
@ -4401,57 +4401,6 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
|
||||
|
||||
self.assertEqual(compiled(a, b), func(a, b))
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_narrow_unbacked_start(self):
|
||||
def func(x, start, length):
|
||||
# unbacked start
|
||||
u0 = start.item()
|
||||
return torch.narrow(x, 0, u0, length)
|
||||
|
||||
compiled_func = torch.compile(func, fullgraph=True, backend="inductor")
|
||||
|
||||
x = torch.tensor([1, 2, 3, 4, 5, 6])
|
||||
|
||||
# Test cases: (start, length)
|
||||
test_cases = [
|
||||
# Negative starts
|
||||
(-2, 2), # Start from second-to-last element
|
||||
(-1, 1), # Start from last element
|
||||
(-3, 3), # Start from third-to-last element
|
||||
(-6, 2), # Start from beginning (negative)
|
||||
(-4, 1), # Start from fourth-to-last element
|
||||
# Positive starts
|
||||
(0, 2), # Start from beginning
|
||||
(1, 3), # Start from second element
|
||||
(2, 2), # Start from third element
|
||||
(4, 2), # Start near end
|
||||
# Edge cases
|
||||
(0, 6), # Full tensor
|
||||
(0, 1), # Single element from start
|
||||
(5, 1), # Single element from end
|
||||
]
|
||||
|
||||
for start_val, length in test_cases:
|
||||
with self.subTest(start=start_val, length=length):
|
||||
start = torch.tensor([start_val])
|
||||
|
||||
# Test with compiled function
|
||||
result_compiled = compiled_func(x, start, length)
|
||||
|
||||
# Test with eager function (expected behavior)
|
||||
result_eager = func(x, start, length)
|
||||
|
||||
# Compare results
|
||||
self.assertEqual(result_compiled, result_eager)
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
@torch._inductor.config.patch("cpp_wrapper", True)
|
||||
def test_narrow_unbacked_start_cpp_wrapper(self):
|
||||
"""Test narrow with unbacked start with cpp_wrapper"""
|
||||
self.test_narrow_unbacked_start()
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestUnbacked)
|
||||
|
||||
|
||||
184
test/test_fx.py
184
test/test_fx.py
@ -72,16 +72,9 @@ from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
skipIfRocm,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace
|
||||
from torch.autograd.profiler_util import _canonicalize_profiler_events
|
||||
|
||||
try:
|
||||
from torchvision import models as torchvision_models
|
||||
|
||||
@ -208,36 +201,6 @@ def side_effect_func(x: torch.Tensor):
|
||||
print(x)
|
||||
|
||||
|
||||
def _enrich_profiler_traces(prof):
|
||||
"""
|
||||
Helper function to extract and augment profiler events with stack traces.
|
||||
|
||||
Args:
|
||||
prof: A torch.profiler.profile object
|
||||
|
||||
Returns:
|
||||
A string representing enriched events
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f:
|
||||
trace_file = f.name
|
||||
prof.export_chrome_trace(trace_file)
|
||||
|
||||
with open(trace_file) as f:
|
||||
trace_data = json.load(f)
|
||||
|
||||
map_recorded_events_to_aten_ops_with_stack_trace(
|
||||
trace_data
|
||||
)
|
||||
|
||||
events = []
|
||||
for event in trace_data["traceEvents"]:
|
||||
if "args" in event and "stack_trace" in event["args"]:
|
||||
events.append(event)
|
||||
|
||||
actual_traces = _canonicalize_profiler_events(events)
|
||||
return actual_traces
|
||||
|
||||
|
||||
class TestFX(JitTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -4249,153 +4212,6 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
|
||||
# recorver mutable checking flag
|
||||
torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_stack_trace_augmentation(self):
|
||||
"""
|
||||
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
|
||||
augments profiler events with stack traces from FX metadata registry.
|
||||
"""
|
||||
|
||||
# Simple test model
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.linear2 = torch.nn.Linear(16, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.relu(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
model = TestModel().cuda()
|
||||
|
||||
# Compile the model
|
||||
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
|
||||
|
||||
# Warmup
|
||||
for _ in range(3):
|
||||
_ = compiled_model(torch.randn(10, 10, device="cuda"))
|
||||
|
||||
# Profile with the compiled model
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
) as prof:
|
||||
result = compiled_model(torch.randn(10, 10, device="cuda"))
|
||||
|
||||
actual_traces = _enrich_profiler_traces(prof)
|
||||
|
||||
self.assertExpectedInline(actual_traces, """\
|
||||
event=aten::t node=t stack_trace=x = self.linear1(x)
|
||||
event=aten::transpose node=t stack_trace=x = self.linear1(x)
|
||||
event=aten::as_strided node=t stack_trace=x = self.linear1(x)
|
||||
event=aten::addmm node=addmm stack_trace=x = self.linear1(x)
|
||||
event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x)
|
||||
event=aten::relu node=relu stack_trace=x = self.relu(x)
|
||||
event=aten::clamp_min node=relu stack_trace=x = self.relu(x)
|
||||
event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x)
|
||||
event=aten::t node=t_1 stack_trace=x = self.linear2(x)
|
||||
event=aten::transpose node=t_1 stack_trace=x = self.linear2(x)
|
||||
event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x)
|
||||
event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x)
|
||||
event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_multiple_modules(self):
|
||||
"""
|
||||
Test that multiple compiled modules under the same profiler session
|
||||
have their events correctly augmented with stack traces.
|
||||
"""
|
||||
|
||||
class ModelA(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 1
|
||||
|
||||
class ModelB(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x - 1
|
||||
|
||||
model_a = ModelA().cuda()
|
||||
model_b = ModelB().cuda()
|
||||
|
||||
# Compile both models
|
||||
compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True)
|
||||
compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True)
|
||||
|
||||
# Warmup
|
||||
for _ in range(3):
|
||||
_ = compiled_a(torch.randn(10, 10, device="cuda"))
|
||||
_ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
|
||||
|
||||
# Profile both models in the same session
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
) as prof:
|
||||
result_a = compiled_a(torch.randn(10, 10, device="cuda"))
|
||||
result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
|
||||
|
||||
actual_traces = _enrich_profiler_traces(prof)
|
||||
self.assertExpectedInline(actual_traces, """\
|
||||
event=aten::add node=add stack_trace=return x + 1
|
||||
event=cudaLaunchKernel node=add stack_trace=return x + 1
|
||||
event=aten::sub node=sub stack_trace=return x - 1
|
||||
event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_nested_graph_modules(self):
|
||||
"""
|
||||
Test that nested graph modules (e.g., graph modules calling subgraphs)
|
||||
have their events correctly augmented with stack traces.
|
||||
"""
|
||||
|
||||
# Model with nested structure
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c = 5
|
||||
|
||||
@torch.compiler.nested_compile_region
|
||||
def forward(self, x, y):
|
||||
m = torch.mul(x, y)
|
||||
s = m.sin()
|
||||
a = s + self.c
|
||||
return a
|
||||
|
||||
model = Mod().cuda()
|
||||
|
||||
# Compile the model (this may create nested graph modules)
|
||||
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
|
||||
|
||||
# Warmup
|
||||
for _ in range(3):
|
||||
_ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
|
||||
|
||||
# Profile
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
) as prof:
|
||||
result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
|
||||
|
||||
actual_traces = _enrich_profiler_traces(prof)
|
||||
self.assertExpectedInline(actual_traces, """\
|
||||
event=aten::mul node=mul stack_trace=m = torch.mul(x, y)
|
||||
event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y)
|
||||
event=aten::sin node=sin stack_trace=s = m.sin()
|
||||
event=cudaLaunchKernel node=sin stack_trace=s = m.sin()
|
||||
event=aten::add node=add stack_trace=a = s + self.c
|
||||
event=cudaLaunchKernel node=add stack_trace=a = s + self.c"""
|
||||
)
|
||||
|
||||
|
||||
def run_getitem_target():
|
||||
from torch.fx._symbolic_trace import _wrapped_methods_to_patch
|
||||
|
||||
@ -359,29 +359,6 @@ class TestMatmulCuda(InductorTestCase):
|
||||
self.assertEqual(agrad, a.grad)
|
||||
self.assertEqual(bgrad, b.grad)
|
||||
|
||||
@onlyCUDA
|
||||
@skipIfRocm
|
||||
@dtypes(torch.half, torch.bfloat16)
|
||||
@unittest.skipIf(not SM100OrLater, "cuBLAS integration for batch invariance is only on Blackwell")
|
||||
@serialTest()
|
||||
def test_cublas_batch_invariance_blackwell(self, device, dtype):
|
||||
orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
|
||||
orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (False, False)
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (False, False)
|
||||
with blas_library_context('cublaslt'):
|
||||
N = 2048
|
||||
K = 6144
|
||||
M_max = 32
|
||||
x = torch.randn(M_max, K, device="cuda", dtype=torch.bfloat16)
|
||||
w = torch.randn(N, K, device="cuda", dtype=torch.bfloat16).t()
|
||||
full = x @ w
|
||||
xx = x[:1]
|
||||
out = xx @ w
|
||||
self.assertEqual(full[:1], out, atol=0., rtol=0.)
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16
|
||||
|
||||
@unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater")
|
||||
@parametrize("strided", [False, True])
|
||||
@parametrize("a_row_major", [False, True])
|
||||
|
||||
@ -1864,8 +1864,6 @@ class TestFP8Matmul(TestCase):
|
||||
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
|
||||
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
|
||||
if torch.version.hip and recipe == "nvfp4":
|
||||
raise unittest.SkipTest("nvfp4 not supported on ROCm, skipping")
|
||||
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
|
||||
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
|
||||
|
||||
|
||||
@ -257,6 +257,34 @@ class TestFuzzerCompileIssues(TestCase):
|
||||
out_compiled.sum().backward()
|
||||
print("Compile Success! ✅")
|
||||
|
||||
@pytest.mark.xfail(reason="Issue #163971")
|
||||
def test_fuzzer_issue_163971(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
def foo(arg0):
|
||||
t0 = arg0 # size=(), stride=(), dtype=bfloat16, device=cuda
|
||||
t1 = torch.softmax(
|
||||
t0, dim=0
|
||||
) # size=(), stride=(), dtype=bfloat16, device=cuda
|
||||
t2 = torch.nn.functional.gelu(
|
||||
t1
|
||||
) # size=(), stride=(), dtype=bfloat16, device=cuda
|
||||
t3 = torch.softmax(
|
||||
t2, dim=0
|
||||
) # size=(), stride=(), dtype=bfloat16, device=cuda
|
||||
output = t3
|
||||
return output
|
||||
|
||||
arg0 = torch.rand([], dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
||||
|
||||
out_eager = foo(arg0)
|
||||
out_eager.sum().backward()
|
||||
print("Eager Success! ✅")
|
||||
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
|
||||
out_compiled = compiled_foo(arg0)
|
||||
out_compiled.sum().backward()
|
||||
print("Compile Success! ✅")
|
||||
|
||||
@pytest.mark.xfail(reason="Issue #164059")
|
||||
def test_fuzzer_issue_164059(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@ -1914,7 +1914,6 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
q, k, v, None, 0.0, is_causal=True))
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
|
||||
def test_mem_eff_attention_fail_with_batch_size_geq_65536(self):
|
||||
batch_size = 2**16
|
||||
query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
|
||||
@ -1936,7 +1935,6 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4)
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
|
||||
def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self):
|
||||
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
@ -1950,7 +1948,6 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
|
||||
@largeTensorTest("15GB", "cuda")
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
|
||||
def test_mem_eff_attention_large_seq_len_uniform_attention(self):
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
@ -2063,8 +2063,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
neg = self.codegen_sizevar(
|
||||
sympy.Max(0, sympy.Min(x + node.size, node.size))
|
||||
)
|
||||
x_cond = self.codegen_sizevar(x)
|
||||
return f"{pos} if {x_cond} >= 0 else {neg}"
|
||||
return f"{pos} if {x} >= 0 else {neg}"
|
||||
|
||||
def codegen_with_step(start_var, end_var, step):
|
||||
if step == 1:
|
||||
|
||||
@ -1224,43 +1224,3 @@ def _build_table(
|
||||
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
|
||||
)
|
||||
return "".join(result)
|
||||
|
||||
|
||||
# Collect all events with stack traces and format them canonically
|
||||
def _canonicalize_profiler_events(events):
|
||||
"""
|
||||
Extract and format all events with stack traces in a canonical way
|
||||
for deterministic testing.
|
||||
"""
|
||||
events_with_traces = []
|
||||
|
||||
for event in events:
|
||||
# Extract relevant fields
|
||||
event_name = event.get("name", "")
|
||||
node_name = event["args"].get("node_name", "")
|
||||
stack_trace = event["args"].get("stack_trace", "")
|
||||
|
||||
# Get the last non-empty line of the stack trace
|
||||
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
|
||||
stack_trace = lines[-1] if lines else ""
|
||||
|
||||
events_with_traces.append(
|
||||
{
|
||||
"event_name": event_name[:20],
|
||||
"node_name": node_name,
|
||||
"stack_trace": stack_trace,
|
||||
"start_time": event.get("ts", 0),
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by node_name for deterministic ordering
|
||||
events_with_traces.sort(key=lambda x: x["start_time"])
|
||||
|
||||
# Format as a string
|
||||
lines: list[str] = []
|
||||
for evt in events_with_traces:
|
||||
lines.append(
|
||||
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
@ -108,14 +108,12 @@ struct FlightRecorder {
|
||||
capture_cpp_stack_ = getCvarBool(
|
||||
{"TORCH_FR_CPP_STACK", "TORCH_NCCL_TRACE_CPP_STACK"}, false);
|
||||
enabled_ = max_entries_ > 0;
|
||||
reset_epoch_start_idx_[0] = 0;
|
||||
}
|
||||
struct Entry {
|
||||
size_t id_; // incremented id in the trace buffer
|
||||
// used to figure out where in the circular entries
|
||||
// buffer this entry will be located to
|
||||
// update state information
|
||||
size_t reset_epoch_; // epoch when this entry was created
|
||||
size_t pg_id_;
|
||||
std::tuple<std::string, std::string> pg_name_; // <group_name, group_desc>
|
||||
|
||||
@ -185,34 +183,11 @@ struct FlightRecorder {
|
||||
size_t max_entries_ = 0;
|
||||
size_t next_ = 0;
|
||||
size_t id_ = 0;
|
||||
size_t reset_epoch_ = 0;
|
||||
std::unordered_map<size_t, size_t>
|
||||
reset_epoch_start_idx_; // maps reset_epoch to the idx where it starts
|
||||
std::map<size_t, std::shared_ptr<ProcessGroupStatus>> all_pg_status_;
|
||||
std::map<std::tuple<std::string, std::string>, std::vector<uint64_t>>
|
||||
pg_name_to_ranks_;
|
||||
std::string comm_lib_version_;
|
||||
|
||||
struct TraceIdentifier {
|
||||
std::optional<size_t> id;
|
||||
std::optional<size_t> reset_epoch;
|
||||
};
|
||||
|
||||
TraceIdentifier recordWithResetEnabled(
|
||||
size_t pg_id,
|
||||
const std::tuple<std::string, std::string>& pg_name,
|
||||
size_t collective_seq_id,
|
||||
size_t p2p_seq_id,
|
||||
size_t op_id,
|
||||
std::string profiling_name,
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
const std::vector<at::Tensor>& outputs,
|
||||
EventType* start,
|
||||
EventType* end,
|
||||
std::chrono::milliseconds timeout_ms,
|
||||
std::shared_ptr<ProcessGroupStatus> pg_status,
|
||||
bool isP2P);
|
||||
|
||||
std::optional<size_t> record(
|
||||
size_t pg_id,
|
||||
const std::tuple<std::string, std::string>& pg_name,
|
||||
@ -238,16 +213,8 @@ struct FlightRecorder {
|
||||
|
||||
std::vector<Entry> dump_entries();
|
||||
|
||||
// Returns the index in entries_ for the given id and reset_epoch.
|
||||
// Caller must hold mutex_lock before calling this method.
|
||||
size_t getIdxFromId(size_t id, size_t reset_epoch) const;
|
||||
|
||||
// Returns the entry with the given id and reset_epoch, if it exists.
|
||||
// Otherwise, returns std::nullopt.
|
||||
TORCH_API std::optional<Entry> getEntry(
|
||||
std::optional<size_t> id,
|
||||
std::optional<size_t> reset_epoch);
|
||||
|
||||
// Returns the entry with the given id, if it exists. Otherwise, returns
|
||||
// std::nullopt.
|
||||
TORCH_API std::optional<Entry> getEntry(std::optional<size_t> id);
|
||||
|
||||
/*
|
||||
@ -260,11 +227,6 @@ struct FlightRecorder {
|
||||
never hang. (timing must also be enabled for compute_duration - see
|
||||
TORCH_NCCL_ENABLE_TIMING).
|
||||
*/
|
||||
TORCH_API void retire_id(
|
||||
std::optional<size_t> id,
|
||||
std::optional<size_t> reset_epoch,
|
||||
bool compute_duration = true);
|
||||
|
||||
TORCH_API void retire_id(
|
||||
std::optional<size_t> id,
|
||||
bool compute_duration = true);
|
||||
|
||||
@ -53,41 +53,8 @@ std::optional<size_t> FlightRecorder<EventType>::record(
|
||||
std::chrono::milliseconds timeout_ms,
|
||||
std::shared_ptr<ProcessGroupStatus> pg_status,
|
||||
bool isP2P) {
|
||||
auto result = recordWithResetEnabled(
|
||||
pg_id,
|
||||
pg_name,
|
||||
collective_seq_id,
|
||||
p2p_seq_id,
|
||||
op_id,
|
||||
std::move(profiling_name),
|
||||
inputs,
|
||||
outputs,
|
||||
start,
|
||||
end,
|
||||
timeout_ms,
|
||||
std::move(pg_status),
|
||||
isP2P);
|
||||
return result.id;
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
|
||||
recordWithResetEnabled(
|
||||
size_t pg_id,
|
||||
const std::tuple<std::string, std::string>& pg_name,
|
||||
size_t collective_seq_id,
|
||||
size_t p2p_seq_id,
|
||||
size_t op_id,
|
||||
std::string profiling_name,
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
const std::vector<at::Tensor>& outputs,
|
||||
EventType* start,
|
||||
EventType* end,
|
||||
std::chrono::milliseconds timeout_ms,
|
||||
std::shared_ptr<ProcessGroupStatus> pg_status,
|
||||
bool isP2P) {
|
||||
if (!enabled_) {
|
||||
return TraceIdentifier{std::nullopt, std::nullopt};
|
||||
return std::nullopt;
|
||||
}
|
||||
if (all_pg_status_.find(pg_id) == all_pg_status_.end()) {
|
||||
// Current pg_status is not in FR.
|
||||
@ -97,13 +64,8 @@ typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
|
||||
torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
|
||||
TORCH_CHECK(
|
||||
reset_epoch_start_idx_.find(reset_epoch_) !=
|
||||
reset_epoch_start_idx_.end());
|
||||
|
||||
auto te = Entry{
|
||||
id_,
|
||||
reset_epoch_,
|
||||
pg_id,
|
||||
pg_name,
|
||||
collective_seq_id,
|
||||
@ -142,20 +104,15 @@ typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
|
||||
te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end());
|
||||
}
|
||||
|
||||
const auto next = next_++;
|
||||
|
||||
if (entries_.size() < max_entries_) {
|
||||
entries_.emplace_back(std::move(te));
|
||||
} else {
|
||||
entries_[next] = std::move(te);
|
||||
entries_[next_++] = std::move(te);
|
||||
if (next_ == max_entries_) {
|
||||
next_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (next_ == max_entries_) {
|
||||
next_ = 0;
|
||||
}
|
||||
|
||||
const auto id = id_++;
|
||||
return TraceIdentifier{id, reset_epoch_};
|
||||
return id_++;
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
@ -206,20 +163,15 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
std::vector<Entry> result;
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
// Filter entries during insertion - only keep entries from current epoch
|
||||
auto filter = [this](const Entry& e) {
|
||||
return e.reset_epoch_ == reset_epoch_;
|
||||
};
|
||||
std::copy_if(
|
||||
result.reserve(entries_.size());
|
||||
result.insert(
|
||||
result.end(),
|
||||
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
|
||||
entries_.end(),
|
||||
std::back_inserter(result),
|
||||
filter);
|
||||
std::copy_if(
|
||||
entries_.end());
|
||||
result.insert(
|
||||
result.end(),
|
||||
entries_.begin(),
|
||||
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
|
||||
std::back_inserter(result),
|
||||
filter);
|
||||
entries_.begin() + static_cast<std::ptrdiff_t>(next_));
|
||||
}
|
||||
// query any remaining events
|
||||
for (auto& r : result) {
|
||||
@ -230,47 +182,28 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
// Returns the index in entries_ for the given id and reset_epoch.
|
||||
// Caller must hold mutex_lock before calling this method.
|
||||
size_t FlightRecorder<EventType>::getIdxFromId(size_t id, size_t reset_epoch)
|
||||
const {
|
||||
// Look up the starting idx for the given reset epoch
|
||||
auto it = reset_epoch_start_idx_.find(reset_epoch);
|
||||
TORCH_CHECK(it != reset_epoch_start_idx_.end());
|
||||
// Calculate idx based on where the epoch started
|
||||
return (it->second + id) % max_entries_;
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
// Returns the entry with the given id and reset_epoch, if it exists. Otherwise,
|
||||
// returns std::nullopt.
|
||||
// Returns the entry with the given id, if it exists. Otherwise, returns
|
||||
// std::nullopt.
|
||||
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
EventType>::
|
||||
getEntry(std::optional<size_t> id, std::optional<size_t> reset_epoch) {
|
||||
if (!enabled_ || !id || !reset_epoch) {
|
||||
EventType>::getEntry(std::optional<size_t> id) {
|
||||
if (!enabled_ || !id) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
Entry entry = entries_.at(getIdxFromId(*id, *reset_epoch));
|
||||
if (entry.id_ == *id && entry.reset_epoch_ == *reset_epoch) {
|
||||
Entry entry = entries_.at(*id % max_entries_);
|
||||
if (entry.id_ == *id) {
|
||||
return entry;
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
EventType>::getEntry(std::optional<size_t> id) {
|
||||
return getEntry(id, 0);
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
void FlightRecorder<EventType>::retire_id(
|
||||
std::optional<size_t> id,
|
||||
std::optional<size_t> reset_epoch,
|
||||
bool compute_duration) {
|
||||
if (!enabled_ || !id || !reset_epoch) {
|
||||
if (!enabled_ || !id) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -281,8 +214,8 @@ void FlightRecorder<EventType>::retire_id(
|
||||
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
|
||||
Entry* entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
|
||||
if (entry->id_ == *id && entry->reset_epoch_ == *reset_epoch) {
|
||||
Entry* entry = &entries_.at(*id % max_entries_);
|
||||
if (entry->id_ == *id) {
|
||||
update_state(*entry);
|
||||
|
||||
if (compute_duration) {
|
||||
@ -304,8 +237,8 @@ void FlightRecorder<EventType>::retire_id(
|
||||
guard.lock();
|
||||
|
||||
// Refresh the entry pointer, see if the entry has been overwritten
|
||||
entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
|
||||
if (!(entry->id_ == *id && entry->reset_epoch_ == *reset_epoch)) {
|
||||
entry = &entries_.at(*id % max_entries_);
|
||||
if (entry->id_ != *id) {
|
||||
LOG(INFO) << "retire_id abandoned for id " << *id
|
||||
<< ", event was overwritten while waiting to compute duration.";
|
||||
return;
|
||||
@ -316,23 +249,12 @@ void FlightRecorder<EventType>::retire_id(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
void FlightRecorder<EventType>::retire_id(
|
||||
std::optional<size_t> id,
|
||||
bool compute_duration) {
|
||||
retire_id(id, 0, compute_duration);
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
void FlightRecorder<EventType>::reset_all() {
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
if (!entries_.empty()) {
|
||||
// Soft delete: increment epoch to mark all existing entries as old
|
||||
// Store where the new epoch starts in the circular buffer
|
||||
reset_epoch_++;
|
||||
reset_epoch_start_idx_[reset_epoch_] = next_;
|
||||
id_ = 0;
|
||||
}
|
||||
next_ = 0;
|
||||
id_ = 0;
|
||||
entries_.clear();
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
|
||||
@ -708,8 +708,7 @@ void ProcessGroupGloo::runLoop(int workerIndex) {
|
||||
// TODO: We need to have numel of tensors for gloo as well.
|
||||
pgStatus_->lastCompletedNumelIn = 0;
|
||||
pgStatus_->lastCompletedNumelOut = 0;
|
||||
FlightRecorder<c10::Event>::get()->retire_id(
|
||||
work->trace_id_, work->trace_reset_epoch_, false);
|
||||
FlightRecorder<c10::Event>::get()->retire_id(work->trace_id_, false);
|
||||
lock.lock();
|
||||
workInProgress_[workerIndex].reset();
|
||||
}
|
||||
@ -781,7 +780,7 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
|
||||
pgStatus_->lastEnqueuedNumelOut = 0;
|
||||
// using c10d::FlightRecorder;
|
||||
// TODO: We need to have a way to use c10::Event inside gloo as well.
|
||||
auto traceId = FlightRecorder<c10::Event>::get()->recordWithResetEnabled(
|
||||
work->trace_id_ = FlightRecorder<c10::Event>::get()->record(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
collectiveCounter_,
|
||||
@ -796,8 +795,6 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
|
||||
work->getTimeout(),
|
||||
pgStatus_,
|
||||
false);
|
||||
work->trace_id_ = traceId.id;
|
||||
work->trace_reset_epoch_ = traceId.reset_epoch;
|
||||
workQueue_.push_back(std::move(work));
|
||||
lock.unlock();
|
||||
|
||||
|
||||
@ -99,7 +99,6 @@ class TORCH_API ProcessGroupGloo : public Backend {
|
||||
// unique id used to tell the trace buffer that this
|
||||
// work has completed
|
||||
std::optional<uint64_t> trace_id_;
|
||||
std::optional<uint64_t> trace_reset_epoch_;
|
||||
std::shared_ptr<gloo::Context> context_;
|
||||
const std::chrono::milliseconds timeout_;
|
||||
|
||||
|
||||
@ -575,7 +575,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
|
||||
futureWorkResult_(w.futureWorkResult_),
|
||||
timingEnabled_(w.timingEnabled_),
|
||||
trace_id_(w.trace_id_),
|
||||
trace_reset_epoch_(w.trace_reset_epoch_),
|
||||
distDebugLevel_(w.distDebugLevel_) {
|
||||
exception_ = w.exception_;
|
||||
}
|
||||
@ -705,9 +704,9 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout(
|
||||
// Print the traceback of the collective at call time
|
||||
std::string ProcessGroupNCCL::WorkNCCL::getTraceback() const {
|
||||
// First step we get the corresponding record entry from FR, based on work's
|
||||
// trace_id_ and trace_reset_epoch_
|
||||
// trace_id_
|
||||
std::optional<FlightRecorderCUDA::Entry> entry =
|
||||
FlightRecorderCUDA::get()->getEntry(trace_id_, trace_reset_epoch_);
|
||||
FlightRecorderCUDA::get()->getEntry(trace_id_);
|
||||
if (entry.has_value()) {
|
||||
auto entryVal = entry.value();
|
||||
// Get stack trace from FR entry, in string format
|
||||
@ -2395,8 +2394,7 @@ void ProcessGroupNCCL::Watchdog::runLoop() {
|
||||
pg_->pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_);
|
||||
pg_->pgStatus_->lastCompletedNumelIn = work.numelIn_;
|
||||
pg_->pgStatus_->lastCompletedNumelOut = work.numelOut_;
|
||||
FlightRecorderCUDA::get()->retire_id(
|
||||
work.trace_id_, work.trace_reset_epoch_, true);
|
||||
FlightRecorderCUDA::get()->retire_id(work.trace_id_, true);
|
||||
if (pg_->onCompletionHook_) {
|
||||
// Move Work object to completedWorkList_ to be consumed by the hook
|
||||
// thread
|
||||
@ -3362,7 +3360,7 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
|
||||
// these objects to the Work because it has implications for keeping those
|
||||
// tensors alive longer and adds overhead when copying Work objects
|
||||
// between threads
|
||||
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
|
||||
r->trace_id_ = FlightRecorderCUDA::get()->record(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
seqCollective_,
|
||||
@ -3376,8 +3374,6 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
|
||||
options_->timeout,
|
||||
pgStatus_,
|
||||
isP2P);
|
||||
r->trace_id_ = traceId.id;
|
||||
r->trace_reset_epoch_ = traceId.reset_epoch;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
@ -3681,7 +3677,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
// later in endCoalescing we record a 'coalesced' Work which has
|
||||
// timing/state updates via watchdog thread, but lacks op metadata such as
|
||||
// input/output sizes and profilingTitle per-op in the group.
|
||||
FlightRecorderCUDA::get()->recordWithResetEnabled(
|
||||
FlightRecorderCUDA::get()->record(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
seqCollective_,
|
||||
@ -4173,7 +4169,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
|
||||
// TODO(whc) because we don't pass output {tensor} to initWork, we tell
|
||||
// initWork to not record, and then we manually call record passing all the
|
||||
// information it wants.
|
||||
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
|
||||
work->trace_id_ = FlightRecorderCUDA::get()->record(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
seqCollective_,
|
||||
@ -4187,8 +4183,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
|
||||
options_->timeout,
|
||||
pgStatus_,
|
||||
/*isP2P=*/true);
|
||||
work->trace_id_ = traceId.id;
|
||||
work->trace_reset_epoch_ = traceId.reset_epoch;
|
||||
}
|
||||
|
||||
// Only check for NaN for send ops, for recv ops `tensor` can be a random
|
||||
|
||||
@ -505,7 +505,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// unique id used to tell the trace buffer that this
|
||||
// work has completed
|
||||
std::optional<uint64_t> trace_id_;
|
||||
std::optional<uint64_t> trace_reset_epoch_;
|
||||
DebugLevel distDebugLevel_;
|
||||
friend class ProcessGroupNCCL;
|
||||
};
|
||||
|
||||
@ -547,7 +547,6 @@ def rebind_unbacked(
|
||||
assert shape_env is not None
|
||||
for raw_u0, path in bindings.items():
|
||||
u1 = pytree.key_get(result, path)
|
||||
|
||||
# Sometimes, things were previously unbacked bindings become constants.
|
||||
# There are two situations this can happen.
|
||||
#
|
||||
@ -603,23 +602,7 @@ def rebind_unbacked(
|
||||
if u1.node.hint is not None:
|
||||
continue
|
||||
|
||||
# unbacked symbols bindings might be replaced to other backed or
|
||||
# unbacked replacements.
|
||||
#
|
||||
# Example:
|
||||
# u = x.item()
|
||||
# torch._check(u == 5)
|
||||
#
|
||||
# The safest approach is to retrieve raw_u1 from u1.node._expr
|
||||
# and perform the rebinding on the original unbacked symbol,
|
||||
# even if it’s no longer directly referenced.
|
||||
#
|
||||
# In other words, we should always rebind the original symbol
|
||||
# before any replacements are applied.
|
||||
# u0 -> u0 == s1
|
||||
raw_u1 = u1.node._expr
|
||||
|
||||
# TODO Do we still need this logic below?
|
||||
raw_u1 = u1.node.expr
|
||||
# Simplify SymBool binding
|
||||
if (
|
||||
isinstance(raw_u1, sympy.Piecewise)
|
||||
|
||||
@ -443,7 +443,6 @@ class CodeGen:
|
||||
colored: bool = False,
|
||||
# Render each argument on its own line
|
||||
expanded_def: bool = False,
|
||||
record_func: bool = False,
|
||||
) -> PythonCode:
|
||||
free_vars: list[str] = []
|
||||
body: list[str] = []
|
||||
@ -818,10 +817,6 @@ class CodeGen:
|
||||
return
|
||||
raise NotImplementedError(f"node: {node.op} {node.target}")
|
||||
|
||||
if record_func:
|
||||
body.append(
|
||||
"_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n"
|
||||
)
|
||||
for i, node in enumerate(nodes):
|
||||
# NOTE: emit_node does not emit a string with newline. It depends
|
||||
# on delete_unused_values to append one
|
||||
@ -831,22 +826,8 @@ class CodeGen:
|
||||
# node index, which will be deleted later
|
||||
# after going through _body_transformer
|
||||
body.append(f"# COUNTER: {i}\n")
|
||||
do_record = record_func and node.op in (
|
||||
"call_function",
|
||||
"call_method",
|
||||
"call_module",
|
||||
)
|
||||
if do_record:
|
||||
# The double hash ## convention is used by post-processing to find the fx markers
|
||||
body.append(
|
||||
f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n"
|
||||
)
|
||||
emit_node(node)
|
||||
delete_unused_values(node)
|
||||
if do_record:
|
||||
body.append(f"_rf_{node.name}.__exit__(None, None, None)\n")
|
||||
if record_func:
|
||||
body.append("_rf.__exit__(None, None, None)\n")
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
@ -1798,7 +1779,6 @@ class Graph:
|
||||
include_device: bool = False,
|
||||
colored: bool = False,
|
||||
expanded_def: bool = False,
|
||||
record_func: bool = False,
|
||||
) -> PythonCode:
|
||||
"""
|
||||
Turn this ``Graph`` into valid Python code.
|
||||
@ -1866,7 +1846,6 @@ class Graph:
|
||||
include_device=include_device,
|
||||
colored=colored,
|
||||
expanded_def=expanded_def,
|
||||
record_func=record_func,
|
||||
)
|
||||
|
||||
def _python_code(
|
||||
@ -1879,7 +1858,6 @@ class Graph:
|
||||
include_device: bool = False,
|
||||
colored: bool = False,
|
||||
expanded_def: bool = False,
|
||||
record_func: bool = False,
|
||||
) -> PythonCode:
|
||||
return self._codegen._gen_python_code(
|
||||
self.nodes,
|
||||
@ -1890,7 +1868,6 @@ class Graph:
|
||||
include_device=include_device,
|
||||
colored=colored,
|
||||
expanded_def=expanded_def,
|
||||
record_func=record_func,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
||||
@ -861,18 +861,14 @@ class {module_name}(torch.nn.Module):
|
||||
if isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||||
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||||
|
||||
from torch._dynamo import config as dynamo_config
|
||||
|
||||
python_code = self._graph.python_code(
|
||||
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
|
||||
)
|
||||
python_code = self._graph.python_code(root_module="self")
|
||||
self._code = python_code.src
|
||||
self._lineno_map = python_code._lineno_map
|
||||
self._prologue_start = python_code._prologue_start
|
||||
|
||||
cls = type(self)
|
||||
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
|
||||
from torch._dynamo import config as dynamo_config
|
||||
|
||||
if dynamo_config.enrich_profiler_metadata:
|
||||
# Generate metadata and register for profiler augmentation
|
||||
@ -889,6 +885,7 @@ class {module_name}(torch.nn.Module):
|
||||
# This ensures the same code+metadata always generates the same filename
|
||||
hash_value = _metadata_hash(self._code, node_metadata)
|
||||
file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}"
|
||||
|
||||
filename = f"{file_stem}.py"
|
||||
|
||||
# Only include co_filename to use it directly as the cache key
|
||||
@ -908,13 +905,6 @@ class {module_name}(torch.nn.Module):
|
||||
|
||||
_register_fx_metadata(filename, metadata)
|
||||
|
||||
# Replace the placeholder in generated code with actual filename
|
||||
# The double hash ## convention is used by post-processing to find the fx markers
|
||||
self._code = self._code.replace(
|
||||
"torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')",
|
||||
f"torch._C._profiler._RecordFunctionFast('## {filename} ##')",
|
||||
)
|
||||
|
||||
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
|
||||
|
||||
# Determine whether this class explicitly defines a __call__ implementation
|
||||
|
||||
@ -165,7 +165,6 @@ def insert_deferred_runtime_asserts(
|
||||
node: torch.fx.Node,
|
||||
stack_trace: Optional[str] = None,
|
||||
nn_module_stack: Optional[dict[str, Any]] = None,
|
||||
custom: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
fake_args = pytree.tree_map(
|
||||
lambda arg: (
|
||||
@ -189,8 +188,6 @@ def insert_deferred_runtime_asserts(
|
||||
node.meta["stack_trace"] = stack_trace
|
||||
if nn_module_stack is not None:
|
||||
node.meta["nn_module_stack"] = nn_module_stack
|
||||
if custom is not None:
|
||||
node.meta["custom"] = custom
|
||||
|
||||
# Track asserts/checks we've added
|
||||
added_asserts: set[sympy.Expr] = set()
|
||||
@ -620,9 +617,6 @@ def insert_deferred_runtime_asserts(
|
||||
_node_metadata_hook,
|
||||
stack_trace=node.meta.get("stack_trace"),
|
||||
nn_module_stack=node.meta.get("nn_module_stack"),
|
||||
# nodes added in `apply_runtime_assertion_pass` will have the same annotation
|
||||
# as the input node to the assertion
|
||||
custom=node.meta.get("custom"),
|
||||
),
|
||||
):
|
||||
if (min_val := convert(vr.lower)) is not None:
|
||||
|
||||
@ -4,7 +4,7 @@ import operator
|
||||
import re
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch.autograd.profiler import profile
|
||||
from torch.profiler import DeviceType
|
||||
@ -400,170 +400,3 @@ def _init_for_cuda_graphs() -> None:
|
||||
|
||||
with profile():
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimelineEvent:
|
||||
"""Represents an event in the profiler timeline."""
|
||||
|
||||
timestamp: int
|
||||
event_type: Literal["start", "end", "regular"]
|
||||
marker_type: Optional[Literal["filename", "node"]]
|
||||
identifier: Optional[str | int]
|
||||
event: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextStackEntry:
|
||||
"""Represents a context (filename or node) in the stack."""
|
||||
|
||||
context_type: Literal["filename", "node"]
|
||||
identifier: str | int
|
||||
metadata: Optional[dict]
|
||||
tid: Optional[int] = None # Thread ID associated with this context
|
||||
|
||||
|
||||
def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
|
||||
"""
|
||||
Maps recorded profiler events to their corresponding fx nodes and adds stack traces.
|
||||
|
||||
Builds a timeline of all events (regular ops and FX markers for filenames/nodes),
|
||||
sorts by timestamp, then processes chronologically while maintaining a context stack of active
|
||||
filename/node scopes. Regular events are augmented with stack traces and node names from the
|
||||
innermost active context. Runtime is O(n log n) for n events.
|
||||
|
||||
Args:
|
||||
traced_data: Json of profiler events from Chrome trace
|
||||
|
||||
Returns:
|
||||
Dict mapping recorded event names to their aten operations with added stack traces
|
||||
"""
|
||||
from torch.fx.traceback import _FX_METADATA_REGISTRY
|
||||
|
||||
trace_events = traced_data.get("traceEvents", [])
|
||||
|
||||
# Create event timeline
|
||||
event_timeline: list[TimelineEvent] = []
|
||||
|
||||
def is_fx_marker_event(event):
|
||||
return (
|
||||
event.get("cat") == "cpu_op"
|
||||
and event.get("name", "").startswith("## ")
|
||||
and event.get("name", "").endswith(" ##")
|
||||
)
|
||||
|
||||
def append_fx_marker_event(event_type, identifier, event):
|
||||
start_ts = event["ts"]
|
||||
end_ts = start_ts + event["dur"]
|
||||
event_timeline.append(
|
||||
TimelineEvent(start_ts, "start", event_type, identifier, event)
|
||||
)
|
||||
event_timeline.append(
|
||||
TimelineEvent(end_ts, "end", event_type, identifier, event)
|
||||
)
|
||||
|
||||
for event in trace_events:
|
||||
if "ts" not in event or "dur" not in event:
|
||||
continue
|
||||
|
||||
if is_fx_marker_event(event):
|
||||
content = event["name"][3:-3]
|
||||
|
||||
if content.endswith(".py"):
|
||||
append_fx_marker_event("filename", content, event)
|
||||
else:
|
||||
try:
|
||||
node_index = int(content)
|
||||
except ValueError:
|
||||
pass
|
||||
append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined]
|
||||
|
||||
else:
|
||||
# Regular event that needs augmentation
|
||||
start_ts = event["ts"]
|
||||
event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event))
|
||||
|
||||
# Sort by timestamp
|
||||
event_timeline.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# Process events in chronological order with a stack
|
||||
context_stack: list[ContextStackEntry] = []
|
||||
|
||||
# Invariant: all start event has a corresponding end event
|
||||
for timeline_event in event_timeline:
|
||||
match timeline_event.event_type:
|
||||
case "start":
|
||||
assert timeline_event.identifier is not None
|
||||
|
||||
if timeline_event.marker_type == "filename":
|
||||
assert isinstance(timeline_event.identifier, str)
|
||||
# Push filename context - query metadata registry on-demand
|
||||
metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier)
|
||||
tid = timeline_event.event.get("tid")
|
||||
context_stack.append(
|
||||
ContextStackEntry(
|
||||
"filename", timeline_event.identifier, metadata, tid
|
||||
)
|
||||
)
|
||||
elif timeline_event.marker_type == "node":
|
||||
# Find the current filename from stack
|
||||
current_file_metadata = None
|
||||
tid = timeline_event.event.get("tid")
|
||||
for ctx_entry in reversed(context_stack):
|
||||
if (
|
||||
ctx_entry.context_type == "filename"
|
||||
and ctx_entry.tid == tid
|
||||
):
|
||||
current_file_metadata = ctx_entry.metadata
|
||||
break
|
||||
|
||||
if current_file_metadata:
|
||||
node_metadata = current_file_metadata.get("node_metadata", {})
|
||||
if timeline_event.identifier in node_metadata:
|
||||
node_meta: Optional[dict] = node_metadata[
|
||||
timeline_event.identifier
|
||||
]
|
||||
context_stack.append(
|
||||
ContextStackEntry(
|
||||
"node", timeline_event.identifier, node_meta, tid
|
||||
)
|
||||
)
|
||||
|
||||
case "end":
|
||||
# Pop from stack - search backwards to find matching context
|
||||
for i in range(len(context_stack) - 1, -1, -1):
|
||||
ctx_entry = context_stack[i]
|
||||
if (
|
||||
timeline_event.marker_type == ctx_entry.context_type
|
||||
and timeline_event.identifier == ctx_entry.identifier
|
||||
):
|
||||
context_stack.pop(i)
|
||||
break
|
||||
|
||||
case "regular":
|
||||
# Apply metadata from current context stack
|
||||
# Find the most specific context (node takes precedence over filename)
|
||||
# Only augment events with the same tid as the file/node event matched
|
||||
current_stack_trace = None
|
||||
current_node_name = None
|
||||
event_tid = timeline_event.event.get("tid")
|
||||
|
||||
for ctx_entry in reversed(context_stack):
|
||||
# Only apply metadata from contexts with matching tid
|
||||
if ctx_entry.tid == event_tid:
|
||||
if ctx_entry.context_type == "node" and ctx_entry.metadata:
|
||||
current_stack_trace = ctx_entry.metadata.get(
|
||||
"stack_trace", "No model stack trace available"
|
||||
)
|
||||
current_node_name = ctx_entry.metadata.get("name", "")
|
||||
# Do we want to only attach the stack trace of the lowest node or stack trace of all nodes
|
||||
# if nodes are nested, e.g. in nested graph modules
|
||||
break
|
||||
|
||||
# Augment the event
|
||||
if current_stack_trace or current_node_name:
|
||||
args = timeline_event.event.setdefault("args", {})
|
||||
if current_stack_trace:
|
||||
args["stack_trace"] = current_stack_trace
|
||||
if current_node_name:
|
||||
args["node_name"] = current_node_name
|
||||
|
||||
@ -306,24 +306,6 @@ class PythonPrinter(ExprPrinter):
|
||||
raise TypeError("ndigits must be an instance of sympy.Integer")
|
||||
return f"round({self._print(number)}, {ndigits})"
|
||||
|
||||
def _print_Piecewise(self, expr: sympy.Expr) -> str:
|
||||
# Convert Piecewise(expr_cond_pairs) to nested ternary expressions
|
||||
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
|
||||
# becomes: e1 if c1 else (e2 if c2 else (... else eN))
|
||||
result: Optional[str] = None
|
||||
for expr_i, cond_i in reversed(expr.args):
|
||||
expr_str = self._print(expr_i)
|
||||
if cond_i == True: # noqa: E712
|
||||
# This is the default case
|
||||
result = expr_str
|
||||
else:
|
||||
cond_str = self._print(cond_i)
|
||||
if result is None:
|
||||
result = expr_str
|
||||
else:
|
||||
result = f"({expr_str} if {cond_str} else {result})"
|
||||
return result if result else "0"
|
||||
|
||||
|
||||
class CppPrinter(ExprPrinter):
|
||||
def _print_Integer(self, expr: sympy.Expr) -> str:
|
||||
@ -345,24 +327,6 @@ class CppPrinter(ExprPrinter):
|
||||
)
|
||||
return f"{c} ? {p} : {q}"
|
||||
|
||||
def _print_Piecewise(self, expr: sympy.Expr) -> str:
|
||||
# Convert Piecewise(expr_cond_pairs) to nested ternary operators
|
||||
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
|
||||
# becomes: c1 ? e1 : (c2 ? e2 : (... : eN))
|
||||
result: Optional[str] = None
|
||||
for expr_i, cond_i in reversed(expr.args):
|
||||
expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5)
|
||||
if cond_i == True: # noqa: E712
|
||||
# This is the default case
|
||||
result = expr_str
|
||||
else:
|
||||
cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5)
|
||||
if result is None:
|
||||
result = expr_str
|
||||
else:
|
||||
result = f"{cond_str} ? {expr_str} : {result}"
|
||||
return f"({result})" if result else "0"
|
||||
|
||||
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
|
||||
x, div, mod = expr.args
|
||||
x = self.doprint(x)
|
||||
|
||||
Reference in New Issue
Block a user