mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 06:07:55 +08:00
Compare commits
18 Commits
ciflow/tru
...
ciflow/pul
| Author | SHA1 | Date | |
|---|---|---|---|
| 0911360736 | |||
| 38de8d0d33 | |||
| a9fe64bee2 | |||
| 84436662a3 | |||
| 794e09311c | |||
| 641d0bae63 | |||
| f9851af59b | |||
| eeebf9f664 | |||
| d9a50bf9a8 | |||
| 2984331c87 | |||
| 9b68682df2 | |||
| 8f5f89c9a0 | |||
| 8919f69362 | |||
| 19c867873a | |||
| e3dadb1d36 | |||
| c9b09a31e8 | |||
| 35571fe94b | |||
| 485f2b607a |
@ -116,7 +116,7 @@ case "$tag" in
|
||||
INSTALL_MINGW=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11)
|
||||
CUDA_VERSION=13.0.0
|
||||
CUDA_VERSION=13.0.2
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
GCC_VERSION=11
|
||||
VISION=yes
|
||||
@ -125,6 +125,16 @@ case "$tag" in
|
||||
UCC_COMMIT=${_UCC_COMMIT}
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc9)
|
||||
CUDA_VERSION=13.0.2
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
GCC_VERSION=9
|
||||
VISION=yes
|
||||
KATEX=yes
|
||||
UCX_COMMIT=${_UCX_COMMIT}
|
||||
UCC_COMMIT=${_UCC_COMMIT}
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks)
|
||||
CUDA_VERSION=12.8.1
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
|
||||
@ -1680,6 +1680,22 @@ test_operator_microbenchmark() {
|
||||
done
|
||||
}
|
||||
|
||||
test_attention_microbenchmark() {
|
||||
TEST_REPORTS_DIR=$(pwd)/test/test-reports
|
||||
mkdir -p "$TEST_REPORTS_DIR"
|
||||
TEST_DIR=$(pwd)
|
||||
|
||||
# Install attention-gym dependency
|
||||
echo "Installing attention-gym..."
|
||||
python -m pip install git+https://github.com/meta-pytorch/attention-gym.git@main
|
||||
pip show triton
|
||||
|
||||
cd "${TEST_DIR}"/benchmarks/transformer
|
||||
|
||||
$TASKSET python score_mod.py --config configs/config_basic.yaml \
|
||||
--output-json-for-dashboard "${TEST_REPORTS_DIR}/attention_microbenchmark.json"
|
||||
}
|
||||
|
||||
if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
|
||||
(cd test && python -c "import torch; print(torch.__config__.show())")
|
||||
(cd test && python -c "import torch; print(torch.__config__.parallel_info())")
|
||||
@ -1737,6 +1753,8 @@ elif [[ "${TEST_CONFIG}" == *operator_benchmark* ]]; then
|
||||
fi
|
||||
elif [[ "${TEST_CONFIG}" == *operator_microbenchmark* ]]; then
|
||||
test_operator_microbenchmark
|
||||
elif [[ "${TEST_CONFIG}" == *attention_microbenchmark* ]]; then
|
||||
test_attention_microbenchmark
|
||||
elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
|
||||
test_inductor_distributed
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
|
||||
|
||||
73
.github/workflows/attention_op_microbenchmark.yml
vendored
Normal file
73
.github/workflows/attention_op_microbenchmark.yml
vendored
Normal file
@ -0,0 +1,73 @@
|
||||
name: attention_op_microbenchmark
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/op-benchmark/*
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
# Run at 06:00 UTC everyday
|
||||
- cron: 0 7 * * *
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
attn-microbenchmark-build:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '8.0 9.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "attention_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.a100" },
|
||||
{ config: "attention_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.h100" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
attn-microbenchmark-test:
|
||||
name: attn-microbenchmark-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: attn-microbenchmark-build
|
||||
with:
|
||||
timeout-minutes: 500
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
|
||||
docker-image: ${{ needs.attn-microbenchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.attn-microbenchmark-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
# B200 runner
|
||||
opmicrobenchmark-build-b200:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: opmicrobenchmark-build-b200
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '10.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
opmicrobenchmark-test-b200:
|
||||
name: opmicrobenchmark-test-b200
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: opmicrobenchmark-build-b200
|
||||
with:
|
||||
timeout-minutes: 500
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }}
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
secrets: inherit
|
||||
1
.github/workflows/docker-builds.yml
vendored
1
.github/workflows/docker-builds.yml
vendored
@ -54,6 +54,7 @@ jobs:
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm,
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks,
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9,
|
||||
pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc9,
|
||||
pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11,
|
||||
pytorch-linux-jammy-py3.10-clang12,
|
||||
pytorch-linux-jammy-py3.11-clang12,
|
||||
|
||||
33
.github/workflows/periodic.yml
vendored
33
.github/workflows/periodic.yml
vendored
@ -204,6 +204,39 @@ jobs:
|
||||
test-matrix: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda13_0-py3_10-gcc9-debug-build:
|
||||
name: linux-jammy-cuda13.0-py3.10-gcc9-debug
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-cuda13.0-py3.10-gcc9-debug
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc9
|
||||
cuda-arch-list: 8.6
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda13_0-py3_10-gcc9-debug-test:
|
||||
name: linux-jammy-cuda13.0-py3.10-gcc9-debug
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs:
|
||||
- linux-jammy-cuda13_0-py3_10-gcc9-debug-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-jammy-cuda13.0-py3.10-gcc9-debug
|
||||
docker-image: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc9-debug-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc9-debug-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda12_8-py3-gcc11-slow-gradcheck-build:
|
||||
name: linux-jammy-cuda12.8-py3-gcc11-slow-gradcheck
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
|
||||
29
.github/workflows/pull.yml
vendored
29
.github/workflows/pull.yml
vendored
@ -268,6 +268,35 @@ jobs:
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda13_0-py3_10-gcc9-build:
|
||||
name: linux-jammy-cuda13.0-py3.10-gcc9
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-cuda13.0-py3.10-gcc9
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc9
|
||||
cuda-arch-list: 8.9
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda13_0-py3_10-gcc9-test:
|
||||
name: linux-jammy-cuda13.0-py3.10-gcc9
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: linux-jammy-cuda13_0-py3_10-gcc9-build
|
||||
with:
|
||||
build-environment: linux-jammy-cuda13.0-py3.10-gcc9
|
||||
docker-image: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc9-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc9-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cpu-py3_10-gcc11-bazel-test:
|
||||
name: linux-jammy-cpu-py3.10-gcc11-bazel-test
|
||||
uses: ./.github/workflows/_bazel-build-test.yml
|
||||
|
||||
29
.github/workflows/slow.yml
vendored
29
.github/workflows/slow.yml
vendored
@ -78,6 +78,35 @@ jobs:
|
||||
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm86-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda13_0-py3_10-gcc11-sm86-build:
|
||||
name: linux-jammy-cuda13.0-py3.10-gcc11-sm86
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-cuda13.0-py3.10-gcc11-sm86
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11
|
||||
cuda-arch-list: 8.6
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "slow", shard: 1, num_shards: 3, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
{ config: "slow", shard: 2, num_shards: 3, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
{ config: "slow", shard: 3, num_shards: 3, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda13_0-py3_10-gcc11-sm86-test:
|
||||
name: linux-jammy-cuda13.0-py3.10-gcc11-sm86
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs:
|
||||
- linux-jammy-cuda13_0-py3_10-gcc11-sm86-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-jammy-cuda13.0-py3.10-gcc11-sm86
|
||||
docker-image: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-sm86-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-sm86-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-py3_10-clang12-build:
|
||||
name: linux-jammy-py3.10-clang12
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
|
||||
67
.github/workflows/trunk.yml
vendored
67
.github/workflows/trunk.yml
vendored
@ -63,6 +63,23 @@ jobs:
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
libtorch-linux-jammy-cuda13_0-py3_10-gcc11-debug-build:
|
||||
name: libtorch-linux-jammy-cuda13.0-py3.10-gcc11-debug
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
build-environment: libtorch-linux-jammy-cuda13.0-py3.10-gcc11
|
||||
cuda-arch-list: '7.5 8.9'
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11
|
||||
build-generates-artifacts: false
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: "linux.c7i.4xlarge"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 1 },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda12_8-py3_10-gcc11-build:
|
||||
name: linux-jammy-cuda12.8-py3.10-gcc11
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
@ -99,6 +116,41 @@ jobs:
|
||||
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda13_0-py3_10-gcc11-build:
|
||||
name: linux-jammy-cuda13.0-py3.10-gcc11
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-cuda13.0-py3.10-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '7.5 8.9'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
|
||||
{ config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
|
||||
{ config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
|
||||
{ config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda13_0-py3_10-gcc11-test:
|
||||
name: linux-jammy-cuda13.0-py3.10-gcc11
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs:
|
||||
- linux-jammy-cuda13_0-py3_10-gcc11-build
|
||||
- target-determination
|
||||
with:
|
||||
timeout-minutes: 360
|
||||
build-environment: linux-jammy-cuda13.0-py3.10-gcc11
|
||||
docker-image: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
# no-ops builds test USE_PER_OPERATOR_HEADERS=0 where ATen/ops is not generated
|
||||
linux-jammy-cuda12_8-py3_10-gcc11-no-ops-build:
|
||||
@ -115,6 +167,21 @@ jobs:
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda13_0-py3_10-gcc11-no-ops-build:
|
||||
name: linux-jammy-cuda13.0-py3.10-gcc11-no-ops
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-cuda13.0-py3.10-gcc11-no-ops
|
||||
cuda-arch-list: '7.5 8.9'
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 1 },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
macos-py3-arm64-build:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: macos-py3-arm64
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <shared_mutex>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cusparse.h>
|
||||
@ -89,13 +88,8 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
|
||||
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
|
||||
|
||||
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
|
||||
struct WorkspaceMapWithMutex {
|
||||
std::map<std::tuple<void*, void*>, at::DataPtr> map;
|
||||
std::shared_mutex mutex;
|
||||
};
|
||||
|
||||
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublas_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
|
||||
TORCH_CUDA_CPP_API size_t getCUDABlasLtWorkspaceSize();
|
||||
TORCH_CUDA_CPP_API void* getCUDABlasLtWorkspace();
|
||||
|
||||
@ -99,7 +99,7 @@ void destroyCublasHandle(cublasHandle_t handle) {
|
||||
// - Comments of @soumith copied from cuDNN handle pool implementation
|
||||
#ifdef NO_CUDNN_DESTROY_HANDLE
|
||||
#else
|
||||
cublasDestroy(handle);
|
||||
cublasDestroy(handle);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -107,27 +107,19 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
|
||||
|
||||
} // namespace
|
||||
|
||||
WorkspaceMapWithMutex& cublas_handle_stream_to_workspace() {
|
||||
static auto& instance = *new WorkspaceMapWithMutex;
|
||||
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
|
||||
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
|
||||
return instance;
|
||||
}
|
||||
|
||||
WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace() {
|
||||
static auto& instance = *new WorkspaceMapWithMutex;
|
||||
std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace() {
|
||||
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void clearCublasWorkspaces() {
|
||||
{
|
||||
auto& workspace = cublas_handle_stream_to_workspace();
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
workspace.map.clear();
|
||||
}
|
||||
{
|
||||
auto& workspace = cublaslt_handle_stream_to_workspace();
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
workspace.map.clear();
|
||||
}
|
||||
cublas_handle_stream_to_workspace().clear();
|
||||
cublaslt_handle_stream_to_workspace().clear();
|
||||
}
|
||||
|
||||
size_t parseChosenWorkspaceSize() {
|
||||
@ -249,10 +241,8 @@ void* getCUDABlasLtWorkspace() {
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
auto& workspace = at::cuda::cublas_handle_stream_to_workspace();
|
||||
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
TORCH_INTERNAL_ASSERT(workspace_it != workspace.map.end());
|
||||
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
|
||||
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
#endif
|
||||
@ -260,34 +250,11 @@ void* getCUDABlasLtWorkspace() {
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
|
||||
auto& workspace = cublaslt_handle_stream_to_workspace();
|
||||
|
||||
// Fast path: check if workspace already exists
|
||||
{
|
||||
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
if (workspace_it != workspace.map.end()) {
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: allocate workspace outside the lock
|
||||
auto new_workspace = getNewCUDABlasLtWorkspace();
|
||||
|
||||
// Insert with lock (double-check in case another thread inserted while we
|
||||
// were allocating)
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
if (workspace_it == workspace.map.end()) {
|
||||
workspace_it =
|
||||
workspace.map.emplace(key, std::move(new_workspace)).first;
|
||||
}
|
||||
// else: another thread inserted it, our new_workspace will be automatically
|
||||
// freed
|
||||
return workspace_it->second.mutable_get();
|
||||
auto workspace_it = cublaslt_handle_stream_to_workspace().find(key);
|
||||
if (workspace_it == cublaslt_handle_stream_to_workspace().end()) {
|
||||
workspace_it = cublaslt_handle_stream_to_workspace().insert(workspace_it, {key, getNewCUDABlasLtWorkspace()});
|
||||
}
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
|
||||
cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
@ -333,39 +300,11 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
// all the memory and cublas's cudaMallocAsync will return OOM
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
|
||||
auto& workspace = cublas_handle_stream_to_workspace();
|
||||
|
||||
size_t workspace_size = getChosenWorkspaceSize();
|
||||
|
||||
// Fast path: check if workspace already exists
|
||||
{
|
||||
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
if (workspace_it != workspace.map.end()) {
|
||||
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(
|
||||
handle, workspace_it->second.get(), workspace_size));
|
||||
return handle;
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: allocate workspace outside the lock
|
||||
auto new_workspace = getNewWorkspace();
|
||||
|
||||
// Insert with lock (double-check in case another thread inserted while we
|
||||
// were allocating)
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
if (workspace_it == workspace.map.end()) {
|
||||
workspace_it =
|
||||
workspace.map.emplace(key, std::move(new_workspace)).first;
|
||||
}
|
||||
// else: another thread inserted it, our new_workspace will be automatically
|
||||
// freed
|
||||
TORCH_CUDABLAS_CHECK(
|
||||
cublasSetWorkspace(handle, workspace_it->second.get(), workspace_size));
|
||||
auto workspace_it = cublas_handle_stream_to_workspace().find(key);
|
||||
if (workspace_it == cublas_handle_stream_to_workspace().end()) {
|
||||
workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
|
||||
}
|
||||
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
|
||||
#if !defined(USE_ROCM)
|
||||
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
|
||||
// FP32 data type calculations based on the value of the allow_tf32 flag.
|
||||
|
||||
@ -61,7 +61,6 @@ list(APPEND ATen_CUDA_TEST_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_math_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cub_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cublas_handle_pool_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp
|
||||
|
||||
@ -1,77 +0,0 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
// Test concurrent access to getCurrentCUDABlasHandle and getCUDABlasLtWorkspace
|
||||
// to verify that the data race fix is working correctly
|
||||
|
||||
TEST(CUDABlasHandlePoolTest, ConcurrentGetAndClearWorkspaces) {
|
||||
if (!at::cuda::is_available()) {
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int num_accessor_threads = 15;
|
||||
constexpr int num_clear_threads = 5;
|
||||
constexpr int iterations_per_thread = 50;
|
||||
|
||||
std::atomic<bool> stop{false};
|
||||
std::atomic<int> error_count{0};
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(num_accessor_threads + num_clear_threads);
|
||||
|
||||
// Launch accessor threads
|
||||
for (int i = 0; i < num_accessor_threads; ++i) {
|
||||
threads.emplace_back([&stop, &error_count]() {
|
||||
try {
|
||||
at::cuda::CUDAGuard device_guard(0);
|
||||
|
||||
while (!stop.load(std::memory_order_relaxed)) {
|
||||
const auto handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
const auto workspace = at::cuda::getCUDABlasLtWorkspace();
|
||||
|
||||
if (handle == nullptr || workspace == nullptr) {
|
||||
error_count++;
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
error_count++;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Launch threads that clear workspaces
|
||||
for (int i = 0; i < num_clear_threads; ++i) {
|
||||
threads.emplace_back([&error_count]() {
|
||||
try {
|
||||
for (int j = 0; j < iterations_per_thread; ++j) {
|
||||
at::cuda::clearCublasWorkspaces();
|
||||
std::this_thread::yield();
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
error_count++;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Let them run for a bit
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
stop.store(true, std::memory_order_relaxed);
|
||||
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
|
||||
EXPECT_EQ(error_count.load(), 0);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
c10::cuda::CUDACachingAllocator::init(1);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@ -125,6 +125,17 @@ AttentionType = Literal[
|
||||
]
|
||||
DtypeString = Literal["bfloat16", "float16", "float32"]
|
||||
SpeedupType = Literal["fwd", "bwd"]
|
||||
# Operator Name mapping
|
||||
backend_to_operator_name = {
|
||||
"math": "math attention kernel",
|
||||
"efficient": "efficient attention kernel",
|
||||
"cudnn": "cudnn attention kernel",
|
||||
"fav2": "flash attention 2 kernel",
|
||||
"fav3": "flash attention 3 kernel",
|
||||
"fakv": "flash attention kv cache kernel",
|
||||
"og-eager": "eager attention kernel",
|
||||
"flex": "flex attention kernel",
|
||||
}
|
||||
|
||||
|
||||
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
|
||||
@ -1265,12 +1276,14 @@ def _output_json_for_dashboard(
|
||||
model: ModelInfo
|
||||
metric: MetricInfo
|
||||
|
||||
operator_name = backend_to_operator_name.get(backend, backend)
|
||||
|
||||
# Benchmark extra info
|
||||
benchmark_extra_info = {
|
||||
"input_config": input_config,
|
||||
"device": device,
|
||||
"arch": device_arch,
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
"attn_type": config.attn_type,
|
||||
"shape": str(config.shape),
|
||||
"max_autotune": config.max_autotune,
|
||||
@ -1288,7 +1301,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
"attn_type": config.attn_type,
|
||||
},
|
||||
),
|
||||
@ -1315,7 +1328,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
@ -1341,7 +1354,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
@ -1371,7 +1384,7 @@ def _output_json_for_dashboard(
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
"operator_name": operator_name,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
|
||||
@ -1394,6 +1394,9 @@ if(NOT INTERN_BUILD_MOBILE)
|
||||
# https://github.com/pytorch/pytorch/pull/55292
|
||||
string(APPEND CMAKE_CUDA_FLAGS " -DCUB_WRAPPED_NAMESPACE=at_cuda_detail")
|
||||
|
||||
# Suppress cusparse warnings
|
||||
string(APPEND CMAKE_CUDA_FLAGS " -DDISABLE_CUSPARSE_DEPRECATED")
|
||||
|
||||
message(STATUS "Found CUDA with FP16 support, compiling with torch.cuda.HalfTensor")
|
||||
string(APPEND CMAKE_CUDA_FLAGS " -DCUDA_HAS_FP16=1"
|
||||
" -D__CUDA_NO_HALF_OPERATORS__"
|
||||
|
||||
@ -54,12 +54,10 @@ from torch.testing._internal.common_distributed import (
|
||||
verify_ddp_error_logged,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
MI300_ARCH,
|
||||
retry_on_connect_failures,
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle,
|
||||
skipIfRocm,
|
||||
skipIfRocmArch,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
@ -1233,7 +1231,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
||||
self._test_gather_stress(inputs, lambda t: t.clone())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skipIfRocmArch(MI300_ARCH)
|
||||
@skipIfRocm
|
||||
@requires_gloo()
|
||||
def test_gather_stress_cuda(self):
|
||||
inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)]
|
||||
|
||||
@ -18,15 +18,16 @@ from functorch.compile import (
|
||||
nop,
|
||||
)
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from torch._higher_order_ops.effects import with_effects
|
||||
from torch._higher_order_ops.effects import (
|
||||
_EffectType,
|
||||
_get_effect,
|
||||
_register_effectful_op,
|
||||
with_effects,
|
||||
)
|
||||
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_cuda import (
|
||||
_get_torch_cuda_version,
|
||||
SM70OrLater,
|
||||
SM80OrLater,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import SM70OrLater, SM80OrLater
|
||||
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
@ -300,7 +301,6 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
@unittest.skipIf(IS_WINDOWS, "triton")
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "triton")
|
||||
@unittest.skipIf(not SM80OrLater, "triton")
|
||||
@unittest.skipIf(_get_torch_cuda_version() >= (11, 7), "triton")
|
||||
@unittest.skipIf(not TEST_CUDA, "triton")
|
||||
@skipIfNoDynamoSupport
|
||||
def test_register_effectful_custom_op(self):
|
||||
@ -308,41 +308,23 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
torch._dynamo.config.capture_scalar_outputs = True
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
|
||||
torch.library.define(
|
||||
"mylib::record_scalar_tensor",
|
||||
"(Tensor x, str prefix) -> ()",
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
# global variable to store the recorded tensor and prefix.
|
||||
recorded_dict = {}
|
||||
|
||||
# Pytorch custorm op implementation
|
||||
@torch.library.impl(
|
||||
"mylib::record_scalar_tensor",
|
||||
"CompositeExplicitAutograd",
|
||||
lib=lib,
|
||||
)
|
||||
def record_scalar_tensor(x, prefix):
|
||||
# Pytorch custom op implementation
|
||||
@torch.library.custom_op("mylib::record_scalar_tensor", mutates_args=())
|
||||
def record_scalar_tensor(x: torch.Tensor, prefix: str) -> None:
|
||||
recorded_dict[prefix] = x.clone()
|
||||
return
|
||||
|
||||
# Meta function of the custom op
|
||||
@torch.library.register_fake(
|
||||
"mylib::record_scalar_tensor",
|
||||
lib=lib,
|
||||
)
|
||||
@record_scalar_tensor.register_fake
|
||||
def record_scalar_tensor_meta(x, prefix):
|
||||
return
|
||||
|
||||
from torch._higher_order_ops.effects import (
|
||||
_EffectType,
|
||||
_register_effectful_op,
|
||||
)
|
||||
record_scalar_tensor.register_effect(_EffectType.ORDERED)
|
||||
|
||||
_register_effectful_op(
|
||||
torch.ops.mylib.record_scalar_tensor.default, _EffectType.ORDERED
|
||||
)
|
||||
self.assertEqual(_get_effect(record_scalar_tensor), _EffectType.ORDERED)
|
||||
|
||||
my_config = {}
|
||||
my_config["MockModule"] = "mean"
|
||||
@ -469,13 +451,12 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
|
||||
torch.library.register_autograd("_mylib::zoo", foo_bwd, lib=lib)
|
||||
|
||||
from torch._higher_order_ops.effects import (
|
||||
_EffectType,
|
||||
_register_effectful_op,
|
||||
torch.library._register_effectful_op(
|
||||
torch.ops._mylib.zoo.default, _EffectType.ORDERED
|
||||
)
|
||||
torch.library._register_effectful_op(
|
||||
torch.ops._mylib.zoo2.default, _EffectType.ORDERED
|
||||
)
|
||||
|
||||
_register_effectful_op(torch.ops._mylib.zoo.default, _EffectType.ORDERED)
|
||||
_register_effectful_op(torch.ops._mylib.zoo2.default, _EffectType.ORDERED)
|
||||
|
||||
def fn(x, y):
|
||||
return torch.ops._mylib.zoo(x) + y
|
||||
@ -687,13 +668,13 @@ def forward(self, arg0_1, arg1_1):
|
||||
|
||||
torch.library.register_autograd("_mylib::foo", foo_bwd, lib=lib)
|
||||
|
||||
from torch._higher_order_ops.effects import (
|
||||
_deregister_effectful_op,
|
||||
_EffectType,
|
||||
_register_effectful_op,
|
||||
handle = _register_effectful_op(
|
||||
torch.ops._mylib.foo.default, _EffectType.ORDERED
|
||||
)
|
||||
self.assertEqual(
|
||||
_get_effect(torch.ops._mylib.foo.default), _EffectType.ORDERED
|
||||
)
|
||||
|
||||
_register_effectful_op(torch.ops._mylib.foo.default, _EffectType.ORDERED)
|
||||
try:
|
||||
|
||||
def fn(x, y):
|
||||
@ -779,17 +760,13 @@ def forward(self, tangents_1, tangents_2, tangents_token):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
finally:
|
||||
_deregister_effectful_op(torch.ops._mylib.foo.default)
|
||||
handle.destroy()
|
||||
|
||||
self.assertEqual(_get_effect(torch.ops._mylib.foo.default), None)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
def test_regular_effectful_op_only_in_backward(self):
|
||||
from torch._higher_order_ops.effects import (
|
||||
_deregister_effectful_op,
|
||||
_EffectType,
|
||||
_register_effectful_op,
|
||||
)
|
||||
|
||||
_register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
|
||||
handle = _register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
|
||||
try:
|
||||
|
||||
def fn(x):
|
||||
@ -852,17 +829,11 @@ def forward(self, primals_1, primals_2, tangents_1, tangents_2, tangents_token):
|
||||
return (mul, mul_1, getitem_2)""",
|
||||
)
|
||||
finally:
|
||||
_deregister_effectful_op(torch.ops.aten.cos.default)
|
||||
handle.destroy()
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
def test_regular_effectful_op_in_forward_and_backward(self):
|
||||
from torch._higher_order_ops.effects import (
|
||||
_deregister_effectful_op,
|
||||
_EffectType,
|
||||
_register_effectful_op,
|
||||
)
|
||||
|
||||
_register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
|
||||
handle = _register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
|
||||
try:
|
||||
|
||||
def fn(x):
|
||||
@ -897,7 +868,7 @@ def forward(self, primals_2, getitem_1, tangents_1, tangents_token):
|
||||
return (mul_1, getitem_2)""",
|
||||
)
|
||||
finally:
|
||||
_deregister_effectful_op(torch.ops.aten.cos.default)
|
||||
handle.destroy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -136,12 +136,59 @@ class TestStandaloneInductor(TestCase):
|
||||
mod_opt = inductor.compile(mod, inp)
|
||||
self.assertEqual(mod(*inp), mod_opt(*inp))
|
||||
|
||||
@mock.patch.dict(os.environ, {"TORCHINDUCTOR_DEBUG_COMPILE": "1"})
|
||||
def test_inductor_generate_debug_compile(self):
|
||||
cpp_code = """
|
||||
int main(){
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
_, source_path = write(
|
||||
cpp_code,
|
||||
"cpp",
|
||||
)
|
||||
build_option = CppOptions()
|
||||
cpp_builder = CppBuilder(
|
||||
name="test_compile",
|
||||
sources=source_path,
|
||||
output_dir=os.path.dirname(source_path),
|
||||
BuildOption=build_option,
|
||||
)
|
||||
cpp_builder.build()
|
||||
binary_path = cpp_builder.get_target_file_path()
|
||||
|
||||
"""
|
||||
When we turn on generate debug compile.
|
||||
On Windows, it should create a [module_name].pdb file. It helps debug by WinDBG.
|
||||
On Linux, it should create some debug sections in binary file.
|
||||
"""
|
||||
|
||||
def check_linux_debug_section(module_path: str):
|
||||
check_cmd = shlex.split(f"readelf -S {module_path}")
|
||||
output = safe_command_output(check_cmd)
|
||||
has_debug_sym = ".debug_info" in output
|
||||
self.assertEqual(has_debug_sym, True)
|
||||
|
||||
def check_windows_pdb_exist(module_path: str):
|
||||
file_name_no_ext = os.path.splitext(module_path)[0]
|
||||
file_name_pdb = f"{file_name_no_ext}.pdb"
|
||||
has_pdb_file = os.path.exists(file_name_pdb)
|
||||
self.assertEqual(has_pdb_file, True)
|
||||
|
||||
if _IS_WINDOWS:
|
||||
check_windows_pdb_exist(binary_path)
|
||||
elif _IS_MACOS:
|
||||
pass # MacOS not sure that if it should be works.
|
||||
else:
|
||||
check_linux_debug_section(binary_path)
|
||||
|
||||
@mock.patch.dict(os.environ, {"TORCHINDUCTOR_DEBUG_SYMBOL": "1"})
|
||||
def test_inductor_generate_debug_symbol(self):
|
||||
cpp_code = """
|
||||
int main(){
|
||||
return 0;
|
||||
}
|
||||
int main(){
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
_, source_path = write(
|
||||
|
||||
@ -683,6 +683,16 @@ class TestNumPyInterop(TestCase):
|
||||
):
|
||||
f(xs)
|
||||
|
||||
def test_copy_mode(self):
|
||||
def f(x):
|
||||
return np.array(x, copy=np._CopyMode.IF_NEEDED)
|
||||
|
||||
opt_f = torch.compile(backend="eager", fullgraph=True)(f)
|
||||
x = np.array([1, 2, 3])
|
||||
# Should run without throwing an exception
|
||||
y = opt_f(x)
|
||||
self.assertEqual(y, f(x))
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestNumPyInterop, globals())
|
||||
|
||||
|
||||
@ -90,7 +90,7 @@ class TestOpaqueObject(TestCase):
|
||||
# This is not accurate since the queue could have tensors that are
|
||||
# not rank 1
|
||||
ctx = torch._custom_op.impl.get_ctx()
|
||||
u0 = ctx.create_unbacked_symint()
|
||||
u0 = ctx.new_dynamic_size()
|
||||
return torch.empty(u0)
|
||||
|
||||
self.lib._register_fake("queue_pop", pop_impl_fake)
|
||||
@ -107,8 +107,7 @@ class TestOpaqueObject(TestCase):
|
||||
@size_impl.register_fake
|
||||
def size_impl_fake(q: torch._C.ScriptObject) -> int:
|
||||
ctx = torch._custom_op.impl.get_ctx()
|
||||
u0 = ctx.create_unbacked_symint()
|
||||
torch._check_is_size(u0)
|
||||
u0 = ctx.new_dynamic_size()
|
||||
return u0
|
||||
|
||||
super().setUp()
|
||||
|
||||
@ -1,12 +1,22 @@
|
||||
# Owner(s): ["module: custom-operators"]
|
||||
|
||||
import random
|
||||
from contextlib import ExitStack
|
||||
|
||||
import torch
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.testing import AotEagerAndRecordGraphs
|
||||
from torch._functorch.aot_autograd import (
|
||||
aot_compile_joint_with_descriptors,
|
||||
aot_export_joint_with_descriptors,
|
||||
aot_export_module,
|
||||
)
|
||||
from torch._library.effects import EffectType
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.opaque_object import register_opaque_type
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -41,11 +51,21 @@ class OpaqueQueue:
|
||||
|
||||
class RNGState:
|
||||
def __init__(self, seed):
|
||||
self.rng = random.Random(seed)
|
||||
self.seed = seed
|
||||
self.rng = random.Random(self.seed)
|
||||
|
||||
|
||||
class Counter:
|
||||
def __init__(self, start):
|
||||
self.counter = torch.tensor(start)
|
||||
|
||||
def increment_counter(self):
|
||||
self.counter += 1
|
||||
|
||||
|
||||
register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue")
|
||||
register_opaque_type(RNGState, "_TestOpaqueObject_RNGState")
|
||||
register_opaque_type(Counter, "_TestOpaqueObject_Counter")
|
||||
|
||||
|
||||
class TestOpaqueObject(TestCase):
|
||||
@ -125,6 +145,20 @@ class TestOpaqueObject(TestCase):
|
||||
def noisy_inject_fake(x: torch.Tensor, obj: RNGState) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
@torch.library.custom_op(
|
||||
"_TestOpaqueObject::increment_counter",
|
||||
mutates_args=["prev"],
|
||||
)
|
||||
def increment_counter_impl(c: Counter, prev: torch.Tensor) -> torch.Tensor:
|
||||
assert isinstance(c, Counter)
|
||||
prev.copy_(c.counter)
|
||||
c.increment_counter()
|
||||
return c.counter
|
||||
|
||||
@increment_counter_impl.register_fake
|
||||
def increment_counter_fake(c: Counter, prev: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(prev)
|
||||
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
@ -233,6 +267,235 @@ def forward(self, arg0_1, arg1_1):
|
||||
):
|
||||
make_fx(f, tracing_mode=make_fx_tracing_mode)(RNGState(0), torch.ones(3))
|
||||
|
||||
def test_aot_export(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, rng_state, x):
|
||||
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
|
||||
x = x * x
|
||||
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
|
||||
x = x + x
|
||||
return (x,)
|
||||
|
||||
mod = Model()
|
||||
rng = RNGState(0)
|
||||
x = torch.ones(2, 3)
|
||||
|
||||
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
|
||||
fake_rng = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, rng)
|
||||
fake_x = fake_mode.from_tensor(x)
|
||||
gm = aot_export_module(mod, (fake_rng, fake_x), trace_joint=False)[0]
|
||||
|
||||
# By default we don't register ops containing PyObjs as being effectful
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
noisy_inject = torch.ops._TestOpaqueObject.noisy_inject.default(arg1_1, arg0_1); arg1_1 = None
|
||||
mul = torch.ops.aten.mul.Tensor(noisy_inject, noisy_inject); noisy_inject = None
|
||||
noisy_inject_1 = torch.ops._TestOpaqueObject.noisy_inject.default(mul, arg0_1); mul = arg0_1 = None
|
||||
add = torch.ops.aten.add.Tensor(noisy_inject_1, noisy_inject_1); noisy_inject_1 = None
|
||||
return (add,)""", # noqa: B950
|
||||
)
|
||||
|
||||
torch.library._register_effectful_op(
|
||||
"_TestOpaqueObject::noisy_inject", EffectType.ORDERED
|
||||
)
|
||||
try:
|
||||
gm = aot_export_module(mod, (rng, fake_x), trace_joint=False)[0]
|
||||
# inputs: token, rng, x
|
||||
# return: token, res
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops._TestOpaqueObject.noisy_inject.default, arg2_1, arg1_1); arg0_1 = arg2_1 = None
|
||||
getitem = with_effects[0]
|
||||
getitem_1 = with_effects[1]; with_effects = None
|
||||
mul = torch.ops.aten.mul.Tensor(getitem_1, getitem_1); getitem_1 = None
|
||||
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TestOpaqueObject.noisy_inject.default, mul, arg1_1); getitem = mul = arg1_1 = None
|
||||
getitem_2 = with_effects_1[0]
|
||||
getitem_3 = with_effects_1[1]; with_effects_1 = None
|
||||
add = torch.ops.aten.add.Tensor(getitem_3, getitem_3); getitem_3 = None
|
||||
return (getitem_2, add)""", # noqa: B950
|
||||
)
|
||||
finally:
|
||||
torch.library._register_effectful_op(
|
||||
"_TestOpaqueObject::noisy_inject", None
|
||||
)
|
||||
|
||||
def test_compile(self):
|
||||
def foo(rng_state, x):
|
||||
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
|
||||
x = x * x
|
||||
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
|
||||
x = x + x
|
||||
return x
|
||||
|
||||
rng = RNGState(0)
|
||||
x = torch.ones(2, 3)
|
||||
|
||||
res = torch.compile(foo, fullgraph=True, backend="inductor")(rng, x)
|
||||
self.assertFalse(torch.allclose(res, x * x + x))
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
torch.compile(foo, fullgraph=True, backend=backend)(rng, x)
|
||||
self.assertExpectedInline(
|
||||
backend.graphs[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, L_x_ : torch.Tensor, L_rng_state_ : __main___RNGState):
|
||||
l_x_ = L_x_
|
||||
l_rng_state_ = L_rng_state_
|
||||
x = torch.ops._TestOpaqueObject.noisy_inject(l_x_, l_rng_state_); l_x_ = None
|
||||
x_1 = x * x; x = None
|
||||
x_2 = torch.ops._TestOpaqueObject.noisy_inject(x_1, l_rng_state_); x_1 = l_rng_state_ = None
|
||||
x_3 = x_2 + x_2; x_2 = None
|
||||
return (x_3,)""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
backend.fw_graphs[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
noisy_inject = torch.ops._TestOpaqueObject.noisy_inject.default(arg0_1, arg1_1); arg0_1 = None
|
||||
mul = torch.ops.aten.mul.Tensor(noisy_inject, noisy_inject); noisy_inject = None
|
||||
noisy_inject_1 = torch.ops._TestOpaqueObject.noisy_inject.default(mul, arg1_1); mul = arg1_1 = None
|
||||
add = torch.ops.aten.add.Tensor(noisy_inject_1, noisy_inject_1); noisy_inject_1 = None
|
||||
return (add,)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_compile_intermediate(self):
|
||||
counter = Counter(0)
|
||||
|
||||
def foo(x, y):
|
||||
z = torch.ops._TestOpaqueObject.increment_counter(counter, y)
|
||||
x = x * z
|
||||
z = torch.ops._TestOpaqueObject.increment_counter(counter, y)
|
||||
x = x + z
|
||||
return x, counter
|
||||
|
||||
inp = (torch.tensor(1), torch.tensor(0))
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_f = torch.compile(foo, fullgraph=True, backend=backend)
|
||||
res = opt_f(*inp)
|
||||
self.assertEqual(res[0], torch.tensor(3))
|
||||
self.assertEqual(res[1].counter, torch.tensor(2))
|
||||
|
||||
res = opt_f(*inp)
|
||||
self.assertEqual(res[0], torch.tensor(7))
|
||||
self.assertEqual(res[1].counter, torch.tensor(4))
|
||||
|
||||
# counter is automatically lifted as an input
|
||||
# Even though we returned counter in the eager code, it does not get
|
||||
# returned in the graph because dynamo does not detect that the object
|
||||
# is mutated.
|
||||
self.assertExpectedInline(
|
||||
backend.fw_graphs[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops._TestOpaqueObject.increment_counter.default, c = arg1_1, _prev_base_index = 0, _all_bases = [arg0_1])
|
||||
getitem = auto_functionalized_v2[0]
|
||||
getitem_1 = auto_functionalized_v2[1]; auto_functionalized_v2 = None
|
||||
mul = torch.ops.aten.mul.Tensor(arg2_1, getitem); arg2_1 = getitem = None
|
||||
auto_functionalized_v2_1 = torch.ops.higher_order.auto_functionalized_v2(torch.ops._TestOpaqueObject.increment_counter.default, c = arg1_1, _prev_base_index = 0, _all_bases = [getitem_1]); arg1_1 = getitem_1 = None
|
||||
getitem_2 = auto_functionalized_v2_1[0]
|
||||
getitem_3 = auto_functionalized_v2_1[1]; auto_functionalized_v2_1 = None
|
||||
add = torch.ops.aten.add.Tensor(mul, getitem_2); mul = getitem_2 = None
|
||||
copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_3); arg0_1 = getitem_3 = copy_ = None
|
||||
return (add,)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_compile_attribute(self):
|
||||
counter = Counter(0)
|
||||
|
||||
def foo(counter, x):
|
||||
x = x * x
|
||||
counter.increment_counter()
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Attempted to access attributes/methods on an OpaqueObject"
|
||||
):
|
||||
torch.compile(foo)(counter, torch.ones(2, 3))
|
||||
|
||||
def bar(counter, x):
|
||||
x = x * x
|
||||
x += counter.counter
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Attempted to access attributes/methods on an OpaqueObject"
|
||||
):
|
||||
torch.compile(bar)(counter, torch.ones(2, 3))
|
||||
|
||||
def test_export_joint(self):
|
||||
class Moo(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return x * y
|
||||
|
||||
register_opaque_type(Moo, "_TestOpaqueObject_Moo")
|
||||
|
||||
torch.library.define(
|
||||
"_TestOpaqueObject::module_mul",
|
||||
"(_TestOpaqueObject_Moo a, Tensor b, SymInt c) -> Tensor",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=self.lib,
|
||||
)
|
||||
|
||||
@torch.library.impl(
|
||||
"_TestOpaqueObject::module_mul", "CompositeExplicitAutograd", lib=self.lib
|
||||
)
|
||||
def module_mul_impl(m: Moo, a: torch.Tensor, b: int) -> torch.Tensor:
|
||||
assert isinstance(m, Moo)
|
||||
return m(a, b)
|
||||
|
||||
@torch.library.register_fake("_TestOpaqueObject::module_mul", lib=self.lib)
|
||||
def module_mul_fake(m: Moo, a: torch.Tensor, b: int) -> torch.Tensor:
|
||||
return torch.empty_like(a)
|
||||
|
||||
def module_mul_setup_context(ctx, inputs, output):
|
||||
m, a, b = inputs
|
||||
ctx.b = b
|
||||
|
||||
def module_mul_backward(ctx, grad) -> torch.Tensor:
|
||||
return None, grad * ctx.b, None
|
||||
|
||||
torch.library.register_autograd(
|
||||
"_TestOpaqueObject::module_mul",
|
||||
module_mul_backward,
|
||||
setup_context=module_mul_setup_context,
|
||||
lib=self.lib,
|
||||
)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.moo = Moo()
|
||||
|
||||
def forward(self, x, y):
|
||||
b = y.item()
|
||||
return torch.ops._TestOpaqueObject.module_mul(self.moo, x, b)
|
||||
|
||||
inp = (torch.randn(3, requires_grad=True), torch.tensor(4))
|
||||
with ExitStack() as stack:
|
||||
with FakeTensorMode(shape_env=ShapeEnv()):
|
||||
joint = aot_export_joint_with_descriptors(stack, M(), inp)
|
||||
self.assertExpectedInline(
|
||||
joint.graph_module.code.strip(),
|
||||
"""\
|
||||
def forward(self, primals, tangents):
|
||||
primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
|
||||
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(primals_2); primals_2 = None
|
||||
_opaque_obj0 = self._opaque_obj0
|
||||
module_mul = torch.ops._TestOpaqueObject.module_mul.default(_opaque_obj0, primals_1, _local_scalar_dense); _opaque_obj0 = primals_1 = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, _local_scalar_dense); tangents_1 = _local_scalar_dense = None
|
||||
return pytree.tree_unflatten([module_mul, mul_1, None], self._out_spec)""", # noqa: B950
|
||||
)
|
||||
compiled_fn = aot_compile_joint_with_descriptors(joint)
|
||||
|
||||
self.assertEqual(compiled_fn(*inp), M()(*inp))
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestOpaqueObject)
|
||||
|
||||
|
||||
@ -796,6 +796,27 @@ def forward(self, x_1):
|
||||
|
||||
self._test(f, [torch.randn(1, 10), torch.zeros(1, dtype=torch.long)])
|
||||
|
||||
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
|
||||
def test_T244632748(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + (x.shape[0] * 2)
|
||||
|
||||
mod = TestModule()
|
||||
sample = torch.randn((5, 5)).to("cuda")
|
||||
dim0 = torch.export.Dim.DYNAMIC(max=100)
|
||||
dynamic_shapes = {"x": (dim0, torch.export.Dim.STATIC)}
|
||||
ep = torch.export.export(mod, (sample,), dynamic_shapes=dynamic_shapes)
|
||||
gm = ep.module()
|
||||
symint = list(gm.graph.nodes)[3].meta["val"]
|
||||
list(gm.graph.nodes)[3].replace_all_uses_with(symint)
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
inductor_fx = torch._inductor.aot_compile(
|
||||
gm, (sample,), options={"fx_wrapper": True, "compile_threads": 1}
|
||||
)
|
||||
|
||||
|
||||
class TestGenericProxyTensorReal(TestGenericProxyTensor):
|
||||
tracing_mode = "real"
|
||||
|
||||
|
||||
@ -310,7 +310,7 @@ class TestHistogram(TestCase):
|
||||
)
|
||||
|
||||
# these should not crash
|
||||
np.histogram([np.array(0.5) for i in range(10)] + [0.500000000000001])
|
||||
np.histogram([np.array(0.5) for i in range(10)] + [0.500000000000002])
|
||||
np.histogram([np.array(0.5) for i in range(10)] + [0.5])
|
||||
|
||||
@xpassIfTorchDynamo_np # (reason="bins='auto'")
|
||||
|
||||
@ -3657,5 +3657,15 @@
|
||||
"Explanation": "Encountered triton kernel unsupported feature: {msg}",
|
||||
"Hints": []
|
||||
}
|
||||
],
|
||||
"GB0362": [
|
||||
{
|
||||
"Gb_type": "Attempted to access attributes/methods on an OpaqueObject",
|
||||
"Context": "value={self.value}, attr={name}",
|
||||
"Explanation": "Attribute/method access of OpaqueObjects is not supported.",
|
||||
"Hints": [
|
||||
"Use custom operators instead of direct attribute/method access."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -56,6 +56,7 @@ from torch._guards import (
|
||||
tracing,
|
||||
TracingContext,
|
||||
)
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch._utils_internal import signpost_event
|
||||
from torch.export.dynamic_shapes import _ConstraintTarget
|
||||
@ -2605,6 +2606,8 @@ class OutputGraph(OutputGraphCommon):
|
||||
fake_attr_val,
|
||||
)
|
||||
continue
|
||||
if is_opaque_type(type(node.meta["grapharg"].example)):
|
||||
continue
|
||||
fake = (
|
||||
arg.fake_tensor if arg.fake_tensor is not None else arg.example
|
||||
)
|
||||
|
||||
@ -58,6 +58,7 @@ from torch._dynamo.utils import (
|
||||
from torch._guards import TracingContext
|
||||
from torch._higher_order_ops.flat_apply import flat_apply
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
|
||||
from torch._subclasses.meta_utils import is_sparse_any, safe_grad
|
||||
@ -1452,27 +1453,32 @@ class VariableBuilder:
|
||||
source=self.source,
|
||||
)
|
||||
|
||||
# This exists to allow a smoother transition.
|
||||
# The implications are:
|
||||
# The script objects won't be tracked as proxies.
|
||||
# Methods on these objects won't show up in the graph.
|
||||
# The original script object might be mutated.
|
||||
if not hasattr(value, "__obj_flatten__"):
|
||||
return self.wrap_user_defined(value)
|
||||
if is_opaque_type(type(value)):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
|
||||
# Install the guards on the fully qualified name of the script object
|
||||
LazyVariableTracker.realize_all(
|
||||
VariableBuilder(self.tx, ScriptObjectQualifiedNameSource(self.source))(
|
||||
value._type().qualified_name() # type: ignore[attr-defined]
|
||||
elif not hasattr(value, "__obj_flatten__"):
|
||||
# This exists to allow a smoother transition.
|
||||
# The implications are:
|
||||
# The script objects won't be tracked as proxies.
|
||||
# Methods on these objects won't show up in the graph.
|
||||
# The original script object might be mutated.
|
||||
return self.wrap_user_defined(value)
|
||||
else:
|
||||
# Install the guards on the fully qualified name of the script object
|
||||
LazyVariableTracker.realize_all(
|
||||
VariableBuilder(
|
||||
self.tx, ScriptObjectQualifiedNameSource(self.source)
|
||||
)(
|
||||
value._type().qualified_name() # type: ignore[attr-defined]
|
||||
)
|
||||
)
|
||||
)
|
||||
# Install the guards on the content of the script object by setting the source
|
||||
# to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
|
||||
LazyVariableTracker.realize_all(
|
||||
VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
|
||||
value.__obj_flatten__()
|
||||
# Install the guards on the content of the script object by setting the source
|
||||
# to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
|
||||
LazyVariableTracker.realize_all(
|
||||
VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
|
||||
value.__obj_flatten__()
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj(
|
||||
self.tx.output.fake_mode, value
|
||||
|
||||
@ -18,6 +18,7 @@ Key classes include:
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
@ -1604,11 +1605,16 @@ class NumpyVariable(VariableTracker):
|
||||
return self.value
|
||||
|
||||
def as_proxy(self):
|
||||
if config.trace_numpy and isinstance(self.value, type):
|
||||
# This handles numpy dtype attributes such as np.float32
|
||||
# We return a string as we don't want to serialize non-PyTorch objects in the output FX graph
|
||||
# In torch/_numpy we normalize strings to their dtypes when the input is a dtype, as NumPy does
|
||||
return self.value.__name__
|
||||
if config.trace_numpy:
|
||||
# Can replace with EnumType once we drop 3.10 support
|
||||
if isinstance(self.value, enum.EnumMeta):
|
||||
# This is mostly for np._CopyMode
|
||||
return self.value
|
||||
if isinstance(self.value, type):
|
||||
# This handles numpy dtype attributes such as np.float32
|
||||
# We return a string as we don't want to serialize non-PyTorch objects in the output FX graph
|
||||
# In torch/_numpy we normalize strings to their dtypes when the input is a dtype, as NumPy does
|
||||
return self.value.__name__
|
||||
|
||||
return super().as_proxy()
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
from torch._guards import Source
|
||||
from torch._library.opaque_object import is_opaque_type, OpaqueTypeStr
|
||||
from torch.fx.proxy import Proxy
|
||||
|
||||
from .. import graph_break_hints
|
||||
@ -61,7 +62,7 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
|
||||
|
||||
@classmethod
|
||||
def is_matching_cls(cls, user_cls: type) -> bool:
|
||||
return issubclass(user_cls, torch.ScriptObject)
|
||||
return issubclass(user_cls, torch.ScriptObject) or is_opaque_type(user_cls)
|
||||
|
||||
@staticmethod
|
||||
def create(proxy: Proxy, value: Any, **options: Any) -> "TorchScriptObjectVariable":
|
||||
@ -80,6 +81,16 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
|
||||
"Dynamo cannot safely trace script object due to graph break."
|
||||
)
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
if getattr(self.value, "script_class_name", "") == OpaqueTypeStr:
|
||||
unimplemented(
|
||||
gb_type="Attempted to access attributes/methods on an OpaqueObject",
|
||||
context=f"value={self.value}, attr={name}",
|
||||
explanation="Attribute/method access of OpaqueObjects is not supported.",
|
||||
hints=[
|
||||
"Use custom operators instead of direct attribute/method access.",
|
||||
],
|
||||
)
|
||||
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
|
||||
from ..source import AttrSource
|
||||
|
||||
@ -24,6 +24,7 @@ from torch._export.passes.lift_constants_pass import ConstantAttrMap
|
||||
from torch._export.utils import _fakify_params_buffers
|
||||
from torch._guards import Source
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.export import Constraint
|
||||
from torch.export.dynamic_shapes import (
|
||||
@ -946,7 +947,9 @@ def _fakify_script_objects(
|
||||
|
||||
try:
|
||||
for obj, fqns in constant_attrs.items():
|
||||
if torch._library.fake_class_registry._is_script_object(obj):
|
||||
if torch._library.fake_class_registry._is_script_object(
|
||||
obj
|
||||
) or is_opaque_type(obj):
|
||||
fake_script_obj = _maybe_fakify_obj(obj)
|
||||
for fqn in fqns:
|
||||
cur_mod, attr = _leaf_mod_and_attr(mod, fqn)
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any, Optional
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._subclasses import FakeTensor, FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
@ -46,7 +47,7 @@ def process_inputs(
|
||||
hint=x,
|
||||
source=source,
|
||||
)
|
||||
if isinstance(x, torch.ScriptObject):
|
||||
if isinstance(x, torch.ScriptObject) or is_opaque_type(type(x)):
|
||||
return torch._library.fake_class_registry.maybe_to_fake_obj(
|
||||
fake_mode, x
|
||||
)
|
||||
|
||||
@ -534,6 +534,7 @@ def create_aot_state(
|
||||
stack.enter_context(autograd_fallback_mode("error"))
|
||||
|
||||
from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
|
||||
# Tracing may mutate the states the fake script object,
|
||||
# so we need to duplicate the fake script objects so that subsequent tracing
|
||||
@ -541,7 +542,7 @@ def create_aot_state(
|
||||
def _dup_fake_script_obj(fake_flat_args):
|
||||
return [
|
||||
maybe_to_fake_obj(detect_fake_mode(fake_flat_args), arg.real_obj)
|
||||
if isinstance(arg, FakeScriptObject)
|
||||
if isinstance(arg, FakeScriptObject) or is_opaque_type(type(arg))
|
||||
else arg
|
||||
for arg in fake_flat_args
|
||||
]
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.custom_ops import CustomOpDef
|
||||
from torch._library.effects import EffectType
|
||||
from torch._library.utils import RegistrationHandle
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
@ -17,39 +17,50 @@ from torch.fx.experimental.proxy_tensor import (
|
||||
)
|
||||
|
||||
|
||||
class _EffectType(Enum):
|
||||
ORDERED = "Ordered"
|
||||
_op_identifier = Union[
|
||||
str,
|
||||
"torch._ops.OpOverload",
|
||||
"torch._library.custom_ops.CustomOpDef",
|
||||
"torch._ops.HigherOrderOperator",
|
||||
]
|
||||
OpType = Union["torch._ops.HigherOrderOperator", "torch._ops.OpOverload"]
|
||||
|
||||
_EffectType = EffectType
|
||||
|
||||
|
||||
OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload]
|
||||
def _get_op_qualname(op: _op_identifier) -> str:
|
||||
"""Convert an op identifier to a qualified string key."""
|
||||
if isinstance(op, torch._ops.OpOverload):
|
||||
return op._name
|
||||
elif isinstance(op, torch._ops.HigherOrderOperator):
|
||||
return f"{op.namespace}::{op.name()}"
|
||||
elif isinstance(op, CustomOpDef):
|
||||
return op._qualname
|
||||
elif isinstance(op, str):
|
||||
return op
|
||||
|
||||
raise ValueError(f"Invalid operator input {op}")
|
||||
|
||||
|
||||
SIDE_EFFECTS = WeakKeyDictionary[OpType, _EffectType](
|
||||
[
|
||||
(torch.ops.aten._print.default, _EffectType.ORDERED),
|
||||
(torch.ops.aten._async_error.default, _EffectType.ORDERED),
|
||||
(call_torchbind, _EffectType.ORDERED),
|
||||
]
|
||||
)
|
||||
def _register_effectful_op(
|
||||
op: _op_identifier, effect: Optional[EffectType]
|
||||
) -> RegistrationHandle:
|
||||
qualname = _get_op_qualname(op)
|
||||
entry = torch._library.simple_registry.singleton.find(qualname)
|
||||
handle = entry.effect.register(effect)
|
||||
return handle
|
||||
|
||||
|
||||
def _register_effectful_op(op: OpType, effect: _EffectType):
|
||||
assert isinstance(
|
||||
op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
|
||||
) and not has_aliasing(op)
|
||||
if op in SIDE_EFFECTS and SIDE_EFFECTS[op] != effect:
|
||||
raise RuntimeError(
|
||||
f"Already registered effect type {SIDE_EFFECTS[op]} to op {op}, "
|
||||
f"trying to register a different effect type {effect}."
|
||||
)
|
||||
SIDE_EFFECTS[op] = effect
|
||||
def _get_effect(op: _op_identifier) -> Optional[_EffectType]:
|
||||
qualname = _get_op_qualname(op)
|
||||
entry = torch._library.simple_registry.singleton.find(qualname)
|
||||
return entry.effect.effect
|
||||
|
||||
|
||||
def _deregister_effectful_op(op: OpType):
|
||||
if op not in SIDE_EFFECTS:
|
||||
raise RuntimeError(f"Op {op} is not registered as effectful")
|
||||
|
||||
del SIDE_EFFECTS[op]
|
||||
_register_effectful_op("aten::_print", _EffectType.ORDERED)
|
||||
_register_effectful_op("aten::_async_error", _EffectType.ORDERED)
|
||||
_register_effectful_op("profiler::_record_function_exit._RecordFunction", None)
|
||||
_register_effectful_op(call_torchbind, _EffectType.ORDERED)
|
||||
|
||||
|
||||
class WithEffects(HigherOrderOperator):
|
||||
@ -78,7 +89,7 @@ class WithEffects(HigherOrderOperator):
|
||||
) -> tuple[Any, ...]:
|
||||
assert isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
|
||||
assert not has_aliasing(op), "Ops with aliasing is not supported"
|
||||
assert has_effects(op, args, kwargs)
|
||||
assert has_effects(op)
|
||||
assert isinstance(kwargs, dict)
|
||||
return super().__call__(token, op, *args, **kwargs)
|
||||
|
||||
@ -89,7 +100,7 @@ with_effects = WithEffects()
|
||||
def has_aliasing(op: OpType):
|
||||
# NOT FOR PUBLIC USE
|
||||
if isinstance(op, torch._ops.HigherOrderOperator):
|
||||
return op not in SIDE_EFFECTS
|
||||
return not _get_effect(op)
|
||||
|
||||
for arg in op._schema.arguments:
|
||||
if arg.alias_info is not None:
|
||||
@ -100,7 +111,7 @@ def has_aliasing(op: OpType):
|
||||
return False
|
||||
|
||||
|
||||
def has_effects(op, args, kwargs) -> bool:
|
||||
def has_effects(op) -> bool:
|
||||
# Skip over the profiler's RecordFunction as they should not show up in the graph
|
||||
_skip_ops = {torch.ops.profiler._record_function_exit._RecordFunction}
|
||||
if op in _skip_ops:
|
||||
@ -109,31 +120,10 @@ def has_effects(op, args, kwargs) -> bool:
|
||||
return (
|
||||
isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
|
||||
and not has_aliasing(op)
|
||||
and get_effect_key(op, args, kwargs) is not None
|
||||
and _get_effect(op) is not None
|
||||
)
|
||||
|
||||
|
||||
def get_effect_key(op, args, kwargs) -> Optional[_EffectType]:
|
||||
if op in SIDE_EFFECTS:
|
||||
return SIDE_EFFECTS[op]
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, (torch.ScriptObject, FakeScriptObject)):
|
||||
# Add it to the table so that next time we see the same op we don't
|
||||
# have to parse through the args again
|
||||
SIDE_EFFECTS[op] = _EffectType.ORDERED
|
||||
return _EffectType.ORDERED
|
||||
|
||||
for arg in kwargs.values():
|
||||
if isinstance(arg, (torch.ScriptObject, FakeScriptObject)):
|
||||
# Add it to the table so that next time we see the same op we don't
|
||||
# have to parse through the args again
|
||||
SIDE_EFFECTS[op] = _EffectType.ORDERED
|
||||
return _EffectType.ORDERED
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def new_token_tensor() -> torch.Tensor:
|
||||
return torch.tensor([])
|
||||
|
||||
@ -238,7 +228,7 @@ def handle_effects(
|
||||
# Get a token. We can't do `tokens.get(op, torch.tensor([]))` because
|
||||
# this will create an empty tensor during proxy mode tracing if the token
|
||||
# doesn't exist. But the tokens should always exist during proxy mode tracing.
|
||||
key = get_effect_key(op, args, kwargs)
|
||||
key = _get_effect(op)
|
||||
assert key is not None
|
||||
if key not in tokens:
|
||||
assert allow_token_discovery, (
|
||||
|
||||
@ -2122,6 +2122,10 @@ class PythonWrapperCodegen(CodeGen):
|
||||
output.writeline(f"{name} = {val}")
|
||||
|
||||
def add_torchbind_input(name, value):
|
||||
if value is None:
|
||||
output.writeline(f"{name} = None")
|
||||
return
|
||||
|
||||
import pickle
|
||||
|
||||
assert isinstance(value, torch.ScriptObject)
|
||||
|
||||
@ -91,6 +91,7 @@ from torch._inductor.utils import (
|
||||
tensor_is_aligned,
|
||||
)
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._logging import trace_structured
|
||||
from torch._utils_internal import compile_time_strobelight_meta
|
||||
from torch.fx import GraphModule
|
||||
@ -2747,7 +2748,9 @@ def _compile_fx_main(
|
||||
node.meta["val"] = fake_mode.from_tensor(
|
||||
target, static_shapes=True
|
||||
)
|
||||
elif isinstance(target, torch.ScriptObject):
|
||||
elif isinstance(target, torch.ScriptObject) or is_opaque_type(
|
||||
type(target)
|
||||
):
|
||||
node.meta["val"] = (
|
||||
torch._library.fake_class_registry.maybe_to_fake_obj(
|
||||
fake_mode, target
|
||||
|
||||
@ -883,11 +883,12 @@ def _get_optimization_cflags(
|
||||
|
||||
should_use_optimized_flags = not (
|
||||
config.aot_inductor.debug_compile
|
||||
or os.environ.get("TORCHINDUCTOR_DEBUG_SYMBOL", "0") == "1"
|
||||
or os.environ.get("TORCHINDUCTOR_DEBUG_COMPILE", "0") == "1"
|
||||
)
|
||||
should_add_debug_symbol_flags = (
|
||||
config.aot_inductor.debug_compile
|
||||
or config.aot_inductor.debug_symbols
|
||||
or os.environ.get("TORCHINDUCTOR_DEBUG_COMPILE", "0") == "1"
|
||||
or os.environ.get("TORCHINDUCTOR_DEBUG_SYMBOL", "0") == "1"
|
||||
)
|
||||
if should_use_optimized_flags:
|
||||
|
||||
@ -9242,12 +9242,9 @@ class EffectfulKernel(FallbackKernel):
|
||||
unbacked_bindings=unbacked_bindings,
|
||||
)
|
||||
|
||||
from torch._higher_order_ops.effects import get_effect_key
|
||||
from torch._higher_order_ops.effects import _get_effect
|
||||
|
||||
uncovered_args = [
|
||||
a.value if isinstance(a, TorchBindObject) else a for a in tensor_args
|
||||
]
|
||||
effect_type = get_effect_key(kernel, (*nontensor_args, *uncovered_args), kwargs)
|
||||
effect_type = _get_effect(kernel)
|
||||
assert effect_type is not None
|
||||
self.effect_type = effect_type
|
||||
self.prev_effect_buffer = V.graph.effectful_ops.get(effect_type, None)
|
||||
@ -9298,6 +9295,10 @@ class TorchBindObject(NonTensorObj):
|
||||
def get_buf_bytes(self) -> int:
|
||||
# Returns the sum of all tensors in the flattened object
|
||||
real_script_obj = self.get_real_obj()
|
||||
|
||||
if real_script_obj is None:
|
||||
return 0
|
||||
|
||||
assert hasattr(real_script_obj, "__obj_flatten__")
|
||||
flat_dict = dict(real_script_obj.__obj_flatten__())
|
||||
flat_elems = pytree.tree_flatten(flat_dict)[0]
|
||||
|
||||
@ -26,6 +26,7 @@ import torch.utils._pytree as pytree
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._higher_order_ops.associative_scan import associative_scan_op
|
||||
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.utils import get_layout_constraint_tag
|
||||
from torch._prims_common import ( # pyrefly: ignore # deprecated; pyrefly: ignore [deprecated]
|
||||
canonicalize_dim,
|
||||
@ -2704,6 +2705,8 @@ def require_channels_last(_, *args, **kwargs):
|
||||
|
||||
|
||||
def constrain_to_fake_tensor(arg, fake_arg):
|
||||
if isinstance(fake_arg, FakeScriptObject):
|
||||
return arg
|
||||
if isinstance(arg, ir.IRNode):
|
||||
meta_stride_expr = [
|
||||
s.node.expr if isinstance(s, torch.SymInt) else s for s in fake_arg.stride()
|
||||
@ -7453,9 +7456,9 @@ def _sink_tokens(tokens):
|
||||
def with_effects(token, op, *args, **kwargs):
|
||||
result = ir.EffectfulKernel.create(op, *args, **kwargs)
|
||||
|
||||
from torch._higher_order_ops.effects import get_effect_key
|
||||
from torch._higher_order_ops.effects import _get_effect
|
||||
|
||||
effect_type = get_effect_key(op, args, kwargs)
|
||||
effect_type = _get_effect(op)
|
||||
assert effect_type is not None
|
||||
effectful_kernel = V.graph.effectful_ops[effect_type]
|
||||
|
||||
|
||||
@ -2740,163 +2740,10 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
|
||||
inputs_key = create_inputs_key(input_nodes)
|
||||
|
||||
# TODO(nmacchioni): remove this hacky way to tell if we ran benchmarking
|
||||
has_autotuned = False
|
||||
|
||||
def benchmark(choices, hint_override: Optional[int] = None):
|
||||
nonlocal has_autotuned
|
||||
# TODO(nmacchioni): remove this hacky way to tell if we ran benchmarking
|
||||
has_autotuned = True
|
||||
counters["inductor"]["select_algorithm_autotune"] += 1
|
||||
# TODO(nmacchioni): remove this layer of abstraction
|
||||
# construct `benchmark_fn` which should pick between in-process and sub-process autotuning
|
||||
benchmark_fn = self.make_benchmark_fn(
|
||||
choices, input_nodes, layout, input_gen_fns, hint_override=hint_override
|
||||
)
|
||||
# `benchmark_fn(choices)` will execute each choice, and return a dict[choice, timing] which
|
||||
# maps each choice to its runtime, calculated by the specified benchmarker, in milliseconds
|
||||
return benchmark_fn(choices)
|
||||
|
||||
def autotune(choices, hint_override: Optional[int] = None):
|
||||
log.debug("Starting autotuning")
|
||||
|
||||
with dynamo_timed(
|
||||
f"{name}_template_autotuning",
|
||||
log_pt2_compile_event=True,
|
||||
dynamo_compile_column_us="compile_time_autotune_time_us",
|
||||
metadata=_autotune_metadata(input_nodes),
|
||||
):
|
||||
benchmark_results = benchmark(choices, hint_override=hint_override)
|
||||
if config.max_autotune_report_choices_stats:
|
||||
_log_autotune_choices_stats(
|
||||
f"{name}_template_autotuning", benchmark_results
|
||||
)
|
||||
return benchmark_results
|
||||
|
||||
if config.autotune_in_subproc:
|
||||
# Initialize the suprocess pool so it will warmup early.
|
||||
torch._inductor.autotune_process.get_tuning_process_pool()
|
||||
|
||||
def do_autotuning(choices, precompile_fn, hint_override: Optional[int] = None):
|
||||
precompile_start_ts = time.time()
|
||||
with dynamo_timed(
|
||||
f"{name}_template_precompiling",
|
||||
log_pt2_compile_event=True,
|
||||
dynamo_compile_column_us="compile_time_autotune_time_us",
|
||||
):
|
||||
precompile_fn()
|
||||
precompile_elapse = time.time() - precompile_start_ts
|
||||
log.debug("Precompilation elapsed time: %.02fs", precompile_elapse)
|
||||
# Prune anything that failed to compile
|
||||
choices = [c for c in choices if not c.failed]
|
||||
if len(choices) == 0:
|
||||
raise self.create_no_valid_choices(
|
||||
name, "All choices failed to compile for backend."
|
||||
)
|
||||
|
||||
candidates = self.prescreen_choices(
|
||||
choices, name, inputs_key, self.prescreening_cache
|
||||
)
|
||||
prescreening_elapse: Optional[float] = None
|
||||
if candidates:
|
||||
prescreening_start_ts = time.time()
|
||||
timings = self.lookup(
|
||||
candidates,
|
||||
name,
|
||||
inputs_key,
|
||||
lambda choices: autotune(choices, hint_override=hint_override),
|
||||
hint_override=hint_override,
|
||||
)
|
||||
choices = self.prune_choices_postscreen(
|
||||
choices, timings, name, inputs_key, self.prescreening_cache
|
||||
)
|
||||
prescreening_elapse = time.time() - prescreening_start_ts
|
||||
log.debug("Prescreening elapsed time: %.02fs", prescreening_elapse)
|
||||
|
||||
autotune_start_ts = time.time()
|
||||
|
||||
if best_config_future is not None:
|
||||
best_config = await_sync(best_config_future)
|
||||
|
||||
important_keys = [
|
||||
"ACC_TYPE",
|
||||
"ALLOW_TF32",
|
||||
"BLOCK_K",
|
||||
"BLOCK_M",
|
||||
"BLOCK_N",
|
||||
"EVEN_K",
|
||||
"GROUP_M",
|
||||
"USE_FAST_ACCUM",
|
||||
"num_stages",
|
||||
"num_warps",
|
||||
"num_consumer_groups",
|
||||
"num_buffers_warp_spec",
|
||||
]
|
||||
choices = [
|
||||
choice
|
||||
for choice in choices
|
||||
if all(
|
||||
f"{k}={best_config[k]}" in choice.description
|
||||
for k in important_keys
|
||||
)
|
||||
for k in important_keys
|
||||
]
|
||||
log.info("Filtered to %d choices based on best_config", len(choices))
|
||||
|
||||
timings = self.lookup(
|
||||
choices,
|
||||
name,
|
||||
inputs_key,
|
||||
lambda choices: autotune(choices, hint_override=hint_override),
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
autotune_elapse = time.time() - autotune_start_ts
|
||||
log.debug("Autotuning elapsed time: %.02fs", autotune_elapse)
|
||||
|
||||
if timings and all(
|
||||
not math.isfinite(timing) for timing in timings.values()
|
||||
):
|
||||
raise NoValidChoicesError
|
||||
|
||||
if (
|
||||
has_autotuned
|
||||
or log.getEffectiveLevel() == logging.DEBUG
|
||||
or config.trace.log_autotuning_results
|
||||
):
|
||||
self.log_results(
|
||||
name,
|
||||
input_nodes,
|
||||
timings,
|
||||
autotune_elapse,
|
||||
precompile_elapse,
|
||||
prescreening_elapse,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
def profiler_bench_function():
|
||||
# we're not running through the normal caching autotuner method here because we want to avoid returning
|
||||
# the cached value.
|
||||
# Avoid benchmarking in a separate process because it's not easy to signal to the TuningProcess that we
|
||||
# should use the profiler.
|
||||
with config.patch(
|
||||
profile_bandwidth_with_do_bench_using_profiling=True,
|
||||
autotune_in_subproc=False,
|
||||
):
|
||||
return benchmark(choices)
|
||||
|
||||
for feedback_fn in self.feedback_saver_fns:
|
||||
# re-benchmarking the same choices with profiler is a bit expensive, so pass it in as a thunk.
|
||||
feedback_fn(
|
||||
timings,
|
||||
name,
|
||||
input_nodes,
|
||||
choices,
|
||||
profiler_bench_function,
|
||||
)
|
||||
|
||||
return timings
|
||||
|
||||
precompile_fn = self.make_precompile_fn(
|
||||
choices,
|
||||
name,
|
||||
@ -2913,8 +2760,16 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
if not hasattr(c, "hint_override")
|
||||
or c.hint_override == hint_override
|
||||
]
|
||||
timings = do_autotuning(
|
||||
filtered_choices, precompile_fn, hint_override=hint_override
|
||||
timings = self.do_autotuning(
|
||||
name,
|
||||
input_nodes,
|
||||
layout,
|
||||
input_gen_fns,
|
||||
inputs_key,
|
||||
filtered_choices,
|
||||
precompile_fn,
|
||||
hint_override=hint_override,
|
||||
best_config_future=best_config_future,
|
||||
)
|
||||
min_extern_choice = float("inf")
|
||||
for choice, timing in timings.items():
|
||||
@ -2950,7 +2805,16 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
)
|
||||
)
|
||||
|
||||
timings = do_autotuning(choices, precompile_fn)
|
||||
timings = self.do_autotuning(
|
||||
name,
|
||||
input_nodes,
|
||||
layout,
|
||||
input_gen_fns,
|
||||
inputs_key,
|
||||
choices,
|
||||
precompile_fn,
|
||||
best_config_future=best_config_future,
|
||||
)
|
||||
# if timings is empty, we really have no choice but to return a semi-random
|
||||
# choice. returning the first `ExternKernelCaller` is probably the safest bet
|
||||
# in this case, since it will generally be the ATen kernel. if there are no
|
||||
@ -2986,6 +2850,229 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
return node, choice
|
||||
return node
|
||||
|
||||
def benchmark(
|
||||
self,
|
||||
choices,
|
||||
input_nodes,
|
||||
layout,
|
||||
input_gen_fns,
|
||||
hint_override: Optional[int] = None,
|
||||
):
|
||||
counters["inductor"]["select_algorithm_autotune"] += 1
|
||||
# TODO(nmacchioni): remove this layer of abstraction
|
||||
# construct `benchmark_fn` which should pick between in-process and sub-process autotuning
|
||||
benchmark_fn = self.make_benchmark_fn(
|
||||
choices, input_nodes, layout, input_gen_fns, hint_override=hint_override
|
||||
)
|
||||
# `benchmark_fn(choices)` will execute each choice, and return a dict[choice, timing] which
|
||||
# maps each choice to its runtime, calculated by the specified benchmarker, in milliseconds
|
||||
return benchmark_fn(choices)
|
||||
|
||||
def autotune(
|
||||
self,
|
||||
name,
|
||||
input_nodes,
|
||||
layout,
|
||||
input_gen_fns,
|
||||
choices,
|
||||
hint_override: Optional[int] = None,
|
||||
):
|
||||
log.debug("Starting autotuning")
|
||||
|
||||
with dynamo_timed(
|
||||
f"{name}_template_autotuning",
|
||||
log_pt2_compile_event=True,
|
||||
dynamo_compile_column_us="compile_time_autotune_time_us",
|
||||
metadata=_autotune_metadata(input_nodes),
|
||||
):
|
||||
benchmark_results = self.benchmark(
|
||||
choices, input_nodes, layout, input_gen_fns, hint_override=hint_override
|
||||
)
|
||||
if config.max_autotune_report_choices_stats:
|
||||
_log_autotune_choices_stats(
|
||||
f"{name}_template_autotuning", benchmark_results
|
||||
)
|
||||
return benchmark_results
|
||||
|
||||
def do_autotuning(
|
||||
self,
|
||||
name,
|
||||
input_nodes,
|
||||
layout,
|
||||
input_gen_fns,
|
||||
inputs_key,
|
||||
choices,
|
||||
precompile_fn,
|
||||
hint_override: Optional[int] = None,
|
||||
best_config_future=None,
|
||||
):
|
||||
"""Execute the autotuning process for kernel algorithm selection.
|
||||
|
||||
This method orchestrates the complete autotuning pipeline including precompilation,
|
||||
prescreening, benchmarking, and feedback collection to select the optimal kernel
|
||||
implementation for given inputs.
|
||||
|
||||
Args:
|
||||
name: Name identifier for the operation being autotuned (e.g., 'mm', 'convolution').
|
||||
input_nodes: List of input IR nodes used for benchmarking.
|
||||
layout: Layout information specifying device and memory format for the operation.
|
||||
input_gen_fns: Optional dict mapping argument indices to functions that generate
|
||||
torch.Tensor inputs from ir.Buffer for benchmarking. If provided, these are
|
||||
used instead of random tensors.
|
||||
inputs_key: Cache key representing the input characteristics (sizes, strides, dtypes).
|
||||
choices: List of ChoiceCaller objects representing candidate kernel implementations.
|
||||
precompile_fn: Callable that precompiles all kernel choices before benchmarking.
|
||||
hint_override: Optional index to override which choice is selected, used for testing
|
||||
or forced selection.
|
||||
best_config_future: Optional future containing pre-determined best configuration to
|
||||
filter choices by specific config parameters.
|
||||
|
||||
Returns:
|
||||
dict: Mapping from ChoiceCaller to benchmark timing in seconds. Choices with
|
||||
non-finite timings (inf/nan) indicate failures.
|
||||
|
||||
Raises:
|
||||
NoValidChoicesError: When all choices fail to compile or benchmark, or when all
|
||||
timing results are non-finite.
|
||||
"""
|
||||
precompile_start_ts = time.time()
|
||||
with dynamo_timed(
|
||||
f"{name}_template_precompiling",
|
||||
log_pt2_compile_event=True,
|
||||
dynamo_compile_column_us="compile_time_autotune_time_us",
|
||||
):
|
||||
precompile_fn()
|
||||
precompile_elapse = time.time() - precompile_start_ts
|
||||
log.debug("Precompilation elapsed time: %.02fs", precompile_elapse)
|
||||
# Prune anything that failed to compile
|
||||
choices = [c for c in choices if not c.failed]
|
||||
if len(choices) == 0:
|
||||
raise self.create_no_valid_choices(
|
||||
name, "All choices failed to compile for backend."
|
||||
)
|
||||
|
||||
candidates = self.prescreen_choices(
|
||||
choices, name, inputs_key, self.prescreening_cache
|
||||
)
|
||||
prescreening_elapse: Optional[float] = None
|
||||
if candidates:
|
||||
prescreening_start_ts = time.time()
|
||||
timings = self.lookup(
|
||||
candidates,
|
||||
name,
|
||||
inputs_key,
|
||||
lambda choices: self.autotune(
|
||||
name,
|
||||
input_nodes,
|
||||
layout,
|
||||
input_gen_fns,
|
||||
choices,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
hint_override=hint_override,
|
||||
)
|
||||
choices = self.prune_choices_postscreen(
|
||||
choices, timings, name, inputs_key, self.prescreening_cache
|
||||
)
|
||||
prescreening_elapse = time.time() - prescreening_start_ts
|
||||
log.debug("Prescreening elapsed time: %.02fs", prescreening_elapse)
|
||||
|
||||
autotune_start_ts = time.time()
|
||||
|
||||
if best_config_future is not None:
|
||||
best_config = await_sync(best_config_future)
|
||||
|
||||
important_keys = [
|
||||
"ACC_TYPE",
|
||||
"ALLOW_TF32",
|
||||
"BLOCK_K",
|
||||
"BLOCK_M",
|
||||
"BLOCK_N",
|
||||
"EVEN_K",
|
||||
"GROUP_M",
|
||||
"USE_FAST_ACCUM",
|
||||
"num_stages",
|
||||
"num_warps",
|
||||
"num_consumer_groups",
|
||||
"num_buffers_warp_spec",
|
||||
]
|
||||
choices = [
|
||||
choice
|
||||
for choice in choices
|
||||
if all(
|
||||
f"{k}={best_config[k]}" in choice.description
|
||||
for k in important_keys
|
||||
)
|
||||
for k in important_keys
|
||||
]
|
||||
log.info("Filtered to %d choices based on best_config", len(choices))
|
||||
|
||||
has_autotuned: bool = False
|
||||
|
||||
def track_has_autotuned(choices):
|
||||
nonlocal has_autotuned
|
||||
has_autotuned = True
|
||||
return self.autotune(
|
||||
name,
|
||||
input_nodes,
|
||||
layout,
|
||||
input_gen_fns,
|
||||
choices,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
timings = self.lookup(
|
||||
choices,
|
||||
name,
|
||||
inputs_key,
|
||||
track_has_autotuned,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
autotune_elapse = time.time() - autotune_start_ts
|
||||
log.debug("Autotuning elapsed time: %.02fs", autotune_elapse)
|
||||
|
||||
if timings and all(not math.isfinite(timing) for timing in timings.values()):
|
||||
raise NoValidChoicesError
|
||||
|
||||
if (
|
||||
has_autotuned
|
||||
or log.getEffectiveLevel() == logging.DEBUG
|
||||
or config.trace.log_autotuning_results
|
||||
):
|
||||
self.log_results(
|
||||
name,
|
||||
input_nodes,
|
||||
timings,
|
||||
autotune_elapse,
|
||||
precompile_elapse,
|
||||
prescreening_elapse,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
def profiler_bench_function():
|
||||
# we're not running through the normal caching autotuner method here because we want to avoid returning
|
||||
# the cached value.
|
||||
# Avoid benchmarking in a separate process because it's not easy to signal to the TuningProcess that we
|
||||
# should use the profiler.
|
||||
with config.patch(
|
||||
profile_bandwidth_with_do_bench_using_profiling=True,
|
||||
autotune_in_subproc=False,
|
||||
):
|
||||
return self.benchmark(choices, input_nodes, layout, input_gen_fns)
|
||||
|
||||
for feedback_fn in self.feedback_saver_fns:
|
||||
# re-benchmarking the same choices with profiler is a bit expensive, so pass it in as a thunk.
|
||||
feedback_fn(
|
||||
timings,
|
||||
name,
|
||||
input_nodes,
|
||||
choices,
|
||||
profiler_bench_function,
|
||||
)
|
||||
|
||||
return timings
|
||||
|
||||
def create_no_valid_choices(self, name: str, reason: str) -> NoValidChoicesError:
|
||||
backend_config = (
|
||||
"max_autotune_gemm_backends"
|
||||
|
||||
@ -13,6 +13,7 @@ from torch.types import _dtype
|
||||
from torch.utils._exposed_in import exposed_in
|
||||
|
||||
from . import autograd, utils
|
||||
from .effects import EffectType
|
||||
|
||||
|
||||
device_types_t = Optional[Union[str, Sequence[str]]]
|
||||
@ -471,6 +472,9 @@ class CustomOpDef:
|
||||
self._abstract_fn = fn
|
||||
return fn
|
||||
|
||||
def register_effect(self, effect: Optional[EffectType]) -> None:
|
||||
self._lib._register_effectful_op(self._qualname, effect)
|
||||
|
||||
def register_torch_dispatch(
|
||||
self, torch_dispatch_class: Any, fn: Optional[Callable] = None, /
|
||||
) -> Callable:
|
||||
|
||||
68
torch/_library/effects.py
Normal file
68
torch/_library/effects.py
Normal file
@ -0,0 +1,68 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class EffectType(Enum):
|
||||
ORDERED = "Ordered"
|
||||
|
||||
|
||||
from torch._library.utils import RegistrationHandle
|
||||
|
||||
|
||||
class EffectHolder:
|
||||
"""A holder where one can register an effect impl to."""
|
||||
|
||||
def __init__(self, qualname: str):
|
||||
self.qualname: str = qualname
|
||||
self._set_default_effect()
|
||||
|
||||
def _set_default_effect(self) -> None:
|
||||
self._effect: Optional[EffectType] = None
|
||||
|
||||
# If the op contains a ScriptObject input, we want to mark it as having effects
|
||||
namespace, opname = torch._library.utils.parse_namespace(self.qualname)
|
||||
split = opname.split(".")
|
||||
if len(split) > 1:
|
||||
assert len(split) == 2, (
|
||||
f"Tried to split {opname} based on '.' but found more than 1 '.'"
|
||||
)
|
||||
opname, overload = split
|
||||
else:
|
||||
overload = ""
|
||||
|
||||
if namespace == "higher_order":
|
||||
return
|
||||
|
||||
opname = f"{namespace}::{opname}"
|
||||
if torch._C._get_operation_overload(opname, overload) is not None:
|
||||
# Since we call this when destroying the library, sometimes the
|
||||
# schema will be gone already at that time.
|
||||
schema = torch._C._get_schema(opname, overload)
|
||||
for arg in schema.arguments:
|
||||
if isinstance(arg.type, torch.ClassType):
|
||||
self._effect = EffectType.ORDERED
|
||||
return
|
||||
|
||||
@property
|
||||
def effect(self) -> Optional[EffectType]:
|
||||
return self._effect
|
||||
|
||||
@effect.setter
|
||||
def effect(self, _):
|
||||
raise RuntimeError("Unable to directly set kernel.")
|
||||
|
||||
def register(self, effect: Optional[EffectType]) -> RegistrationHandle:
|
||||
"""Register an effect
|
||||
|
||||
Returns a RegistrationHandle that one can use to de-register this
|
||||
effect.
|
||||
"""
|
||||
self._effect = effect
|
||||
|
||||
def deregister_effect():
|
||||
self._set_default_effect()
|
||||
|
||||
handle = RegistrationHandle(deregister_effect)
|
||||
return handle
|
||||
@ -1,6 +1,7 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
from .effects import EffectHolder
|
||||
from .fake_impl import FakeImplHolder
|
||||
from .utils import RegistrationHandle
|
||||
|
||||
@ -51,6 +52,8 @@ class SimpleOperatorEntry:
|
||||
GenericTorchDispatchRuleHolder(qualname)
|
||||
)
|
||||
|
||||
self.effect: EffectHolder = EffectHolder(qualname)
|
||||
|
||||
# For compatibility reasons. We can delete this soon.
|
||||
@property
|
||||
def abstract_impl(self) -> FakeImplHolder:
|
||||
|
||||
@ -230,6 +230,12 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
|
||||
if ndim_extra > 0:
|
||||
tensor = tensor.view((1,) * ndim_extra + tensor.shape)
|
||||
|
||||
# special handling for np._CopyMode
|
||||
try:
|
||||
copy = bool(copy)
|
||||
except ValueError:
|
||||
# TODO handle _CopyMode.IF_NEEDED correctly
|
||||
copy = False
|
||||
# copy if requested
|
||||
if copy:
|
||||
tensor = tensor.clone()
|
||||
|
||||
@ -1023,6 +1023,7 @@ class TorchBindOpOverload(OpOverload[_P, _T]):
|
||||
DispatchKey.BackendSelect,
|
||||
DispatchKey.PythonTLSSnapshot,
|
||||
DispatchKey.PythonDispatcher,
|
||||
DispatchKey.Functionalize,
|
||||
]
|
||||
|
||||
def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):
|
||||
@ -1046,17 +1047,23 @@ class TorchBindOpOverload(OpOverload[_P, _T]):
|
||||
def _register_as_effectful_op_temporarily(self):
|
||||
from torch._higher_order_ops.effects import (
|
||||
_EffectType,
|
||||
_get_effect,
|
||||
_register_effectful_op,
|
||||
SIDE_EFFECTS,
|
||||
)
|
||||
|
||||
try:
|
||||
if self not in SIDE_EFFECTS:
|
||||
_register_effectful_op(self, _EffectType.ORDERED)
|
||||
# We don't want to register the effect if there already exists a
|
||||
# registration, especially if the registration is None (explicitly
|
||||
# no effect)
|
||||
register_tmp_effect = _get_effect(self) is None
|
||||
handle = None
|
||||
if register_tmp_effect:
|
||||
handle = _register_effectful_op(self, _EffectType.ORDERED)
|
||||
yield
|
||||
finally:
|
||||
if self in SIDE_EFFECTS:
|
||||
del SIDE_EFFECTS[self]
|
||||
if register_tmp_effect:
|
||||
assert handle is not None
|
||||
handle.destroy()
|
||||
|
||||
# Use positional-only argument to avoid naming collision with aten ops arguments
|
||||
# that are named "self". This way, all the aten ops can be called by kwargs.
|
||||
|
||||
@ -11,7 +11,7 @@ import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import _functionalization_reapply_views_tls as _reapply_views
|
||||
from torch._ops import _get_dispatch_mode_pre_dispatch
|
||||
from torch._ops import _get_dispatch_mode_pre_dispatch, TorchBindOpOverload
|
||||
from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.utils._python_dispatch import (
|
||||
_detect_infra_mode,
|
||||
@ -471,7 +471,7 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
|
||||
from torch._higher_order_ops.effects import handle_effects, has_effects
|
||||
|
||||
if has_effects(func, args, kwargs):
|
||||
if has_effects(func):
|
||||
assert not torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||
func.name(), torch._C.DispatchKey.Functionalize
|
||||
)
|
||||
@ -504,65 +504,81 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
- FunctionalTensor._extra_dispatch_keys
|
||||
)
|
||||
|
||||
# All we want to do here is reuse the existing C++ functionalization logic.
|
||||
# This requires swizzling our TLS dispatch keys so that the Functionalize key is active.
|
||||
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
|
||||
try:
|
||||
# By default for python functionalization (for AOTAutograd), we reapply views.
|
||||
old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined]
|
||||
if isinstance(func, TorchBindOpOverload):
|
||||
# When the function is a TorchBindOpOverload, meaning some of the
|
||||
# inputs are FakeScriptObjects, we need to skip c++ dispatcher and
|
||||
# dispatch in python because C++ dispatcher will check the schema
|
||||
# and cannot recognize FakeScriptObject.
|
||||
ctx = PythonFunctionalizeAPI()
|
||||
fully_unwrapped_args = ctx.unwrap_tensors(args)
|
||||
fully_unwrapped_kwargs = ctx.unwrap_tensors(
|
||||
kwargs # pyrefly: ignore[bad-argument-type]
|
||||
)
|
||||
outs_unwrapped = func(
|
||||
*fully_unwrapped_args,
|
||||
**fully_unwrapped_kwargs,
|
||||
)
|
||||
outs_wrapped = ctx.wrap_tensors(outs_unwrapped)
|
||||
else:
|
||||
# All we want to do here is reuse the existing C++ functionalization logic.
|
||||
# This requires swizzling our TLS dispatch keys so that the Functionalize key is active.
|
||||
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
|
||||
try:
|
||||
# By default for python functionalization (for AOTAutograd), we reapply views.
|
||||
old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined]
|
||||
|
||||
# Sometimes these functions cannot be directly dispatched to functionalize key
|
||||
# because args are sometimes not functional tensors for some reason?
|
||||
if func in FunctionalTensor.metadata_fns:
|
||||
outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped)
|
||||
outs_wrapped = pytree.tree_map_only(
|
||||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
else:
|
||||
# Note: [Functionalization View Replay Annotation]
|
||||
# When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases
|
||||
# at the first time they are next used.
|
||||
# This is a problem when plumbing user annotations during tracing. We want the view ops from view replay
|
||||
# to have the same annotation that the user specified on the original views. But view replay in
|
||||
# functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)),
|
||||
# so when we regenerate views before calling into second_op, those views will end up getting the metadata
|
||||
# for second_op!
|
||||
#
|
||||
# Instead, we need to remember the node metadata from the original views, and ensure that this node metadata
|
||||
# is globally set when we lazily perform view replay.
|
||||
# The globally set metadata will be used to populate the fx node created for the replayed operation.
|
||||
if m := torch._C._get_dispatch_mode(
|
||||
torch._C._TorchDispatchModeKey.PROXY
|
||||
):
|
||||
for a in pytree.tree_leaves([args, kwargs]):
|
||||
if not isinstance(a, FunctionalTensor):
|
||||
continue
|
||||
curr_node = m.tracer.tensor_tracker[
|
||||
torch._from_functional_tensor(a.elem)
|
||||
].proxy.node
|
||||
with fx_traceback.set_current_replay_node(curr_node):
|
||||
torch._sync(a)
|
||||
# Sometimes these functions cannot be directly dispatched to functionalize key
|
||||
# because args are sometimes not functional tensors for some reason?
|
||||
if func in FunctionalTensor.metadata_fns:
|
||||
outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped)
|
||||
outs_wrapped = pytree.tree_map_only(
|
||||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
else:
|
||||
# Note: [Functionalization View Replay Annotation]
|
||||
# When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases
|
||||
# at the first time they are next used.
|
||||
# This is a problem when plumbing user annotations during tracing. We want the view ops from view replay
|
||||
# to have the same annotation that the user specified on the original views. But view replay in
|
||||
# functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)),
|
||||
# so when we regenerate views before calling into second_op, those views will end up getting the metadata
|
||||
# for second_op!
|
||||
#
|
||||
# Instead, we need to remember the node metadata from the original views, and ensure that this node metadata
|
||||
# is globally set when we lazily perform view replay.
|
||||
# The globally set metadata will be used to populate the fx node created for the replayed operation.
|
||||
if m := torch._C._get_dispatch_mode(
|
||||
torch._C._TorchDispatchModeKey.PROXY
|
||||
):
|
||||
for a in pytree.tree_leaves([args, kwargs]):
|
||||
if not isinstance(a, FunctionalTensor):
|
||||
continue
|
||||
curr_node = m.tracer.tensor_tracker[
|
||||
torch._from_functional_tensor(a.elem)
|
||||
].proxy.node
|
||||
with fx_traceback.set_current_replay_node(curr_node):
|
||||
torch._sync(a)
|
||||
|
||||
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
|
||||
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
|
||||
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
|
||||
# from the TLS in order to avoid infinite looping, but this would prevent us from coming
|
||||
# back to PreDispatch later
|
||||
outs_unwrapped = func._op_dk(
|
||||
torch._C.DispatchKey.Functionalize,
|
||||
*args_unwrapped,
|
||||
**kwargs_unwrapped,
|
||||
)
|
||||
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
|
||||
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
|
||||
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
|
||||
# from the TLS in order to avoid infinite looping, but this would prevent us from coming
|
||||
# back to PreDispatch later
|
||||
outs_unwrapped = func._op_dk(
|
||||
torch._C.DispatchKey.Functionalize,
|
||||
*args_unwrapped,
|
||||
**kwargs_unwrapped,
|
||||
)
|
||||
|
||||
if self.export:
|
||||
if func is torch.ops.aten.dropout.default:
|
||||
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
|
||||
outs_wrapped = pytree.tree_map_only(
|
||||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined]
|
||||
if self.export:
|
||||
if func is torch.ops.aten.dropout.default:
|
||||
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
|
||||
outs_wrapped = pytree.tree_map_only(
|
||||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined]
|
||||
|
||||
is_included = torch._C._dispatch_tls_is_dispatch_key_included(
|
||||
torch._C.DispatchKey.Functionalize
|
||||
|
||||
@ -18,6 +18,7 @@ import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import ScriptObject # type: ignore[attr-defined]
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
|
||||
from ._compatibility import compatibility
|
||||
from ._lazy_graph_module import _make_graph_module
|
||||
@ -421,8 +422,10 @@ class Tracer(TracerBase):
|
||||
# a get_attr to retrieve that tensor. Otherwise, we'll store away the
|
||||
# tensor value into a special attribute on the Module s.t. we can
|
||||
# retrieve it with a get_attr.
|
||||
if isinstance(a, _constant_attribute_types):
|
||||
qualname: Optional[str] = self.tensor_attrs.get(a)
|
||||
if isinstance(a, _constant_attribute_types) or is_opaque_type(type(a)):
|
||||
qualname: Optional[str] = self.tensor_attrs.get(
|
||||
a
|
||||
) # pyrefly: ignore[no-matching-overload]
|
||||
|
||||
# Tensor was not found in the Module hierarchy, stow it away in a
|
||||
# special attribute and set the qualname to refer to that
|
||||
@ -433,13 +436,17 @@ class Tracer(TracerBase):
|
||||
base_name = "_torchbind_obj"
|
||||
elif isinstance(a, pytree.TreeSpec):
|
||||
base_name = "_tree_spec_constant"
|
||||
elif is_opaque_type(type(a)):
|
||||
base_name = "_opaque_obj"
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"cannot create constant arg for {a} of type {type(a)}."
|
||||
)
|
||||
qualname = self.get_fresh_qualname(base_name)
|
||||
assert isinstance(qualname, str)
|
||||
self.tensor_attrs[a] = qualname
|
||||
self.tensor_attrs[a] = ( # pyrefly: ignore[unsupported-operation]
|
||||
qualname
|
||||
)
|
||||
setattr(self.root, qualname, a)
|
||||
|
||||
return self.create_node("get_attr", qualname, (), {})
|
||||
|
||||
@ -84,7 +84,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from torch._ops import OpOverload
|
||||
from torch.fx._symbolic_trace import PHBase
|
||||
from torch.types import IntLikeType
|
||||
from torch.types import BoolLikeType, FloatLikeType, IntLikeType
|
||||
|
||||
__all__ = [
|
||||
"PythonKeyTracer",
|
||||
@ -458,7 +458,7 @@ def _sympy_handlers() -> dict[type[sympy.Expr], Callable[..., Any]]:
|
||||
|
||||
def _build_proxy_for_sym_expr(
|
||||
tracer: _ProxyTracer, expr: sympy.Expr, out: PySymType | None = None
|
||||
) -> PySymType | None:
|
||||
) -> IntLikeType | FloatLikeType | BoolLikeType | None:
|
||||
"""
|
||||
Decompose `expr` and look for the pieces as inputs. If `out` is provided
|
||||
then that will be the resulting SymNode (and `out.expr` must be the same as
|
||||
@ -532,6 +532,13 @@ def _build_proxy_for_sym_expr(
|
||||
assert not out
|
||||
return value.value
|
||||
|
||||
if isinstance(expr, (int, float, bool)):
|
||||
return expr
|
||||
if expr.is_Integer:
|
||||
return int(expr)
|
||||
if expr.is_Float:
|
||||
return float(expr)
|
||||
|
||||
args = []
|
||||
for arg in expr.args:
|
||||
if (arg_value := _build_proxy_for_sym_expr(tracer, arg)) is None:
|
||||
|
||||
@ -19,6 +19,7 @@ from torch._library.custom_ops import (
|
||||
CustomOpDef,
|
||||
device_types_t,
|
||||
)
|
||||
from torch._library.effects import EffectType
|
||||
from torch._library.infer_schema import infer_schema # noqa: F401
|
||||
from torch._library.triton import triton_op, wrap_triton
|
||||
from torch._ops import OpOverload
|
||||
@ -398,6 +399,22 @@ class Library:
|
||||
|
||||
self.m.fallback(dispatch_key, fn, with_keyset)
|
||||
|
||||
def _register_effectful_op(self, op_name: str, effect: Optional[EffectType]):
|
||||
"""
|
||||
Registers an effect to an operator. This is used to register an op that
|
||||
has side effects that is not capturable by the schema.
|
||||
|
||||
Args:
|
||||
op_name: operator name (along with the overload) or OpOverload object.
|
||||
effect: The effect of the op.
|
||||
"""
|
||||
from torch._higher_order_ops.effects import (
|
||||
_register_effectful_op as hoo_register_effect,
|
||||
)
|
||||
|
||||
handle = hoo_register_effect(op_name, effect)
|
||||
self._registration_handles.append(handle)
|
||||
|
||||
def _destroy(self):
|
||||
if self.m is not None:
|
||||
self.m.reset()
|
||||
@ -1065,6 +1082,44 @@ def register_fake(
|
||||
return register(func)
|
||||
|
||||
|
||||
def _register_effectful_op(
|
||||
op: _op_identifier,
|
||||
effect: Optional[EffectType],
|
||||
*,
|
||||
lib: Optional[Library] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
To specify that an operator has side-effects, we must register an effect
|
||||
type for the operator. This will prevent graph passes in torch.compile from
|
||||
reordering operations with the same effect type.
|
||||
|
||||
Args:
|
||||
op_name: Operator name (along with the overload) or OpOverload object.
|
||||
effect: Effect type to register. None means the operator is not effectful.
|
||||
"""
|
||||
if not isinstance(
|
||||
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
|
||||
):
|
||||
raise ValueError(
|
||||
f"register_effectful_op({op}): got unexpected type for op: {type(op)}"
|
||||
)
|
||||
|
||||
if isinstance(op, torch._ops.OpOverload):
|
||||
op = op._name
|
||||
opdef = _maybe_get_opdef(op)
|
||||
if opdef is not None:
|
||||
opdef.register_effect(effect)
|
||||
assert isinstance(op, str)
|
||||
|
||||
namespace, _ = torch._library.utils.parse_namespace(op)
|
||||
if lib is None:
|
||||
use_lib = Library(namespace, "FRAGMENT")
|
||||
_keep_alive.append(use_lib)
|
||||
else:
|
||||
use_lib = lib
|
||||
use_lib._register_effectful_op(op, effect)
|
||||
|
||||
|
||||
def register_autograd(
|
||||
op: _op_identifier,
|
||||
backward: Callable,
|
||||
|
||||
@ -37,7 +37,7 @@ import functools
|
||||
import traceback
|
||||
import weakref
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING # noqa: F401
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
|
||||
Reference in New Issue
Block a user