mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Compare commits
122 Commits
ciflow/ind
...
main-enabl
Author | SHA1 | Date | |
---|---|---|---|
e752a29afd | |||
36b622bb72 | |||
83a04f38a4 | |||
6579829bee | |||
2b856676f3 | |||
5746261c97 | |||
b3c94fd0fc | |||
6fd366b2c7 | |||
fe25f6ab59 | |||
ca89e5732f | |||
f12cb265d4 | |||
7dc6bf5377 | |||
e5ba464808 | |||
7d95185044 | |||
77fb3c1cac | |||
11a3d1d87b | |||
8c6d9feb26 | |||
003dd13073 | |||
c2bd41ac9f | |||
ca8bd5dbed | |||
26f3803433 | |||
48064acf37 | |||
e5a9c247bc | |||
36371b8ec7 | |||
7e6721fb0a | |||
901bbcba12 | |||
febb603230 | |||
568d2f3ae7 | |||
b54e466fd0 | |||
53f9ae0e50 | |||
b42fe389b9 | |||
66ea76ec44 | |||
e787d532b6 | |||
b3f6d49b69 | |||
bc1f2108d7 | |||
f071f17911 | |||
fa1539594b | |||
dfc8a1c5dd | |||
7f9b745494 | |||
83f9baf413 | |||
ffc7552e01 | |||
78f5a1ec60 | |||
2b71b62045 | |||
8c4b528403 | |||
066f818eea | |||
14af1dc3da | |||
2395d7d7da | |||
0aa7ebaf03 | |||
7a97832585 | |||
84d141e910 | |||
7c6c5d04fe | |||
b509fb9b5d | |||
331b7cc054 | |||
815d641599 | |||
ffe3cb226a | |||
7ae123d72c | |||
7719cb75bf | |||
712f54d453 | |||
f58f301313 | |||
5c583e2573 | |||
0c14f55de6 | |||
8e510e1095 | |||
59d30d1b75 | |||
3915898c22 | |||
3044e1a460 | |||
b11593c31b | |||
36871622f1 | |||
b4fd47179e | |||
4f400ab520 | |||
839f6facdb | |||
ca65023b90 | |||
132ae8e6dd | |||
a20afb6100 | |||
47524dcc48 | |||
9ffba8a2f9 | |||
3681312ce0 | |||
7778a58e7c | |||
e7091a47da | |||
bcfea48ab7 | |||
d2e1dbc8f2 | |||
89298ada83 | |||
c467e59cb0 | |||
bbb902c8dd | |||
e6f766c7d7 | |||
13b621d87c | |||
01738a3fea | |||
a2f34bdd7c | |||
a63ab0b8cd | |||
102b7885ff | |||
382d04a51e | |||
1ec0755a7e | |||
058782c6ab | |||
2b4ef6b4d6 | |||
3f83e8915e | |||
d7e3f493d9 | |||
08f09d9543 | |||
74acf92648 | |||
cbf212e9c7 | |||
d18e068fd6 | |||
3401665110 | |||
8c60f4ae08 | |||
c4565c3b94 | |||
6918f17114 | |||
9b6be53326 | |||
7fee6bbf34 | |||
6adaa328f4 | |||
4a7eed527f | |||
d2494cbb2b | |||
5eddbb5e47 | |||
c9b2a09530 | |||
bf5aeb3148 | |||
45b8c0f75c | |||
c733072874 | |||
fbe0d20a17 | |||
1fa11f42b1 | |||
6f713e25bb | |||
09a4187b8e | |||
306c55ba27 | |||
56d6229ff9 | |||
74db92b218 | |||
c48843e4c6 | |||
9e89b1c4c7 |
@ -187,19 +187,22 @@ if [[ $CUDA_VERSION == 12* || $CUDA_VERSION == 13* ]]; then
|
||||
export USE_CUFILE=0
|
||||
else
|
||||
DEPS_LIST+=(
|
||||
"/usr/local/cuda/lib64/libnvToolsExt.so.1"
|
||||
"/usr/local/cuda/lib64/libcublas.so.12"
|
||||
"/usr/local/cuda/lib64/libcublasLt.so.12"
|
||||
"/usr/local/cuda/lib64/libcudart.so.12"
|
||||
"/usr/local/cuda/lib64/libnvrtc.so.12"
|
||||
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12")
|
||||
DEPS_SONAME+=(
|
||||
"libnvToolsExt.so.1"
|
||||
"libcublas.so.12"
|
||||
"libcublasLt.so.12"
|
||||
"libcudart.so.12"
|
||||
"libnvrtc.so.12"
|
||||
"libcupti.so.12")
|
||||
|
||||
if [[ $CUDA_VERSION != 12.9* ]]; then
|
||||
DEPS_LIST+=("/usr/local/cuda/lib64/libnvToolsExt.so.1")
|
||||
DEPS_SONAME+=("libnvToolsExt.so.1")
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo "Using nvidia libs from pypi."
|
||||
|
1
.github/ISSUE_TEMPLATE/ci-sev.md
vendored
1
.github/ISSUE_TEMPLATE/ci-sev.md
vendored
@ -8,6 +8,7 @@ assignees: ''
|
||||
---
|
||||
|
||||
> NOTE: Remember to label this issue with "`ci: sev`"
|
||||
> If you want autorevert to be disabled, keep the ci: disable-autorevert label
|
||||
|
||||
<!-- Add the `merge blocking` label to this PR to prevent PRs from being merged while this issue is open -->
|
||||
|
||||
|
4
.github/ISSUE_TEMPLATE/disable-autorevert.md
vendored
4
.github/ISSUE_TEMPLATE/disable-autorevert.md
vendored
@ -1,7 +1,7 @@
|
||||
---
|
||||
name: DISABLE AUTOREVERT
|
||||
name: "D❌\U0001F519 ISABLE AUTOREVERT"
|
||||
about: Disables autorevert when open
|
||||
title: "❌\U0001F519 [DISABLE AUTOREVERT]"
|
||||
title: "[DISABLE AUTOREVERT]"
|
||||
labels: 'ci: disable-autorevert'
|
||||
assignees: ''
|
||||
|
||||
|
@ -65,7 +65,7 @@ runs:
|
||||
cd .ci/lumen_cli
|
||||
python3 -m pip install -e .
|
||||
)
|
||||
MAX_JOBS="$(nproc --ignore=6)"
|
||||
MAX_JOBS="$(nproc --ignore=10)"
|
||||
export MAX_JOBS
|
||||
|
||||
# Split the comma-separated list and build each target
|
||||
|
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
||||
8ad2aa5d354d1bf432339113860185d5a5d1abbd
|
||||
1b013f5b5a87a1882eb143c26d79d091150d6a37
|
||||
|
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
||||
f5c6c2ec6490455e86f67b2a25c10390d60a27f7
|
||||
faffd5cf673615583da6517275e361cb3dbc77e6
|
||||
|
4
.github/pytorch-probot.yml
vendored
4
.github/pytorch-probot.yml
vendored
@ -3,6 +3,7 @@ ciflow_tracking_issue: 64124
|
||||
ciflow_push_tags:
|
||||
- ciflow/b200
|
||||
- ciflow/b200-symm-mem
|
||||
- ciflow/b200-distributed
|
||||
- ciflow/binaries
|
||||
- ciflow/binaries_libtorch
|
||||
- ciflow/binaries_wheel
|
||||
@ -15,7 +16,8 @@ ciflow_push_tags:
|
||||
- ciflow/inductor-micro-benchmark
|
||||
- ciflow/inductor-micro-benchmark-cpu-x86
|
||||
- ciflow/inductor-perf-compare
|
||||
- ciflow/inductor-perf-test-nightly-rocm
|
||||
- ciflow/inductor-perf-test-nightly-rocm-mi300
|
||||
- ciflow/inductor-perf-test-nightly-rocm-mi355
|
||||
- ciflow/inductor-perf-test-nightly-x86-zen
|
||||
- ciflow/inductor-periodic
|
||||
- ciflow/inductor-rocm
|
||||
|
62
.github/workflows/b200-distributed.yml
vendored
Normal file
62
.github/workflows/b200-distributed.yml
vendored
Normal file
@ -0,0 +1,62 @@
|
||||
name: CI for distributed tests on B200
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- .github/workflows/b200-distributed.yml
|
||||
workflow_dispatch:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/b200-distributed/*
|
||||
schedule:
|
||||
- cron: 46 8 * * * # about 1:46am PDT
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
|
||||
get-label-type:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200:
|
||||
name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed-b200
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '10.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "distributed", shard: 1, num_shards: 2, runner: "linux.dgx.b200.8" },
|
||||
{ config: "distributed", shard: 2, num_shards: 2, runner: "linux.dgx.b200.8" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda12_8-py3_10-gcc11-test-distributed-b200:
|
||||
name: linux-jammy-cuda12.8-py3.10-gcc11-test-b200
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs:
|
||||
- linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200
|
||||
with:
|
||||
timeout-minutes: 1200
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
|
||||
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.test-matrix }}
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
secrets: inherit
|
19
.github/workflows/build-vllm-wheel.yml
vendored
19
.github/workflows/build-vllm-wheel.yml
vendored
@ -27,9 +27,8 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [ '3.12' ]
|
||||
# TODO (huydhn): Add cu130 after https://github.com/vllm-project/vllm/issues/24464 is resolved
|
||||
platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ]
|
||||
device: [ 'cu128', 'cu129' ]
|
||||
device: [ 'cu128', 'cu129', 'cu130' ]
|
||||
include:
|
||||
- platform: manylinux_2_28_x86_64
|
||||
device: cu128
|
||||
@ -39,6 +38,10 @@ jobs:
|
||||
device: cu129
|
||||
manylinux-image: 'pytorch/manylinux2_28-builder:cuda12.9'
|
||||
runner: linux.12xlarge.memory
|
||||
- platform: manylinux_2_28_x86_64
|
||||
device: cu130
|
||||
manylinux-image: 'pytorch/manylinux2_28-builder:cuda13.0'
|
||||
runner: linux.12xlarge.memory
|
||||
- platform: manylinux_2_28_aarch64
|
||||
device: cu128
|
||||
manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.8'
|
||||
@ -47,6 +50,11 @@ jobs:
|
||||
device: cu129
|
||||
manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.9'
|
||||
runner: linux.arm64.r7g.12xlarge.memory
|
||||
exclude:
|
||||
# TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and
|
||||
# xformers is update to support 13.0
|
||||
- platform: manylinux_2_28_aarch64
|
||||
device: cu130
|
||||
name: "Build ${{ matrix.device }} vLLM wheel on ${{ matrix.platform }}"
|
||||
runs-on: ${{ matrix.runner }}
|
||||
timeout-minutes: 480
|
||||
@ -169,7 +177,12 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ]
|
||||
device: [ 'cu128', 'cu129' ]
|
||||
device: [ 'cu128', 'cu129', 'cu130' ]
|
||||
exclude:
|
||||
# TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and
|
||||
# xformers is update to support 13.0
|
||||
- platform: manylinux_2_28_aarch64
|
||||
device: cu130
|
||||
env:
|
||||
PLATFORM: ${{ matrix.platform }}
|
||||
BUILD_DEVICE: ${{ matrix.device }}
|
||||
|
132
.github/workflows/inductor-perf-test-nightly-rocm-mi300.yml
vendored
Normal file
132
.github/workflows/inductor-perf-test-nightly-rocm-mi300.yml
vendored
Normal file
@ -0,0 +1,132 @@
|
||||
name: inductor-perf-nightly-rocm-mi300
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/inductor-perf-test-nightly-rocm-mi300/*
|
||||
schedule:
|
||||
- cron: 15 0 * * *
|
||||
# NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
|
||||
# out, let try to run torchao cudagraphs_low_precision as part of cudagraphs
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
training:
|
||||
description: Run training (on by default)?
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
inference:
|
||||
description: Run inference (on by default)?
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
default:
|
||||
description: Run inductor_default?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
dynamic:
|
||||
description: Run inductor_dynamic_shapes?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
cppwrapper:
|
||||
description: Run inductor_cpp_wrapper?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
cudagraphs:
|
||||
description: Run inductor_cudagraphs?
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
freezing_cudagraphs:
|
||||
description: Run inductor_cudagraphs with freezing for inference?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
aotinductor:
|
||||
description: Run aot_inductor for inference?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
maxautotune:
|
||||
description: Run inductor_max_autotune?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
benchmark_configs:
|
||||
description: The list of configs used the benchmark
|
||||
required: false
|
||||
type: string
|
||||
default: inductor_huggingface_perf_rocm_mi300,inductor_timm_perf_rocm_mi300,inductor_torchbench_perf_rocm_mi300
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions: read-all
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
opt_out_experiments: lf
|
||||
|
||||
linux-jammy-rocm-py3_10-inductor-benchmark-build:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: rocm-py3_10-inductor-benchmark-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3_10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_huggingface_perf_rocm_mi300", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi300", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi300", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi300", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi300", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi300", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi300", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi300", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi300", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi300", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi300", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi300", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi300", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-rocm-py3_10-inductor-benchmark-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: rocm-py3_10-inductor-benchmark-test
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs: linux-jammy-rocm-py3_10-inductor-benchmark-build
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3_10
|
||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-inductor-benchmark-build.outputs.test-matrix }}
|
||||
timeout-minutes: 720
|
||||
# Disable monitor in perf tests for more investigation
|
||||
disable-monitor: true
|
||||
monitor-log-interval: 10
|
||||
monitor-data-collect-interval: 2
|
||||
secrets: inherit
|
@ -1,11 +1,11 @@
|
||||
name: inductor-perf-nightly-rocm
|
||||
name: inductor-perf-nightly-rocm-mi355
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/inductor-perf-test-nightly-rocm/*
|
||||
- ciflow/inductor-perf-test-nightly-rocm-mi355/*
|
||||
schedule:
|
||||
- cron: 0 7 * * 0,3
|
||||
- cron: 15 0 * * *
|
||||
# NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
|
||||
# out, let try to run torchao cudagraphs_low_precision as part of cudagraphs
|
||||
workflow_dispatch:
|
||||
@ -59,7 +59,7 @@ on:
|
||||
description: The list of configs used the benchmark
|
||||
required: false
|
||||
type: string
|
||||
default: inductor_huggingface_perf_rocm,inductor_timm_perf_rocm,inductor_torchbench_perf_rocm
|
||||
default: inductor_huggingface_perf_rocm_mi355,inductor_timm_perf_rocm_mi355,inductor_torchbench_perf_rocm_mi355
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
@ -88,23 +88,27 @@ jobs:
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 5, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
23
.github/workflows/operator_benchmark.yml
vendored
23
.github/workflows/operator_benchmark.yml
vendored
@ -7,9 +7,11 @@ on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
test_mode:
|
||||
required: false
|
||||
type: string
|
||||
default: 'short'
|
||||
type: choice
|
||||
options:
|
||||
- 'short'
|
||||
- 'long'
|
||||
- 'all'
|
||||
description: tag filter for operator benchmarks, options from long, short, all
|
||||
schedule:
|
||||
# Run at 07:00 UTC every Sunday
|
||||
@ -37,20 +39,7 @@ jobs:
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
opbenchmark-on-demand-build:
|
||||
if: ${{ github.event_name == 'workflow_dispatch' && github.repository_owner == 'pytorch' }}
|
||||
name: opbenchmark-on-demand-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
build-environment: linux-jammy-py3.10-gcc11-build
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_operator_benchmark_${{ inputs.test_mode }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
|
||||
{ config: "cpu_operator_benchmark_${{ inputs.test_mode || 'short' }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
8
.github/workflows/trunk.yml
vendored
8
.github/workflows/trunk.yml
vendored
@ -180,13 +180,13 @@ jobs:
|
||||
disable-monitor: false
|
||||
secrets: inherit
|
||||
|
||||
win-vs2022-cuda12_6-py3-build:
|
||||
name: win-vs2022-cuda12.6-py3
|
||||
win-vs2022-cuda12_8-py3-build:
|
||||
name: win-vs2022-cuda12.8-py3
|
||||
uses: ./.github/workflows/_win-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
build-environment: win-vs2022-cuda12.6-py3
|
||||
cuda-version: "12.6"
|
||||
build-environment: win-vs2022-cuda12.8-py3
|
||||
cuda-version: "12.8"
|
||||
runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
||||
secrets: inherit
|
||||
|
||||
|
@ -256,6 +256,7 @@ endif()
|
||||
IF(USE_FBGEMM_GENAI)
|
||||
set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/)
|
||||
set(FBGEMM_GENAI_SRCS ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
|
||||
|
||||
if(USE_CUDA)
|
||||
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
|
||||
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
|
||||
@ -292,58 +293,64 @@ IF(USE_FBGEMM_GENAI)
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
|
||||
)
|
||||
|
||||
target_include_directories(fbgemm_genai PUBLIC
|
||||
target_include_directories(fbgemm_genai PRIVATE
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
|
||||
${fbgemm_genai_mx8mx8bf16_grouped}
|
||||
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
|
||||
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
|
||||
)
|
||||
else()
|
||||
if(USE_ROCM)
|
||||
# Only include the kernels we want to build to avoid increasing binary size.
|
||||
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
|
||||
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
|
||||
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
|
||||
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
|
||||
|
||||
# Add additional HIPCC compiler flags for performance
|
||||
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
|
||||
-mllvm
|
||||
-amdgpu-coerce-illegal-types=1
|
||||
-mllvm
|
||||
-enable-post-misched=0
|
||||
-mllvm
|
||||
-greedy-reverse-local-assignment=1
|
||||
-fhip-new-launch-api)
|
||||
# Add FBGEMM_GENAI include directories for torch_ops.h
|
||||
list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
|
||||
elseif(USE_ROCM)
|
||||
# Only include the kernels we want to build to avoid increasing binary size.
|
||||
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
|
||||
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
|
||||
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
|
||||
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
|
||||
|
||||
# Only compile for gfx942 for now.
|
||||
# This is rather hacky, I could not figure out a clean solution :(
|
||||
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
|
||||
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
|
||||
if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
|
||||
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
|
||||
endif()
|
||||
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
|
||||
# Add additional HIPCC compiler flags for performance
|
||||
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
|
||||
-mllvm
|
||||
-amdgpu-coerce-illegal-types=1
|
||||
-mllvm
|
||||
-enable-post-misched=0
|
||||
-mllvm
|
||||
-greedy-reverse-local-assignment=1
|
||||
-fhip-new-launch-api)
|
||||
|
||||
hip_add_library(
|
||||
fbgemm_genai STATIC
|
||||
${fbgemm_genai_native_rocm_hip}
|
||||
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
|
||||
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
|
||||
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
|
||||
|
||||
target_include_directories(fbgemm_genai PUBLIC
|
||||
# FBGEMM version of Composable Kernel is used due to some customizations
|
||||
${FBGEMM_THIRD_PARTY}/composable_kernel/include
|
||||
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
|
||||
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
|
||||
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
|
||||
)
|
||||
# Only compile for gfx942 for now.
|
||||
# This is rather hacky, I could not figure out a clean solution :(
|
||||
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
|
||||
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
|
||||
if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
|
||||
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
|
||||
endif()
|
||||
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
|
||||
|
||||
hip_add_library(
|
||||
fbgemm_genai STATIC
|
||||
${fbgemm_genai_native_rocm_hip}
|
||||
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
|
||||
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
|
||||
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
|
||||
|
||||
target_include_directories(fbgemm_genai PRIVATE
|
||||
# FBGEMM version of Composable Kernel is used due to some customizations
|
||||
${FBGEMM_THIRD_PARTY}/composable_kernel/include
|
||||
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/include
|
||||
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
|
||||
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
|
||||
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
|
||||
)
|
||||
|
||||
# Add FBGEMM_GENAI include directories for torch_ops.h
|
||||
list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
|
||||
list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@ -692,12 +699,6 @@ if(USE_CUDA AND NOT USE_ROCM)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
|
||||
|
||||
# Add FBGEMM_GENAI include directories for torch_ops.h
|
||||
if(USE_FBGEMM_GENAI)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
|
||||
endif()
|
||||
|
||||
if($ENV{ATEN_STATIC_CUDA})
|
||||
if(CUDA_VERSION VERSION_LESS_EQUAL 12.9)
|
||||
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
|
||||
|
@ -389,37 +389,16 @@ void fillVersion<DLManagedTensorVersioned>(
|
||||
// constructed out of ATen tensor
|
||||
template <class T>
|
||||
T* toDLPackImpl(const Tensor& src) {
|
||||
auto view = src;
|
||||
|
||||
// Detect whether there is need to normalize the strides
|
||||
// Background: gh-83069
|
||||
//
|
||||
// However, normalizing strides can come at a high-cost
|
||||
// to slow down toDLPack conversion 3x, so we
|
||||
// only normalize if needed.
|
||||
//
|
||||
// The following code detects whether the src follows
|
||||
// a continuous pattern. If the src follows such pattern (common-case)
|
||||
// then we do not need to normalize the strides.
|
||||
bool need_normalize_strides = src.dim() == 1 && src.size(0) == 1 && src.stride(0) != 1;
|
||||
// less common case, try normalizing the strides
|
||||
if (need_normalize_strides) {
|
||||
// create a new tensor with possibly normalized strides
|
||||
// gh-83069
|
||||
auto shape = src.sizes();
|
||||
view = src.as_strided(shape, {1}, src.storage_offset());
|
||||
}
|
||||
|
||||
ATenDLMTensor<T>* atDLMTensor(new ATenDLMTensor<T>);
|
||||
atDLMTensor->handle = view;
|
||||
atDLMTensor->handle = src;
|
||||
atDLMTensor->tensor.manager_ctx = atDLMTensor;
|
||||
atDLMTensor->tensor.deleter = &deleter<T>;
|
||||
atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
|
||||
atDLMTensor->tensor.dl_tensor.data = src.data_ptr();
|
||||
atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device());
|
||||
atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
|
||||
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
|
||||
atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(view.sizes().data());
|
||||
atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(view.strides().data());
|
||||
atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(src.sizes().data());
|
||||
atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(src.strides().data());
|
||||
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
|
||||
fillVersion(&atDLMTensor->tensor);
|
||||
|
||||
|
@ -624,7 +624,14 @@ struct TORCH_API IValue final {
|
||||
IValue(const c10::SymBool& i) {
|
||||
if (auto mi = i.maybe_as_bool()) {
|
||||
tag = Tag::Bool;
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
payload.u.as_int = *mi;
|
||||
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
/* due to byteorder if value assigned as_int, as_bool actually is not set correctly */
|
||||
payload.u.as_bool = *mi;
|
||||
#else
|
||||
#error Unexpected or undefined __BYTE_ORDER__
|
||||
#endif
|
||||
} else {
|
||||
tag = Tag::SymBool;
|
||||
payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
||||
#include <ATen/cuda/tunable/TunableOp.h>
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
@ -150,6 +151,7 @@ inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) {
|
||||
BLASType = "unknown";
|
||||
}
|
||||
return BLASType;
|
||||
|
||||
}
|
||||
|
||||
// Similar to Compute Type in GemmRocblas.h
|
||||
@ -244,33 +246,25 @@ inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivatio
|
||||
|
||||
namespace detail {
|
||||
|
||||
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
|
||||
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size, const NumericalCheckConfig& config) {
|
||||
|
||||
if (!config.enabled) {
|
||||
return true; // skip when disabled
|
||||
}
|
||||
|
||||
auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
|
||||
// comparison done as 1D tensor
|
||||
at::Tensor ref = at::from_blob(c, {size}, options);
|
||||
at::Tensor oth = at::from_blob(other_c, {size}, options);
|
||||
at::Tensor ref_float = ref.to(at::kFloat);
|
||||
at::Tensor oth_float = oth.to(at::kFloat);
|
||||
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
||||
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
||||
double last_succeed_atol = 1;
|
||||
double last_succeed_rtol = 1;
|
||||
for (auto& atol : atols) {
|
||||
for (auto& rtol : rtols) {
|
||||
if (at::allclose(ref_float, oth_float, rtol, atol)) {
|
||||
last_succeed_atol = atol;
|
||||
last_succeed_rtol = rtol;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (last_succeed_atol == 1) {
|
||||
return false;
|
||||
}
|
||||
else {
|
||||
TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
|
||||
}
|
||||
|
||||
return true;
|
||||
const bool ok = at::allclose(ref_float, oth_float, config.rtol, config.atol);
|
||||
if (ok) {
|
||||
TUNABLE_LOG3("├──verify numerics: PASSED with atol=", config.atol, ", rtol=", config.rtol);
|
||||
} else {
|
||||
TUNABLE_LOG3("├──verify numerics: FAILED with atol=", config.atol, ", rtol=", config.rtol);
|
||||
}
|
||||
return ok;
|
||||
}
|
||||
|
||||
}
|
||||
@ -355,8 +349,10 @@ struct GemmParams : OpParams {
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(GemmParams<T> *other) {
|
||||
auto* ctx = getTuningContext();
|
||||
auto cfg = ctx->GetNumericalCheckConfig();
|
||||
auto c_dtype = c10::CppTypeToScalarType<T>::value;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||
}
|
||||
|
||||
char transa{};
|
||||
@ -449,8 +445,10 @@ struct GemmAndBiasParams : OpParams {
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
|
||||
auto* ctx = getTuningContext();
|
||||
auto cfg = ctx->GetNumericalCheckConfig();
|
||||
auto c_dtype = c10::CppTypeToScalarType<T>::value;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||
}
|
||||
|
||||
char transa{};
|
||||
@ -546,8 +544,10 @@ struct GemmStridedBatchedParams : OpParams {
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
|
||||
auto* ctx = getTuningContext();
|
||||
auto cfg = ctx->GetNumericalCheckConfig();
|
||||
auto c_dtype = c10::CppTypeToScalarType<C_Dtype>::value;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||
}
|
||||
|
||||
char transa{};
|
||||
@ -663,7 +663,9 @@ struct ScaledGemmParams : OpParams {
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
||||
auto* ctx = getTuningContext();
|
||||
auto cfg = ctx->GetNumericalCheckConfig();
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||
}
|
||||
|
||||
char transa{};
|
||||
|
@ -145,7 +145,7 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins
|
||||
| PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. |
|
||||
| PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. |
|
||||
| PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. |
|
||||
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is 0. Set to 1 to enable. |
|
||||
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is off. Set 'atol_rtol' to enable, for example "1e-5_1e-5". |
|
||||
| PYTORCH_TUNABLEOP_ROCBLAS_ENABLED | Default is 1. Set to 0 to disable rocblas being considered during tuning. |
|
||||
| PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED | Default is 1. Set to 0 to disable hipblaslt being considered during tuning. |
|
||||
| PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS | Default is 30. Unit is milliseconds. |
|
||||
@ -173,10 +173,9 @@ All python APIs exist in the `torch.cuda.tunable` module.
|
||||
| get_max_tuning_iterations() -> int | |
|
||||
| set_filename(filename: str, insert_device_ordinal: bool = False) -> None | |
|
||||
| get_filename() -> str | |
|
||||
| set_numerical_check_tolerances(enable: bool, atol: float, rtol: float) -> None | Enable or disable numerical checking; atol and rtol default to 1e-5.
|
||||
| get_results() -> Tuple[str, str, str, float] | |
|
||||
| get_validators() -> Tuple[str, str] | |
|
||||
| write_file_on_exit(val: bool) -> None | Default is True. |
|
||||
| write_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
|
||||
| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
|
||||
| tune_gemm_in_file(filename: str) -> None | read an untuned file and tune GEMMs in it. |
|
||||
| mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: -> None | read one or more untuned files and tune all unique GEMMs on one or more GPUs. |
|
||||
|
@ -107,14 +107,30 @@ void TuningResultsManager::AddImpl(const std::string& op_signature,
|
||||
}
|
||||
|
||||
void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, ResultEntry best) {
|
||||
std::scoped_lock l{lock_};
|
||||
bool is_new = false;
|
||||
ResultEntry inserted = ResultEntry::Null();
|
||||
|
||||
auto it = results_.find(op_signature);
|
||||
if (it == results_.end()) {
|
||||
it = results_.insert({op_signature, {}}).first;
|
||||
// ---- mutate maps under results lock ----
|
||||
{
|
||||
std::scoped_lock l{lock_};
|
||||
auto& km = results_[op_signature]; // creates if missing
|
||||
is_new = (km.find(params_signature) == km.end());
|
||||
AddImpl(op_signature, params_signature, std::move(best), km);
|
||||
if (is_new) {
|
||||
inserted = km.at(params_signature); // snapshot for I/O after unlocking
|
||||
}
|
||||
}
|
||||
if (!is_new) return; // only write once per unique (op, params)
|
||||
|
||||
TuningContext* ctx = getTuningContext();
|
||||
if (ctx->IsTuningEnabled() && !ctx->IsRecordUntunedEnabled()) {
|
||||
InitRealtimeAppend(ctx->GetFilename(), ctx->GetTuningResultsValidator().GetAllValidators());
|
||||
|
||||
if (is_new && realtime_out_ && realtime_out_->good()) {
|
||||
AppendResultLine(op_signature, params_signature, inserted);
|
||||
}
|
||||
}
|
||||
|
||||
AddImpl(op_signature, params_signature, std::move(best), it->second);
|
||||
}
|
||||
|
||||
void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
|
||||
@ -150,6 +166,77 @@ void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std
|
||||
}
|
||||
}
|
||||
|
||||
void TuningResultsManager::InitRealtimeAppend(const std::string& filename, const std::unordered_map<std::string, std::string>& validators) {
|
||||
std::scoped_lock fl{realtime_file_mutex_};
|
||||
|
||||
if (realtime_out_ && realtime_out_->good() && realtime_filename_ == filename) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (realtime_out_ && realtime_filename_ != filename) {
|
||||
realtime_out_->flush();
|
||||
realtime_out_->close();
|
||||
realtime_out_.reset();
|
||||
validators_written_ = false;
|
||||
}
|
||||
|
||||
bool file_exists = false;
|
||||
bool file_empty = true;
|
||||
|
||||
{
|
||||
std::ifstream check_file(filename);
|
||||
if (check_file.good()) {
|
||||
file_exists = true;
|
||||
file_empty = (check_file.peek() == std::ifstream::traits_type::eof());
|
||||
}
|
||||
}
|
||||
|
||||
realtime_out_ = std::make_unique<std::ofstream>(filename, std::ios::out | std::ios::app);
|
||||
|
||||
if (!realtime_out_->good()) {
|
||||
TORCH_WARN("TunableOp realtime append: failed to open '", filename,"'");
|
||||
realtime_out_.reset();
|
||||
return;
|
||||
}
|
||||
|
||||
if(!file_exists || file_empty) {
|
||||
for(const auto& [key, val] : validators) {
|
||||
(*realtime_out_) << "Validator," << key << "," << val << std::endl;
|
||||
realtime_out_->flush();
|
||||
}
|
||||
validators_written_ = true;
|
||||
|
||||
TUNABLE_LOG2("Wrote validators to realtime output file");
|
||||
}
|
||||
|
||||
realtime_filename_ = filename;
|
||||
}
|
||||
|
||||
void TuningResultsManager::AppendResultLine(const std::string& op_sig, const std::string& param_sig, const ResultEntry& result) {
|
||||
std::scoped_lock fl{realtime_file_mutex_};
|
||||
|
||||
if(!realtime_out_ || !realtime_out_->good()) {
|
||||
return;
|
||||
}
|
||||
|
||||
(*realtime_out_) << op_sig << "," << param_sig << "," << result << std::endl;
|
||||
realtime_out_->flush(); //ensure immediate write to disk
|
||||
|
||||
TUNABLE_LOG3("Realtime append: ", op_sig, "(", param_sig, ") -> ", result);
|
||||
}
|
||||
|
||||
void TuningResultsManager::CloseRealtimeAppend() {
|
||||
std::scoped_lock fl{realtime_file_mutex_};
|
||||
|
||||
|
||||
if(realtime_out_) {
|
||||
realtime_out_->flush();
|
||||
realtime_out_->close();
|
||||
realtime_out_.reset();
|
||||
TUNABLE_LOG2("Closed realtime output file");
|
||||
}
|
||||
}
|
||||
|
||||
void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) {
|
||||
std::scoped_lock l{lock_};
|
||||
|
||||
@ -396,7 +483,6 @@ TuningContext::TuningContext() :
|
||||
tuning_enable_{true},
|
||||
record_untuned_enable_{false},
|
||||
manager_initialized_{false},
|
||||
write_file_on_exit_{true},
|
||||
numerics_check_enable_{false},
|
||||
max_tuning_duration_ms_{30},
|
||||
max_tuning_iterations_{100},
|
||||
@ -417,20 +503,8 @@ TuningContext::~TuningContext() {
|
||||
// but doesn't do any computation itself.
|
||||
return;
|
||||
}
|
||||
auto filename = GetFilename();
|
||||
if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty() && write_file_on_exit_) {
|
||||
if (results_count_from_input_file_ < GetTuningResultsManager().GetSize()) {
|
||||
if (results_count_from_input_file_ > 0) {
|
||||
TUNABLE_LOG1("additional tuning results available, rewriting file ", filename);
|
||||
}
|
||||
else {
|
||||
TUNABLE_LOG1("writing file ", filename);
|
||||
}
|
||||
if (!WriteFile(filename)) {
|
||||
TUNABLE_LOG1("failed to write file ", filename);
|
||||
}
|
||||
}
|
||||
}
|
||||
TUNABLE_LOG1("Closing File");
|
||||
GetTuningResultsManager().CloseRealtimeAppend(); // Since, we do instant logging by default now.
|
||||
|
||||
if (untuned_file_.good()) {
|
||||
untuned_file_.close();
|
||||
@ -511,20 +585,54 @@ std::ofstream& TuningContext::GetUntunedFile(){
|
||||
return untuned_file_;
|
||||
}
|
||||
|
||||
void TuningContext::WriteFileOnExit(bool value) {
|
||||
write_file_on_exit_ = value;
|
||||
}
|
||||
|
||||
void TuningContext::EnableNumericsCheck(bool value) {
|
||||
numerics_check_enable_ = value;
|
||||
}
|
||||
|
||||
bool TuningContext::IsNumericsCheckEnabled() const {
|
||||
const auto env = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
|
||||
if (env == "1") {
|
||||
return true;
|
||||
NumericalCheckConfig TuningContext::GetNumericalCheckConfig() const {
|
||||
const auto env_opt = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
|
||||
|
||||
if (!env_opt.has_value()) {
|
||||
return numerics_cfg_;
|
||||
}
|
||||
return numerics_check_enable_;
|
||||
|
||||
const std::string& env = env_opt.value();
|
||||
|
||||
if (env == "0") {
|
||||
return NumericalCheckConfig(false, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
const size_t underscore = env.find('_');
|
||||
|
||||
TORCH_CHECK(
|
||||
underscore != std::string::npos,
|
||||
"Invalid PYTORCH_TUNABLEOP_NUMERICAL_CHECK format. "
|
||||
"Expected 'atol_rtol', got: ",
|
||||
env);
|
||||
|
||||
double atol = 0.0;
|
||||
double rtol = 0.0;
|
||||
|
||||
try {
|
||||
atol = std::stod(env.substr(0, underscore));
|
||||
rtol = std::stod(env.substr(underscore + 1));
|
||||
} catch (const std::exception& e) {
|
||||
TORCH_CHECK(false, "Failed to parse PYTORCH_TUNABLEOP_NUMERICAL_CHECK: ", e.what());
|
||||
}
|
||||
|
||||
TORCH_CHECK( atol > 0.0 && rtol > 0.0, "Tolerance values must be positive. atol=", atol, ", rtol=", rtol);
|
||||
return NumericalCheckConfig(true, atol, rtol);
|
||||
}
|
||||
|
||||
void TuningContext::SetNumericalCheckConfig(bool enabled, double atol, double rtol) {
|
||||
TORCH_CHECK(atol > 0.0 && rtol > 0.0, "Numerical check tolerances must be positive");
|
||||
numerics_cfg_ = {enabled, atol, rtol};
|
||||
}
|
||||
|
||||
bool TuningContext::IsNumericsCheckEnabled() const {
|
||||
const auto cfg = GetNumericalCheckConfig();
|
||||
return cfg.enabled || numerics_check_enable_;
|
||||
}
|
||||
|
||||
void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) {
|
||||
@ -634,11 +742,6 @@ TuningResultsManager& TuningContext::GetTuningResultsManager() {
|
||||
auto filename = GetFilename();
|
||||
if (!filename.empty() && !IsRecordUntunedEnabled()) {
|
||||
ReadFile(filename);
|
||||
// attempt immediately to open file for writing to catch errors early
|
||||
std::ofstream file(filename, std::ios::out | std::ios::app);
|
||||
if (!file.good()) {
|
||||
TORCH_WARN("failed to open file '", filename, "' for writing; your tuning results will not be saved");
|
||||
}
|
||||
}
|
||||
});
|
||||
return manager_;
|
||||
@ -744,27 +847,6 @@ bool TuningContext::ReadFile(const std::string& filename_) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TuningContext::WriteFile(const std::string& filename_) {
|
||||
std::string filename = filename_.empty() ? GetFilename() : filename_;
|
||||
std::ofstream file(filename, std::ios::out | std::ios::trunc);
|
||||
if (!file.good()) {
|
||||
TUNABLE_LOG1("error opening tuning results file for writing ", filename);
|
||||
return false;
|
||||
}
|
||||
auto validators = GetTuningResultsValidator().GetAllValidators();
|
||||
for (const auto& [key, val] : validators) {
|
||||
file << "Validator," << key << "," << val << std::endl;
|
||||
}
|
||||
auto results = GetTuningResultsManager().Dump();
|
||||
for (const auto& [op_sig, kernelmap] : results) {
|
||||
for (const auto& [param_sig, result] : kernelmap) {
|
||||
file << op_sig << "," << param_sig << "," << result << std::endl;
|
||||
}
|
||||
}
|
||||
file.close();
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct MaybeDelete {
|
||||
|
@ -103,10 +103,24 @@ class TORCH_CUDA_CPP_API TuningResultsManager {
|
||||
|
||||
void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
|
||||
const std::string& params_signature, const std::string& blas_signature);
|
||||
|
||||
void InitRealtimeAppend(
|
||||
const std::string& filename,
|
||||
const std::unordered_map<std::string, std::string>& validators);
|
||||
|
||||
void AppendResultLine(const std::string& op_sig,
|
||||
const std::string& param_sig,
|
||||
const ResultEntry& result);
|
||||
|
||||
void CloseRealtimeAppend(); // For clean shutdown
|
||||
private:
|
||||
std::mutex lock_;
|
||||
std::mutex realtime_file_mutex_;
|
||||
std::unique_ptr<std::ofstream> realtime_out_;
|
||||
std::string realtime_filename_;
|
||||
ResultsMap results_;
|
||||
UntunedMap untuned_results_;
|
||||
bool validators_written_ = false;
|
||||
|
||||
};
|
||||
|
||||
@ -134,6 +148,16 @@ class TORCH_CUDA_CPP_API TuningResultsValidator {
|
||||
GetValidateFuncs validators_;
|
||||
};
|
||||
|
||||
struct NumericalCheckConfig {
|
||||
bool enabled{false};
|
||||
double atol{1e-5};
|
||||
double rtol{1e-5};
|
||||
|
||||
NumericalCheckConfig() = default;
|
||||
NumericalCheckConfig(bool e, double a, double r) : enabled(e), atol(a), rtol(r) {}
|
||||
};
|
||||
|
||||
|
||||
class TORCH_CUDA_CPP_API TuningContext {
|
||||
public:
|
||||
TuningContext();
|
||||
@ -155,6 +179,8 @@ class TORCH_CUDA_CPP_API TuningContext {
|
||||
|
||||
void EnableNumericsCheck(bool value);
|
||||
bool IsNumericsCheckEnabled() const;
|
||||
void SetNumericalCheckConfig(bool enabled, double atol, double rtol);
|
||||
NumericalCheckConfig GetNumericalCheckConfig() const;
|
||||
|
||||
void SetMaxTuningDurationMs(int max_duration_ms);
|
||||
int GetMaxTuningDurationMs() const;
|
||||
@ -185,10 +211,7 @@ class TORCH_CUDA_CPP_API TuningContext {
|
||||
void SetFilename(const std::string& filename, bool insert_device_ordinal=false);
|
||||
std::string GetFilename() const;
|
||||
|
||||
void WriteFileOnExit(bool value);
|
||||
|
||||
bool ReadFile(const std::string& filename={});
|
||||
bool WriteFile(const std::string& filename={});
|
||||
|
||||
template<class... Types>
|
||||
void Log(int level, Types... args) {
|
||||
@ -207,7 +230,6 @@ class TORCH_CUDA_CPP_API TuningContext {
|
||||
bool tuning_enable_;
|
||||
bool record_untuned_enable_;
|
||||
bool manager_initialized_;
|
||||
bool write_file_on_exit_;
|
||||
bool numerics_check_enable_;
|
||||
int max_tuning_duration_ms_;
|
||||
int max_tuning_iterations_;
|
||||
@ -222,6 +244,8 @@ class TORCH_CUDA_CPP_API TuningContext {
|
||||
std::ofstream untuned_file_;
|
||||
size_t results_count_from_input_file_;
|
||||
bool is_shutting_down_;
|
||||
|
||||
NumericalCheckConfig numerics_cfg_{};
|
||||
};
|
||||
|
||||
TORCH_CUDA_CPP_API TuningContext* getTuningContext();
|
||||
|
@ -267,27 +267,10 @@ class TunableOp {
|
||||
for (size_t i = 0; i < op_names_.size(); i++) {
|
||||
auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
|
||||
|
||||
if (do_numerics_check) {
|
||||
ParamsT* numerical_params = params->DeepCopy(false);
|
||||
auto status = candidate->Call(numerical_params);
|
||||
if (status != OK) {
|
||||
numerical_params->Delete();
|
||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
status = reference_params->NumericalCheck(numerical_params);
|
||||
numerical_params->Delete();
|
||||
if (status != OK) {
|
||||
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
else {
|
||||
auto status = candidate->Call(reusable_params[0]);
|
||||
if (status != OK) {
|
||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
auto status = candidate->Call(reusable_params[0]);
|
||||
if (status != OK) {
|
||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
|
||||
// collect a small profile
|
||||
@ -310,6 +293,22 @@ class TunableOp {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (do_numerics_check) {
|
||||
ParamsT* numerical_params = params->DeepCopy(false);
|
||||
auto status = candidate->Call(numerical_params);
|
||||
if (status != OK) {
|
||||
numerical_params->Delete();
|
||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
status = reference_params->NumericalCheck(numerical_params);
|
||||
numerical_params->Delete();
|
||||
if (status != OK) {
|
||||
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// for warmup does user set max duration, max iters, or both?
|
||||
// warmup is skipped by default, i.e. warmup_iter = 0
|
||||
// warmup will be set to the non-zero value of max_warmup_duration
|
||||
|
@ -213,40 +213,22 @@ static cudnn_grid_sample_backward_batch_rule(
|
||||
return grid_sample_backward_helper_out(std::move(bw_out), 0, 0, bdim_size);
|
||||
}
|
||||
|
||||
// TODO: replace with targetable functionalization
|
||||
// uses functional formulation for one_hot under vmap to be compatible with
|
||||
// fakeTensor/dynamic shapes and compiled functorch transforms.
|
||||
// mirrors the meta path in aten/src/ATen/native/Onehot.cpp,
|
||||
// but requires explicit positive num_classes under vmap to avoid
|
||||
// data-dependent output shapes.
|
||||
static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes) {
|
||||
TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
|
||||
auto shape = self.sym_sizes().vec();
|
||||
|
||||
// empty tensor could be converted to one hot representation,
|
||||
// but shape inference is not possible.
|
||||
if (self.sym_numel() == 0) {
|
||||
if (num_classes <= 0) {
|
||||
TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
|
||||
} else {
|
||||
shape.emplace_back(num_classes);
|
||||
return at::empty_symint(shape, self.options());
|
||||
}
|
||||
}
|
||||
|
||||
// disallow implicit inference under vmap; this would be data-dependent
|
||||
// and is intentionally guarded by Dynamo in torch/_dynamo/variables/torch.py.
|
||||
TORCH_CHECK(num_classes > 0, "When vmap-ing torch.nn.functional.one_hot, please "
|
||||
"provide an explicit positive num_classes argument.");
|
||||
|
||||
// Disabling all of the following checks. This is OK because scatter has checks too.
|
||||
// Maybe one_hot should be a primitive wrt autograd so we don't have to deal with this.
|
||||
// // non-empty tensor
|
||||
// if (self.device().type() != at::kCUDA) {
|
||||
// //for cuda, rely on device assert thrown by scatter
|
||||
// TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
|
||||
// }
|
||||
// if (self.device().type() != at::kCUDA) {
|
||||
// //rely on device asserts from scatter to avoid sync here
|
||||
// TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
|
||||
// }
|
||||
|
||||
shape.emplace_back(num_classes);
|
||||
Tensor ret = at::zeros_symint(shape, self.options());
|
||||
return ret.scatter(-1, self.unsqueeze(-1), 1);
|
||||
const auto options = self.options();
|
||||
at::Tensor index = at::arange(num_classes, options);
|
||||
return at::eq(self.unsqueeze(-1), index).to(at::kLong);
|
||||
}
|
||||
|
||||
template <typename A, A a, typename C>
|
||||
|
@ -34,16 +34,16 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
|
||||
}
|
||||
}
|
||||
|
||||
auto shape = self.sizes().vec();
|
||||
auto shape = self.sym_sizes().vec();
|
||||
|
||||
// empty tensor could be converted to one hot representation,
|
||||
// but shape inference is not possible.
|
||||
if (self.numel() == 0) {
|
||||
if (self.sym_numel() == 0) {
|
||||
if (num_classes <= 0) {
|
||||
TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
|
||||
} else {
|
||||
shape.push_back(num_classes);
|
||||
return at::empty(shape, self.options());
|
||||
shape.emplace_back(num_classes);
|
||||
return at::empty_symint(shape, self.options());
|
||||
}
|
||||
}
|
||||
|
||||
@ -66,8 +66,8 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
|
||||
}
|
||||
}
|
||||
|
||||
shape.push_back(num_classes);
|
||||
Tensor ret = at::zeros(shape, self.options());
|
||||
shape.emplace_back(num_classes);
|
||||
Tensor ret = at::zeros_symint(shape, self.options());
|
||||
ret.scatter_(-1, self.unsqueeze(-1), 1);
|
||||
return ret;
|
||||
}
|
||||
|
@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
|
||||
} else if (dtype == ScalarType::Half) {
|
||||
[&]() {
|
||||
using scalar_t =
|
||||
decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
|
||||
c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
|
||||
const auto exp = exp_scalar.to<scalar_t>();
|
||||
using Vec = Vectorized<scalar_t>;
|
||||
cpu_kernel_vec(iter,
|
||||
|
@ -1230,8 +1230,205 @@ std::pair<ScalingType, ScalingType> get_joint_scaling(
|
||||
);
|
||||
}
|
||||
|
||||
Tensor&
|
||||
_tunable_scaled_gemm_rocm(
|
||||
cublasCommonArgs& args,
|
||||
const Tensor& mat1, const Tensor& mat2,
|
||||
const Tensor& scale_a, const Tensor& scale_b,
|
||||
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
const at::ScalarType out_dtype,
|
||||
Tensor& out) {
|
||||
#ifdef USE_ROCM
|
||||
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
|
||||
if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
} \
|
||||
else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
} \
|
||||
else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
} \
|
||||
else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2, at::Float8_e5m2, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
}
|
||||
AT_DISPATCH_V2(out_dtype, "_tunable_scaled_gemm", AT_WRAP([&] {
|
||||
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
|
||||
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
|
||||
at::cuda::tunable::ScaledGemmParams<scalar_t> params;
|
||||
params.transa = args.transa;
|
||||
params.transb = args.transb;
|
||||
params.m = args.m;
|
||||
params.n = args.n;
|
||||
params.k = args.k;
|
||||
params.a = args.mata->data_ptr();
|
||||
params.a_scale_ptr = args.scale_mata_ptr;
|
||||
params.a_scale_dtype = args.scale_mata_dtype.value();
|
||||
params.lda = args.lda;
|
||||
params.a_dtype = args.mata->scalar_type();
|
||||
params.a_scale_dtype = args.scale_mata_dtype.value();
|
||||
params.a_scaling_type = args.scaling_mata_type.value();
|
||||
params.b = args.matb->data_ptr();
|
||||
params.b_scale_ptr = args.scale_matb_ptr;
|
||||
params.b_scale_dtype = args.scale_matb_dtype.value();
|
||||
params.ldb = args.ldb;
|
||||
params.b_dtype = args.matb->scalar_type();
|
||||
params.b_scale_dtype = args.scale_matb_dtype.value();
|
||||
params.b_scaling_type = args.scaling_matb_type.value();
|
||||
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
|
||||
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype) ? at::ScalarType::Half : out_dtype;
|
||||
params.c = args.result->data_ptr();
|
||||
params.c_scale_ptr = args.scale_result_ptr;
|
||||
params.ldc = args.result_ld;
|
||||
params.c_dtype = out_dtype;
|
||||
params.use_fast_accum = use_fast_accum;
|
||||
if (transa_ && transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
|
||||
}
|
||||
else if (transa_ && !transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
|
||||
}
|
||||
else if (!transa_ && transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
|
||||
}
|
||||
else if (!transa_ && !transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
}),
|
||||
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
|
||||
#undef TUNABLE_DISPATCH
|
||||
return out;
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_gemm_rocm only callable on ROCM devices");
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
_scaled_gemm(
|
||||
const Tensor& mat1, const Tensor& mat2,
|
||||
const Tensor& scale_a, const Tensor& scale_b,
|
||||
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b);
|
||||
const auto out_dtype_ = args.result->scalar_type();
|
||||
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
||||
|
||||
// ROCM enables the TunableOp path only
|
||||
// but can fallback to at::cuda::blas::scaled_gemm
|
||||
#ifdef USE_ROCM
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
bool tunable_op_enabled = tuning_ctx->IsTunableOpEnabled();
|
||||
#else
|
||||
bool tunable_op_enabled = false;
|
||||
#endif
|
||||
if (tunable_op_enabled) {
|
||||
// Only available on ROCM
|
||||
return _tunable_scaled_gemm_rocm(
|
||||
args,
|
||||
mat1, mat2,
|
||||
scale_a, scale_b,
|
||||
scaling_choice_a, scaling_choice_b,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out_dtype_,
|
||||
out);
|
||||
}
|
||||
else
|
||||
{
|
||||
at::cuda::blas::scaled_gemm(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
args.mata->data_ptr(),
|
||||
args.scale_mata_ptr,
|
||||
args.lda,
|
||||
args.mata->scalar_type(),
|
||||
args.scale_mata_dtype.value(),
|
||||
args.scaling_mata_type.value(),
|
||||
args.matb->data_ptr(),
|
||||
args.scale_matb_ptr,
|
||||
args.ldb,
|
||||
args.matb->scalar_type(),
|
||||
args.scale_matb_dtype.value(),
|
||||
args.scaling_matb_type.value(),
|
||||
bias ? bias->data_ptr(): nullptr,
|
||||
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
|
||||
args.result->data_ptr(),
|
||||
args.scale_result_ptr,
|
||||
args.result_ld,
|
||||
out_dtype_,
|
||||
use_fast_accum);
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// NOTE(slayton58): This is defined as part of the _v2 code (way) below - declare the signature here
|
||||
// to help cleanup v1 call structure.
|
||||
Tensor&
|
||||
_scaled_rowwise_rowwise(
|
||||
const Tensor&, const Tensor&,
|
||||
const Tensor&, const Tensor&,
|
||||
const std::optional<Tensor>&,
|
||||
const c10::ScalarType,
|
||||
bool,
|
||||
Tensor&);
|
||||
|
||||
|
||||
// Computes matrix multiply + bias while applying scaling to input and output matrices
|
||||
// Scales are only applicable when matrices are of Float8 type and assumed to be equal to 1.0 by default.
|
||||
// If output matrix type is 16 or 32-bit type, scale_result is not applied.
|
||||
@ -1273,6 +1470,10 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
// by decreasing priority. We prefer "simpler" schemes as they are supported
|
||||
// more broadly (more GPU archs, more CUDA versions) and because they are more
|
||||
// efficient. This tends to matter only for small matmuls (e.g., 1x1x128).
|
||||
|
||||
// List of supported BlockWise pairs for FP8:
|
||||
// https://docs.nvidia.com/cuda/cublas/#element-1d-and-128x128-2d-block-scaling-for-fp8-data-types
|
||||
|
||||
auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling(
|
||||
{
|
||||
std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise),
|
||||
@ -1305,7 +1506,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
TORCH_CHECK(isFloat8Type(mat2.scalar_type()) || mat2.scalar_type() == ScalarType::Float4_e2m1fn_x2, "Expected mat2 to be Float8 or Float4_x2 matrix got ", mat2.scalar_type());
|
||||
#ifndef USE_ROCM
|
||||
// Type restrictions imposed by CuBLASLt as of CUDA-12.1
|
||||
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
|
||||
TORCH_CHECK_VALUE(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
|
||||
"Multiplication of two Float8_e5m2 matrices is not supported");
|
||||
#endif
|
||||
if (use_fast_accum) {
|
||||
@ -1371,41 +1572,44 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
|
||||
// NVIDIA's cuBLAS only started supporting row-wise scaling in version 12.9,
|
||||
// and only for compute capability 9.0+. In other cases we use CUTLASS.
|
||||
#ifndef USE_ROCM
|
||||
// We are doing row-wise scaling
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise
|
||||
&& ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
|
||||
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|
||||
|| (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty())))) {
|
||||
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
|
||||
at::cuda::detail::f8f8bf16_rowwise(
|
||||
mat1,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
return out;
|
||||
}
|
||||
#else
|
||||
if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) {
|
||||
#ifndef USE_ROCM
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
|
||||
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|
||||
|| (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty()))) {
|
||||
TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
|
||||
return _scaled_rowwise_rowwise(
|
||||
mat1,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
bias,
|
||||
out.scalar_type(),
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
#else
|
||||
// For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes.
|
||||
Tensor b = mat2;
|
||||
if (_scaled_mm_is_fnuz()) {
|
||||
TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fnuz);
|
||||
TORCH_CHECK_VALUE(b.dtype() == at::kFloat8_e4m3fnuz,
|
||||
"Expected b.dtype() == at::kFloat8_e4m3fnuz, got: ", b.dtype());
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn);
|
||||
TORCH_CHECK_VALUE(b.dtype() == at::kFloat8_e4m3fn,
|
||||
"Expected b.dtype() == at::kFloat8_e4m3fn, got: ", b.dtype());
|
||||
}
|
||||
// Until more than bf16 is supported.
|
||||
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16,
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16,
|
||||
"hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
|
||||
#endif
|
||||
}
|
||||
else if (scaling_choice_a == ScalingType::BlockWise1x32 && scaling_choice_b == ScalingType::BlockWise1x32) {
|
||||
#ifdef USE_ROCM
|
||||
#if ROCM_VERSION >= 70000
|
||||
TORCH_CHECK(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
|
||||
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
|
||||
|
||||
int packed_factor = 1;
|
||||
@ -1414,163 +1618,20 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
// effectively packing two elements into one byte.
|
||||
packed_factor = 2;
|
||||
}
|
||||
TORCH_CHECK(mat1.size(0) % 16 == 0 && (mat1.size(1) * packed_factor) % 128 == 0 &&
|
||||
TORCH_CHECK_VALUE(mat1.size(0) % 16 == 0 && (mat1.size(1) * packed_factor) % 128 == 0 &&
|
||||
mat2.size(1) % 16 == 0,
|
||||
"M, N must be multiples of 16 and K must be multiple of 128 for block-wise scaling");
|
||||
|
||||
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 ||
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
|
||||
out.scalar_type() == ScalarType::Half,
|
||||
"Block-wise scaling only supports BFloat16 or Half output types");
|
||||
#else
|
||||
TORCH_CHECK(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b);
|
||||
const auto out_dtype_ = args.result->scalar_type();
|
||||
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
||||
|
||||
#ifdef USE_ROCM
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
|
||||
if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
} \
|
||||
else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
} \
|
||||
else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
} \
|
||||
else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2, at::Float8_e5m2, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
}
|
||||
AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] {
|
||||
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
|
||||
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
|
||||
at::cuda::tunable::ScaledGemmParams<scalar_t> params;
|
||||
params.transa = args.transa;
|
||||
params.transb = args.transb;
|
||||
params.m = args.m;
|
||||
params.n = args.n;
|
||||
params.k = args.k;
|
||||
params.a = args.mata->data_ptr();
|
||||
params.a_scale_ptr = args.scale_mata_ptr;
|
||||
params.a_scale_dtype = args.scale_mata_dtype.value();
|
||||
params.lda = args.lda;
|
||||
params.a_dtype = args.mata->scalar_type();
|
||||
params.a_scale_dtype = args.scale_mata_dtype.value();
|
||||
params.a_scaling_type = args.scaling_mata_type.value();
|
||||
params.b = args.matb->data_ptr();
|
||||
params.b_scale_ptr = args.scale_matb_ptr;
|
||||
params.b_scale_dtype = args.scale_matb_dtype.value();
|
||||
params.ldb = args.ldb;
|
||||
params.b_dtype = args.matb->scalar_type();
|
||||
params.b_scale_dtype = args.scale_matb_dtype.value();
|
||||
params.b_scaling_type = args.scaling_matb_type.value();
|
||||
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
|
||||
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
|
||||
params.c = args.result->data_ptr();
|
||||
params.c_scale_ptr = args.scale_result_ptr;
|
||||
params.ldc = args.result_ld;
|
||||
params.c_dtype = out_dtype_;
|
||||
params.use_fast_accum = use_fast_accum;
|
||||
if (transa_ && transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
|
||||
}
|
||||
else if (transa_ && !transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
|
||||
}
|
||||
else if (!transa_ && transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
|
||||
}
|
||||
else if (!transa_ && !transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
}),
|
||||
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
|
||||
#undef TUNABLE_DISPATCH
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
at::cuda::blas::scaled_gemm(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
args.mata->data_ptr(),
|
||||
args.scale_mata_ptr,
|
||||
args.lda,
|
||||
args.mata->scalar_type(),
|
||||
args.scale_mata_dtype.value(),
|
||||
args.scaling_mata_type.value(),
|
||||
args.matb->data_ptr(),
|
||||
args.scale_matb_ptr,
|
||||
args.ldb,
|
||||
args.matb->scalar_type(),
|
||||
args.scale_matb_dtype.value(),
|
||||
args.scaling_matb_type.value(),
|
||||
bias ? bias->data_ptr(): nullptr,
|
||||
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
|
||||
args.result->data_ptr(),
|
||||
args.scale_result_ptr,
|
||||
args.result_ld,
|
||||
out_dtype_,
|
||||
use_fast_accum);
|
||||
}
|
||||
|
||||
return out;
|
||||
return _scaled_gemm(mat1, mat2, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -1910,159 +1971,6 @@ std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8>
|
||||
{ "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
|
||||
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||
|
||||
Tensor&
|
||||
_cutlass_scaled_gemm(
|
||||
const Tensor& mat1, const Tensor& mat2,
|
||||
const Tensor& scale_a, const Tensor& scale_b,
|
||||
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b);
|
||||
const auto out_dtype_ = args.result->scalar_type();
|
||||
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
||||
|
||||
#ifdef USE_ROCM
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
|
||||
if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
} \
|
||||
else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
} \
|
||||
else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
} \
|
||||
else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \
|
||||
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
|
||||
static at::cuda::tunable::ScaledGemmTunableOp< \
|
||||
at::Float8_e5m2, at::Float8_e5m2, scalar_t, \
|
||||
BLASOP_A, BLASOP_B> scaledgemm{}; \
|
||||
scaledgemm(¶ms); \
|
||||
} \
|
||||
}
|
||||
AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] {
|
||||
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
|
||||
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
|
||||
at::cuda::tunable::ScaledGemmParams<scalar_t> params;
|
||||
params.transa = args.transa;
|
||||
params.transb = args.transb;
|
||||
params.m = args.m;
|
||||
params.n = args.n;
|
||||
params.k = args.k;
|
||||
params.a = args.mata->data_ptr();
|
||||
params.a_scale_ptr = args.scale_mata_ptr;
|
||||
params.a_scale_dtype = args.scale_mata_dtype.value();
|
||||
params.lda = args.lda;
|
||||
params.a_dtype = args.mata->scalar_type();
|
||||
params.a_scale_dtype = args.scale_mata_dtype.value();
|
||||
params.a_scaling_type = args.scaling_mata_type.value();
|
||||
params.b = args.matb->data_ptr();
|
||||
params.b_scale_ptr = args.scale_matb_ptr;
|
||||
params.b_scale_dtype = args.scale_matb_dtype.value();
|
||||
params.ldb = args.ldb;
|
||||
params.b_dtype = args.matb->scalar_type();
|
||||
params.b_scale_dtype = args.scale_matb_dtype.value();
|
||||
params.b_scaling_type = args.scaling_matb_type.value();
|
||||
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
|
||||
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
|
||||
params.c = args.result->data_ptr();
|
||||
params.c_scale_ptr = args.scale_result_ptr;
|
||||
params.ldc = args.result_ld;
|
||||
params.c_dtype = out_dtype_;
|
||||
params.use_fast_accum = use_fast_accum;
|
||||
if (transa_ && transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
|
||||
}
|
||||
else if (transa_ && !transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
|
||||
}
|
||||
else if (!transa_ && transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
|
||||
}
|
||||
else if (!transa_ && !transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
}),
|
||||
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
|
||||
#undef TUNABLE_DISPATCH
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
at::cuda::blas::scaled_gemm(
|
||||
args.transa,
|
||||
args.transb,
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
args.mata->data_ptr(),
|
||||
args.scale_mata_ptr,
|
||||
args.lda,
|
||||
args.mata->scalar_type(),
|
||||
args.scale_mata_dtype.value(),
|
||||
args.scaling_mata_type.value(),
|
||||
args.matb->data_ptr(),
|
||||
args.scale_matb_ptr,
|
||||
args.ldb,
|
||||
args.matb->scalar_type(),
|
||||
args.scale_matb_dtype.value(),
|
||||
args.scaling_matb_type.value(),
|
||||
bias ? bias->data_ptr(): nullptr,
|
||||
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
|
||||
args.result->data_ptr(),
|
||||
args.scale_result_ptr,
|
||||
args.result_ld,
|
||||
out_dtype_,
|
||||
use_fast_accum);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor&
|
||||
_scaled_tensorwise_tensorwise(
|
||||
const Tensor& mat_a, const Tensor& mat_b,
|
||||
@ -2082,7 +1990,7 @@ _scaled_tensorwise_tensorwise(
|
||||
auto scaling_choice_a = ScalingType::TensorWise;
|
||||
auto scaling_choice_b = ScalingType::TensorWise;
|
||||
|
||||
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
@ -2118,7 +2026,7 @@ _scaled_rowwise_rowwise(
|
||||
if (((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
|
||||
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|
||||
|| (dprops->major == 10 && (scale_a.sizes().size() || scale_b.sizes().size())))) {
|
||||
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
|
||||
TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
|
||||
at::cuda::detail::f8f8bf16_rowwise(
|
||||
mat_a,
|
||||
mat_b,
|
||||
@ -2144,11 +2052,38 @@ _scaled_rowwise_rowwise(
|
||||
"hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
|
||||
#endif
|
||||
|
||||
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
// Check the shapes & sizes of scales for deepseek-style (1x128, 128x128) scaling.
|
||||
// Wraps check_size_stride for easier integration, correctly handles cases where a dimension of the scale == 1,
|
||||
// and strides become somewhat meaningless
|
||||
void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const ScalingType scale_type) {
|
||||
if (scale_type == ScalingType::BlockWise1x128) {
|
||||
TORCH_CHECK_VALUE(check_size_stride(scale, 0, t.size(0), 1),
|
||||
"at dim=0 scale should have ", t.size(0), "elements and stride(0) ", 1, "if ", t.size(0), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
auto expected_size = ceil_div<int64_t>(t.size(1), 128);
|
||||
TORCH_CHECK_VALUE(check_size_stride(scale, 1, expected_size, t.size(0)),
|
||||
"at dim=1 scale should have ", expected_size, "elements and stride ", t.size(0), "if ", expected_size, " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
} else if (scale_type == ScalingType::BlockWise128x128) {
|
||||
TORCH_CHECK_VALUE(check_size_stride(
|
||||
scale,
|
||||
0,
|
||||
ceil_div<int64_t>(t.size(0), 128),
|
||||
ceil_div<int64_t>(t.size(1), 128)),
|
||||
"at dim=0 scale should have ", ceil_div<int64_t>(t.size(0), 128), "elements and stride(0) ", ceil_div<int64_t>(t.size(1), 128), "if ", ceil_div<int64_t>(t.size(0), 128), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
TORCH_CHECK(check_size_stride(
|
||||
scale, 1, ceil_div<int64_t>(t.size(1), 128), 1),
|
||||
"at dim=1 scale should have ", ceil_div<int64_t>(t.size(1), 128), "elements and stride(1) ", 1, "if ", ceil_div<int64_t>(t.size(1), 128), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
}
|
||||
}
|
||||
|
||||
Tensor&
|
||||
_scaled_block1x128_block1x128(
|
||||
const Tensor& mat_a, const Tensor& mat_b,
|
||||
@ -2166,15 +2101,14 @@ _scaled_block1x128_block1x128(
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
|
||||
|
||||
TORCH_CHECK(scale_a.stride(0) == 1, "expected scale_a.stride(0) to be 1, but got ", scale_a.stride(0));
|
||||
TORCH_CHECK(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
|
||||
TORCH_CHECK(scale_b.stride(0) == scale_b.size(1),
|
||||
"expected scale_b.stride(0) to be ", scale_b.size(1), ", but got ", scale_b.size(1));
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x128;
|
||||
|
||||
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
@ -2189,6 +2123,8 @@ _scaled_block128x128_block1x128(
|
||||
Tensor& out) {
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
std::cout << "mat_b: " << mat_b.dim() << ", " << mat_b.sizes() << ", " << mat_b.strides() << std::endl;
|
||||
std::cout << "scale_b: " << scale_b.dim() << ", " << scale_b.sizes() << ", " << scale_b.strides() << std::endl;
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
|
||||
@ -2196,15 +2132,14 @@ _scaled_block128x128_block1x128(
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
|
||||
|
||||
TORCH_CHECK_VALUE(scale_a.stride(1) == 1, "expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1));
|
||||
TORCH_CHECK_VALUE(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
|
||||
TORCH_CHECK_VALUE(scale_b.stride(0) == scale_b.size(1),
|
||||
"expected scale_b.stride(0) to be ", scale_b.size(1), ", but got ", scale_b.stride(0));
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise128x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x128;
|
||||
|
||||
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
@ -2226,15 +2161,14 @@ _scaled_block1x128_block128x128(
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == mat_b.sizes()[0] / 128 && scale_b.sizes()[1] == mat_b.sizes()[1] / 128 && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", mat_b.sizes()[0] / 128, " x ", mat_b.sizes()[1] / 128, " Float elements, got ", scale_b.sizes())
|
||||
|
||||
TORCH_CHECK_VALUE(scale_a.stride(0) == 1, "expected scale_a.stride(0) to be 1, but got ", scale_a.stride(0));
|
||||
TORCH_CHECK_VALUE(scale_b.stride(0) == 1, "expected scale_b.stride(0) to be 1, but got ", scale_b.stride(0));
|
||||
TORCH_CHECK_VALUE(scale_b.stride(1) == scale_b.size(0),
|
||||
"expected scale_b.stride(1) to be ", scale_b.size(0), ", but got ", scale_b.stride(1));
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise128x128;
|
||||
|
||||
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
@ -2288,7 +2222,7 @@ _scaled_mxfp8_mxfp8(
|
||||
#endif
|
||||
#endif
|
||||
|
||||
return _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
}
|
||||
|
||||
Tensor&
|
||||
@ -2325,7 +2259,7 @@ _scaled_nvfp4_nvfp4(
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x16;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x16;
|
||||
return _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
}
|
||||
|
||||
|
||||
@ -2574,7 +2508,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const SwizzleType& swizzle_a,
|
||||
const Tensor& scale_b,
|
||||
const SwizzleType& swizzle_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
Tensor& out) {
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
@ -2585,6 +2521,16 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
|
||||
TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
|
||||
// MXFP8 expects float8_e8m0fnu scales.
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
|
||||
"For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
|
||||
"For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
|
||||
#else
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
|
||||
"For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
|
||||
#endif
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
|
||||
fbgemm_gpu::mx8mx8bf16_grouped_mm(
|
||||
@ -2669,6 +2615,9 @@ _f8_f8_bf16_rowwise_grouped_mm(
|
||||
const std::optional<Tensor>& bias,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
// FP8 per-tensor and per-row scaling expect fp32 scales.
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
|
||||
"For grouped FP8 rowwise, both scales must be float32 tensors");
|
||||
#ifndef USE_ROCM
|
||||
return _f8_f8_bf16_rowwise_grouped_mm_cuda(
|
||||
mat_a,
|
||||
@ -2768,11 +2717,15 @@ _scaled_grouped_mm_cuda(
|
||||
#endif
|
||||
|
||||
if (is_mx8mx8bf16) {
|
||||
// Note: Passing implied SwizzleType here, correctness of scale previously checked
|
||||
// in `check_scale` call
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
SwizzleType::SWIZZLE_32_4_4,
|
||||
scale_b,
|
||||
SwizzleType::SWIZZLE_32_4_4,
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
@ -2789,6 +2742,140 @@ _scaled_grouped_mm_cuda(
|
||||
out);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
|
||||
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Tensor
|
||||
_scaled_grouped_mm_cuda_v2(
|
||||
const Tensor& mat_a, const Tensor& mat_b,
|
||||
ArrayRef<Tensor> scale_a,
|
||||
IntArrayRef scale_recipe_a,
|
||||
IntArrayRef swizzle_a,
|
||||
ArrayRef<Tensor> scale_b,
|
||||
IntArrayRef scale_recipe_b,
|
||||
IntArrayRef swizzle_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
const std::optional<c10::ScalarType> out_dtype,
|
||||
IntArrayRef contraction_dim,
|
||||
bool use_fast_accum) {
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
|
||||
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
|
||||
|
||||
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
|
||||
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
|
||||
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
|
||||
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
|
||||
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
if (contraction_dim.size() > 0) {
|
||||
const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
|
||||
TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
|
||||
"Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
|
||||
mat_b.size(dim_b));
|
||||
// Note: only (-1, -2) is currently supported
|
||||
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
|
||||
} else {
|
||||
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
"Expected trailing dimension of mat_a to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes(),
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
|
||||
"Expected mat_b shape to be divisible by 16 ",
|
||||
"but got mat_b shape: (",
|
||||
mat_b.sizes(),
|
||||
").");
|
||||
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
|
||||
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
|
||||
|
||||
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
|
||||
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
|
||||
// routines
|
||||
if (offs.has_value()) {
|
||||
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
|
||||
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
|
||||
}
|
||||
|
||||
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
||||
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
|
||||
// Conversion of implicitly-defined enums to explicit
|
||||
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
|
||||
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
|
||||
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
|
||||
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
|
||||
|
||||
// at this point we can start working out what we want to be doing
|
||||
// Try to do as few steps as possible.
|
||||
// NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
|
||||
// Do this via a list of defined (name, acceptance, concrete_impl) tuples.
|
||||
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
|
||||
for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
|
||||
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
|
||||
bool ok = accept_fn(mat_a.scalar_type(),
|
||||
scale_recipe_a_enum,
|
||||
scale_a,
|
||||
mat_b.scalar_type(),
|
||||
scale_recipe_b_enum,
|
||||
scale_b);
|
||||
if (ok) {
|
||||
gemm_impl = scaled_gemm_impl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
|
||||
"No gemm implementation was found");
|
||||
|
||||
switch (gemm_impl) {
|
||||
case ScaledGemmImplementation::ROWWISE_ROWWISE: {
|
||||
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
|
||||
_check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
|
||||
_check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
|
||||
return _f8_f8_bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
scale_b[0],
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
case ScaledGemmImplementation::MXFP8_MXFP8: {
|
||||
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
swizzle_a_enum[0],
|
||||
scale_b[0],
|
||||
swizzle_b_enum[0],
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
|
||||
}
|
||||
}
|
||||
|
||||
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
|
@ -856,9 +856,13 @@ struct type_specialized_kernel_launcher {
|
||||
out_calc_t output_offset_calculator,
|
||||
loader_t loader,
|
||||
storer_t storer) {
|
||||
if (ret_t == rt_binary_specializations[arg_index][0] &&
|
||||
arg0_t == rt_binary_specializations[arg_index][1] &&
|
||||
arg1_t == rt_binary_specializations[arg_index][2])
|
||||
constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
|
||||
constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
|
||||
constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
|
||||
if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
|
||||
using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
|
||||
using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
|
||||
using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
|
||||
launch_vectorized_templated_kernel<
|
||||
func_t,
|
||||
array_t,
|
||||
@ -866,12 +870,9 @@ struct type_specialized_kernel_launcher {
|
||||
out_calc_t,
|
||||
loader_t,
|
||||
storer_t,
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][0]>::t),
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][1]>::t),
|
||||
decltype(c10::impl::ScalarTypeToCPPType<
|
||||
rt_binary_specializations[arg_index][2]>::t)>(
|
||||
cret_t,
|
||||
carg0_t,
|
||||
carg1_t>(
|
||||
numel,
|
||||
f,
|
||||
data,
|
||||
@ -879,6 +880,7 @@ struct type_specialized_kernel_launcher {
|
||||
output_offset_calculator,
|
||||
loader,
|
||||
storer);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -655,8 +655,14 @@ struct ReduceOp {
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Intra-warp reduction, fix CUDA to have offset decreasing for better numerics
|
||||
// matching Triton, etc.
|
||||
// todo for AMD
|
||||
#ifdef USE_ROCM
|
||||
for (int offset = 1; offset < dim_x; offset <<= 1) {
|
||||
#else
|
||||
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
|
||||
#endif
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
arg_t other = ops.warp_shfl_down(value[i], offset);
|
||||
|
@ -77,8 +77,8 @@ struct nansum_functor_complex {
|
||||
#if AT_USE_JITERATOR()
|
||||
void operator()(TensorIterator& iter) {
|
||||
std::string func = jiterator_stringify(
|
||||
arg_t combine(arg_t a, scalar_t b) {
|
||||
return a + (std::isnan(b) ? arg_t{0.} : arg_t{b});
|
||||
arg_t combine(arg_t a, arg_t b) {
|
||||
return a + (std::isnan(b) ? arg_t{0.} : b);
|
||||
}
|
||||
);
|
||||
jitted_gpu_reduce_kernel<nansum_name, scalar_t, scalar_t>(
|
||||
|
@ -464,6 +464,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
}
|
||||
#endif
|
||||
int32_t trailingSize;
|
||||
int nDimsLocal = nDims;
|
||||
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> kernelOutputParam;
|
||||
if (isInOutAligned) {
|
||||
// in this case we can and should flatten the tensors after the cat dim
|
||||
@ -477,7 +478,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
// and divide all strides except last by elems_per_vec (last stride is 1 always)
|
||||
// for input, we will fix up the sizes and strides in the kernel directly
|
||||
kernelOutputParam = outputParam;
|
||||
nDims = dimension + 1;
|
||||
nDimsLocal = dimension + 1;
|
||||
constexpr auto elems_per_vec = alignment / sizeof(scalar_t);
|
||||
auto out_size = dimension == 0 ? out.numel() : kernelOutputParam.tensorStride[dimension-1];
|
||||
kernelOutputParam.tensorSize[dimension] = out_size / elems_per_vec;
|
||||
@ -494,7 +495,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
case 0:
|
||||
break;
|
||||
case 1:
|
||||
cat_dim = nDims - cat_dim;
|
||||
cat_dim = nDimsLocal - cat_dim;
|
||||
break;
|
||||
default:
|
||||
cat_dim--;
|
||||
@ -525,7 +526,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
|
||||
}\
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
switch (nDims) {
|
||||
switch (nDimsLocal) {
|
||||
case 1:
|
||||
HANDLE_CASE(1);
|
||||
break;
|
||||
|
@ -21,9 +21,15 @@ namespace {
|
||||
struct offset_t {
|
||||
int stride;
|
||||
int begin;
|
||||
__device__ int operator[](int i) {
|
||||
__device__ int operator[](int i) const {
|
||||
return stride * (begin + i);
|
||||
}
|
||||
#if CCCL_VERSION >= 3001000
|
||||
__device__ offset_t& operator+=(int i) {
|
||||
begin += i;
|
||||
return *this;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
// Segmented sort by full sort algorithm:.
|
||||
// Say we are sorting a (2, 3) tensor. We have in flattened form:
|
||||
|
@ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// Helper function to compute output pixel range that can contribute to input pixel
|
||||
template <typename accscalar_t>
|
||||
__device__ __forceinline__ void compute_output_range(
|
||||
int input_pos,
|
||||
accscalar_t scale,
|
||||
int output_size,
|
||||
bool align_corners,
|
||||
int& min_output,
|
||||
int& max_output) {
|
||||
accscalar_t lo, hi;
|
||||
if (align_corners) {
|
||||
lo = static_cast<accscalar_t>(input_pos - 1) / scale;
|
||||
hi = static_cast<accscalar_t>(input_pos + 1) / scale;
|
||||
} else {
|
||||
lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
|
||||
hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
|
||||
}
|
||||
min_output = max(0, static_cast<int>(ceil(lo)));
|
||||
max_output = min(output_size - 1, static_cast<int>(floor(hi)));
|
||||
}
|
||||
#endif
|
||||
|
||||
// Backward (adjoint) operation 1 <- 2 (accumulates)
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
C10_LAUNCH_BOUNDS_1(1024)
|
||||
@ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame(
|
||||
const bool align_corners,
|
||||
scalar_t* __restrict__ idata,
|
||||
const scalar_t* __restrict__ odata) {
|
||||
const size_t o_numel = nc * width2 * height2;
|
||||
// In C++, integer multiplication, like in standard arithmetic, is generally commutative.
|
||||
const size_t i_numel = nc * width1 * height1;
|
||||
#ifdef USE_ROCM
|
||||
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
|
||||
index += blockDim.x * gridDim.x) {
|
||||
// Decode input pixel coordinates
|
||||
size_t index_temp = index;
|
||||
const int w1 = index_temp % width1;
|
||||
index_temp /= width1;
|
||||
const int h1 = index_temp % height1;
|
||||
const size_t nc_idx = index_temp / height1;
|
||||
|
||||
accscalar_t grad_sum = 0;
|
||||
|
||||
// Find range of output pixels that could interpolate from this input pixel
|
||||
int h2_min, h2_max, w2_min, w2_max;
|
||||
compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
|
||||
compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
|
||||
|
||||
// Iterate over potential output pixels
|
||||
for (int h2 = h2_min; h2 <= h2_max; h2++) {
|
||||
for (int w2 = w2_min; w2 <= w2_max; w2++) {
|
||||
// Compute source coordinates for this output pixel
|
||||
const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
|
||||
rheight, h2, align_corners, /*cubic=*/false);
|
||||
const int h1_base = (int)h1r;
|
||||
const int h1p = (h1_base < height1 - 1) ? 1 : 0;
|
||||
const accscalar_t h1lambda = h1r - h1_base;
|
||||
const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
|
||||
|
||||
const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
|
||||
rwidth, w2, align_corners, /*cubic=*/false);
|
||||
const int w1_base = (int)w1r;
|
||||
const int w1p = (w1_base < width1 - 1) ? 1 : 0;
|
||||
const accscalar_t w1lambda = w1r - w1_base;
|
||||
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
|
||||
|
||||
// Check if our input pixel participates in this interpolation and accumulate all weights
|
||||
// At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
|
||||
// to the same pixel, so we need to accumulate weights from all matching positions
|
||||
accscalar_t weight = 0;
|
||||
|
||||
// Check all four interpolation positions and accumulate weights
|
||||
if (h1 == h1_base && w1 == w1_base) {
|
||||
weight += h0lambda * w0lambda; // top-left
|
||||
}
|
||||
if (h1 == h1_base && w1 == w1_base + w1p) {
|
||||
weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0)
|
||||
}
|
||||
if (h1 == h1_base + h1p && w1 == w1_base) {
|
||||
weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0)
|
||||
}
|
||||
if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
|
||||
weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions)
|
||||
}
|
||||
|
||||
if (weight > 0) {
|
||||
const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
|
||||
grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write accumulated gradient (no atomics needed)
|
||||
idata[index] = static_cast<scalar_t>(grad_sum);
|
||||
}
|
||||
#else
|
||||
const size_t o_numel = nc * width2 * height2;
|
||||
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
|
||||
index += blockDim.x * gridDim.x) {
|
||||
size_t index_temp = index;
|
||||
@ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame(
|
||||
static_cast<scalar_t>(h1lambda * w1lambda * d2val),
|
||||
true);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
@ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
// threads are not covering the whole input tensor.
|
||||
grad_input.zero_();
|
||||
|
||||
const size_t num_kernels = nbatch * channels * output_height * output_width;
|
||||
const int num_threads = std::min(
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
@ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
constexpr bool use_input = true;
|
||||
#else
|
||||
constexpr bool use_input = false;
|
||||
#endif
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16,
|
||||
grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
|
||||
@ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
|
||||
input_width, output_width, align_corners, scales_w);
|
||||
|
||||
const size_t num_kernels = nbatch * channels * output_height * output_width;
|
||||
|
||||
upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
|
||||
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
|
||||
input_height,
|
||||
@ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
|
||||
input_width, output_width, align_corners, scales_w);
|
||||
|
||||
const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
|
||||
|
||||
upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
|
||||
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
|
||||
num_threads,
|
||||
|
@ -466,7 +466,11 @@ struct ReduceJitOp {
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifdef USE_ROCM
|
||||
for (int offset = 1; offset < dim_x; offset <<= 1) {
|
||||
#else
|
||||
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
|
||||
#endif
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
arg_t other = reducer::warp_shfl_down(value[i], offset);
|
||||
|
@ -512,7 +512,7 @@ TORCH_IMPL_FUNC(threshold_backward_out_mps)
|
||||
}
|
||||
|
||||
static MPSGraphTensor* normcdf(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
// (1.0f + erf(x*SQRT1_2)) * 0.5f * x;
|
||||
// (1.0f + erf(x*SQRT1_2)) * 0.5f;
|
||||
auto dataType = [inputTensor dataType];
|
||||
const float SQRT1_2 = 0.707106781186547524400844362104849039f;
|
||||
MPSGraphTensor* sqrt1_2 = [mpsGraph constantWithScalar:SQRT1_2 shape:@[ @1 ] dataType:dataType];
|
||||
|
@ -54,6 +54,10 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) {
|
||||
using namespace mps;
|
||||
using CachedGraph = MPSBinaryCachedGraph;
|
||||
|
||||
if (self.numel() == 0 & other.numel() == 0) {
|
||||
return zeros({}, self.options());
|
||||
}
|
||||
|
||||
dot_check(self, other);
|
||||
|
||||
auto output = at::empty({}, self.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
|
||||
|
@ -7183,6 +7183,12 @@
|
||||
CUDA: _scaled_grouped_mm_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _scaled_grouped_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _scaled_grouped_mm_cuda_v2
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,0
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,0
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,0
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,0
|
||||
|
||||
|
||||
visformer_small,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,pass,0
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,0
|
||||
|
|
@ -10,10 +10,18 @@ beit_base_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
deit_base_distilled_patch16_224,pass,7
|
||||
|
||||
|
||||
|
||||
deit_tiny_patch16_224.fb_in1k,pass,7
|
||||
|
||||
|
||||
|
||||
dm_nfnet_f0,pass,6
|
||||
|
||||
|
||||
@ -55,3 +63,11 @@ tf_efficientnet_b0,pass,6
|
||||
|
||||
|
||||
visformer_small,pass,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch14_dinov2.lvd142m,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
vit_base_patch16_siglip_256,pass,7
|
||||
|
|
@ -1060,6 +1060,8 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
|
||||
frozen_model_iter_fn = export_nativert(model, example_inputs)
|
||||
elif args.torchscript_jit_trace:
|
||||
frozen_model_iter_fn = torchscript_jit_trace(model, example_inputs)
|
||||
elif args.aot_precompile:
|
||||
frozen_model_iter_fn = aot_precompile(model, example_inputs)
|
||||
else:
|
||||
if kwargs["hf_llm"]:
|
||||
# If it's an llm, we want to optimize model.forward, and use
|
||||
@ -1495,6 +1497,37 @@ def export(model, example_inputs):
|
||||
return opt_export
|
||||
|
||||
|
||||
def aot_precompile(model, example_inputs):
|
||||
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
|
||||
save_path = f.name
|
||||
|
||||
with fresh_cache(), torch._dynamo.config.patch("enable_aot_compile", True):
|
||||
compiled_fn = torch.compile(
|
||||
model,
|
||||
fullgraph=True,
|
||||
options={"guard_filter_fn": lambda guards: [False for _ in guards]},
|
||||
).forward.aot_compile((example_args, example_kwargs))
|
||||
|
||||
compiled_fn.save_compiled_function(save_path)
|
||||
|
||||
torch._dynamo.reset()
|
||||
with open(save_path, "rb") as f:
|
||||
load_start_time = time.perf_counter()
|
||||
loaded_fn = torch.compiler.load_compiled_function(f)
|
||||
load_end_time = time.perf_counter()
|
||||
print(
|
||||
f"AOT Precompile loading time: {load_end_time - load_start_time} seconds"
|
||||
)
|
||||
|
||||
def opt_aot_precompile(_, example_inputs, collect_outputs=False):
|
||||
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
|
||||
return loaded_fn(model, *example_args, **example_kwargs)
|
||||
|
||||
return opt_aot_precompile
|
||||
|
||||
|
||||
def export_nativert(model, example_inputs):
|
||||
optimized = NativeRTCache.load(model, example_inputs)
|
||||
|
||||
@ -2274,6 +2307,7 @@ class BenchmarkRunner:
|
||||
or self.args.export_aot_inductor
|
||||
or self.args.export_nativert
|
||||
or self.args.torchscript_jit_trace
|
||||
or self.args.aot_precompile
|
||||
):
|
||||
# apply export on module directly
|
||||
# no need for n iterations
|
||||
@ -2729,6 +2763,7 @@ class BenchmarkRunner:
|
||||
self.args.export_aot_inductor
|
||||
or self.args.export_nativert
|
||||
or self.args.torchscript_jit_trace
|
||||
or self.args.aot_precompile
|
||||
):
|
||||
optimized_model_iter_fn = optimize_ctx
|
||||
else:
|
||||
@ -3505,6 +3540,11 @@ def parse_args(args=None):
|
||||
action="store_true",
|
||||
help="Measure pass rate with Export+AOTInductor",
|
||||
)
|
||||
group.add_argument(
|
||||
"--aot-precompile",
|
||||
action="store_true",
|
||||
help="Measure pass rate with AOT Precompile",
|
||||
)
|
||||
group.add_argument(
|
||||
"--export-nativert",
|
||||
action="store_true",
|
||||
@ -3935,6 +3975,10 @@ def run(runner, args, original_dir=None):
|
||||
optimize_ctx = export
|
||||
experiment = speedup_experiment
|
||||
output_filename = "export.csv"
|
||||
elif args.aot_precompile:
|
||||
optimize_ctx = aot_precompile
|
||||
experiment = speedup_experiment
|
||||
output_filename = "aot_precompile.csv"
|
||||
elif args.export_nativert:
|
||||
optimize_ctx = export_nativert
|
||||
experiment = speedup_experiment
|
||||
|
@ -271,8 +271,6 @@ class TimmRunner(BenchmarkRunner):
|
||||
memory_format=torch.channels_last if channels_last else None,
|
||||
)
|
||||
|
||||
self.num_classes = model.num_classes
|
||||
|
||||
data_config = resolve_data_config(
|
||||
vars(self._args) if timmversion >= "0.8.0" else self._args,
|
||||
model=model,
|
||||
@ -302,7 +300,6 @@ class TimmRunner(BenchmarkRunner):
|
||||
example_inputs = [
|
||||
example_inputs,
|
||||
]
|
||||
self.target = self._gen_target(batch_size, device)
|
||||
|
||||
self.loss = torch.nn.CrossEntropyLoss().to(device)
|
||||
|
||||
@ -370,11 +367,6 @@ class TimmRunner(BenchmarkRunner):
|
||||
tolerance = 1e-2
|
||||
return tolerance, cosine
|
||||
|
||||
def _gen_target(self, batch_size, device):
|
||||
return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_(
|
||||
self.num_classes
|
||||
)
|
||||
|
||||
def compute_loss(self, pred):
|
||||
# High loss values make gradient checking harder, as small changes in
|
||||
# accumulation order upsets accuracy checks.
|
||||
|
@ -1,6 +1,8 @@
|
||||
adv_inception_v3 128
|
||||
beit_base_patch16_224 128
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k 128
|
||||
deit_base_distilled_patch16_224 128
|
||||
deit_tiny_patch16_224.fb_in1k 128
|
||||
dm_nfnet_f0 128
|
||||
ghostnet_100 512
|
||||
inception_v3 128
|
||||
@ -12,3 +14,5 @@ repvgg_a2 128
|
||||
swin_base_patch4_window7_224 128
|
||||
tf_efficientnet_b0 128
|
||||
visformer_small 128
|
||||
vit_base_patch14_dinov2.lvd142m 128
|
||||
vit_base_patch16_siglip_256 128
|
@ -1,6 +1,8 @@
|
||||
adv_inception_v3,128
|
||||
beit_base_patch16_224,64
|
||||
convnextv2_nano.fcmae_ft_in22k_in1k,128
|
||||
deit_base_distilled_patch16_224,64
|
||||
deit_tiny_patch16_224.fb_in1k,128
|
||||
dm_nfnet_f0,128
|
||||
ghostnet_100,128
|
||||
inception_v3,128
|
||||
@ -12,3 +14,5 @@ repvgg_a2,128
|
||||
swin_base_patch4_window7_224,64
|
||||
tf_efficientnet_b0,128
|
||||
visformer_small,128
|
||||
vit_base_patch14_dinov2.lvd142m,128
|
||||
ViT-B-16-SigLIP-i18n-256,128
|
@ -28,101 +28,8 @@
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// [dtype Macros note] For the macros below:
|
||||
//
|
||||
// For users: If you want to macro some code for all non-QInt scalar types
|
||||
// (i.e. types with complete information, you probably want one of the
|
||||
// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND macros below, which are
|
||||
// designed to behave similarly to the Dispatch macros with the same name.
|
||||
//
|
||||
// For adding a new dtype: In the beginning, we had an idea that there was a
|
||||
// list of all scalar types, and you could use AT_FORALL_SCALAR_TYPES to
|
||||
// iterate over them. But over the years we added weird types which couldn't
|
||||
// be handled uniformly everywhere and so in the end we ended up with some
|
||||
// mish-mosh of some helper macros, but mostly use sites making a call about
|
||||
// what dtypes they can or can't support. So if you want to add a new dtype,
|
||||
// the preferred resolution is to find a dtype similar to what you want,
|
||||
// grep for it and edit all the sites you find this way. If you need to add
|
||||
// a completely new kind of dtype, you're going to have to laboriously audit
|
||||
// all of the sites everywhere to figure out how it should work. Consulting
|
||||
// some old PRs where we added new dtypes (check history of this file) can
|
||||
// help give you an idea where to start.
|
||||
|
||||
// If you want to support ComplexHalf for real, add ComplexHalf
|
||||
// into this macro (and change the name). But beware: convert()
|
||||
// doesn't work for all the conversions you need...
|
||||
//
|
||||
// TODO: To add unsigned int types here, we must define accumulate type.
|
||||
// But uint8 currently accumulates into int64, so we would have to make
|
||||
// an inconsistent choice for the larger types. Difficult.
|
||||
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(at::Half, Half) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble) \
|
||||
_(bool, Bool) \
|
||||
_(at::BFloat16, BFloat16) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn)
|
||||
|
||||
// This macro controls many of our C++ APIs, including constructors
|
||||
// for Scalar as well as the data() and item() accessors on Tensor
|
||||
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(at::Half, Half) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(c10::complex<c10::Half>, ComplexHalf) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble) \
|
||||
_(bool, Bool) \
|
||||
_(at::BFloat16, BFloat16) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn) \
|
||||
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
|
||||
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
|
||||
namespace impl {
|
||||
|
||||
// These are used to map ScalarTypes to C++ types.
|
||||
|
||||
template <c10::ScalarType N>
|
||||
struct ScalarTypeToCPPType;
|
||||
|
||||
#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, scalar_type) \
|
||||
template <> \
|
||||
struct ScalarTypeToCPPType<c10::ScalarType::scalar_type> { \
|
||||
using type = cpp_type; \
|
||||
\
|
||||
/* This is a workaround for the CUDA bug which prevents */ \
|
||||
/* ::detail::ScalarTypeToCType<T>::type being used directly due to */ \
|
||||
/* ambiguous reference which can't to be resolved. For some reason it */ \
|
||||
/* can't pick between at::detail and at::cuda::detail. */ \
|
||||
/* For repro example, please see: */ \
|
||||
/* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
|
||||
/* TODO: remove once the bug is fixed. */ \
|
||||
static type t; \
|
||||
};
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)
|
||||
|
||||
#undef SPECIALIZE_ScalarTypeToCPPType
|
||||
|
||||
template <c10::ScalarType N>
|
||||
using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
|
||||
|
||||
} // namespace impl
|
||||
// See [dtype Macros note] in torch/headeronly/core/ScalarType.h
|
||||
// regarding macros.
|
||||
|
||||
template <typename T>
|
||||
struct CppTypeToScalarType;
|
||||
@ -138,130 +45,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
||||
|
||||
#undef SPECIALIZE_CppTypeToScalarType
|
||||
|
||||
// NB: despite its generic sounding name, the macros that don't take _AND
|
||||
// are mostly only used by tensorexpr
|
||||
#define AT_FORALL_INT_TYPES(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES(_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double)
|
||||
|
||||
// These macros are often controlling how many template instantiations we
|
||||
// create for kernels. It is typically inappropriate to add new dtypes here,
|
||||
// instead, new types should be added to use sites on a case-by-case basis.
|
||||
// We generally are not accepting new dtypes due to binary size concerns.
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE>::t), \
|
||||
SCALARTYPE)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE3>::t), \
|
||||
SCALARTYPE3)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND7( \
|
||||
SCALARTYPE1, \
|
||||
SCALARTYPE2, \
|
||||
SCALARTYPE3, \
|
||||
SCALARTYPE4, \
|
||||
SCALARTYPE5, \
|
||||
SCALARTYPE6, \
|
||||
SCALARTYPE7, \
|
||||
_) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE3>::t), \
|
||||
SCALARTYPE3) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE4>::t), \
|
||||
SCALARTYPE4) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE5>::t), \
|
||||
SCALARTYPE5) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE6>::t), \
|
||||
SCALARTYPE6) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE7>::t), \
|
||||
SCALARTYPE7)
|
||||
|
||||
#define AT_FORALL_QINT_TYPES(_) \
|
||||
_(c10::qint8, QInt8) \
|
||||
_(c10::quint8, QUInt8) \
|
||||
_(c10::qint32, QInt32) \
|
||||
_(c10::quint4x2, QUInt4x2) \
|
||||
_(c10::quint2x4, QUInt2x4)
|
||||
|
||||
#define AT_FORALL_FLOAT8_TYPES(_) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn) \
|
||||
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
|
||||
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
|
||||
#define AT_FORALL_COMPLEX_TYPES(_) \
|
||||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble)
|
||||
|
||||
#define DEFINE_CONSTANT(_, name) \
|
||||
constexpr ScalarType k##name = ScalarType::name;
|
||||
|
||||
@ -269,19 +52,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
|
||||
#undef DEFINE_CONSTANT
|
||||
|
||||
inline const char* toString(ScalarType t) {
|
||||
#define DEFINE_CASE(_, name) \
|
||||
case ScalarType::name: \
|
||||
return #name;
|
||||
|
||||
switch (t) {
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
|
||||
default:
|
||||
return "UNKNOWN_SCALAR";
|
||||
}
|
||||
#undef DEFINE_CASE
|
||||
}
|
||||
|
||||
inline size_t elementSize(ScalarType t) {
|
||||
#define CASE_ELEMENTSIZE_CASE(ctype, name) \
|
||||
case ScalarType::name: \
|
||||
@ -525,12 +295,6 @@ inline bool canCast(const ScalarType from, const ScalarType to) {
|
||||
|
||||
C10_API ScalarType promoteTypes(ScalarType a, ScalarType b);
|
||||
|
||||
inline std::ostream& operator<<(
|
||||
std::ostream& stream,
|
||||
at::ScalarType scalar_type) {
|
||||
return stream << toString(scalar_type);
|
||||
}
|
||||
|
||||
// Returns a pair of strings representing the names for each dtype.
|
||||
// The returned pair is (name, legacy_name_if_applicable)
|
||||
C10_API std::pair<std::string, std::string> getDtypeNames(
|
||||
|
@ -65,7 +65,7 @@ struct default_constructible
|
||||
|
||||
namespace impl {
|
||||
template <typename T>
|
||||
constexpr bool supports_default_construction(const ::strong::default_constructible::modifier<T>*)
|
||||
constexpr bool supports_default_construction(const ::strong::default_constructible::modifier<T>* /*unused*/)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
@ -76,7 +76,7 @@ class type : public modifier<M, type<T, Tag, M...>>...
|
||||
{
|
||||
public:
|
||||
template <typename TT = T, typename = std::enable_if_t<std::is_trivially_constructible<TT>{}>>
|
||||
explicit type(uninitialized_t)
|
||||
explicit type(uninitialized_t /*unused*/)
|
||||
noexcept
|
||||
{
|
||||
}
|
||||
@ -138,7 +138,7 @@ private:
|
||||
|
||||
namespace impl {
|
||||
template <typename T, typename Tag, typename ... Ms>
|
||||
constexpr bool is_strong_type_func(const strong::type<T, Tag, Ms...>*) { return true;}
|
||||
constexpr bool is_strong_type_func(const strong::type<T, Tag, Ms...>* /*unused*/) { return true;}
|
||||
constexpr bool is_strong_type_func(...) { return false;}
|
||||
template <typename T, typename Tag, typename ... Ms>
|
||||
constexpr T underlying_type(strong::type<T, Tag, Ms...>*);
|
||||
|
@ -68,14 +68,6 @@
|
||||
.. autofunction:: get_validators
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: write_file_on_exit
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: write_file
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: read_file
|
||||
```
|
||||
@ -95,3 +87,7 @@
|
||||
```{eval-rst}
|
||||
.. autofunction:: get_rotating_buffer_size
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: set_numerical_check_tolerances
|
||||
```
|
@ -123,3 +123,7 @@ The frontend API is `fully_shard` that can be called on a `module`:
|
||||
.. autoclass:: CPUOffloadPolicy
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: share_comm_ctx
|
||||
```
|
||||
|
@ -51,7 +51,7 @@ MPI supports CUDA only if the implementation used to build PyTorch supports it.
|
||||
+----------------+-----+-----+-----+-----+-----+-----+-----+-----+
|
||||
| reduce_scatter | ✓ | ✓ | ✘ | ✘ | ✘ | ✓ | ✘ | ✓ |
|
||||
+----------------+-----+-----+-----+-----+-----+-----+-----+-----+
|
||||
| all_to_all | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ✘ | ✓ |
|
||||
| all_to_all | ✘ | ✘ | ✓ | ? | ✘ | ✓ | ✘ | ✓ |
|
||||
+----------------+-----+-----+-----+-----+-----+-----+-----+-----+
|
||||
| barrier | ✓ | ✘ | ✓ | ? | ✘ | ✓ | ✘ | ✓ |
|
||||
+----------------+-----+-----+-----+-----+-----+-----+-----+-----+
|
||||
|
@ -23,6 +23,7 @@ Submodules
|
||||
flex_attention
|
||||
bias
|
||||
experimental
|
||||
varlen
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
@ -30,3 +31,4 @@ Submodules
|
||||
nn.attention.flex_attention
|
||||
nn.attention.bias
|
||||
nn.attention.experimental
|
||||
nn.attention.varlen
|
||||
|
17
docs/source/nn.attention.varlen.md
Normal file
17
docs/source/nn.attention.varlen.md
Normal file
@ -0,0 +1,17 @@
|
||||
```{eval-rst}
|
||||
.. role:: hidden
|
||||
:class: hidden-section
|
||||
```
|
||||
|
||||
# torch.nn.attention.varlen
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.nn.attention.varlen
|
||||
.. currentmodule:: torch.nn.attention.varlen
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autofunction:: varlen_attn
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autoclass:: AuxRequest
|
||||
```
|
@ -228,3 +228,4 @@ Low-Precision functions
|
||||
ScalingType
|
||||
SwizzleType
|
||||
scaled_mm
|
||||
scaled_grouped_mm
|
||||
|
@ -1,14 +1,12 @@
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.compiler.config
|
||||
|
||||
```
|
||||
|
||||
# torch.compiler.config
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.compiler.config
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autodata:: torch.compiler.config.job_id
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
@ -816,6 +816,10 @@ Operator Tags
|
||||
.. py:module:: torch.types
|
||||
.. py:module:: torch.version
|
||||
|
||||
.. Compiler configuration module - documented in torch.compiler.config.md
|
||||
.. py:module:: torch.compiler.config
|
||||
:noindex:
|
||||
|
||||
.. Hidden aliases (e.g. torch.functional.broadcast_tensors()). We want `torch.broadcast_tensors()` to
|
||||
be visible only.
|
||||
.. toctree::
|
||||
|
@ -10,6 +10,7 @@ set(AOTI_ABI_CHECK_TEST_SRCS
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_scalartype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_vec.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_vec_half.cpp
|
||||
)
|
||||
|
76
test/cpp/aoti_abi_check/test_scalartype.cpp
Normal file
76
test/cpp/aoti_abi_check/test_scalartype.cpp
Normal file
@ -0,0 +1,76 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
TEST(TestScalarType, ScalarTypeToCPPTypeT) {
|
||||
using torch::headeronly::ScalarType;
|
||||
using torch::headeronly::impl::ScalarTypeToCPPTypeT;
|
||||
|
||||
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
|
||||
EXPECT_EQ(typeid(ScalarTypeToCPPTypeT<ScalarType::SCALARTYPE>), typeid(TYPE));
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
|
||||
{ \
|
||||
EXPECT_EQ( \
|
||||
typeid(ScalarTypeToCPPTypeT<ScalarType::SCALARTYPE>), typeid(TYPE)); \
|
||||
count++; \
|
||||
}
|
||||
|
||||
#define TEST_FORALL(M, EXPECTEDCOUNT, ...) \
|
||||
TEST(TestScalarType, M) { \
|
||||
using torch::headeronly::ScalarType; \
|
||||
using torch::headeronly::impl::ScalarTypeToCPPTypeT; \
|
||||
int8_t count = 0; \
|
||||
M(__VA_ARGS__ DEFINE_CHECK); \
|
||||
EXPECT_EQ(count, EXPECTEDCOUNT); \
|
||||
}
|
||||
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ, 14)
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX, 18)
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS, 46)
|
||||
TEST_FORALL(AT_FORALL_INT_TYPES, 5)
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES, 7)
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND, 8, Bool, )
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND2, 9, Bool, Half, )
|
||||
TEST_FORALL(AT_FORALL_SCALAR_TYPES_AND3, 10, Bool, Half, ComplexFloat, )
|
||||
TEST_FORALL(
|
||||
AT_FORALL_SCALAR_TYPES_AND7,
|
||||
14,
|
||||
Bool,
|
||||
Half,
|
||||
ComplexHalf,
|
||||
ComplexFloat,
|
||||
ComplexDouble,
|
||||
UInt16,
|
||||
UInt32, )
|
||||
TEST_FORALL(AT_FORALL_QINT_TYPES, 5)
|
||||
TEST_FORALL(AT_FORALL_FLOAT8_TYPES, 5)
|
||||
TEST_FORALL(AT_FORALL_COMPLEX_TYPES, 2)
|
||||
|
||||
#undef DEFINE_CHECK
|
||||
#undef TEST_FORALL
|
||||
|
||||
TEST(TestScalarType, toString) {
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
#define DEFINE_CHECK(_, name) EXPECT_EQ(toString(ScalarType::name), #name);
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
TEST(TestScalarType, operator_left_shift) {
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
#define DEFINE_CHECK(_, name) \
|
||||
{ \
|
||||
std::stringstream ss; \
|
||||
ss << ScalarType::name; \
|
||||
EXPECT_EQ(ss.str(), #name); \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
@ -6,7 +6,7 @@ import functools
|
||||
import itertools
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -24,6 +24,11 @@ from torch.distributed.fsdp import (
|
||||
fully_shard,
|
||||
OffloadPolicy,
|
||||
register_fsdp_forward_method,
|
||||
share_comm_ctx,
|
||||
)
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
|
||||
foreach_all_gather,
|
||||
foreach_reduce,
|
||||
)
|
||||
from torch.distributed.tensor import DTensor, init_device_mesh, Shard
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
@ -39,6 +44,8 @@ from torch.testing._internal.common_fsdp import (
|
||||
MLP,
|
||||
MLPStack,
|
||||
patch_all_gather,
|
||||
patch_foreach_all_gather,
|
||||
patch_foreach_reduce,
|
||||
patch_reduce_scatter,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -1487,6 +1494,116 @@ class TestFullyShardCustomForwardMethod(FSDPTest):
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
|
||||
class TestFullyShardShareCommContext(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(torch.get_device_module(device_type).device_count(), 2)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_share_comm_context(self):
|
||||
torch.manual_seed(42)
|
||||
n_layers = 3
|
||||
lin_dim = 16
|
||||
model = nn.Sequential(
|
||||
*[MLP(lin_dim, torch.device("cpu")) for _ in range(n_layers)]
|
||||
)
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
for layer in model:
|
||||
fully_shard(layer)
|
||||
layer._get_fsdp_state()._lazy_init()
|
||||
share_comm_ctx(list(model))
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn(4, 3, lin_dim, device=device_type.type)
|
||||
ref_loss = ref_model(inp).sum()
|
||||
|
||||
all_gather_streams = set()
|
||||
reduce_scatter_streams = set()
|
||||
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_api import (
|
||||
AllGather,
|
||||
ReduceScatter,
|
||||
)
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||
|
||||
orig_foreach_all_gather = foreach_all_gather
|
||||
|
||||
def foreach_all_gather_with_assert(
|
||||
fsdp_params: list[FSDPParam],
|
||||
group: dist.ProcessGroup,
|
||||
async_op: bool,
|
||||
all_gather_copy_in_stream: torch.Stream,
|
||||
all_gather_stream: torch.Stream,
|
||||
device: torch.device,
|
||||
all_gather_comm: AllGather,
|
||||
):
|
||||
nonlocal all_gather_streams
|
||||
all_gather_streams.add(all_gather_stream)
|
||||
return orig_foreach_all_gather(
|
||||
fsdp_params,
|
||||
group,
|
||||
async_op,
|
||||
all_gather_copy_in_stream,
|
||||
all_gather_stream,
|
||||
device,
|
||||
all_gather_comm,
|
||||
)
|
||||
|
||||
orig_foreach_reduce = foreach_reduce
|
||||
|
||||
@torch.no_grad()
|
||||
def foreach_reduce_with_assert(
|
||||
fsdp_params: list[FSDPParam],
|
||||
unsharded_grads: list[torch.Tensor],
|
||||
reduce_scatter_group: dist.ProcessGroup,
|
||||
reduce_scatter_stream: torch.Stream,
|
||||
reduce_scatter_comm: ReduceScatter,
|
||||
orig_dtype: Optional[torch.dtype],
|
||||
reduce_dtype: Optional[torch.dtype],
|
||||
device: torch.device,
|
||||
gradient_divide_factor: Optional[float],
|
||||
all_reduce_group: Optional[dist.ProcessGroup], # not `None` iff HSDP
|
||||
all_reduce_stream: torch.Stream,
|
||||
all_reduce_grads: bool,
|
||||
partial_reduce_output: Optional[torch.Tensor], # only used for HSDP
|
||||
all_reduce_hook: Optional[Callable[[torch.Tensor], None]],
|
||||
force_sum_reduction_for_comms: bool = False,
|
||||
):
|
||||
nonlocal reduce_scatter_streams
|
||||
reduce_scatter_streams.add(reduce_scatter_stream)
|
||||
return orig_foreach_reduce(
|
||||
fsdp_params,
|
||||
unsharded_grads,
|
||||
reduce_scatter_group,
|
||||
reduce_scatter_stream,
|
||||
reduce_scatter_comm,
|
||||
orig_dtype,
|
||||
reduce_dtype,
|
||||
device,
|
||||
gradient_divide_factor,
|
||||
all_reduce_group,
|
||||
all_reduce_stream,
|
||||
all_reduce_grads,
|
||||
partial_reduce_output,
|
||||
all_reduce_hook,
|
||||
force_sum_reduction_for_comms,
|
||||
)
|
||||
|
||||
with (
|
||||
patch_foreach_all_gather(foreach_all_gather_with_assert),
|
||||
patch_foreach_reduce(foreach_reduce_with_assert),
|
||||
):
|
||||
loss = model(inp).sum()
|
||||
self.assertEqual(ref_loss, loss)
|
||||
ref_loss.backward()
|
||||
loss.backward()
|
||||
for param in ref_model.parameters():
|
||||
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
|
||||
self.assertEqual(len(all_gather_streams), 1)
|
||||
self.assertEqual(len(reduce_scatter_streams), 1)
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
|
||||
class TestFullyShardWorldSize1(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
|
@ -559,7 +559,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
FAIL = 138
|
||||
pc = start_processes(
|
||||
name="echo",
|
||||
entrypoint=bin("echo1.py"),
|
||||
entrypoint=bin("echo4.py"),
|
||||
args={0: ("--exitcode", FAIL, "foo"), 1: ("--exitcode", 0, "bar")},
|
||||
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
|
||||
logs_specs=DefaultLogsSpecs(
|
||||
|
@ -9,7 +9,6 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -24,6 +23,5 @@ if __name__ == "__main__":
|
||||
print(f"exit {exitcode} from {rank}", file=sys.stderr)
|
||||
sys.exit(exitcode)
|
||||
else:
|
||||
time.sleep(1000)
|
||||
print(f"{args.msg} stdout from {rank}")
|
||||
print(f"{args.msg} stderr from {rank}", file=sys.stderr)
|
||||
|
29
test/distributed/elastic/multiprocessing/bin/echo4.py
Executable file
29
test/distributed/elastic/multiprocessing/bin/echo4.py
Executable file
@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="test binary, exits with exitcode")
|
||||
parser.add_argument("--exitcode", type=int, default=0)
|
||||
parser.add_argument("msg", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
rank = int(os.environ["RANK"])
|
||||
exitcode = args.exitcode
|
||||
if exitcode != 0:
|
||||
print(f"exit {exitcode} from {rank}", file=sys.stderr)
|
||||
sys.exit(exitcode)
|
||||
else:
|
||||
time.sleep(1000)
|
||||
print(f"{args.msg} stdout from {rank}")
|
||||
print(f"{args.msg} stderr from {rank}", file=sys.stderr)
|
@ -536,6 +536,23 @@ class TestScheduleLowering(TestCase):
|
||||
"compute": ["0F0", "0F1", " ", "0B0", "0B1"],
|
||||
"comms": ["0UNSHARD", "0F0", "0F1", "0B0", "0B1", "0RESHARD"],
|
||||
},
|
||||
{
|
||||
"compute": ["0F0", "0F1", "1F0", "1F1", "1B0", "1B1", "0B0", "0B1"],
|
||||
"comms": [
|
||||
"0UNSHARD",
|
||||
"1UNSHARD",
|
||||
"0F0",
|
||||
"0F1",
|
||||
"1F0",
|
||||
"1F1",
|
||||
"1B0",
|
||||
"1B1",
|
||||
"1RESHARD",
|
||||
"0B0",
|
||||
"0B1",
|
||||
"0RESHARD",
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_unshard_reshard(self, test_info):
|
||||
|
@ -1020,6 +1020,19 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
self.fail("Unexpected ValueError raised with run_check=False")
|
||||
|
||||
|
||||
DTensorMeshTestWithLocalTensor = create_local_tensor_test_class(
|
||||
DTensorMeshTest,
|
||||
skipped_tests=[
|
||||
# Submeshes are not supported by local tensor mode
|
||||
"test_from_local_sub_mesh",
|
||||
"test_default_value_sub_mesh",
|
||||
"test_redistribute_sub_mesh",
|
||||
# Local tensor mode doesn't support tensors of different types on different ranks
|
||||
"test_metadata_consistency_check",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestDTensorPlacementTypes(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
@ -1086,6 +1099,11 @@ class TestDTensorPlacementTypes(DTensorTestBase):
|
||||
assert_array_equal(expected_is_tensor_empty, is_tensor_empty)
|
||||
|
||||
|
||||
TestDTensorPlacementTypesWithLocalTensor = create_local_tensor_test_class(
|
||||
TestDTensorPlacementTypes,
|
||||
)
|
||||
|
||||
|
||||
class TestDTensorSpec(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
@ -1265,5 +1283,9 @@ class TestDTensorSpec(DTensorTestBase):
|
||||
)
|
||||
|
||||
|
||||
TestDTensorSpecWithLocalTensor = create_local_tensor_test_class(
|
||||
TestDTensorSpec,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -239,9 +239,7 @@ class DTensorExportTest(TestCase):
|
||||
"view_9",
|
||||
"t_15",
|
||||
"detach",
|
||||
"detach_1",
|
||||
"detach_6",
|
||||
"detach_7",
|
||||
"detach_3",
|
||||
"threshold_backward_1",
|
||||
"t_16",
|
||||
"mm_6",
|
||||
@ -259,10 +257,8 @@ class DTensorExportTest(TestCase):
|
||||
"sum_1",
|
||||
"view_7",
|
||||
"t_7",
|
||||
"detach_1",
|
||||
"detach_2",
|
||||
"detach_3",
|
||||
"detach_4",
|
||||
"detach_5",
|
||||
"threshold_backward",
|
||||
"mm_2",
|
||||
"t_9",
|
||||
|
@ -20,6 +20,7 @@ from torch.distributed.tensor._collective_utils import shard_dim_alltoall
|
||||
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
|
||||
from torch.distributed.tensor._redistribute import redistribute_local_tensor
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.placement_types import _StridedShard
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -1145,6 +1146,22 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
sharded_dt, mesh, tgt_placement, shard_order=None
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_shard_order_same_data_as_strided_shard(self):
|
||||
device_mesh = init_device_mesh(self.device_type, (4, 2))
|
||||
x = torch.randn(8, 4, device=self.device_type)
|
||||
# specify right-to-left order use _StridedShard
|
||||
strided_placement = [_StridedShard(-2, split_factor=2), Shard(-2)]
|
||||
x_strided_dt = distribute_tensor(x, device_mesh, strided_placement)
|
||||
# specify right-to-left order use ordered shard
|
||||
x_ordered_dt = self.distribute_tensor(
|
||||
x,
|
||||
device_mesh,
|
||||
placements=[Shard(0), Shard(0)],
|
||||
shard_order=(ShardOrderEntry(tensor_dim=0, mesh_dims=(1, 0)),),
|
||||
)
|
||||
self.assertEqual(x_ordered_dt.to_local(), x_strided_dt.to_local())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -70,6 +70,8 @@ def get_patches():
|
||||
"force_disable_caches": True,
|
||||
# Messes up existing test strings
|
||||
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
|
||||
# interferes with testing, / custom estimation
|
||||
"test_configs.assume_bucketing_reduces_latency": False,
|
||||
}
|
||||
|
||||
|
||||
@ -364,6 +366,8 @@ def get_bucket_patches(compute_multiplier=1.0):
|
||||
"force_disable_caches": True,
|
||||
# messes up test strings
|
||||
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
|
||||
# interferes with testing, / custom estimation
|
||||
"test_configs.assume_bucketing_reduces_latency": False,
|
||||
}
|
||||
|
||||
|
||||
@ -579,7 +583,7 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches(2.0))
|
||||
def test_bucketing_split_for_overlap_blocking(self):
|
||||
def test_bucketing_split_for_overlap_blocking_no_deps(self):
|
||||
"""Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute."""
|
||||
|
||||
def func(a, b, c, d, *, ranks):
|
||||
|
@ -7,8 +7,13 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing.reductions import reduce_tensor
|
||||
from torch.testing._internal.common_cuda import SM100OrLater
|
||||
from torch.testing._internal.common_distributed import MultiProcContinuousTest
|
||||
from torch.testing._internal.common_utils import requires_cuda_p2p_access, run_tests
|
||||
from torch.testing._internal.common_utils import (
|
||||
requires_cuda_p2p_access,
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
)
|
||||
|
||||
|
||||
# So that tests are written in device-agnostic way
|
||||
@ -59,6 +64,10 @@ class CupyAsTensorTest(MultiProcContinuousTest):
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(device_type, self.rank)
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
SM100OrLater,
|
||||
"Fails if ran in docker environment without privileged access (https://github.com/pytorch/pytorch/issues/165170)",
|
||||
)
|
||||
def test_cupy_as_tensor(self) -> None:
|
||||
"""
|
||||
Test that torch.as_tensor works for cupy array interface
|
||||
|
@ -2,6 +2,7 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
import os
|
||||
import unittest
|
||||
from datetime import timedelta
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -40,6 +41,13 @@ from torch.utils._typing_utils import not_none
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
device_count = torch.accelerator.device_count()
|
||||
|
||||
try:
|
||||
import torch._C._distributed_c10d.ProcessGroupNCCL
|
||||
|
||||
_NCCL_AVAILABLE = True
|
||||
except ImportError:
|
||||
_NCCL_AVAILABLE = False
|
||||
|
||||
|
||||
def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_rank=-1):
|
||||
os.environ["MASTER_ADDR"] = addr
|
||||
@ -962,6 +970,85 @@ class TestDeviceMeshGetItem(DTensorTestBase):
|
||||
# check flattened mesh dependency
|
||||
self.assertEqual(dp_cp_mesh._get_root_mesh(), mesh_4d)
|
||||
|
||||
@with_comms
|
||||
def test_unflatten_mesh_2d(self):
|
||||
mesh_shape = (4, 2)
|
||||
mesh_dim_names = ("dp", "tp")
|
||||
mesh_2d = init_device_mesh(
|
||||
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
unflatten_mesh = mesh_2d._unflatten(0, (2, 2), ("dp_shard", "dp_replicate"))
|
||||
self.assertEqual(
|
||||
unflatten_mesh.mesh_dim_names, ["dp_shard", "dp_replicate", "tp"]
|
||||
)
|
||||
self.assertEqual(mesh_2d["tp"].mesh, unflatten_mesh["tp"].mesh)
|
||||
self.assertEqual(mesh_2d["tp"].get_group(), unflatten_mesh["tp"].get_group())
|
||||
|
||||
# Not supporting slicing out unflatten dim name from root mesh.
|
||||
with self.assertRaises(KeyError):
|
||||
self.assertEqual(mesh_2d["dp_shard"].mesh, unflatten_mesh["dp_shard"].mesh)
|
||||
|
||||
@with_comms
|
||||
def test_unflatten_mesh_3d(self):
|
||||
# Test unflatten from a dummy world mesh, which is the case we need for Expert Parallelism(EP).
|
||||
global_mesh = init_device_mesh(
|
||||
self.device_type,
|
||||
(8,),
|
||||
mesh_dim_names=("world",),
|
||||
)
|
||||
non_ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "cp", "tp"))
|
||||
ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "ep", "ep_tp"))
|
||||
self.assertEqual(non_ep_mesh["cp"].mesh, ep_mesh["ep"].mesh)
|
||||
self.assertEqual(non_ep_mesh["tp"].mesh, ep_mesh["ep_tp"].mesh)
|
||||
mesh_3d = global_mesh._unflatten(0, (4, 2, 1), ("dp", "cp", "tp"))
|
||||
unflatten_mesh = mesh_3d._unflatten(0, (2, 2), ("dp_shard", "dp_replicate"))
|
||||
self.assertEqual(
|
||||
unflatten_mesh.mesh_dim_names, ["dp_shard", "dp_replicate", "cp", "tp"]
|
||||
)
|
||||
self.assertEqual(mesh_3d["tp"].mesh, unflatten_mesh["tp"].mesh)
|
||||
self.assertEqual(mesh_3d["tp"].get_group(), unflatten_mesh["tp"].get_group())
|
||||
self.assertEqual(mesh_3d["cp"].mesh, unflatten_mesh["cp"].mesh)
|
||||
self.assertEqual(mesh_3d["cp"].get_group(), unflatten_mesh["cp"].get_group())
|
||||
|
||||
# Test unflatten with backend override set.
|
||||
if not _NCCL_AVAILABLE:
|
||||
return
|
||||
opts = dist.ProcessGroupNCCL.Options()
|
||||
opts._timeout = timedelta(seconds=30)
|
||||
mesh_2d = global_mesh._unflatten(
|
||||
0,
|
||||
(1, 8),
|
||||
("pp", "spmd"),
|
||||
backend_override={"pp": "fake", "spmd": ("nccl", opts)},
|
||||
)
|
||||
opts = dist.ProcessGroupNCCL.Options()
|
||||
opts._timeout = timedelta(seconds=60)
|
||||
mesh_4d = mesh_2d._unflatten(
|
||||
1,
|
||||
(2, 2, 2),
|
||||
("dp", "cp", "tp"),
|
||||
backend_override={"dp": "nccl", "cp": "nccl", "tp": ("nccl", opts)},
|
||||
)
|
||||
self.assertEqual(mesh_4d["pp"].get_group()._get_backend_name(), "custom")
|
||||
spmd_pg = mesh_2d["spmd"].get_group()
|
||||
self.assertEqual(spmd_pg._get_backend_name(), "nccl")
|
||||
w = spmd_pg.allreduce(torch.rand(10).cuda(self.rank))
|
||||
self.assertTrue(
|
||||
spmd_pg._get_backend(
|
||||
torch.device(f"cuda:{self.rank}")
|
||||
)._verify_work_timeout(w, timedelta(seconds=30))
|
||||
)
|
||||
w.wait()
|
||||
tp_pg = mesh_4d["tp"].get_group()
|
||||
self.assertEqual(tp_pg._get_backend_name(), "nccl")
|
||||
w = tp_pg.allreduce(torch.rand(10).cuda(self.rank))
|
||||
self.assertTrue(
|
||||
tp_pg._get_backend(torch.device(f"cuda:{self.rank}"))._verify_work_timeout(
|
||||
w, timedelta(seconds=60)
|
||||
)
|
||||
)
|
||||
w.wait()
|
||||
|
||||
@with_comms
|
||||
def test_reconstruct_mesh_with_flatten_dim(self):
|
||||
mesh_3d = init_device_mesh(
|
||||
|
@ -273,12 +273,7 @@ class TestFakePG(TestCase):
|
||||
kwargs = {}
|
||||
return func(*args, **kwargs)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"FakeProcessGroup cannot be constructed directly\. "
|
||||
r"Use torch\.distributed\.init_process_group\(backend='fake'\) instead to ensure "
|
||||
r"proper dispatch system integration\.",
|
||||
):
|
||||
with self.assertRaisesRegex(TypeError, r"No constructor defined"):
|
||||
fake_pg = FakeProcessGroup(rank=0, world_size=3)
|
||||
|
||||
with SimpleTensorMode():
|
||||
|
@ -12,6 +12,7 @@ import torch.distributed._symmetric_memory as symm_mem
|
||||
import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem
|
||||
from torch._inductor.runtime.triton_compat import triton
|
||||
from torch.distributed._symmetric_memory._nvshmem_triton import requires_nvshmem
|
||||
from torch.testing._internal.common_cuda import SM100OrLater
|
||||
from torch.testing._internal.common_distributed import MultiProcContinuousTest
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
@ -264,6 +265,10 @@ def my_reduce_kernel(
|
||||
nvshmem.reduce(team_handle, dest_tensor, source_tensor, nreduce, operation)
|
||||
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
SM100OrLater,
|
||||
"Skipping all NVSHMEM Triton tests due to https://github.com/pytorch/pytorch/issues/162897",
|
||||
)
|
||||
@instantiate_parametrized_tests
|
||||
class NVSHMEMTritonTest(MultiProcContinuousTest):
|
||||
def _init_device(self) -> None:
|
||||
|
@ -52,6 +52,9 @@ from torch.testing._internal.common_utils import (
|
||||
|
||||
test_contexts = [nullcontext, _test_mode]
|
||||
|
||||
# Set environment variable to disable multicast for all tests in this module
|
||||
os.environ["TORCH_SYMM_MEM_DISABLE_MULTICAST"] = "1"
|
||||
|
||||
# So that tests are written in device-agnostic way
|
||||
device_type = "cuda"
|
||||
device_module = torch.get_device_module(device_type)
|
||||
@ -546,6 +549,10 @@ class AsyncTPTest(MultiProcContinuousTest):
|
||||
@skipUnless(SM89OrLater, "Requires compute capability >= 8.9")
|
||||
@parametrize("scatter_dim", [0, 1])
|
||||
@parametrize("rowwise", [True, False])
|
||||
@skipIf(
|
||||
SM100OrLater,
|
||||
"https://github.com/pytorch/pytorch/issues/162940",
|
||||
)
|
||||
def test_fused_scaled_matmul_reduce_scatter(
|
||||
self, scatter_dim: int, rowwise: bool
|
||||
) -> None:
|
||||
|
@ -1,5 +1,5 @@
|
||||
diff --git a/test/dynamo/cpython/3_13/test_baseexception.py b/test/dynamo/cpython/3_13/test_baseexception.py
|
||||
index e599b02c17d..750d7a84fb4 100644
|
||||
index e599b02c17d..057b6ec01b9 100644
|
||||
--- a/test/dynamo/cpython/3_13/test_baseexception.py
|
||||
+++ b/test/dynamo/cpython/3_13/test_baseexception.py
|
||||
@@ -1,10 +1,64 @@
|
||||
@ -78,7 +78,27 @@ index e599b02c17d..750d7a84fb4 100644
|
||||
self.assertEqual(len(exc_set), 0, "%s not accounted for" % exc_set)
|
||||
|
||||
interface_tests = ("length", "args", "str", "repr")
|
||||
@@ -142,7 +193,7 @@ class ExceptionClassTests(unittest.TestCase):
|
||||
@@ -122,12 +173,13 @@ class ExceptionClassTests(unittest.TestCase):
|
||||
# in PyObject_SetAttr.
|
||||
import gc
|
||||
d = {}
|
||||
- class HashThisKeyWillClearTheDict(str):
|
||||
- def __hash__(self) -> int:
|
||||
- d.clear()
|
||||
- return super().__hash__()
|
||||
- class Value(str):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class HashThisKeyWillClearTheDict(str):
|
||||
+ def __hash__(self) -> int:
|
||||
+ d.clear()
|
||||
+ return super().__hash__()
|
||||
+ class Value(str):
|
||||
+ pass
|
||||
exc = Exception()
|
||||
|
||||
d[HashThisKeyWillClearTheDict()] = Value() # refcount of Value() is 1 now
|
||||
@@ -142,7 +194,7 @@ class ExceptionClassTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
|
||||
|
||||
@ -87,7 +107,31 @@ index e599b02c17d..750d7a84fb4 100644
|
||||
|
||||
"""Test usage of exceptions"""
|
||||
|
||||
@@ -208,5 +259,5 @@ class UsageTests(unittest.TestCase):
|
||||
@@ -182,8 +234,9 @@ class UsageTests(unittest.TestCase):
|
||||
# BaseException; the ability was not possible until BaseException's
|
||||
# introduction so no need to support new-style objects that do not
|
||||
# inherit from it.
|
||||
- class NewStyleClass(object):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class NewStyleClass(object):
|
||||
+ pass
|
||||
self.raise_fails(NewStyleClass)
|
||||
self.raise_fails(NewStyleClass())
|
||||
|
||||
@@ -194,8 +247,9 @@ class UsageTests(unittest.TestCase):
|
||||
def test_catch_non_BaseException(self):
|
||||
# Trying to catch an object that does not inherit from BaseException
|
||||
# is not allowed.
|
||||
- class NonBaseException(object):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class NonBaseException(object):
|
||||
+ pass
|
||||
self.catch_fails(NonBaseException)
|
||||
self.catch_fails(NonBaseException())
|
||||
|
||||
@@ -208,5 +262,5 @@ class UsageTests(unittest.TestCase):
|
||||
self.catch_fails("spam")
|
||||
|
||||
|
||||
|
@ -173,12 +173,13 @@ class ExceptionClassTests(__TestCase):
|
||||
# in PyObject_SetAttr.
|
||||
import gc
|
||||
d = {}
|
||||
class HashThisKeyWillClearTheDict(str):
|
||||
def __hash__(self) -> int:
|
||||
d.clear()
|
||||
return super().__hash__()
|
||||
class Value(str):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class HashThisKeyWillClearTheDict(str):
|
||||
def __hash__(self) -> int:
|
||||
d.clear()
|
||||
return super().__hash__()
|
||||
class Value(str):
|
||||
pass
|
||||
exc = Exception()
|
||||
|
||||
d[HashThisKeyWillClearTheDict()] = Value() # refcount of Value() is 1 now
|
||||
@ -233,8 +234,9 @@ class UsageTests(__TestCase):
|
||||
# BaseException; the ability was not possible until BaseException's
|
||||
# introduction so no need to support new-style objects that do not
|
||||
# inherit from it.
|
||||
class NewStyleClass(object):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class NewStyleClass(object):
|
||||
pass
|
||||
self.raise_fails(NewStyleClass)
|
||||
self.raise_fails(NewStyleClass())
|
||||
|
||||
@ -245,8 +247,9 @@ class UsageTests(__TestCase):
|
||||
def test_catch_non_BaseException(self):
|
||||
# Trying to catch an object that does not inherit from BaseException
|
||||
# is not allowed.
|
||||
class NonBaseException(object):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class NonBaseException(object):
|
||||
pass
|
||||
self.catch_fails(NonBaseException)
|
||||
self.catch_fails(NonBaseException())
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
diff --git a/test/dynamo/cpython/3_13/test_exceptions.py b/test/dynamo/cpython/3_13/test_exceptions.py
|
||||
index c91f6662948..0ded70db3c7 100644
|
||||
index c91f6662948..3a62dec411c 100644
|
||||
--- a/test/dynamo/cpython/3_13/test_exceptions.py
|
||||
+++ b/test/dynamo/cpython/3_13/test_exceptions.py
|
||||
@@ -1,3 +1,59 @@
|
||||
@ -71,7 +71,305 @@ index c91f6662948..0ded70db3c7 100644
|
||||
|
||||
def raise_catch(self, exc, excname):
|
||||
with self.subTest(exc=exc, excname=excname):
|
||||
@@ -1844,7 +1900,7 @@ class ExceptionTests(unittest.TestCase):
|
||||
@@ -343,12 +399,13 @@ class ExceptionTests(unittest.TestCase):
|
||||
# test that setting an exception at the C level works even if the
|
||||
# exception object can't be constructed.
|
||||
|
||||
- class BadException(Exception):
|
||||
- def __init__(self_):
|
||||
- raise RuntimeError("can't instantiate BadException")
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class BadException(Exception):
|
||||
+ def __init__(self_):
|
||||
+ raise RuntimeError("can't instantiate BadException")
|
||||
|
||||
- class InvalidException:
|
||||
- pass
|
||||
+ class InvalidException:
|
||||
+ pass
|
||||
|
||||
@unittest.skipIf(_testcapi is None, "requires _testcapi")
|
||||
def test_capi1():
|
||||
@@ -636,8 +693,9 @@ class ExceptionTests(unittest.TestCase):
|
||||
self.assertIsInstance(e, IndexError)
|
||||
self.assertEqual(e.__traceback__, tb)
|
||||
|
||||
- class MyException(Exception):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyException(Exception):
|
||||
+ pass
|
||||
|
||||
e = MyException().with_traceback(tb)
|
||||
self.assertIsInstance(e, MyException)
|
||||
@@ -696,8 +754,9 @@ class ExceptionTests(unittest.TestCase):
|
||||
self.assertIsNone(e.__context__)
|
||||
self.assertIsNone(e.__cause__)
|
||||
|
||||
- class MyException(OSError):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyException(OSError):
|
||||
+ pass
|
||||
|
||||
e = MyException()
|
||||
self.assertIsNone(e.__context__)
|
||||
@@ -726,10 +785,11 @@ class ExceptionTests(unittest.TestCase):
|
||||
# but user-defined subclasses can if they want
|
||||
self.assertRaises(TypeError, BaseException, a=1)
|
||||
|
||||
- class DerivedException(BaseException):
|
||||
- def __init__(self, fancy_arg):
|
||||
- BaseException.__init__(self)
|
||||
- self.fancy_arg = fancy_arg
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class DerivedException(BaseException):
|
||||
+ def __init__(self, fancy_arg):
|
||||
+ BaseException.__init__(self)
|
||||
+ self.fancy_arg = fancy_arg
|
||||
|
||||
x = DerivedException(fancy_arg=42)
|
||||
self.assertEqual(x.fancy_arg, 42)
|
||||
@@ -779,11 +839,12 @@ class ExceptionTests(unittest.TestCase):
|
||||
# Make sure exception state is cleaned up as soon as the except
|
||||
# block is left. See #2507
|
||||
|
||||
- class MyException(Exception):
|
||||
- def __init__(self, obj):
|
||||
- self.obj = obj
|
||||
- class MyObj:
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyException(Exception):
|
||||
+ def __init__(self, obj):
|
||||
+ self.obj = obj
|
||||
+ class MyObj:
|
||||
+ pass
|
||||
|
||||
def inner_raising_func():
|
||||
# Create some references in exception value and traceback
|
||||
@@ -881,11 +942,12 @@ class ExceptionTests(unittest.TestCase):
|
||||
self.assertIsNone(obj)
|
||||
|
||||
# Inside an exception-silencing "with" block
|
||||
- class Context:
|
||||
- def __enter__(self):
|
||||
- return self
|
||||
- def __exit__ (self, exc_type, exc_value, exc_tb):
|
||||
- return True
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Context:
|
||||
+ def __enter__(self):
|
||||
+ return self
|
||||
+ def __exit__ (self, exc_type, exc_value, exc_tb):
|
||||
+ return True
|
||||
obj = MyObj()
|
||||
wr = weakref.ref(obj)
|
||||
with Context():
|
||||
@@ -1027,11 +1089,12 @@ class ExceptionTests(unittest.TestCase):
|
||||
def _check_generator_cleanup_exc_state(self, testfunc):
|
||||
# Issue #12791: exception state is cleaned up as soon as a generator
|
||||
# is closed (reference cycles are broken).
|
||||
- class MyException(Exception):
|
||||
- def __init__(self, obj):
|
||||
- self.obj = obj
|
||||
- class MyObj:
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyException(Exception):
|
||||
+ def __init__(self, obj):
|
||||
+ self.obj = obj
|
||||
+ class MyObj:
|
||||
+ pass
|
||||
|
||||
def raising_gen():
|
||||
try:
|
||||
@@ -1090,10 +1153,11 @@ class ExceptionTests(unittest.TestCase):
|
||||
def test_3114(self):
|
||||
# Bug #3114: in its destructor, MyObject retrieves a pointer to
|
||||
# obsolete and/or deallocated objects.
|
||||
- class MyObject:
|
||||
- def __del__(self):
|
||||
- nonlocal e
|
||||
- e = sys.exception()
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyObject:
|
||||
+ def __del__(self):
|
||||
+ nonlocal e
|
||||
+ e = sys.exception()
|
||||
e = ()
|
||||
try:
|
||||
raise Exception(MyObject())
|
||||
@@ -1103,12 +1167,13 @@ class ExceptionTests(unittest.TestCase):
|
||||
self.assertIsNone(e)
|
||||
|
||||
def test_raise_does_not_create_context_chain_cycle(self):
|
||||
- class A(Exception):
|
||||
- pass
|
||||
- class B(Exception):
|
||||
- pass
|
||||
- class C(Exception):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A(Exception):
|
||||
+ pass
|
||||
+ class B(Exception):
|
||||
+ pass
|
||||
+ class C(Exception):
|
||||
+ pass
|
||||
|
||||
# Create a context chain:
|
||||
# C -> B -> A
|
||||
@@ -1164,12 +1229,13 @@ class ExceptionTests(unittest.TestCase):
|
||||
def test_no_hang_on_context_chain_cycle2(self):
|
||||
# See issue 25782. Cycle at head of context chain.
|
||||
|
||||
- class A(Exception):
|
||||
- pass
|
||||
- class B(Exception):
|
||||
- pass
|
||||
- class C(Exception):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A(Exception):
|
||||
+ pass
|
||||
+ class B(Exception):
|
||||
+ pass
|
||||
+ class C(Exception):
|
||||
+ pass
|
||||
|
||||
# Context cycle:
|
||||
# +-----------+
|
||||
@@ -1200,16 +1266,17 @@ class ExceptionTests(unittest.TestCase):
|
||||
def test_no_hang_on_context_chain_cycle3(self):
|
||||
# See issue 25782. Longer context chain with cycle.
|
||||
|
||||
- class A(Exception):
|
||||
- pass
|
||||
- class B(Exception):
|
||||
- pass
|
||||
- class C(Exception):
|
||||
- pass
|
||||
- class D(Exception):
|
||||
- pass
|
||||
- class E(Exception):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A(Exception):
|
||||
+ pass
|
||||
+ class B(Exception):
|
||||
+ pass
|
||||
+ class C(Exception):
|
||||
+ pass
|
||||
+ class D(Exception):
|
||||
+ pass
|
||||
+ class E(Exception):
|
||||
+ pass
|
||||
|
||||
# Context cycle:
|
||||
# +-----------+
|
||||
@@ -1364,11 +1431,12 @@ class ExceptionTests(unittest.TestCase):
|
||||
def test_badisinstance(self):
|
||||
# Bug #2542: if issubclass(e, MyException) raises an exception,
|
||||
# it should be ignored
|
||||
- class Meta(type):
|
||||
- def __subclasscheck__(cls, subclass):
|
||||
- raise ValueError()
|
||||
- class MyException(Exception, metaclass=Meta):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Meta(type):
|
||||
+ def __subclasscheck__(cls, subclass):
|
||||
+ raise ValueError()
|
||||
+ class MyException(Exception, metaclass=Meta):
|
||||
+ pass
|
||||
|
||||
with captured_stderr() as stderr:
|
||||
try:
|
||||
@@ -1602,8 +1670,9 @@ class ExceptionTests(unittest.TestCase):
|
||||
self.assertTrue(issubclass(error3, error2))
|
||||
|
||||
# test with explicit base tuple
|
||||
- class C(object):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C(object):
|
||||
+ pass
|
||||
error4 = _testcapi.make_exception_with_doc("_testcapi.error4", doc4,
|
||||
(error3, C))
|
||||
self.assertTrue(issubclass(error4, error3))
|
||||
@@ -1623,8 +1692,9 @@ class ExceptionTests(unittest.TestCase):
|
||||
# Issue #5437: preallocated MemoryError instances should not keep
|
||||
# traceback objects alive.
|
||||
from _testcapi import raise_memoryerror
|
||||
- class C:
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C:
|
||||
+ pass
|
||||
wr = None
|
||||
def inner():
|
||||
nonlocal wr
|
||||
@@ -1644,8 +1714,9 @@ class ExceptionTests(unittest.TestCase):
|
||||
@no_tracing
|
||||
def test_recursion_error_cleanup(self):
|
||||
# Same test as above, but with "recursion exceeded" errors
|
||||
- class C:
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C:
|
||||
+ pass
|
||||
wr = None
|
||||
def inner():
|
||||
nonlocal wr
|
||||
@@ -1670,11 +1741,12 @@ class ExceptionTests(unittest.TestCase):
|
||||
|
||||
def test_unraisable(self):
|
||||
# Issue #22836: PyErr_WriteUnraisable() should give sensible reports
|
||||
- class BrokenDel:
|
||||
- def __del__(self):
|
||||
- exc = ValueError("del is broken")
|
||||
- # The following line is included in the traceback report:
|
||||
- raise exc
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class BrokenDel:
|
||||
+ def __del__(self):
|
||||
+ exc = ValueError("del is broken")
|
||||
+ # The following line is included in the traceback report:
|
||||
+ raise exc
|
||||
|
||||
obj = BrokenDel()
|
||||
with support.catch_unraisable_exception() as cm:
|
||||
@@ -1728,11 +1800,12 @@ class ExceptionTests(unittest.TestCase):
|
||||
|
||||
def test_yield_in_nested_try_excepts(self):
|
||||
#Issue #25612
|
||||
- class MainError(Exception):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MainError(Exception):
|
||||
+ pass
|
||||
|
||||
- class SubError(Exception):
|
||||
- pass
|
||||
+ class SubError(Exception):
|
||||
+ pass
|
||||
|
||||
def main():
|
||||
try:
|
||||
@@ -1807,8 +1880,9 @@ class ExceptionTests(unittest.TestCase):
|
||||
# subclass object. Finally, it checks that creating a new MemoryError
|
||||
# succeeds, proving that the freelist is not corrupted.
|
||||
|
||||
- class TestException(MemoryError):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class TestException(MemoryError):
|
||||
+ pass
|
||||
|
||||
try:
|
||||
raise MemoryError
|
||||
@@ -1844,7 +1918,7 @@ class ExceptionTests(unittest.TestCase):
|
||||
self.assertIn(b'MemoryError', err)
|
||||
|
||||
|
||||
@ -80,7 +378,18 @@ index c91f6662948..0ded70db3c7 100644
|
||||
def test_name_error_has_name(self):
|
||||
try:
|
||||
bluch
|
||||
@@ -1894,7 +1950,7 @@ class NameErrorTests(unittest.TestCase):
|
||||
@@ -1886,15 +1960,16 @@ class NameErrorTests(unittest.TestCase):
|
||||
|
||||
def test_gh_111654(self):
|
||||
def f():
|
||||
- class TestClass:
|
||||
- TestClass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class TestClass:
|
||||
+ TestClass
|
||||
|
||||
self.assertRaises(NameError, f)
|
||||
|
||||
# Note: name suggestion tests live in `test_traceback`.
|
||||
|
||||
|
||||
@ -89,7 +398,33 @@ index c91f6662948..0ded70db3c7 100644
|
||||
def test_attributes(self):
|
||||
# Setting 'attr' should not be a problem.
|
||||
exc = AttributeError('Ouch!')
|
||||
@@ -1937,7 +1993,7 @@ class AttributeErrorTests(unittest.TestCase):
|
||||
@@ -1907,8 +1982,9 @@ class AttributeErrorTests(unittest.TestCase):
|
||||
self.assertIs(exc.obj, sentinel)
|
||||
|
||||
def test_getattr_has_name_and_obj(self):
|
||||
- class A:
|
||||
- blech = None
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A:
|
||||
+ blech = None
|
||||
|
||||
obj = A()
|
||||
try:
|
||||
@@ -1923,9 +1999,10 @@ class AttributeErrorTests(unittest.TestCase):
|
||||
self.assertEqual(obj, exc.obj)
|
||||
|
||||
def test_getattr_has_name_and_obj_for_method(self):
|
||||
- class A:
|
||||
- def blech(self):
|
||||
- return
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class A:
|
||||
+ def blech(self):
|
||||
+ return
|
||||
|
||||
obj = A()
|
||||
try:
|
||||
@@ -1937,7 +2014,7 @@ class AttributeErrorTests(unittest.TestCase):
|
||||
# Note: name suggestion tests live in `test_traceback`.
|
||||
|
||||
|
||||
@ -98,7 +433,7 @@ index c91f6662948..0ded70db3c7 100644
|
||||
|
||||
def test_attributes(self):
|
||||
# Setting 'name' and 'path' should not be a problem.
|
||||
@@ -2024,7 +2080,7 @@ def run_script(source):
|
||||
@@ -2024,7 +2101,7 @@ def run_script(source):
|
||||
_rc, _out, err = script_helper.assert_python_failure('-Wd', '-X', 'utf8', TESTFN)
|
||||
return err.decode('utf-8').splitlines()
|
||||
|
||||
@ -107,7 +442,7 @@ index c91f6662948..0ded70db3c7 100644
|
||||
def tearDown(self):
|
||||
unlink(TESTFN)
|
||||
|
||||
@@ -2159,7 +2215,7 @@ class AssertionErrorTests(unittest.TestCase):
|
||||
@@ -2159,7 +2236,7 @@ class AssertionErrorTests(unittest.TestCase):
|
||||
|
||||
|
||||
@support.force_not_colorized_test_class
|
||||
@ -116,7 +451,19 @@ index c91f6662948..0ded70db3c7 100644
|
||||
maxDiff = None
|
||||
|
||||
@force_not_colorized
|
||||
@@ -2290,6 +2346,7 @@ class SyntaxErrorTests(unittest.TestCase):
|
||||
@@ -2254,8 +2331,9 @@ class SyntaxErrorTests(unittest.TestCase):
|
||||
the_exception = exc
|
||||
|
||||
def test_subclass(self):
|
||||
- class MySyntaxError(SyntaxError):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MySyntaxError(SyntaxError):
|
||||
+ pass
|
||||
|
||||
try:
|
||||
raise MySyntaxError("bad bad", ("bad.py", 1, 2, "abcdefg", 1, 7))
|
||||
@@ -2290,6 +2368,7 @@ class SyntaxErrorTests(unittest.TestCase):
|
||||
err = run_script(b"\x89")
|
||||
self.assertIn("SyntaxError: Non-UTF-8 code starting with '\\x89' in file", err[-1])
|
||||
|
||||
@ -124,7 +471,7 @@ index c91f6662948..0ded70db3c7 100644
|
||||
def test_string_source(self):
|
||||
def try_compile(source):
|
||||
with self.assertRaises(SyntaxError) as cm:
|
||||
@@ -2405,7 +2462,7 @@ class SyntaxErrorTests(unittest.TestCase):
|
||||
@@ -2405,7 +2484,7 @@ class SyntaxErrorTests(unittest.TestCase):
|
||||
self.assertRaises(TypeError, SyntaxError, "bad bad", args)
|
||||
|
||||
|
||||
@ -133,7 +480,7 @@ index c91f6662948..0ded70db3c7 100644
|
||||
def test_except_star_invalid_exception_type(self):
|
||||
with self.assertRaises(TypeError):
|
||||
try:
|
||||
@@ -2420,7 +2477,7 @@ class TestInvalidExceptionMatcher(unittest.TestCase):
|
||||
@@ -2420,7 +2499,7 @@ class TestInvalidExceptionMatcher(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
@ -142,7 +489,42 @@ index c91f6662948..0ded70db3c7 100644
|
||||
|
||||
def lineno_after_raise(self, f, *expected):
|
||||
try:
|
||||
@@ -2529,5 +2586,5 @@ class PEP626Tests(unittest.TestCase):
|
||||
@@ -2499,11 +2578,12 @@ class PEP626Tests(unittest.TestCase):
|
||||
self.lineno_after_raise(in_finally_except, 4)
|
||||
|
||||
def test_lineno_after_with(self):
|
||||
- class Noop:
|
||||
- def __enter__(self):
|
||||
- return self
|
||||
- def __exit__(self, *args):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Noop:
|
||||
+ def __enter__(self):
|
||||
+ return self
|
||||
+ def __exit__(self, *args):
|
||||
+ pass
|
||||
def after_with():
|
||||
with Noop():
|
||||
1/0
|
||||
@@ -2518,16 +2598,17 @@ class PEP626Tests(unittest.TestCase):
|
||||
self.lineno_after_raise(f, None)
|
||||
|
||||
def test_lineno_after_raise_in_with_exit(self):
|
||||
- class ExitFails:
|
||||
- def __enter__(self):
|
||||
- return self
|
||||
- def __exit__(self, *args):
|
||||
- raise ValueError
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class ExitFails:
|
||||
+ def __enter__(self):
|
||||
+ return self
|
||||
+ def __exit__(self, *args):
|
||||
+ raise ValueError
|
||||
|
||||
def after_with():
|
||||
with ExitFails():
|
||||
1/0
|
||||
self.lineno_after_raise(after_with, 1, 1)
|
||||
|
||||
|
@ -399,12 +399,13 @@ class ExceptionTests(__TestCase):
|
||||
# test that setting an exception at the C level works even if the
|
||||
# exception object can't be constructed.
|
||||
|
||||
class BadException(Exception):
|
||||
def __init__(self_):
|
||||
raise RuntimeError("can't instantiate BadException")
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class BadException(Exception):
|
||||
def __init__(self_):
|
||||
raise RuntimeError("can't instantiate BadException")
|
||||
|
||||
class InvalidException:
|
||||
pass
|
||||
class InvalidException:
|
||||
pass
|
||||
|
||||
@unittest.skipIf(_testcapi is None, "requires _testcapi")
|
||||
def test_capi1():
|
||||
@ -692,8 +693,9 @@ class ExceptionTests(__TestCase):
|
||||
self.assertIsInstance(e, IndexError)
|
||||
self.assertEqual(e.__traceback__, tb)
|
||||
|
||||
class MyException(Exception):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyException(Exception):
|
||||
pass
|
||||
|
||||
e = MyException().with_traceback(tb)
|
||||
self.assertIsInstance(e, MyException)
|
||||
@ -752,8 +754,9 @@ class ExceptionTests(__TestCase):
|
||||
self.assertIsNone(e.__context__)
|
||||
self.assertIsNone(e.__cause__)
|
||||
|
||||
class MyException(OSError):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyException(OSError):
|
||||
pass
|
||||
|
||||
e = MyException()
|
||||
self.assertIsNone(e.__context__)
|
||||
@ -782,10 +785,11 @@ class ExceptionTests(__TestCase):
|
||||
# but user-defined subclasses can if they want
|
||||
self.assertRaises(TypeError, BaseException, a=1)
|
||||
|
||||
class DerivedException(BaseException):
|
||||
def __init__(self, fancy_arg):
|
||||
BaseException.__init__(self)
|
||||
self.fancy_arg = fancy_arg
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class DerivedException(BaseException):
|
||||
def __init__(self, fancy_arg):
|
||||
BaseException.__init__(self)
|
||||
self.fancy_arg = fancy_arg
|
||||
|
||||
x = DerivedException(fancy_arg=42)
|
||||
self.assertEqual(x.fancy_arg, 42)
|
||||
@ -835,11 +839,12 @@ class ExceptionTests(__TestCase):
|
||||
# Make sure exception state is cleaned up as soon as the except
|
||||
# block is left. See #2507
|
||||
|
||||
class MyException(Exception):
|
||||
def __init__(self, obj):
|
||||
self.obj = obj
|
||||
class MyObj:
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyException(Exception):
|
||||
def __init__(self, obj):
|
||||
self.obj = obj
|
||||
class MyObj:
|
||||
pass
|
||||
|
||||
def inner_raising_func():
|
||||
# Create some references in exception value and traceback
|
||||
@ -937,11 +942,12 @@ class ExceptionTests(__TestCase):
|
||||
self.assertIsNone(obj)
|
||||
|
||||
# Inside an exception-silencing "with" block
|
||||
class Context:
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__ (self, exc_type, exc_value, exc_tb):
|
||||
return True
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Context:
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__ (self, exc_type, exc_value, exc_tb):
|
||||
return True
|
||||
obj = MyObj()
|
||||
wr = weakref.ref(obj)
|
||||
with Context():
|
||||
@ -1083,11 +1089,12 @@ class ExceptionTests(__TestCase):
|
||||
def _check_generator_cleanup_exc_state(self, testfunc):
|
||||
# Issue #12791: exception state is cleaned up as soon as a generator
|
||||
# is closed (reference cycles are broken).
|
||||
class MyException(Exception):
|
||||
def __init__(self, obj):
|
||||
self.obj = obj
|
||||
class MyObj:
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyException(Exception):
|
||||
def __init__(self, obj):
|
||||
self.obj = obj
|
||||
class MyObj:
|
||||
pass
|
||||
|
||||
def raising_gen():
|
||||
try:
|
||||
@ -1146,10 +1153,11 @@ class ExceptionTests(__TestCase):
|
||||
def test_3114(self):
|
||||
# Bug #3114: in its destructor, MyObject retrieves a pointer to
|
||||
# obsolete and/or deallocated objects.
|
||||
class MyObject:
|
||||
def __del__(self):
|
||||
nonlocal e
|
||||
e = sys.exception()
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyObject:
|
||||
def __del__(self):
|
||||
nonlocal e
|
||||
e = sys.exception()
|
||||
e = ()
|
||||
try:
|
||||
raise Exception(MyObject())
|
||||
@ -1159,12 +1167,13 @@ class ExceptionTests(__TestCase):
|
||||
self.assertIsNone(e)
|
||||
|
||||
def test_raise_does_not_create_context_chain_cycle(self):
|
||||
class A(Exception):
|
||||
pass
|
||||
class B(Exception):
|
||||
pass
|
||||
class C(Exception):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A(Exception):
|
||||
pass
|
||||
class B(Exception):
|
||||
pass
|
||||
class C(Exception):
|
||||
pass
|
||||
|
||||
# Create a context chain:
|
||||
# C -> B -> A
|
||||
@ -1220,12 +1229,13 @@ class ExceptionTests(__TestCase):
|
||||
def test_no_hang_on_context_chain_cycle2(self):
|
||||
# See issue 25782. Cycle at head of context chain.
|
||||
|
||||
class A(Exception):
|
||||
pass
|
||||
class B(Exception):
|
||||
pass
|
||||
class C(Exception):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A(Exception):
|
||||
pass
|
||||
class B(Exception):
|
||||
pass
|
||||
class C(Exception):
|
||||
pass
|
||||
|
||||
# Context cycle:
|
||||
# +-----------+
|
||||
@ -1256,16 +1266,17 @@ class ExceptionTests(__TestCase):
|
||||
def test_no_hang_on_context_chain_cycle3(self):
|
||||
# See issue 25782. Longer context chain with cycle.
|
||||
|
||||
class A(Exception):
|
||||
pass
|
||||
class B(Exception):
|
||||
pass
|
||||
class C(Exception):
|
||||
pass
|
||||
class D(Exception):
|
||||
pass
|
||||
class E(Exception):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A(Exception):
|
||||
pass
|
||||
class B(Exception):
|
||||
pass
|
||||
class C(Exception):
|
||||
pass
|
||||
class D(Exception):
|
||||
pass
|
||||
class E(Exception):
|
||||
pass
|
||||
|
||||
# Context cycle:
|
||||
# +-----------+
|
||||
@ -1420,11 +1431,12 @@ class ExceptionTests(__TestCase):
|
||||
def test_badisinstance(self):
|
||||
# Bug #2542: if issubclass(e, MyException) raises an exception,
|
||||
# it should be ignored
|
||||
class Meta(type):
|
||||
def __subclasscheck__(cls, subclass):
|
||||
raise ValueError()
|
||||
class MyException(Exception, metaclass=Meta):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Meta(type):
|
||||
def __subclasscheck__(cls, subclass):
|
||||
raise ValueError()
|
||||
class MyException(Exception, metaclass=Meta):
|
||||
pass
|
||||
|
||||
with captured_stderr() as stderr:
|
||||
try:
|
||||
@ -1658,8 +1670,9 @@ class ExceptionTests(__TestCase):
|
||||
self.assertTrue(issubclass(error3, error2))
|
||||
|
||||
# test with explicit base tuple
|
||||
class C(object):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C(object):
|
||||
pass
|
||||
error4 = _testcapi.make_exception_with_doc("_testcapi.error4", doc4,
|
||||
(error3, C))
|
||||
self.assertTrue(issubclass(error4, error3))
|
||||
@ -1679,8 +1692,9 @@ class ExceptionTests(__TestCase):
|
||||
# Issue #5437: preallocated MemoryError instances should not keep
|
||||
# traceback objects alive.
|
||||
from _testcapi import raise_memoryerror
|
||||
class C:
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C:
|
||||
pass
|
||||
wr = None
|
||||
def inner():
|
||||
nonlocal wr
|
||||
@ -1700,8 +1714,9 @@ class ExceptionTests(__TestCase):
|
||||
@no_tracing
|
||||
def test_recursion_error_cleanup(self):
|
||||
# Same test as above, but with "recursion exceeded" errors
|
||||
class C:
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C:
|
||||
pass
|
||||
wr = None
|
||||
def inner():
|
||||
nonlocal wr
|
||||
@ -1726,11 +1741,12 @@ class ExceptionTests(__TestCase):
|
||||
|
||||
def test_unraisable(self):
|
||||
# Issue #22836: PyErr_WriteUnraisable() should give sensible reports
|
||||
class BrokenDel:
|
||||
def __del__(self):
|
||||
exc = ValueError("del is broken")
|
||||
# The following line is included in the traceback report:
|
||||
raise exc
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class BrokenDel:
|
||||
def __del__(self):
|
||||
exc = ValueError("del is broken")
|
||||
# The following line is included in the traceback report:
|
||||
raise exc
|
||||
|
||||
obj = BrokenDel()
|
||||
with support.catch_unraisable_exception() as cm:
|
||||
@ -1784,11 +1800,12 @@ class ExceptionTests(__TestCase):
|
||||
|
||||
def test_yield_in_nested_try_excepts(self):
|
||||
#Issue #25612
|
||||
class MainError(Exception):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MainError(Exception):
|
||||
pass
|
||||
|
||||
class SubError(Exception):
|
||||
pass
|
||||
class SubError(Exception):
|
||||
pass
|
||||
|
||||
def main():
|
||||
try:
|
||||
@ -1863,8 +1880,9 @@ class ExceptionTests(__TestCase):
|
||||
# subclass object. Finally, it checks that creating a new MemoryError
|
||||
# succeeds, proving that the freelist is not corrupted.
|
||||
|
||||
class TestException(MemoryError):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class TestException(MemoryError):
|
||||
pass
|
||||
|
||||
try:
|
||||
raise MemoryError
|
||||
@ -1942,8 +1960,9 @@ class NameErrorTests(__TestCase):
|
||||
|
||||
def test_gh_111654(self):
|
||||
def f():
|
||||
class TestClass:
|
||||
TestClass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class TestClass:
|
||||
TestClass
|
||||
|
||||
self.assertRaises(NameError, f)
|
||||
|
||||
@ -1963,8 +1982,9 @@ class AttributeErrorTests(__TestCase):
|
||||
self.assertIs(exc.obj, sentinel)
|
||||
|
||||
def test_getattr_has_name_and_obj(self):
|
||||
class A:
|
||||
blech = None
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A:
|
||||
blech = None
|
||||
|
||||
obj = A()
|
||||
try:
|
||||
@ -1979,9 +1999,10 @@ class AttributeErrorTests(__TestCase):
|
||||
self.assertEqual(obj, exc.obj)
|
||||
|
||||
def test_getattr_has_name_and_obj_for_method(self):
|
||||
class A:
|
||||
def blech(self):
|
||||
return
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class A:
|
||||
def blech(self):
|
||||
return
|
||||
|
||||
obj = A()
|
||||
try:
|
||||
@ -2310,8 +2331,9 @@ class SyntaxErrorTests(__TestCase):
|
||||
the_exception = exc
|
||||
|
||||
def test_subclass(self):
|
||||
class MySyntaxError(SyntaxError):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MySyntaxError(SyntaxError):
|
||||
pass
|
||||
|
||||
try:
|
||||
raise MySyntaxError("bad bad", ("bad.py", 1, 2, "abcdefg", 1, 7))
|
||||
@ -2556,11 +2578,12 @@ class PEP626Tests(__TestCase):
|
||||
self.lineno_after_raise(in_finally_except, 4)
|
||||
|
||||
def test_lineno_after_with(self):
|
||||
class Noop:
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Noop:
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
def after_with():
|
||||
with Noop():
|
||||
1/0
|
||||
@ -2575,11 +2598,12 @@ class PEP626Tests(__TestCase):
|
||||
self.lineno_after_raise(f, None)
|
||||
|
||||
def test_lineno_after_raise_in_with_exit(self):
|
||||
class ExitFails:
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, *args):
|
||||
raise ValueError
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class ExitFails:
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, *args):
|
||||
raise ValueError
|
||||
|
||||
def after_with():
|
||||
with ExitFails():
|
||||
|
@ -1,5 +1,5 @@
|
||||
diff --git a/test/dynamo/cpython/3_13/test_raise.py b/test/dynamo/cpython/3_13/test_raise.py
|
||||
index 6d26a61bee4..042d1ae3d7c 100644
|
||||
index 6d26a61bee4..ce748433d28 100644
|
||||
--- a/test/dynamo/cpython/3_13/test_raise.py
|
||||
+++ b/test/dynamo/cpython/3_13/test_raise.py
|
||||
@@ -1,3 +1,58 @@
|
||||
@ -70,7 +70,35 @@ index 6d26a61bee4..042d1ae3d7c 100644
|
||||
def test_invalid_reraise(self):
|
||||
try:
|
||||
raise
|
||||
@@ -148,7 +203,7 @@ class TestRaise(unittest.TestCase):
|
||||
@@ -120,9 +175,10 @@ class TestRaise(unittest.TestCase):
|
||||
self.assertRaises(StopIteration, lambda: next(g))
|
||||
|
||||
def test_erroneous_exception(self):
|
||||
- class MyException(Exception):
|
||||
- def __init__(self):
|
||||
- raise RuntimeError()
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyException(Exception):
|
||||
+ def __init__(self):
|
||||
+ raise RuntimeError()
|
||||
|
||||
try:
|
||||
raise MyException
|
||||
@@ -133,9 +189,10 @@ class TestRaise(unittest.TestCase):
|
||||
|
||||
def test_new_returns_invalid_instance(self):
|
||||
# See issue #11627.
|
||||
- class MyException(Exception):
|
||||
- def __new__(cls, *args):
|
||||
- return object()
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyException(Exception):
|
||||
+ def __new__(cls, *args):
|
||||
+ return object()
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
raise MyException
|
||||
@@ -148,7 +205,7 @@ class TestRaise(unittest.TestCase):
|
||||
|
||||
|
||||
|
||||
@ -79,7 +107,37 @@ index 6d26a61bee4..042d1ae3d7c 100644
|
||||
|
||||
def testCauseSyntax(self):
|
||||
try:
|
||||
@@ -221,7 +276,7 @@ class TestCause(unittest.TestCase):
|
||||
@@ -186,10 +243,11 @@ class TestCause(unittest.TestCase):
|
||||
self.fail("No exception raised")
|
||||
|
||||
def test_class_cause_nonexception_result(self):
|
||||
- class ConstructsNone(BaseException):
|
||||
- @classmethod
|
||||
- def __new__(*args, **kwargs):
|
||||
- return None
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class ConstructsNone(BaseException):
|
||||
+ @classmethod
|
||||
+ def __new__(*args, **kwargs):
|
||||
+ return None
|
||||
try:
|
||||
raise IndexError from ConstructsNone
|
||||
except TypeError as e:
|
||||
@@ -209,9 +267,10 @@ class TestCause(unittest.TestCase):
|
||||
self.fail("No exception raised")
|
||||
|
||||
def test_erroneous_cause(self):
|
||||
- class MyException(Exception):
|
||||
- def __init__(self):
|
||||
- raise RuntimeError()
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class MyException(Exception):
|
||||
+ def __init__(self):
|
||||
+ raise RuntimeError()
|
||||
|
||||
try:
|
||||
raise IndexError from MyException
|
||||
@@ -221,7 +280,7 @@ class TestCause(unittest.TestCase):
|
||||
self.fail("No exception raised")
|
||||
|
||||
|
||||
@ -88,7 +146,7 @@ index 6d26a61bee4..042d1ae3d7c 100644
|
||||
|
||||
def test_sets_traceback(self):
|
||||
try:
|
||||
@@ -242,7 +297,7 @@ class TestTraceback(unittest.TestCase):
|
||||
@@ -242,7 +301,7 @@ class TestTraceback(unittest.TestCase):
|
||||
self.fail("No exception raised")
|
||||
|
||||
|
||||
@ -97,7 +155,7 @@ index 6d26a61bee4..042d1ae3d7c 100644
|
||||
|
||||
def raiser(self):
|
||||
raise ValueError
|
||||
@@ -308,7 +363,7 @@ class TestTracebackType(unittest.TestCase):
|
||||
@@ -308,7 +367,7 @@ class TestTracebackType(unittest.TestCase):
|
||||
types.TracebackType(other_tb, frame, 1, "nuh-uh")
|
||||
|
||||
|
||||
@ -106,7 +164,45 @@ index 6d26a61bee4..042d1ae3d7c 100644
|
||||
def test_instance_context_instance_raise(self):
|
||||
context = IndexError()
|
||||
try:
|
||||
@@ -498,7 +553,7 @@ class TestContext(unittest.TestCase):
|
||||
@@ -392,11 +451,12 @@ class TestContext(unittest.TestCase):
|
||||
self.fail("No exception raised")
|
||||
|
||||
def test_context_manager(self):
|
||||
- class ContextManager:
|
||||
- def __enter__(self):
|
||||
- pass
|
||||
- def __exit__(self, t, v, tb):
|
||||
- xyzzy
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class ContextManager:
|
||||
+ def __enter__(self):
|
||||
+ pass
|
||||
+ def __exit__(self, t, v, tb):
|
||||
+ xyzzy
|
||||
try:
|
||||
with ContextManager():
|
||||
1/0
|
||||
@@ -471,12 +531,13 @@ class TestContext(unittest.TestCase):
|
||||
import gc
|
||||
# A re-raised exception in a __del__ caused the __context__
|
||||
# to be cleared
|
||||
- class C:
|
||||
- def __del__(self):
|
||||
- try:
|
||||
- 1/0
|
||||
- except:
|
||||
- raise
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class C:
|
||||
+ def __del__(self):
|
||||
+ try:
|
||||
+ 1/0
|
||||
+ except:
|
||||
+ raise
|
||||
|
||||
def f():
|
||||
x = C()
|
||||
@@ -498,7 +559,7 @@ class TestContext(unittest.TestCase):
|
||||
self.assertEqual(ZeroDivisionError, cm.unraisable.exc_type)
|
||||
|
||||
|
||||
@ -115,7 +211,7 @@ index 6d26a61bee4..042d1ae3d7c 100644
|
||||
def test_tuples(self):
|
||||
try:
|
||||
raise (IndexError, KeyError) # This should be a tuple!
|
||||
@@ -517,4 +572,4 @@ class TestRemovedFunctionality(unittest.TestCase):
|
||||
@@ -517,4 +578,4 @@ class TestRemovedFunctionality(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -175,9 +175,10 @@ class TestRaise(__TestCase):
|
||||
self.assertRaises(StopIteration, lambda: next(g))
|
||||
|
||||
def test_erroneous_exception(self):
|
||||
class MyException(Exception):
|
||||
def __init__(self):
|
||||
raise RuntimeError()
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyException(Exception):
|
||||
def __init__(self):
|
||||
raise RuntimeError()
|
||||
|
||||
try:
|
||||
raise MyException
|
||||
@ -188,9 +189,10 @@ class TestRaise(__TestCase):
|
||||
|
||||
def test_new_returns_invalid_instance(self):
|
||||
# See issue #11627.
|
||||
class MyException(Exception):
|
||||
def __new__(cls, *args):
|
||||
return object()
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyException(Exception):
|
||||
def __new__(cls, *args):
|
||||
return object()
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
raise MyException
|
||||
@ -241,10 +243,11 @@ class TestCause(__TestCase):
|
||||
self.fail("No exception raised")
|
||||
|
||||
def test_class_cause_nonexception_result(self):
|
||||
class ConstructsNone(BaseException):
|
||||
@classmethod
|
||||
def __new__(*args, **kwargs):
|
||||
return None
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class ConstructsNone(BaseException):
|
||||
@classmethod
|
||||
def __new__(*args, **kwargs):
|
||||
return None
|
||||
try:
|
||||
raise IndexError from ConstructsNone
|
||||
except TypeError as e:
|
||||
@ -264,9 +267,10 @@ class TestCause(__TestCase):
|
||||
self.fail("No exception raised")
|
||||
|
||||
def test_erroneous_cause(self):
|
||||
class MyException(Exception):
|
||||
def __init__(self):
|
||||
raise RuntimeError()
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class MyException(Exception):
|
||||
def __init__(self):
|
||||
raise RuntimeError()
|
||||
|
||||
try:
|
||||
raise IndexError from MyException
|
||||
@ -447,11 +451,12 @@ class TestContext(__TestCase):
|
||||
self.fail("No exception raised")
|
||||
|
||||
def test_context_manager(self):
|
||||
class ContextManager:
|
||||
def __enter__(self):
|
||||
pass
|
||||
def __exit__(self, t, v, tb):
|
||||
xyzzy
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class ContextManager:
|
||||
def __enter__(self):
|
||||
pass
|
||||
def __exit__(self, t, v, tb):
|
||||
xyzzy
|
||||
try:
|
||||
with ContextManager():
|
||||
1/0
|
||||
@ -526,12 +531,13 @@ class TestContext(__TestCase):
|
||||
import gc
|
||||
# A re-raised exception in a __del__ caused the __context__
|
||||
# to be cleared
|
||||
class C:
|
||||
def __del__(self):
|
||||
try:
|
||||
1/0
|
||||
except:
|
||||
raise
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class C:
|
||||
def __del__(self):
|
||||
try:
|
||||
1/0
|
||||
except:
|
||||
raise
|
||||
|
||||
def f():
|
||||
x = C()
|
||||
|
@ -916,43 +916,41 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
||||
dedent(
|
||||
"""\
|
||||
SeqNr|OrigAten|SrcFn|FwdSrcFn
|
||||
0|aten.convolution.default|l__self___conv1|
|
||||
0|aten.add.Tensor|l__self___bn1|
|
||||
1|aten._native_batch_norm_legit_functional.default|l__self___bn1|
|
||||
2|aten.relu.default|l__self___relu1|
|
||||
2|aten.detach.default|l__self___relu1|
|
||||
2|aten.detach.default|l__self___relu1|
|
||||
0|aten.convolution.default|conv2d|
|
||||
0|aten.add.Tensor|add_|
|
||||
1|aten._native_batch_norm_legit_functional.default|batch_norm|
|
||||
2|aten.relu.default|relu|
|
||||
2|aten.detach.default|relu|
|
||||
3|aten.add.Tensor|add|
|
||||
4|aten.view.default|flatten|
|
||||
5|aten.view.default|l__self___fc1|
|
||||
6|aten.t.default|l__self___fc1|
|
||||
7|aten.addmm.default|l__self___fc1|
|
||||
8|aten.view.default|l__self___fc1|
|
||||
9|aten.sub.Tensor|l__self___loss_fn|
|
||||
10|aten.abs.default|l__self___loss_fn|
|
||||
11|aten.mean.default|l__self___loss_fn|
|
||||
11|aten.ones_like.default||l__self___loss_fn
|
||||
11|aten.expand.default||l__self___loss_fn
|
||||
11|aten.div.Scalar||l__self___loss_fn
|
||||
10|aten.sgn.default||l__self___loss_fn
|
||||
10|aten.mul.Tensor||l__self___loss_fn
|
||||
8|aten.view.default||l__self___fc1
|
||||
7|aten.t.default||l__self___fc1
|
||||
7|aten.mm.default||l__self___fc1
|
||||
7|aten.t.default||l__self___fc1
|
||||
7|aten.mm.default||l__self___fc1
|
||||
7|aten.t.default||l__self___fc1
|
||||
7|aten.sum.dim_IntList||l__self___fc1
|
||||
7|aten.view.default||l__self___fc1
|
||||
6|aten.t.default||l__self___fc1
|
||||
5|aten.view.default||l__self___fc1
|
||||
5|aten.view.default|linear|
|
||||
6|aten.t.default|linear|
|
||||
7|aten.addmm.default|linear|
|
||||
8|aten.view.default|linear|
|
||||
9|aten.sub.Tensor|l1_loss|
|
||||
10|aten.abs.default|l1_loss|
|
||||
11|aten.mean.default|l1_loss|
|
||||
11|aten.ones_like.default||l1_loss
|
||||
11|aten.expand.default||l1_loss
|
||||
11|aten.div.Scalar||l1_loss
|
||||
10|aten.sgn.default||l1_loss
|
||||
10|aten.mul.Tensor||l1_loss
|
||||
8|aten.view.default||linear
|
||||
7|aten.t.default||linear
|
||||
7|aten.mm.default||linear
|
||||
7|aten.t.default||linear
|
||||
7|aten.mm.default||linear
|
||||
7|aten.t.default||linear
|
||||
7|aten.sum.dim_IntList||linear
|
||||
7|aten.view.default||linear
|
||||
6|aten.t.default||linear
|
||||
5|aten.view.default||linear
|
||||
4|aten.view.default||flatten
|
||||
2|aten.detach.default||l__self___relu1
|
||||
2|aten.detach.default||l__self___relu1
|
||||
2|aten.threshold_backward.default||l__self___relu1
|
||||
1|aten.native_batch_norm_backward.default||l__self___bn1
|
||||
0|aten.convolution_backward.default||l__self___conv1
|
||||
11|aten.add.Tensor||l__self___loss_fn
|
||||
2|aten.detach.default||relu
|
||||
2|aten.threshold_backward.default||relu
|
||||
1|aten.native_batch_norm_backward.default||batch_norm
|
||||
0|aten.convolution_backward.default||conv2d
|
||||
11|aten.add.Tensor||l1_loss
|
||||
"""
|
||||
),
|
||||
)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user