Compare commits

..

4 Commits

Author SHA1 Message Date
e1d39f3249 decouple fx graph partition with cudagraph wrapper 2025-10-21 23:21:30 -07:00
e43984ddf5 lint 2025-10-20 17:41:41 -07:00
c1498ebb0d nit 2025-10-20 17:37:28 -07:00
f63ef9d7d8 init 2025-10-20 16:28:49 -07:00
432 changed files with 4832 additions and 9691 deletions

View File

@ -19,7 +19,7 @@ pip_install \
transformers==4.36.2
pip_install coloredlogs packaging
pip_install onnxruntime==1.23.1
pip_install onnxruntime==1.23.0
pip_install onnxscript==0.5.4
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers

View File

@ -334,12 +334,12 @@ sympy==1.13.3
#Pinned versions:
#test that import:
onnx==1.19.1
onnx==1.18.0
#Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal
#Pinned versions:
#test that import:
onnxscript==0.5.4
onnxscript==0.5.3
#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal
#Pinned versions:
#test that import:

View File

@ -6,7 +6,7 @@ dependencies = [
"GitPython==3.1.45",
"docker==7.1.0",
"pytest==7.3.2",
"uv==0.9.5"
"uv==0.8.6"
]
[tool.setuptools]

View File

@ -163,13 +163,8 @@ if [[ "$(uname)" != Darwin ]]; then
MEMORY_LIMIT_MAX_JOBS=12
NUM_CPUS=$(( $(nproc) - 2 ))
if [[ "$(uname)" == Linux ]]; then
# Defaults here for **binary** linux builds so they can be changed in one place
export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
else
# For other builds
export MAX_JOBS=${NUM_CPUS}
fi
# Defaults here for **binary** linux builds so they can be changed in one place
export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
cat >>"$envfile" <<EOL
export MAX_JOBS="${MAX_JOBS}"

View File

@ -124,10 +124,3 @@ runs:
id: login-ecr
continue-on-error: true
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
- name: Preserve github env variables for use in docker
shell: bash
run: |
env | grep '^GITHUB' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
env | grep '^CI' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
env | grep '^RUNNER' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"

View File

@ -1 +1 @@
1752fe6809b74921644866275ab80244b96e80bc
faffd5cf673615583da6517275e361cb3dbc77e6

View File

@ -15,11 +15,6 @@
- "module: reinplacing"
then:
- "module: pt2-dispatcher"
- any:
- "vllm-compile"
then:
- "module: vllm"
- "oncall: pt2"
- any:
- "module: vmap"
then:
@ -32,6 +27,10 @@
- "module: pt2 optimizer"
then:
- "module: dynamo"
- any:
- "module: flex attention"
then:
- "module: higher order operators"
- any:
- "module: aotinductor"
then:

View File

@ -33,7 +33,6 @@ ciflow_push_tags:
- ciflow/rocm
- ciflow/rocm-mi300
- ciflow/rocm-mi355
- ciflow/rocm-navi31
- ciflow/s390
- ciflow/slow
- ciflow/torchbench

View File

@ -79,9 +79,9 @@ jobs:
runs-on: "windows-11-arm64-preview"
{%- else %}
{%- if branches == "nightly" %}
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
{%- else %}
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge.nonephemeral"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
{%- endif %}
{%- endif %}
timeout-minutes: !{{ common.timeout_minutes_windows_binary }}

View File

@ -44,7 +44,7 @@ jobs:
libtorch-cpu-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -291,7 +291,7 @@ jobs:
libtorch-cuda12_6-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -541,7 +541,7 @@ jobs:
libtorch-cuda12_8-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -791,7 +791,7 @@ jobs:
libtorch-cuda13_0-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -44,7 +44,7 @@ jobs:
libtorch-cpu-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -291,7 +291,7 @@ jobs:
libtorch-cuda12_6-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -541,7 +541,7 @@ jobs:
libtorch-cuda12_8-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -791,7 +791,7 @@ jobs:
libtorch-cuda13_0-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -44,7 +44,7 @@ jobs:
wheel-py3_10-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -279,7 +279,7 @@ jobs:
wheel-py3_10-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -517,7 +517,7 @@ jobs:
wheel-py3_10-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -755,7 +755,7 @@ jobs:
wheel-py3_10-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -993,7 +993,7 @@ jobs:
wheel-py3_10-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1229,7 +1229,7 @@ jobs:
wheel-py3_11-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1464,7 +1464,7 @@ jobs:
wheel-py3_11-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1702,7 +1702,7 @@ jobs:
wheel-py3_11-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -1940,7 +1940,7 @@ jobs:
wheel-py3_11-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2178,7 +2178,7 @@ jobs:
wheel-py3_11-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2414,7 +2414,7 @@ jobs:
wheel-py3_12-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2649,7 +2649,7 @@ jobs:
wheel-py3_12-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -2887,7 +2887,7 @@ jobs:
wheel-py3_12-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3125,7 +3125,7 @@ jobs:
wheel-py3_12-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3363,7 +3363,7 @@ jobs:
wheel-py3_12-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3599,7 +3599,7 @@ jobs:
wheel-py3_13-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -3834,7 +3834,7 @@ jobs:
wheel-py3_13-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4072,7 +4072,7 @@ jobs:
wheel-py3_13-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4310,7 +4310,7 @@ jobs:
wheel-py3_13-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4548,7 +4548,7 @@ jobs:
wheel-py3_13-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -4784,7 +4784,7 @@ jobs:
wheel-py3_13t-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5019,7 +5019,7 @@ jobs:
wheel-py3_13t-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5257,7 +5257,7 @@ jobs:
wheel-py3_13t-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5495,7 +5495,7 @@ jobs:
wheel-py3_13t-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5733,7 +5733,7 @@ jobs:
wheel-py3_13t-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -5969,7 +5969,7 @@ jobs:
wheel-py3_14-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6204,7 +6204,7 @@ jobs:
wheel-py3_14-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6442,7 +6442,7 @@ jobs:
wheel-py3_14-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6680,7 +6680,7 @@ jobs:
wheel-py3_14-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -6918,7 +6918,7 @@ jobs:
wheel-py3_14-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7154,7 +7154,7 @@ jobs:
wheel-py3_14t-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7389,7 +7389,7 @@ jobs:
wheel-py3_14t-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7627,7 +7627,7 @@ jobs:
wheel-py3_14t-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -7865,7 +7865,7 @@ jobs:
wheel-py3_14t-cuda13_0-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -8103,7 +8103,7 @@ jobs:
wheel-py3_14t-xpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -88,6 +88,7 @@ jobs:
with:
build-environment: linux-jammy-rocm-py3_10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },

View File

@ -147,16 +147,15 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
cuda-arch-list: 8.9
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
]}
secrets: inherit

View File

@ -347,8 +347,7 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
# This should sync with the build in xpu.yml but xpu uses a larger runner
# sync-tag: linux-xpu-n-build
sync-tag: linux-xpu-n-build
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
build-environment: linux-jammy-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3

View File

@ -45,6 +45,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-noble-rocm-py3.12-mi300
docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" },

View File

@ -42,6 +42,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-noble-rocm-py3.12-mi355
docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },

View File

@ -1,75 +0,0 @@
name: rocm-navi31
on:
push:
tags:
- ciflow/rocm-navi31/*
workflow_dispatch:
schedule:
# We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs.
# Also run less frequently on weekends.
- cron: 45 */2 * * 1-5
- cron: 45 4,12 * * 0,6
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions: read-all
jobs:
target-determination:
if: github.repository_owner == 'pytorch'
name: before-test
uses: ./.github/workflows/target_determination.yml
permissions:
id-token: write
contents: read
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-rocm-py3_10-build:
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
]}
secrets: inherit
linux-jammy-rocm-py3_10-test:
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3_10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
tests-to-include: >-
${{ github.event_name == 'schedule' && 'test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs
test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark
inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor
inductor/test_torchinductor inductor/test_decompose_mem_bound_mm
inductor/test_flex_attention inductor/test_max_autotune' || '' }}
secrets: inherit

View File

@ -26,23 +26,11 @@ jobs:
id-token: write
contents: read
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-rocm-py3_10-build:
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
sync-tag: rocm-build
@ -71,3 +59,29 @@ jobs:
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-rocm-py3_10-gfx1100-test:
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3_10-gfx1100
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
]}
tests-to-include: >
test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs
test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark
inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor
inductor/test_torchinductor inductor/test_decompose_mem_bound_mm
inductor/test_flex_attention inductor/test_max_autotune
secrets: inherit

View File

@ -58,10 +58,8 @@ jobs:
else
COMMIT_SHA="${{ github.sha }}"
fi
{
echo "sha=${COMMIT_SHA}"
echo "tag_name=trunk/${COMMIT_SHA}"
} >> "${GITHUB_OUTPUT}"
echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
- name: Validate commit SHA
run: |
@ -89,7 +87,7 @@ jobs:
echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
fi
- name: Create and push tag(s) with retry
- name: Create and push tag with retry
id: check_tag
env:
TAG_NAME: ${{ steps.commit.outputs.tag_name }}
@ -114,23 +112,14 @@ jobs:
return 1
}
# Counters for summary reporting
created_count=0
skipped_count=0
failed_count=0
# Exit early if tag already exists
if check_tag_exists; then
echo "✅ Tag already exists - no action needed"
echo "exists=true" >> "${GITHUB_OUTPUT}"
exit 0
fi
# Always write outputs once on exit
finish() {
set +e
if [ -n "${GITHUB_OUTPUT:-}" ]; then
{
echo "created_count=${created_count}"
echo "skipped_count=${skipped_count}"
echo "failed_count=${failed_count}"
} >> "${GITHUB_OUTPUT}"
fi
}
trap finish EXIT
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
# Retry configuration
MAX_RETRIES=5
@ -205,111 +194,31 @@ jobs:
}
}
# New behavior for push events: enumerate commits in the push and tag each one.
# For workflow_dispatch, retain existing single-SHA behavior.
# Always fetch tags once up front to improve idempotency in loops
git fetch origin --tags --quiet || true
if [ "${{ github.event_name }}" = "push" ]; then
BEFORE_SHA="${{ github.event.before }}"
AFTER_SHA="${{ github.sha }}" # same as event.after
# List commits introduced by this push (old..new), oldest first for stable ordering
commits_file="$(mktemp)"
git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}"
if [ ! -s "${commits_file}" ]; then
echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag."
rm -f "${commits_file}"
exit 0
fi
commit_count="$(wc -l < "${commits_file}" | tr -d ' ')"
echo "Found ${commit_count} commit(s) to tag for push:"
while IFS= read -r sha; do
printf ' %s\n' "${sha}"
done < "${commits_file}"
while IFS= read -r sha; do
TAG_NAME="trunk/${sha}"
COMMIT_SHA="${sha}"
# If tag already exists locally or remotely, skip (idempotent)
if check_tag_exists; then
echo "✅ Tag ${TAG_NAME} already exists - skipping"
skipped_count=$((skipped_count + 1))
continue
fi
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
created_count=$((created_count + 1))
else
echo "Tag creation failed after all retry attempts for ${TAG_NAME}"
failed_count=$((failed_count + 1))
fi
done < "${commits_file}"
rm -f "${commits_file}"
if [ "${failed_count}" -gt 0 ]; then
exit 1
fi
# Execute with retry
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
echo "exists=false" >> "${GITHUB_OUTPUT}"
exit 0
else
# workflow_dispatch path (single SHA tagging preserved)
# Exit early if tag already exists
if check_tag_exists; then
echo "✅ Tag already exists - no action needed"
skipped_count=1
exit 0
fi
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
created_count=1
exit 0
else
echo "Tag creation failed after all retry attempts"
failed_count=1
exit 1
fi
echo "Tag creation failed after all retry attempts"
exit 1
fi
- name: Tag creation summary
if: always()
run: |
if [ "${{ github.event_name }}" = "push" ]; then
echo "Trigger: push on main"
echo "Created: ${{ steps.check_tag.outputs.created_count }}"
echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}"
echo "Failed: ${{ steps.check_tag.outputs.failed_count }}"
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}"
else
echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}"
fi
if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
elif [ "${{ job.status }}" = "success" ]; then
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
else
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
else
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
else
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
echo ""
echo "Tag details:"
echo " Name: ${{ steps.commit.outputs.tag_name }}"
echo " Commit: ${{ steps.commit.outputs.sha }}"
echo " Trigger: ${{ github.event_name }}"
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
fi
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
fi
echo ""
echo "Tag details:"
echo " Name: ${{ steps.commit.outputs.tag_name }}"
echo " Commit: ${{ steps.commit.outputs.sha }}"
echo " Trigger: ${{ github.event_name }}"
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
fi

View File

@ -1138,8 +1138,11 @@ command = [
[[linter]]
code = 'WORKFLOWSYNC'
include_patterns = [
'.github/workflows/*.yml',
'.github/workflows/*.yaml',
'.github/workflows/pull.yml',
'.github/workflows/trunk.yml',
'.github/workflows/periodic.yml',
'.github/workflows/mac-mps.yml',
'.github/workflows/slow.yml',
]
command = [
'python3',

View File

@ -289,15 +289,14 @@ IF(USE_FBGEMM_GENAI)
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
set(fbgemm_genai_cuh
set(fbgemm_genai_mx8mx8bf16_grouped
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
"${FBGEMM_GENAI_SRCS}/"
)
target_include_directories(fbgemm_genai PRIVATE
${FBGEMM_THIRD_PARTY}/cutlass/include
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
${fbgemm_genai_cuh}
${fbgemm_genai_mx8mx8bf16_grouped}
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
)
@ -314,14 +313,13 @@ IF(USE_FBGEMM_GENAI)
# Add additional HIPCC compiler flags for performance
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
-mllvm
-amdgpu-coerce-illegal-types=1
-mllvm
-enable-post-misched=0
-mllvm
-greedy-reverse-local-assignment=1
-fhip-new-launch-api)
if(DEFINED ROCM_VERSION_DEV AND ROCM_VERSION_DEV VERSION_LESS "7.2.0")
list(PREPEND FBGEMM_GENAI_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1)
endif()
# Only compile for gfx942 for now.
# This is rather hacky, I could not figure out a clean solution :(

View File

@ -19,7 +19,6 @@
#include <ATen/detail/MPSHooksInterface.h>
#include <ATen/detail/MTIAHooksInterface.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>
#include <ATen/detail/XLAHooksInterface.h>
#include <ATen/detail/XPUHooksInterface.h>
#include <c10/core/QEngine.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
@ -89,8 +88,6 @@ class TORCH_API Context {
return at::detail::getHIPHooks();
} else if (opt_device_type == at::kHPU) {
return at::detail::getHPUHooks();
} else if (opt_device_type == at::kXLA) {
return at::detail::getXLAHooks();
} else {
TORCH_CHECK(
false,
@ -199,7 +196,7 @@ class TORCH_API Context {
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
}
static bool hasXLA() {
return detail::getXLAHooks().hasXLA();
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
}
static bool hasXPU() {
return detail::getXPUHooks().hasXPU();

View File

@ -39,7 +39,7 @@ struct HostBlock {
};
template <typename B>
struct alignas(hardware_destructive_interference_size) FreeBlockList {
struct alignas(64) FreeBlockList {
std::mutex mutex_;
std::deque<B*> list_;
};
@ -122,7 +122,7 @@ struct TORCH_API HostStats {
// Struct containing memory allocator summary statistics for host, as they
// are staged for reporting. This is a temporary struct that is used to
// avoid locking the allocator while collecting stats.
struct alignas(hardware_destructive_interference_size) HostStatsStaged {
struct alignas(64) HostStatsStaged {
std::mutex timing_mutex_;
// COUNT: total allocations (active + free)
// LOCK: access to this stat is protected by the allocator's blocks_mutex_
@ -669,7 +669,7 @@ struct CachingHostAllocatorImpl {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
}
alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_;
alignas(64) std::mutex blocks_mutex_;
ska::flat_hash_set<B*> blocks_; // block list
ska::flat_hash_map<void*, B*> ptr_to_block_;
@ -677,17 +677,17 @@ struct CachingHostAllocatorImpl {
// size. This allows us to quickly find a free block of the right size.
// We use deque to store per size free list and guard the list with its own
// mutex.
alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>> free_list_ =
alignas(64) std::vector<FreeBlockList<B>> free_list_ =
std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
alignas(64) std::mutex events_mutex_;
std::deque<std::pair<E, B*>> events_; // event queue paired with block
// Indicates whether the object is active.
// Set to false in the destructor to signal background threads to stop.
std::atomic<bool> active_{true};
protected:
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
alignas(64) HostStatsStaged stats_;
};
struct TORCH_API HostAllocator : public at::Allocator {

View File

@ -59,7 +59,9 @@ struct TORCH_API Generator {
explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl)
: impl_(std::move(gen_impl)) {
TORCH_CHECK(impl_.get(), "GeneratorImpl with nullptr is not supported");
if (impl_.get() == nullptr) {
throw std::runtime_error("GeneratorImpl with nullptr is not supported");
}
}
bool operator==(const Generator& rhs) const {

View File

@ -111,7 +111,9 @@ class TORCH_API TensorBase {
explicit TensorBase(
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
: impl_(std::move(tensor_impl)) {
TORCH_CHECK(impl_.get(), "TensorImpl with nullptr is not supported");
if (impl_.get() == nullptr) {
throw std::runtime_error("TensorImpl with nullptr is not supported");
}
}
TensorBase(const TensorBase&) = default;
TensorBase(TensorBase&&) noexcept = default;

View File

@ -68,7 +68,11 @@ Symbol InternedStrings::_symbol(const std::string& s) {
return it->second;
auto pos = s.find("::");
TORCH_CHECK(pos != std::string::npos, "all symbols must have a namespace, <namespace>::<string>, but found: ", s);
if (pos == std::string::npos) {
std::stringstream ss;
ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s;
throw std::runtime_error(ss.str());
}
Symbol ns = _symbol("namespaces::" + s.substr(0, pos));
Symbol sym(sym_to_info_.size());
@ -117,7 +121,12 @@ std::string Symbol::domainString() const {
}
Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) {
TORCH_CHECK(d.compare(0, domain_prefix().size(), domain_prefix()) == 0, "Symbol: domain string is expected to be prefixed with '", domain_prefix(), "', e.g. 'org.pytorch.aten'");
if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) {
std::ostringstream ss;
ss << "Symbol: domain string is expected to be prefixed with '"
<< domain_prefix() << "', e.g. 'org.pytorch.aten'";
throw std::runtime_error(ss.str());
}
std::string qualString = d.substr(domain_prefix().size()) + "::" + s;
return fromQualString(qualString);
}

View File

@ -7,7 +7,6 @@
#include <ATen/core/jit_type.h>
#include <ATen/core/stack.h>
#include <ATen/core/type_factory.h>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <c10/util/hash.h>
#include <c10/util/irange.h>
@ -413,7 +412,7 @@ size_t IValue::hash(const IValue& v) {
case Tag::Enum:
case Tag::Stream:
case Tag::Uninitialized:
TORCH_CHECK(false,
throw std::runtime_error(
"unhashable type: '" + v.type()->repr_str() + "'");
}
// the above switch should be exhaustive

View File

@ -8,7 +8,6 @@
#include <ATen/core/type_factory.h>
#include <ATen/core/qualified_name.h>
#include <c10/util/TypeList.h>
#include <c10/util/Exception.h>
#include <optional>
#include <c10/core/SymFloat.h>
#include <c10/core/SymBool.h>
@ -117,8 +116,10 @@ struct SingleElementType : public SharedType {
protected:
SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) {
TORCH_CHECK(this->elem, c10::str(
if (!this->elem) {
throw std::runtime_error(c10::str(
"Can not create ", typeKindToString(Kind), " with None type"));
}
}
private:
@ -415,12 +416,16 @@ struct TORCH_API SymbolicShape {
}
ShapeSymbol operator[](size_t i) const {
TORCH_CHECK(dims_, "Rank isn't fixed");
if (!dims_) {
throw std::runtime_error("Rank isn't fixed");
}
return (*dims_).at(i);
}
ShapeSymbol at(size_t i) const {
TORCH_CHECK(dims_, "Rank isn't fixed");
if (!dims_) {
throw std::runtime_error("Rank isn't fixed");
}
return (*dims_).at(i);
}
@ -515,7 +520,9 @@ struct VaryingShape {
}
const std::optional<T> &operator[](size_t i) const {
TORCH_CHECK(dims_, "Rank isn't fixed");
if (!dims_) {
throw std::runtime_error("Rank isn't fixed");
}
return (*dims_).at(i);
}
@ -950,7 +957,9 @@ struct TORCH_API DictType : public SharedType {
TypePtr createWithContained(
std::vector<TypePtr> contained_types) const override {
TORCH_CHECK(contained_types.size() == 2, "Expected 2 contained types");
if (contained_types.size() != 2) {
throw std::runtime_error("Expected 2 contained types");
}
return create(std::move(contained_types.at(0)), std::move(contained_types.at(1)));
}

View File

@ -8,7 +8,6 @@
#include <ATen/core/jit_type.h>
#include <c10/macros/Macros.h>
#include <c10/util/env.h>
#include <c10/util/Exception.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include <array>
@ -827,7 +826,9 @@ TupleType::TupleType(
: NamedType(TypeKind::TupleType, std::move(name)),
elements_(std::move(elements)),
has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) {
TORCH_CHECK(v, "Can not create tuple with None type");
if (!v) {
throw std::runtime_error("Can not create tuple with None type");
}
return v->hasFreeVariables();
})), schema_(std::move(schema)) {

View File

@ -6,11 +6,9 @@
#ifdef __aarch64__
#if !defined(CPU_CAPABILITY_SVE)
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
#include <ATen/cpu/vec/vec128/vec128_double_neon.h>
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
#include <ATen/cpu/vec/vec128/vec128_uint_aarch64.h>
#endif
#include <ATen/cpu/vec/vec128/vec128_convert.h>

View File

@ -354,47 +354,9 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
Vectorized frac() const;
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
#ifdef __ARM_FEATURE_BF16
Vectorized<c10::BFloat16> neg() const {
return -values;
}
Vectorized<c10::BFloat16> reciprocal() const {
return 1.0f / values;
}
Vectorized<c10::BFloat16> operator==(
const Vectorized<c10::BFloat16>& other) const {
return values == other.values;
}
Vectorized<c10::BFloat16> operator!=(
const Vectorized<c10::BFloat16>& other) const {
return values != other.values;
}
Vectorized<c10::BFloat16> operator<(
const Vectorized<c10::BFloat16>& other) const {
return values < other.values;
}
Vectorized<c10::BFloat16> operator<=(
const Vectorized<c10::BFloat16>& other) const {
return values <= other.values;
}
Vectorized<c10::BFloat16> operator>(
const Vectorized<c10::BFloat16>& other) const {
return values > other.values;
}
Vectorized<c10::BFloat16> operator>=(
const Vectorized<c10::BFloat16>& other) const {
return values >= other.values;
}
#else
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
@ -402,7 +364,6 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
#endif
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
@ -451,52 +412,28 @@ template <>
Vectorized<c10::BFloat16> inline operator+(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x + y;
#else
return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
#endif
}
template <>
Vectorized<c10::BFloat16> inline operator-(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x - y;
#else
return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
#endif
}
template <>
Vectorized<c10::BFloat16> inline operator*(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x * y;
#else
return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
#endif
}
template <>
Vectorized<c10::BFloat16> inline operator/(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x / y;
#else
return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
#endif
}
// frac. Implement this here so we can use subtraction
@ -607,19 +544,12 @@ Vectorized<c10::BFloat16> inline fmadd(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return x * y + z;
#else
// NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also,
// vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
// elements, not the bottom and top half, so they don't seem
// particularly useful here. Ideally we would include dot product in
// the Vectorized interface...
return a * b + c;
#endif
}
template <>
@ -627,15 +557,8 @@ Vectorized<c10::BFloat16> inline fnmadd(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return (-x) * y + z;
#else
// See NOTE [BF16 FMA] above.
return -a * b + c;
#endif
}
template <>
@ -643,15 +566,8 @@ Vectorized<c10::BFloat16> inline fmsub(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return x * y - z;
#else
// See NOTE [BF16 FMA] above.
return a * b - c;
#endif
}
template <>
@ -659,15 +575,8 @@ Vectorized<c10::BFloat16> inline fnmsub(
const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
#ifdef __ARM_FEATURE_BF16
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return (-x) * y - z;
#else
// See NOTE [BF16 FMA] above.
return -a * b - c;
#endif
}
#endif // !defined(C10_MOBILE) && defined(__aarch64__)

View File

@ -1,586 +0,0 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/macros/Macros.h>
#include <c10/util/irange.h>
#include <cmath>
namespace at::vec {
// Note [CPU_CAPABILITY namespace]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This header, and all of its subheaders, will be compiled with
// different architecture flags for each supported set of vector
// intrinsics. So we need to make sure they aren't inadvertently
// linked together. We do this by declaring objects in an `inline
// namespace` which changes the name mangling, but can still be
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
template <>
struct is_vec_specialized_for<double> : std::bool_constant<true> {};
template <>
class Vectorized<double> {
private:
float64x2_t values;
public:
using value_type = double;
using size_type = int;
static constexpr size_type size() {
return 2;
}
Vectorized() {
values = vdupq_n_f64(0.0);
}
Vectorized(float64x2_t v) : values(v) {}
Vectorized(double val) {
values = vdupq_n_f64(val);
}
template <
typename... Args,
typename = std::enable_if_t<(sizeof...(Args) == size())>>
Vectorized(Args... vals) {
__at_align__ double buffer[size()] = {vals...};
values = vld1q_f64(buffer);
}
operator float64x2_t() const {
return values;
}
template <int64_t mask>
static Vectorized<double> blend(
const Vectorized<double>& a,
const Vectorized<double>& b) {
// Build an array of flags: each bit of element is 1 if the corresponding
// bit in 'mask' is set, 0 otherwise.
uint64x2_t maskArray = {
(mask & 1ULL) ? 0xFFFFFFFFFFFFFFFF : 0,
(mask & 2ULL) ? 0xFFFFFFFFFFFFFFFF : 0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_f64(maskArray, b.values, a.values);
}
static Vectorized<double> blendv(
const Vectorized<double>& a,
const Vectorized<double>& b,
const Vectorized<double>& mask_) {
return vbslq_f64(vreinterpretq_u64_f64(mask_.values), b.values, a.values);
}
template <typename step_t>
static Vectorized<double> arange(
double base = 0.,
step_t step = static_cast<step_t>(1)) {
return {base, base + static_cast<double>(step)};
}
static inline Vectorized<double> set(
const Vectorized<double>& a,
const Vectorized<double>& b,
int64_t count = size()) {
if (count == 0) {
return a;
} else if (count >= 2) {
return b;
} else {
float64x2_t c = {b.values[0], a.values[1]};
return c;
}
}
static Vectorized<double> loadu(const void* ptr, int64_t count = size()) {
if (count == size()) {
return vld1q_f64(reinterpret_cast<const double*>(ptr));
} else if (count == 1) {
float64x1_t x = vld1_f64(reinterpret_cast<const double*>(ptr));
float64x1_t z = {0.0};
return vcombine_f64(x, z);
} else {
return vdupq_n_f64(0.0);
}
}
void store(void* ptr, int64_t count = size()) const {
if (count == size()) {
vst1q_f64(reinterpret_cast<double*>(ptr), values);
} else if (count == 1) {
vst1_f64(reinterpret_cast<double*>(ptr), vget_low_f64(values));
}
}
const double& operator[](int idx) const = delete;
double& operator[](int idx) = delete;
int64_t zero_mask() const {
// returns an integer mask where all zero elements are translated to 1-bit
// and others are translated to 0-bit
uint64x2_t cmpReg = vceqzq_f64(values);
uint64x2_t mask = {1, 2};
uint64x2_t res = vandq_u64(cmpReg, mask);
return res[0] | res[1];
}
Vectorized<double> isnan() const {
// NaN check
return vreinterpretq_f64_u32(
vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, values))));
}
bool has_inf_nan() const {
Vectorized<double> x = vsubq_f64(values, values);
float64x2_t r = x.isnan();
uint64x2_t u = vreinterpretq_u64_f64(r);
return u[0] | u[1];
}
Vectorized<double> map(double (*f)(double)) const {
float64x2_t result;
result[0] = f(values[0]);
result[1] = f(values[1]);
return result;
}
Vectorized<double> map2(
const Vectorized<double>& second,
double (*const f)(double, double)) const {
float64x2_t result;
result[0] = f(values[0], second.values[0]);
result[1] = f(values[1], second.values[1]);
return result;
}
Vectorized<double> abs() const {
return vabsq_f64(values);
}
Vectorized<double> angle() const {
auto zero = Vectorized<double>(0.0);
auto pi = Vectorized<double>(c10::pi<double>);
auto tmp = blendv(zero, pi, vreinterpretq_f64_u64(vcltzq_f64(values)));
return blendv(tmp, *this, isnan());
}
Vectorized<double> real() const {
return *this;
}
Vectorized<double> imag() const {
return Vectorized<double>(0.0);
}
Vectorized<double> conj() const {
return *this;
}
Vectorized<double> acos() const {
return USE_SLEEF(
Vectorized<double>(Sleef_acosd2_u10(values)), map(std::acos));
}
Vectorized<double> acosh() const {
return USE_SLEEF(
Vectorized<double>(Sleef_acoshd2_u10(values)), map(std::acosh));
}
Vectorized<double> asin() const {
return USE_SLEEF(
Vectorized<double>(Sleef_asind2_u10(values)), map(std::asin));
}
Vectorized<double> asinh() const {
return USE_SLEEF(
Vectorized<double>(Sleef_asinhd2_u10(values)), map(std::asinh));
}
Vectorized<double> atan() const {
return USE_SLEEF(
Vectorized<double>(Sleef_atand2_u10(values)), map(std::atan));
}
Vectorized<double> atanh() const {
return USE_SLEEF(
Vectorized<double>(Sleef_atanhd2_u10(values)), map(std::atanh));
}
Vectorized<double> atan2(const Vectorized<double>& b) const {USE_SLEEF(
{ return Vectorized<double>(Sleef_atan2d2_u10(values, b)); },
{
__at_align__ double tmp[size()];
__at_align__ double tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::atan2(tmp[i], tmp_b[i]);
}
return loadu(tmp);
})} Vectorized<double> copysign(const Vectorized<double>& sign) const {
USE_SLEEF(
{ return Vectorized<double>(Sleef_copysignd2(values, sign)); },
{
__at_align__ double tmp[size()];
__at_align__ double tmp_sign[size()];
store(tmp);
sign.store(tmp_sign);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::copysign(tmp[i], tmp_sign[i]);
}
return loadu(tmp);
})} Vectorized<double> erf() const {
return USE_SLEEF(
Vectorized<double>(Sleef_erfd2_u10(values)), map(std::erf));
}
Vectorized<double> erfc() const {
return USE_SLEEF(
Vectorized<double>(Sleef_erfcd2_u15(values)), map(std::erfc));
}
Vectorized<double> exp() const {
return USE_SLEEF(
Vectorized<double>(Sleef_expd2_u10(values)), map(std::exp));
}
Vectorized<double> exp2() const {
return USE_SLEEF(
Vectorized<double>(Sleef_exp2d2_u10(values)), map(std::exp2));
}
Vectorized<double> expm1() const {
return USE_SLEEF(
Vectorized<double>(Sleef_expm1d2_u10(values)), map(std::expm1));
}
Vectorized<double> fmod(const Vectorized<double>& q) const {USE_SLEEF(
{ return Vectorized<double>(Sleef_fmodd2(values, q)); },
{
__at_align__ double tmp[size()];
__at_align__ double tmp_q[size()];
store(tmp);
q.store(tmp_q);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::fmod(tmp[i], tmp_q[i]);
}
return loadu(tmp);
})} Vectorized<double> hypot(const Vectorized<double>& b) const {
USE_SLEEF(
{ return Vectorized<double>(Sleef_hypotd2_u05(values, b)); },
{
__at_align__ double tmp[size()];
__at_align__ double tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::hypot(tmp[i], tmp_b[i]);
}
return loadu(tmp);
})} Vectorized<double> i0() const {
return map(calc_i0);
}
Vectorized<double> nextafter(const Vectorized<double>& b) const {USE_SLEEF(
{ return Vectorized<double>(Sleef_nextafterd2(values, b)); },
{
__at_align__ double tmp[size()];
__at_align__ double tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); ++i) {
tmp[i] = std::nextafter(tmp[i], tmp_b[i]);
}
return loadu(tmp);
})} Vectorized<double> log() const {
return USE_SLEEF(
Vectorized<double>(Sleef_logd2_u10(values)), map(std::log));
}
Vectorized<double> log2() const {
return USE_SLEEF(
Vectorized<double>(Sleef_log2d2_u10(values)), map(std::log2));
}
Vectorized<double> log10() const {
return USE_SLEEF(
Vectorized<double>(Sleef_log10d2_u10(values)), map(std::log10));
}
Vectorized<double> log1p() const {
return USE_SLEEF(
Vectorized<double>(Sleef_log1pd2_u10(values)), map(std::log1p));
}
Vectorized<double> frac() const;
Vectorized<double> sin() const {
return USE_SLEEF(
Vectorized<double>(Sleef_sind2_u10(values)), map(std::sin));
}
Vectorized<double> sinh() const {
return USE_SLEEF(
Vectorized<double>(Sleef_sinhd2_u10(values)), map(std::sinh));
}
Vectorized<double> cos() const {
return USE_SLEEF(
Vectorized<double>(Sleef_cosd2_u10(values)), map(std::cos));
}
Vectorized<double> cosh() const {
return USE_SLEEF(
Vectorized<double>(Sleef_coshd2_u10(values)), map(std::cosh));
}
Vectorized<double> pow(const Vectorized<double>& b) const {USE_SLEEF(
{ return Vectorized<double>(Sleef_powd2_u10(values, b)); },
{
__at_align__ double tmp[size()];
__at_align__ double tmp_b[size()];
store(tmp);
b.store(tmp_b);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = std::pow(tmp[i], tmp_b[i]);
}
return loadu(tmp);
})} // Comparison using the _CMP_**_OQ predicate.
// `O`: get false if an operand is NaN
// `Q`: do not raise if an operand is NaN
Vectorized<double> tan() const {
return USE_SLEEF(
Vectorized<double>(Sleef_tand2_u10(values)), map(std::tan));
}
Vectorized<double> tanh() const {
return USE_SLEEF(
Vectorized<double>(Sleef_tanhd2_u10(values)), map(std::tanh));
}
Vectorized<double> lgamma() const {
return USE_SLEEF(
Vectorized<double>(Sleef_lgammad2_u10(values)), map(std::lgamma));
}
Vectorized<double> erfinv() const {
return map(calc_erfinv);
}
Vectorized<double> exp_u20() const {
return exp();
}
Vectorized<double> fexp_u20() const {
return exp();
}
Vectorized<double> i0e() const {
return map(calc_i0e);
}
Vectorized<double> digamma() const {
return map(calc_digamma);
}
Vectorized<double> igamma(const Vectorized<double>& x) const {
__at_align__ double tmp[size()];
__at_align__ double tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<double> igammac(const Vectorized<double>& x) const {
__at_align__ double tmp[size()];
__at_align__ double tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vectorized<double> ceil() const {
return vrndpq_f64(values);
}
Vectorized<double> floor() const {
return vrndmq_f64(values);
}
Vectorized<double> neg() const {
return vnegq_f64(values);
}
Vectorized<double> round() const {
return vrndiq_f64(values);
}
Vectorized<double> trunc() const {
return vrndq_f64(values);
}
Vectorized<double> sqrt() const {
return vsqrtq_f64(values);
}
Vectorized<double> reciprocal() const {
return vdivq_f64(vdupq_n_f64(1.0), values);
}
Vectorized<double> rsqrt() const {
return vdivq_f64(vdupq_n_f64(1.0), vsqrtq_f64(values));
}
double reduce_add() const {
return vaddvq_f64(values);
}
double reduce_max() const {
return vmaxvq_f64(values);
}
Vectorized<double> operator==(const Vectorized<double>& other) const {
return Vectorized<double>(
vreinterpretq_f64_u64(vceqq_f64(values, other.values)));
}
Vectorized<double> operator!=(const Vectorized<double>& other) const {
float64x2_t r0 = vreinterpretq_f64_u32(
vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, other.values))));
return Vectorized<double>(r0);
}
Vectorized<double> operator<(const Vectorized<double>& other) const {
return Vectorized<double>(
vreinterpretq_f64_u64(vcltq_f64(values, other.values)));
}
Vectorized<double> operator<=(const Vectorized<double>& other) const {
return Vectorized<double>(
vreinterpretq_f64_u64(vcleq_f64(values, other.values)));
}
Vectorized<double> operator>(const Vectorized<double>& other) const {
return Vectorized<double>(
vreinterpretq_f64_u64(vcgtq_f64(values, other.values)));
}
Vectorized<double> operator>=(const Vectorized<double>& other) const {
return Vectorized<double>(
vreinterpretq_f64_u64(vcgeq_f64(values, other.values)));
}
Vectorized<double> eq(const Vectorized<double>& other) const;
Vectorized<double> ne(const Vectorized<double>& other) const;
Vectorized<double> gt(const Vectorized<double>& other) const;
Vectorized<double> ge(const Vectorized<double>& other) const;
Vectorized<double> lt(const Vectorized<double>& other) const;
Vectorized<double> le(const Vectorized<double>& other) const;
};
template <>
Vectorized<double> inline operator+(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return vaddq_f64(a, b);
}
template <>
Vectorized<double> inline operator-(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return vsubq_f64(a, b);
}
template <>
Vectorized<double> inline operator*(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return vmulq_f64(a, b);
}
template <>
Vectorized<double> inline operator/(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return vdivq_f64(a, b);
}
// frac. Implement this here so we can use subtraction
Vectorized<double> inline Vectorized<double>::frac() const {
return *this - this->trunc();
}
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<double> inline maximum(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return vmaxq_f64(a, b);
}
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <>
Vectorized<double> inline minimum(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return vminq_f64(a, b);
}
template <>
Vectorized<double> inline clamp(
const Vectorized<double>& a,
const Vectorized<double>& min,
const Vectorized<double>& max) {
return vminq_f64(max, vmaxq_f64(min, a));
}
template <>
Vectorized<double> inline clamp_max(
const Vectorized<double>& a,
const Vectorized<double>& max) {
return vminq_f64(max, a);
}
template <>
Vectorized<double> inline clamp_min(
const Vectorized<double>& a,
const Vectorized<double>& min) {
return vmaxq_f64(min, a);
}
template <>
Vectorized<double> inline operator&(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return vreinterpretq_f64_u64(
vandq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
}
template <>
Vectorized<double> inline operator|(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return vreinterpretq_f64_u64(
vorrq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
}
template <>
Vectorized<double> inline operator^(
const Vectorized<double>& a,
const Vectorized<double>& b) {
return vreinterpretq_f64_u64(
veorq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
}
inline Vectorized<double> Vectorized<double>::eq(
const Vectorized<double>& other) const {
return (*this == other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::ne(
const Vectorized<double>& other) const {
return (*this != other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::gt(
const Vectorized<double>& other) const {
return (*this > other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::ge(
const Vectorized<double>& other) const {
return (*this >= other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::lt(
const Vectorized<double>& other) const {
return (*this < other) & Vectorized<double>(1.0);
}
inline Vectorized<double> Vectorized<double>::le(
const Vectorized<double>& other) const {
return (*this <= other) & Vectorized<double>(1.0);
}
template <>
Vectorized<double> inline fmadd(
const Vectorized<double>& a,
const Vectorized<double>& b,
const Vectorized<double>& c) {
return vfmaq_f64(c, a, b);
}
template <>
Vectorized<double> inline fnmadd(
const Vectorized<double>& a,
const Vectorized<double>& b,
const Vectorized<double>& c) {
return vfmsq_f64(c, a, b);
}
template <>
Vectorized<double> inline fmsub(
const Vectorized<double>& a,
const Vectorized<double>& b,
const Vectorized<double>& c) {
return vfmaq_f64(vnegq_f64(c), a, b);
}
template <>
Vectorized<double> inline fnmsub(
const Vectorized<double>& a,
const Vectorized<double>& b,
const Vectorized<double>& c) {
return vfmsq_f64(vnegq_f64(c), a, b);
}
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -1,378 +0,0 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/macros/Macros.h>
#include <c10/util/irange.h>
namespace at::vec {
// Note [CPU_CAPABILITY namespace]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This header, and all of its subheaders, will be compiled with
// different architecture flags for each supported set of vector
// intrinsics. So we need to make sure they aren't inadvertently
// linked together. We do this by declaring objects in an `inline
// namespace` which changes the name mangling, but can still be
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#define VEC_UINT_NEON_TEMPLATE(vl, bit) \
template <> \
struct is_vec_specialized_for<uint##bit##_t> : std::bool_constant<true> {}; \
\
template <> \
class Vectorized<uint##bit##_t> { \
using neon_type = uint##bit##x##vl##_t; \
\
private: \
neon_type values; \
\
public: \
using value_type = uint##bit##_t; \
using size_type = int; \
static constexpr size_type size() { \
return vl; \
} \
Vectorized() { \
values = vdupq_n_u##bit(0); \
} \
Vectorized(neon_type v) : values(v) {} \
Vectorized(uint##bit##_t val); \
template < \
typename... Args, \
typename = std::enable_if_t<(sizeof...(Args) == size())>> \
Vectorized(Args... vals) { \
__at_align__ uint##bit##_t buffer[size()] = {vals...}; \
values = vld1q_u##bit(buffer); \
} \
operator neon_type() const { \
return values; \
} \
static Vectorized<uint##bit##_t> loadu( \
const void* ptr, \
uint64_t count = size()); \
void store(void* ptr, uint64_t count = size()) const; \
template <uint64_t mask> \
static Vectorized<uint##bit##_t> blend( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b); \
static Vectorized<uint##bit##_t> blendv( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b, \
const Vectorized<uint##bit##_t>& mask_) { \
return vbslq_u##bit(mask_.values, b, a); \
} \
template <typename step_t> \
static Vectorized<uint##bit##_t> arange( \
value_type base = 0, \
step_t step = static_cast<step_t>(1)); \
static Vectorized<uint##bit##_t> set( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b, \
uint64_t count = size()); \
const uint##bit##_t& operator[](uint idx) const = delete; \
uint##bit##_t& operator[](uint idx) = delete; \
Vectorized<uint##bit##_t> abs() const { \
return values; \
} \
Vectorized<uint##bit##_t> real() const { \
return values; \
} \
Vectorized<uint##bit##_t> imag() const { \
return vdupq_n_u##bit(0); \
} \
Vectorized<uint##bit##_t> conj() const { \
return values; \
} \
Vectorized<uint##bit##_t> neg() const { \
return vreinterpretq_u##bit##_s##bit( \
vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values))); \
} \
uint##bit##_t reduce_add() const { \
return vaddvq_u##bit(values); \
} \
uint##bit##_t reduce_max() const; \
Vectorized<uint##bit##_t> operator==( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vceqq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator!=( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> operator<( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcltq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator<=( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcleq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator>( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcgtq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> operator>=( \
const Vectorized<uint##bit##_t>& other) const { \
return Vectorized<value_type>(vcgeq_u##bit(values, other.values)); \
} \
Vectorized<uint##bit##_t> eq( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> ne( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> gt( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> ge( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> lt( \
const Vectorized<uint##bit##_t>& other) const; \
Vectorized<uint##bit##_t> le( \
const Vectorized<uint##bit##_t>& other) const; \
}; \
template <> \
Vectorized<uint##bit##_t> inline operator+( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vaddq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator-( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vsubq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator&( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vandq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator|( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return vorrq_u##bit(a, b); \
} \
template <> \
Vectorized<uint##bit##_t> inline operator^( \
const Vectorized<uint##bit##_t>& a, \
const Vectorized<uint##bit##_t>& b) { \
return veorq_u##bit(a, b); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::eq( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this == other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ne( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this != other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::gt( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this > other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ge( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this >= other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::lt( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this < other) & Vectorized<uint##bit##_t>(1); \
} \
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::le( \
const Vectorized<uint##bit##_t>& other) const { \
return (*this <= other) & Vectorized<uint##bit##_t>(1); \
}
VEC_UINT_NEON_TEMPLATE(16, 8)
inline uint8_t Vectorized<uint8_t>::reduce_max() const {
return vmaxvq_u8(values);
}
template <>
Vectorized<uint8_t> inline operator*(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
return vmulq_u8(a, b);
}
template <>
inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) {
return vmvnq_u8(a);
}
inline Vectorized<uint8_t> Vectorized<uint8_t>::operator!=(
const Vectorized<uint8_t>& other) const {
return ~(*this == other);
}
template <>
Vectorized<uint8_t> inline minimum(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
return vminq_u8(a, b);
}
template <>
Vectorized<uint8_t> inline maximum(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
return vmaxq_u8(a, b);
}
template <uint64_t mask>
Vectorized<uint8_t> Vectorized<uint8_t>::blend(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
// Build an array of flags: each bit of element is 1 if the corresponding bit
// in 'mask' is set, 0 otherwise.
uint8x16_t maskArray = {
(mask & 1LL) ? 0xFF : 0,
(mask & 2LL) ? 0xFF : 0,
(mask & 4LL) ? 0xFF : 0,
(mask & 8LL) ? 0xFF : 0,
(mask & 16LL) ? 0xFF : 0,
(mask & 32LL) ? 0xFF : 0,
(mask & 64LL) ? 0xFF : 0,
(mask & 128LL) ? 0xFF : 0,
(mask & 256LL) ? 0xFF : 0,
(mask & 512LL) ? 0xFF : 0,
(mask & 1024LL) ? 0xFF : 0,
(mask & 2048LL) ? 0xFF : 0,
(mask & 4096LL) ? 0xFF : 0,
(mask & 8192LL) ? 0xFF : 0,
(mask & 16384LL) ? 0xFF : 0,
(mask & 32768LL) ? 0xFF : 0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_u8(maskArray, b.values, a.values);
}
#define VEC_UINT_NEON_OPS(vl, bit) \
inline Vectorized<uint##bit##_t>::Vectorized(uint##bit##_t val) { \
values = vdupq_n_u##bit(val); \
} \
inline Vectorized<uint##bit##_t> Vectorized<uint##bit##_t>::loadu( \
const void* ptr, uint64_t count) { \
if (count == size()) { \
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(ptr)); \
} else { \
__at_align__ uint##bit##_t tmp_values[size()]; \
for (const auto i : c10::irange(size())) { \
tmp_values[i] = 0; \
} \
std::memcpy( \
tmp_values, \
reinterpret_cast<const uint##bit##_t*>(ptr), \
count * sizeof(uint##bit##_t)); \
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(tmp_values)); \
} \
} \
inline void Vectorized<uint##bit##_t>::store(void* ptr, uint64_t count) \
const { \
if (count == size()) { \
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(ptr), values); \
} else { \
uint##bit##_t tmp_values[size()]; \
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(tmp_values), values); \
std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t)); \
} \
}
VEC_UINT_NEON_OPS(16, 8)
template <typename step_t>
inline Vectorized<uint8_t> Vectorized<uint8_t>::arange(
uint8_t base,
step_t step) {
const Vectorized<uint8_t> base_vec(base);
const Vectorized<uint8_t> step_vec(step);
const uint8x16_t step_sizes = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
return vmlaq_u8(base_vec, step_sizes, step_vec);
}
template <>
Vectorized<uint8_t> inline operator>>(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
uint8x16_t x = a;
uint8x16_t bound = vdupq_n_u8(8);
uint8x16_t z = vminq_u8(b, bound);
return x >> z;
}
template <>
Vectorized<uint8_t> inline operator<<(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
uint8x16_t bound = vdupq_n_u8(8);
uint8x16_t z = vminq_u8(b, bound);
return vshlq_u8(a, vreinterpretq_s8_u8(z));
}
inline Vectorized<uint8_t> Vectorized<uint8_t>::set(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b,
uint64_t count) {
if (count == 0) {
return a;
} else if (count >= 16) {
return b;
} else {
// Build an array of flags: each bit of element is 1 if the corresponding
// bit in 'mask' is set, 0 otherwise.
uint8x16_t maskArray = {
static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_u8(maskArray, b.values, a.values);
}
}
template <>
Vectorized<uint8_t> inline operator/(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& b) {
uint8x16_t x = a;
uint8x16_t y = b;
return x / y;
}
template <>
Vectorized<uint8_t> inline clamp(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& min,
const Vectorized<uint8_t>& max) {
return minimum(max, maximum(min, a));
}
template <>
Vectorized<uint8_t> inline clamp_max(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& max) {
return minimum(max, a);
}
template <>
Vectorized<uint8_t> inline clamp_min(
const Vectorized<uint8_t>& a,
const Vectorized<uint8_t>& min) {
return maximum(min, a);
}
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -1390,7 +1390,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
at::vec::Vectorized<uint8_t> src) {
auto u8x8 = vget_low_u8(src);
auto u8x8 = vld1_u8(src.operator const uint8_t*());
auto u16x8 = vmovl_u8(u8x8);
auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8));
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
@ -1412,7 +1412,7 @@ Vectorized<float> inline convert_int8_half_register_to_float(
Vectorized<float> inline convert_int8_half_register_to_float(
at::vec::Vectorized<uint8_t> src) {
auto u8x8 = vget_low_u8(src);
auto u8x8 = vld1_u8(src.operator const uint8_t*());
auto u16x8 = vmovl_u8(u8x8);
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));

View File

@ -1,192 +0,0 @@
#include <ATen/cuda/CUDAGreenContext.h>
namespace at::cuda {
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
#if CUDA_HAS_GREEN_CONTEXT
int driver_version;
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
TORCH_CHECK(
driver_version >= 12080, "cuda driver too old to use green context!");
CUcontext pctx = nullptr;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
if (C10_UNLIKELY(!pctx)) {
TORCH_WARN(
"Attempted to create a green context but"
" there was no primary context! Creating a primary context...");
cudaFree(0);
}
CUdevice device;
device_id_ = device_id;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
// Get device resources
CUdevResource device_resource;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
// Split resources
std::vector<CUdevResource> result(1);
auto result_data = result.data();
unsigned int nb_groups = 1;
CUdevResource remaining;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
result_data,
&nb_groups,
&device_resource,
&remaining,
0, // default flags
num_sms));
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
// Generate resource descriptor
CUdevResourceDesc desc;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
&desc, result_data, 1));
// Create green context
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
// Convert to regular context
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
std::unique_ptr<GreenContext> GreenContext::create(
uint32_t num_sms,
std::optional<uint32_t> device_id) {
#if CUDA_HAS_GREEN_CONTEXT
if (!device_id.has_value()) {
device_id = at::cuda::current_device();
}
return std::make_unique<GreenContext>(device_id.value(), num_sms);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Implement move operations
GreenContext::GreenContext(GreenContext&& other) noexcept{
#if CUDA_HAS_GREEN_CONTEXT
device_id_ = std::exchange(other.device_id_, -1);
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
context_ = std::exchange(other.context_, nullptr);
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
#if CUDA_HAS_GREEN_CONTEXT
if (this != &other) {
// Clean up current resources
if (green_ctx_) {
CUcontext current = nullptr;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&current));
if (current == context_) {
TORCH_CHECK(
false,
"attempting to overwrite current green ctx "
"when it is active!");
}
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
}
// Take ownership of other's resources
device_id_ = std::exchange(other.device_id_, -1);
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
context_ = std::exchange(other.context_, nullptr);
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
}
return *this;
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
GreenContext::~GreenContext() noexcept{
#if CUDA_HAS_GREEN_CONTEXT
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Get the underlying CUDA context
CUcontext GreenContext::getContext() const {
#if CUDA_HAS_GREEN_CONTEXT
return context_;
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Get the underlying green context
#if CUDA_HAS_GREEN_CONTEXT
CUgreenCtx GreenContext::getGreenContext() const {
return green_ctx_;
}
#endif
// Make this context current
void GreenContext::setContext() {
#if CUDA_HAS_GREEN_CONTEXT
auto current_stream = c10::cuda::getCurrentCUDAStream();
parent_stream_ = current_stream.stream();
at::cuda::CUDAEvent ev;
ev.record(current_stream);
CUcontext current = nullptr;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&current));
if (!current) {
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_));
} else {
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_));
}
// currently hardcodes the new green context to use the default stream
// TODO(eqy): consider creating a new stream if e.g., it allows interop
// with CUDA Graph captures etc.
auto default_stream = c10::cuda::getDefaultCUDAStream();
ev.block(default_stream);
c10::cuda::setCurrentCUDAStream(default_stream);
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
void GreenContext::popContext() {
#if CUDA_HAS_GREEN_CONTEXT
// see above note about stream being hardcoded to the default stream
at::cuda::CUDAEvent ev;
ev.record(c10::cuda::getCurrentCUDAStream());
CUcontext popped;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped));
TORCH_INTERNAL_ASSERT(
popped == context_, "expected popped context to be the current ctx");
ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_));
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
} // namespace at::cuda

View File

@ -1,53 +0,0 @@
#pragma once
#include <ATen/cuda/CUDAEvent.h>
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#include <cuda.h>
#include <memory>
#include <stdexcept>
#include <vector>
#define CUDA_HAS_GREEN_CONTEXT 1
#else
#define CUDA_HAS_GREEN_CONTEXT 0
#endif
namespace at::cuda {
class TORCH_CUDA_CPP_API GreenContext {
public:
GreenContext(uint32_t device_id, uint32_t num_sms);
static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
// Delete copy constructor and assignment
GreenContext(const GreenContext&) = delete;
GreenContext& operator=(const GreenContext&) = delete;
// Implement move operations
GreenContext(GreenContext&& other) noexcept;
GreenContext& operator=(GreenContext&& other) noexcept;
~GreenContext() noexcept;
// Get the underlying CUDA context
CUcontext getContext() const;
// Get the underlying green context
#if CUDA_HAS_GREEN_CONTEXT
CUgreenCtx getGreenContext() const;
#endif
// Make this context current
void setContext();
void popContext();
private:
#if CUDA_HAS_GREEN_CONTEXT
int32_t device_id_ = -1;
CUgreenCtx green_ctx_ = nullptr;
CUcontext context_ = nullptr;
cudaStream_t parent_stream_ = nullptr;
#endif
};
} // namespace at::cuda

View File

@ -70,7 +70,11 @@
#define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max()
#endif
#if defined(USE_ROCM)
#if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM)
#if !defined(USE_ROCM)
namespace at_cuda_detail {
#endif
// backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16
@ -92,6 +96,10 @@ template <>
struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>:
ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};
#if !defined(USE_ROCM)
} // namespace at_cuda_detail
#endif
#endif
#if !defined(USE_ROCM)
@ -113,7 +121,7 @@ struct cuda_type<c10::Half> {
using type = __half;
};
#if !defined(USE_ROCM)
#if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16()
template<>
struct cuda_type<c10::BFloat16> {
@ -195,6 +203,36 @@ __global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputItera
*out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b));
}
#if !CUB_SUPPORTS_FUTURE_VALUE()
template<typename ValueT, typename InputIteratorT>
struct chained_iterator {
using iterator_category = std::random_access_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = ValueT;
using pointer = ValueT*;
using reference = ValueT&;
InputIteratorT iter;
ValueT *first;
difference_type offset = 0;
__device__ ValueT operator[](difference_type i) {
i += offset;
if (i == 0) {
return *first;
} else {
return ValueT(iter[i - 1]);
}
}
__device__ chained_iterator operator+(difference_type i) {
return chained_iterator{iter, first, i};
}
__device__ ValueT operator*() {
return (*this)[0];
}
};
#endif
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
// so split at int_max/2
constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
@ -239,6 +277,25 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
first_elem_ptr,
scan_op);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#if !CUB_SUPPORTS_FUTURE_VALUE()
using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>;
using tuple = typename ArgIndexInputIterator::value_type;
auto input_iter_transform = [=] __device__ (const tuple &x)->input_t {
if (x.key == 0) {
return *first_elem_ptr;
} else {
return x.value;
}
};
auto input_ = ATEN_CUB_TRANSFORM_ITERATOR(input_t, decltype(input_iter_transform), ArgIndexInputIterator)(
ArgIndexInputIterator(input + i), input_iter_transform);
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
input_,
output + i,
scan_op,
size_cub,
at::cuda::getCurrentCUDAStream());
#else
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
input + i + 1,
output + i,
@ -246,6 +303,7 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr),
size_cub,
at::cuda::getCurrentCUDAStream());
#endif
}
#endif
}
@ -497,6 +555,16 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
first_elem_ptr,
scan_op);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#if !CUB_SUPPORTS_FUTURE_VALUE()
auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{
input + i, first_elem_ptr};
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
input_,
output + i,
scan_op,
size_cub,
at::cuda::getCurrentCUDAStream());
#else
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
input + i,
output + i,
@ -504,6 +572,7 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr),
size_cub,
at::cuda::getCurrentCUDAStream());
#endif
}
#endif
}

View File

@ -10,6 +10,14 @@
#define CUB_VERSION 200001
#endif
// cub sort support for __nv_bfloat16 is added to cub 1.13 in:
// https://github.com/NVIDIA/cub/pull/306
#if CUB_VERSION >= 101300
#define CUB_SUPPORTS_NV_BFLOAT16() true
#else
#define CUB_SUPPORTS_NV_BFLOAT16() false
#endif
// cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in:
// https://github.com/NVIDIA/cub/pull/326
// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake
@ -20,6 +28,14 @@
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
#endif
// cub support for cub::FutureValue is added to cub 1.15 in:
// https://github.com/NVIDIA/cub/pull/305
#if CUB_VERSION >= 101500
#define CUB_SUPPORTS_FUTURE_VALUE() true
#else
#define CUB_SUPPORTS_FUTURE_VALUE() false
#endif
// There were many bc-breaking changes in major version release of CCCL v3.0.0
// Please see https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html
#if CUB_VERSION >= 200800

View File

@ -1,23 +0,0 @@
#include <ATen/detail/XLAHooksInterface.h>
namespace at {
namespace detail {
const XLAHooksInterface& getXLAHooks() {
auto create_impl = [] {
// Create XLA hooks using the registry
auto hooks = XLAHooksRegistry()->Create("torch_xla::detail::XLAHooks", XLAHooksArgs{});
if (hooks) {
return hooks;
}
// If hooks creation fails, fall back to default implementation
return std::make_unique<XLAHooksInterface>();
};
static auto hooks = create_impl();
return *hooks;
}
} // namespace detail
C10_DEFINE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs)
} // namespace at

View File

@ -1,79 +0,0 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/util/Exception.h>
#include <c10/util/Registry.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
namespace at {
constexpr const char* XLA_HELP =
"This error has occurred because you are trying "
"to use some XLA functionality, but the XLA library has not been "
"loaded by the dynamic linker. You must load xla libraries by `import torch_xla`";
struct TORCH_API XLAHooksInterface : AcceleratorHooksInterface {
~XLAHooksInterface() override = default;
void init() const override {
TORCH_CHECK(false, "Cannot initialize XLA without torch_xla library. ", XLA_HELP);
}
virtual bool hasXLA() const {
return false;
}
virtual std::string showConfig() const {
TORCH_CHECK(
false,
"Cannot query detailed XLA version without torch_xla library. ",
XLA_HELP);
}
const Generator& getDefaultGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const override {
TORCH_CHECK(
false, "Cannot get default XLA generator without torch_xla library. ", XLA_HELP);
}
Generator getNewGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const override {
TORCH_CHECK(false, "Cannot get XLA generator without torch_xla library. ", XLA_HELP);
}
virtual DeviceIndex getCurrentDevice() const override {
TORCH_CHECK(false, "Cannot get current XLA device without torch_xla library. ", XLA_HELP);
}
Device getDeviceFromPtr(void* /*data*/) const override {
TORCH_CHECK(false, "Cannot get device of pointer on XLA without torch_xla library. ", XLA_HELP);
}
Allocator* getPinnedMemoryAllocator() const override {
TORCH_CHECK(false, "Cannot get XLA pinned memory allocator without torch_xla library. ", XLA_HELP);
}
bool isPinnedPtr(const void* data) const override {
return false;
}
bool hasPrimaryContext(DeviceIndex device_index) const override {
TORCH_CHECK(false, "Cannot query primary context without torch_xla library. ", XLA_HELP);
}
};
struct TORCH_API XLAHooksArgs {};
TORCH_DECLARE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs);
#define REGISTER_XLA_HOOKS(clsname) \
C10_REGISTER_CLASS(XLAHooksRegistry, clsname, clsname)
namespace detail {
TORCH_API const XLAHooksInterface& getXLAHooks();
} // namespace detail
} // namespace at
C10_DIAGNOSTIC_POP()

View File

@ -3620,7 +3620,7 @@ Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result)
try {
mkldnn_matmul_i8i8i32(self, mat2, result);
dispatched = true;
} catch ([[maybe_unused]] const std::exception& e) {
} catch (const std::exception& e) {
TORCH_WARN(func_name, " failed, switching to BLAS gemm: ", e.what());
}
}

View File

@ -11,8 +11,6 @@ inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_facto
"pixel_shuffle expects a positive upscale_factor, but got ",
upscale_factor);
int64_t c = self.size(-3);
TORCH_CHECK_VALUE(upscale_factor <= std::numeric_limits<decltype(upscale_factor)>::max() / upscale_factor,
"upscale factor is too large, (upscale_factor)^2 overflowed: upscale_factor=", upscale_factor);
int64_t upscale_factor_squared = upscale_factor * upscale_factor;
TORCH_CHECK(c % upscale_factor_squared == 0,
"pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "

View File

@ -259,20 +259,11 @@ inline void winograd_f2k3_input_transform_inplace__rvv(
const vfloat32m1_t wd1 = __riscv_vfadd_vv_f32m1(d1, d2, 4);
const vfloat32m1_t wd2 = __riscv_vfsub_vv_f32m1(d2, d1, 4);
const vfloat32m1_t wd3 = __riscv_vfsub_vv_f32m1(d1, d3, 4);
/* GCC 14.2 (RISC-V RVV) ICE workaround:
* Avoid single-statement read-modify-write on MEM_REF like:
* *input_tile_val =
* __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val);
* This triggers an ICE during GIMPLE lower (gsi_replace / riscv_gimple_fold_builtin)
* with -march=rv64gcv. Use a temporary then write back.
* Do NOT refactor into the single-statement form. Clang is unaffected.
*/
vfloat32m1x4_t tmp_input_tile_val = *input_tile_val;
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 0, wd0);
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 1, wd1);
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 2, wd2);
tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 3, wd3);
*input_tile_val = tmp_input_tile_val;
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wd0);
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wd1);
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 2, wd2);
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 3, wd3);
}
inline void winograd_f2k3_output_transform_inplace__rvv(
@ -286,15 +277,9 @@ inline void winograd_f2k3_output_transform_inplace__rvv(
const vfloat32m1_t wm0 = __riscv_vfadd_vv_f32m1(m0_plus_m1, m2, 4);
const vfloat32m1_t m1_sub_m2 = __riscv_vfsub_vv_f32m1(m1, m2, 4);
const vfloat32m1_t wm1 = __riscv_vfsub_vv_f32m1(m1_sub_m2, m3, 4);
/* GCC 14.2 (RISC-V RVV) ICE workaround — see note above.
* Keep the temporary + write-back pattern to avoid ICE.
* Do NOT rewrite into:
* *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val);
*/
vfloat32m1x4_t tmp_output_tile_val = *input_tile_val;
tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 0, wm0);
tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 1, wm1);
*input_tile_val = tmp_output_tile_val;
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wm0);
*input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wm1);
}
inline vfloat32m1_t
@ -315,17 +300,11 @@ inline void winograd_f2k3_kernel_transform__rvv(
const vfloat32m1_t const_half = __riscv_vfmv_v_f_f32m1(0.5f, 4);
const vfloat32m1_t g0_plus_g2 = __riscv_vfadd_vv_f32m1(g0, g2, 4);
vfloat32m1_t half_g0_plus_g2 = __riscv_vfmul_vv_f32m1(const_half, g0_plus_g2, 4);
/* GCC 14.2 (RISC-V RVV) ICE workaround — see note above.
* Keep the temporary + write-back pattern to avoid ICE.
* Do NOT rewrite into:
* *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, idx, val);
*/
vfloat32m1x4_t tmp_transform = *transform;
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 0, g0);
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1));
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1));
tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 3, g2);
*transform = tmp_transform;
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 0, g0);
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1));
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1));
*transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 3, g2);
}
inline vfloat32m1x4_t v4f_transpose4x4__rvv(const vfloat32m1x4_t m) {

View File

@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
} else if (dtype == ScalarType::Half) {
[&]() {
using scalar_t =
c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
const auto exp = exp_scalar.to<scalar_t>();
using Vec = Vectorized<scalar_t>;
cpu_kernel_vec(iter,

View File

@ -272,110 +272,28 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
}
}
/*
* Checks whether DISABLE_ADDMM_CUDA_LT is set.
* Additionally, for ROCM we test whether the architecture supports the Lt.
*/
static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
// When hipBLASLt is not supported on the architecture, return true
#ifdef USE_ROCM
static const std::vector<std::string> archs = {
static bool getDisableAddmmCudaLt() {
static const auto env_value = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
if (env_value == "1") {
return true;
}
return false;
}
#ifdef USE_ROCM
static bool isSupportedHipLtROCmArch(int index) {
static const std::vector<std::string> archs = {
"gfx90a", "gfx942",
#if ROCM_VERSION >= 60300
#if ROCM_VERSION >= 60300
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
#endif
#if ROCM_VERSION >= 70000
#endif
#if ROCM_VERSION >= 70000
"gfx950", "gfx1150", "gfx1151"
#endif
};
const auto is_hipblas_lt_arch_supported = at::detail::getCUDAHooks().isGPUArch(archs, device.index());
if (!is_hipblas_lt_arch_supported) {
return true;
}
#endif
// Check whether it is disabled in the env
static const auto is_addmm_cuda_lt_disabled = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT");
if (is_addmm_cuda_lt_disabled == "1") {
return true;
}
return false;
}
/*
* Check whether for the given input we want to enable the Lt interface
*/
static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
// Implies 2D bias which we currently not send through Lt.
// TODO: this check is done pre col-major input preparation,
// so, this condition can be ralexed in cases when a col-major
// copy of result is needed.
if (result.is_same(self)) {
return false;
}
#if defined(USE_ROCM) && ROCM_VERSION == 60400
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
const auto args = cublasCommonArgs(mat1, mat2, result);
if (args.transa == 't' && args.transb == 't') {
return false;
}
#endif
const auto mat1_sizes = mat1.sizes();
const auto mat2_sizes = mat2.sizes();
#if defined(CUDA_VERSION) || defined(USE_ROCM)
const auto scalar_type = mat1.scalar_type();
return (beta.toComplexDouble() == 1.0
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
&& self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous()
&& result.dim() == 2 && result.is_contiguous()
&& ( // some dtype restrictions
#ifndef USE_ROCM
scalar_type == at::ScalarType::Double ||
#endif
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16
)
&& ( // some shape/stride restrictions
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// NOTE: extension to mat1 because mat1/mat2 can be swapped based off
// their row-/col-majorness.
mat1_sizes[0] > 1 && mat1_sizes[1] > 1 &&
mat2_sizes[0] > 1 && mat2_sizes[1] > 1
// The last conditions is to skip 16b transA and non-trans-B having
// leading dim >> rows when they are sliced from a large tensor
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
#if !(defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
// Related to avoiding the leading stride >> leading dim problematic case
// with 16b dtypes described above. For such dtypes we only allow inputs
// which are either row- or col-major (i.e. non-overlapping, compact memory layout).
// In that case the leading stride will be equal to the outer dim len.
// Why do we catch this case here? The following `prepare_matrix_for_cublas` method
// does not modify inputs as long as there is a stride of length 1
// and the leading stride is at least max(1, other dim length), so we might
// end up with contiguous cols but not rows (i.e. holes between different rows)
// and vice versa.
&& mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
&& (
// filter by dtype
(scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) ||
// check mat1/mat2 is row-/col-major
(mat1.is_non_overlapping_and_dense() && mat2.is_non_overlapping_and_dense())
)
#endif
)
);
#endif
// no compliance by default
return false;
#endif
};
return at::detail::getCUDAHooks().isGPUArch(archs, index);
}
#endif
template <typename scalar_t>
void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) {
@ -417,70 +335,7 @@ void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const
}
}
template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmAndBiasCublasLt(
// args contains result which is modified
cublasCommonArgs& args,
const Tensor& self,
const Scalar& alpha,
Activation activation = Activation::None
) {
const auto* self_ptr = self.const_data_ptr<scalar_t>();
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
// TODO: maybe also return some success state?
launchTunableGemmAndBias<scalar_t>(
args, alpha, self_ptr, activation_to_gemm_and_blas_arg(activation)
);
return true;
}
return at::cuda::blas::gemm_and_bias<scalar_t, res_scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self_ptr,
args.result->data_ptr<res_scalar_t>(),
args.result_ld,
activation_to_gemm_and_blas_arg(activation)
);
}
template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmCublas(
// args contains result which is modified
cublasCommonArgs& args,
const Scalar& alpha,
const Scalar& beta
) {
at::cuda::blas::gemm<scalar_t, res_scalar_t>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
beta.to<at::opmath_type<scalar_t>>(),
args.result->data_ptr<res_scalar_t>(),
args.result_ld
);
return true; // success!
}
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None, bool disable_addmm_cuda_lt_override=false) {
// Shape checks {
// Make sure to keep addmm_cuda below in sync with this code; it
// preflights a check to try to avoid actually needing to call
// expand().
@ -490,62 +345,105 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
"expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
)
if (result.is_same(self)) {
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(self.sizes()[0] == mat1.sizes()[0], "self dim 0 must match mat1 dim 0");
TORCH_CHECK(self.sizes()[1] == mat2.sizes()[1], "self dim 1 must match mat2 dim 1");
}
// } Shape checks
// NOLINTNEXTLINE(*c-array*)
TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
checkAllSameGPU(__func__, targs);
// Handle whether to use the Lt interface {
static bool persistent_disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device());
IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
IntArrayRef self__sizes;
bool useLtInterface = false;
#if defined(USE_ROCM)
// When hipBLASLt is not supported on the architecture,
// disable_addmm_cuda_lt will always be to set to true
static bool disable_addmm_cuda_lt =
!isSupportedHipLtROCmArch(self.device().index()) || getDisableAddmmCudaLt();
#else
static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt();
#endif
// if lt path fails, we recurse back into this function here and force the lt path to off
// we cannot update varible disable_addmm_cuda_lt from above since it is static and would be permanent
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
#ifdef USE_ROCM
// Conditioned on the device index, which is not persistent
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
#endif
// Condition on the input
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt;
// }
bool disable_addmm_cuda_lt_final = disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
#if defined(USE_ROCM) && ROCM_VERSION == 60400
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
cublasCommonArgs _args(mat1, mat2, result);
if (_args.transa == 't' && _args.transb == 't') {
disable_addmm_cuda_lt_final = true;
}
#endif
at::ScalarType scalar_type = mat1.scalar_type();
bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;
c10::MaybeOwned<Tensor> self_;
if (&result != &self) {
#if defined(CUDA_VERSION) || defined(USE_ROCM)
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
// is to use lt interface only when self is bias.
// for cuda 11.4, cublasLtMatmul is activated
// the last two conditions is to skip 16b transA and non-trans-B having
// leading dim >> rows when they are sliced from a large tensor
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
if (!disable_addmm_cuda_lt_final) {
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
self.is_contiguous() && result.is_contiguous() &&
#ifdef USE_ROCM
(scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16) &&
#else
(scalar_type == at::ScalarType::Double ||
scalar_type == at::ScalarType::Float ||
scalar_type == at::ScalarType::Half ||
scalar_type == at::ScalarType::BFloat16) &&
#endif
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM))
mat2_sizes[0] > 1 && mat2_sizes[1] > 1;
#else
mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
// avoid leading dim >> rows bugs
((mat1.strides()[0] == 1 && mat1.strides()[1] == mat1_sizes[0]) ||
(mat1.strides()[1] == 1 && mat1.strides()[0] == mat1_sizes[1]) ||
(scalar_type != at::ScalarType::Half &&
scalar_type != at::ScalarType::BFloat16)) &&
((mat2.strides()[0] == 1 && mat2.strides()[1] == mat2_sizes[0]) ||
(mat2.strides()[1] == 1 && mat2.strides()[0] == mat2_sizes[1]) ||
(scalar_type != at::ScalarType::Half &&
scalar_type != at::ScalarType::BFloat16));
#endif
}
#endif
if (!useLtInterface) {
self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
}
self__sizes = self_->sizes();
} else {
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
self__sizes = self_->sizes();
TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0");
TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1");
}
// Handle result/self shapes
if (!result.is_same(self)) {
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
if (disable_addmm_cuda_lt) {
// When in non-Lt path we do expand self even before
// check for beta != 0.0 to make sure that
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
// runs green.
return expand_size(self, result.sizes(), "addmm");
}
// copy next, should broadcast
return c10::MaybeOwned<Tensor>::borrowed(self);
}();
// We copy bias when in the non-Lt path
if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) {
// NOTE: self should broadcast over result
at::native::copy_(result, *self_maybe_expanded);
if (&result != &self) {
at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]});
if (beta.toComplexDouble() != 0.0 && !useLtInterface) {
at::native::copy_(result, *self_);
}
}
// Short circuit on empty result
if (result.numel() == 0) {
IntArrayRef result_sizes = result.sizes();
if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
return result;
}
// Short circuit if the reduction dim is empty
if (mat1.sizes()[1] == 0) {
cublasCommonArgs args(mat1, mat2, result);
if (mat1.numel() == 0) {
// By definition, when beta==0, values in self should be ignored. nans and infs
// should not propagate
if (beta.toComplexDouble() == 0.) {
@ -557,64 +455,158 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
result,
self.expand(result.sizes()),
at::native::scalar_tensor(
beta,
self.scalar_type(),
std::nullopt /* layout */,
at::kCPU,
std::nullopt /* pin_memory */
)
);
beta,
self.scalar_type(),
std::nullopt /* layout */,
at::kCPU,
std::nullopt /* pin_memory */));
}
cublasCommonArgs args(mat1, mat2, result);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
// The Lt path
if (!disable_addmm_cuda_lt) {
bool lt_success = false;
if (useLtInterface) {
#if defined(USE_ROCM)
bool okay = true;
if (is_float_output_with_half_input) {
#ifdef USE_ROCM
TORCH_CHECK(false, "float output with half input is not enabled for ROCm");
#else
if (at::cuda::tunable::getTuningContext()->IsTunableOpEnabled()) {
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda_lt",
[&] {
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation);
}
);
#endif
} else {
// !is_float_output_with_half_input
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"addmm_cuda_lt",
[&] {
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation);
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
launchTunableGemmAndBias<scalar_t>(
args,
alpha,
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
activation_to_gemm_and_blas_arg(activation));
} else {
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
// This condition is needed for mm case on ROCm for hipblasLt path.
// Passing the bias ptr as null to avoid accuracy issues for mm case.
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_to_gemm_and_blas_arg(activation)
);
}
);
} // end is_float_output_with_half_input
if (!lt_success) {
// lt path failed; recurse but disable lt path
});
}
if (!okay) {
// lt path failed; recurse but disable lt path
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
}
// end Lt path
} else {
// No Lt, we use a GEMM instead
#else
auto activation_epilogue = activation_to_gemm_and_blas_arg(activation);
bool okay = true;
if (is_float_output_with_half_input) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input");
}
else {
okay = at::cuda::blas::gemm_and_bias<scalar_t, float>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<float>(),
args.result_ld,
activation_epilogue
);
}});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
scalar_type,
"addmm_cuda_lt",
[&] {
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
launchTunableGemmAndBias<scalar_t>(
args,
alpha,
self.const_data_ptr<scalar_t>(),
activation_epilogue);
}
else {
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
args.n,
args.k,
alpha.to<at::opmath_type<scalar_t>>(),
args.mata->const_data_ptr<scalar_t>(),
args.lda,
args.matb->const_data_ptr<scalar_t>(),
args.ldb,
self.const_data_ptr<scalar_t>(),
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_epilogue
);
}});
}
if (!okay) {
// lt path failed; recurse but disable lt path
return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true);
}
#endif
} else
{
if (is_float_output_with_half_input) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(
scalar_type,
"addmm_cuda",
[&] {
launchGemmCublas<scalar_t, float>(args, alpha, beta);
}
);
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha_val = alpha.to<opmath_t>();
opmath_t beta_val = beta.to<opmath_t>();
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
float* result_ptr = args.result->mutable_data_ptr<float>();
at::cuda::blas::gemm<scalar_t, float>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha_val,
mat1_ptr,
args.lda,
mat2_ptr,
args.ldb,
beta_val,
result_ptr,
args.result_ld);
});
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half,
@ -622,12 +614,28 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
scalar_type,
"addmm_cuda",
[&] {
launchGemmCublas<scalar_t>(args, alpha, beta);
}
);
using opmath_t = at::opmath_type<scalar_t>;
opmath_t alpha_val = alpha.to<opmath_t>();
opmath_t beta_val = beta.to<opmath_t>();
const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>();
const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>();
scalar_t* result_ptr = args.result->mutable_data_ptr<scalar_t>();
at::cuda::blas::gemm<scalar_t>(
args.transa,
args.transb,
args.m,
args.n,
args.k,
alpha_val,
mat1_ptr,
args.lda,
mat2_ptr,
args.ldb,
beta_val,
result_ptr,
args.result_ld);
});
}
// Apply epilogue
switch (activation) {
case Activation::RELU:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
@ -639,14 +647,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
break;
default: break;
}
} // end GEMM path
}
// Preprocessor gate here needs to match the inverse of the check
// gating activation_to_gemm_and_blas_arg above; here we are manually
// performing a post-GELU because we weren't able to use the GELU
// epilogue above.
#if !defined(CUDA_VERSION) && !defined(USE_ROCM)
if (!disable_addmm_cuda_lt && activation == Activation::GELU) {
if (useLtInterface && activation == Activation::GELU) {
at::gelu_(const_cast<Tensor&>(*args.result), "tanh");
}
#endif

View File

@ -856,13 +856,9 @@ struct type_specialized_kernel_launcher {
out_calc_t output_offset_calculator,
loader_t loader,
storer_t storer) {
constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
if (ret_t == rt_binary_specializations[arg_index][0] &&
arg0_t == rt_binary_specializations[arg_index][1] &&
arg1_t == rt_binary_specializations[arg_index][2])
launch_vectorized_templated_kernel<
func_t,
array_t,
@ -870,9 +866,12 @@ struct type_specialized_kernel_launcher {
out_calc_t,
loader_t,
storer_t,
cret_t,
carg0_t,
carg1_t>(
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][0]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][1]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][2]>::t)>(
numel,
f,
data,
@ -880,7 +879,6 @@ struct type_specialized_kernel_launcher {
output_offset_calculator,
loader,
storer);
}
}
};

View File

@ -1,17 +1,18 @@
#pragma once
#include <ATen/OpMathType.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/detail/FunctionTraits.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TensorIteratorDynamicCasting.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/OpMathType.h>
#include <ATen/native/cuda/thread_constants.h>
#include <thrust/tuple.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <tuple>
namespace at::native {
template<int N>
@ -61,11 +62,7 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
#pragma unroll
for (int i = 0; i < elems_per_thread; i++) {
if (policy.check_inbounds(i)) {
#if defined(__HIP__)
results[i] = c10::guts::apply(f, args[i]);
#else
results[i] = std::apply(f, args[i]);
#endif
}
}

View File

@ -23,7 +23,7 @@ namespace at::native {
// The maximum number of threads in a block
#if defined(USE_ROCM)
constexpr int MAX_BLOCK_SIZE = 1024;
constexpr int MAX_BLOCK_SIZE = 256;
#else
constexpr int MAX_BLOCK_SIZE = 512;
#endif
@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) {
#if defined(USE_ROCM)
int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
#else
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
#endif
@ -115,23 +115,9 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
// first the reductions each thread does separately
scalar_t sum = static_cast<scalar_t>(0);
for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
#if defined(USE_ROCM)
constexpr int UNRL = 4; // load deserilize factor
scalar_t tmp[UNRL];
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) {
#pragma unroll
for (int u = 0; u < UNRL; u++)
tmp[u] = op(batch, plane, min((int)tensor.size(2)-1, (int)(x+u*blockDim.x)));
#pragma unroll
for (int u = 0; u < UNRL; u++)
if (x+u*blockDim.x < tensor.size(2))
sum += tmp[u];
}
#else
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
sum += op(batch, plane, x);
}
#endif
}
__shared__ scalar_t shared[C10_WARP_SIZE];
SumReduceOp<scalar_t> reduce_op;
@ -306,22 +292,6 @@ __global__ void batch_norm_collect_statistics_kernel(
stat_accscalar_t var_n = 0;
int n = 0;
for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
#if defined(USE_ROCM)
constexpr int UNRL = 4;
stat_accscalar_t v_[UNRL];
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x*UNRL) {
for (int u = 0; u < UNRL; u++)
v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)];
for (int u = 0; u < UNRL; u++) {
if (x+u*blockDim.x < input.size(2)) {
stat_accscalar_t d1 = v_[u] - avg;
n++;
avg += d1 / n;
var_n += d1 * (v_[u] - avg);
}
}
}
#else
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
stat_accscalar_t v = input[batch][plane][x];
stat_accscalar_t d1 = v - avg;
@ -329,7 +299,6 @@ __global__ void batch_norm_collect_statistics_kernel(
avg += d1 / n;
var_n += d1 * (v - avg);
}
#endif
}
// first warpSum to get one value per thread to

View File

@ -653,14 +653,8 @@ struct ReduceOp {
}
__syncthreads();
// Intra-warp reduction, fix CUDA to have offset decreasing for better numerics
// matching Triton, etc.
// TODO(PaulZhang12): AMD and internal
#if defined(USE_ROCM) || defined(FBCODE_CAFFE2)
for (int offset = 1; offset < dim_x; offset <<= 1) {
#else
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
#endif
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
arg_t other = ops.warp_shfl_down(value[i], offset);

View File

@ -92,16 +92,6 @@ inline thrust::pair<int64_t, int64_t> get_index_mapping2d(
output_offset + output_y * output_dim_x + output_x);
}
__device__ __forceinline__ int64_t reflect_index(int64_t x, int64_t len) {
const int64_t two = (len - 1) * 2;
if (two <= 0) {
return 0;
}
int64_t m = x % two;
if (m < 0) m += two;
return (m < len) ? m : (two - m);
}
template<typename scalar_t>
__global__ void reflection_pad1d_out_kernel(
const scalar_t * input, scalar_t * output,
@ -116,28 +106,6 @@ __global__ void reflection_pad1d_out_kernel(
}
}
template <typename scalar_t>
__global__ void reflection_pad1d_flat(
const scalar_t* __restrict__ input,
scalar_t* __restrict__ output,
int64_t input_w, int64_t pad_l, int64_t pad_r,
int64_t out_w, int64_t plane_count) {
const int64_t bx = blockDim.x;
const int64_t tx = threadIdx.x;
const int64_t total = plane_count * out_w;
const int64_t grid_stride = static_cast<int64_t>(bx) * gridDim.x;
int64_t linear = static_cast<int64_t>(blockIdx.x) * bx + tx;
for (; linear < total; linear += grid_stride) {
const int64_t plane = linear / out_w;
const int64_t x = linear - plane * out_w;
const int64_t j = reflect_index(x - pad_l, input_w);
output[plane * out_w + x] = input[plane * input_w + j];
}
}
template <typename scalar_t>
__global__ void reflection_pad1d_backward_out_kernel(
scalar_t * grad_input, const scalar_t * grad_output,
@ -742,44 +710,25 @@ TORCH_IMPL_FUNC(reflection_pad1d_out_cuda)
int64_t input_w = input_.size(dim_w);
int64_t output_w = input_w + pad_l + pad_r;
dim3 block_size(output_w > 256 ? 256 : output_w);
dim3 grid_size((int)::ceil(output_w / 256.0), nplane, nbatch);
Tensor input = input_.contiguous();
const int block_x = static_cast<int>(std::min<int64_t>(256, std::max<int64_t>(1, output_w)));
const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
const int max_x = prop->maxGridSize[0];
const int max_y = prop->maxGridSize[1];
const int max_z = prop->maxGridSize[2];
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out", [&] {
auto stream = at::cuda::getCurrentCUDAStream();
const int64_t gx = at::ceil_div(output_w, static_cast<int64_t>(block_x));
const bool fits3d = (nplane <= max_y) && (nbatch <= max_z) && (gx <= max_x);
if (fits3d) {
dim3 block(block_x, 1, 1);
dim3 grid(gx, static_cast<unsigned>(nplane), static_cast<unsigned>(nbatch));
reflection_pad1d_out_kernel<scalar_t><<<grid, block, 0, stream>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w, pad_l, pad_r);
} else {
dim3 block(block_x, 1, 1);
const int64_t plane_count = nplane * nbatch;
const int64_t total_blocks = at::ceil_div(plane_count * output_w, static_cast<int64_t>(block_x));
const int grid_x = static_cast<int>(std::min<int64_t>(max_x, std::max<int64_t>(1, total_blocks)));
dim3 grid(grid_x, 1, 1);
reflection_pad1d_flat<scalar_t><<<grid, block, 0, stream>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w, pad_l, pad_r, output_w, plane_count);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kHalf, kBFloat16, input.scalar_type(), "reflection_pad1d_out_template", [&] {
reflection_pad1d_out_kernel<<<
grid_size,
block_size,
0,
at::cuda::getCurrentCUDAStream()>>>(
input.const_data_ptr<scalar_t>(),
output.mutable_data_ptr<scalar_t>(),
input_w,
pad_l,
pad_r);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cuda)(const Tensor& grad_output_,

View File

@ -52,7 +52,7 @@ struct FusedAdagradMathFunctor {
using opmath_t = at::opmath_type<scalar_t>;
C10_DEVICE __forceinline__ void operator()(
int64_t chunk_size,
int chunk_size,
FusedOptimizerTensorListMetadata<3>& tl,
const float* lr_ptr,
const double& lr,
@ -133,4 +133,4 @@ struct FusedAdagradMathFunctor {
} // namespace
} // namespace at::native
} // namespace at::native

View File

@ -466,11 +466,7 @@ struct ReduceJitOp {
__syncthreads();
#if defined(USE_ROCM) || defined(FBCODE_CAFFE2)
for (int offset = 1; offset < dim_x; offset <<= 1) {
#else
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
#endif
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
arg_t other = reducer::warp_shfl_down(value[i], offset);

View File

@ -92,8 +92,13 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
}
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
if ([maskedMM dataType] != MPSDataTypeFloat32) {
maskedMM = [mpsGraph castTensor:maskedMM toType:MPSDataTypeFloat32 name:nil];
}
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
if ([maskedMM dataType] != qTensor.dataType) {
maskedMM = [mpsGraph castTensor:maskedMM toType:qTensor.dataType name:nil];
}
if (is_causal) {
auto causalMask = [mpsGraph constantWithScalar:1.0f
@ -107,9 +112,7 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
name:nil];
} else if (attn_mask) {
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
name:nil];
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil];
}
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
@ -130,8 +133,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
graph->qTensor = qTensor;
graph->kTensor = kTensor;
graph->vTensor = vTensor;
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
graph->outputTensor = output;
graph->attnTensor = sm;
});
auto qPlaceholder = Placeholder(cachedGraph->qTensor, query);
auto kPlaceholder = Placeholder(cachedGraph->kTensor, key);

View File

@ -338,8 +338,6 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
". See https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdecompositionstatus for details.");
}
}
map_mps_decomposition_error_code_to_blas(info);
}
static void linalg_solve_out_mps_impl(const Tensor& A,
@ -1450,6 +1448,20 @@ TORCH_IMPL_FUNC(_linalg_solve_ex_out_mps)
mps::linalg_solve_out_mps_impl(A, B, left, check_errors, result, LU, pivots, info);
}
std::tuple<Tensor&, Tensor&> linalg_lu_factor_out_mps(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) {
Tensor info = at::empty({}, A.options().dtype(kInt));
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false);
return std::tie(LU, pivots);
}
std::tuple<Tensor, Tensor> linalg_lu_factor_mps(const Tensor& A, bool pivot) {
Tensor LU = at::empty({0}, A.options());
Tensor pivots = at::empty({0}, A.options().dtype(kInt));
Tensor info = at::empty({}, A.options().dtype(kInt));
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, false);
return std::make_tuple(std::move(LU), std::move(pivots));
}
TORCH_IMPL_FUNC(lu_unpack_out_mps)
(const Tensor& LU_data,
const Tensor& LU_pivots,

View File

@ -706,7 +706,6 @@
variants: function, method
dispatch:
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_all
tags: reduction
- func: all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
@ -716,7 +715,6 @@
cpp_no_default_args: ['dim']
dispatch:
CompositeExplicitAutograd: all_dims_default
tags: reduction
- func: all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -725,7 +723,6 @@
CPU, CUDA: all_out
MPS: all_out_mps
MTIA: all_out_mtia
tags: reduction
- func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -734,16 +731,13 @@
CPU, CUDA: all_dims_out
CompositeExplicitAutograd: all_dims_out_default
cpp_no_default_args: ['dim']
tags: reduction
- func: all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool
variants: function, method
@ -755,14 +749,14 @@
device_check: NoCheck # TensorIterator
structured_delegate: any.out
variants: function, method
tags: [core, reduction]
tags: core
- func: any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: any.dims_out
variants: function, method
cpp_no_default_args: ['dim']
tags: [core, reduction]
tags: core
dispatch:
CompositeExplicitAutograd: any_dims_default
@ -772,7 +766,6 @@
dispatch:
CPU, CUDA: any_out
MPS: any_out_mps
tags: reduction
- func: any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -781,16 +774,13 @@
CPU, CUDA: any_dims_out
CompositeExplicitAutograd: any_dims_out_default
cpp_no_default_args: ['dim']
tags: reduction
- func: any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
dispatch:
@ -836,27 +826,25 @@
structured_delegate: argmax.out
device_check: NoCheck # TensorIterator
variants: function, method
tags: [core, reduction]
tags: core
- func: argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU, CUDA: argmax_out
MPS: argmax_out_mps
tags: reduction
- func: argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor
structured_delegate: argmin.out
device_check: NoCheck # TensorIterator
variants: function, method
tags: [core, reduction]
tags: core
- func: argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU, CUDA: argmin_out
MPS: argmin_out_mps
tags: reduction
- func: acosh(Tensor self) -> Tensor
variants: function, method
@ -1881,14 +1869,12 @@
CUDA: count_nonzero_cuda
MPS: count_nonzero_mps
autogen: count_nonzero.dim_IntList_out
tags: reduction
- func: count_nonzero(Tensor self, int? dim=None) -> Tensor
variants: function, method
dispatch:
CompositeExplicitAutograd: count_nonzero
autogen: count_nonzero.out
tags: reduction
- func: cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor
variants: function, method
@ -3809,23 +3795,19 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: logsumexp
tags: reduction
- func: logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
# calls squeeze
CompositeExplicitAutogradNonFunctional: logsumexp_out
tags: reduction
- func: logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor
@ -3875,7 +3857,6 @@
device_check: NoCheck # TensorIterator
structured_delegate: aminmax.out
variants: function, method
tags: reduction
- func: aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)
device_check: NoCheck # TensorIterator
@ -3883,7 +3864,6 @@
dispatch:
CPU, CUDA, MTIA: aminmax_out
MPS: aminmax_out_mps
tags: reduction
- func: _compute_linear_combination(Tensor input, Tensor coefficients) -> Tensor
dispatch:
@ -3899,7 +3879,7 @@
variants: function, method
dispatch:
QuantizedCPU, QuantizedCUDA: qmax
tags: [core, reduction]
tags: core
- func: max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
device_check: NoCheck # TensorIterator
@ -3909,16 +3889,13 @@
dispatch:
CPU, CUDA, MTIA: max_out
MPS: max_out_mps
tags: reduction
- func: max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
device_check: NoCheck # TensorIterator
tags: reduction
- func: value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor
variants: function
@ -3931,14 +3908,13 @@
- func: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
variants: function, method
structured_delegate: amax.out
tags: [core, reduction]
tags: core
- func: amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU, CUDA, MTIA: amax_out
MPS: amax_out_mps
tags: reduction
# Return: (Tensor output, Tensor indices)
- func: max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
@ -4000,14 +3976,13 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: mean
tags: [core, reduction]
tags: core
# For normal naming convention this should be `mean.out`. However since we already have `mean.out` we have to rename this.
- func: mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CompositeExplicitAutograd: mean_dtype_out
tags: reduction
- func: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
structured_delegate: mean.out
@ -4015,7 +3990,7 @@
variants: function, method
dispatch:
QuantizedCPU: mean_quantized_cpu
tags: [core, reduction]
tags: core
- func: mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
structured: True
@ -4024,16 +3999,13 @@
CPU, CUDA: mean_out
MPS: mean_out_mps
QuantizedCPU: mean_out_quantized_cpu
tags: reduction
- func: mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: nanmean(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # Composite
@ -4096,7 +4068,7 @@
variants: function, method
dispatch:
QuantizedCPU, QuantizedCUDA: qmin
tags: [core, reduction]
tags: core
- func: min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)
device_check: NoCheck # TensorIterator
@ -4106,28 +4078,24 @@
dispatch:
CPU, CUDA, MTIA: min_out
MPS: min_out_mps
tags: reduction
- func: min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)
device_check: NoCheck # TensorIterator
tags: reduction
- func: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
variants: function, method
structured_delegate: amin.out
tags: [core, reduction]
tags: core
- func: amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU, CUDA, MTIA: amin_out
MPS: amin_out_mps
tags: reduction
# TODO: Add this function to MPS dispatch key so that we avoid declaring it in
# native_functions.yaml
@ -5892,7 +5860,6 @@
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sum_coo
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_csr
autogen: sum.out
tags: reduction
- func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
# TODO: Align the signature of sum.dim_IntList and _sparse_csr_sum.dim_dtype
@ -5903,12 +5870,11 @@
NestedTensorCPU: NestedTensor_sum_dim_CPU
SparseCPU, SparseCUDA, SparseMPS: sum_sparse_coo
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_sparse_compressed
tags: [core, reduction]
tags: core
- func: sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
structured: True
@ -5916,11 +5882,9 @@
dispatch:
CPU, CUDA: sum_out
MPS: sum_out_mps
tags: reduction
- func: sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
# TODO: this function will be replaced once nested expand semantics have been settled on
- func: _nested_sum_backward(Tensor grad, Tensor self, int[1]? dim, bool keepdim=False) -> Tensor
@ -5932,13 +5896,11 @@
dispatch:
CPU, CUDA: nansum
MPS: nansum_mps
tags: reduction
- func: nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: nansum_out
MPS: nansum_out_mps
tags: reduction
- func: hash_tensor(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0) -> Tensor
variants: function, method
@ -6002,13 +5964,11 @@
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
@ -6017,19 +5977,16 @@
CPU, CUDA: std
MPS: std_mps
QuantizedCPU: std_quantized_cpu
tags: reduction
- func: std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
@ -6038,51 +5995,42 @@
CPU, CUDA: std_mean
MPS: std_mean_mps
autogen: std_mean.correction_out
tags: reduction
- func: std_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
tags: reduction
- func: std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: std_out
QuantizedCPU: std_out_quantized_cpu
tags: reduction
- func: std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: std.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: std.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: function
tags: reduction
- func: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
@ -6091,13 +6039,13 @@
CPU, CUDA: prod
MPS: prod_mps
autogen: prod.out
tags: [core, reduction]
tags: core
- func: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
structured_delegate: prod.int_out
device_check: NoCheck # TensorIterator
variants: function, method
tags: [core, reduction]
tags: core
- func: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
structured: True
@ -6105,16 +6053,13 @@
dispatch:
CPU, CUDA: prod_out
MPS: prod_out_mps
tags: reduction
- func: prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: t(Tensor(a) self) -> Tensor(a)
device_check: NoCheck
@ -6575,12 +6520,11 @@
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: [core, reduction]
tags: core
cpp_no_default_args: ["unbiased"]
- func: var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor
@ -6590,51 +6534,43 @@
CPU, CUDA: var
MPS: var_mps
MTIA: var_mtia
tags: [core, reduction]
tags: core
- func: var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: var_out
tags: reduction
- func: var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: var.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: function
tags: reduction
- func: var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
@ -6643,18 +6579,15 @@
CPU, CUDA: var_mean
MPS: var_mean_mps
autogen: var_mean.correction_out
tags: reduction
- func: var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]
tags: reduction
- func: var_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
tags: reduction
- func: view_as(Tensor(a) self, Tensor other) -> Tensor(a)
variants: method
@ -6914,7 +6847,6 @@
dispatch:
CompositeExplicitAutograd: norm
autogen: norm.ScalarOpt_dtype_out
tags: reduction
- func: norm.Scalar(Tensor self, Scalar p=2) -> Tensor
device_check: NoCheck # TensorIterator
@ -6922,7 +6854,6 @@
dispatch:
CompositeExplicitAutograd: norm
autogen: norm.Scalar_out
tags: reduction
- func: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
structured_delegate: norm.dtype_out
@ -6930,7 +6861,6 @@
variants: function, method
dispatch:
SparseCPU, SparseCUDA, SparseMPS: sparse_dtype_norm
tags: reduction
- func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor
structured_delegate: norm.out
@ -6938,7 +6868,6 @@
variants: function, method
dispatch:
SparseCPU, SparseCUDA, SparseMPS: sparse_norm
tags: reduction
- func: norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
structured: True
@ -6946,7 +6875,6 @@
dispatch:
CPU, CUDA: norm_dtype_out
MPS: norm_dtype_out_mps
tags: reduction
- func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
@ -6954,26 +6882,21 @@
dispatch:
CPU, CUDA: norm_out
MPS: norm_out_mps
tags: reduction
# These four redispatch in their implementation, so OK to be CompositeImplicitAutograd
- func: norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: reduction
- func: norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
tags: reduction
- func: frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent)
variants: method, function
@ -10159,14 +10082,12 @@
CPU, CUDA: min
MPS: min_mps
QuantizedCPU: min_quantized_cpu
tags: [reduction]
- func: min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: min_unary_out
QuantizedCPU: min_quantized_unary_out
tags: [reduction]
- func: fmin(Tensor self, Tensor other) -> Tensor
structured_delegate: fmin.out
@ -10189,7 +10110,6 @@
CPU, CUDA: max
MPS: max_mps
QuantizedCPU: max_quantized_cpu
tags: [reduction]
- func: fmax(Tensor self, Tensor other) -> Tensor
structured_delegate: fmax.out
@ -10236,7 +10156,6 @@
dispatch:
CPU, CUDA: max_unary_out
QuantizedCPU: max_quantized_unary_out
tags: [reduction]
- func: minimum(Tensor self, Tensor other) -> Tensor
structured_delegate: minimum.out
@ -10356,7 +10275,6 @@
device_check: NoCheck # TensorIterator
structured_delegate: all.all_out
variants: method, function
tags: reduction
- func: all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
@ -10365,7 +10283,6 @@
CPU, CUDA: all_all_out
MTIA: all_all_out_mtia
MPS: all_all_out_mps
tags: reduction
- func: any(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
@ -10373,7 +10290,7 @@
variants: method, function
dispatch:
SparseCPU, SparseCUDA, SparseMPS: any_sparse
tags: [core, reduction]
tags: core
- func: any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck
@ -10381,7 +10298,6 @@
dispatch:
CPU, CUDA: any_all_out
MPS: any_all_out_mps
tags: reduction
- func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -14157,10 +14073,16 @@
- func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots)
python_module: linalg
variants: function
dispatch:
CompositeImplicitAutograd: linalg_lu_factor
MPS: linalg_lu_factor_mps
- func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots)
python_module: linalg
variants: function
dispatch:
CompositeImplicitAutograd: linalg_lu_factor_out
MPS: linalg_lu_factor_out_mps
- func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info)
python_module: linalg
@ -14427,7 +14349,6 @@
python_module: linalg
variants: function
structured_delegate: linalg_vector_norm.out
tags: reduction
- func: linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
@ -14435,7 +14356,6 @@
dispatch:
CPU, CUDA: linalg_vector_norm_out
MPS: linalg_vector_norm_out_mps
tags: reduction
- func: linalg_matrix_norm(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
python_module: linalg

View File

@ -40,7 +40,15 @@
#include <thrust/iterator/discard_iterator.h>
#if defined(__CUDACC__) && (defined(CUSPARSE_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60300))
#define IS_CUSPARSE11_AVAILABLE() 1
#else
#define IS_CUSPARSE11_AVAILABLE() 0
#endif
#if IS_CUSPARSE11_AVAILABLE()
#include <library_types.h>
#endif
namespace at::native {
@ -95,9 +103,17 @@ struct csrMatrixRef {
int nnz_{0};
std::vector<int> size_{};
cusparseSpMatDescr_t description_{0};
#if IS_CUSPARSE11_AVAILABLE()
cusparseSpMatDescr_t description_{0};
#else
cusparseMatDescr_t description_{0};
#endif
csrMatrixRef() = default;
csrMatrixRef() {
#if !IS_CUSPARSE11_AVAILABLE()
create_general_description_(description_);
#endif
}
csrMatrixRef(
int* csr_indices,
@ -110,6 +126,7 @@ struct csrMatrixRef {
csr_values_{csr_values},
nnz_{nnz},
size_{size} {
#if IS_CUSPARSE11_AVAILABLE()
cudaDataType cuda_data_type = at::cuda::getCudaDataType<scalar_t>();
TORCH_CUDASPARSE_CHECK(cusparseCreateCsr(
&description_,
@ -123,10 +140,17 @@ struct csrMatrixRef {
CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_BASE_ZERO,
cuda_data_type));
#else
create_general_description_(description_);
#endif
}
~csrMatrixRef() {
cusparseDestroySpMat(description_);
#if IS_CUSPARSE11_AVAILABLE()
cusparseDestroySpMat(description_);
#else
cusparseDestroyMatDescr(description_);
#endif
}
int size(int index) const {
@ -172,6 +196,8 @@ struct csrOutput {
}
};
#if IS_CUSPARSE11_AVAILABLE()
// RAII guard helps to support cuSparse 11 API for `A @ B` operation
// This generic template exists because with cuSparse the `scalar_t` type could be a double or float
template <class scalar_t>
@ -370,6 +396,284 @@ template struct CusparseMatrixMultiplyOp<float>;
template struct CusparseMatrixMultiplyOp<double>;
#else // if not IS_CUSPARSE11_AVAILABLE()
using DcsrMatrixRef = csrMatrixRef<double>;
using ScsrMatrixRef = csrMatrixRef<float>;
// RAII guard helps to support cuSparse 10 API for `A @ B` operation
// This generic template exists because with cuSparse the `scalar_t` type could be a double or float
template <class scalar_t>
struct CusparseMatrixMultiplyOp {
csrOutput operator()(
const csrMatrixRef<scalar_t>& lhs,
const csrMatrixRef<scalar_t>& rhs,
Tensor &output_values,
Tensor &output_indices)
{
static_assert(false&&sizeof(scalar_t), "cusparse csr sparse-sparse MM only supports data type of float and double.");
}
};
// Specializacion for `A @ B` operation for double values with cuSparse
template<> struct CusparseMatrixMultiplyOp<double> {
csrgemm2Info_t gemm2Info_;
CusparseMatrixMultiplyOp() {
TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_));
}
~CusparseMatrixMultiplyOp() {
cusparseDestroyCsrgemm2Info(gemm2Info_);
}
csrOutput operator ()(
const DcsrMatrixRef& lhs,
const DcsrMatrixRef& rhs,
Tensor &output_values,
Tensor &output_indices) {
double alpha = 1.0;
DcsrMatrixRef empty;
return Dgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices);
}
csrOutput Dgemm2(
const DcsrMatrixRef& A,
const DcsrMatrixRef& B,
const DcsrMatrixRef& C,
const double* alpha,
const double* beta,
Tensor &output_values,
Tensor &output_indices) {
void* buffer_{nullptr};
cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle();
TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST));
csrOutput out({A.size(0), B.size(1)});
int innerSize = confirm_mult_size(A.size_, B.size_);
out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
// Compute needed buffer size
size_t new_bubber_sz;
TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2_bufferSizeExt(
cusparseHandle_,
out.size(0),
out.size(1),
innerSize,
alpha,
A.description_,
A.nnz_,
A.csr_pointers_,
A.csr_indices_,
B.description_,
B.nnz_,
B.csr_pointers_,
B.csr_indices_,
beta,
C.description_,
C.nnz_,
C.csr_pointers_,
C.csr_indices_,
gemm2Info_,
&new_bubber_sz));
// (Re)allocate buffer if needed
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
at::DataPtr data_ptr = allocator.allocate(new_bubber_sz);
buffer_ = data_ptr.get();
// Find the resulting non-zero pattern.
TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz(
cusparseHandle_,
out.size(0),
out.size(1),
innerSize,
A.description_,
A.nnz_,
A.csr_pointers_,
A.csr_indices_,
B.description_,
B.nnz_,
B.csr_pointers_,
B.csr_indices_,
C.description_,
C.nnz_,
C.csr_pointers_,
C.csr_indices_,
out.description_,
out.csr_pointers_.data_ptr<int>(),
&out.nnz_,
gemm2Info_,
buffer_));
out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
out.csr_values_ = at::empty({out.nnz_}, output_values.options());
// Perform the gemm2 operation for doubles
// out = alpha A B + beta C
TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2(
cusparseHandle_,
out.size(0),
out.size(1),
innerSize,
alpha,
A.description_,
A.nnz_,
A.csr_values_,
A.csr_pointers_,
A.csr_indices_,
B.description_,
B.nnz_,
B.csr_values_,
B.csr_pointers_,
B.csr_indices_,
beta,
C.description_,
C.nnz_,
C.csr_values_,
C.csr_pointers_,
C.csr_indices_,
out.description_,
out.csr_values_.data_ptr<double>(),
out.csr_pointers_.data_ptr<int>(),
out.csr_indices_.data_ptr<int>(),
gemm2Info_,
buffer_));
return out;
}
};
// Specializacion for `A @ B` operation for float values with cuSparse
template<> struct CusparseMatrixMultiplyOp<float> {
csrgemm2Info_t gemm2Info_;
CusparseMatrixMultiplyOp() {
TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_));
}
~CusparseMatrixMultiplyOp() {
cusparseDestroyCsrgemm2Info(gemm2Info_);
}
csrOutput operator()(
const ScsrMatrixRef& lhs,
const ScsrMatrixRef& rhs,
Tensor &output_values,
Tensor &output_indices) {
float alpha = 1.0;
ScsrMatrixRef empty;
return Sgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices);
}
csrOutput Sgemm2(
const ScsrMatrixRef& A,
const ScsrMatrixRef& B,
const ScsrMatrixRef& C,
const float* alpha,
const float* beta,
Tensor &output_values,
Tensor &output_indices) {
void* buffer_{nullptr};
cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle();
TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST));
csrOutput out({A.size(0), B.size(1)});
int innerSize = confirm_mult_size(A.size_, B.size_);
out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
// Compute needed buffer size
size_t new_bubber_sz;
TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2_bufferSizeExt(
cusparseHandle_,
out.size(0),
out.size(1),
innerSize,
alpha,
A.description_,
A.nnz_,
A.csr_pointers_,
A.csr_indices_,
B.description_,
B.nnz_,
B.csr_pointers_,
B.csr_indices_,
beta,
C.description_,
C.nnz_,
C.csr_pointers_,
C.csr_indices_,
gemm2Info_,
&new_bubber_sz));
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
at::DataPtr data_ptr = allocator.allocate(new_bubber_sz);
buffer_ = data_ptr.get();
// Find the resulting non-zero pattern.
TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz(
cusparseHandle_,
out.size(0),
out.size(1),
innerSize,
A.description_,
A.nnz_,
A.csr_pointers_,
A.csr_indices_,
B.description_,
B.nnz_,
B.csr_pointers_,
B.csr_indices_,
C.description_,
C.nnz_,
C.csr_pointers_,
C.csr_indices_,
out.description_,
out.csr_pointers_.data_ptr<int>(),
&out.nnz_,
gemm2Info_,
buffer_));
out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
out.csr_values_ = at::empty({out.nnz_}, output_values.options());
// Perform the gemm2 operation for doubles
// out = alpha A B + beta C
TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2(
cusparseHandle_,
out.size(0),
out.size(1),
innerSize,
alpha,
A.description_,
A.nnz_,
A.csr_values_,
A.csr_pointers_,
A.csr_indices_,
B.description_,
B.nnz_,
B.csr_values_,
B.csr_pointers_,
B.csr_indices_,
beta,
C.description_,
C.nnz_,
C.csr_values_,
C.csr_pointers_,
C.csr_indices_,
out.description_,
out.csr_values_.data_ptr<float>(),
out.csr_pointers_.data_ptr<int>(),
out.csr_indices_.data_ptr<int>(),
gemm2Info_,
buffer_));
return out;
}
};
#endif // IS_CUSPARSE11_AVAILABLE()
template <typename scalar_t>
void sparse_sparse_matmul_cuda_kernel(
Tensor& result,
@ -511,15 +815,19 @@ Tensor sparse_sparse_matmul_cuda(const Tensor& mat1_, const Tensor& mat2_) {
auto output = at::native::empty_like(mat1_);
output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0);
#if !defined(USE_ROCM)
#if IS_CUSPARSE11_AVAILABLE() && !defined(USE_ROCM)
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, mat1_.scalar_type(), "sparse_matmul", [&] {
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
});
#else
#elif IS_CUSPARSE11_AVAILABLE() && defined(USE_ROCM)
// ROCm does not support half and bfloat16 types for sparse_matmul
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
});
#else
AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
});
#endif
return output;
}

View File

@ -62,6 +62,7 @@ kernel void build_row_ptr_from_sorted_rows_by_batch(
template <typename T>
kernel void spmm_bmm_coo_rows_grouped(
device const long* rows [[buffer(0)]],
device const long* cols [[buffer(1)]],
device const T* vals [[buffer(2)]],
device const T* dense [[buffer(3)]],
@ -72,6 +73,7 @@ kernel void spmm_bmm_coo_rows_grouped(
uint3 ltid [[thread_position_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]])
{
const uint B = dims.x;
const uint I = dims.y;
const uint J = dims.z;
const uint K = dims.w;
@ -319,6 +321,7 @@ INSTANTIATE_FOR_FLOAT_TYPES(INSTANTIATE_FUSED_GATHER_MUL);
#define INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED(DTYPE) \
template [[host_name("spmm_bmm_coo_rows_grouped_" #DTYPE)]] kernel void \
spmm_bmm_coo_rows_grouped<DTYPE>( \
device const long* rows [[buffer(0)]], \
device const long* cols [[buffer(1)]], \
device const DTYPE* vals [[buffer(2)]], \
device const DTYPE* dense [[buffer(3)]], \

View File

@ -93,7 +93,3 @@
This operator does not support cudagraphs. The presence of this tag on an operator will cause
Inductor to split the graph around this operator. Note that operators without this tag may still
not support CUDAGraphs. Inductor may have other hardcoded lists around that.
- tag: reduction
desc: |
This tag indicates that an operator performs a reduction operation, computing aggregate values
(sum, mean, max, min, etc.) across one or more dimensions of the input tensor(s).

View File

@ -202,6 +202,7 @@ supported:
- select_backward
- _trilinear
- linalg_pinv.atol_rtol_tensor
- svd
- logsumexp.out
symint:
- empty.memory_format

View File

@ -1,8 +1,8 @@
add_loop_eager,compile_time_instruction_count,3184000000,0.1
add_loop_eager,compile_time_instruction_count,3070000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
update_hint_regression,compile_time_instruction_count,1645000000,0.1
update_hint_regression,compile_time_instruction_count,1719000000,0.1
sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1
@ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1
@ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1

1 add_loop_eager compile_time_instruction_count 3184000000 3070000000 0.1
2 add_loop_eager_dynamic compile_time_instruction_count 4595000000 4432000000 0.1
3 add_loop_inductor compile_time_instruction_count 29660000000 29660000000 0.1
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 39910000000 39910000000 0.1
5 add_loop_inductor_gpu compile_time_instruction_count 26800000000 26800000000 0.1
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 1096000000 1048000000 0.1
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 15240000000 15240000000 0.1
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 17720000000 17020000000 0.1
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3152000000 3442000000 0.1
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 8301000000 9239000000 0.1
20 mm_loop_inductor_gpu compile_time_instruction_count 4958000000 4820968837 0.1
21 mm_loop_inductor_dynamic_gpu compile_time_instruction_count 9051000000 9051000000 0.1
22 basic_NestedModule_eager compile_time_instruction_count 9990000000 9554000000 0.1
23 basic_InlineMod_eager compile_time_instruction_count 8126000000 7618000000 0.1
24
26
27
28
29
30
31
32
34
35
36
37
38
39
40
41
42
43
44
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
82
83
84
85
86
87
88
89

View File

@ -43,7 +43,6 @@ tolerance:
- doctr_reco_predictor
- drq
- phlippe_resnet
- pytorch_CycleGAN_and_pix2pix
higher_bf16:
- doctr_reco_predictor

View File

@ -127,7 +127,7 @@ def trainbench(
bwd_time = bwd_start_event.elapsed_time(bwd_end_event)
return fwd_time, bwd_time
creator_args = {
creator_args = creator_args = {
"seqLength": seqLength,
"numLayers": numLayers,
"inputSize": inputSize,

View File

@ -12,7 +12,7 @@ def modeldef(request, net_name, executor, fuser):
# Given a 'net_name' provided by generate_tests, build the thing
name, rnn_creator, context = get_nn_runners(net_name)[0]
creator_args = {
creator_args = creator_args = {
"seqLength": 100,
"numLayers": 1,
"inputSize": 512,

View File

@ -44,101 +44,21 @@ PyTorch,div_,div__M1_N1_K1_cpu_dtype_onetorch.float32_dtype_twotorch.float32,sho
PyTorch,div_,div__M64_N64_K64_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,59.241161,0.000000
PyTorch,div_,div__M64_N64_K128_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,59.852816,0.000000
PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,57.006677,0.000000
PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,88.167000,0.000000
PyTorch,add,"add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.519000,0.000000
PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,55.606088,0.000000
PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,86.551000,0.000000
PyTorch,sub,"sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.864088,0.000000
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,58.529255,0.000000
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,71.641000,0.000000
PyTorch,div,"div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,83.073000,0.000000
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32",short,False,54.645077,0.000000
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16",short,False,67.570000,0.000000
PyTorch,mul,"mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64",short,False,57.895000,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,4.397014,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.739000,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.786000,0.000000
PyTorch,add,add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.911000,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,59.243500,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.066000,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.076000,0.000000
PyTorch,add,add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.225000,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.947691,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.291000,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.224000,0.000000
PyTorch,add,add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.912000,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.925851,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.0240000,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.069000,0.000000
PyTorch,sub,sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.938000,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.308320,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.091000,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.710000,0.000000
PyTorch,sub,sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.502000,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,57.787743,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.863000,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,108.939000,0.000000
PyTorch,sub,sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.603000,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,7.978539,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,8.741000,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,8.757000,0.000000
PyTorch,div,div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,8.774000,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,159.754860,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,165.552000,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,165.755000,0.000000
PyTorch,div,div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,165.714000,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,165.360235,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,168.376000,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,169.604000,0.000000
PyTorch,div,div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,168.428000,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,3.928136,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.402000,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.567000,0.000000
PyTorch,mul,mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,4.020000,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,56.413499,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,104.638000,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.335000,0.000000
PyTorch,mul,mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.612000,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.925090,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.110000,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.389000,0.000000
PyTorch,mul,mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.195000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.989000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.999000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.939000,0.000000
PyTorch,asr,asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.980000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.408000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.647000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.476000,0.000000
PyTorch,asr,asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.784000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.583000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,108.083000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.663000,0.000000
PyTorch,asr,asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.283000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.986000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.676000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.618000,0.000000
PyTorch,lsl,lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.982000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.698000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.899000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.741000,0.000000
PyTorch,lsl,lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.182000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.290000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,107.744000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,107.820000,0.000000
PyTorch,lsl,lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,51.298000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,1.988000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,7.689000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.695000,0.000000
PyTorch,xor,xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,1.978000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,54.934000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,105.217000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,104.215000,0.000000
PyTorch,xor,xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,47.115000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,55.974000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,106.828000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,106.879000,0.000000
PyTorch,xor,xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,48.197000,0.000000
PyTorch,logical_and,"logical_and_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bool",short,False,78.404254,0.000000
PyTorch,logical_and,logical_and_M1_N1_K1_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,5.354032,0.000000
PyTorch,logical_and,logical_and_M64_N64_K64_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,54.072783,0.000000
@ -151,9 +71,6 @@ PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.float32,short,False,6.631313,
PyTorch,baddbmm,baddbmm_B2_M1_N8_K2_cpu_dtypetorch.bfloat16,short,False,6.476986,0.000000
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.float32,short,False,266.065131,0.000000
PyTorch,baddbmm,baddbmm_B128_M64_N32_K64_cpu_dtypetorch.bfloat16,short,False,295.503063,0.000000
PyTorch,all,all_M1_N1_K1_cpu,short,False,5.773000,0.000000
PyTorch,all,all_M64_N64_K64_cpu,short,False,89.427000,0.000000
PyTorch,all,all_M64_N64_K128_cpu,short,False,120.119000,0.000000
PyTorch,cat,"cat_sizes(1,1,1)_N2_dim0_cpu",short,False,4.301950,0.000000
PyTorch,cat,"cat_sizes(512,512,2)_N2_dim1_cpu",short,False,99.093415,0.000000
PyTorch,cat,"cat_sizes(128,1024,2)_N2_dim1_cpu",short,False,96.771578,0.000000

1 Benchmarking Framework Benchmarking Module Name Case Name tag run_backward Execution Time Peak Memory (KB)
44 PyTorch div_ div__M64_N64_K64_cpu_dtype_onetorch.float32_dtype_twotorch.float32 short False 59.241161 0.000000
45 PyTorch div_ div__M64_N64_K128_cpu_dtype_onetorch.float32_dtype_twotorch.float32 short False 59.852816 0.000000
46 PyTorch add add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32 short False 57.006677 0.000000
PyTorch add add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16 short False 88.167000 0.000000
PyTorch add add_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64 short False 57.519000 0.000000
47 PyTorch sub sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32 short False 55.606088 0.000000
PyTorch sub sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16 short False 86.551000 0.000000
PyTorch sub sub_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64 short False 57.864088 0.000000
48 PyTorch div div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32 short False 58.529255 0.000000
PyTorch div div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16 short False 71.641000 0.000000
PyTorch div div_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64 short False 83.073000 0.000000
49 PyTorch mul mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float32 short False 54.645077 0.000000
PyTorch mul mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bfloat16 short False 67.570000 0.000000
PyTorch mul mul_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.float64 short False 57.895000 0.000000
50 PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 4.397014 0.000000
PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.739000 0.000000
PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.786000 0.000000
PyTorch add add_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.911000 0.000000
51 PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 59.243500 0.000000
PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.066000 0.000000
PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.076000 0.000000
PyTorch add add_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.225000 0.000000
52 PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 57.947691 0.000000
PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 107.291000 0.000000
PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 107.224000 0.000000
PyTorch add add_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.912000 0.000000
53 PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.925851 0.000000
PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 8.0240000 0.000000
PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 8.069000 0.000000
PyTorch sub sub_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.938000 0.000000
54 PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 57.308320 0.000000
PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 107.091000 0.000000
PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 108.710000 0.000000
PyTorch sub sub_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.502000 0.000000
55 PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 57.787743 0.000000
PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 108.863000 0.000000
PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 108.939000 0.000000
PyTorch sub sub_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.603000 0.000000
56 PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 7.978539 0.000000
PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 8.741000 0.000000
PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 8.757000 0.000000
PyTorch div div_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 8.774000 0.000000
57 PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 159.754860 0.000000
PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 165.552000 0.000000
PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 165.755000 0.000000
PyTorch div div_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 165.714000 0.000000
58 PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 165.360235 0.000000
PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 168.376000 0.000000
PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 169.604000 0.000000
PyTorch div div_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 168.428000 0.000000
59 PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 3.928136 0.000000
PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.402000 0.000000
PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.567000 0.000000
PyTorch mul mul_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 4.020000 0.000000
60 PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 56.413499 0.000000
PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 104.638000 0.000000
PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 104.335000 0.000000
PyTorch mul mul_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.612000 0.000000
61 PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.925090 0.000000
PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 106.110000 0.000000
PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.389000 0.000000
PyTorch mul mul_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.195000 0.000000
PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.989000 0.000000
PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.999000 0.000000
PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.939000 0.000000
PyTorch asr asr_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.980000 0.000000
PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 54.408000 0.000000
PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.647000 0.000000
PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.476000 0.000000
PyTorch asr asr_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.784000 0.000000
PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.583000 0.000000
PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 108.083000 0.000000
PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 107.663000 0.000000
PyTorch asr asr_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.283000 0.000000
PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.986000 0.000000
PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.676000 0.000000
PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.618000 0.000000
PyTorch lsl lsl_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.982000 0.000000
PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 54.698000 0.000000
PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.899000 0.000000
PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.741000 0.000000
PyTorch lsl lsl_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 51.182000 0.000000
PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.290000 0.000000
PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 107.744000 0.000000
PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 107.820000 0.000000
PyTorch lsl lsl_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 51.298000 0.000000
PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 1.988000 0.000000
PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 7.689000 0.000000
PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 7.695000 0.000000
PyTorch xor xor_M1_N1_K1_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 1.978000 0.000000
PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 54.934000 0.000000
PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 105.217000 0.000000
PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 104.215000 0.000000
PyTorch xor xor_M64_N64_K64_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 47.115000 0.000000
PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.int32 short False 55.974000 0.000000
PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8 short False 106.828000 0.000000
PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32 short False 106.879000 0.000000
PyTorch xor xor_M64_N64_K128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8 short False 48.197000 0.000000
62 PyTorch logical_and logical_and_in_one[64,1,64]_in_two[1,64,1]_cpu_dtypetorch.bool short False 78.404254 0.000000
63 PyTorch logical_and logical_and_M1_N1_K1_cpu_dtype_onetorch.bool_dtype_twotorch.bool short False 5.354032 0.000000
64 PyTorch logical_and logical_and_M64_N64_K64_cpu_dtype_onetorch.bool_dtype_twotorch.bool short False 54.072783 0.000000
71 PyTorch baddbmm baddbmm_B2_M1_N8_K2_cpu_dtypetorch.bfloat16 short False 6.476986 0.000000
72 PyTorch baddbmm baddbmm_B128_M64_N32_K64_cpu_dtypetorch.float32 short False 266.065131 0.000000
73 PyTorch baddbmm baddbmm_B128_M64_N32_K64_cpu_dtypetorch.bfloat16 short False 295.503063 0.000000
PyTorch all all_M1_N1_K1_cpu short False 5.773000 0.000000
PyTorch all all_M64_N64_K64_cpu short False 89.427000 0.000000
PyTorch all all_M64_N64_K128_cpu short False 120.119000 0.000000
74 PyTorch cat cat_sizes(1,1,1)_N2_dim0_cpu short False 4.301950 0.000000
75 PyTorch cat cat_sizes(512,512,2)_N2_dim1_cpu short False 99.093415 0.000000
76 PyTorch cat cat_sizes(128,1024,2)_N2_dim1_cpu short False 96.771578 0.000000

View File

@ -580,9 +580,6 @@ class BenchmarkRunner:
else "unknown"
)
# Extract operator name from test_name
operator_name = test_name.split("_")[0]
# Create the record
@dataclass
class BenchmarkInfo:
@ -596,7 +593,6 @@ class BenchmarkRunner:
name: str
type: str
origins: list[str]
extra_info: dict[str, Any]
@dataclass
class MetricInfo:
@ -622,14 +618,10 @@ class BenchmarkRunner:
"device": device,
"arch": device_arch,
"use_compile": use_compile,
"operator_name": operator_name,
},
),
model=ModelInfo(
name=test_name,
type="micro-benchmark",
origins=["pytorch"],
extra_info={"operator_name": operator_name},
name=test_name, type="micro-benchmark", origins=["pytorch"]
),
metric=MetricInfo(
name="latency",

View File

@ -25,7 +25,7 @@ binary_configs_broadcast = op_bench.config_list(
],
cross_product_configs={
"device": ["cpu"],
"dtype": [torch.float, torch.bfloat16, torch.float64],
"dtype": [torch.float],
},
tags=["short"],
)
@ -71,8 +71,8 @@ binary_short_configs = op_bench.config_list(
],
cross_product_configs={
"device": ["cpu", "cuda"],
"dtype_one": [torch.int32, torch.uint8],
"dtype_two": [torch.int32, torch.uint8],
"dtype_one": [torch.int32],
"dtype_two": [torch.int32],
},
tags=["short"],
)
@ -82,8 +82,8 @@ binary_long_configs = op_bench.cross_product_configs(
N=[32, 64],
K=[256, 512],
device=["cpu", "cuda"],
dtype_one=[torch.int8, torch.int32, torch.uint8],
dtype_two=[torch.int8, torch.int32, torch.uint8],
dtype_one=[torch.int8, torch.int32],
dtype_two=[torch.int8, torch.int32],
tags=["long"],
)

View File

@ -176,8 +176,8 @@ THIRD_PARTY_LIBS = {
"omp": ["//xplat/third-party/linker_lib:omp", "//third_party:no-op"],
"pocketfft": ["//third-party/pocket_fft:pocketfft", "//third_party:pocketfft_header"],
"psimd": ["//xplat/third-party/psimd:psimd", "//third_party:psimd"],
"pthreadpool": ["fbsource//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
"pthreadpool_header": ["fbsource//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
"pthreadpool": ["//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
"pthreadpool_header": ["//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
"moodycamel": ["//third-party/moodycamel:moodycamel", "//third_party:moodycamel"],
"pyyaml": ["//third-party/pypi/pyyaml:pyyaml", "//third_party:pyyaml"],
"rt": ["//xplat/third-party/linker_lib:rt", "//third_party:rt"],
@ -1729,10 +1729,8 @@ def define_buck_targets(
"torch/csrc/jit/backends/backend_debug_info.cpp",
"torch/csrc/jit/backends/backend_interface.cpp",
],
compiler_flags = get_pt_compiler_flags() + select({
"DEFAULT": [],
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags
}),
compiler_flags = get_pt_compiler_flags(),
fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags,
# @lint-ignore BUCKLINT link_whole
link_whole = True,
linker_flags = get_no_as_needed_linker_flag(),
@ -2025,9 +2023,6 @@ def define_buck_targets(
"ovr_config//os:android-x86_64": [
"-mssse3",
],
}) + select({
"DEFAULT": [],
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags,
}),
exported_preprocessor_flags = get_aten_preprocessor_flags(),
exported_deps = [

View File

@ -855,7 +855,6 @@ libtorch_python_cuda_core_sources = [
"torch/csrc/cuda/Stream.cpp",
"torch/csrc/cuda/Graph.cpp",
"torch/csrc/cuda/MemPool.cpp",
"torch/csrc/cuda/GreenContext.cpp",
"torch/csrc/cuda/shared/cudart.cpp",
"torch/csrc/cuda/shared/nvtx.cpp",
"torch/csrc/cuda/utils.cpp",

View File

@ -9,7 +9,6 @@
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/alignment.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>

View File

@ -13,17 +13,7 @@
namespace c10::CachingAllocator {
// "large" allocations may be packed in 20 MiB blocks
constexpr size_t kLargeBuffer = 20971520;
// "small" allocations are packed in 2 MiB blocks
constexpr size_t kSmallBuffer = 2097152;
// all sizes are rounded to at least 512 bytes
constexpr size_t kMinBlockSize = 512;
// largest "small" allocation is 1 MiB
constexpr size_t kSmallSize = 1048576;
// allocations between 1 and 10 MiB may use kLargeBuffer
constexpr size_t kMinLargeAlloc = 10485760;
// round up large allocations to 2 MiB
constexpr size_t kRoundLarge = 2097152;
const size_t kLargeBuffer = 20971520;
// A utility class for tokenizing allocator configuration strings into discrete
// parts. For example, the config string:

View File

@ -223,7 +223,7 @@ inline DispatchKey backendToDispatchKey(Backend b) {
case Backend::PrivateUse1:
return DispatchKey::PrivateUse1;
default:
TORCH_CHECK(false, "Unknown backend");
throw std::runtime_error("Unknown backend");
}
}

View File

@ -52,9 +52,7 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
// where we would like to support composite implicit kernels but not
// explicit kernels therefore we manually add the key to the
// math_dispatch_keyset
DispatchKeySet{DispatchKey::NestedTensor} |
// Functionalize should always reuse CompositeImplicit decomps.
DispatchKeySet{DispatchKey::Functionalize};
DispatchKeySet{DispatchKey::NestedTensor};
constexpr DispatchKeySet nested_dispatch_keyset =
DispatchKeySet(

View File

@ -102,7 +102,7 @@ uint64_t getNonDeterministicRandom(bool is_cuda) {
} else {
std::random_device rd;
// limit to 53 bits to ensure unique representation in double
s = (((static_cast<uint64_t>(rd())) << 32) + rd()) & 0x1FFFFFFFFFFFFF;
s = ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF;
}
return s;
}

View File

@ -20,8 +20,7 @@ void maybeApplyRefcountedDeleter(const c10::Storage& storage) {
std::lock_guard<std::mutex> guard(replace_data_ptr_mutex);
c10::DataPtr& data_ptr = storage.mutable_data_ptr();
if (reinterpret_cast<const void*>(data_ptr.get_deleter()) ==
reinterpret_cast<const void*>(&c10::refcounted_deleter)) {
if ((void*)data_ptr.get_deleter() == (void*)&c10::refcounted_deleter) {
// Data pointer is already shared
return;
}

View File

@ -336,7 +336,7 @@ class C10_API Scalar {
} else if (isBoolean()) {
return ScalarType::Bool;
} else {
TORCH_CHECK(false, "Unknown scalar type.");
throw std::runtime_error("Unknown scalar type.");
}
}

View File

@ -228,7 +228,7 @@ std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) {
case c10::ScalarType::Float4_e2m1fn_x2:
return std::make_pair("float4_e2m1fn_x2", "");
default:
TORCH_CHECK(false, "Unimplemented scalar type");
throw std::runtime_error("Unimplemented scalar type");
}
}

View File

@ -52,6 +52,19 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
#undef DEFINE_CONSTANT
inline const char* toString(ScalarType t) {
#define DEFINE_CASE(_, name) \
case ScalarType::name: \
return #name;
switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
default:
return "UNKNOWN_SCALAR";
}
#undef DEFINE_CASE
}
inline size_t elementSize(ScalarType t) {
#define CASE_ELEMENTSIZE_CASE(ctype, name) \
case ScalarType::name: \
@ -137,6 +150,22 @@ inline ScalarType toQIntType(ScalarType t) {
}
}
inline ScalarType toUnderlying(ScalarType t) {
switch (t) {
case ScalarType::QUInt8:
case ScalarType::QUInt4x2:
[[fallthrough]];
case ScalarType::QUInt2x4:
return ScalarType::Byte;
case ScalarType::QInt8:
return ScalarType::Char;
case ScalarType::QInt32:
return ScalarType::Int;
default:
return t;
}
}
inline bool isSignedType(ScalarType t) {
#define CASE_ISSIGNED(name) \
case ScalarType::name: \
@ -279,6 +308,12 @@ inline bool canCast(const ScalarType from, const ScalarType to) {
C10_API ScalarType promoteTypes(ScalarType a, ScalarType b);
inline std::ostream& operator<<(
std::ostream& stream,
at::ScalarType scalar_type) {
return stream << toString(scalar_type);
}
// Returns a pair of strings representing the names for each dtype.
// The returned pair is (name, legacy_name_if_applicable)
C10_API std::pair<std::string, std::string> getDtypeNames(

View File

@ -83,7 +83,7 @@ DEFINE_BINARY(max_slow_path, sym_max, SymInt)
SymInt::operator SymFloat() const {
if (auto ma = maybe_as_int()) {
return SymFloat(static_cast<double>(*ma));
return SymFloat(double(*ma));
} else {
return SymFloat(toSymNodeImplUnowned()->sym_float());
}

View File

@ -1,7 +1,6 @@
#pragma once
#include <cstddef>
#include <new>
namespace c10 {
@ -19,12 +18,4 @@ constexpr size_t gPagesize = 4096;
// since the default thp pagesize is 2MB, enable thp only
// for buffers of size 2MB or larger to avoid memory bloating
constexpr size_t gAlloc_threshold_thp = static_cast<size_t>(2) * 1024 * 1024;
// Cache line size used to avoid false sharing between threads. Falls back to 64
// bytes if C++17 feature is unavailable.
#ifdef __cpp_lib_hardware_interference_size
using std::hardware_destructive_interference_size;
#else
constexpr std::size_t hardware_destructive_interference_size = 64;
#endif
} // namespace c10

View File

@ -44,8 +44,7 @@ bool has_simple_data_ptr(const c10::StorageImpl& storage) {
}
bool is_cow_data_ptr(const c10::DataPtr& data_ptr) {
return reinterpret_cast<const void*>(data_ptr.get_deleter()) ==
reinterpret_cast<const void*>(&cow::cow_deleter);
return (void*)data_ptr.get_deleter() == (void*)&cow::cow_deleter;
}
c10::intrusive_ptr<StorageImpl> lazy_clone_storage(StorageImpl& storage) {

View File

@ -87,7 +87,9 @@ bool ThreadPool::inThreadPool() const {
}
void ThreadPool::run(std::function<void()> func) {
TORCH_CHECK(threads_.size() > 0, "No threads to run a task");
if (threads_.empty()) {
throw std::runtime_error("No threads to run a task");
}
std::unique_lock<std::mutex> lock(mutex_);
// Set task and signal condition variable so that a worker thread will

View File

@ -131,6 +131,15 @@ namespace Native {
* notifyCaptureDestroy.
*/
constexpr size_t kMinBlockSize =
512; // all sizes are rounded to at least 512 bytes
constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB
constexpr size_t kSmallBuffer =
2097152; // "small" allocations are packed in 2 MiB blocks
constexpr size_t kMinLargeAlloc =
10485760; // allocations between 1 and 10 MiB may use kLargeBuffer
constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB
static char SHAREABLE_HANDLE_VERSION = 2;
enum ShareableHandleType : char {
SHAREABLE_CUDA_MALLOC = 'c',
@ -503,7 +512,7 @@ struct ExpandableSegment {
header.segment_size = segment_size_;
header.num_handles = end - begin;
buf.write(reinterpret_cast<const char*>(&header), sizeof(ShareHeader));
buf.write((const char*)&header, sizeof(ShareHeader));
for (auto i : c10::irange(begin, end)) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
auto& handle = handles_.at(i).value();
@ -519,9 +528,7 @@ struct ExpandableSegment {
TORCH_CHECK(
handle.shareable_handle != std::nullopt,
"shareable_handle is null");
buf.write(
reinterpret_cast<const char*>(&*handle.shareable_handle),
sizeof(int));
buf.write((const char*)&*handle.shareable_handle, sizeof(int));
} else {
if (!handle.shareable_handle) {
CUmemFabricHandle fabric_handle;
@ -534,8 +541,7 @@ struct ExpandableSegment {
handle.shareable_handle != std::nullopt,
"shareable_handle is null");
buf.write(
reinterpret_cast<const char*>(&*handle.shareable_handle),
sizeof(CUmemFabricHandle));
(const char*)&*handle.shareable_handle, sizeof(CUmemFabricHandle));
}
}
return rangeFromHandles(begin, end);
@ -546,7 +552,7 @@ struct ExpandableSegment {
std::vector<c10::DeviceIndex> peers,
std::istream& buf) {
ShareHeader header{};
buf.read(reinterpret_cast<char*>(&header), sizeof(ShareHeader));
buf.read((char*)&header, sizeof(ShareHeader));
auto segment = std::make_unique<ExpandableSegment>(
device, std::nullopt, header.segment_size, std::move(peers));
// older build setups (e.g. multiwheels) do not have this syscall, added 2020
@ -568,11 +574,11 @@ struct ExpandableSegment {
for (auto i : c10::irange(header.num_handles)) {
(void)i;
int fd = 0;
buf.read(reinterpret_cast<char*>(&fd), sizeof(int));
buf.read((char*)&fd, sizeof(int));
auto myfd = syscall(SYS_pidfd_getfd, pidfd, fd, 0);
if (myfd == -1) {
auto err = errno;
close(static_cast<int>(pidfd));
close((int)pidfd);
for (auto& h : segment->handles_) {
C10_CUDA_DRIVER_CHECK(
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
@ -592,16 +598,15 @@ struct ExpandableSegment {
(void*)(uintptr_t)myfd,
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
LOG(INFO) << "use posix fd to import expandable segments.";
close(static_cast<int>(myfd));
close((int)myfd);
segment->handles_.emplace_back(Handle{handle, std::nullopt});
}
close(static_cast<int>(pidfd));
close((int)pidfd);
} else {
for (auto i : c10::irange(header.num_handles)) {
(void)i;
CUmemFabricHandle fabric_handle;
buf.read(
reinterpret_cast<char*>(&fabric_handle), sizeof(CUmemFabricHandle));
buf.read((char*)&fabric_handle, sizeof(CUmemFabricHandle));
CUmemGenericAllocationHandle handle = 0;
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemImportFromShareableHandle_(
&handle,
@ -932,7 +937,7 @@ class EventPool {
private:
struct PerDevicePool {
alignas(hardware_destructive_interference_size) std::mutex mutex_;
alignas(64) std::mutex mutex_;
std::vector<std::unique_ptr<cudaEvent_t>> event_pool_;
};
std::vector<PerDevicePool> pools_;
@ -1054,7 +1059,7 @@ class RingBuffer {
void setMaxEntries(size_t size) {
std::lock_guard<std::mutex> lk(alloc_trace_lock);
alloc_trace_max_entries_ = std::max(static_cast<size_t>(1), size);
alloc_trace_max_entries_ = std::max(size_t(1), size);
}
void insertEntries(const T& entry) {
@ -1986,16 +1991,15 @@ class DeviceCachingAllocator {
while (base_block->prev) {
base_block = base_block->prev;
}
offset = static_cast<const char*>(block->ptr) -
static_cast<const char*>(base_block->ptr);
offset = (char*)block->ptr - (char*)base_block->ptr;
cudaIpcMemHandle_t handle;
C10_CUDA_CHECK(cudaIpcGetMemHandle(&handle, base_block->ptr));
ss.write(reinterpret_cast<const char*>(&handle), CUDA_IPC_HANDLE_SIZE);
ss.write((char*)&handle, CUDA_IPC_HANDLE_SIZE);
} else {
ss.put(SHAREABLE_CUDA_EXPANDABLE_SEGMENT);
auto full_range = block->expandable_segment_->share(
SegmentRange(block->ptr, block->size), ss);
offset = static_cast<const char*>(block->ptr) - full_range.ptr;
offset = (char*)block->ptr - full_range.ptr;
}
return ShareableHandle{offset, ss.str()};
}
@ -3225,8 +3229,7 @@ class DeviceCachingAllocator {
}
total_allocated_memory += size;
p.block = new Block(
p.device(), p.stream(), size, p.pool, static_cast<char*>(ptr));
p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr);
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
stats.segment[stat_type].increase(1);
stats.reserved_bytes[stat_type].increase(size);
@ -3749,6 +3752,11 @@ static void uncached_delete(void* ptr) {
static void local_raw_delete(void* ptr);
thread_local std::stack<std::string> DeviceCachingAllocator::compile_context;
thread_local std::string DeviceCachingAllocator::user_metadata;
#ifdef __cpp_lib_hardware_interference_size
using std::hardware_destructive_interference_size;
#else
static constexpr std::size_t hardware_destructive_interference_size = 64;
#endif
class NativeCachingAllocator : public CUDAAllocator {
private:
@ -3769,7 +3777,7 @@ class NativeCachingAllocator : public CUDAAllocator {
allocated_blocks;
static size_t get_mutex_shard_id(void* ptr) {
return twang_mix64(reinterpret_cast<uintptr_t>(ptr)) % kNumMutexShard;
return twang_mix64((size_t)ptr) % kNumMutexShard;
}
void add_allocated_block(Block* block) {
@ -3806,8 +3814,8 @@ class NativeCachingAllocator : public CUDAAllocator {
if (size < device_count) {
device_allocator.resize(device_count);
for (const auto i : c10::irange(size, device_count)) {
device_allocator[i] = std::make_unique<DeviceCachingAllocator>(
static_cast<c10::DeviceIndex>(i));
device_allocator[i] =
std::make_unique<DeviceCachingAllocator>(c10::DeviceIndex(i));
}
}
}
@ -4336,7 +4344,7 @@ class NativeCachingAllocator : public CUDAAllocator {
// SHARABLE_CUDA_MALLOC
if (type == SHAREABLE_CUDA_MALLOC) {
cudaIpcMemHandle_t cuda_handle;
ss.read(reinterpret_cast<char*>(&cuda_handle), CUDA_IPC_HANDLE_SIZE);
ss.read((char*)&cuda_handle, CUDA_IPC_HANDLE_SIZE);
C10_CUDA_CHECK(cudaIpcOpenMemHandle(
&cuda_ipc_ptr_, cuda_handle, cudaIpcMemLazyEnablePeerAccess));
} else if (type == SHAREABLE_CUDA_EXPANDABLE_SEGMENT) {
@ -4469,10 +4477,7 @@ struct BackendStaticInitializer {
if (key == "backend") {
tokenizer.checkToken(++i, ":");
i++; // Move to the value after the colon
// break up token to trick hipify
if (tokenizer[i] ==
"c"
"udaMallocAsync"
if (tokenizer[i] == "cudaMallocAsync"
#ifdef USE_ROCM
// convenience for ROCm users to allow either CUDA or HIP env var
|| tokenizer[i] == "hipMallocAsync"

View File

@ -46,7 +46,7 @@ bool operator==(const UsageStream& lhs, const UsageStream& rhs) {
struct UsageStreamHash {
size_t operator()(const UsageStream& us) const noexcept {
return std::hash<void*>{}(us.stream) + static_cast<size_t>(us.device);
return std::hash<void*>{}(us.stream) + size_t(us.device);
}
};
@ -913,9 +913,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
}
}
std::string name() override {
// break up token to trick hipify
return "c"
"udaMallocAsync";
return "cudaMallocAsync";
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
C10_CUDA_CHECK(

View File

@ -128,7 +128,7 @@ std::ostream& operator<<(std::ostream& stream, StreamIdType s) {
} else if (s.isExt()) {
stream << "EXT";
} else {
stream << "PRIORITY " << static_cast<int>(s.getStreamType());
stream << "PRIORITY " << int(s.getStreamType());
}
return stream;
}

View File

@ -51,17 +51,6 @@
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \
_(cuCtxFromGreenCtx, 12080) \
_(cuCtxGetCurrent, 12080) \
_(cuCtxPopCurrent, 12080) \
_(cuCtxPushCurrent, 12080) \
_(cuCtxSetCurrent, 12080) \
_(cuGreenCtxCreate, 12080) \
_(cuGreenCtxDestroy, 12080) \
_(cuDevSmResourceSplitByCount, 12080) \
_(cuDeviceGet, 12080) \
_(cuDeviceGetDevResource, 12080) \
_(cuDevResourceGenerateDesc, 12080) \
_(cuMulticastAddDevice, 12030) \
_(cuMulticastBindMem, 12030) \
_(cuMulticastCreate, 12030) \

View File

@ -46,8 +46,7 @@ std::function<time_t(approx_time_t)> ApproximateClockToUnixTimeConverter::
for (const auto i : c10::irange(replicates)) {
auto delta_ns = end_times[i].t_ - start_times_[i].t_;
auto delta_approx = end_times[i].approx_t_ - start_times_[i].approx_t_;
scale_factors[i] =
static_cast<double>(delta_ns) / static_cast<double>(delta_approx);
scale_factors[i] = (double)delta_ns / (double)delta_approx;
}
std::sort(scale_factors.begin(), scale_factors.end());
long double scale_factor = scale_factors[replicates / 2 + 1];
@ -65,8 +64,7 @@ std::function<time_t(approx_time_t)> ApproximateClockToUnixTimeConverter::
for (const auto i : c10::irange(replicates)) {
auto dt = start_times_[i].t_ - t0;
auto dt_approx =
static_cast<double>(start_times_[i].approx_t_ - t0_approx) *
scale_factor;
(double)(start_times_[i].approx_t_ - t0_approx) * scale_factor;
t0_correction[i] = dt - (time_t)dt_approx; // NOLINT
}
t0 += t0_correction[t0_correction.size() / 2 + 1]; // NOLINT
@ -74,9 +72,7 @@ std::function<time_t(approx_time_t)> ApproximateClockToUnixTimeConverter::
return [=](approx_time_t t_approx) {
// See above for why this is more stable than `A * t_approx + B`.
return t_approx > t0_approx
? static_cast<time_t>(
static_cast<double>(t_approx - t0_approx) * scale_factor) +
t0
? (time_t)((double)(t_approx - t0_approx) * scale_factor) + t0
: 0;
};
}

View File

@ -45,7 +45,14 @@ constexpr bool is_pod_v = is_pod<T>::value;
namespace guts {
#if defined(__HIP__)
#if defined(__cpp_lib_apply) && !defined(__CUDA_ARCH__) && !defined(__HIP__)
template <class F, class Tuple>
C10_HOST_DEVICE inline constexpr decltype(auto) apply(F&& f, Tuple&& t) {
return std::apply(std::forward<F>(f), std::forward<Tuple>(t));
}
#else
// Implementation from http://en.cppreference.com/w/cpp/utility/apply (but
// modified)

View File

@ -132,15 +132,15 @@ std::ostream& operator<<(std::ostream& o, const uint128& b) {
int div_base_log = 0;
switch (flags & std::ios::basefield) {
case std::ios::hex:
div = static_cast<uint64_t>(0x1000000000000000u); // 16^15
div = (uint64_t)0x1000000000000000u; // 16^15
div_base_log = 15;
break;
case std::ios::oct:
div = static_cast<uint64_t>(01000000000000000000000u); // 8^21
div = (uint64_t)01000000000000000000000u; // 8^21
div_base_log = 21;
break;
default: // std::ios::dec
div = static_cast<uint64_t>(10000000000000000000u); // 10^19
div = (uint64_t)10000000000000000000u; // 10^19
div_base_log = 19;
break;
}

View File

@ -14,6 +14,16 @@ using namespace c10::CachingDeviceAllocator;
// newly allocated memory with 512-byte alignment.
constexpr size_t kDeviceAlignment = 512;
// all sizes are rounded to at least 512 bytes
constexpr size_t kMinBlockSize = 512;
// largest "small" allocation is 1 MiB
constexpr size_t kSmallSize = 1048576;
// "small" allocations are packed in 2 MiB blocks
constexpr size_t kSmallBuffer = 2097152;
// allocations between 1 and 10 MiB may use kLargeBuffer
constexpr size_t kMinLargeAlloc = 10485760;
// round up large allocations to 2 MiB
constexpr size_t kRoundLarge = 2097152;
namespace {
using stream_set = ska::flat_hash_set<xpu::XPUStream>;
@ -544,7 +554,7 @@ static void local_raw_delete(void* ptr);
class XPUAllocator : public DeviceAllocator {
private:
alignas(hardware_destructive_interference_size) std::mutex mutex;
std::mutex mutex;
ska::flat_hash_map<void*, Block*> allocated_blocks;
void add_allocated_block(Block* block) {

View File

@ -607,12 +607,6 @@ if(USE_CUDA)
set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a")
endif()
endif()
if(NOT WIN32)
set_source_files_properties(
${TORCH_ROOT}/aten/src/ATen/cuda/CUDAGreenContext.cpp
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
)
endif()
set_source_files_properties(
${TORCH_ROOT}/aten/src/ATen/cuda/detail/LazyNVRTC.cpp
PROPERTIES COMPILE_DEFINITIONS "NVRTC_SHORTHASH=${CUDA_NVRTC_SHORTHASH}"

View File

@ -16,7 +16,7 @@ find_path(vecLib_INCLUDE_DIR vecLib.h
DOC "vecLib include directory"
PATHS /System/Library/Frameworks/Accelerate.framework/Versions/Current/${__veclib_include_suffix}
/System/Library/${__veclib_include_suffix}
/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/Frameworks/Accelerate.framework/Versions/Current/Frameworks/vecLib.framework/Headers/
/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.9.sdk/System/Library/Frameworks/Accelerate.framework/Versions/Current/Frameworks/vecLib.framework/Headers/
${CMAKE_OSX_SYSROOT}/System/Library/Frameworks/Accelerate.framework/Versions/Current/${__veclib_include_suffix}
NO_DEFAULT_PATH)

View File

@ -224,12 +224,6 @@ AMD/ROCm/HIP
- Jithun Nair (`jithunnair-amd <https://github.com/jithunnair-amd>`__)
- (emeritus) Junjie Bai (`bddppq <https://github.com/bddppq>`__)
XPU/Intel GPU
~~~~~~~~~~~~~
- Eikan Wang (`EikanWang <https://github.com/EikanWang>`__)
- Guangye Yu (`guangyey <https://github.com/guangyey>`__)
Build + CI
~~~~~~~~~~

View File

@ -258,28 +258,6 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
```
## Green Contexts (experimental)
`torch.cuda.green_contexts` provides thin wrappers around the CUDA Green Context APIs
to enable more general carveout of SM resources for CUDA kernels.
These APIs can be used in PyTorch with CUDA versions greater than or equal to 12.8.
See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example of how to use these.
```{eval-rst}
.. currentmodule:: torch.cuda.green_contexts
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
GreenContext
```
% This module needs to be documented. Adding here in the meantime
% for tracking purposes
@ -292,10 +270,6 @@ See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example
.. py:module:: torch.cuda.gds
```
```{eval-rst}
.. py:module:: torch.cuda.green_contexts
```
```{eval-rst}
.. py:module:: torch.cuda.jiterator
```

View File

@ -44,9 +44,9 @@ following invariants. More specifications about the IR can be found
- **Normalized**: There are no Python semantics within the graph. Submodules
from the original programs are inlined to form one fully flattened
computational graph.
- **Graph properties**: By default, the graph may contain both functional and
non-functional operators (including mutations). To obtain a purely functional
graph, use `run_decompositions()` which removes mutations and aliasing.
- **Graph properties**: The graph is purely functional, meaning it does not
contain operations with side effects such as mutations or aliasing. It does
not mutate any intermediate values, parameters, or buffers.
- **Metadata**: The graph contains metadata captured during tracing, such as a
stacktrace from user's code.
@ -56,8 +56,8 @@ Under the hood, `torch.export` leverages the following latest technologies:
called the Frame Evaluation API to safely trace PyTorch graphs. This
provides a massively improved graph capturing experience, with much fewer
rewrites needed in order to fully trace the PyTorch code.
- **AOT Autograd** ensures the graph is decomposed/lowered to the ATen operator
set. When using `run_decompositions()`, it can also provide functionalization.
- **AOT Autograd** provides a functionalized PyTorch graph and ensures the graph
is decomposed/lowered to the ATen operator set.
- **Torch FX (torch.fx)** is the underlying representation of the graph,
allowing flexible Python-based transformations.
@ -444,31 +444,23 @@ saved_exported_program = torch.export.load('exported_program.pt2')
(training-export)=
## Export IR: Training vs Inference
## Export IR, Decompositions
The graph produced by `torch.export` returns a graph containing only
[ATen operators](https://pytorch.org/cppdocs/#aten), which are the basic unit of
computation in PyTorch. Export provides different IR levels based on your use case:
computation in PyTorch. As there are over
3000 ATen operators, export provides a way to narrow down the operator set used
in the graph based on certain characteristics, creating different IRs.
| IR Type | How to Obtain | Properties | Operator Count | Use Case |
|---------|---------------|------------|----------------|----------|
| Training IR | `torch.export.export()` (default) | May contain mutations | ~3000 | Training with autograd |
| Inference IR | `ep.run_decompositions(decomp_table={})` | Purely functional | ~2000 | Inference deployment |
| Core ATen IR | `ep.run_decompositions(decomp_table=None)` | Purely functional, highly decomposed | ~180 | Minimal backend support |
### Training IR (Default)
By default, export produces a **Training IR** which contains all ATen
operators, including both functional and non-functional (mutating) operators.
A functional operator is one that does not contain any mutations or aliasing
of the inputs, while non-functional operators may modify their inputs in-place.
By default, export produces the most generic IR which contains all ATen
operators, including both functional and non-functional operators. A functional
operator is one that does not contain any mutations or aliasing of the inputs.
You can find a list of all ATen operators
[here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml)
and you can inspect if an operator is functional by checking
`op._schema.is_mutable`.
This Training IR, which may contain mutations, is designed for training use
cases and can be used with eager PyTorch Autograd.
This generic IR can be used to train in eager PyTorch Autograd.
```{code-cell}
import torch
@ -488,18 +480,15 @@ ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
print(ep_for_training.graph_module.print_readable(print_output=False))
```
### Inference IR (via run_decompositions)
However, if you want to use the IR for inference, or decrease the amount of
operators being used, you can lower the graph through the
{func}`ExportedProgram.run_decompositions` API. This method decomposes the
ATen operators into the ones specified in the decomposition table, and
functionalizes the graph.
To obtain an **Inference IR** suitable for deployment, use the
{func}`ExportedProgram.run_decompositions` API. This method automatically:
1. Functionalizes the graph (removes all mutations and converts them to functional equivalents)
2. Optionally decomposes ATen operators based on the provided decomposition table
This produces a purely functional graph ideal for inference scenarios.
By specifying an empty decomposition table (`decomp_table={}`), you get just
the functionalization without additional decompositions. This produces an
Inference IR with ~2000 functional operators (compared to 3000+ in Training IR).
By specifying an empty set, we're only performing functionalization, and does
not do any additional decompositions. This results in an IR which contains ~2000
operators (instead of the 3000 operators above), and is ideal for inference cases.
```{code-cell}
import torch
@ -525,14 +514,11 @@ As we can see, the previously in-place operator,
`torch.ops.aten.add_.default` has now been replaced with
`torch.ops.aten.add.default`, a functional operator.
### Core ATen IR
We can further lower the Inference IR to the
We can also further lower this exported program to an operator set which only
contains the
`Core ATen Operator Set <https://pytorch.org/docs/main/torch.compiler_ir.html#core-aten-ir>`__,
which contains only ~180 operators. This is achieved by passing `decomp_table=None`
(which uses the default decomposition table) to `run_decompositions()`. This IR
is optimal for backends who want to minimize the number of operators they need
to implement.
which is a collection of only ~180 operators. This IR is optimal for backends
who do not want to reimplement all ATen operators.
```{code-cell}
import torch

View File

@ -208,13 +208,11 @@ select = [
"PLC1802", # len({expression}) used as condition without comparison
"PLC0205", # string as __slots__
"PLC3002", # unnecessary-direct-lambda-call
"PLC0414", # Import alias does not rename original package
"PLE",
"PLR0133", # constant comparison
"PLR0206", # property with params
"PLR1722", # use sys exit
"PLR1736", # unnecessary list index
"PLW0127", # Self-assignment of variable
"PLW0129", # assert on string literal
"PLW0131", # named expr without context
"PLW0133", # useless exception statement

View File

@ -23,12 +23,10 @@ project-includes = [
project-excludes = [
# ==== below will be enabled directory by directory ====
# ==== to test Pyrefly on a specific directory, simply comment it out ====
"torch/_inductor/runtime",
"torch/_inductor/codegen/triton.py",
"tools/linter/adapters/test_device_bias_linter.py",
"tools/code_analyzer/gen_operators_yaml.py",
"torch/_inductor/runtime/triton_heuristics.py",
"torch/_inductor/runtime/triton_helpers.py",
"torch/_inductor/runtime/halide_helpers.py",
# formatting issues, will turn on after adjusting where suppressions can be
# in import statements
"tools/flight_recorder/components/types.py",

View File

@ -53,40 +53,3 @@ TEST_FORALL(AT_FORALL_COMPLEX_TYPES, 2)
#undef DEFINE_CHECK
#undef TEST_FORALL
TEST(TestScalarType, toString) {
using torch::headeronly::ScalarType;
#define DEFINE_CHECK(_, name) EXPECT_EQ(toString(ScalarType::name), #name);
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
#undef DEFINE_CHECK
}
TEST(TestScalarType, operator_left_shift) {
using torch::headeronly::ScalarType;
#define DEFINE_CHECK(_, name) \
{ \
std::stringstream ss; \
ss << ScalarType::name; \
EXPECT_EQ(ss.str(), #name); \
}
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
#undef DEFINE_CHECK
}
TEST(TestScalarType, toUnderlying) {
using torch::headeronly::ScalarType;
using torch::headeronly::toUnderlying;
EXPECT_EQ(toUnderlying(ScalarType::QUInt8), ScalarType::Byte);
EXPECT_EQ(toUnderlying(ScalarType::QUInt4x2), ScalarType::Byte);
EXPECT_EQ(toUnderlying(ScalarType::QUInt2x4), ScalarType::Byte);
EXPECT_EQ(toUnderlying(ScalarType::QInt8), ScalarType::Char);
EXPECT_EQ(toUnderlying(ScalarType::QInt32), ScalarType::Int);
#define DEFINE_CHECK(_, name) \
EXPECT_EQ(toUnderlying(ScalarType::name), ScalarType::name);
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CHECK);
AT_FORALL_FLOAT8_TYPES(DEFINE_CHECK);
#undef DEFINE_CHECK
}

View File

@ -1166,7 +1166,7 @@ class TestFullyShardPrefetch(FSDPTest):
loss = model(inp)
events.clear()
loss.sum().backward()
expected_backward_events = [
expected_backward_events = expected_backward_events = [
("unshard", "norm, output", TrainingState.PRE_BACKWARD),
# root explicit prefetch layers.2
("unshard", "layers.2", TrainingState.PRE_BACKWARD),

Some files were not shown because too many files have changed in this diff Show More