mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-07 01:50:04 +08:00
Compare commits
34 Commits
ciflow/ind
...
annotate_a
| Author | SHA1 | Date | |
|---|---|---|---|
| 89fb2567e7 | |||
| fd5edda1ed | |||
| 872d1daec2 | |||
| 6cd57e6fc2 | |||
| d29efba8fa | |||
| a344069f2a | |||
| af829c0dad | |||
| 3869aa115b | |||
| 47eb34b7ac | |||
| 08200280ce | |||
| ad7a57262c | |||
| 711a775878 | |||
| e9a688f02e | |||
| e69aaaf45a | |||
| fd8f368d31 | |||
| 13d2cc7bd2 | |||
| c6c913d18e | |||
| ef3f953966 | |||
| ea44f12bce | |||
| a74fe75c45 | |||
| 6d30666bc1 | |||
| 8e8cbb85ee | |||
| fbd70fb84e | |||
| 6c5db82584 | |||
| 6052a01b71 | |||
| 14b153bcf2 | |||
| 641de23c96 | |||
| 89165c0a2b | |||
| dcc2ba4ca4 | |||
| ad5c7c20e0 | |||
| c86540f120 | |||
| c17aa0f113 | |||
| 4ff068c33a | |||
| 0c7a4a6b48 |
@ -7,13 +7,13 @@ ENV LC_ALL en_US.UTF-8
|
||||
ENV LANG en_US.UTF-8
|
||||
ENV LANGUAGE en_US.UTF-8
|
||||
|
||||
ARG DEVTOOLSET_VERSION=11
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
|
||||
RUN yum -y update
|
||||
RUN yum -y install epel-release
|
||||
# install glibc-langpack-en make sure en_US.UTF-8 locale is available
|
||||
RUN yum -y install glibc-langpack-en
|
||||
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-toolchain
|
||||
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
|
||||
# Just add everything as a safe.directory for git since these will be used in multiple places with git
|
||||
RUN git config --global --add safe.directory '*'
|
||||
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||
@ -41,6 +41,7 @@ RUN bash ./install_conda.sh && rm install_conda.sh
|
||||
# Install CUDA
|
||||
FROM base as cuda
|
||||
ARG CUDA_VERSION=12.6
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
RUN rm -rf /usr/local/cuda-*
|
||||
ADD ./common/install_cuda.sh install_cuda.sh
|
||||
COPY ./common/install_nccl.sh install_nccl.sh
|
||||
@ -50,7 +51,8 @@ ENV CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
|
||||
# Preserve CUDA_VERSION for the builds
|
||||
ENV CUDA_VERSION=${CUDA_VERSION}
|
||||
# Make things in our path by default
|
||||
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:$PATH
|
||||
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||
|
||||
|
||||
FROM cuda as cuda12.6
|
||||
RUN bash ./install_cuda.sh 12.6
|
||||
@ -68,8 +70,22 @@ FROM cuda as cuda13.0
|
||||
RUN bash ./install_cuda.sh 13.0
|
||||
ENV DESIRED_CUDA=13.0
|
||||
|
||||
FROM ${ROCM_IMAGE} as rocm
|
||||
FROM ${ROCM_IMAGE} as rocm_base
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
ENV LC_ALL en_US.UTF-8
|
||||
ENV LANG en_US.UTF-8
|
||||
ENV LANGUAGE en_US.UTF-8
|
||||
# Install devtoolset on ROCm base image
|
||||
RUN yum -y update && \
|
||||
yum -y install epel-release && \
|
||||
yum -y install glibc-langpack-en && \
|
||||
yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
|
||||
RUN git config --global --add safe.directory '*'
|
||||
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||
|
||||
FROM rocm_base as rocm
|
||||
ARG PYTORCH_ROCM_ARCH
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
|
||||
ADD ./common/install_mkl.sh install_mkl.sh
|
||||
RUN bash ./install_mkl.sh && rm install_mkl.sh
|
||||
@ -88,6 +104,7 @@ COPY --from=cuda13.0 /usr/local/cuda-13.0 /usr/local/cuda-13.0
|
||||
|
||||
# Final step
|
||||
FROM ${BASE_TARGET} as final
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
COPY --from=openssl /opt/openssl /opt/openssl
|
||||
COPY --from=patchelf /patchelf /usr/local/bin/patchelf
|
||||
COPY --from=conda /opt/conda /opt/conda
|
||||
|
||||
@ -63,7 +63,7 @@ docker build \
|
||||
--target final \
|
||||
--progress plain \
|
||||
--build-arg "BASE_TARGET=${BASE_TARGET}" \
|
||||
--build-arg "DEVTOOLSET_VERSION=11" \
|
||||
--build-arg "DEVTOOLSET_VERSION=13" \
|
||||
${EXTRA_BUILD_ARGS} \
|
||||
-t ${tmp_tag} \
|
||||
$@ \
|
||||
|
||||
@ -271,6 +271,16 @@ case "$tag" in
|
||||
# from pytorch/llvm:9.0.1 is x86 specific
|
||||
SKIP_LLVM_SRC_BUILD_INSTALL=yes
|
||||
;;
|
||||
pytorch-linux-jammy-aarch64-py3.10-clang21)
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
CLANG_VERSION=21
|
||||
ACL=yes
|
||||
VISION=yes
|
||||
OPENBLAS=yes
|
||||
# snadampal: skipping llvm src build install because the current version
|
||||
# from pytorch/llvm:9.0.1 is x86 specific
|
||||
SKIP_LLVM_SRC_BUILD_INSTALL=yes
|
||||
;;
|
||||
pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks)
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
GCC_VERSION=11
|
||||
|
||||
@ -1 +1 @@
|
||||
7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd
|
||||
bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7
|
||||
|
||||
@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then
|
||||
# work around ubuntu apt-get conflicts
|
||||
sudo apt-get -y -f install
|
||||
wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
|
||||
if [[ $CLANG_VERSION == 18 ]]; then
|
||||
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main"
|
||||
if [[ $CLANG_VERSION -ge 18 ]]; then
|
||||
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
@ -129,7 +129,7 @@ function install_129 {
|
||||
}
|
||||
|
||||
function install_128 {
|
||||
CUDNN_VERSION=9.10.2.21
|
||||
CUDNN_VERSION=9.8.0.87
|
||||
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
|
||||
# install CUDA 12.8.1 in the same container
|
||||
install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux
|
||||
|
||||
@ -10,6 +10,7 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" -
|
||||
|
||||
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
|
||||
OPENBLAS_BUILD_FLAGS="
|
||||
CC=gcc
|
||||
NUM_THREADS=128
|
||||
USE_OPENMP=1
|
||||
NO_SHARED=0
|
||||
|
||||
@ -1 +1 @@
|
||||
3.5.0
|
||||
3.5.1
|
||||
|
||||
@ -272,18 +272,6 @@ def smoke_test_cuda(
|
||||
torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version())
|
||||
print(f"Torch cuDNN version: {torch_cudnn_version}")
|
||||
|
||||
torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion()
|
||||
print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}")
|
||||
torch_cudnn_runtime_version = tuple(
|
||||
[int(x) for x in torch_cudnn_version.split(".")]
|
||||
)
|
||||
if torch_cudnn_runtime_version != torch_cudnn_compile_version:
|
||||
raise RuntimeError(
|
||||
"cuDNN runtime version doesn't match comple version. "
|
||||
f"Loaded: {torch_cudnn_runtime_version} "
|
||||
f"Expected: {torch_cudnn_compile_version}"
|
||||
)
|
||||
|
||||
if sys.platform in ["linux", "linux2"]:
|
||||
torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version())
|
||||
print(f"Torch nccl; version: {torch_nccl_version}")
|
||||
|
||||
2
.github/workflows/docker-builds.yml
vendored
2
.github/workflows/docker-builds.yml
vendored
@ -79,6 +79,8 @@ jobs:
|
||||
include:
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
timeout-minutes: 600
|
||||
|
||||
@ -18,7 +18,7 @@ aspects of contributing to PyTorch.
|
||||
- [Python Unit Testing](#python-unit-testing)
|
||||
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
|
||||
- [Local linting](#local-linting)
|
||||
- [Running `mypy`](#running-mypy)
|
||||
- [Running `pyrefly`](#running-pyrefly)
|
||||
- [C++ Unit Testing](#c-unit-testing)
|
||||
- [Run Specific CI Jobs](#run-specific-ci-jobs)
|
||||
- [Merging your Change](#merging-your-change)
|
||||
@ -281,7 +281,7 @@ dependencies as well as the nightly binaries into the repo directory.
|
||||
**Prerequisites**:
|
||||
The following packages should be installed with `pip`:
|
||||
- `expecttest` and `hypothesis` - required to run tests
|
||||
- `mypy` - recommended for linting
|
||||
- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/)
|
||||
- `pytest` - recommended to run tests more selectively
|
||||
Running
|
||||
```
|
||||
@ -350,15 +350,32 @@ make lint
|
||||
|
||||
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
|
||||
|
||||
#### Running `mypy`
|
||||
#### Running `pyrefly`
|
||||
|
||||
`mypy` is an optional static type checker for Python. We have multiple `mypy`
|
||||
configs for the PyTorch codebase that are automatically validated against whenever the linter is run.
|
||||
[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback.
|
||||
|
||||
PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository.
|
||||
|
||||
**Getting Started with Pyrefly:**
|
||||
|
||||
To run type checking on the PyTorch codebase:
|
||||
```bash
|
||||
pyrefly check
|
||||
```
|
||||
|
||||
For more detailed error information with summaries:
|
||||
```bash
|
||||
pyrefly check --summarize-errors
|
||||
```
|
||||
|
||||
**Learn More:**
|
||||
- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options
|
||||
- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking
|
||||
- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations
|
||||
|
||||
See [Guide for adding type annotations to
|
||||
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
|
||||
for more information on how to set up `mypy` and tackle type annotation
|
||||
tasks.
|
||||
for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase.
|
||||
|
||||
### C++ Unit Testing
|
||||
|
||||
|
||||
@ -388,6 +388,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
#ifndef USE_ROCM
|
||||
at::Half halpha;
|
||||
at::Half hbeta;
|
||||
uint32_t mask = -1;
|
||||
#endif
|
||||
void * alpha_ptr = α
|
||||
void * beta_ptr = β
|
||||
@ -427,7 +428,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
|
||||
if (fp16_reduction !=
|
||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||
uint32_t mask =
|
||||
mask =
|
||||
fp16_reduction ==
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||
@ -444,7 +445,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
|
||||
if (bf16_reduction !=
|
||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||
uint32_t mask =
|
||||
mask =
|
||||
bf16_reduction ==
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||
@ -511,17 +512,41 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||
int returnedResult = 0;
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
Adesc.descriptor(),
|
||||
Bdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
preference.descriptor(),
|
||||
1,
|
||||
&heuristicResult,
|
||||
&returnedResult));
|
||||
// on Blackwell+, we fake a n > 1 matmul when querying heuristics
|
||||
// to prevent cuBLASLt from dispatching to a GEMV kernel for batch-invariance
|
||||
#ifndef USE_ROCM
|
||||
const bool lie_to_cublaslt = mask == CUBLASLT_REDUCTION_SCHEME_NONE && n == 1 && at::cuda::getCurrentDeviceProperties()->major >= 10;
|
||||
#else
|
||||
const bool lie_to_cublaslt = false;
|
||||
#endif
|
||||
if (lie_to_cublaslt) {
|
||||
CuBlasLtMatrixLayout FakeBdesc(abType, k, 2, ldb, opb == CUBLAS_OP_T);
|
||||
CuBlasLtMatrixLayout FakeCdesc(cType, m, 2, ldc);
|
||||
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
Adesc.descriptor(),
|
||||
FakeBdesc.descriptor(),
|
||||
FakeCdesc.descriptor(),
|
||||
FakeCdesc.descriptor(),
|
||||
preference.descriptor(),
|
||||
1,
|
||||
&heuristicResult,
|
||||
&returnedResult));
|
||||
} else {
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
Adesc.descriptor(),
|
||||
Bdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
preference.descriptor(),
|
||||
1,
|
||||
&heuristicResult,
|
||||
&returnedResult));
|
||||
}
|
||||
if (returnedResult == 0) {
|
||||
cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
|
||||
}
|
||||
|
||||
@ -24,7 +24,13 @@ namespace detail {
|
||||
// radix_sort_pairs doesn't interact with value_t other than to copy
|
||||
// the data, so we can save template instantiations by reinterpreting
|
||||
// it as an opaque type.
|
||||
// We use native integer types for 1/2/4/8-byte values to reduce
|
||||
// register usage in CUDA kernels. For sizes > 8 fall back to char array.
|
||||
template <int N> struct alignas(N) OpaqueType { char data[N]; };
|
||||
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
|
||||
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
|
||||
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
|
||||
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
|
||||
|
||||
template<typename key_t, int value_size>
|
||||
void radix_sort_pairs_impl(
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
@ -1710,11 +1711,37 @@ Tensor narrow_symint(
|
||||
"], but got ",
|
||||
start,
|
||||
")")
|
||||
if (start < 0) {
|
||||
start = start + cur_size;
|
||||
|
||||
auto cond1 = TORCH_GUARD_OR_FALSE(start.sym_lt(0));
|
||||
auto cond2 = TORCH_GUARD_OR_FALSE(start.sym_ge(0));
|
||||
|
||||
if (cond1 || cond2) {
|
||||
if (cond1) {
|
||||
start = start + cur_size;
|
||||
}
|
||||
|
||||
TORCH_SYM_CHECK(
|
||||
start.sym_le(cur_size - length),
|
||||
"start (",
|
||||
start,
|
||||
") + length (",
|
||||
length,
|
||||
") exceeds dimension size (",
|
||||
cur_size,
|
||||
").");
|
||||
return at::slice_symint(self, dim, start, start + length, 1);
|
||||
}
|
||||
|
||||
// Unbacked start handling!
|
||||
|
||||
// Bounds check without converting start:
|
||||
// - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start +
|
||||
// length <= 0
|
||||
// - If start >= 0: need start + length <= cur_size
|
||||
auto end = start + length;
|
||||
TORCH_SYM_CHECK(
|
||||
start.sym_le(cur_size - length),
|
||||
(start.sym_lt(0).sym_and((end).sym_le(0)))
|
||||
.sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))),
|
||||
"start (",
|
||||
start,
|
||||
") + length (",
|
||||
@ -1722,7 +1749,28 @@ Tensor narrow_symint(
|
||||
") exceeds dimension size (",
|
||||
cur_size,
|
||||
").");
|
||||
return at::slice_symint(self, dim, start, start + length, 1);
|
||||
|
||||
if (TORCH_GUARD_OR_FALSE(end.sym_ne(0))) {
|
||||
return at::slice_symint(self, dim, start, end, 1);
|
||||
} else {
|
||||
// Cannot statically determine the condition due to unbacked.
|
||||
// This is an interesting situation; when start is negative and
|
||||
// start + length == 0, slice and narrow do different things.
|
||||
// i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to
|
||||
// pass curr_size instead of 0. Otherwise, they would do the same thing.
|
||||
// This says at runtime: if start < 0 and end == 0, then pass curr_size
|
||||
// instead of 0.
|
||||
|
||||
auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt();
|
||||
auto result =
|
||||
at::slice_symint(self, dim, start, end + use_different * cur_size, 1);
|
||||
|
||||
// Ensure slice allocated unbacked size is specialized to length.
|
||||
SymInt new_size = result.sym_size(dim);
|
||||
TORCH_SYM_CHECK(new_size.sym_eq(length), "")
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// This overload exists purely for XLA, because they wanted to pass in
|
||||
|
||||
@ -247,8 +247,8 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
|
||||
});
|
||||
}
|
||||
|
||||
template <typename func_t, typename vec_func_t>
|
||||
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) {
|
||||
template <typename func_t, typename vec_func_t, typename ident_t = double>
|
||||
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, ident_t ident = static_cast<ident_t>(0)) {
|
||||
using traits = binary_function_traits<func_t>;
|
||||
static_assert(
|
||||
all_same<
|
||||
|
||||
@ -339,33 +339,13 @@ void or_kernel_impl(TensorIterator& iter) {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
struct MinValuesOps: public at::native::MinOps<scalar_t> {
|
||||
using arg_t = typename MinOps<scalar_t>::arg_t;
|
||||
static scalar_t project(arg_t arg) {
|
||||
return arg.first;
|
||||
}
|
||||
};
|
||||
|
||||
void min_values_kernel_impl(TensorIterator& iter) {
|
||||
// This case is special because of Vectorized<int64_t> does not
|
||||
// handle upper_bound<int64_t>().
|
||||
// See: https://github.com/pytorch/pytorch/issues/43254
|
||||
if (iter.dtype() == kLong || iter.dtype() == kUInt64) {
|
||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
|
||||
binary_kernel_reduce(
|
||||
iter,
|
||||
MinValuesOps<scalar_t>{},
|
||||
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
|
||||
}), kLong, kUInt64);
|
||||
return;
|
||||
}
|
||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
|
||||
binary_kernel_reduce_vec(
|
||||
iter,
|
||||
[](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); },
|
||||
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return minimum(a, b); },
|
||||
static_cast<double>(upper_bound<scalar_t>()));
|
||||
upper_bound<scalar_t>());
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
}
|
||||
|
||||
|
||||
@ -22,6 +22,9 @@
|
||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
||||
#include <ATen/native/cuda/ScaledGroupMM.h>
|
||||
#include <ATen/native/cuda/GroupMM.h>
|
||||
#ifdef USE_ROCM
|
||||
#include <ATen/native/hip/ck_group_gemm.h>
|
||||
#endif
|
||||
#include <ATen/ceil_div.h>
|
||||
|
||||
#ifdef USE_FBGEMM_GENAI
|
||||
@ -666,12 +669,19 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||
bool use_fast_path = false;
|
||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
|
||||
use_fast_path = true;
|
||||
}
|
||||
#endif
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
if (use_fast_path) {
|
||||
// fast path, no d2h sync needed
|
||||
#ifndef USE_ROCM
|
||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||
#else
|
||||
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
|
||||
#endif
|
||||
} else {
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
}
|
||||
|
||||
19
aten/src/ATen/native/hip/ck_group_gemm.h
Normal file
19
aten/src/ATen/native/hip/ck_group_gemm.h
Normal file
@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <optional>
|
||||
|
||||
namespace at {
|
||||
namespace hip {
|
||||
namespace detail {
|
||||
void group_gemm_ck(
|
||||
const at::Tensor& mat_a,
|
||||
const at::Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::Tensor& out);
|
||||
|
||||
} // namespace detail
|
||||
} // namespace hip
|
||||
} // namespace at
|
||||
462
aten/src/ATen/native/hip/ck_group_gemm.hip
Normal file
462
aten/src/ATen/native/hip/ck_group_gemm.hip
Normal file
@ -0,0 +1,462 @@
|
||||
#undef __HIP_NO_HALF_CONVERSIONS__
|
||||
#include <ATen/hip/HIPContext.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/TensorAccessor.h>
|
||||
#include <c10/hip/HIPStream.h>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
#include <ck/ck.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/utility/tuple.hpp>
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
namespace at {
|
||||
namespace hip {
|
||||
namespace detail {
|
||||
|
||||
namespace CkTypes {
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename DataType>
|
||||
using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage<
|
||||
ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor,
|
||||
DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType,
|
||||
CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough,
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2,
|
||||
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
|
||||
3, 8, 8, 1,
|
||||
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
|
||||
3, 8, 8, 1,
|
||||
1, 1,
|
||||
S<1,32,1,8>, 4
|
||||
>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename DataType>
|
||||
void launch_grouped_bgemm_ck_impl_dispatch(
|
||||
const at::Tensor& mat_a,
|
||||
const at::Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
at::Tensor& out)
|
||||
{
|
||||
using DeviceOp = GroupedGemmKernel<ALayout, BLayout, DataType>;
|
||||
using PassThrough = CkTypes::PassThrough;
|
||||
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
std::vector<const void*> p_a_ptrs, p_b_ptrs;
|
||||
std::vector<void*> p_e_ptrs;
|
||||
// Note: d_ptrs will be resized after we populate the other vectors
|
||||
|
||||
const int mat_a_dim = mat_a.dim();
|
||||
const int mat_b_dim = mat_b.dim();
|
||||
|
||||
const char* a_ptr_base = reinterpret_cast<const char*>(mat_a.data_ptr());
|
||||
const char* b_ptr_base = reinterpret_cast<const char*>(mat_b.data_ptr());
|
||||
char* out_ptr_base = reinterpret_cast<char*>(out.data_ptr());
|
||||
const size_t a_element_size = mat_a.element_size();
|
||||
const size_t b_element_size = mat_b.element_size();
|
||||
const size_t out_element_size = out.element_size();
|
||||
|
||||
// for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses.
|
||||
if (mat_a_dim == 2 && mat_b_dim == 2) {
|
||||
// 2D*2D case requires offset tensor
|
||||
auto offs_accessor = offs->accessor<int, 1>();
|
||||
int num_groups = offs_accessor.size(0);
|
||||
const int M = mat_a.size(0); // number of rows in A
|
||||
const int N = mat_b.size(1); // number of columns in B
|
||||
const int K = mat_a.size(1); // columns in A == rows in B
|
||||
// for 2d*2d input, output is 3d.
|
||||
// for each group, A columns (K) are sliced. M and N dimensions are not sliced.
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int start_k = (i == 0) ? 0 : offs_accessor[i-1];
|
||||
int end_k = offs_accessor[i];
|
||||
int k = end_k - start_k;
|
||||
|
||||
//K dimension are sliced, hence select stride(1) always.
|
||||
//K dimension is always dimension 1, regardless of memory layout (row/column major)
|
||||
const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size;
|
||||
const void* group_b_ptr;
|
||||
int ldb;
|
||||
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset
|
||||
group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size;
|
||||
// Leading dimension = distance between rows = stride(0)
|
||||
ldb = mat_b.stride(0);
|
||||
} else {
|
||||
// Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset
|
||||
group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size;
|
||||
// Leading dimension = distance between columns = stride(1)
|
||||
ldb = mat_b.stride(1);
|
||||
}
|
||||
|
||||
// Calculate output pointer for group i in 3D tensor [num_groups, M, N]
|
||||
// stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i
|
||||
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
|
||||
int lda, ldc;
|
||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major A [M,K]: leading dimension = distance between rows = stride(0)
|
||||
lda = mat_a.stride(0);
|
||||
} else {
|
||||
// Column-major A [M,K]: leading dimension = distance between columns = stride(1)
|
||||
lda = mat_a.stride(1);
|
||||
}
|
||||
// Output is always row-major in 3D tensor [num_groups, M, N]
|
||||
// Leading dimension for each group's [M,N] slice = stride(1) = N
|
||||
ldc = out.stride(1);
|
||||
size_t output_group_bytes = M * N * out_element_size;
|
||||
void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes;
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(M),
|
||||
static_cast<ck::index_t>(N),
|
||||
static_cast<ck::index_t>(k),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else if (mat_a_dim == 2 && mat_b_dim == 3) {
|
||||
// 2D*3D case requires offset tensor
|
||||
auto offs_accessor = offs->accessor<int, 1>();
|
||||
int num_groups = offs_accessor.size(0);
|
||||
|
||||
// 2d*3d input, output is 2d.
|
||||
// A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n]
|
||||
// Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B
|
||||
const int K = mat_a.size(1); // columns in A
|
||||
// For 2D-3D case: The output determines N (result width)
|
||||
const int N = out.size(1); // N is the width of the output tensor
|
||||
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int start_m = (i == 0) ? 0 : offs_accessor[i - 1];
|
||||
int end_m = offs_accessor[i];
|
||||
int m = end_m - start_m;
|
||||
|
||||
// Skip zero-sized groups but continue processing subsequent groups
|
||||
if (m <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Select A rows for group i: skip start_m rows
|
||||
const void* group_a_ptr;
|
||||
int lda;
|
||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart
|
||||
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
|
||||
lda = mat_a.stride(0); // distance between rows
|
||||
} else {
|
||||
// Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows)
|
||||
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
|
||||
|
||||
// Detect stride pattern for A tensor to determine appropriate lda calculation
|
||||
bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0));
|
||||
|
||||
if (a_is_strided_tensor) {
|
||||
// For strided A tensors: stride(0) gives the actual leading dimension
|
||||
lda = mat_a.stride(0);
|
||||
} else {
|
||||
// For non-strided A tensors: use the M dimension (total rows)
|
||||
lda = mat_a.size(0); // Total M dimension for column-major layout
|
||||
}
|
||||
}
|
||||
|
||||
// Select B batch for group i: B[i, :, :]
|
||||
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
|
||||
int ldb;
|
||||
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed
|
||||
ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N]
|
||||
} else {
|
||||
// Detect stride pattern to determine appropriate ldb calculation
|
||||
bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2));
|
||||
|
||||
if (is_strided_tensor) {
|
||||
// For strided tensors: stride(2) gives the actual leading dimension
|
||||
ldb = mat_b.stride(2);
|
||||
} else {
|
||||
// For non-strided tensors: use the N dimension
|
||||
ldb = mat_b.size(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N]
|
||||
void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size;
|
||||
int ldc = out.stride(0); // distance between rows in output (should be N for 2D case)
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(m),
|
||||
static_cast<ck::index_t>(N),
|
||||
static_cast<ck::index_t>(K),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else if (mat_a_dim == 3 && mat_b_dim == 3) {
|
||||
// 3d*3d input, output is 3d - batched matrix multiplication
|
||||
// A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n]
|
||||
// Each batch is processed as a separate GEMM operation
|
||||
const int batch_size = mat_a.size(0);
|
||||
const int M = mat_a.size(1); // rows in each A matrix
|
||||
const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed)
|
||||
|
||||
// Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout
|
||||
int N;
|
||||
if (mat_b.size(1) == K) {
|
||||
// B is [batch, k, n] - normal layout
|
||||
N = mat_b.size(2);
|
||||
} else if (mat_b.size(2) == K) {
|
||||
// B is [batch, n, k] - transposed layout
|
||||
N = mat_b.size(1);
|
||||
} else {
|
||||
TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[",
|
||||
batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]");
|
||||
}
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
// Select A batch for group i: A[i, :, :]
|
||||
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
|
||||
|
||||
// Select B batch for group i: B[i, :, :]
|
||||
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
|
||||
|
||||
// Select output batch for group i: Output[i, :, :]
|
||||
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
|
||||
|
||||
int lda, ldb, ldc;
|
||||
|
||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major A: leading dimension = distance between rows = stride(1)
|
||||
lda = mat_a.stride(1);
|
||||
} else {
|
||||
// Column-major A: leading dimension = distance between columns = stride(2)
|
||||
lda = mat_a.stride(2);
|
||||
}
|
||||
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major B: leading dimension = distance between rows
|
||||
if (mat_b.size(1) == K) {
|
||||
// B is [batch, k, n] - normal layout
|
||||
ldb = mat_b.stride(1); // stride between K rows
|
||||
} else {
|
||||
// B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM
|
||||
ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n])
|
||||
}
|
||||
} else {
|
||||
// Column-major B: leading dimension = distance between columns
|
||||
if (mat_b.size(1) == K) {
|
||||
// B is [batch, k, n] - normal layout
|
||||
ldb = mat_b.stride(2); // stride between N columns
|
||||
} else {
|
||||
// B is [batch, n, k] - transposed layout
|
||||
ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n])
|
||||
}
|
||||
}
|
||||
|
||||
// Output is typically row-major: leading dimension = distance between rows = stride(1)
|
||||
ldc = out.stride(1);
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(M),
|
||||
static_cast<ck::index_t>(N),
|
||||
static_cast<ck::index_t>(K),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else if (mat_a_dim == 3 && mat_b_dim == 2) {
|
||||
// 3D*2D case requires offset tensor
|
||||
auto offs_accessor = offs->accessor<int, 1>();
|
||||
int num_groups = offs_accessor.size(0);
|
||||
// 3d*2d input, output is 3d.
|
||||
// A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both)
|
||||
// Offset divides N dimension of B, each group gets different slice of B and different batch of A
|
||||
const int batch_size = mat_a.size(0); // n_groups
|
||||
const int M = mat_a.size(1); // rows in each A matrix
|
||||
const int K = mat_a.size(2); // columns in A
|
||||
|
||||
// For row-major A and B case: B should be [K, total_N]
|
||||
const int total_N = mat_b.size(1); // B is [K, total_N] for row-major
|
||||
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int start_n = (i == 0) ? 0 : offs_accessor[i - 1];
|
||||
int end_n = offs_accessor[i];
|
||||
int n = end_n - start_n;
|
||||
|
||||
// Skip zero-sized groups but continue processing subsequent groups
|
||||
if (n <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Select A batch for group i: A[i, :, :]
|
||||
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
|
||||
|
||||
// Select B slice for group i: B[:, start_n:end_n] (B[K, total_N])
|
||||
const void* group_b_ptr;
|
||||
int ldb;
|
||||
|
||||
// Check if B is row-major or column-major
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major B [K, total_N]: slice columns [start_n:end_n]
|
||||
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
|
||||
ldb = mat_b.stride(0); // distance between rows (should be total_N)
|
||||
} else {
|
||||
// Column-major B [K, total_N]: slice columns [start_n:end_n]
|
||||
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
|
||||
ldb = mat_b.stride(1); // distance between columns (should be K)
|
||||
}
|
||||
|
||||
// Select output slice for group i: Output[:, start_n:end_n]
|
||||
void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size;
|
||||
|
||||
int lda, ldc;
|
||||
|
||||
// Row-major A: leading dimension = distance between rows = stride(1)
|
||||
lda = mat_a.stride(1);
|
||||
// Output is row-major: leading dimension = distance between rows = stride(0)
|
||||
ldc = out.stride(0);
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(M),
|
||||
static_cast<ck::index_t>(n),
|
||||
static_cast<ck::index_t>(K),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim);
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups");
|
||||
|
||||
// Initialize d_ptrs with the correct size
|
||||
std::vector<std::array<const void*, 0>> d_ptrs(p_a_ptrs.size());
|
||||
|
||||
static DeviceOp gemm_instance;
|
||||
auto argument = gemm_instance.MakeArgument(
|
||||
p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs,
|
||||
gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}
|
||||
);
|
||||
TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument),
|
||||
"CK Group GEMM: argument unsupported (shape/strides/type config)");
|
||||
size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument);
|
||||
size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument);
|
||||
|
||||
void* gemm_arg_buf = nullptr;
|
||||
void* ws_buf = nullptr;
|
||||
|
||||
hipMalloc(&gemm_arg_buf, arg_buf_size);
|
||||
hipMalloc(&ws_buf, ws_size);
|
||||
|
||||
gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf);
|
||||
gemm_instance.SetWorkSpacePointer(&argument, ws_buf);
|
||||
|
||||
auto invoker = gemm_instance.MakeInvoker();
|
||||
hipStream_t stream = c10::hip::getCurrentHIPStream();
|
||||
invoker.Run(argument, {stream});
|
||||
hipFree(gemm_arg_buf);
|
||||
hipFree(ws_buf);
|
||||
}
|
||||
|
||||
void group_gemm_ck(
|
||||
const at::Tensor& input_a,
|
||||
const at::Tensor& input_b_colmajor,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& /*bias*/,
|
||||
at::Tensor& out)
|
||||
{
|
||||
// Detect if input_a is row-major based on stride pattern
|
||||
bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1);
|
||||
bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1);
|
||||
// Ensure tensor A is row-major and contiguous if not already
|
||||
at::Tensor mat_a = input_a;
|
||||
if (!a_row_major) {
|
||||
// If A is not row-major, make it contiguous (row-major)
|
||||
mat_a = input_a.contiguous();
|
||||
}
|
||||
// Force tensor B to be column-major using double transpose trick
|
||||
// This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape
|
||||
at::Tensor mat_b = input_b_colmajor;
|
||||
if (!b_col_major) {
|
||||
mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1);
|
||||
}
|
||||
|
||||
// For 3D tensors, check the last dimension stride for row-major detection
|
||||
a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1);
|
||||
bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1);
|
||||
|
||||
if (mat_a.dtype() == at::kBFloat16) {
|
||||
// bf16 path
|
||||
if (a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
} else if (a_row_major && !b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
} else if (!a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
} else {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
}
|
||||
} else if (mat_a.dtype() == at::kHalf) {
|
||||
// fp16 path
|
||||
if (a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
} else if (a_row_major && !b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
} else if (!a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
} else {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
}
|
||||
} else if (mat_a.dtype() == at::kFloat) {
|
||||
// fp32 path
|
||||
if (a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
} else if (a_row_major && !b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
} else if (!a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
} else {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace hip
|
||||
} // namespace at
|
||||
@ -1,4 +1,5 @@
|
||||
#include <c10/core/SymBool.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
|
||||
namespace c10 {
|
||||
@ -111,4 +112,17 @@ bool SymBool::has_hint() const {
|
||||
return toSymNodeImpl()->has_hint();
|
||||
}
|
||||
|
||||
SymInt SymBool::toSymInt() const {
|
||||
// If concrete bool, return concrete SymInt
|
||||
if (auto ma = maybe_as_bool()) {
|
||||
return SymInt(*ma ? 1 : 0);
|
||||
}
|
||||
|
||||
// Symbolic case: use sym_ite to convert bool to int (0 or 1)
|
||||
auto node = toSymNodeImpl();
|
||||
auto one_node = node->wrap_int(1);
|
||||
auto zero_node = node->wrap_int(0);
|
||||
return SymInt(node->sym_ite(one_node, zero_node));
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
@ -12,6 +12,8 @@
|
||||
|
||||
namespace c10 {
|
||||
|
||||
class SymInt;
|
||||
|
||||
class C10_API SymBool {
|
||||
public:
|
||||
/*implicit*/ SymBool(bool b) : data_(b) {}
|
||||
@ -80,6 +82,10 @@ class C10_API SymBool {
|
||||
return toSymNodeImplUnowned()->constant_bool();
|
||||
}
|
||||
|
||||
// Convert SymBool to SymInt (0 or 1)
|
||||
// This is the C++ equivalent of Python's cast_symbool_to_symint_guardless
|
||||
SymInt toSymInt() const;
|
||||
|
||||
bool is_heap_allocated() const {
|
||||
return ptr_;
|
||||
}
|
||||
|
||||
@ -47,20 +47,10 @@ Tensor sgd_out_of_place(
|
||||
STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1");
|
||||
STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1");
|
||||
|
||||
int64_t *param_sizes;
|
||||
int64_t *param_strides;
|
||||
aoti_torch_get_sizes(param.get(), ¶m_sizes);
|
||||
aoti_torch_get_strides(param.get(), ¶m_strides);
|
||||
// testing Tensor strides + stride
|
||||
STD_TORCH_CHECK(param.strides()[0] == param.stride(0));
|
||||
|
||||
int32_t param_dtype;
|
||||
aoti_torch_get_dtype(param.get(), ¶m_dtype);
|
||||
|
||||
int32_t param_device_type;
|
||||
aoti_torch_get_device_type(param.get(), ¶m_device_type);
|
||||
|
||||
AtenTensorHandle out_ath;
|
||||
aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath);
|
||||
auto out = Tensor(out_ath);
|
||||
auto out = new_empty(param, param.sizes());
|
||||
|
||||
sgd_math(
|
||||
reinterpret_cast<float*>(param.data_ptr()),
|
||||
@ -344,6 +334,8 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
// Still using a std::vector below even though people can just pass in an
|
||||
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
|
||||
// directly.
|
||||
// This is to test that passing in a std::vector works for BC. (It gets
|
||||
// implicitly converted to HeaderOnlyArrayRef too!)
|
||||
std::vector<int64_t> sizes = {2, 5};
|
||||
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
|
||||
return new_empty(t, sizes, dtype);
|
||||
|
||||
@ -5,8 +5,16 @@ import contextlib
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
Partial,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -426,6 +434,31 @@ class TestDTensorDebugMode(TestCase):
|
||||
][-1]
|
||||
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
|
||||
|
||||
def test_pretty_print_dtensor_make_fx(self):
|
||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
||||
A = torch.randn(8, 32)
|
||||
B = torch.randn(32, 32)
|
||||
dA = distribute_tensor(A, mesh, [Shard(0)]).requires_grad_()
|
||||
dB = distribute_tensor(B, mesh, [Replicate()]).requires_grad_()
|
||||
|
||||
def f(dA, dB):
|
||||
dy = dA @ dB
|
||||
loss = dy.sum()
|
||||
loss.backward()
|
||||
return dA.grad, dB.grad
|
||||
|
||||
# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
|
||||
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
|
||||
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
|
||||
gm = make_fx(f, tracing_mode="fake")(dA, dB)
|
||||
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
|
||||
gm.graph.eliminate_dead_code()
|
||||
gm.recompile()
|
||||
# Colored is nice for actual viewing, not using in this test though
|
||||
gm_str = gm.print_readable(colored=False, print_output=False)
|
||||
self.assertTrue('"DTensor(f32[8, 32], S(0))" = torch.ops.aten.mm' in gm_str)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestDTensorDebugMode)
|
||||
|
||||
|
||||
@ -5789,6 +5789,229 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
else:
|
||||
self.assertTrue("duration_ms" not in t["entries"][0])
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset_circular_buffer_full(self, timing_enabled):
|
||||
"""
|
||||
Test that when the circular buffer in entries_ is full and we call reset,
|
||||
then fill the buffer with new entries, dump_entries returns only the new
|
||||
entries and not the old ones.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# Fill the buffer completely with 10 entries
|
||||
for _ in range(10):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Verify buffer is full with 10 entries
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 10)
|
||||
|
||||
# Now reset the flight recorder
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Add new entries after reset - fill the buffer completely again
|
||||
for _ in range(10):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Verify we get exactly 10 new entries, not 20
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 10)
|
||||
|
||||
# Verify all entries have the expected properties (from after reset)
|
||||
# After reset, record IDs should start from 0 again
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("profiling_name", entry)
|
||||
self.assertEqual(entry["profiling_name"], "nccl:all_reduce")
|
||||
self.assertIn("record_id", entry)
|
||||
# Record IDs should be sequential starting from 0 after reset
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset_partial_overwrite(self, timing_enabled):
|
||||
"""
|
||||
Test that when the circular buffer is full, we reset, and then add fewer
|
||||
entries than the buffer size, we only get the new entries.
|
||||
This tests that old entries at the end of the circular buffer are properly
|
||||
filtered out based on reset_epoch.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# Fill the buffer completely
|
||||
for _ in range(10):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Reset the flight recorder
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Add only 3 new entries (much less than buffer size)
|
||||
for _ in range(3):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Verify we only get the 3 new entries, not 10
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 3)
|
||||
|
||||
# Verify record IDs start from 0 after reset
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("record_id", entry)
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset_wraparound(self, timing_enabled):
|
||||
"""
|
||||
Test that when we reset in the middle of the circular buffer and then
|
||||
wrap around, dump_entries correctly returns only entries from the current
|
||||
epoch in the correct order.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# Fill half the buffer
|
||||
for _ in range(5):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Reset at this point (reset happens at index 5)
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Now add 8 entries, which will wrap around
|
||||
# (5->9 fills rest of buffer, then 0->2 wraps around)
|
||||
for _ in range(8):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Should get exactly 8 entries, properly ordered
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 8)
|
||||
|
||||
# Entries should be in chronological order
|
||||
# The dump_entries() method returns entries from next_ to end, then 0 to next_
|
||||
# After filtering old entries, we should have 8 entries in order
|
||||
# Verify record IDs start from 0 after reset (id_ is reset in reset_all())
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("profiling_name", entry)
|
||||
self.assertIn("record_id", entry)
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_multiple_resets(self, timing_enabled):
|
||||
"""
|
||||
Test multiple consecutive resets to ensure each reset properly increments
|
||||
the epoch and filters out entries from previous epochs.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# First batch: 2 entries
|
||||
for _ in range(2):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# First reset
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Second batch: 3 entries
|
||||
for _ in range(3):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Second reset
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Third batch: 4 entries
|
||||
for _ in range(4):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Should only see the last 4 entries
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 4)
|
||||
|
||||
# Verify record IDs start from 0 after the last reset
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("record_id", entry)
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def check_if_test_is_skipped(fn):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
|
||||
@ -8,21 +8,11 @@ from torch._dynamo.graph_deduplication import apply_graph_deduplication
|
||||
from torch._dynamo.graph_utils import _detect_cycles
|
||||
from torch._dynamo.output_graph import FakeRootModule
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._dynamo.testing import (
|
||||
AotEagerAndRecordGraphs,
|
||||
extract_graph_and_tracker,
|
||||
normalize_gm,
|
||||
)
|
||||
from torch._dynamo.testing import extract_graph, extract_graph_and_tracker, normalize_gm
|
||||
from torch.compiler import allow_in_graph
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
def extract_graph(fn, *args, **kwargs):
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
|
||||
return result, backend.graphs, backend.fw_graphs
|
||||
|
||||
|
||||
def graph_str(gm):
|
||||
return normalize_gm(gm.print_readable(print_output=False))
|
||||
|
||||
@ -40,7 +30,7 @@ class GraphDededuplicationTests(TestCase):
|
||||
super().tearDown()
|
||||
|
||||
def run_and_return_graphs(self, fn, *args, **kwargs):
|
||||
return extract_graph(fn, *args, **kwargs)
|
||||
return extract_graph(fn, *args, **kwargs)[0:3]
|
||||
|
||||
def run_and_get_simple_graph(self):
|
||||
def fn(x, y):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import unittest
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Union
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
from typing import NamedTuple, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
@ -7,6 +7,10 @@ from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.testing import CompileCounter, same
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
"""
|
||||
This is an example of a pure-python version of autograd implemented by
|
||||
@zdevito. It represents a rather challenging test case for TorchDynamo
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import functools
|
||||
import re
|
||||
import unittest
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.testing import extract_graph, remove_trailing_space
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_utils import requires_cuda
|
||||
|
||||
@ -15,6 +17,14 @@ requires_multigpu = functools.partial(
|
||||
)
|
||||
|
||||
|
||||
def remove_file_comment(gm_str: str) -> str:
|
||||
return remove_trailing_space(re.sub(r"File.*\n", "\n", gm_str))
|
||||
|
||||
|
||||
def print_graph(graph: torch.fx.GraphModule) -> str:
|
||||
return remove_file_comment(graph.print_readable())
|
||||
|
||||
|
||||
class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -36,9 +46,7 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_enter_exit(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
def fn(x, y, s1, s2):
|
||||
with s1:
|
||||
z1 = torch.add(x, y)
|
||||
with s2:
|
||||
@ -47,13 +55,36 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
|
||||
return y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(), torch.Stream())
|
||||
expected = fn(*inp)
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
actual = fn_opt(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': None}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
|
||||
return (add_3,)
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
@unittest.skip("Needs graph break support with annotation context")
|
||||
def test_stream_context_graph_break(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
@ -70,9 +101,16 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
fn_opt = torch.compile(fn)
|
||||
actual = fn_opt(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertEqual(len(fw_graphs), 2)
|
||||
self.assertExpectedInline(print_graph(fw_graphs[0]), """""")
|
||||
self.assertExpectedInline(print_graph(fw_graphs[1]), """""")
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_input(self):
|
||||
@ -155,22 +193,248 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(s_act, s_exp)
|
||||
|
||||
def test_nested_stream_enter_exit(self):
|
||||
pass
|
||||
def fn(x, y, s0, s1, s2):
|
||||
with s1:
|
||||
with s2:
|
||||
z1 = torch.add(x, y)
|
||||
with s0:
|
||||
z0 = torch.add(x, y)
|
||||
with s2:
|
||||
y = 2 + z1
|
||||
|
||||
return z0, y
|
||||
|
||||
inp = (
|
||||
torch.ones(2, 2) + 1,
|
||||
torch.ones(2, 2),
|
||||
torch.Stream(),
|
||||
torch.Stream(),
|
||||
torch.Stream(),
|
||||
)
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': None}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
|
||||
return (add_1, add_2)
|
||||
""",
|
||||
)
|
||||
|
||||
@unittest.skip("Needs graph break support with annotation context")
|
||||
def test_stream_enter_exit_graph_break(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Needs graph break support with annotation context")
|
||||
def test_nested_stream_enter_exit_graph_break(self):
|
||||
pass
|
||||
|
||||
def test_local_stream_enter_exit(self):
|
||||
pass
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
with s1:
|
||||
z1 = torch.add(x, y)
|
||||
with s2:
|
||||
z = torch.add(x, y)
|
||||
y = z + 2 + z1
|
||||
|
||||
return y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 1}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
|
||||
return (add_3,)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_local_stream_nested_enter_exit(self):
|
||||
pass
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
with s1:
|
||||
with s2:
|
||||
z1 = torch.add(x, y)
|
||||
with s0:
|
||||
z0 = torch.add(x, y)
|
||||
with s2:
|
||||
y = 2 + z1
|
||||
|
||||
return z0, y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 0}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': 2}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
|
||||
return (add_1, add_2)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_stream_with_mutation(self):
|
||||
pass
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
with s1:
|
||||
with s2:
|
||||
x.add_(y)
|
||||
with s0:
|
||||
z1 = torch.add(y, y)
|
||||
z0 = torch.add(z1, y)
|
||||
with s2:
|
||||
y = 2 + z1
|
||||
|
||||
return z0, y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 0}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': 2}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': 2}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, arg1_1); arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
|
||||
|
||||
#
|
||||
copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
|
||||
return (add_2, add_3)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_stream_backward(self) -> None:
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
with s0:
|
||||
y0 = 2 * x + y
|
||||
with s2:
|
||||
z = 2 * x + y
|
||||
|
||||
return y0, z
|
||||
|
||||
inp = (
|
||||
torch.ones(2, 2, requires_grad=True) + 1,
|
||||
torch.ones(2, 2, requires_grad=True),
|
||||
)
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
bw_graphs,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 1}
|
||||
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
|
||||
return (add, add_1)
|
||||
""",
|
||||
)
|
||||
|
||||
actual[1].sum().backward()
|
||||
self.assertExpectedInline(
|
||||
print_graph(bw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 0}
|
||||
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
|
||||
|
||||
#
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
|
||||
|
||||
# Annotation: {'stream': 1}
|
||||
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
|
||||
|
||||
#
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
|
||||
return (add_3, add_2)
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_run_opcheck(self):
|
||||
|
||||
@ -721,6 +721,34 @@ class TestExport(TestCase):
|
||||
)
|
||||
self.assertEqual(node.meta["from_node"][-1].graph_id, graph_id)
|
||||
|
||||
def test_annotate_on_assert(self):
|
||||
# nodes added in `apply_runtime_assertion_pass` will be annotated
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
with torch.fx.traceback.annotate({"moo": 0}):
|
||||
x = torch.cat([x, x])
|
||||
b = y.item()
|
||||
torch._check(b >= x.shape[0])
|
||||
return x * b
|
||||
|
||||
with torch.fx.traceback.preserve_node_meta():
|
||||
ep = torch.export.export(
|
||||
M(),
|
||||
(torch.randn(3), torch.tensor(6)),
|
||||
dynamic_shapes={"x": {0: Dim("b")}, "y": None},
|
||||
)
|
||||
|
||||
custom_metadata = torch.fx.traceback._get_custom_metadata(ep.module())
|
||||
self.assertExpectedInline(
|
||||
str(custom_metadata),
|
||||
"""\
|
||||
('call_function', 'cat', {'moo': 0})
|
||||
('call_function', 'item', {'moo': 0})
|
||||
('call_function', 'ge_1', {'moo': 0})
|
||||
('call_function', '_assert_scalar_default', {'moo': 0})
|
||||
('call_function', 'mul', {'moo': 0})""",
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
def test_flex_attention_export(self):
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
@ -6093,26 +6121,19 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
retry_export(
|
||||
cf_implicitsize(),
|
||||
(torch.tensor(2), torch.randn(10)),
|
||||
fixes=[
|
||||
# Could not guard on data-dependent expression u0 < 0
|
||||
"torch._check(i >= 0)",
|
||||
],
|
||||
fixes=[],
|
||||
)
|
||||
|
||||
class cf_stacklist(torch.nn.Module):
|
||||
def forward(self, xs, y, fixes):
|
||||
i = y.item()
|
||||
eval(fixes)
|
||||
# instead of xs[i]
|
||||
return torch.stack(xs, 0).narrow(0, i, 1).squeeze()
|
||||
|
||||
retry_export(
|
||||
cf_stacklist(),
|
||||
([torch.ones(5) * i for i in range(10)], torch.tensor(2)),
|
||||
fixes=[
|
||||
# Could not guard on data-dependent expression u0 < 0
|
||||
"torch._check(i >= 0)",
|
||||
],
|
||||
fixes=[],
|
||||
)
|
||||
|
||||
class cf_tensorsplit(torch.nn.Module):
|
||||
@ -6166,7 +6187,12 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
class cf_stacklist(torch.nn.Module):
|
||||
def forward(self, xs, y):
|
||||
# y.item() is not a local, so we can't suggest a fix
|
||||
return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
|
||||
if y.item() < 0:
|
||||
return (
|
||||
torch.stack(xs, 0).narrow(0, y.item() + xs.size(), 1).squeeze()
|
||||
)
|
||||
else:
|
||||
return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
error_type,
|
||||
@ -6196,7 +6222,18 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
def forward(self, xs, y):
|
||||
box = Box(y.item())
|
||||
# box.content is not a local, so we can't suggest a fix
|
||||
return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze()
|
||||
if box.content < 0:
|
||||
return (
|
||||
torch.stack(xs, 0)
|
||||
.narrow(0, box.content + xs.size(), 1)
|
||||
.squeeze()
|
||||
)
|
||||
else:
|
||||
return (
|
||||
torch.stack(xs, 0)
|
||||
.narrow(0, box.content + xs.size(), 1)
|
||||
.squeeze()
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
error_type,
|
||||
|
||||
176
test/test_as_strided.py
Normal file
176
test/test_as_strided.py
Normal file
@ -0,0 +1,176 @@
|
||||
# Owner(s): ["oncall: pt2"]
|
||||
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
def get_state(t: torch.Tensor) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
"""Extract (sizes, strides) tuple from a tensor."""
|
||||
return (tuple(t.size()), tuple(t.stride()))
|
||||
|
||||
|
||||
def enumerate_reachable_states(
|
||||
initial_size: int,
|
||||
) -> set[tuple[tuple[int, ...], tuple[int, ...]]]:
|
||||
"""
|
||||
Use BFS with DP to enumerate all reachable (size, stride) states from
|
||||
a 1D contiguous tensor via valid view operations.
|
||||
|
||||
We only explore states with offset=0 (you can retroactively change the offset).
|
||||
We reject states with size=0 or size=1 dimensions as they are degenerate.
|
||||
"""
|
||||
# Create initial 1D contiguous tensor
|
||||
initial_tensor = torch.arange(initial_size)
|
||||
|
||||
initial_state = get_state(initial_tensor)
|
||||
|
||||
# Map from state to tensor for that state
|
||||
state_to_tensor: dict[tuple[tuple[int, ...], tuple[int, ...]], torch.Tensor] = {
|
||||
initial_state: initial_tensor
|
||||
}
|
||||
visited: set[tuple[tuple[int, ...], tuple[int, ...]]] = {initial_state}
|
||||
queue: deque[tuple[tuple[int, ...], tuple[int, ...]]] = deque([initial_state])
|
||||
|
||||
while queue:
|
||||
state = queue.popleft()
|
||||
t = state_to_tensor[state]
|
||||
sizes, strides = state
|
||||
ndim = len(sizes)
|
||||
|
||||
def add_state(new_t: torch.Tensor) -> None:
|
||||
new_state = get_state(new_t)
|
||||
sizes, strides = new_state
|
||||
# Skip if has size-0 or size-1 dimensions
|
||||
if any(s == 0 or s == 1 for s in sizes):
|
||||
return
|
||||
# Only accept states where strides are in descending order
|
||||
if list(strides) != sorted(strides, reverse=True):
|
||||
return
|
||||
if new_state not in visited:
|
||||
visited.add(new_state)
|
||||
queue.append(new_state)
|
||||
state_to_tensor[new_state] = new_t
|
||||
|
||||
# 1. Unflatten: try factoring each dimension
|
||||
for dim in range(ndim):
|
||||
size = sizes[dim]
|
||||
assert size > 1
|
||||
# Try all factorizations x * y = size where both x, y >= 2
|
||||
# We only need to check x up to size // 2 since when x > size // 2,
|
||||
# y = size // x < 2, which we reject
|
||||
for x in range(2, size // 2 + 1):
|
||||
if size % x == 0:
|
||||
y = size // x
|
||||
add_state(t.unflatten(dim, (x, y)))
|
||||
|
||||
# 2. Slice: exhaustively check all possible slicing parameters
|
||||
for dim in range(ndim):
|
||||
size = sizes[dim]
|
||||
for start in range(size):
|
||||
for stop in range(start + 1, size + 1):
|
||||
for step in range(1, size + 1):
|
||||
slices = [slice(None)] * ndim
|
||||
slices[dim] = slice(start, stop, step)
|
||||
add_state(t[tuple(slices)])
|
||||
|
||||
# 3. Flatten: merge adjacent dimensions
|
||||
for dim in range(ndim - 1):
|
||||
add_state(t.flatten(dim, dim + 1))
|
||||
|
||||
return visited
|
||||
|
||||
|
||||
class TestAsStrided(TestCase):
|
||||
def test_size_10_exhaustive(self) -> None:
|
||||
"""Test that size 10 produces exactly the expected 54 states."""
|
||||
expected_states = {
|
||||
((2,), (1,)),
|
||||
((2,), (2,)),
|
||||
((2,), (3,)),
|
||||
((2,), (4,)),
|
||||
((2,), (5,)),
|
||||
((2,), (6,)),
|
||||
((2,), (7,)),
|
||||
((2,), (8,)),
|
||||
((2,), (9,)),
|
||||
((2, 2), (2, 1)),
|
||||
((2, 2), (3, 1)),
|
||||
((2, 2), (3, 2)),
|
||||
((2, 2), (4, 1)),
|
||||
((2, 2), (4, 2)),
|
||||
((2, 2), (4, 3)),
|
||||
((2, 2), (5, 1)),
|
||||
((2, 2), (5, 2)),
|
||||
((2, 2), (5, 3)),
|
||||
((2, 2), (5, 4)),
|
||||
((2, 2), (6, 1)),
|
||||
((2, 2), (6, 2)),
|
||||
((2, 2), (6, 3)),
|
||||
((2, 2), (8, 1)),
|
||||
((2, 2, 2), (4, 2, 1)),
|
||||
((2, 2, 2), (5, 2, 1)),
|
||||
((2, 3), (3, 1)),
|
||||
((2, 3), (4, 1)),
|
||||
((2, 3), (5, 1)),
|
||||
((2, 3), (5, 2)),
|
||||
((2, 3), (6, 1)),
|
||||
((2, 4), (4, 1)),
|
||||
((2, 4), (5, 1)),
|
||||
((2, 5), (5, 1)),
|
||||
((3,), (1,)),
|
||||
((3,), (2,)),
|
||||
((3,), (3,)),
|
||||
((3,), (4,)),
|
||||
((3, 2), (2, 1)),
|
||||
((3, 2), (3, 1)),
|
||||
((3, 2), (3, 2)),
|
||||
((3, 2), (4, 1)),
|
||||
((3, 3), (3, 1)),
|
||||
((4,), (1,)),
|
||||
((4,), (2,)),
|
||||
((4,), (3,)),
|
||||
((4, 2), (2, 1)),
|
||||
((5,), (1,)),
|
||||
((5,), (2,)),
|
||||
((5, 2), (2, 1)),
|
||||
((6,), (1,)),
|
||||
((7,), (1,)),
|
||||
((8,), (1,)),
|
||||
((9,), (1,)),
|
||||
((10,), (1,)),
|
||||
}
|
||||
|
||||
actual_states = enumerate_reachable_states(10)
|
||||
|
||||
self.assertEqual(len(actual_states), 54)
|
||||
self.assertEqual(actual_states, expected_states)
|
||||
|
||||
def test_subset_property(self) -> None:
|
||||
"""
|
||||
Test that for sizes 2..10, each smaller tensor results in a strict
|
||||
subset of possible states compared to the next one.
|
||||
"""
|
||||
prev_states: Optional[set[tuple[tuple[int, ...], tuple[int, ...]]]] = None
|
||||
for size in range(2, 11):
|
||||
current_states = enumerate_reachable_states(size)
|
||||
|
||||
if prev_states is not None:
|
||||
# Check that prev_states is a strict subset of current_states
|
||||
self.assertTrue(
|
||||
prev_states.issubset(current_states),
|
||||
f"States from size {size - 1} are not a subset of size {size}",
|
||||
)
|
||||
# Check that it's a strict subset (not equal)
|
||||
self.assertTrue(
|
||||
len(prev_states) < len(current_states),
|
||||
f"States from size {size - 1} should be strictly fewer than size {size}",
|
||||
)
|
||||
|
||||
prev_states = current_states
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -4401,6 +4401,57 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
|
||||
|
||||
self.assertEqual(compiled(a, b), func(a, b))
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_narrow_unbacked_start(self):
|
||||
def func(x, start, length):
|
||||
# unbacked start
|
||||
u0 = start.item()
|
||||
return torch.narrow(x, 0, u0, length)
|
||||
|
||||
compiled_func = torch.compile(func, fullgraph=True, backend="inductor")
|
||||
|
||||
x = torch.tensor([1, 2, 3, 4, 5, 6])
|
||||
|
||||
# Test cases: (start, length)
|
||||
test_cases = [
|
||||
# Negative starts
|
||||
(-2, 2), # Start from second-to-last element
|
||||
(-1, 1), # Start from last element
|
||||
(-3, 3), # Start from third-to-last element
|
||||
(-6, 2), # Start from beginning (negative)
|
||||
(-4, 1), # Start from fourth-to-last element
|
||||
# Positive starts
|
||||
(0, 2), # Start from beginning
|
||||
(1, 3), # Start from second element
|
||||
(2, 2), # Start from third element
|
||||
(4, 2), # Start near end
|
||||
# Edge cases
|
||||
(0, 6), # Full tensor
|
||||
(0, 1), # Single element from start
|
||||
(5, 1), # Single element from end
|
||||
]
|
||||
|
||||
for start_val, length in test_cases:
|
||||
with self.subTest(start=start_val, length=length):
|
||||
start = torch.tensor([start_val])
|
||||
|
||||
# Test with compiled function
|
||||
result_compiled = compiled_func(x, start, length)
|
||||
|
||||
# Test with eager function (expected behavior)
|
||||
result_eager = func(x, start, length)
|
||||
|
||||
# Compare results
|
||||
self.assertEqual(result_compiled, result_eager)
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
@torch._inductor.config.patch("cpp_wrapper", True)
|
||||
def test_narrow_unbacked_start_cpp_wrapper(self):
|
||||
"""Test narrow with unbacked start with cpp_wrapper"""
|
||||
self.test_narrow_unbacked_start()
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestUnbacked)
|
||||
|
||||
|
||||
@ -72,6 +72,7 @@ from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
skipIfRocm,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
@ -4249,6 +4250,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
|
||||
torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_stack_trace_augmentation(self):
|
||||
"""
|
||||
@ -4304,6 +4306,7 @@ event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_multiple_modules(self):
|
||||
"""
|
||||
@ -4347,6 +4350,7 @@ event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_nested_graph_modules(self):
|
||||
"""
|
||||
|
||||
@ -359,6 +359,29 @@ class TestMatmulCuda(InductorTestCase):
|
||||
self.assertEqual(agrad, a.grad)
|
||||
self.assertEqual(bgrad, b.grad)
|
||||
|
||||
@onlyCUDA
|
||||
@skipIfRocm
|
||||
@dtypes(torch.half, torch.bfloat16)
|
||||
@unittest.skipIf(not SM100OrLater, "cuBLAS integration for batch invariance is only on Blackwell")
|
||||
@serialTest()
|
||||
def test_cublas_batch_invariance_blackwell(self, device, dtype):
|
||||
orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
|
||||
orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (False, False)
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (False, False)
|
||||
with blas_library_context('cublaslt'):
|
||||
N = 2048
|
||||
K = 6144
|
||||
M_max = 32
|
||||
x = torch.randn(M_max, K, device="cuda", dtype=torch.bfloat16)
|
||||
w = torch.randn(N, K, device="cuda", dtype=torch.bfloat16).t()
|
||||
full = x @ w
|
||||
xx = x[:1]
|
||||
out = xx @ w
|
||||
self.assertEqual(full[:1], out, atol=0., rtol=0.)
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16
|
||||
|
||||
@unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater")
|
||||
@parametrize("strided", [False, True])
|
||||
@parametrize("a_row_major", [False, True])
|
||||
@ -490,8 +513,6 @@ class TestMatmulCuda(InductorTestCase):
|
||||
@parametrize("b_row_major", [False, True])
|
||||
@dtypes(torch.bfloat16, torch.float32, torch.float16)
|
||||
def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype):
|
||||
if TEST_WITH_ROCM and a_row_major and b_row_major and dtype in [torch.bfloat16, torch.float16]:
|
||||
self.skipTest("failed using hipblaslt on rocm 6.4.2")
|
||||
device = "cuda"
|
||||
s_int = int(strided)
|
||||
m, n, k, n_groups = 16, 32, 64, 4
|
||||
|
||||
@ -1864,6 +1864,8 @@ class TestFP8Matmul(TestCase):
|
||||
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
|
||||
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
|
||||
if torch.version.hip and recipe == "nvfp4":
|
||||
raise unittest.SkipTest("nvfp4 not supported on ROCm, skipping")
|
||||
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
|
||||
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
|
||||
|
||||
|
||||
@ -257,34 +257,6 @@ class TestFuzzerCompileIssues(TestCase):
|
||||
out_compiled.sum().backward()
|
||||
print("Compile Success! ✅")
|
||||
|
||||
@pytest.mark.xfail(reason="Issue #163971")
|
||||
def test_fuzzer_issue_163971(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
def foo(arg0):
|
||||
t0 = arg0 # size=(), stride=(), dtype=bfloat16, device=cuda
|
||||
t1 = torch.softmax(
|
||||
t0, dim=0
|
||||
) # size=(), stride=(), dtype=bfloat16, device=cuda
|
||||
t2 = torch.nn.functional.gelu(
|
||||
t1
|
||||
) # size=(), stride=(), dtype=bfloat16, device=cuda
|
||||
t3 = torch.softmax(
|
||||
t2, dim=0
|
||||
) # size=(), stride=(), dtype=bfloat16, device=cuda
|
||||
output = t3
|
||||
return output
|
||||
|
||||
arg0 = torch.rand([], dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
||||
|
||||
out_eager = foo(arg0)
|
||||
out_eager.sum().backward()
|
||||
print("Eager Success! ✅")
|
||||
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
|
||||
out_compiled = compiled_foo(arg0)
|
||||
out_compiled.sum().backward()
|
||||
print("Compile Success! ✅")
|
||||
|
||||
@pytest.mark.xfail(reason="Issue #164059")
|
||||
def test_fuzzer_issue_164059(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@ -1914,6 +1914,7 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
q, k, v, None, 0.0, is_causal=True))
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
|
||||
def test_mem_eff_attention_fail_with_batch_size_geq_65536(self):
|
||||
batch_size = 2**16
|
||||
query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
|
||||
@ -1935,6 +1936,7 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4)
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
|
||||
def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self):
|
||||
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
@ -1948,6 +1950,7 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
|
||||
@largeTensorTest("15GB", "cuda")
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
|
||||
def test_mem_eff_attention_large_seq_len_uniform_attention(self):
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from typing import Union
|
||||
from typing_extensions import assert_type, TypeAlias
|
||||
from typing import TypeAlias, Union
|
||||
from typing_extensions import assert_type
|
||||
|
||||
from torch import randn, Tensor
|
||||
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code="type-arg"
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, overload, Union
|
||||
from typing import Any, Optional, overload, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
@ -3320,7 +3320,7 @@ class InstructionTranslatorBase(
|
||||
obj = self.stack[-inst.arg]
|
||||
assert isinstance(obj, SetVariable)
|
||||
assert obj.is_mutable()
|
||||
obj.call_method(self, "add", [v], {})
|
||||
obj.call_method(self, "add", [v], {}) # type: ignore[arg-type]
|
||||
|
||||
def SET_UPDATE(self, inst: Instruction) -> None:
|
||||
v = self.pop()
|
||||
@ -3329,7 +3329,7 @@ class InstructionTranslatorBase(
|
||||
obj = self.stack[-inst.arg]
|
||||
assert isinstance(obj, SetVariable)
|
||||
assert obj.is_mutable()
|
||||
obj.call_method(self, "update", [v], {})
|
||||
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
|
||||
|
||||
def LIST_APPEND(self, inst: Instruction) -> None:
|
||||
v = self.pop()
|
||||
@ -3637,7 +3637,7 @@ class InstructionTranslatorBase(
|
||||
obj = self.stack[-inst.arg].realize()
|
||||
assert isinstance(obj, ConstDictVariable)
|
||||
assert obj.is_mutable()
|
||||
obj.call_method(self, "update", [v], {})
|
||||
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
|
||||
|
||||
DICT_UPDATE = DICT_MERGE
|
||||
|
||||
|
||||
@ -87,6 +87,12 @@ def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-d
|
||||
return gm.graph, region_tracker # type: ignore[union-attr]
|
||||
|
||||
|
||||
def extract_graph(fn, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
|
||||
return result, backend.graphs, backend.fw_graphs, backend.bw_graphs
|
||||
|
||||
|
||||
def collect_results(
|
||||
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
|
||||
) -> list[Any]:
|
||||
|
||||
@ -21,9 +21,9 @@ restoring state changes.
|
||||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Callable, Sequence, Sized
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union
|
||||
from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch._C
|
||||
from torch._guards import Guard
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Dictionary-related variable tracking classes for PyTorch Dynamo.
|
||||
|
||||
@ -26,7 +24,7 @@ import inspect
|
||||
import operator
|
||||
import types
|
||||
from collections.abc import Hashable as py_Hashable
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch._subclasses.fake_tensor import is_fake
|
||||
|
||||
@ -59,11 +57,13 @@ if TYPE_CHECKING:
|
||||
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
|
||||
|
||||
|
||||
def was_instancecheck_override(obj):
|
||||
def was_instancecheck_override(obj: Any) -> bool:
|
||||
return type(obj).__dict__.get("__instancecheck__", False)
|
||||
|
||||
|
||||
def raise_unhashable(arg, tx=None):
|
||||
def raise_unhashable(
|
||||
arg: VariableTracker, tx: Optional["InstructionTranslator"] = None
|
||||
) -> None:
|
||||
if tx is None:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
@ -75,7 +75,7 @@ def raise_unhashable(arg, tx=None):
|
||||
)
|
||||
|
||||
|
||||
def is_hashable(x):
|
||||
def is_hashable(x: VariableTracker) -> bool:
|
||||
# NB - performing isinstance check on a LazVT realizes the VT, accidentally
|
||||
# inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at
|
||||
# the underlying value without realizing the VT. Consider updating the
|
||||
@ -143,7 +143,7 @@ class ConstDictVariable(VariableTracker):
|
||||
Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing
|
||||
"""
|
||||
|
||||
def __init__(self, vt) -> None:
|
||||
def __init__(self, vt: VariableTracker) -> None:
|
||||
# We specialize SymNodes
|
||||
vt = specialize_symnode(vt)
|
||||
# TODO Temporarily remove to figure out what keys are we breaking on
|
||||
@ -153,7 +153,7 @@ class ConstDictVariable(VariableTracker):
|
||||
self.vt = vt
|
||||
|
||||
@property
|
||||
def underlying_value(self):
|
||||
def underlying_value(self) -> Any:
|
||||
if (
|
||||
isinstance(self.vt, variables.LazyVariableTracker)
|
||||
and not self.vt.is_realized()
|
||||
@ -178,7 +178,8 @@ class ConstDictVariable(VariableTracker):
|
||||
elif isinstance(self.vt, variables.FrozenDataClassVariable):
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
fields_values = {
|
||||
k: Hashable(v).underlying_value for k, v in self.vt.fields.items()
|
||||
k: Hashable(v).underlying_value
|
||||
for k, v in self.vt.fields.items() # type: ignore[attr-defined]
|
||||
}
|
||||
return variables.FrozenDataClassVariable.HashWrapper(
|
||||
self.vt.python_type(), fields_values
|
||||
@ -187,16 +188,16 @@ class ConstDictVariable(VariableTracker):
|
||||
# The re module in Python 3.13+ has a dictionary (_cache2) with
|
||||
# an object as key (`class _ZeroSentinel(int): ...`):
|
||||
# python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual
|
||||
return self.vt.value
|
||||
return self.vt.value # type: ignore[attr-defined,union-attr]
|
||||
else:
|
||||
x = self.vt.as_python_constant()
|
||||
return x
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.underlying_value)
|
||||
|
||||
@staticmethod
|
||||
def _eq_impl(a, b):
|
||||
def _eq_impl(a: Any, b: Any) -> bool:
|
||||
# TODO: Put this in utils and share it between variables/builtin.py and here
|
||||
type_a, type_b = type(a), type(b)
|
||||
if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)):
|
||||
@ -212,7 +213,7 @@ class ConstDictVariable(VariableTracker):
|
||||
else:
|
||||
return a == b
|
||||
|
||||
def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), (
|
||||
type(other)
|
||||
@ -226,8 +227,8 @@ class ConstDictVariable(VariableTracker):
|
||||
def __init__(
|
||||
self,
|
||||
items: dict[VariableTracker, VariableTracker],
|
||||
user_cls=dict,
|
||||
**kwargs,
|
||||
user_cls: type = dict,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# .clone() pass these arguments in kwargs but they're recreated a few
|
||||
# lines below
|
||||
@ -247,18 +248,22 @@ class ConstDictVariable(VariableTracker):
|
||||
for x, v in items.items()
|
||||
)
|
||||
|
||||
def make_hashable(key):
|
||||
def make_hashable(
|
||||
key: Union[VariableTracker, "ConstDictVariable._HashableTracker"],
|
||||
) -> "ConstDictVariable._HashableTracker":
|
||||
return key if isinstance(key, Hashable) else Hashable(key)
|
||||
|
||||
dict_cls = self._get_dict_cls_from_user_cls(user_cls)
|
||||
self.items = dict_cls({make_hashable(x): v for x, v in items.items()})
|
||||
# need to reconstruct everything if the dictionary is an intermediate value
|
||||
# or if a pop/delitem was executed
|
||||
self.should_reconstruct_all = not is_from_local_source(self.source)
|
||||
self.should_reconstruct_all = (
|
||||
not is_from_local_source(self.source) if self.source else True
|
||||
)
|
||||
self.original_items = items.copy()
|
||||
self.user_cls = user_cls
|
||||
|
||||
def _get_dict_cls_from_user_cls(self, user_cls):
|
||||
def _get_dict_cls_from_user_cls(self, user_cls: type) -> type:
|
||||
accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict)
|
||||
|
||||
# avoid executing user code if user_cls is a dict subclass
|
||||
@ -277,10 +282,10 @@ class ConstDictVariable(VariableTracker):
|
||||
dict_cls = dict
|
||||
return dict_cls
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> dict[Any, Any]:
|
||||
return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
return (
|
||||
"{"
|
||||
+ ", ".join(
|
||||
@ -289,20 +294,20 @@ class ConstDictVariable(VariableTracker):
|
||||
+ "}"
|
||||
)
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> dict[Any, Any]:
|
||||
return {
|
||||
k.vt.as_python_constant(): v.as_python_constant()
|
||||
for k, v in self.items.items()
|
||||
}
|
||||
|
||||
def keys_as_python_constant(self):
|
||||
def keys_as_python_constant(self) -> dict[Any, VariableTracker]:
|
||||
self.install_dict_keys_match_guard()
|
||||
return {k.vt.as_python_constant(): v for k, v in self.items.items()}
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return self.user_cls
|
||||
|
||||
def __contains__(self, vt) -> bool:
|
||||
def __contains__(self, vt: VariableTracker) -> bool:
|
||||
assert isinstance(vt, VariableTracker)
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
return (
|
||||
@ -322,13 +327,15 @@ class ConstDictVariable(VariableTracker):
|
||||
for key, value in self.items.items()
|
||||
)
|
||||
|
||||
def is_new_item(self, value, other):
|
||||
def is_new_item(
|
||||
self, value: Optional[VariableTracker], other: VariableTracker
|
||||
) -> bool:
|
||||
# compare the id of the realized values if both values are not lazy VTs
|
||||
if value and value.is_realized() and other.is_realized():
|
||||
return id(value.realize()) != id(other.realize())
|
||||
return id(value) != id(other)
|
||||
|
||||
def reconstruct_kvs_into_new_dict(self, codegen):
|
||||
def reconstruct_kvs_into_new_dict(self, codegen: "PyCodegen") -> None:
|
||||
# Build a dictionary that contains the keys and values.
|
||||
num_args = 0
|
||||
for key, value in self.items.items():
|
||||
@ -340,7 +347,7 @@ class ConstDictVariable(VariableTracker):
|
||||
num_args += 1
|
||||
codegen.append_output(create_instruction("BUILD_MAP", arg=num_args))
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
if self.user_cls is collections.OrderedDict:
|
||||
# emit `OrderedDict(constructed_dict)`
|
||||
codegen.add_push_null(
|
||||
@ -358,19 +365,21 @@ class ConstDictVariable(VariableTracker):
|
||||
|
||||
def getitem_const_raise_exception_if_absent(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
):
|
||||
) -> VariableTracker:
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
if key not in self.items:
|
||||
raise_observed_exception(KeyError, tx)
|
||||
return self.items[key]
|
||||
|
||||
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
def getitem_const(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
if key not in self.items:
|
||||
msg = f"Dictionary key {arg.value} not found during tracing"
|
||||
msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined]
|
||||
unimplemented_v2(
|
||||
gb_type="key not found in dict",
|
||||
context=f"Key {arg.value}",
|
||||
context=f"Key {arg.value}", # type: ignore[attr-defined]
|
||||
explanation=msg,
|
||||
hints=[
|
||||
"Check if the key exists in the dictionary before accessing it.",
|
||||
@ -379,13 +388,13 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
return self.items[key]
|
||||
|
||||
def maybe_getitem_const(self, arg: VariableTracker):
|
||||
def maybe_getitem_const(self, arg: VariableTracker) -> Optional[VariableTracker]:
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
if key not in self.items:
|
||||
return None
|
||||
return self.items[key]
|
||||
|
||||
def realize_key_vt(self, arg: VariableTracker):
|
||||
def realize_key_vt(self, arg: VariableTracker) -> None:
|
||||
# Realize the LazyVT on a particular index
|
||||
assert arg in self
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
@ -394,11 +403,13 @@ class ConstDictVariable(VariableTracker):
|
||||
if isinstance(original_key_vt, variables.LazyVariableTracker):
|
||||
original_key_vt.realize()
|
||||
|
||||
def install_dict_keys_match_guard(self):
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
if self.source:
|
||||
install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH))
|
||||
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
# Key guarding - These are the cases to consider
|
||||
# 1) The dict has been mutated. In this case, we would have already
|
||||
# inserted a DICT_KEYS_MATCH guard, so we can skip.
|
||||
@ -439,11 +450,11 @@ class ConstDictVariable(VariableTracker):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
# NB - Both key and value are LazyVariableTrackers in the beginning. So,
|
||||
# we have to insert guards when a dict method is accessed. For this to
|
||||
# be simple, we are conservative and overguard. We skip guard only for
|
||||
@ -462,7 +473,7 @@ class ConstDictVariable(VariableTracker):
|
||||
tx, *args, **kwargs
|
||||
)
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items.update(temp_dict_vt.items)
|
||||
self.items.update(temp_dict_vt.items) # type: ignore[attr-defined]
|
||||
return ConstantVariable.create(None)
|
||||
elif name == "__getitem__":
|
||||
# Key guarding - Nothing to do. LazyVT for value will take care.
|
||||
@ -526,7 +537,7 @@ class ConstDictVariable(VariableTracker):
|
||||
return ConstantVariable.create(len(self.items))
|
||||
elif name == "__setitem__" and self.is_mutable():
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
self.install_dict_keys_match_guard()
|
||||
if kwargs or len(args) != 2:
|
||||
@ -550,7 +561,7 @@ class ConstDictVariable(VariableTracker):
|
||||
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
if args[0] not in self:
|
||||
self.install_dict_contains_guard(tx, args)
|
||||
@ -565,7 +576,7 @@ class ConstDictVariable(VariableTracker):
|
||||
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
if args[0] not in self:
|
||||
# missing item, return the default value. Install no DICT_CONTAINS guard.
|
||||
@ -599,7 +610,7 @@ class ConstDictVariable(VariableTracker):
|
||||
last = v.value
|
||||
else:
|
||||
raise_args_mismatch(tx, name)
|
||||
k, v = self.items.popitem(last=last)
|
||||
k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined]
|
||||
else:
|
||||
k, v = self.items.popitem()
|
||||
|
||||
@ -632,17 +643,17 @@ class ConstDictVariable(VariableTracker):
|
||||
# NB - Guard on all the keys of the other dict to ensure
|
||||
# correctness.
|
||||
args[0].install_dict_keys_match_guard()
|
||||
dict_vt = args[0]
|
||||
dict_vt: ConstDictVariable = args[0]
|
||||
else:
|
||||
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
|
||||
self.items.update(dict_vt.items)
|
||||
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment]
|
||||
self.items.update(dict_vt.items) # type: ignore[attr-defined]
|
||||
if has_kwargs:
|
||||
# Handle kwargs
|
||||
kwargs = {
|
||||
kwargs_hashable = {
|
||||
Hashable(ConstantVariable.create(k)): v
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
self.items.update(kwargs)
|
||||
self.items.update(kwargs_hashable)
|
||||
return ConstantVariable.create(None)
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
@ -656,7 +667,7 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
self.install_dict_contains_guard(tx, args)
|
||||
contains = args[0] in self
|
||||
@ -671,7 +682,7 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0])
|
||||
raise_unhashable(args[0], tx)
|
||||
|
||||
self.install_dict_keys_match_guard()
|
||||
if kwargs or len(args) > 2:
|
||||
@ -707,7 +718,7 @@ class ConstDictVariable(VariableTracker):
|
||||
and "last" in kwargs
|
||||
and isinstance(kwargs["last"], ConstantVariable)
|
||||
):
|
||||
last = kwargs.get("last").value
|
||||
last = kwargs.get("last").value # type: ignore[union-attr]
|
||||
|
||||
key = Hashable(args[0])
|
||||
self.items.move_to_end(key, last=last)
|
||||
@ -723,7 +734,7 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
elif name == "__ne__":
|
||||
return ConstantVariable.create(
|
||||
not self.call_method(tx, "__eq__", args, kwargs).value
|
||||
not self.call_method(tx, "__eq__", args, kwargs).value # type: ignore[attr-defined]
|
||||
)
|
||||
elif name == "__or__":
|
||||
if len(args) != 1:
|
||||
@ -750,14 +761,14 @@ class ConstDictVariable(VariableTracker):
|
||||
if not istype(
|
||||
other, (ConstDictVariable, variables.UserDefinedDictVariable)
|
||||
):
|
||||
msg = (
|
||||
err_msg = (
|
||||
f"unsupported operand type(s) for |: '{self.python_type().__name__}'"
|
||||
f"and '{other.python_type().__name__}'"
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
raise_observed_exception(TypeError, tx, args=[err_msg])
|
||||
|
||||
# OrderedDict overloads __ror__
|
||||
ts = {self.user_cls, other.user_cls}
|
||||
ts = {self.user_cls, other.user_cls} # type: ignore[attr-defined]
|
||||
user_cls = (
|
||||
collections.OrderedDict
|
||||
if any(issubclass(t, collections.OrderedDict) for t in ts)
|
||||
@ -774,8 +785,8 @@ class ConstDictVariable(VariableTracker):
|
||||
|
||||
# NB - Guard on all the keys of the other dict to ensure
|
||||
# correctness.
|
||||
args[0].install_dict_keys_match_guard()
|
||||
new_dict_vt.items.update(args[0].items)
|
||||
args[0].install_dict_keys_match_guard() # type: ignore[attr-defined]
|
||||
new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined]
|
||||
return new_dict_vt
|
||||
elif name == "__ior__":
|
||||
self.call_method(tx, "update", args, kwargs)
|
||||
@ -789,11 +800,13 @@ class ConstDictVariable(VariableTracker):
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
self.install_dict_keys_match_guard()
|
||||
return [x.vt for x in self.items.keys()]
|
||||
|
||||
def call_obj_hasattr(self, tx, name):
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
# dict not allow setting arbitrary attributes. OrderedDict and
|
||||
# defaultdict allow arbitrary setattr, but not deletion of default attrs
|
||||
if any(
|
||||
@ -816,25 +829,25 @@ class ConstDictVariable(VariableTracker):
|
||||
],
|
||||
)
|
||||
|
||||
def clone(self, **kwargs):
|
||||
def clone(self, **kwargs: Any) -> VariableTracker:
|
||||
self.install_dict_keys_match_guard()
|
||||
return super().clone(**kwargs)
|
||||
|
||||
|
||||
class MappingProxyVariable(VariableTracker):
|
||||
# proxies to the original dict_vt
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(dv_dict, ConstDictVariable)
|
||||
self.dv_dict = dv_dict
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return types.MappingProxyType
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
return self.dv_dict.unpack_var_sequence(tx)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# load types.MappingProxyType
|
||||
if self.source:
|
||||
msg = (
|
||||
@ -863,11 +876,11 @@ class MappingProxyVariable(VariableTracker):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if self.source and tx.output.side_effects.has_existing_dict_mutation():
|
||||
msg = (
|
||||
"A dict has been modified while we have an existing mappingproxy object. "
|
||||
@ -892,7 +905,7 @@ class MappingProxyVariable(VariableTracker):
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
if self.python_type() is types.MappingProxyType:
|
||||
return ConstantVariable.create(name in types.MappingProxyType.__dict__)
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
@ -900,35 +913,44 @@ class MappingProxyVariable(VariableTracker):
|
||||
|
||||
class NNModuleHooksDictVariable(ConstDictVariable):
|
||||
# Special class to avoid adding any guards on the nn module hook ids.
|
||||
def install_dict_keys_match_guard(self):
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
pass
|
||||
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DefaultDictVariable(ConstDictVariable):
|
||||
def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
items: dict[VariableTracker, VariableTracker],
|
||||
user_cls: type,
|
||||
default_factory: Optional[VariableTracker] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(items, user_cls, **kwargs)
|
||||
assert user_cls is collections.defaultdict
|
||||
if default_factory is None:
|
||||
default_factory = ConstantVariable.create(None)
|
||||
self.default_factory = default_factory
|
||||
|
||||
def is_python_constant(self):
|
||||
def is_python_constant(self) -> bool:
|
||||
# Return false for unsupported defaults. This ensures that a bad handler
|
||||
# path is not taken in BuiltinVariable for getitem.
|
||||
if self.default_factory not in [list, tuple, dict] and not self.items:
|
||||
return False
|
||||
return super().is_python_constant()
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
assert self.default_factory is not None
|
||||
return (
|
||||
f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_supported_arg(arg):
|
||||
def is_supported_arg(arg: VariableTracker) -> bool:
|
||||
if isinstance(arg, variables.BuiltinVariable):
|
||||
return arg.fn in (list, tuple, dict, set)
|
||||
else:
|
||||
@ -942,11 +964,11 @@ class DefaultDictVariable(ConstDictVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "__getitem__":
|
||||
if len(args) != 1:
|
||||
raise_args_mismatch(tx, name, "1 args", f"{len(args)} args")
|
||||
@ -962,13 +984,13 @@ class DefaultDictVariable(ConstDictVariable):
|
||||
else:
|
||||
default_var = self.default_factory.call_function(tx, [], {})
|
||||
super().call_method(
|
||||
tx, "__setitem__", (args[0], default_var), kwargs
|
||||
tx, "__setitem__", [args[0], default_var], kwargs
|
||||
)
|
||||
return default_var
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# emit `defaultdict(default_factory, new_dict)`
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
@ -994,40 +1016,48 @@ class SetVariable(ConstDictVariable):
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# pyrefly: ignore[bad-assignment]
|
||||
items = dict.fromkeys(items, SetVariable._default_value())
|
||||
# pyrefly: ignore[bad-argument-type]
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
if not self.items:
|
||||
return "set()"
|
||||
else:
|
||||
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
|
||||
|
||||
@property
|
||||
def set_items(self):
|
||||
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
|
||||
return set(self.items.keys())
|
||||
|
||||
@staticmethod
|
||||
def _default_value():
|
||||
def _default_value() -> VariableTracker:
|
||||
# Variable to fill in he keys of the dictionary
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def as_proxy(self):
|
||||
def as_proxy(self) -> Any:
|
||||
return {k.vt.as_proxy() for k in self.set_items}
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return set
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return {k.vt.as_python_constant() for k in self.set_items}
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.foreach([x.vt for x in self.set_items])
|
||||
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
|
||||
|
||||
def _fast_set_method(self, tx, fn, args, kwargs):
|
||||
def _fast_set_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
fn: Any,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
try:
|
||||
res = fn(
|
||||
*[x.as_python_constant() for x in [self, *args]],
|
||||
@ -1037,15 +1067,16 @@ class SetVariable(ConstDictVariable):
|
||||
raise_observed_exception(
|
||||
type(exc), tx, args=list(map(ConstantVariable.create, exc.args))
|
||||
)
|
||||
# pyrefly: ignore[unbound-name]
|
||||
return VariableTracker.build(tx, res)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
# We forward the calls to the dictionary model
|
||||
from ..utils import check_constant_args
|
||||
|
||||
@ -1065,10 +1096,10 @@ class SetVariable(ConstDictVariable):
|
||||
return self._fast_set_method(tx, getattr(py_type, name), args, kwargs)
|
||||
|
||||
if name == "__init__":
|
||||
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
|
||||
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, **kwargs)
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items.clear()
|
||||
self.items.update(temp_set_vt.items)
|
||||
self.items.update(temp_set_vt.items) # type: ignore[attr-defined]
|
||||
return ConstantVariable.create(None)
|
||||
elif name == "add":
|
||||
if kwargs or len(args) != 1:
|
||||
@ -1079,7 +1110,7 @@ class SetVariable(ConstDictVariable):
|
||||
f"{len(args)} args and {len(kwargs)} kwargs",
|
||||
)
|
||||
name = "__setitem__"
|
||||
args = (args[0], SetVariable._default_value())
|
||||
args = [args[0], SetVariable._default_value()]
|
||||
elif name == "pop":
|
||||
if kwargs or args:
|
||||
raise_args_mismatch(
|
||||
@ -1090,12 +1121,14 @@ class SetVariable(ConstDictVariable):
|
||||
)
|
||||
# Choose an item at random and pop it via the Dict.pop method
|
||||
try:
|
||||
result = self.set_items.pop().vt
|
||||
result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment]
|
||||
except KeyError as e:
|
||||
raise_observed_exception(
|
||||
KeyError, tx, args=list(map(ConstantVariable.create, e.args))
|
||||
)
|
||||
super().call_method(tx, name, (result,), kwargs)
|
||||
# pyrefly: ignore[unbound-name]
|
||||
super().call_method(tx, name, [result], kwargs)
|
||||
# pyrefly: ignore[unbound-name]
|
||||
return result
|
||||
elif name == "isdisjoint":
|
||||
if kwargs or len(args) != 1:
|
||||
@ -1217,6 +1250,7 @@ class SetVariable(ConstDictVariable):
|
||||
f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'"
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
assert m is not None
|
||||
return self.call_method(tx, m, args, kwargs)
|
||||
elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"):
|
||||
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
|
||||
@ -1230,29 +1264,34 @@ class SetVariable(ConstDictVariable):
|
||||
"__ixor__": "symmetric_difference_update",
|
||||
"__isub__": "difference_update",
|
||||
}.get(name)
|
||||
assert m is not None
|
||||
self.call_method(tx, m, args, kwargs)
|
||||
return self
|
||||
elif name == "__eq__":
|
||||
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
|
||||
return ConstantVariable.create(False)
|
||||
r = self.call_method(tx, "symmetric_difference", args, kwargs)
|
||||
return ConstantVariable.create(len(r.set_items) == 0)
|
||||
return ConstantVariable.create(len(r.set_items) == 0) # type: ignore[attr-defined]
|
||||
elif name in cmp_name_to_op_mapping:
|
||||
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
|
||||
return ConstantVariable.create(NotImplemented)
|
||||
return ConstantVariable.create(
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
|
||||
)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
def getitem_const(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
raise RuntimeError("Illegal to getitem on a set")
|
||||
|
||||
def install_dict_keys_match_guard(self):
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
# Already EQUALS_MATCH guarded
|
||||
pass
|
||||
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
super().install_dict_contains_guard(tx, args)
|
||||
|
||||
|
||||
@ -1260,27 +1299,27 @@ class FrozensetVariable(SetVariable):
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
if not self.items:
|
||||
return "frozenset()"
|
||||
else:
|
||||
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
|
||||
|
||||
@property
|
||||
def set_items(self):
|
||||
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
|
||||
return self.items.keys()
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return frozenset
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return frozenset({k.vt.as_python_constant() for k in self.set_items})
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
codegen.foreach([x.vt for x in self.set_items])
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
@ -1293,11 +1332,11 @@ class FrozensetVariable(SetVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
|
||||
raise RuntimeError(f"Illegal call_method {name} on a frozenset")
|
||||
elif name == "__init__":
|
||||
@ -1316,7 +1355,7 @@ class FrozensetVariable(SetVariable):
|
||||
"symmetric_difference",
|
||||
):
|
||||
r = super().call_method(tx, name, args, kwargs)
|
||||
return FrozensetVariable(r.items)
|
||||
return FrozensetVariable(r.items) # type: ignore[attr-defined]
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
||||
@ -1324,11 +1363,11 @@ class DictKeySetVariable(SetVariable):
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def debug_repr(self):
|
||||
def debug_repr(self) -> str:
|
||||
if not self.items:
|
||||
return "dict_keys([])"
|
||||
else:
|
||||
@ -1338,33 +1377,35 @@ class DictKeySetVariable(SetVariable):
|
||||
+ "])"
|
||||
)
|
||||
|
||||
def install_dict_keys_match_guard(self):
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
# Already EQUALS_MATCH guarded
|
||||
pass
|
||||
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
# Already EQUALS_MATCH guarded
|
||||
pass
|
||||
|
||||
@property
|
||||
def set_items(self):
|
||||
def set_items(self) -> Any:
|
||||
return self.items
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return dict_keys
|
||||
|
||||
def as_python_constant(self):
|
||||
def as_python_constant(self) -> Any:
|
||||
return dict.fromkeys(
|
||||
{k.vt.as_python_constant() for k in self.set_items}, None
|
||||
).keys()
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> "VariableTracker":
|
||||
) -> VariableTracker:
|
||||
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
|
||||
raise RuntimeError(f"Illegal call_method {name} on a dict_keys")
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
@ -1379,42 +1420,47 @@ class DictViewVariable(VariableTracker):
|
||||
|
||||
kv: Optional[str] = None
|
||||
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert self.kv in ("keys", "values", "items")
|
||||
assert isinstance(dv_dict, ConstDictVariable)
|
||||
self.dv_dict = dv_dict
|
||||
|
||||
@property
|
||||
def view_items(self):
|
||||
def view_items(self) -> Any:
|
||||
assert self.kv is not None
|
||||
return getattr(self.dv_dict.items, self.kv)()
|
||||
|
||||
@property
|
||||
def view_items_vt(self):
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
# Returns an iterable of the unpacked items
|
||||
# Implement in the subclasses
|
||||
raise NotImplementedError
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
return self.view_items_vt
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
assert self.kv is not None
|
||||
codegen(self.dv_dict)
|
||||
codegen.load_method(self.kv)
|
||||
codegen.call_method(0)
|
||||
|
||||
def call_obj_hasattr(self, tx, name):
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
assert self.kv is not None
|
||||
if name in self.python_type().__dict__:
|
||||
return ConstantVariable.create(True)
|
||||
return ConstantVariable.create(False)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "__len__":
|
||||
return self.dv_dict.call_method(tx, name, args, kwargs)
|
||||
elif name == "__iter__":
|
||||
@ -1428,24 +1474,24 @@ class DictKeysVariable(DictViewVariable):
|
||||
kv = "keys"
|
||||
|
||||
@property
|
||||
def set_items(self):
|
||||
def set_items(self) -> set[VariableTracker]:
|
||||
return set(self.view_items)
|
||||
|
||||
@property
|
||||
def view_items_vt(self):
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
# Returns an iterable of the unpacked items
|
||||
return [x.vt for x in self.view_items]
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return dict_keys
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if name == "__contains__":
|
||||
return self.dv_dict.call_method(tx, name, args, kwargs)
|
||||
elif name in (
|
||||
@ -1460,13 +1506,13 @@ class DictKeysVariable(DictViewVariable):
|
||||
):
|
||||
# These methods always returns a set
|
||||
m = getattr(self.set_items, name)
|
||||
r = m(args[0].set_items)
|
||||
r = m(args[0].set_items) # type: ignore[attr-defined]
|
||||
return SetVariable(r)
|
||||
if name in cmp_name_to_op_mapping:
|
||||
if not isinstance(args[0], (SetVariable, DictKeysVariable)):
|
||||
return ConstantVariable.create(NotImplemented)
|
||||
return ConstantVariable.create(
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
|
||||
)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
@ -1476,10 +1522,10 @@ class DictValuesVariable(DictViewVariable):
|
||||
kv = "values"
|
||||
|
||||
@property
|
||||
def view_items_vt(self):
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
return list(self.view_items)
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return dict_values
|
||||
|
||||
|
||||
@ -1487,14 +1533,20 @@ class DictItemsVariable(DictViewVariable):
|
||||
kv = "items"
|
||||
|
||||
@property
|
||||
def view_items_vt(self):
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
# Returns an iterable of the unpacked items
|
||||
return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items]
|
||||
|
||||
def python_type(self):
|
||||
def python_type(self) -> type:
|
||||
return dict_items
|
||||
|
||||
def call_method(self, tx, name, args, kwargs):
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
# TODO(guilhermeleobas): This should actually check if args[0]
|
||||
# implements the mapping protocol.
|
||||
if name == "__eq__":
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import sympy # noqa: TC002
|
||||
|
||||
@ -17,6 +17,8 @@ from .simd import SIMDKernel, SIMDScheduling
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
from ..ir import IRNode
|
||||
from ..scheduler import BaseSchedulerNode
|
||||
|
||||
|
||||
@ -627,7 +627,7 @@ class ComboKernel(Kernel):
|
||||
if heuristics == "foreach":
|
||||
heuristics_line = f"""
|
||||
@triton_heuristics.foreach(
|
||||
num_warps={self.num_warps},
|
||||
filename=__file__,
|
||||
triton_meta={triton_meta!r},
|
||||
inductor_meta={inductor_meta!r},
|
||||
)
|
||||
|
||||
@ -2063,7 +2063,8 @@ class PythonWrapperCodegen(CodeGen):
|
||||
neg = self.codegen_sizevar(
|
||||
sympy.Max(0, sympy.Min(x + node.size, node.size))
|
||||
)
|
||||
return f"{pos} if {x} >= 0 else {neg}"
|
||||
x_cond = self.codegen_sizevar(x)
|
||||
return f"{pos} if {x_cond} >= 0 else {neg}"
|
||||
|
||||
def codegen_with_step(start_var, end_var, step):
|
||||
if step == 1:
|
||||
|
||||
@ -360,7 +360,7 @@ def estimate_nccl_collective_runtime_from_fx_node(
|
||||
fx_node: torch.fx.Node,
|
||||
override_size: Optional[int] = None,
|
||||
# TODO(ivankobzarev): NCCL estimator sometimes fail unexpectedly, enable back after fix.
|
||||
use_nccl_estimator: bool = False,
|
||||
use_nccl_estimator: bool = True,
|
||||
) -> float:
|
||||
"""
|
||||
Returns estimated NCCL collective runtime in nanoseconds (ns).
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import cache, partial
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch._environment import is_fbcode
|
||||
|
||||
@ -3586,13 +3586,24 @@ def user_autotune(
|
||||
)
|
||||
|
||||
|
||||
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
|
||||
def foreach(triton_meta, filename=None, inductor_meta=None):
|
||||
"""
|
||||
Compile a triton foreach kernel
|
||||
"""
|
||||
configs = []
|
||||
|
||||
# Naive autotuning path for num_warps
|
||||
if not (
|
||||
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
|
||||
):
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=8))
|
||||
else:
|
||||
for warps in [1, 2, 4, 8]:
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
|
||||
|
||||
return cached_autotune(
|
||||
None,
|
||||
[triton.Config({}, num_stages=1, num_warps=num_warps)],
|
||||
configs,
|
||||
triton_meta=triton_meta,
|
||||
inductor_meta=inductor_meta,
|
||||
heuristic_type=HeuristicType.TEMPLATE,
|
||||
|
||||
@ -52,26 +52,7 @@ __all__ = [
|
||||
"MemRecordsAcc",
|
||||
]
|
||||
|
||||
try:
|
||||
# Available in Python >= 3.2
|
||||
from contextlib import ContextDecorator as _ContextDecorator
|
||||
except ImportError:
|
||||
import functools
|
||||
|
||||
class _ContextDecorator: # type: ignore[no-redef]
|
||||
def __enter__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, func):
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
from contextlib import ContextDecorator
|
||||
|
||||
|
||||
# global python state - whether profiler is currently enabled
|
||||
@ -744,8 +725,7 @@ class profile:
|
||||
return all_function_events
|
||||
|
||||
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
class record_function(_ContextDecorator):
|
||||
class record_function(ContextDecorator):
|
||||
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
|
||||
Label will only appear if CPU activity tracing is enabled.
|
||||
|
||||
|
||||
@ -108,12 +108,14 @@ struct FlightRecorder {
|
||||
capture_cpp_stack_ = getCvarBool(
|
||||
{"TORCH_FR_CPP_STACK", "TORCH_NCCL_TRACE_CPP_STACK"}, false);
|
||||
enabled_ = max_entries_ > 0;
|
||||
reset_epoch_start_idx_[0] = 0;
|
||||
}
|
||||
struct Entry {
|
||||
size_t id_; // incremented id in the trace buffer
|
||||
// used to figure out where in the circular entries
|
||||
// buffer this entry will be located to
|
||||
// update state information
|
||||
size_t reset_epoch_; // epoch when this entry was created
|
||||
size_t pg_id_;
|
||||
std::tuple<std::string, std::string> pg_name_; // <group_name, group_desc>
|
||||
|
||||
@ -183,11 +185,34 @@ struct FlightRecorder {
|
||||
size_t max_entries_ = 0;
|
||||
size_t next_ = 0;
|
||||
size_t id_ = 0;
|
||||
size_t reset_epoch_ = 0;
|
||||
std::unordered_map<size_t, size_t>
|
||||
reset_epoch_start_idx_; // maps reset_epoch to the idx where it starts
|
||||
std::map<size_t, std::shared_ptr<ProcessGroupStatus>> all_pg_status_;
|
||||
std::map<std::tuple<std::string, std::string>, std::vector<uint64_t>>
|
||||
pg_name_to_ranks_;
|
||||
std::string comm_lib_version_;
|
||||
|
||||
struct TraceIdentifier {
|
||||
std::optional<size_t> id;
|
||||
std::optional<size_t> reset_epoch;
|
||||
};
|
||||
|
||||
TraceIdentifier recordWithResetEnabled(
|
||||
size_t pg_id,
|
||||
const std::tuple<std::string, std::string>& pg_name,
|
||||
size_t collective_seq_id,
|
||||
size_t p2p_seq_id,
|
||||
size_t op_id,
|
||||
std::string profiling_name,
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
const std::vector<at::Tensor>& outputs,
|
||||
EventType* start,
|
||||
EventType* end,
|
||||
std::chrono::milliseconds timeout_ms,
|
||||
std::shared_ptr<ProcessGroupStatus> pg_status,
|
||||
bool isP2P);
|
||||
|
||||
std::optional<size_t> record(
|
||||
size_t pg_id,
|
||||
const std::tuple<std::string, std::string>& pg_name,
|
||||
@ -213,8 +238,16 @@ struct FlightRecorder {
|
||||
|
||||
std::vector<Entry> dump_entries();
|
||||
|
||||
// Returns the entry with the given id, if it exists. Otherwise, returns
|
||||
// std::nullopt.
|
||||
// Returns the index in entries_ for the given id and reset_epoch.
|
||||
// Caller must hold mutex_lock before calling this method.
|
||||
size_t getIdxFromId(size_t id, size_t reset_epoch) const;
|
||||
|
||||
// Returns the entry with the given id and reset_epoch, if it exists.
|
||||
// Otherwise, returns std::nullopt.
|
||||
TORCH_API std::optional<Entry> getEntry(
|
||||
std::optional<size_t> id,
|
||||
std::optional<size_t> reset_epoch);
|
||||
|
||||
TORCH_API std::optional<Entry> getEntry(std::optional<size_t> id);
|
||||
|
||||
/*
|
||||
@ -227,6 +260,11 @@ struct FlightRecorder {
|
||||
never hang. (timing must also be enabled for compute_duration - see
|
||||
TORCH_NCCL_ENABLE_TIMING).
|
||||
*/
|
||||
TORCH_API void retire_id(
|
||||
std::optional<size_t> id,
|
||||
std::optional<size_t> reset_epoch,
|
||||
bool compute_duration = true);
|
||||
|
||||
TORCH_API void retire_id(
|
||||
std::optional<size_t> id,
|
||||
bool compute_duration = true);
|
||||
|
||||
@ -53,8 +53,41 @@ std::optional<size_t> FlightRecorder<EventType>::record(
|
||||
std::chrono::milliseconds timeout_ms,
|
||||
std::shared_ptr<ProcessGroupStatus> pg_status,
|
||||
bool isP2P) {
|
||||
auto result = recordWithResetEnabled(
|
||||
pg_id,
|
||||
pg_name,
|
||||
collective_seq_id,
|
||||
p2p_seq_id,
|
||||
op_id,
|
||||
std::move(profiling_name),
|
||||
inputs,
|
||||
outputs,
|
||||
start,
|
||||
end,
|
||||
timeout_ms,
|
||||
std::move(pg_status),
|
||||
isP2P);
|
||||
return result.id;
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
|
||||
recordWithResetEnabled(
|
||||
size_t pg_id,
|
||||
const std::tuple<std::string, std::string>& pg_name,
|
||||
size_t collective_seq_id,
|
||||
size_t p2p_seq_id,
|
||||
size_t op_id,
|
||||
std::string profiling_name,
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
const std::vector<at::Tensor>& outputs,
|
||||
EventType* start,
|
||||
EventType* end,
|
||||
std::chrono::milliseconds timeout_ms,
|
||||
std::shared_ptr<ProcessGroupStatus> pg_status,
|
||||
bool isP2P) {
|
||||
if (!enabled_) {
|
||||
return std::nullopt;
|
||||
return TraceIdentifier{std::nullopt, std::nullopt};
|
||||
}
|
||||
if (all_pg_status_.find(pg_id) == all_pg_status_.end()) {
|
||||
// Current pg_status is not in FR.
|
||||
@ -64,8 +97,13 @@ std::optional<size_t> FlightRecorder<EventType>::record(
|
||||
torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
|
||||
TORCH_CHECK(
|
||||
reset_epoch_start_idx_.find(reset_epoch_) !=
|
||||
reset_epoch_start_idx_.end());
|
||||
|
||||
auto te = Entry{
|
||||
id_,
|
||||
reset_epoch_,
|
||||
pg_id,
|
||||
pg_name,
|
||||
collective_seq_id,
|
||||
@ -104,15 +142,20 @@ std::optional<size_t> FlightRecorder<EventType>::record(
|
||||
te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end());
|
||||
}
|
||||
|
||||
const auto next = next_++;
|
||||
|
||||
if (entries_.size() < max_entries_) {
|
||||
entries_.emplace_back(std::move(te));
|
||||
} else {
|
||||
entries_[next_++] = std::move(te);
|
||||
if (next_ == max_entries_) {
|
||||
next_ = 0;
|
||||
}
|
||||
entries_[next] = std::move(te);
|
||||
}
|
||||
return id_++;
|
||||
|
||||
if (next_ == max_entries_) {
|
||||
next_ = 0;
|
||||
}
|
||||
|
||||
const auto id = id_++;
|
||||
return TraceIdentifier{id, reset_epoch_};
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
@ -163,15 +206,20 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
std::vector<Entry> result;
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
result.reserve(entries_.size());
|
||||
result.insert(
|
||||
result.end(),
|
||||
// Filter entries during insertion - only keep entries from current epoch
|
||||
auto filter = [this](const Entry& e) {
|
||||
return e.reset_epoch_ == reset_epoch_;
|
||||
};
|
||||
std::copy_if(
|
||||
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
|
||||
entries_.end());
|
||||
result.insert(
|
||||
result.end(),
|
||||
entries_.end(),
|
||||
std::back_inserter(result),
|
||||
filter);
|
||||
std::copy_if(
|
||||
entries_.begin(),
|
||||
entries_.begin() + static_cast<std::ptrdiff_t>(next_));
|
||||
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
|
||||
std::back_inserter(result),
|
||||
filter);
|
||||
}
|
||||
// query any remaining events
|
||||
for (auto& r : result) {
|
||||
@ -182,28 +230,47 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
// Returns the entry with the given id, if it exists. Otherwise, returns
|
||||
// std::nullopt.
|
||||
// Returns the index in entries_ for the given id and reset_epoch.
|
||||
// Caller must hold mutex_lock before calling this method.
|
||||
size_t FlightRecorder<EventType>::getIdxFromId(size_t id, size_t reset_epoch)
|
||||
const {
|
||||
// Look up the starting idx for the given reset epoch
|
||||
auto it = reset_epoch_start_idx_.find(reset_epoch);
|
||||
TORCH_CHECK(it != reset_epoch_start_idx_.end());
|
||||
// Calculate idx based on where the epoch started
|
||||
return (it->second + id) % max_entries_;
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
// Returns the entry with the given id and reset_epoch, if it exists. Otherwise,
|
||||
// returns std::nullopt.
|
||||
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
EventType>::getEntry(std::optional<size_t> id) {
|
||||
if (!enabled_ || !id) {
|
||||
EventType>::
|
||||
getEntry(std::optional<size_t> id, std::optional<size_t> reset_epoch) {
|
||||
if (!enabled_ || !id || !reset_epoch) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
Entry entry = entries_.at(*id % max_entries_);
|
||||
if (entry.id_ == *id) {
|
||||
Entry entry = entries_.at(getIdxFromId(*id, *reset_epoch));
|
||||
if (entry.id_ == *id && entry.reset_epoch_ == *reset_epoch) {
|
||||
return entry;
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
EventType>::getEntry(std::optional<size_t> id) {
|
||||
return getEntry(id, 0);
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
void FlightRecorder<EventType>::retire_id(
|
||||
std::optional<size_t> id,
|
||||
std::optional<size_t> reset_epoch,
|
||||
bool compute_duration) {
|
||||
if (!enabled_ || !id) {
|
||||
if (!enabled_ || !id || !reset_epoch) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -214,8 +281,8 @@ void FlightRecorder<EventType>::retire_id(
|
||||
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
|
||||
Entry* entry = &entries_.at(*id % max_entries_);
|
||||
if (entry->id_ == *id) {
|
||||
Entry* entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
|
||||
if (entry->id_ == *id && entry->reset_epoch_ == *reset_epoch) {
|
||||
update_state(*entry);
|
||||
|
||||
if (compute_duration) {
|
||||
@ -237,8 +304,8 @@ void FlightRecorder<EventType>::retire_id(
|
||||
guard.lock();
|
||||
|
||||
// Refresh the entry pointer, see if the entry has been overwritten
|
||||
entry = &entries_.at(*id % max_entries_);
|
||||
if (entry->id_ != *id) {
|
||||
entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
|
||||
if (!(entry->id_ == *id && entry->reset_epoch_ == *reset_epoch)) {
|
||||
LOG(INFO) << "retire_id abandoned for id " << *id
|
||||
<< ", event was overwritten while waiting to compute duration.";
|
||||
return;
|
||||
@ -249,12 +316,23 @@ void FlightRecorder<EventType>::retire_id(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
void FlightRecorder<EventType>::retire_id(
|
||||
std::optional<size_t> id,
|
||||
bool compute_duration) {
|
||||
retire_id(id, 0, compute_duration);
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
void FlightRecorder<EventType>::reset_all() {
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
next_ = 0;
|
||||
id_ = 0;
|
||||
entries_.clear();
|
||||
if (!entries_.empty()) {
|
||||
// Soft delete: increment epoch to mark all existing entries as old
|
||||
// Store where the new epoch starts in the circular buffer
|
||||
reset_epoch_++;
|
||||
reset_epoch_start_idx_[reset_epoch_] = next_;
|
||||
id_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
|
||||
@ -708,7 +708,8 @@ void ProcessGroupGloo::runLoop(int workerIndex) {
|
||||
// TODO: We need to have numel of tensors for gloo as well.
|
||||
pgStatus_->lastCompletedNumelIn = 0;
|
||||
pgStatus_->lastCompletedNumelOut = 0;
|
||||
FlightRecorder<c10::Event>::get()->retire_id(work->trace_id_, false);
|
||||
FlightRecorder<c10::Event>::get()->retire_id(
|
||||
work->trace_id_, work->trace_reset_epoch_, false);
|
||||
lock.lock();
|
||||
workInProgress_[workerIndex].reset();
|
||||
}
|
||||
@ -780,7 +781,7 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
|
||||
pgStatus_->lastEnqueuedNumelOut = 0;
|
||||
// using c10d::FlightRecorder;
|
||||
// TODO: We need to have a way to use c10::Event inside gloo as well.
|
||||
work->trace_id_ = FlightRecorder<c10::Event>::get()->record(
|
||||
auto traceId = FlightRecorder<c10::Event>::get()->recordWithResetEnabled(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
collectiveCounter_,
|
||||
@ -795,6 +796,8 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
|
||||
work->getTimeout(),
|
||||
pgStatus_,
|
||||
false);
|
||||
work->trace_id_ = traceId.id;
|
||||
work->trace_reset_epoch_ = traceId.reset_epoch;
|
||||
workQueue_.push_back(std::move(work));
|
||||
lock.unlock();
|
||||
|
||||
|
||||
@ -99,6 +99,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
|
||||
// unique id used to tell the trace buffer that this
|
||||
// work has completed
|
||||
std::optional<uint64_t> trace_id_;
|
||||
std::optional<uint64_t> trace_reset_epoch_;
|
||||
std::shared_ptr<gloo::Context> context_;
|
||||
const std::chrono::milliseconds timeout_;
|
||||
|
||||
|
||||
@ -575,6 +575,7 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
|
||||
futureWorkResult_(w.futureWorkResult_),
|
||||
timingEnabled_(w.timingEnabled_),
|
||||
trace_id_(w.trace_id_),
|
||||
trace_reset_epoch_(w.trace_reset_epoch_),
|
||||
distDebugLevel_(w.distDebugLevel_) {
|
||||
exception_ = w.exception_;
|
||||
}
|
||||
@ -704,9 +705,9 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout(
|
||||
// Print the traceback of the collective at call time
|
||||
std::string ProcessGroupNCCL::WorkNCCL::getTraceback() const {
|
||||
// First step we get the corresponding record entry from FR, based on work's
|
||||
// trace_id_
|
||||
// trace_id_ and trace_reset_epoch_
|
||||
std::optional<FlightRecorderCUDA::Entry> entry =
|
||||
FlightRecorderCUDA::get()->getEntry(trace_id_);
|
||||
FlightRecorderCUDA::get()->getEntry(trace_id_, trace_reset_epoch_);
|
||||
if (entry.has_value()) {
|
||||
auto entryVal = entry.value();
|
||||
// Get stack trace from FR entry, in string format
|
||||
@ -2394,7 +2395,8 @@ void ProcessGroupNCCL::Watchdog::runLoop() {
|
||||
pg_->pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_);
|
||||
pg_->pgStatus_->lastCompletedNumelIn = work.numelIn_;
|
||||
pg_->pgStatus_->lastCompletedNumelOut = work.numelOut_;
|
||||
FlightRecorderCUDA::get()->retire_id(work.trace_id_, true);
|
||||
FlightRecorderCUDA::get()->retire_id(
|
||||
work.trace_id_, work.trace_reset_epoch_, true);
|
||||
if (pg_->onCompletionHook_) {
|
||||
// Move Work object to completedWorkList_ to be consumed by the hook
|
||||
// thread
|
||||
@ -3360,7 +3362,7 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
|
||||
// these objects to the Work because it has implications for keeping those
|
||||
// tensors alive longer and adds overhead when copying Work objects
|
||||
// between threads
|
||||
r->trace_id_ = FlightRecorderCUDA::get()->record(
|
||||
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
seqCollective_,
|
||||
@ -3374,6 +3376,8 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
|
||||
options_->timeout,
|
||||
pgStatus_,
|
||||
isP2P);
|
||||
r->trace_id_ = traceId.id;
|
||||
r->trace_reset_epoch_ = traceId.reset_epoch;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
@ -3593,6 +3597,7 @@ float ProcessGroupNCCL::endTimeEstimate() {
|
||||
#ifdef NCCL_SIM_INFO_INITIALIZER
|
||||
ncclSimInfo_t simInfo = NCCL_SIM_INFO_INITIALIZER;
|
||||
C10D_NCCL_CHECK(ncclGroupSimulateEnd(&simInfo), std::nullopt);
|
||||
--ncclActiveGroupCounter_;
|
||||
return simInfo.estimatedTime;
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
@ -3676,7 +3681,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
// later in endCoalescing we record a 'coalesced' Work which has
|
||||
// timing/state updates via watchdog thread, but lacks op metadata such as
|
||||
// input/output sizes and profilingTitle per-op in the group.
|
||||
FlightRecorderCUDA::get()->record(
|
||||
FlightRecorderCUDA::get()->recordWithResetEnabled(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
seqCollective_,
|
||||
@ -4168,7 +4173,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
|
||||
// TODO(whc) because we don't pass output {tensor} to initWork, we tell
|
||||
// initWork to not record, and then we manually call record passing all the
|
||||
// information it wants.
|
||||
work->trace_id_ = FlightRecorderCUDA::get()->record(
|
||||
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
seqCollective_,
|
||||
@ -4182,6 +4187,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
|
||||
options_->timeout,
|
||||
pgStatus_,
|
||||
/*isP2P=*/true);
|
||||
work->trace_id_ = traceId.id;
|
||||
work->trace_reset_epoch_ = traceId.reset_epoch;
|
||||
}
|
||||
|
||||
// Only check for NaN for send ops, for recv ops `tensor` can be a random
|
||||
|
||||
@ -505,6 +505,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// unique id used to tell the trace buffer that this
|
||||
// work has completed
|
||||
std::optional<uint64_t> trace_id_;
|
||||
std::optional<uint64_t> trace_reset_epoch_;
|
||||
DebugLevel distDebugLevel_;
|
||||
friend class ProcessGroupNCCL;
|
||||
};
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
#include <climits>
|
||||
#include <memory>
|
||||
@ -13,6 +14,7 @@
|
||||
HIDDEN_NAMESPACE_BEGIN(torch, stable)
|
||||
|
||||
using accelerator::DeviceIndex;
|
||||
using torch::headeronly::IntHeaderOnlyArrayRef;
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
// The torch::stable::Tensor class is a highlevel C++ wrapper around
|
||||
@ -93,6 +95,32 @@ class Tensor {
|
||||
return numel;
|
||||
}
|
||||
|
||||
// note: this API is, for all intents and purposes, the same as the one in
|
||||
// TensorBase.h: it returns a borrowed reference of the dimension sizes of
|
||||
// a Tensor.
|
||||
//
|
||||
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
|
||||
// which has slightly less functionality than a regular IntArrayRef. See
|
||||
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
|
||||
IntHeaderOnlyArrayRef sizes() const {
|
||||
int64_t* sizes;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(ath_.get(), &sizes));
|
||||
return IntHeaderOnlyArrayRef(sizes, dim());
|
||||
}
|
||||
|
||||
// note: this API is, for all intents and purposes, the same as the one in
|
||||
// TensorBase.h: it returns a borrowed reference of the strides of a
|
||||
// Tensor.
|
||||
//
|
||||
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
|
||||
// which has slightly less functionality than a regular IntArrayRef. See
|
||||
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
|
||||
IntHeaderOnlyArrayRef strides() const {
|
||||
int64_t* strides;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(ath_.get(), &strides));
|
||||
return IntHeaderOnlyArrayRef(strides, dim());
|
||||
}
|
||||
|
||||
// note: this is a subset of the original TensorBase API. It takes no
|
||||
// arguments whereas the original API takes in a kwarg of memory format.
|
||||
// Here, we assume the default contiguous memory format.
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
import functools
|
||||
import math
|
||||
import operator
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from datetime import timedelta
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch._C import ScriptObject
|
||||
|
||||
@ -10,6 +10,7 @@ from ._context_parallel._attention import (
|
||||
_enable_context_parallel_dispatcher,
|
||||
_is_causal_behavior,
|
||||
_RotateMethod,
|
||||
_templated_ring_attention,
|
||||
context_parallel,
|
||||
context_parallel_unshard,
|
||||
set_rotate_method,
|
||||
@ -22,6 +23,7 @@ from ._context_parallel._load_balancer import (
|
||||
)
|
||||
|
||||
|
||||
# TODO(fegin): add deprecation message once the final interfaces are concluded.
|
||||
__all__ = [
|
||||
"_CausalBehavior",
|
||||
"_context_parallel_shard",
|
||||
@ -31,6 +33,7 @@ __all__ = [
|
||||
"_enable_context_parallel_dispatcher",
|
||||
"_is_causal_behavior",
|
||||
"_RotateMethod",
|
||||
"_templated_ring_attention",
|
||||
"context_parallel",
|
||||
"context_parallel_unshard",
|
||||
"set_rotate_method",
|
||||
|
||||
@ -547,6 +547,7 @@ def rebind_unbacked(
|
||||
assert shape_env is not None
|
||||
for raw_u0, path in bindings.items():
|
||||
u1 = pytree.key_get(result, path)
|
||||
|
||||
# Sometimes, things were previously unbacked bindings become constants.
|
||||
# There are two situations this can happen.
|
||||
#
|
||||
@ -602,7 +603,23 @@ def rebind_unbacked(
|
||||
if u1.node.hint is not None:
|
||||
continue
|
||||
|
||||
raw_u1 = u1.node.expr
|
||||
# unbacked symbols bindings might be replaced to other backed or
|
||||
# unbacked replacements.
|
||||
#
|
||||
# Example:
|
||||
# u = x.item()
|
||||
# torch._check(u == 5)
|
||||
#
|
||||
# The safest approach is to retrieve raw_u1 from u1.node._expr
|
||||
# and perform the rebinding on the original unbacked symbol,
|
||||
# even if it’s no longer directly referenced.
|
||||
#
|
||||
# In other words, we should always rebind the original symbol
|
||||
# before any replacements are applied.
|
||||
# u0 -> u0 == s1
|
||||
raw_u1 = u1.node._expr
|
||||
|
||||
# TODO Do we still need this logic below?
|
||||
# Simplify SymBool binding
|
||||
if (
|
||||
isinstance(raw_u1, sympy.Piecewise)
|
||||
|
||||
@ -648,6 +648,15 @@ class CodeGen:
|
||||
|
||||
if verbose:
|
||||
# override annotation with more detailed information
|
||||
try:
|
||||
from torch.distributed.tensor._api import DTensor, DTensorSpec
|
||||
|
||||
dtensorspec_format_shard_order_str = (
|
||||
DTensorSpec.format_shard_order_str
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
DTensor = None # type: ignore[assignment,misc]
|
||||
dtensorspec_format_shard_order_str = None
|
||||
from torch.fx.experimental.proxy_tensor import py_sym_types
|
||||
from torch.fx.passes.shape_prop import TensorMetadata
|
||||
|
||||
@ -678,6 +687,16 @@ class CodeGen:
|
||||
core = _tensor_annotation(meta_val)
|
||||
if is_plain:
|
||||
maybe_type_annotation = f': "{core}"'
|
||||
elif type(meta_val) is DTensor:
|
||||
assert dtensorspec_format_shard_order_str is not None
|
||||
dtensor_meta = dtensorspec_format_shard_order_str(
|
||||
meta_val._spec.placements, # type: ignore[attr-defined]
|
||||
meta_val._spec.shard_order, # type: ignore[attr-defined]
|
||||
)
|
||||
cls = meta_val.__class__.__name__
|
||||
maybe_type_annotation = (
|
||||
f': "{cls}({core}, {dim_green(dtensor_meta)})"'
|
||||
)
|
||||
else:
|
||||
cls = meta_val.__class__.__name__
|
||||
maybe_type_annotation = f': "{cls}({core})"'
|
||||
|
||||
@ -165,6 +165,7 @@ def insert_deferred_runtime_asserts(
|
||||
node: torch.fx.Node,
|
||||
stack_trace: Optional[str] = None,
|
||||
nn_module_stack: Optional[dict[str, Any]] = None,
|
||||
custom: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
fake_args = pytree.tree_map(
|
||||
lambda arg: (
|
||||
@ -188,6 +189,8 @@ def insert_deferred_runtime_asserts(
|
||||
node.meta["stack_trace"] = stack_trace
|
||||
if nn_module_stack is not None:
|
||||
node.meta["nn_module_stack"] = nn_module_stack
|
||||
if custom is not None:
|
||||
node.meta["custom"] = custom
|
||||
|
||||
# Track asserts/checks we've added
|
||||
added_asserts: set[sympy.Expr] = set()
|
||||
@ -617,6 +620,9 @@ def insert_deferred_runtime_asserts(
|
||||
_node_metadata_hook,
|
||||
stack_trace=node.meta.get("stack_trace"),
|
||||
nn_module_stack=node.meta.get("nn_module_stack"),
|
||||
# nodes added in `apply_runtime_assertion_pass` will have the same annotation
|
||||
# as the input node to the assertion
|
||||
custom=node.meta.get("custom"),
|
||||
),
|
||||
):
|
||||
if (min_val := convert(vr.lower)) is not None:
|
||||
|
||||
@ -210,7 +210,8 @@ class _KinetoProfile:
|
||||
def start_trace(self) -> None:
|
||||
if self.execution_trace_observer:
|
||||
self.execution_trace_observer.start()
|
||||
assert self.profiler is not None
|
||||
if self.profiler is None:
|
||||
raise AssertionError("Profiler must be initialized before starting trace")
|
||||
self.profiler._start_trace()
|
||||
|
||||
if self.profile_memory:
|
||||
@ -256,7 +257,8 @@ class _KinetoProfile:
|
||||
def stop_trace(self) -> None:
|
||||
if self.execution_trace_observer:
|
||||
self.execution_trace_observer.stop()
|
||||
assert self.profiler is not None
|
||||
if self.profiler is None:
|
||||
raise AssertionError("Profiler must be initialized before stopping trace")
|
||||
self.profiler.__exit__(None, None, None)
|
||||
|
||||
def export_chrome_trace(self, path: str):
|
||||
@ -264,7 +266,10 @@ class _KinetoProfile:
|
||||
Exports the collected trace in Chrome JSON format. If kineto is enabled, only
|
||||
last cycle in schedule is exported.
|
||||
"""
|
||||
assert self.profiler
|
||||
if self.profiler is None:
|
||||
raise AssertionError(
|
||||
"Profiler must be initialized before exporting chrome trace"
|
||||
)
|
||||
if path.endswith(".gz"):
|
||||
fp = tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False)
|
||||
fp.close()
|
||||
@ -284,7 +289,8 @@ class _KinetoProfile:
|
||||
path (str): save stacks file to this location;
|
||||
metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total"
|
||||
"""
|
||||
assert self.profiler
|
||||
if self.profiler is None:
|
||||
raise AssertionError("Profiler must be initialized before exporting stacks")
|
||||
return self.profiler.export_stacks(path, metric)
|
||||
|
||||
def toggle_collection_dynamic(
|
||||
@ -316,7 +322,7 @@ class _KinetoProfile:
|
||||
print(p.key_averages().table(
|
||||
sort_by="self_cuda_time_total", row_limit=-1))
|
||||
"""
|
||||
if not self.profiler:
|
||||
if self.profiler is None:
|
||||
return
|
||||
self.profiler.toggle_collection_dynamic(enable, activities)
|
||||
|
||||
@ -333,7 +339,10 @@ class _KinetoProfile:
|
||||
To use shape/stack functionality make sure to set record_shapes/with_stack
|
||||
when creating profiler context manager.
|
||||
"""
|
||||
assert self.profiler
|
||||
if self.profiler is None:
|
||||
raise AssertionError(
|
||||
"Profiler must be initialized before getting key averages"
|
||||
)
|
||||
return self.profiler.key_averages(
|
||||
group_by_input_shape, group_by_stack_n, group_by_overload_name
|
||||
)
|
||||
@ -343,7 +352,8 @@ class _KinetoProfile:
|
||||
Returns the list of unaggregated profiler events,
|
||||
to be used in the trace callback or after the profiling is finished
|
||||
"""
|
||||
assert self.profiler
|
||||
if self.profiler is None:
|
||||
raise AssertionError("Profiler must be initialized before accessing events")
|
||||
return self.profiler.function_events
|
||||
|
||||
def add_metadata(self, key: str, value: str) -> None:
|
||||
@ -395,7 +405,10 @@ class _KinetoProfile:
|
||||
if missing:
|
||||
raise ValueError(f"{', '.join(missing)} required for memory profiling.")
|
||||
|
||||
assert self.profiler is not None and self.profiler.kineto_results is not None
|
||||
if self.profiler is None or self.profiler.kineto_results is None:
|
||||
raise AssertionError(
|
||||
"Profiler and kineto_results must be initialized for memory profiling"
|
||||
)
|
||||
return MemoryProfile(self.profiler.kineto_results)
|
||||
|
||||
def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None:
|
||||
@ -485,7 +498,8 @@ def schedule(
|
||||
"""
|
||||
|
||||
def schedule_fn(step: int) -> ProfilerAction:
|
||||
assert step >= 0
|
||||
if step < 0:
|
||||
raise AssertionError(f"Step must be non-negative. Got {step}.")
|
||||
if step < skip_first:
|
||||
return ProfilerAction.NONE
|
||||
else:
|
||||
@ -508,9 +522,11 @@ def schedule(
|
||||
else ProfilerAction.RECORD_AND_SAVE
|
||||
)
|
||||
|
||||
assert (
|
||||
wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0
|
||||
), "Invalid profiler schedule arguments"
|
||||
if wait < 0 or warmup < 0 or active <= 0 or repeat < 0 or skip_first < 0:
|
||||
raise AssertionError(
|
||||
f"Invalid profiler schedule arguments. Got wait={wait} (need >= 0), warmup={warmup} (need >= 0), "
|
||||
f"active={active} (need > 0), repeat={repeat} (need >= 0), skip_first={skip_first} (need >= 0)."
|
||||
)
|
||||
if warmup == 0:
|
||||
warn(
|
||||
"Profiler won't be using warmup, this can skew profiler results",
|
||||
@ -717,7 +733,8 @@ class profile(_KinetoProfile):
|
||||
activities_set.add(ProfilerActivity.CUDA)
|
||||
elif ProfilerActivity.CUDA in activities_set:
|
||||
activities_set.remove(ProfilerActivity.CUDA)
|
||||
assert len(activities_set) > 0, "No valid profiler activities found"
|
||||
if len(activities_set) == 0:
|
||||
raise AssertionError("No valid profiler activities found")
|
||||
|
||||
super().__init__(
|
||||
activities=activities,
|
||||
|
||||
@ -306,6 +306,24 @@ class PythonPrinter(ExprPrinter):
|
||||
raise TypeError("ndigits must be an instance of sympy.Integer")
|
||||
return f"round({self._print(number)}, {ndigits})"
|
||||
|
||||
def _print_Piecewise(self, expr: sympy.Expr) -> str:
|
||||
# Convert Piecewise(expr_cond_pairs) to nested ternary expressions
|
||||
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
|
||||
# becomes: e1 if c1 else (e2 if c2 else (... else eN))
|
||||
result: Optional[str] = None
|
||||
for expr_i, cond_i in reversed(expr.args):
|
||||
expr_str = self._print(expr_i)
|
||||
if cond_i == True: # noqa: E712
|
||||
# This is the default case
|
||||
result = expr_str
|
||||
else:
|
||||
cond_str = self._print(cond_i)
|
||||
if result is None:
|
||||
result = expr_str
|
||||
else:
|
||||
result = f"({expr_str} if {cond_str} else {result})"
|
||||
return result if result else "0"
|
||||
|
||||
|
||||
class CppPrinter(ExprPrinter):
|
||||
def _print_Integer(self, expr: sympy.Expr) -> str:
|
||||
@ -327,6 +345,24 @@ class CppPrinter(ExprPrinter):
|
||||
)
|
||||
return f"{c} ? {p} : {q}"
|
||||
|
||||
def _print_Piecewise(self, expr: sympy.Expr) -> str:
|
||||
# Convert Piecewise(expr_cond_pairs) to nested ternary operators
|
||||
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
|
||||
# becomes: c1 ? e1 : (c2 ? e2 : (... : eN))
|
||||
result: Optional[str] = None
|
||||
for expr_i, cond_i in reversed(expr.args):
|
||||
expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5)
|
||||
if cond_i == True: # noqa: E712
|
||||
# This is the default case
|
||||
result = expr_str
|
||||
else:
|
||||
cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5)
|
||||
if result is None:
|
||||
result = expr_str
|
||||
else:
|
||||
result = f"{cond_str} ? {expr_str} : {result}"
|
||||
return f"({result})" if result else "0"
|
||||
|
||||
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
|
||||
x, div, mod = expr.args
|
||||
x = self.doprint(x)
|
||||
|
||||
Reference in New Issue
Block a user