mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-28 02:04:53 +08:00
Compare commits
1 Commits
ciflow/tru
...
annotate_f
| Author | SHA1 | Date | |
|---|---|---|---|
| 98826fd37b |
@ -113,7 +113,6 @@ case "$tag" in
|
|||||||
UCX_COMMIT=${_UCX_COMMIT}
|
UCX_COMMIT=${_UCX_COMMIT}
|
||||||
UCC_COMMIT=${_UCC_COMMIT}
|
UCC_COMMIT=${_UCC_COMMIT}
|
||||||
TRITON=yes
|
TRITON=yes
|
||||||
INSTALL_MINGW=yes
|
|
||||||
;;
|
;;
|
||||||
pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11)
|
pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11)
|
||||||
CUDA_VERSION=13.0.0
|
CUDA_VERSION=13.0.0
|
||||||
@ -362,7 +361,6 @@ docker build \
|
|||||||
--build-arg "OPENBLAS=${OPENBLAS:-}" \
|
--build-arg "OPENBLAS=${OPENBLAS:-}" \
|
||||||
--build-arg "SKIP_SCCACHE_INSTALL=${SKIP_SCCACHE_INSTALL:-}" \
|
--build-arg "SKIP_SCCACHE_INSTALL=${SKIP_SCCACHE_INSTALL:-}" \
|
||||||
--build-arg "SKIP_LLVM_SRC_BUILD_INSTALL=${SKIP_LLVM_SRC_BUILD_INSTALL:-}" \
|
--build-arg "SKIP_LLVM_SRC_BUILD_INSTALL=${SKIP_LLVM_SRC_BUILD_INSTALL:-}" \
|
||||||
--build-arg "INSTALL_MINGW=${INSTALL_MINGW:-}" \
|
|
||||||
-f $(dirname ${DOCKERFILE})/Dockerfile \
|
-f $(dirname ${DOCKERFILE})/Dockerfile \
|
||||||
-t "$tmp_tag" \
|
-t "$tmp_tag" \
|
||||||
"$@" \
|
"$@" \
|
||||||
|
|||||||
@ -83,6 +83,10 @@ function build_cpython {
|
|||||||
py_suffix=${py_ver::-1}
|
py_suffix=${py_ver::-1}
|
||||||
py_folder=$py_suffix
|
py_folder=$py_suffix
|
||||||
fi
|
fi
|
||||||
|
# Update to rc2 due to https://github.com/python/cpython/commit/c72699086fe4
|
||||||
|
if [ "$py_suffix" == "3.14.0" ]; then
|
||||||
|
py_suffix="3.14.0rc2"
|
||||||
|
fi
|
||||||
wget -q $PYTHON_DOWNLOAD_URL/$py_folder/Python-$py_suffix.tgz -O Python-$py_ver.tgz
|
wget -q $PYTHON_DOWNLOAD_URL/$py_folder/Python-$py_suffix.tgz -O Python-$py_ver.tgz
|
||||||
do_cpython_build $py_ver Python-$py_suffix
|
do_cpython_build $py_ver Python-$py_suffix
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -ex
|
|
||||||
|
|
||||||
# Install MinGW-w64 for Windows cross-compilation
|
|
||||||
apt-get update
|
|
||||||
apt-get install -y g++-mingw-w64-x86-64-posix
|
|
||||||
|
|
||||||
echo "MinGW-w64 installed successfully"
|
|
||||||
x86_64-w64-mingw32-g++ --version
|
|
||||||
@ -19,8 +19,8 @@ pip_install \
|
|||||||
transformers==4.36.2
|
transformers==4.36.2
|
||||||
|
|
||||||
pip_install coloredlogs packaging
|
pip_install coloredlogs packaging
|
||||||
pip_install onnxruntime==1.23.1
|
pip_install onnxruntime==1.23.0
|
||||||
pip_install onnxscript==0.5.4
|
pip_install onnxscript==0.5.3
|
||||||
|
|
||||||
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
|
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
|
||||||
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
|
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
|
||||||
|
|||||||
@ -39,13 +39,9 @@ case ${DOCKER_TAG_PREFIX} in
|
|||||||
DOCKER_GPU_BUILD_ARG=""
|
DOCKER_GPU_BUILD_ARG=""
|
||||||
;;
|
;;
|
||||||
rocm*)
|
rocm*)
|
||||||
# we want the patch version of 7.0 instead
|
|
||||||
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
|
|
||||||
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
|
|
||||||
fi
|
|
||||||
# we want the patch version of 6.4 instead
|
# we want the patch version of 6.4 instead
|
||||||
if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then
|
if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then
|
||||||
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4"
|
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
|
||||||
fi
|
fi
|
||||||
BASE_TARGET=rocm
|
BASE_TARGET=rocm
|
||||||
GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete
|
GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete
|
||||||
|
|||||||
@ -75,13 +75,9 @@ case ${image} in
|
|||||||
DOCKERFILE_SUFFIX="_cuda_aarch64"
|
DOCKERFILE_SUFFIX="_cuda_aarch64"
|
||||||
;;
|
;;
|
||||||
manylinux2_28-builder:rocm*)
|
manylinux2_28-builder:rocm*)
|
||||||
# we want the patch version of 7.0 instead
|
|
||||||
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
|
|
||||||
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
|
|
||||||
fi
|
|
||||||
# we want the patch version of 6.4 instead
|
# we want the patch version of 6.4 instead
|
||||||
if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then
|
if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then
|
||||||
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4"
|
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
|
||||||
fi
|
fi
|
||||||
TARGET=rocm_final
|
TARGET=rocm_final
|
||||||
MANY_LINUX_VERSION="2_28"
|
MANY_LINUX_VERSION="2_28"
|
||||||
|
|||||||
@ -334,12 +334,12 @@ sympy==1.13.3
|
|||||||
#Pinned versions:
|
#Pinned versions:
|
||||||
#test that import:
|
#test that import:
|
||||||
|
|
||||||
onnx==1.19.1
|
onnx==1.18.0
|
||||||
#Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal
|
#Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal
|
||||||
#Pinned versions:
|
#Pinned versions:
|
||||||
#test that import:
|
#test that import:
|
||||||
|
|
||||||
onnxscript==0.5.4
|
onnxscript==0.5.3
|
||||||
#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal
|
#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal
|
||||||
#Pinned versions:
|
#Pinned versions:
|
||||||
#test that import:
|
#test that import:
|
||||||
|
|||||||
@ -103,11 +103,6 @@ COPY ci_commit_pins/torchbench.txt torchbench.txt
|
|||||||
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
|
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
|
||||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
|
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
|
||||||
|
|
||||||
ARG INSTALL_MINGW
|
|
||||||
COPY ./common/install_mingw.sh install_mingw.sh
|
|
||||||
RUN if [ -n "${INSTALL_MINGW}" ]; then bash ./install_mingw.sh; fi
|
|
||||||
RUN rm install_mingw.sh
|
|
||||||
|
|
||||||
ARG TRITON
|
ARG TRITON
|
||||||
ARG TRITON_CPU
|
ARG TRITON_CPU
|
||||||
|
|
||||||
|
|||||||
@ -57,8 +57,8 @@ def clone_external_repo(target: str, repo: str, dst: str = "", update_submodules
|
|||||||
logger.info("Successfully cloned %s", target)
|
logger.info("Successfully cloned %s", target)
|
||||||
return r, commit
|
return r, commit
|
||||||
|
|
||||||
except GitCommandError:
|
except GitCommandError as e:
|
||||||
logger.exception("Git operation failed")
|
logger.error("Git operation failed: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ dependencies = [
|
|||||||
"GitPython==3.1.45",
|
"GitPython==3.1.45",
|
||||||
"docker==7.1.0",
|
"docker==7.1.0",
|
||||||
"pytest==7.3.2",
|
"pytest==7.3.2",
|
||||||
"uv==0.9.5"
|
"uv==0.8.6"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
|
|||||||
@ -485,22 +485,6 @@ test_inductor_aoti() {
|
|||||||
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile
|
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile
|
||||||
}
|
}
|
||||||
|
|
||||||
test_inductor_aoti_cross_compile_for_windows() {
|
|
||||||
|
|
||||||
TEST_REPORTS_DIR=$(pwd)/test/test-reports
|
|
||||||
mkdir -p "$TEST_REPORTS_DIR"
|
|
||||||
|
|
||||||
# Set WINDOWS_CUDA_HOME environment variable
|
|
||||||
WINDOWS_CUDA_HOME="$(pwd)/win-torch-wheel-extracted"
|
|
||||||
export WINDOWS_CUDA_HOME
|
|
||||||
|
|
||||||
echo "WINDOWS_CUDA_HOME is set to: $WINDOWS_CUDA_HOME"
|
|
||||||
echo "Contents:"
|
|
||||||
ls -lah "$(pwd)/win-torch-wheel-extracted/lib/x64/" || true
|
|
||||||
|
|
||||||
python test/inductor/test_aoti_cross_compile_windows.py -k compile --package-dir "$TEST_REPORTS_DIR" --win-torch-lib-dir "$(pwd)/win-torch-wheel-extracted/torch/lib"
|
|
||||||
}
|
|
||||||
|
|
||||||
test_inductor_cpp_wrapper_shard() {
|
test_inductor_cpp_wrapper_shard() {
|
||||||
if [[ -z "$NUM_TEST_SHARDS" ]]; then
|
if [[ -z "$NUM_TEST_SHARDS" ]]; then
|
||||||
echo "NUM_TEST_SHARDS must be defined to run a Python test shard"
|
echo "NUM_TEST_SHARDS must be defined to run a Python test shard"
|
||||||
@ -916,7 +900,7 @@ test_inductor_set_cpu_affinity(){
|
|||||||
export LD_PRELOAD="$JEMALLOC_LIB":"$LD_PRELOAD"
|
export LD_PRELOAD="$JEMALLOC_LIB":"$LD_PRELOAD"
|
||||||
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
|
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
|
||||||
|
|
||||||
if [[ "$(uname -m)" != "aarch64" ]]; then
|
if [[ "${TEST_CONFIG}" != *aarch64* ]]; then
|
||||||
# Use Intel OpenMP for x86
|
# Use Intel OpenMP for x86
|
||||||
IOMP_LIB="$(dirname "$(which python)")/../lib/libiomp5.so"
|
IOMP_LIB="$(dirname "$(which python)")/../lib/libiomp5.so"
|
||||||
export LD_PRELOAD="$IOMP_LIB":"$LD_PRELOAD"
|
export LD_PRELOAD="$IOMP_LIB":"$LD_PRELOAD"
|
||||||
@ -930,7 +914,7 @@ test_inductor_set_cpu_affinity(){
|
|||||||
cores=$((cpus / thread_per_core))
|
cores=$((cpus / thread_per_core))
|
||||||
|
|
||||||
# Set number of cores to 16 on aarch64 for performance runs
|
# Set number of cores to 16 on aarch64 for performance runs
|
||||||
if [[ "$(uname -m)" == "aarch64" && $cores -gt 16 ]]; then
|
if [[ "${TEST_CONFIG}" == *aarch64* && $cores -gt 16 ]]; then
|
||||||
cores=16
|
cores=16
|
||||||
fi
|
fi
|
||||||
export OMP_NUM_THREADS=$cores
|
export OMP_NUM_THREADS=$cores
|
||||||
@ -1683,7 +1667,7 @@ if [[ "${TEST_CONFIG}" == *numpy_2* ]]; then
|
|||||||
python -m pip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0
|
python -m pip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0
|
||||||
fi
|
fi
|
||||||
python test/run_test.py --include dynamo/test_functions.py dynamo/test_unspec.py test_binary_ufuncs.py test_fake_tensor.py test_linalg.py test_numpy_interop.py test_tensor_creation_ops.py test_torch.py torch_np/test_basic.py
|
python test/run_test.py --include dynamo/test_functions.py dynamo/test_unspec.py test_binary_ufuncs.py test_fake_tensor.py test_linalg.py test_numpy_interop.py test_tensor_creation_ops.py test_torch.py torch_np/test_basic.py
|
||||||
elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" == 'default' ]]; then
|
elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" != *perf_cpu_aarch64* ]]; then
|
||||||
test_linux_aarch64
|
test_linux_aarch64
|
||||||
elif [[ "${TEST_CONFIG}" == *backward* ]]; then
|
elif [[ "${TEST_CONFIG}" == *backward* ]]; then
|
||||||
test_forward_backward_compatibility
|
test_forward_backward_compatibility
|
||||||
@ -1734,8 +1718,6 @@ elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
|
|||||||
test_inductor_triton_cpu
|
test_inductor_triton_cpu
|
||||||
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then
|
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then
|
||||||
test_inductor_micro_benchmark
|
test_inductor_micro_benchmark
|
||||||
elif [[ "${TEST_CONFIG}" == *aoti_cross_compile_for_windows* ]]; then
|
|
||||||
test_inductor_aoti_cross_compile_for_windows
|
|
||||||
elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then
|
elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then
|
||||||
install_torchvision
|
install_torchvision
|
||||||
id=$((SHARD_NUMBER-1))
|
id=$((SHARD_NUMBER-1))
|
||||||
|
|||||||
@ -163,13 +163,8 @@ if [[ "$(uname)" != Darwin ]]; then
|
|||||||
MEMORY_LIMIT_MAX_JOBS=12
|
MEMORY_LIMIT_MAX_JOBS=12
|
||||||
NUM_CPUS=$(( $(nproc) - 2 ))
|
NUM_CPUS=$(( $(nproc) - 2 ))
|
||||||
|
|
||||||
if [[ "$(uname)" == Linux ]]; then
|
# Defaults here for **binary** linux builds so they can be changed in one place
|
||||||
# Defaults here for **binary** linux builds so they can be changed in one place
|
export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
|
||||||
export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))}
|
|
||||||
else
|
|
||||||
# For other builds
|
|
||||||
export MAX_JOBS=${NUM_CPUS}
|
|
||||||
fi
|
|
||||||
|
|
||||||
cat >>"$envfile" <<EOL
|
cat >>"$envfile" <<EOL
|
||||||
export MAX_JOBS="${MAX_JOBS}"
|
export MAX_JOBS="${MAX_JOBS}"
|
||||||
|
|||||||
@ -1,354 +0,0 @@
|
|||||||
# PyTorch Docstring Writing Guide
|
|
||||||
|
|
||||||
This skill describes how to write docstrings for functions and methods in the PyTorch project, following the conventions in `torch/_tensor_docs.py` and `torch/nn/functional.py`.
|
|
||||||
|
|
||||||
## General Principles
|
|
||||||
|
|
||||||
- Use **raw strings** (`r"""..."""`) for all docstrings to avoid issues with LaTeX/math backslashes
|
|
||||||
- Follow **Sphinx/reStructuredText** (reST) format for documentation
|
|
||||||
- Be **concise but complete** - include all essential information
|
|
||||||
- Always include **examples** when possible
|
|
||||||
- Use **cross-references** to related functions/classes
|
|
||||||
|
|
||||||
## Docstring Structure
|
|
||||||
|
|
||||||
### 1. Function Signature (First Line)
|
|
||||||
|
|
||||||
Start with the function signature showing all parameters:
|
|
||||||
|
|
||||||
```python
|
|
||||||
r"""function_name(param1, param2, *, kwarg1=default1, kwarg2=default2) -> ReturnType
|
|
||||||
```
|
|
||||||
|
|
||||||
**Notes:**
|
|
||||||
- Include the function name
|
|
||||||
- Show positional and keyword-only arguments (use `*` separator)
|
|
||||||
- Include default values
|
|
||||||
- Show return type annotation
|
|
||||||
- This line should NOT end with a period
|
|
||||||
|
|
||||||
### 2. Brief Description
|
|
||||||
|
|
||||||
Provide a one-line description of what the function does:
|
|
||||||
|
|
||||||
```python
|
|
||||||
r"""conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
|
|
||||||
|
|
||||||
Applies a 2D convolution over an input image composed of several input
|
|
||||||
planes.
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Mathematical Formulas (if applicable)
|
|
||||||
|
|
||||||
Use Sphinx math directives for mathematical expressions:
|
|
||||||
|
|
||||||
```python
|
|
||||||
.. math::
|
|
||||||
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
|
||||||
```
|
|
||||||
|
|
||||||
Or inline math: `:math:\`x^2\``
|
|
||||||
|
|
||||||
### 4. Cross-References
|
|
||||||
|
|
||||||
Link to related classes and functions using Sphinx roles:
|
|
||||||
|
|
||||||
- `:class:\`~torch.nn.ModuleName\`` - Link to a class
|
|
||||||
- `:func:\`torch.function_name\`` - Link to a function
|
|
||||||
- `:meth:\`~Tensor.method_name\`` - Link to a method
|
|
||||||
- `:attr:\`attribute_name\`` - Reference an attribute
|
|
||||||
- The `~` prefix shows only the last component (e.g., `Conv2d` instead of `torch.nn.Conv2d`)
|
|
||||||
|
|
||||||
**Example:**
|
|
||||||
```python
|
|
||||||
See :class:`~torch.nn.Conv2d` for details and output shape.
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5. Notes and Warnings
|
|
||||||
|
|
||||||
Use admonitions for important information:
|
|
||||||
|
|
||||||
```python
|
|
||||||
.. note::
|
|
||||||
This function doesn't work directly with NLLLoss,
|
|
||||||
which expects the Log to be computed between the Softmax and itself.
|
|
||||||
Use log_softmax instead (it's faster and has better numerical properties).
|
|
||||||
|
|
||||||
.. warning::
|
|
||||||
:func:`new_tensor` always copies :attr:`data`. If you have a Tensor
|
|
||||||
``data`` and want to avoid a copy, use :func:`torch.Tensor.requires_grad_`
|
|
||||||
or :func:`torch.Tensor.detach`.
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6. Args Section
|
|
||||||
|
|
||||||
Document all parameters with type annotations and descriptions:
|
|
||||||
|
|
||||||
```python
|
|
||||||
Args:
|
|
||||||
input (Tensor): input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
|
|
||||||
weight (Tensor): filters of shape :math:`(\text{out\_channels} , kH , kW)`
|
|
||||||
bias (Tensor, optional): optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None``
|
|
||||||
stride (int or tuple): the stride of the convolving kernel. Can be a single number or a
|
|
||||||
tuple `(sH, sW)`. Default: 1
|
|
||||||
```
|
|
||||||
|
|
||||||
**Formatting rules:**
|
|
||||||
- Parameter name in **lowercase**
|
|
||||||
- Type in parentheses: `(Type)`, `(Type, optional)` for optional parameters
|
|
||||||
- Description follows the type
|
|
||||||
- For optional parameters, include "Default: ``value``" at the end
|
|
||||||
- Use double backticks for inline code: ``` ``None`` ```
|
|
||||||
- Indent continuation lines by 2 spaces
|
|
||||||
|
|
||||||
### 7. Keyword Args Section (if applicable)
|
|
||||||
|
|
||||||
Sometimes keyword arguments are documented separately:
|
|
||||||
|
|
||||||
```python
|
|
||||||
Keyword args:
|
|
||||||
dtype (:class:`torch.dtype`, optional): the desired type of returned tensor.
|
|
||||||
Default: if None, same :class:`torch.dtype` as this tensor.
|
|
||||||
device (:class:`torch.device`, optional): the desired device of returned tensor.
|
|
||||||
Default: if None, same :class:`torch.device` as this tensor.
|
|
||||||
requires_grad (bool, optional): If autograd should record operations on the
|
|
||||||
returned tensor. Default: ``False``.
|
|
||||||
```
|
|
||||||
|
|
||||||
### 8. Returns Section (if needed)
|
|
||||||
|
|
||||||
Document the return value:
|
|
||||||
|
|
||||||
```python
|
|
||||||
Returns:
|
|
||||||
Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
|
|
||||||
If ``hard=True``, the returned samples will be one-hot, otherwise they will
|
|
||||||
be probability distributions that sum to 1 across `dim`.
|
|
||||||
```
|
|
||||||
|
|
||||||
Or simply include it in the function signature line if obvious from context.
|
|
||||||
|
|
||||||
### 9. Examples Section
|
|
||||||
|
|
||||||
Always include examples when possible:
|
|
||||||
|
|
||||||
```python
|
|
||||||
Examples::
|
|
||||||
|
|
||||||
>>> inputs = torch.randn(33, 16, 30)
|
|
||||||
>>> filters = torch.randn(20, 16, 5)
|
|
||||||
>>> F.conv1d(inputs, filters)
|
|
||||||
|
|
||||||
>>> # With square kernels and equal stride
|
|
||||||
>>> filters = torch.randn(8, 4, 3, 3)
|
|
||||||
>>> inputs = torch.randn(1, 4, 5, 5)
|
|
||||||
>>> F.conv2d(inputs, filters, padding=1)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Formatting rules:**
|
|
||||||
- Use `Examples::` with double colon
|
|
||||||
- Use `>>>` prompt for Python code
|
|
||||||
- Include comments with `#` when helpful
|
|
||||||
- Show actual output when it helps understanding (indent without `>>>`)
|
|
||||||
|
|
||||||
### 10. External References
|
|
||||||
|
|
||||||
Link to papers or external documentation:
|
|
||||||
|
|
||||||
```python
|
|
||||||
.. _Link Name:
|
|
||||||
https://arxiv.org/abs/1611.00712
|
|
||||||
```
|
|
||||||
|
|
||||||
Reference them in text: ```See `Link Name`_```
|
|
||||||
|
|
||||||
## Method Types
|
|
||||||
|
|
||||||
### Native Python Functions
|
|
||||||
|
|
||||||
For regular Python functions, use a standard docstring:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def relu(input: Tensor, inplace: bool = False) -> Tensor:
|
|
||||||
r"""relu(input, inplace=False) -> Tensor
|
|
||||||
|
|
||||||
Applies the rectified linear unit function element-wise. See
|
|
||||||
:class:`~torch.nn.ReLU` for more details.
|
|
||||||
"""
|
|
||||||
# implementation
|
|
||||||
```
|
|
||||||
|
|
||||||
### C-Bound Functions (using add_docstr)
|
|
||||||
|
|
||||||
For C-bound functions, use `_add_docstr`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
conv1d = _add_docstr(
|
|
||||||
torch.conv1d,
|
|
||||||
r"""
|
|
||||||
conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor
|
|
||||||
|
|
||||||
Applies a 1D convolution over an input signal composed of several input
|
|
||||||
planes.
|
|
||||||
|
|
||||||
See :class:`~torch.nn.Conv1d` for details and output shape.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
|
|
||||||
weight: filters of shape :math:`(\text{out\_channels} , kW)`
|
|
||||||
...
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### In-Place Variants
|
|
||||||
|
|
||||||
For in-place operations (ending with `_`), reference the original:
|
|
||||||
|
|
||||||
```python
|
|
||||||
add_docstr_all(
|
|
||||||
"abs_",
|
|
||||||
r"""
|
|
||||||
abs_() -> Tensor
|
|
||||||
|
|
||||||
In-place version of :meth:`~Tensor.abs`
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Alias Functions
|
|
||||||
|
|
||||||
For aliases, simply reference the original:
|
|
||||||
|
|
||||||
```python
|
|
||||||
add_docstr_all(
|
|
||||||
"absolute",
|
|
||||||
r"""
|
|
||||||
absolute() -> Tensor
|
|
||||||
|
|
||||||
Alias for :func:`abs`
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Common Patterns
|
|
||||||
|
|
||||||
### Shape Documentation
|
|
||||||
|
|
||||||
Use LaTeX math notation for tensor shapes:
|
|
||||||
|
|
||||||
```python
|
|
||||||
:math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
|
|
||||||
```
|
|
||||||
|
|
||||||
### Reusable Argument Definitions
|
|
||||||
|
|
||||||
For commonly used arguments, define them once and reuse:
|
|
||||||
|
|
||||||
```python
|
|
||||||
common_args = parse_kwargs(
|
|
||||||
"""
|
|
||||||
dtype (:class:`torch.dtype`, optional): the desired type of returned tensor.
|
|
||||||
Default: if None, same as this tensor.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Then use with .format():
|
|
||||||
r"""
|
|
||||||
...
|
|
||||||
|
|
||||||
Keyword args:
|
|
||||||
{dtype}
|
|
||||||
{device}
|
|
||||||
""".format(**common_args)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Template Insertion
|
|
||||||
|
|
||||||
Insert reproducibility notes or other common text:
|
|
||||||
|
|
||||||
```python
|
|
||||||
r"""
|
|
||||||
{tf32_note}
|
|
||||||
|
|
||||||
{cudnn_reproducibility_note}
|
|
||||||
""".format(**reproducibility_notes, **tf32_notes)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Complete Example
|
|
||||||
|
|
||||||
Here's a complete example showing all elements:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def gumbel_softmax(
|
|
||||||
logits: Tensor,
|
|
||||||
tau: float = 1,
|
|
||||||
hard: bool = False,
|
|
||||||
eps: float = 1e-10,
|
|
||||||
dim: int = -1,
|
|
||||||
) -> Tensor:
|
|
||||||
r"""
|
|
||||||
Sample from the Gumbel-Softmax distribution and optionally discretize.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logits (Tensor): `[..., num_features]` unnormalized log probabilities
|
|
||||||
tau (float): non-negative scalar temperature
|
|
||||||
hard (bool): if ``True``, the returned samples will be discretized as one-hot vectors,
|
|
||||||
but will be differentiated as if it is the soft sample in autograd. Default: ``False``
|
|
||||||
dim (int): A dimension along which softmax will be computed. Default: -1
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
|
|
||||||
If ``hard=True``, the returned samples will be one-hot, otherwise they will
|
|
||||||
be probability distributions that sum to 1 across `dim`.
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
This function is here for legacy reasons, may be removed from nn.Functional in the future.
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
>>> logits = torch.randn(20, 32)
|
|
||||||
>>> # Sample soft categorical using reparametrization trick:
|
|
||||||
>>> F.gumbel_softmax(logits, tau=1, hard=False)
|
|
||||||
>>> # Sample hard categorical using "Straight-through" trick:
|
|
||||||
>>> F.gumbel_softmax(logits, tau=1, hard=True)
|
|
||||||
|
|
||||||
.. _Link 1:
|
|
||||||
https://arxiv.org/abs/1611.00712
|
|
||||||
"""
|
|
||||||
# implementation
|
|
||||||
```
|
|
||||||
|
|
||||||
## Quick Checklist
|
|
||||||
|
|
||||||
When writing a PyTorch docstring, ensure:
|
|
||||||
|
|
||||||
- [ ] Use raw string (`r"""`)
|
|
||||||
- [ ] Include function signature on first line
|
|
||||||
- [ ] Provide brief description
|
|
||||||
- [ ] Document all parameters in Args section with types
|
|
||||||
- [ ] Include default values for optional parameters
|
|
||||||
- [ ] Use Sphinx cross-references (`:func:`, `:class:`, `:meth:`)
|
|
||||||
- [ ] Add mathematical formulas if applicable
|
|
||||||
- [ ] Include at least one example in Examples section
|
|
||||||
- [ ] Add warnings/notes for important caveats
|
|
||||||
- [ ] Link to related module class with `:class:`
|
|
||||||
- [ ] Use proper math notation for tensor shapes
|
|
||||||
- [ ] Follow consistent formatting and indentation
|
|
||||||
|
|
||||||
## Common Sphinx Roles Reference
|
|
||||||
|
|
||||||
- `:class:\`~torch.nn.Module\`` - Class reference
|
|
||||||
- `:func:\`torch.function\`` - Function reference
|
|
||||||
- `:meth:\`~Tensor.method\`` - Method reference
|
|
||||||
- `:attr:\`attribute\`` - Attribute reference
|
|
||||||
- `:math:\`equation\`` - Inline math
|
|
||||||
- `:ref:\`label\`` - Internal reference
|
|
||||||
- ``` ``code`` ``` - Inline code (use double backticks)
|
|
||||||
|
|
||||||
## Additional Notes
|
|
||||||
|
|
||||||
- **Indentation**: Use 4 spaces for code, 2 spaces for continuation of parameter descriptions
|
|
||||||
- **Line length**: Try to keep lines under 100 characters when possible
|
|
||||||
- **Periods**: End sentences with periods, but not the signature line
|
|
||||||
- **Backticks**: Use double backticks for code: ``` ``True`` ``None`` ``False`` ```
|
|
||||||
- **Types**: Common types are `Tensor`, `int`, `float`, `bool`, `str`, `tuple`, `list`, etc.
|
|
||||||
6
.flake8
6
.flake8
@ -7,12 +7,16 @@ max-line-length = 120
|
|||||||
# C408 ignored because we like the dict keyword argument syntax
|
# C408 ignored because we like the dict keyword argument syntax
|
||||||
# E501 is not flexible enough, we're using B950 instead
|
# E501 is not flexible enough, we're using B950 instead
|
||||||
ignore =
|
ignore =
|
||||||
E203,E305,E402,E501,E704,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824,
|
E203,E305,E402,E501,E704,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824,
|
||||||
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
|
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
|
||||||
# to line this up with executable bit
|
# to line this up with executable bit
|
||||||
EXE001,
|
EXE001,
|
||||||
# these ignores are from flake8-bugbear; please fix!
|
# these ignores are from flake8-bugbear; please fix!
|
||||||
B007,B008,B017,B019,B023,B028,B903,B905,B906,B907,B908,B910
|
B007,B008,B017,B019,B023,B028,B903,B905,B906,B907,B908,B910
|
||||||
|
# these ignores are from flake8-comprehensions; please fix!
|
||||||
|
C407,
|
||||||
|
# these ignores are from flake8-logging-format; please fix!
|
||||||
|
G100,G101,G200
|
||||||
# these ignores are from flake8-simplify. please fix or ignore with commented reason
|
# these ignores are from flake8-simplify. please fix or ignore with commented reason
|
||||||
SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12,
|
SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12,
|
||||||
# SIM104 is already covered by pyupgrade ruff
|
# SIM104 is already covered by pyupgrade ruff
|
||||||
|
|||||||
7
.github/actions/setup-rocm/action.yml
vendored
7
.github/actions/setup-rocm/action.yml
vendored
@ -124,10 +124,3 @@ runs:
|
|||||||
id: login-ecr
|
id: login-ecr
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||||
|
|
||||||
- name: Preserve github env variables for use in docker
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
env | grep '^GITHUB' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
|
|
||||||
env | grep '^CI' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
|
|
||||||
env | grep '^RUNNER' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}"
|
|
||||||
|
|||||||
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
|||||||
69bbe7363897764f9e758d851cd0340147d27f94
|
1b013f5b5a87a1882eb143c26d79d091150d6a37
|
||||||
|
|||||||
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
|||||||
1752fe6809b74921644866275ab80244b96e80bc
|
faffd5cf673615583da6517275e361cb3dbc77e6
|
||||||
|
|||||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
|||||||
df6798dfb931ce7c7fe5bed2447cd1092a5981af
|
0fa6e3129e61143224663e1ec67980d12b7ec4eb
|
||||||
|
|||||||
5
.github/ci_configs/vllm/Dockerfile
vendored
5
.github/ci_configs/vllm/Dockerfile
vendored
@ -283,9 +283,6 @@ RUN --mount=type=bind,source=${TORCH_WHEELS_PATH},target=/dist \
|
|||||||
uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \
|
uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
|
||||||
uv pip install --system --pre apache-tvm-ffi==0.1.0b15
|
|
||||||
|
|
||||||
# Install the vllm wheel from previous stage
|
# Install the vllm wheel from previous stage
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
uv pip install --system /wheels/vllm/*.whl --verbose
|
uv pip install --system /wheels/vllm/*.whl --verbose
|
||||||
@ -298,8 +295,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||||||
ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0'
|
ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0'
|
||||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||||
|
|
||||||
# TODO(elainewy): remove this once vllm commit is updated, and install flashinfer from pip
|
|
||||||
# see https://github.com/pytorch/pytorch/pull/165274#issuecomment-3408531784
|
|
||||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||||
ARG FLASHINFER_GIT_REF="v0.2.14.post1"
|
ARG FLASHINFER_GIT_REF="v0.2.14.post1"
|
||||||
|
|
||||||
|
|||||||
9
.github/label_to_label.yml
vendored
9
.github/label_to_label.yml
vendored
@ -15,11 +15,6 @@
|
|||||||
- "module: reinplacing"
|
- "module: reinplacing"
|
||||||
then:
|
then:
|
||||||
- "module: pt2-dispatcher"
|
- "module: pt2-dispatcher"
|
||||||
- any:
|
|
||||||
- "vllm-compile"
|
|
||||||
then:
|
|
||||||
- "module: vllm"
|
|
||||||
- "oncall: pt2"
|
|
||||||
- any:
|
- any:
|
||||||
- "module: vmap"
|
- "module: vmap"
|
||||||
then:
|
then:
|
||||||
@ -32,6 +27,10 @@
|
|||||||
- "module: pt2 optimizer"
|
- "module: pt2 optimizer"
|
||||||
then:
|
then:
|
||||||
- "module: dynamo"
|
- "module: dynamo"
|
||||||
|
- any:
|
||||||
|
- "module: flex attention"
|
||||||
|
then:
|
||||||
|
- "module: higher order operators"
|
||||||
- any:
|
- any:
|
||||||
- "module: aotinductor"
|
- "module: aotinductor"
|
||||||
then:
|
then:
|
||||||
|
|||||||
29
.github/labeler.yml
vendored
29
.github/labeler.yml
vendored
@ -133,32 +133,3 @@
|
|||||||
|
|
||||||
"ciflow/vllm":
|
"ciflow/vllm":
|
||||||
- .github/ci_commit_pins/vllm.txt
|
- .github/ci_commit_pins/vllm.txt
|
||||||
|
|
||||||
"ciflow/b200":
|
|
||||||
- test/test_matmul_cuda.py
|
|
||||||
- test/test_scaled_matmul_cuda.py
|
|
||||||
- test/inductor/test_fp8.py
|
|
||||||
- aten/src/ATen/native/cuda/Blas.cpp
|
|
||||||
- torch/**/*cublas*
|
|
||||||
- torch/_inductor/kernel/mm.py
|
|
||||||
- test/inductor/test_max_autotune.py
|
|
||||||
- third_party/fbgemm
|
|
||||||
|
|
||||||
"ciflow/h100":
|
|
||||||
- test/test_matmul_cuda.py
|
|
||||||
- test/test_scaled_matmul_cuda.py
|
|
||||||
- test/inductor/test_fp8.py
|
|
||||||
- aten/src/ATen/native/cuda/Blas.cpp
|
|
||||||
- torch/**/*cublas*
|
|
||||||
- torch/_inductor/kernel/mm.py
|
|
||||||
- test/inductor/test_max_autotune.py
|
|
||||||
- third_party/fbgemm
|
|
||||||
|
|
||||||
"ciflow/rocm":
|
|
||||||
- test/test_matmul_cuda.py
|
|
||||||
- test/test_scaled_matmul_cuda.py
|
|
||||||
- test/inductor/test_fp8.py
|
|
||||||
- aten/src/ATen/native/cuda/Blas.cpp
|
|
||||||
- torch/_inductor/kernel/mm.py
|
|
||||||
- test/inductor/test_max_autotune.py
|
|
||||||
- third_party/fbgemm
|
|
||||||
|
|||||||
2
.github/pytorch-probot.yml
vendored
2
.github/pytorch-probot.yml
vendored
@ -3,7 +3,6 @@ ciflow_tracking_issue: 64124
|
|||||||
ciflow_push_tags:
|
ciflow_push_tags:
|
||||||
- ciflow/b200
|
- ciflow/b200
|
||||||
- ciflow/b200-symm-mem
|
- ciflow/b200-symm-mem
|
||||||
- ciflow/b200-distributed
|
|
||||||
- ciflow/binaries
|
- ciflow/binaries
|
||||||
- ciflow/binaries_libtorch
|
- ciflow/binaries_libtorch
|
||||||
- ciflow/binaries_wheel
|
- ciflow/binaries_wheel
|
||||||
@ -33,7 +32,6 @@ ciflow_push_tags:
|
|||||||
- ciflow/rocm
|
- ciflow/rocm
|
||||||
- ciflow/rocm-mi300
|
- ciflow/rocm-mi300
|
||||||
- ciflow/rocm-mi355
|
- ciflow/rocm-mi355
|
||||||
- ciflow/rocm-navi31
|
|
||||||
- ciflow/s390
|
- ciflow/s390
|
||||||
- ciflow/slow
|
- ciflow/slow
|
||||||
- ciflow/torchbench
|
- ciflow/torchbench
|
||||||
|
|||||||
30
.github/scripts/generate_binary_build_matrix.py
vendored
30
.github/scripts/generate_binary_build_matrix.py
vendored
@ -79,21 +79,21 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
|
|||||||
"nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux'"
|
"nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux'"
|
||||||
),
|
),
|
||||||
"12.9": (
|
"12.9": (
|
||||||
"nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | "
|
"nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||||
"nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | "
|
"nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||||
"nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | "
|
"nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||||
"nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | "
|
"nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||||
"nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | "
|
"nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||||
"nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | "
|
"nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||||
"nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | "
|
"nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||||
"nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | "
|
"nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||||
"nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | "
|
"nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||||
"nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | "
|
"nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||||
"nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | "
|
"nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||||
"nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | "
|
"nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||||
"nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | "
|
"nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||||
"nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | "
|
"nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||||
"nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'"
|
"nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'"
|
||||||
),
|
),
|
||||||
"13.0": (
|
"13.0": (
|
||||||
"nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | "
|
"nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | "
|
||||||
|
|||||||
2
.github/scripts/trymerge.py
vendored
2
.github/scripts/trymerge.py
vendored
@ -1092,7 +1092,7 @@ class GitHubPR:
|
|||||||
editor = node["editor"]
|
editor = node["editor"]
|
||||||
return GitHubComment(
|
return GitHubComment(
|
||||||
body_text=node["bodyText"],
|
body_text=node["bodyText"],
|
||||||
created_at=node.get("createdAt", ""),
|
created_at=node["createdAt"] if "createdAt" in node else "",
|
||||||
author_login=node["author"]["login"],
|
author_login=node["author"]["login"],
|
||||||
author_url=node["author"].get("url", None),
|
author_url=node["author"].get("url", None),
|
||||||
author_association=node["authorAssociation"],
|
author_association=node["authorAssociation"],
|
||||||
|
|||||||
@ -26,8 +26,9 @@ name: !{{ build_environment }}
|
|||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v6
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
|
# TODO: Removeme once 3.14 is out
|
||||||
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
|
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
|
||||||
python-version: "!{{ py_ver.strip('t') + ('.4' if '3.14' not in py_ver else '.0') }}"
|
python-version: "!{{ (py_ver.strip('t') + '.4') if '3.14' not in py_ver else '3.14.0-rc.2' }}"
|
||||||
freethreaded: !{{ "true" if py_ver.endswith('t') else "false" }}
|
freethreaded: !{{ "true" if py_ver.endswith('t') else "false" }}
|
||||||
{%- endmacro %}
|
{%- endmacro %}
|
||||||
|
|
||||||
|
|||||||
@ -79,9 +79,9 @@ jobs:
|
|||||||
runs-on: "windows-11-arm64-preview"
|
runs-on: "windows-11-arm64-preview"
|
||||||
{%- else %}
|
{%- else %}
|
||||||
{%- if branches == "nightly" %}
|
{%- if branches == "nightly" %}
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
{%- else %}
|
{%- else %}
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge.nonephemeral"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
timeout-minutes: !{{ common.timeout_minutes_windows_binary }}
|
timeout-minutes: !{{ common.timeout_minutes_windows_binary }}
|
||||||
|
|||||||
40
.github/workflows/_linux-test.yml
vendored
40
.github/workflows/_linux-test.yml
vendored
@ -224,46 +224,6 @@ jobs:
|
|||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
uses: ./.github/actions/download-td-artifacts
|
uses: ./.github/actions/download-td-artifacts
|
||||||
|
|
||||||
- name: Download Windows torch wheel for cross-compilation
|
|
||||||
if: matrix.win_torch_wheel_artifact != ''
|
|
||||||
uses: seemethere/download-artifact-s3@1da556a7aa0a088e3153970611f6c432d58e80e6 # v4.2.0
|
|
||||||
with:
|
|
||||||
name: ${{ matrix.win_torch_wheel_artifact }}
|
|
||||||
path: win-torch-wheel
|
|
||||||
|
|
||||||
- name: Extract Windows wheel and setup CUDA libraries
|
|
||||||
if: matrix.win_torch_wheel_artifact != ''
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
set -x
|
|
||||||
|
|
||||||
# Find the wheel file
|
|
||||||
WHEEL_FILE=$(find win-torch-wheel -name "*.whl" -type f | head -n 1)
|
|
||||||
if [ -z "$WHEEL_FILE" ]; then
|
|
||||||
echo "Error: No wheel file found in win-torch-wheel directory"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
echo "Found wheel file: $WHEEL_FILE"
|
|
||||||
|
|
||||||
# Unzip the wheel file
|
|
||||||
unzip -q "$WHEEL_FILE" -d win-torch-wheel-extracted
|
|
||||||
echo "Extracted wheel contents"
|
|
||||||
|
|
||||||
# Setup CUDA libraries (cuda.lib and cudart.lib) directory
|
|
||||||
mkdir -p win-torch-wheel-extracted/lib/x64
|
|
||||||
if [ -f "win-torch-wheel/cuda.lib" ]; then
|
|
||||||
mv win-torch-wheel/cuda.lib win-torch-wheel-extracted/lib/x64/
|
|
||||||
echo "Moved cuda.lib to win-torch-wheel-extracted/lib/x64/"
|
|
||||||
fi
|
|
||||||
if [ -f "win-torch-wheel/cudart.lib" ]; then
|
|
||||||
mv win-torch-wheel/cudart.lib win-torch-wheel-extracted/lib/x64/
|
|
||||||
echo "Moved cudart.lib to win-torch-wheel-extracted/lib/x64/"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Verify CUDA libraries are present
|
|
||||||
echo "CUDA libraries:"
|
|
||||||
ls -la win-torch-wheel-extracted/lib/x64/ || echo "No CUDA libraries found"
|
|
||||||
|
|
||||||
- name: Parse ref
|
- name: Parse ref
|
||||||
id: parse-ref
|
id: parse-ref
|
||||||
run: .github/scripts/parse_ref.py
|
run: .github/scripts/parse_ref.py
|
||||||
|
|||||||
25
.github/workflows/_win-build.yml
vendored
25
.github/workflows/_win-build.yml
vendored
@ -168,31 +168,6 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
.ci/pytorch/win-build.sh
|
.ci/pytorch/win-build.sh
|
||||||
|
|
||||||
# Collect Windows torch libs and CUDA libs for cross-compilation
|
|
||||||
- name: Collect Windows CUDA libs for cross-compilation
|
|
||||||
if: steps.build.outcome != 'skipped' && inputs.cuda-version != 'cpu'
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
set -ex
|
|
||||||
|
|
||||||
# Create directory structure if does not exist
|
|
||||||
mkdir -p /c/${{ github.run_id }}/build-results
|
|
||||||
|
|
||||||
# Copy CUDA libs
|
|
||||||
CUDA_PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${{ inputs.cuda-version }}"
|
|
||||||
|
|
||||||
if [ -f "${CUDA_PATH}/lib/x64/cuda.lib" ]; then
|
|
||||||
cp "${CUDA_PATH}/lib/x64/cuda.lib" /c/${{ github.run_id }}/build-results/
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ -f "${CUDA_PATH}/lib/x64/cudart.lib" ]; then
|
|
||||||
cp "${CUDA_PATH}/lib/x64/cudart.lib" /c/${{ github.run_id }}/build-results/
|
|
||||||
fi
|
|
||||||
|
|
||||||
# List collected files
|
|
||||||
echo "Collected CUDA libs:"
|
|
||||||
ls -lah /c/${{ github.run_id }}/build-results/*.lib
|
|
||||||
|
|
||||||
# Upload to github so that people can click and download artifacts
|
# Upload to github so that people can click and download artifacts
|
||||||
- name: Upload artifacts to s3
|
- name: Upload artifacts to s3
|
||||||
if: steps.build.outcome != 'skipped'
|
if: steps.build.outcome != 'skipped'
|
||||||
|
|||||||
62
.github/workflows/b200-distributed.yml
vendored
62
.github/workflows/b200-distributed.yml
vendored
@ -1,62 +0,0 @@
|
|||||||
name: CI for distributed tests on B200
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
paths:
|
|
||||||
- .github/workflows/b200-distributed.yml
|
|
||||||
workflow_dispatch:
|
|
||||||
push:
|
|
||||||
tags:
|
|
||||||
- ciflow/b200-distributed/*
|
|
||||||
schedule:
|
|
||||||
- cron: 46 8 * * * # about 1:46am PDT
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
id-token: write
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
|
|
||||||
get-label-type:
|
|
||||||
if: github.repository_owner == 'pytorch'
|
|
||||||
name: get-label-type
|
|
||||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
|
||||||
with:
|
|
||||||
triggering_actor: ${{ github.triggering_actor }}
|
|
||||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
|
||||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
|
||||||
curr_ref_type: ${{ github.ref_type }}
|
|
||||||
|
|
||||||
linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200:
|
|
||||||
name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed-b200
|
|
||||||
uses: ./.github/workflows/_linux-build.yml
|
|
||||||
needs: get-label-type
|
|
||||||
with:
|
|
||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
|
||||||
runner: linux.12xlarge.memory
|
|
||||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
|
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
|
||||||
cuda-arch-list: '10.0'
|
|
||||||
test-matrix: |
|
|
||||||
{ include: [
|
|
||||||
{ config: "distributed", shard: 1, num_shards: 2, runner: "linux.dgx.b200.8" },
|
|
||||||
{ config: "distributed", shard: 2, num_shards: 2, runner: "linux.dgx.b200.8" },
|
|
||||||
]}
|
|
||||||
secrets: inherit
|
|
||||||
|
|
||||||
linux-jammy-cuda12_8-py3_10-gcc11-test-distributed-b200:
|
|
||||||
name: linux-jammy-cuda12.8-py3.10-gcc11-test-b200
|
|
||||||
uses: ./.github/workflows/_linux-test.yml
|
|
||||||
needs:
|
|
||||||
- linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200
|
|
||||||
with:
|
|
||||||
timeout-minutes: 1200
|
|
||||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
|
|
||||||
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.docker-image }}
|
|
||||||
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.test-matrix }}
|
|
||||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
|
||||||
secrets: inherit
|
|
||||||
14
.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
generated
vendored
14
.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
generated
vendored
@ -224,7 +224,7 @@ jobs:
|
|||||||
ALPINE_IMAGE: "arm64v8/alpine"
|
ALPINE_IMAGE: "arm64v8/alpine"
|
||||||
build_name: manywheel-py3_10-cuda-aarch64-12_9
|
build_name: manywheel-py3_10-cuda-aarch64-12_9
|
||||||
build_environment: linux-aarch64-binary-manywheel
|
build_environment: linux-aarch64-binary-manywheel
|
||||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
|
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||||
timeout-minutes: 420
|
timeout-minutes: 420
|
||||||
secrets:
|
secrets:
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
@ -473,7 +473,7 @@ jobs:
|
|||||||
ALPINE_IMAGE: "arm64v8/alpine"
|
ALPINE_IMAGE: "arm64v8/alpine"
|
||||||
build_name: manywheel-py3_11-cuda-aarch64-12_9
|
build_name: manywheel-py3_11-cuda-aarch64-12_9
|
||||||
build_environment: linux-aarch64-binary-manywheel
|
build_environment: linux-aarch64-binary-manywheel
|
||||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
|
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||||
timeout-minutes: 420
|
timeout-minutes: 420
|
||||||
secrets:
|
secrets:
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
@ -722,7 +722,7 @@ jobs:
|
|||||||
ALPINE_IMAGE: "arm64v8/alpine"
|
ALPINE_IMAGE: "arm64v8/alpine"
|
||||||
build_name: manywheel-py3_12-cuda-aarch64-12_9
|
build_name: manywheel-py3_12-cuda-aarch64-12_9
|
||||||
build_environment: linux-aarch64-binary-manywheel
|
build_environment: linux-aarch64-binary-manywheel
|
||||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
|
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||||
timeout-minutes: 420
|
timeout-minutes: 420
|
||||||
secrets:
|
secrets:
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
@ -971,7 +971,7 @@ jobs:
|
|||||||
ALPINE_IMAGE: "arm64v8/alpine"
|
ALPINE_IMAGE: "arm64v8/alpine"
|
||||||
build_name: manywheel-py3_13-cuda-aarch64-12_9
|
build_name: manywheel-py3_13-cuda-aarch64-12_9
|
||||||
build_environment: linux-aarch64-binary-manywheel
|
build_environment: linux-aarch64-binary-manywheel
|
||||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
|
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||||
timeout-minutes: 420
|
timeout-minutes: 420
|
||||||
secrets:
|
secrets:
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
@ -1220,7 +1220,7 @@ jobs:
|
|||||||
ALPINE_IMAGE: "arm64v8/alpine"
|
ALPINE_IMAGE: "arm64v8/alpine"
|
||||||
build_name: manywheel-py3_13t-cuda-aarch64-12_9
|
build_name: manywheel-py3_13t-cuda-aarch64-12_9
|
||||||
build_environment: linux-aarch64-binary-manywheel
|
build_environment: linux-aarch64-binary-manywheel
|
||||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
|
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||||
timeout-minutes: 420
|
timeout-minutes: 420
|
||||||
secrets:
|
secrets:
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
@ -1469,7 +1469,7 @@ jobs:
|
|||||||
ALPINE_IMAGE: "arm64v8/alpine"
|
ALPINE_IMAGE: "arm64v8/alpine"
|
||||||
build_name: manywheel-py3_14-cuda-aarch64-12_9
|
build_name: manywheel-py3_14-cuda-aarch64-12_9
|
||||||
build_environment: linux-aarch64-binary-manywheel
|
build_environment: linux-aarch64-binary-manywheel
|
||||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
|
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||||
timeout-minutes: 420
|
timeout-minutes: 420
|
||||||
secrets:
|
secrets:
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
@ -1718,7 +1718,7 @@ jobs:
|
|||||||
ALPINE_IMAGE: "arm64v8/alpine"
|
ALPINE_IMAGE: "arm64v8/alpine"
|
||||||
build_name: manywheel-py3_14t-cuda-aarch64-12_9
|
build_name: manywheel-py3_14t-cuda-aarch64-12_9
|
||||||
build_environment: linux-aarch64-binary-manywheel
|
build_environment: linux-aarch64-binary-manywheel
|
||||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
|
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||||
timeout-minutes: 420
|
timeout-minutes: 420
|
||||||
secrets:
|
secrets:
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|||||||
14
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
14
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
@ -259,7 +259,7 @@ jobs:
|
|||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
build_name: manywheel-py3_10-cuda12_9
|
build_name: manywheel-py3_10-cuda12_9
|
||||||
build_environment: linux-binary-manywheel
|
build_environment: linux-binary-manywheel
|
||||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
|
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||||
secrets:
|
secrets:
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
manywheel-py3_10-cuda12_9-test: # Testing
|
manywheel-py3_10-cuda12_9-test: # Testing
|
||||||
@ -925,7 +925,7 @@ jobs:
|
|||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
build_name: manywheel-py3_11-cuda12_9
|
build_name: manywheel-py3_11-cuda12_9
|
||||||
build_environment: linux-binary-manywheel
|
build_environment: linux-binary-manywheel
|
||||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
|
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||||
secrets:
|
secrets:
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
manywheel-py3_11-cuda12_9-test: # Testing
|
manywheel-py3_11-cuda12_9-test: # Testing
|
||||||
@ -1591,7 +1591,7 @@ jobs:
|
|||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
build_name: manywheel-py3_12-cuda12_9
|
build_name: manywheel-py3_12-cuda12_9
|
||||||
build_environment: linux-binary-manywheel
|
build_environment: linux-binary-manywheel
|
||||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
|
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||||
secrets:
|
secrets:
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
manywheel-py3_12-cuda12_9-test: # Testing
|
manywheel-py3_12-cuda12_9-test: # Testing
|
||||||
@ -2257,7 +2257,7 @@ jobs:
|
|||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
build_name: manywheel-py3_13-cuda12_9
|
build_name: manywheel-py3_13-cuda12_9
|
||||||
build_environment: linux-binary-manywheel
|
build_environment: linux-binary-manywheel
|
||||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
|
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||||
secrets:
|
secrets:
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
manywheel-py3_13-cuda12_9-test: # Testing
|
manywheel-py3_13-cuda12_9-test: # Testing
|
||||||
@ -2923,7 +2923,7 @@ jobs:
|
|||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
build_name: manywheel-py3_13t-cuda12_9
|
build_name: manywheel-py3_13t-cuda12_9
|
||||||
build_environment: linux-binary-manywheel
|
build_environment: linux-binary-manywheel
|
||||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
|
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||||
secrets:
|
secrets:
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
manywheel-py3_13t-cuda12_9-test: # Testing
|
manywheel-py3_13t-cuda12_9-test: # Testing
|
||||||
@ -3589,7 +3589,7 @@ jobs:
|
|||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
build_name: manywheel-py3_14-cuda12_9
|
build_name: manywheel-py3_14-cuda12_9
|
||||||
build_environment: linux-binary-manywheel
|
build_environment: linux-binary-manywheel
|
||||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
|
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||||
secrets:
|
secrets:
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
manywheel-py3_14-cuda12_9-test: # Testing
|
manywheel-py3_14-cuda12_9-test: # Testing
|
||||||
@ -4255,7 +4255,7 @@ jobs:
|
|||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
build_name: manywheel-py3_14t-cuda12_9
|
build_name: manywheel-py3_14t-cuda12_9
|
||||||
build_environment: linux-binary-manywheel
|
build_environment: linux-binary-manywheel
|
||||||
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
|
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
|
||||||
secrets:
|
secrets:
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
manywheel-py3_14t-cuda12_9-test: # Testing
|
manywheel-py3_14t-cuda12_9-test: # Testing
|
||||||
|
|||||||
1
.github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml
generated
vendored
1
.github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml
generated
vendored
@ -63,6 +63,7 @@ jobs:
|
|||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v6
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
|
# TODO: Removeme once 3.14 is out
|
||||||
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
|
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
|
||||||
python-version: "3.10.4"
|
python-version: "3.10.4"
|
||||||
freethreaded: false
|
freethreaded: false
|
||||||
|
|||||||
11
.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
generated
vendored
11
.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
generated
vendored
@ -59,6 +59,7 @@ jobs:
|
|||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v6
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
|
# TODO: Removeme once 3.14 is out
|
||||||
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
|
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
|
||||||
python-version: "3.10.4"
|
python-version: "3.10.4"
|
||||||
freethreaded: false
|
freethreaded: false
|
||||||
@ -168,6 +169,7 @@ jobs:
|
|||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v6
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
|
# TODO: Removeme once 3.14 is out
|
||||||
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
|
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
|
||||||
python-version: "3.11.4"
|
python-version: "3.11.4"
|
||||||
freethreaded: false
|
freethreaded: false
|
||||||
@ -277,6 +279,7 @@ jobs:
|
|||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v6
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
|
# TODO: Removeme once 3.14 is out
|
||||||
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
|
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
|
||||||
python-version: "3.12.4"
|
python-version: "3.12.4"
|
||||||
freethreaded: false
|
freethreaded: false
|
||||||
@ -386,6 +389,7 @@ jobs:
|
|||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v6
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
|
# TODO: Removeme once 3.14 is out
|
||||||
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
|
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
|
||||||
python-version: "3.13.4"
|
python-version: "3.13.4"
|
||||||
freethreaded: false
|
freethreaded: false
|
||||||
@ -495,6 +499,7 @@ jobs:
|
|||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v6
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
|
# TODO: Removeme once 3.14 is out
|
||||||
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
|
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
|
||||||
python-version: "3.13.4"
|
python-version: "3.13.4"
|
||||||
freethreaded: true
|
freethreaded: true
|
||||||
@ -604,8 +609,9 @@ jobs:
|
|||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v6
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
|
# TODO: Removeme once 3.14 is out
|
||||||
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
|
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
|
||||||
python-version: "3.14.0"
|
python-version: "3.14.0-rc.2"
|
||||||
freethreaded: false
|
freethreaded: false
|
||||||
- name: Checkout PyTorch
|
- name: Checkout PyTorch
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@ -713,8 +719,9 @@ jobs:
|
|||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v6
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
|
# TODO: Removeme once 3.14 is out
|
||||||
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
|
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
|
||||||
python-version: "3.14.0"
|
python-version: "3.14.0-rc.2"
|
||||||
freethreaded: true
|
freethreaded: true
|
||||||
- name: Checkout PyTorch
|
- name: Checkout PyTorch
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
8
.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
generated
vendored
8
.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
generated
vendored
@ -44,7 +44,7 @@ jobs:
|
|||||||
libtorch-cpu-shared-with-deps-debug-build:
|
libtorch-cpu-shared-with-deps-debug-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -291,7 +291,7 @@ jobs:
|
|||||||
libtorch-cuda12_6-shared-with-deps-debug-build:
|
libtorch-cuda12_6-shared-with-deps-debug-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -541,7 +541,7 @@ jobs:
|
|||||||
libtorch-cuda12_8-shared-with-deps-debug-build:
|
libtorch-cuda12_8-shared-with-deps-debug-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -791,7 +791,7 @@ jobs:
|
|||||||
libtorch-cuda13_0-shared-with-deps-debug-build:
|
libtorch-cuda13_0-shared-with-deps-debug-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
|
|||||||
8
.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
generated
vendored
8
.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
generated
vendored
@ -44,7 +44,7 @@ jobs:
|
|||||||
libtorch-cpu-shared-with-deps-release-build:
|
libtorch-cpu-shared-with-deps-release-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -291,7 +291,7 @@ jobs:
|
|||||||
libtorch-cuda12_6-shared-with-deps-release-build:
|
libtorch-cuda12_6-shared-with-deps-release-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -541,7 +541,7 @@ jobs:
|
|||||||
libtorch-cuda12_8-shared-with-deps-release-build:
|
libtorch-cuda12_8-shared-with-deps-release-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -791,7 +791,7 @@ jobs:
|
|||||||
libtorch-cuda13_0-shared-with-deps-release-build:
|
libtorch-cuda13_0-shared-with-deps-release-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
|
|||||||
70
.github/workflows/generated-windows-binary-wheel-nightly.yml
generated
vendored
70
.github/workflows/generated-windows-binary-wheel-nightly.yml
generated
vendored
@ -44,7 +44,7 @@ jobs:
|
|||||||
wheel-py3_10-cpu-build:
|
wheel-py3_10-cpu-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -279,7 +279,7 @@ jobs:
|
|||||||
wheel-py3_10-cuda12_6-build:
|
wheel-py3_10-cuda12_6-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -517,7 +517,7 @@ jobs:
|
|||||||
wheel-py3_10-cuda12_8-build:
|
wheel-py3_10-cuda12_8-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -755,7 +755,7 @@ jobs:
|
|||||||
wheel-py3_10-cuda13_0-build:
|
wheel-py3_10-cuda13_0-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -993,7 +993,7 @@ jobs:
|
|||||||
wheel-py3_10-xpu-build:
|
wheel-py3_10-xpu-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -1229,7 +1229,7 @@ jobs:
|
|||||||
wheel-py3_11-cpu-build:
|
wheel-py3_11-cpu-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -1464,7 +1464,7 @@ jobs:
|
|||||||
wheel-py3_11-cuda12_6-build:
|
wheel-py3_11-cuda12_6-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -1702,7 +1702,7 @@ jobs:
|
|||||||
wheel-py3_11-cuda12_8-build:
|
wheel-py3_11-cuda12_8-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -1940,7 +1940,7 @@ jobs:
|
|||||||
wheel-py3_11-cuda13_0-build:
|
wheel-py3_11-cuda13_0-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -2178,7 +2178,7 @@ jobs:
|
|||||||
wheel-py3_11-xpu-build:
|
wheel-py3_11-xpu-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -2414,7 +2414,7 @@ jobs:
|
|||||||
wheel-py3_12-cpu-build:
|
wheel-py3_12-cpu-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -2649,7 +2649,7 @@ jobs:
|
|||||||
wheel-py3_12-cuda12_6-build:
|
wheel-py3_12-cuda12_6-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -2887,7 +2887,7 @@ jobs:
|
|||||||
wheel-py3_12-cuda12_8-build:
|
wheel-py3_12-cuda12_8-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -3125,7 +3125,7 @@ jobs:
|
|||||||
wheel-py3_12-cuda13_0-build:
|
wheel-py3_12-cuda13_0-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -3363,7 +3363,7 @@ jobs:
|
|||||||
wheel-py3_12-xpu-build:
|
wheel-py3_12-xpu-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -3599,7 +3599,7 @@ jobs:
|
|||||||
wheel-py3_13-cpu-build:
|
wheel-py3_13-cpu-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -3834,7 +3834,7 @@ jobs:
|
|||||||
wheel-py3_13-cuda12_6-build:
|
wheel-py3_13-cuda12_6-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -4072,7 +4072,7 @@ jobs:
|
|||||||
wheel-py3_13-cuda12_8-build:
|
wheel-py3_13-cuda12_8-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -4310,7 +4310,7 @@ jobs:
|
|||||||
wheel-py3_13-cuda13_0-build:
|
wheel-py3_13-cuda13_0-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -4548,7 +4548,7 @@ jobs:
|
|||||||
wheel-py3_13-xpu-build:
|
wheel-py3_13-xpu-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -4784,7 +4784,7 @@ jobs:
|
|||||||
wheel-py3_13t-cpu-build:
|
wheel-py3_13t-cpu-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -5019,7 +5019,7 @@ jobs:
|
|||||||
wheel-py3_13t-cuda12_6-build:
|
wheel-py3_13t-cuda12_6-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -5257,7 +5257,7 @@ jobs:
|
|||||||
wheel-py3_13t-cuda12_8-build:
|
wheel-py3_13t-cuda12_8-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -5495,7 +5495,7 @@ jobs:
|
|||||||
wheel-py3_13t-cuda13_0-build:
|
wheel-py3_13t-cuda13_0-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -5733,7 +5733,7 @@ jobs:
|
|||||||
wheel-py3_13t-xpu-build:
|
wheel-py3_13t-xpu-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -5969,7 +5969,7 @@ jobs:
|
|||||||
wheel-py3_14-cpu-build:
|
wheel-py3_14-cpu-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -6204,7 +6204,7 @@ jobs:
|
|||||||
wheel-py3_14-cuda12_6-build:
|
wheel-py3_14-cuda12_6-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -6442,7 +6442,7 @@ jobs:
|
|||||||
wheel-py3_14-cuda12_8-build:
|
wheel-py3_14-cuda12_8-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -6680,7 +6680,7 @@ jobs:
|
|||||||
wheel-py3_14-cuda13_0-build:
|
wheel-py3_14-cuda13_0-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -6918,7 +6918,7 @@ jobs:
|
|||||||
wheel-py3_14-xpu-build:
|
wheel-py3_14-xpu-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -7154,7 +7154,7 @@ jobs:
|
|||||||
wheel-py3_14t-cpu-build:
|
wheel-py3_14t-cpu-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -7389,7 +7389,7 @@ jobs:
|
|||||||
wheel-py3_14t-cuda12_6-build:
|
wheel-py3_14t-cuda12_6-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -7627,7 +7627,7 @@ jobs:
|
|||||||
wheel-py3_14t-cuda12_8-build:
|
wheel-py3_14t-cuda12_8-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -7865,7 +7865,7 @@ jobs:
|
|||||||
wheel-py3_14t-cuda13_0-build:
|
wheel-py3_14t-cuda13_0-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
@ -8103,7 +8103,7 @@ jobs:
|
|||||||
wheel-py3_14t-xpu-build:
|
wheel-py3_14t-xpu-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge"
|
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
|
||||||
timeout-minutes: 360
|
timeout-minutes: 360
|
||||||
env:
|
env:
|
||||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||||
|
|||||||
@ -88,27 +88,27 @@ jobs:
|
|||||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
|
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
]}
|
]}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
|
|||||||
1
.github/workflows/inductor-periodic.yml
vendored
1
.github/workflows/inductor-periodic.yml
vendored
@ -88,6 +88,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
build-environment: linux-jammy-rocm-py3_10
|
build-environment: linux-jammy-rocm-py3_10
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
|
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
|
||||||
|
sync-tag: rocm-build
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||||
|
|||||||
24
.github/workflows/operator_benchmark.yml
vendored
24
.github/workflows/operator_benchmark.yml
vendored
@ -52,27 +52,3 @@ jobs:
|
|||||||
docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }}
|
docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }}
|
||||||
test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }}
|
test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
aarch64-opbenchmark-build:
|
|
||||||
if: github.repository_owner == 'pytorch'
|
|
||||||
name: aarch64-opbenchmark-build
|
|
||||||
uses: ./.github/workflows/_linux-build.yml
|
|
||||||
with:
|
|
||||||
build-environment: linux-jammy-aarch64-py3.10
|
|
||||||
runner: linux.arm64.m7g.4xlarge
|
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
|
|
||||||
test-matrix: |
|
|
||||||
{ include: [
|
|
||||||
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" },
|
|
||||||
]}
|
|
||||||
secrets: inherit
|
|
||||||
|
|
||||||
aarch64-opbenchmark-test:
|
|
||||||
name: aarch64-opbenchmark-test
|
|
||||||
uses: ./.github/workflows/_linux-test.yml
|
|
||||||
needs: aarch64-opbenchmark-build
|
|
||||||
with:
|
|
||||||
build-environment: linux-jammy-aarch64-py3.10
|
|
||||||
docker-image: ${{ needs.aarch64-opbenchmark-build.outputs.docker-image }}
|
|
||||||
test-matrix: ${{ needs.aarch64-opbenchmark-build.outputs.test-matrix }}
|
|
||||||
secrets: inherit
|
|
||||||
|
|||||||
15
.github/workflows/periodic.yml
vendored
15
.github/workflows/periodic.yml
vendored
@ -147,16 +147,15 @@ jobs:
|
|||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
|
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
|
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9
|
||||||
cuda-arch-list: 8.9
|
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
{ config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||||
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
{ config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||||
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
{ config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||||
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
{ config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||||
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
{ config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||||
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
{ config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||||
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] },
|
{ config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] },
|
||||||
]}
|
]}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
|
|||||||
3
.github/workflows/pull.yml
vendored
3
.github/workflows/pull.yml
vendored
@ -347,8 +347,7 @@ jobs:
|
|||||||
uses: ./.github/workflows/_linux-build.yml
|
uses: ./.github/workflows/_linux-build.yml
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
with:
|
with:
|
||||||
# This should sync with the build in xpu.yml but xpu uses a larger runner
|
sync-tag: linux-xpu-n-build
|
||||||
# sync-tag: linux-xpu-n-build
|
|
||||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||||
build-environment: linux-jammy-xpu-n-py3.10
|
build-environment: linux-jammy-xpu-n-py3.10
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
|
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
|
||||||
|
|||||||
1
.github/workflows/rocm-mi300.yml
vendored
1
.github/workflows/rocm-mi300.yml
vendored
@ -45,6 +45,7 @@ jobs:
|
|||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
build-environment: linux-noble-rocm-py3.12-mi300
|
build-environment: linux-noble-rocm-py3.12-mi300
|
||||||
docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
|
docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
|
||||||
|
sync-tag: rocm-build
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" },
|
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" },
|
||||||
|
|||||||
13
.github/workflows/rocm-mi355.yml
vendored
13
.github/workflows/rocm-mi355.yml
vendored
@ -42,14 +42,15 @@ jobs:
|
|||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
build-environment: linux-noble-rocm-py3.12-mi355
|
build-environment: linux-noble-rocm-py3.12-mi355
|
||||||
docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
|
docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
|
||||||
|
sync-tag: rocm-build
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
|
||||||
]}
|
]}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
|
|||||||
75
.github/workflows/rocm-navi31.yml
vendored
75
.github/workflows/rocm-navi31.yml
vendored
@ -1,75 +0,0 @@
|
|||||||
name: rocm-navi31
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
tags:
|
|
||||||
- ciflow/rocm-navi31/*
|
|
||||||
workflow_dispatch:
|
|
||||||
schedule:
|
|
||||||
# We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs.
|
|
||||||
# Also run less frequently on weekends.
|
|
||||||
- cron: 45 */2 * * 1-5
|
|
||||||
- cron: 45 4,12 * * 0,6
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
permissions: read-all
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
target-determination:
|
|
||||||
if: github.repository_owner == 'pytorch'
|
|
||||||
name: before-test
|
|
||||||
uses: ./.github/workflows/target_determination.yml
|
|
||||||
permissions:
|
|
||||||
id-token: write
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
get-label-type:
|
|
||||||
name: get-label-type
|
|
||||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
|
||||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
|
||||||
with:
|
|
||||||
triggering_actor: ${{ github.triggering_actor }}
|
|
||||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
|
||||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
|
||||||
curr_ref_type: ${{ github.ref_type }}
|
|
||||||
|
|
||||||
linux-jammy-rocm-py3_10-build:
|
|
||||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
|
||||||
name: linux-jammy-rocm-py3.10
|
|
||||||
uses: ./.github/workflows/_linux-build.yml
|
|
||||||
needs: get-label-type
|
|
||||||
with:
|
|
||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
|
||||||
build-environment: linux-jammy-rocm-py3.10
|
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
|
||||||
sync-tag: rocm-build
|
|
||||||
test-matrix: |
|
|
||||||
{ include: [
|
|
||||||
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
|
|
||||||
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
|
|
||||||
]}
|
|
||||||
secrets: inherit
|
|
||||||
|
|
||||||
linux-jammy-rocm-py3_10-test:
|
|
||||||
permissions:
|
|
||||||
id-token: write
|
|
||||||
contents: read
|
|
||||||
name: linux-jammy-rocm-py3_10
|
|
||||||
uses: ./.github/workflows/_rocm-test.yml
|
|
||||||
needs:
|
|
||||||
- linux-jammy-rocm-py3_10-build
|
|
||||||
- target-determination
|
|
||||||
with:
|
|
||||||
build-environment: linux-jammy-rocm-py3.10
|
|
||||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
|
||||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
|
||||||
tests-to-include: >-
|
|
||||||
${{ github.event_name == 'schedule' && 'test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs
|
|
||||||
test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark
|
|
||||||
inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor
|
|
||||||
inductor/test_torchinductor inductor/test_decompose_mem_bound_mm
|
|
||||||
inductor/test_flex_attention inductor/test_max_autotune' || '' }}
|
|
||||||
secrets: inherit
|
|
||||||
38
.github/workflows/rocm.yml
vendored
38
.github/workflows/rocm.yml
vendored
@ -26,23 +26,11 @@ jobs:
|
|||||||
id-token: write
|
id-token: write
|
||||||
contents: read
|
contents: read
|
||||||
|
|
||||||
get-label-type:
|
|
||||||
name: get-label-type
|
|
||||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
|
||||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
|
||||||
with:
|
|
||||||
triggering_actor: ${{ github.triggering_actor }}
|
|
||||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
|
||||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
|
||||||
curr_ref_type: ${{ github.ref_type }}
|
|
||||||
|
|
||||||
linux-jammy-rocm-py3_10-build:
|
linux-jammy-rocm-py3_10-build:
|
||||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||||
name: linux-jammy-rocm-py3.10
|
name: linux-jammy-rocm-py3.10
|
||||||
uses: ./.github/workflows/_linux-build.yml
|
uses: ./.github/workflows/_linux-build.yml
|
||||||
needs: get-label-type
|
|
||||||
with:
|
with:
|
||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
|
||||||
build-environment: linux-jammy-rocm-py3.10
|
build-environment: linux-jammy-rocm-py3.10
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||||
sync-tag: rocm-build
|
sync-tag: rocm-build
|
||||||
@ -71,3 +59,29 @@ jobs:
|
|||||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
|
linux-jammy-rocm-py3_10-gfx1100-test:
|
||||||
|
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
|
||||||
|
permissions:
|
||||||
|
id-token: write
|
||||||
|
contents: read
|
||||||
|
name: linux-jammy-rocm-py3_10-gfx1100
|
||||||
|
uses: ./.github/workflows/_rocm-test.yml
|
||||||
|
needs:
|
||||||
|
- linux-jammy-rocm-py3_10-build
|
||||||
|
- target-determination
|
||||||
|
with:
|
||||||
|
build-environment: linux-jammy-rocm-py3.10
|
||||||
|
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||||
|
test-matrix: |
|
||||||
|
{ include: [
|
||||||
|
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
|
||||||
|
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" },
|
||||||
|
]}
|
||||||
|
tests-to-include: >
|
||||||
|
test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs
|
||||||
|
test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark
|
||||||
|
inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor
|
||||||
|
inductor/test_torchinductor inductor/test_decompose_mem_bound_mm
|
||||||
|
inductor/test_flex_attention inductor/test_max_autotune
|
||||||
|
secrets: inherit
|
||||||
|
|||||||
149
.github/workflows/trunk-tagging.yml
vendored
149
.github/workflows/trunk-tagging.yml
vendored
@ -58,10 +58,8 @@ jobs:
|
|||||||
else
|
else
|
||||||
COMMIT_SHA="${{ github.sha }}"
|
COMMIT_SHA="${{ github.sha }}"
|
||||||
fi
|
fi
|
||||||
{
|
echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
|
||||||
echo "sha=${COMMIT_SHA}"
|
echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}"
|
||||||
echo "tag_name=trunk/${COMMIT_SHA}"
|
|
||||||
} >> "${GITHUB_OUTPUT}"
|
|
||||||
|
|
||||||
- name: Validate commit SHA
|
- name: Validate commit SHA
|
||||||
run: |
|
run: |
|
||||||
@ -89,7 +87,7 @@ jobs:
|
|||||||
echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
|
echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Create and push tag(s) with retry
|
- name: Create and push tag with retry
|
||||||
id: check_tag
|
id: check_tag
|
||||||
env:
|
env:
|
||||||
TAG_NAME: ${{ steps.commit.outputs.tag_name }}
|
TAG_NAME: ${{ steps.commit.outputs.tag_name }}
|
||||||
@ -114,23 +112,14 @@ jobs:
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
# Counters for summary reporting
|
# Exit early if tag already exists
|
||||||
created_count=0
|
if check_tag_exists; then
|
||||||
skipped_count=0
|
echo "✅ Tag already exists - no action needed"
|
||||||
failed_count=0
|
echo "exists=true" >> "${GITHUB_OUTPUT}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
# Always write outputs once on exit
|
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
||||||
finish() {
|
|
||||||
set +e
|
|
||||||
if [ -n "${GITHUB_OUTPUT:-}" ]; then
|
|
||||||
{
|
|
||||||
echo "created_count=${created_count}"
|
|
||||||
echo "skipped_count=${skipped_count}"
|
|
||||||
echo "failed_count=${failed_count}"
|
|
||||||
} >> "${GITHUB_OUTPUT}"
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
trap finish EXIT
|
|
||||||
|
|
||||||
# Retry configuration
|
# Retry configuration
|
||||||
MAX_RETRIES=5
|
MAX_RETRIES=5
|
||||||
@ -205,111 +194,31 @@ jobs:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# New behavior for push events: enumerate commits in the push and tag each one.
|
# Execute with retry
|
||||||
# For workflow_dispatch, retain existing single-SHA behavior.
|
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
||||||
|
echo "exists=false" >> "${GITHUB_OUTPUT}"
|
||||||
# Always fetch tags once up front to improve idempotency in loops
|
|
||||||
git fetch origin --tags --quiet || true
|
|
||||||
|
|
||||||
if [ "${{ github.event_name }}" = "push" ]; then
|
|
||||||
BEFORE_SHA="${{ github.event.before }}"
|
|
||||||
AFTER_SHA="${{ github.sha }}" # same as event.after
|
|
||||||
|
|
||||||
# List commits introduced by this push (old..new), oldest first for stable ordering
|
|
||||||
commits_file="$(mktemp)"
|
|
||||||
git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}"
|
|
||||||
|
|
||||||
if [ ! -s "${commits_file}" ]; then
|
|
||||||
echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag."
|
|
||||||
rm -f "${commits_file}"
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
|
|
||||||
commit_count="$(wc -l < "${commits_file}" | tr -d ' ')"
|
|
||||||
echo "Found ${commit_count} commit(s) to tag for push:"
|
|
||||||
while IFS= read -r sha; do
|
|
||||||
printf ' %s\n' "${sha}"
|
|
||||||
done < "${commits_file}"
|
|
||||||
|
|
||||||
while IFS= read -r sha; do
|
|
||||||
TAG_NAME="trunk/${sha}"
|
|
||||||
COMMIT_SHA="${sha}"
|
|
||||||
|
|
||||||
# If tag already exists locally or remotely, skip (idempotent)
|
|
||||||
if check_tag_exists; then
|
|
||||||
echo "✅ Tag ${TAG_NAME} already exists - skipping"
|
|
||||||
skipped_count=$((skipped_count + 1))
|
|
||||||
continue
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
|
||||||
|
|
||||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
|
||||||
created_count=$((created_count + 1))
|
|
||||||
else
|
|
||||||
echo "Tag creation failed after all retry attempts for ${TAG_NAME}"
|
|
||||||
failed_count=$((failed_count + 1))
|
|
||||||
fi
|
|
||||||
done < "${commits_file}"
|
|
||||||
|
|
||||||
rm -f "${commits_file}"
|
|
||||||
|
|
||||||
if [ "${failed_count}" -gt 0 ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
exit 0
|
exit 0
|
||||||
else
|
else
|
||||||
# workflow_dispatch path (single SHA tagging preserved)
|
echo "Tag creation failed after all retry attempts"
|
||||||
|
exit 1
|
||||||
# Exit early if tag already exists
|
|
||||||
if check_tag_exists; then
|
|
||||||
echo "✅ Tag already exists - no action needed"
|
|
||||||
skipped_count=1
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "Tag ${TAG_NAME} does not exist, proceeding with creation"
|
|
||||||
|
|
||||||
if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then
|
|
||||||
created_count=1
|
|
||||||
exit 0
|
|
||||||
else
|
|
||||||
echo "Tag creation failed after all retry attempts"
|
|
||||||
failed_count=1
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Tag creation summary
|
- name: Tag creation summary
|
||||||
if: always()
|
if: always()
|
||||||
run: |
|
run: |
|
||||||
if [ "${{ github.event_name }}" = "push" ]; then
|
if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then
|
||||||
echo "Trigger: push on main"
|
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
|
||||||
echo "Created: ${{ steps.check_tag.outputs.created_count }}"
|
elif [ "${{ job.status }}" = "success" ]; then
|
||||||
echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}"
|
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||||
echo "Failed: ${{ steps.check_tag.outputs.failed_count }}"
|
|
||||||
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
|
|
||||||
echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}"
|
|
||||||
else
|
|
||||||
echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}"
|
|
||||||
fi
|
|
||||||
else
|
else
|
||||||
if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then
|
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
||||||
if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then
|
fi
|
||||||
echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed"
|
|
||||||
else
|
echo ""
|
||||||
echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
echo "Tag details:"
|
||||||
fi
|
echo " Name: ${{ steps.commit.outputs.tag_name }}"
|
||||||
else
|
echo " Commit: ${{ steps.commit.outputs.sha }}"
|
||||||
echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}"
|
echo " Trigger: ${{ github.event_name }}"
|
||||||
fi
|
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
|
||||||
|
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
|
||||||
echo ""
|
|
||||||
echo "Tag details:"
|
|
||||||
echo " Name: ${{ steps.commit.outputs.tag_name }}"
|
|
||||||
echo " Commit: ${{ steps.commit.outputs.sha }}"
|
|
||||||
echo " Trigger: ${{ github.event_name }}"
|
|
||||||
if [ -n "${{ github.event.inputs.commit_sha }}" ]; then
|
|
||||||
echo " Manual commit: ${{ github.event.inputs.commit_sha }}"
|
|
||||||
fi
|
|
||||||
fi
|
fi
|
||||||
|
|||||||
51
.github/workflows/trunk.yml
vendored
51
.github/workflows/trunk.yml
vendored
@ -190,40 +190,6 @@ jobs:
|
|||||||
runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
linux-jammy-rocm-py3_10-build:
|
|
||||||
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }}
|
|
||||||
name: linux-jammy-rocm-py3.10
|
|
||||||
uses: ./.github/workflows/_linux-build.yml
|
|
||||||
needs: get-label-type
|
|
||||||
with:
|
|
||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
|
||||||
build-environment: linux-jammy-rocm-py3.10
|
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
|
||||||
sync-tag: rocm-build
|
|
||||||
test-matrix: |
|
|
||||||
{ include: [
|
|
||||||
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
|
||||||
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
|
||||||
]}
|
|
||||||
secrets: inherit
|
|
||||||
|
|
||||||
linux-jammy-rocm-py3_10-test:
|
|
||||||
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }}
|
|
||||||
permissions:
|
|
||||||
id-token: write
|
|
||||||
contents: read
|
|
||||||
name: linux-jammy-rocm-py3.10
|
|
||||||
uses: ./.github/workflows/_rocm-test.yml
|
|
||||||
needs:
|
|
||||||
- linux-jammy-rocm-py3_10-build
|
|
||||||
- target-determination
|
|
||||||
with:
|
|
||||||
build-environment: linux-jammy-rocm-py3.10
|
|
||||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
|
||||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
|
||||||
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
|
|
||||||
secrets: inherit
|
|
||||||
|
|
||||||
inductor-build:
|
inductor-build:
|
||||||
name: inductor-build
|
name: inductor-build
|
||||||
uses: ./.github/workflows/_linux-build.yml
|
uses: ./.github/workflows/_linux-build.yml
|
||||||
@ -234,23 +200,6 @@ jobs:
|
|||||||
cuda-arch-list: '8.0'
|
cuda-arch-list: '8.0'
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
# Test cross-compiled models with Windows libs extracted from wheel
|
|
||||||
cross-compile-linux-test:
|
|
||||||
name: cross-compile-linux-test
|
|
||||||
uses: ./.github/workflows/_linux-test.yml
|
|
||||||
needs:
|
|
||||||
- linux-jammy-cuda12_8-py3_10-gcc11-build
|
|
||||||
- get-label-type
|
|
||||||
- win-vs2022-cuda12_8-py3-build
|
|
||||||
with:
|
|
||||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11
|
|
||||||
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }}
|
|
||||||
test-matrix: |
|
|
||||||
{ include: [
|
|
||||||
{ config: "aoti_cross_compile_for_windows", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", win_torch_wheel_artifact: "win-vs2022-cuda12.8-py3" },
|
|
||||||
]}
|
|
||||||
secrets: inherit
|
|
||||||
|
|
||||||
verify-cachebench-cpu-build:
|
verify-cachebench-cpu-build:
|
||||||
name: verify-cachebench-cpu-build
|
name: verify-cachebench-cpu-build
|
||||||
uses: ./.github/workflows/_linux-build.yml
|
uses: ./.github/workflows/_linux-build.yml
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -374,7 +374,6 @@ third_party/ruy/
|
|||||||
third_party/glog/
|
third_party/glog/
|
||||||
|
|
||||||
# Virtualenv
|
# Virtualenv
|
||||||
.venv/
|
|
||||||
venv/
|
venv/
|
||||||
|
|
||||||
# Log files
|
# Log files
|
||||||
|
|||||||
@ -833,7 +833,8 @@ exclude_patterns = [
|
|||||||
command = [
|
command = [
|
||||||
'python3',
|
'python3',
|
||||||
'tools/linter/adapters/grep_linter.py',
|
'tools/linter/adapters/grep_linter.py',
|
||||||
'--pattern=(cudaSetDevice|cudaGetDevice)\\(',
|
'--pattern=cudaSetDevice(',
|
||||||
|
'--pattern=cudaGetDevice(',
|
||||||
'--linter-name=RAWCUDADEVICE',
|
'--linter-name=RAWCUDADEVICE',
|
||||||
'--error-name=raw CUDA API usage',
|
'--error-name=raw CUDA API usage',
|
||||||
"""--error-description=\
|
"""--error-description=\
|
||||||
@ -1137,8 +1138,11 @@ command = [
|
|||||||
[[linter]]
|
[[linter]]
|
||||||
code = 'WORKFLOWSYNC'
|
code = 'WORKFLOWSYNC'
|
||||||
include_patterns = [
|
include_patterns = [
|
||||||
'.github/workflows/*.yml',
|
'.github/workflows/pull.yml',
|
||||||
'.github/workflows/*.yaml',
|
'.github/workflows/trunk.yml',
|
||||||
|
'.github/workflows/periodic.yml',
|
||||||
|
'.github/workflows/mac-mps.yml',
|
||||||
|
'.github/workflows/slow.yml',
|
||||||
]
|
]
|
||||||
command = [
|
command = [
|
||||||
'python3',
|
'python3',
|
||||||
|
|||||||
14
CODEOWNERS
14
CODEOWNERS
@ -201,17 +201,3 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
|
|||||||
/torch/csrc/stable/ @janeyx99 @mikaylagawarecki
|
/torch/csrc/stable/ @janeyx99 @mikaylagawarecki
|
||||||
/torch/headeronly/ @janeyx99
|
/torch/headeronly/ @janeyx99
|
||||||
/torch/header_only_apis.txt @janeyx99
|
/torch/header_only_apis.txt @janeyx99
|
||||||
|
|
||||||
# FlexAttention
|
|
||||||
/torch/nn/attention/flex_attention.py @drisspg
|
|
||||||
/torch/_higher_order_ops/flex_attention.py @drisspg
|
|
||||||
/torch/_inductor/kernel/flex/ @drisspg
|
|
||||||
/torch/_inductor/codegen/cpp_flex_attention_template.py @drisspg
|
|
||||||
/test/inductor/test_flex_attention.py @drisspg
|
|
||||||
/test/inductor/test_flex_decoding.py @drisspg
|
|
||||||
|
|
||||||
# Low Precision GEMMs
|
|
||||||
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
|
|
||||||
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
|
|
||||||
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
|
|
||||||
/test/test_scaled_matmul_cuda.py @drisspg @slayton58
|
|
||||||
|
|||||||
@ -38,7 +38,7 @@ set_bool(AT_HIPSPARSELT_ENABLED CAFFE2_USE_HIPSPARSELT)
|
|||||||
|
|
||||||
configure_file(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h")
|
configure_file(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h")
|
||||||
# TODO: Do not generate CUDAConfig.h for ROCm BUILDS
|
# TODO: Do not generate CUDAConfig.h for ROCm BUILDS
|
||||||
# At the moment, `jit_macros.h` include CUDAConfig.h for both CUDA and HIP builds
|
# At the moment, `jit_macors.h` include CUDAConfig.h for both CUDA and HIP builds
|
||||||
if(USE_CUDA OR USE_ROCM)
|
if(USE_CUDA OR USE_ROCM)
|
||||||
configure_file(cuda/CUDAConfig.h.in "${CMAKE_CURRENT_SOURCE_DIR}/cuda/CUDAConfig.h")
|
configure_file(cuda/CUDAConfig.h.in "${CMAKE_CURRENT_SOURCE_DIR}/cuda/CUDAConfig.h")
|
||||||
endif()
|
endif()
|
||||||
@ -289,15 +289,14 @@ IF(USE_FBGEMM_GENAI)
|
|||||||
|
|
||||||
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||||
|
|
||||||
set(fbgemm_genai_cuh
|
set(fbgemm_genai_mx8mx8bf16_grouped
|
||||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
|
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
|
||||||
"${FBGEMM_GENAI_SRCS}/"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
target_include_directories(fbgemm_genai PRIVATE
|
target_include_directories(fbgemm_genai PRIVATE
|
||||||
${FBGEMM_THIRD_PARTY}/cutlass/include
|
${FBGEMM_THIRD_PARTY}/cutlass/include
|
||||||
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
|
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
|
||||||
${fbgemm_genai_cuh}
|
${fbgemm_genai_mx8mx8bf16_grouped}
|
||||||
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
|
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
|
||||||
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
|
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
|
||||||
)
|
)
|
||||||
@ -314,14 +313,13 @@ IF(USE_FBGEMM_GENAI)
|
|||||||
|
|
||||||
# Add additional HIPCC compiler flags for performance
|
# Add additional HIPCC compiler flags for performance
|
||||||
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
|
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
|
||||||
|
-mllvm
|
||||||
|
-amdgpu-coerce-illegal-types=1
|
||||||
-mllvm
|
-mllvm
|
||||||
-enable-post-misched=0
|
-enable-post-misched=0
|
||||||
-mllvm
|
-mllvm
|
||||||
-greedy-reverse-local-assignment=1
|
-greedy-reverse-local-assignment=1
|
||||||
-fhip-new-launch-api)
|
-fhip-new-launch-api)
|
||||||
if(DEFINED ROCM_VERSION_DEV AND ROCM_VERSION_DEV VERSION_LESS "7.2.0")
|
|
||||||
list(PREPEND FBGEMM_GENAI_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Only compile for gfx942 for now.
|
# Only compile for gfx942 for now.
|
||||||
# This is rather hacky, I could not figure out a clean solution :(
|
# This is rather hacky, I could not figure out a clean solution :(
|
||||||
|
|||||||
@ -19,7 +19,6 @@
|
|||||||
#include <ATen/detail/MPSHooksInterface.h>
|
#include <ATen/detail/MPSHooksInterface.h>
|
||||||
#include <ATen/detail/MTIAHooksInterface.h>
|
#include <ATen/detail/MTIAHooksInterface.h>
|
||||||
#include <ATen/detail/PrivateUse1HooksInterface.h>
|
#include <ATen/detail/PrivateUse1HooksInterface.h>
|
||||||
#include <ATen/detail/XLAHooksInterface.h>
|
|
||||||
#include <ATen/detail/XPUHooksInterface.h>
|
#include <ATen/detail/XPUHooksInterface.h>
|
||||||
#include <c10/core/QEngine.h>
|
#include <c10/core/QEngine.h>
|
||||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||||
@ -89,8 +88,6 @@ class TORCH_API Context {
|
|||||||
return at::detail::getHIPHooks();
|
return at::detail::getHIPHooks();
|
||||||
} else if (opt_device_type == at::kHPU) {
|
} else if (opt_device_type == at::kHPU) {
|
||||||
return at::detail::getHPUHooks();
|
return at::detail::getHPUHooks();
|
||||||
} else if (opt_device_type == at::kXLA) {
|
|
||||||
return at::detail::getXLAHooks();
|
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
@ -199,7 +196,7 @@ class TORCH_API Context {
|
|||||||
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
|
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
|
||||||
}
|
}
|
||||||
static bool hasXLA() {
|
static bool hasXLA() {
|
||||||
return detail::getXLAHooks().hasXLA();
|
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
|
||||||
}
|
}
|
||||||
static bool hasXPU() {
|
static bool hasXPU() {
|
||||||
return detail::getXPUHooks().hasXPU();
|
return detail::getXPUHooks().hasXPU();
|
||||||
|
|||||||
@ -122,7 +122,7 @@ void FunctionalTensorWrapper::freeze_storage() const {
|
|||||||
// | have their own storages, but backends like functorch |
|
// | have their own storages, but backends like functorch |
|
||||||
// \/ are allowed to re-alias underneath the pass \/
|
// \/ are allowed to re-alias underneath the pass \/
|
||||||
// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
|
// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
|
||||||
// | underlying_storage | | underlying_storage |
|
// | underyling_storage | | underyling_storage |
|
||||||
// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
|
// . - - - - - - - - - - - - - . . - - - - - - - - - - - - - - - .
|
||||||
//
|
//
|
||||||
// This constructor is only used by view ops.
|
// This constructor is only used by view ops.
|
||||||
|
|||||||
@ -1534,7 +1534,7 @@ void TensorIteratorBase::build(TensorIteratorConfig& config) {
|
|||||||
|
|
||||||
// XLA and lazy tensors don't have storage, so they don't have an underlying data pointer.
|
// XLA and lazy tensors don't have storage, so they don't have an underlying data pointer.
|
||||||
// Nothing beyond this point is important for meta functions, so it's fine to exit early here.
|
// Nothing beyond this point is important for meta functions, so it's fine to exit early here.
|
||||||
// Extend the condition to MAIA tensors as MAIA tensors also don't have storage.
|
// Extend the condition to MAIA tesnors as MAIA tensors also don't have storage.
|
||||||
if (privateuse1_without_storage ||
|
if (privateuse1_without_storage ||
|
||||||
common_device_.type() == DeviceType::XLA ||
|
common_device_.type() == DeviceType::XLA ||
|
||||||
common_device_.type() == DeviceType::IPU ||
|
common_device_.type() == DeviceType::IPU ||
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <ATen/CachedTensorUtils.h>
|
#include <ATen/CachedTensorUtils.h>
|
||||||
|
#include <c10/core/GradMode.h>
|
||||||
#include <c10/util/flat_hash_map.h>
|
#include <c10/util/flat_hash_map.h>
|
||||||
|
|
||||||
namespace at::autocast {
|
namespace at::autocast {
|
||||||
@ -36,10 +37,29 @@ namespace {
|
|||||||
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
|
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
|
||||||
using val_type = std::tuple<weakref_type, Tensor>;
|
using val_type = std::tuple<weakref_type, Tensor>;
|
||||||
|
|
||||||
ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
|
// We maintain separate caches for gradient-enabled and gradient-disabled modes.
|
||||||
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts;
|
// This ensures that tensors cached in torch.no_grad() (with requires_grad=False)
|
||||||
return cached_casts;
|
// are not incorrectly reused in gradient-enabled contexts.
|
||||||
|
// This fixes issue #158232 while maintaining optimal performance for both modes.
|
||||||
|
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_enabled() {
|
||||||
|
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_enabled;
|
||||||
|
return cached_casts_grad_enabled;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_disabled() {
|
||||||
|
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_disabled;
|
||||||
|
return cached_casts_grad_disabled;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to get the appropriate cache based on current gradient mode.
|
||||||
|
// This allows us to cache tensors separately for grad-enabled and grad-disabled contexts,
|
||||||
|
// preventing incorrect cache hits when gradient mode changes.
|
||||||
|
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
|
||||||
|
return at::GradMode::is_enabled() ?
|
||||||
|
get_cached_casts_grad_enabled() :
|
||||||
|
get_cached_casts_grad_disabled();
|
||||||
|
}
|
||||||
|
|
||||||
std::mutex cached_casts_mutex;
|
std::mutex cached_casts_mutex;
|
||||||
|
|
||||||
|
|
||||||
@ -86,7 +106,9 @@ thread_local bool cache_enabled = true;
|
|||||||
|
|
||||||
void clear_cache() {
|
void clear_cache() {
|
||||||
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
|
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
|
||||||
get_cached_casts().clear();
|
// Clear both caches to ensure consistent behavior regardless of current gradient mode
|
||||||
|
get_cached_casts_grad_enabled().clear();
|
||||||
|
get_cached_casts_grad_disabled().clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
int increment_nesting() {
|
int increment_nesting() {
|
||||||
@ -121,6 +143,11 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_
|
|||||||
if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) {
|
if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) {
|
||||||
// Heuristic: Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves).
|
// Heuristic: Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves).
|
||||||
// See cached_casts declaration above for detailed strategy.
|
// See cached_casts declaration above for detailed strategy.
|
||||||
|
//
|
||||||
|
// We maintain separate caches for gradient-enabled and gradient-disabled modes
|
||||||
|
// (see get_cached_casts() above). This ensures correctness when mixing torch.no_grad()
|
||||||
|
// with torch.autocast(), while maintaining optimal performance for both training and inference.
|
||||||
|
// This fixes issue #158232 without any performance regression.
|
||||||
bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) &&
|
bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) &&
|
||||||
arg.scalar_type() == at::kFloat && arg.requires_grad() &&
|
arg.scalar_type() == at::kFloat && arg.requires_grad() &&
|
||||||
arg.is_leaf() && !arg.is_view() && cache_enabled &&
|
arg.is_leaf() && !arg.is_view() && cache_enabled &&
|
||||||
|
|||||||
@ -39,7 +39,7 @@ struct HostBlock {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename B>
|
template <typename B>
|
||||||
struct alignas(hardware_destructive_interference_size) FreeBlockList {
|
struct alignas(64) FreeBlockList {
|
||||||
std::mutex mutex_;
|
std::mutex mutex_;
|
||||||
std::deque<B*> list_;
|
std::deque<B*> list_;
|
||||||
};
|
};
|
||||||
@ -94,11 +94,11 @@ struct PinnedReserveSegment {
|
|||||||
struct TORCH_API HostStats {
|
struct TORCH_API HostStats {
|
||||||
// COUNT: total allocations (active)
|
// COUNT: total allocations (active)
|
||||||
Stat active_requests;
|
Stat active_requests;
|
||||||
// SUM: bytes allocated/reserved by this memory allocator. (active)
|
// SUM: bytes allocated/reserved by this memory alocator. (active)
|
||||||
Stat active_bytes;
|
Stat active_bytes;
|
||||||
// COUNT: total allocations (active + free)
|
// COUNT: total allocations (active + free)
|
||||||
Stat allocations;
|
Stat allocations;
|
||||||
// SUM: bytes allocated/reserved by this memory allocator. This accounts
|
// SUM: bytes allocated/reserved by this memory alocator. This accounts
|
||||||
// for both free and in-use blocks.
|
// for both free and in-use blocks.
|
||||||
Stat allocated_bytes;
|
Stat allocated_bytes;
|
||||||
|
|
||||||
@ -122,12 +122,12 @@ struct TORCH_API HostStats {
|
|||||||
// Struct containing memory allocator summary statistics for host, as they
|
// Struct containing memory allocator summary statistics for host, as they
|
||||||
// are staged for reporting. This is a temporary struct that is used to
|
// are staged for reporting. This is a temporary struct that is used to
|
||||||
// avoid locking the allocator while collecting stats.
|
// avoid locking the allocator while collecting stats.
|
||||||
struct alignas(hardware_destructive_interference_size) HostStatsStaged {
|
struct alignas(64) HostStatsStaged {
|
||||||
std::mutex timing_mutex_;
|
std::mutex timing_mutex_;
|
||||||
// COUNT: total allocations (active + free)
|
// COUNT: total allocations (active + free)
|
||||||
// LOCK: access to this stat is protected by the allocator's blocks_mutex_
|
// LOCK: access to this stat is protected by the allocator's blocks_mutex_
|
||||||
Stat allocations;
|
Stat allocations;
|
||||||
// SUM: bytes allocated/reserved by this memory allocator. This accounts
|
// SUM: bytes allocated/reserved by this memory alocator. This accounts
|
||||||
// for both free and in-use blocks.
|
// for both free and in-use blocks.
|
||||||
Stat allocated_bytes;
|
Stat allocated_bytes;
|
||||||
// COUNT: number of allocations per bucket (active)
|
// COUNT: number of allocations per bucket (active)
|
||||||
@ -455,7 +455,7 @@ struct CachingHostAllocatorImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void resetAccumulatedStats() {
|
void resetAccumulatedStats() {
|
||||||
// Resetting accumulated memory stats requires concurrently holding both the
|
// Reseting accumulated memory stats requires concurrently holding both the
|
||||||
// free list mutexes and the blocks mutex. Previously, this was only done in
|
// free list mutexes and the blocks mutex. Previously, this was only done in
|
||||||
// empty_cache function.
|
// empty_cache function.
|
||||||
for (size_t i = 0; i < free_list_.size(); ++i) {
|
for (size_t i = 0; i < free_list_.size(); ++i) {
|
||||||
@ -482,7 +482,7 @@ struct CachingHostAllocatorImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void resetPeakStats() {
|
void resetPeakStats() {
|
||||||
// Resetting peak memory stats requires concurrently holding both the
|
// Reseting peak memory stats requires concurrently holding both the
|
||||||
// free list mutexes and the blocks mutex. Previously, this was only done in
|
// free list mutexes and the blocks mutex. Previously, this was only done in
|
||||||
// empty_cache function.
|
// empty_cache function.
|
||||||
for (size_t i = 0; i < free_list_.size(); ++i) {
|
for (size_t i = 0; i < free_list_.size(); ++i) {
|
||||||
@ -669,7 +669,7 @@ struct CachingHostAllocatorImpl {
|
|||||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
|
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
|
||||||
}
|
}
|
||||||
|
|
||||||
alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_;
|
alignas(64) std::mutex blocks_mutex_;
|
||||||
ska::flat_hash_set<B*> blocks_; // block list
|
ska::flat_hash_set<B*> blocks_; // block list
|
||||||
ska::flat_hash_map<void*, B*> ptr_to_block_;
|
ska::flat_hash_map<void*, B*> ptr_to_block_;
|
||||||
|
|
||||||
@ -677,17 +677,17 @@ struct CachingHostAllocatorImpl {
|
|||||||
// size. This allows us to quickly find a free block of the right size.
|
// size. This allows us to quickly find a free block of the right size.
|
||||||
// We use deque to store per size free list and guard the list with its own
|
// We use deque to store per size free list and guard the list with its own
|
||||||
// mutex.
|
// mutex.
|
||||||
alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>> free_list_ =
|
alignas(64) std::vector<FreeBlockList<B>> free_list_ =
|
||||||
std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
|
std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
|
||||||
|
|
||||||
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
|
alignas(64) std::mutex events_mutex_;
|
||||||
std::deque<std::pair<E, B*>> events_; // event queue paired with block
|
std::deque<std::pair<E, B*>> events_; // event queue paired with block
|
||||||
|
|
||||||
// Indicates whether the object is active.
|
// Indicates whether the object is active.
|
||||||
// Set to false in the destructor to signal background threads to stop.
|
// Set to false in the destructor to signal background threads to stop.
|
||||||
std::atomic<bool> active_{true};
|
std::atomic<bool> active_{true};
|
||||||
protected:
|
protected:
|
||||||
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
|
alignas(64) HostStatsStaged stats_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TORCH_API HostAllocator : public at::Allocator {
|
struct TORCH_API HostAllocator : public at::Allocator {
|
||||||
|
|||||||
@ -59,7 +59,9 @@ struct TORCH_API Generator {
|
|||||||
|
|
||||||
explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl)
|
explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl)
|
||||||
: impl_(std::move(gen_impl)) {
|
: impl_(std::move(gen_impl)) {
|
||||||
TORCH_CHECK(impl_.get(), "GeneratorImpl with nullptr is not supported");
|
if (impl_.get() == nullptr) {
|
||||||
|
throw std::runtime_error("GeneratorImpl with nullptr is not supported");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator==(const Generator& rhs) const {
|
bool operator==(const Generator& rhs) const {
|
||||||
|
|||||||
@ -229,10 +229,10 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static constexpr uint32_t kPhilox10A = 0x9E3779B9;
|
static const uint32_t kPhilox10A = 0x9E3779B9;
|
||||||
static constexpr uint32_t kPhilox10B = 0xBB67AE85;
|
static const uint32_t kPhilox10B = 0xBB67AE85;
|
||||||
static constexpr uint32_t kPhiloxSA = 0xD2511F53;
|
static const uint32_t kPhiloxSA = 0xD2511F53;
|
||||||
static constexpr uint32_t kPhiloxSB = 0xCD9E8D57;
|
static const uint32_t kPhiloxSB = 0xCD9E8D57;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef philox_engine Philox4_32;
|
typedef philox_engine Philox4_32;
|
||||||
|
|||||||
@ -111,7 +111,9 @@ class TORCH_API TensorBase {
|
|||||||
explicit TensorBase(
|
explicit TensorBase(
|
||||||
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
|
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
|
||||||
: impl_(std::move(tensor_impl)) {
|
: impl_(std::move(tensor_impl)) {
|
||||||
TORCH_CHECK(impl_.get(), "TensorImpl with nullptr is not supported");
|
if (impl_.get() == nullptr) {
|
||||||
|
throw std::runtime_error("TensorImpl with nullptr is not supported");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
TensorBase(const TensorBase&) = default;
|
TensorBase(const TensorBase&) = default;
|
||||||
TensorBase(TensorBase&&) noexcept = default;
|
TensorBase(TensorBase&&) noexcept = default;
|
||||||
|
|||||||
@ -109,10 +109,6 @@ TORCH_LIBRARY_IMPL(_, AutogradHPU, m) {
|
|||||||
m.fallback(AUTOGRAD_FALLBACK);
|
m.fallback(AUTOGRAD_FALLBACK);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) {
|
|
||||||
m.fallback(AUTOGRAD_FALLBACK);
|
|
||||||
}
|
|
||||||
|
|
||||||
#undef AUTOGRAD_FALLBACK
|
#undef AUTOGRAD_FALLBACK
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|||||||
@ -148,7 +148,7 @@ struct TORCH_API ClassType : public NamedType {
|
|||||||
|
|
||||||
void checkNotExist(const std::string& name, const std::string& what) const;
|
void checkNotExist(const std::string& name, const std::string& what) const;
|
||||||
|
|
||||||
// Attributes are stored in a specific slot at runtime for efficiency.
|
// Attributes are stored in a specific slot at runtime for effiency.
|
||||||
// When emitting instructions we specify the slot so that attribute access is
|
// When emitting instructions we specify the slot so that attribute access is
|
||||||
// a constant lookup
|
// a constant lookup
|
||||||
std::optional<size_t> findAttributeSlot(const std::string& name) const {
|
std::optional<size_t> findAttributeSlot(const std::string& name) const {
|
||||||
@ -412,7 +412,7 @@ struct TORCH_API ClassType : public NamedType {
|
|||||||
// Holds method attributes
|
// Holds method attributes
|
||||||
std::weak_ptr<CompilationUnit> compilation_unit_;
|
std::weak_ptr<CompilationUnit> compilation_unit_;
|
||||||
|
|
||||||
// Holds all attributes, attribute details are found on ClassAttribute
|
// Holds all atrributes, attribute details are found on ClassAttribute
|
||||||
std::vector<ClassAttribute> attributes_;
|
std::vector<ClassAttribute> attributes_;
|
||||||
// Construct mirroring attributes_, only around due to the fact that `containedTypes()` method returns an ArrayRef.
|
// Construct mirroring attributes_, only around due to the fact that `containedTypes()` method returns an ArrayRef.
|
||||||
// Never fill this without using the appropriate provideNewClassAttribute method
|
// Never fill this without using the appropriate provideNewClassAttribute method
|
||||||
|
|||||||
@ -442,17 +442,11 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker
|
|||||||
|
|
||||||
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
|
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
|
||||||
TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx);
|
TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx);
|
||||||
// NB: Perserve BC for registering fallback for AutogradPrivateUse1 multiple time,
|
|
||||||
// refer to https://github.com/pytorch/pytorch/issues/163979 for more informations.
|
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
dispatchKey == DispatchKey::AutogradPrivateUse1 ||
|
!backendFallbackKernels_[idx].kernel.isValid(),
|
||||||
!backendFallbackKernels_[idx].kernel.isValid(),
|
"Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
|
||||||
"Tried to register multiple backend fallbacks for the same dispatch key ",
|
backendFallbackKernels_[idx].debug, ", new registration ", debug
|
||||||
dispatchKey,
|
);
|
||||||
"; previous registration ",
|
|
||||||
backendFallbackKernels_[idx].debug,
|
|
||||||
", new registration ",
|
|
||||||
debug);
|
|
||||||
// NB: inferred function schema is always nullptr for fallbacks, as fallbacks
|
// NB: inferred function schema is always nullptr for fallbacks, as fallbacks
|
||||||
// cannot be unboxed
|
// cannot be unboxed
|
||||||
backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
|
backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
|
||||||
@ -537,7 +531,7 @@ int64_t Dispatcher::sequenceNumberForRunningRecordFunction(DispatchKey dispatchK
|
|||||||
|
|
||||||
// Note: this records a sequence number for both Autograd keys, and for
|
// Note: this records a sequence number for both Autograd keys, and for
|
||||||
// non-Autograd keys where the dispatchKeySet still contains an autograd key.
|
// non-Autograd keys where the dispatchKeySet still contains an autograd key.
|
||||||
// This means that we might collect the same sequence number two different
|
// This means that we might collect the same sequence nubmer two different
|
||||||
// events if they all occurred above Autograd and still had the Autograd
|
// events if they all occurred above Autograd and still had the Autograd
|
||||||
// dispatch key in the dispatch key set.
|
// dispatch key in the dispatch key set.
|
||||||
// However, this usually doesn't happen: normally the first call will
|
// However, this usually doesn't happen: normally the first call will
|
||||||
|
|||||||
@ -585,7 +585,7 @@ class TORCH_API OperatorHandle {
|
|||||||
|
|
||||||
// We need to store this iterator in order to make
|
// We need to store this iterator in order to make
|
||||||
// Dispatcher::cleanup() fast -- it runs a lot on program
|
// Dispatcher::cleanup() fast -- it runs a lot on program
|
||||||
// termination (and presumably library unloading).
|
// termination (and presuambly library unloading).
|
||||||
std::list<Dispatcher::OperatorDef>::iterator operatorIterator_;
|
std::list<Dispatcher::OperatorDef>::iterator operatorIterator_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -365,7 +365,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
|
|||||||
// For autograd keys, we only use kernel from CompositeImplicitAutograd when there's no direct registration
|
// For autograd keys, we only use kernel from CompositeImplicitAutograd when there's no direct registration
|
||||||
// to its corresponding backend key or CompositeExplicitAutograd. See Note [CompositeExplicitAutograd and CompositeImplicitAutograd].
|
// to its corresponding backend key or CompositeExplicitAutograd. See Note [CompositeExplicitAutograd and CompositeImplicitAutograd].
|
||||||
// For AutogradOther, we eagerly return ambiguousAutogradOtherKernel() if there's registration to any of
|
// For AutogradOther, we eagerly return ambiguousAutogradOtherKernel() if there's registration to any of
|
||||||
// its backends and ask backend extender to request a dedicated Autograd key for the backend.
|
// its backends and ask backend extender to request a decicated Autograd key for the backend.
|
||||||
// See Note [Ambiguity in AutogradOther kernel] for more details.
|
// See Note [Ambiguity in AutogradOther kernel] for more details.
|
||||||
// A CompositeExplicitAutograd kernel prevents CompositeImplicitAutograd kernel being used for Autograd keys, but it doesn't
|
// A CompositeExplicitAutograd kernel prevents CompositeImplicitAutograd kernel being used for Autograd keys, but it doesn't
|
||||||
// cause confusion for AutogradOther. It's pretty straightforward to use Autograd (if available)
|
// cause confusion for AutogradOther. It's pretty straightforward to use Autograd (if available)
|
||||||
|
|||||||
@ -261,7 +261,7 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
|
|||||||
//
|
//
|
||||||
// There are 2 cases
|
// There are 2 cases
|
||||||
// 1. something like 'aten::items.str(Dict(str, t) self) -> ((str, t)[])'.
|
// 1. something like 'aten::items.str(Dict(str, t) self) -> ((str, t)[])'.
|
||||||
// without the extra parenthesis, the c++ scheme parser can not parse it.
|
// without the extra parenthesis, the c++ schem parser can not parse it.
|
||||||
// 2. something like '-> ((str, str))'. Need extra parenthesis so the return
|
// 2. something like '-> ((str, str))'. Need extra parenthesis so the return
|
||||||
// type is a single tuple rather than two strings.
|
// type is a single tuple rather than two strings.
|
||||||
// PR (https://github.com/pytorch/pytorch/pull/23204) has more context about
|
// PR (https://github.com/pytorch/pytorch/pull/23204) has more context about
|
||||||
|
|||||||
@ -68,7 +68,11 @@ Symbol InternedStrings::_symbol(const std::string& s) {
|
|||||||
return it->second;
|
return it->second;
|
||||||
|
|
||||||
auto pos = s.find("::");
|
auto pos = s.find("::");
|
||||||
TORCH_CHECK(pos != std::string::npos, "all symbols must have a namespace, <namespace>::<string>, but found: ", s);
|
if (pos == std::string::npos) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s;
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
Symbol ns = _symbol("namespaces::" + s.substr(0, pos));
|
Symbol ns = _symbol("namespaces::" + s.substr(0, pos));
|
||||||
|
|
||||||
Symbol sym(sym_to_info_.size());
|
Symbol sym(sym_to_info_.size());
|
||||||
@ -117,7 +121,12 @@ std::string Symbol::domainString() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) {
|
Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) {
|
||||||
TORCH_CHECK(d.compare(0, domain_prefix().size(), domain_prefix()) == 0, "Symbol: domain string is expected to be prefixed with '", domain_prefix(), "', e.g. 'org.pytorch.aten'");
|
if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) {
|
||||||
|
std::ostringstream ss;
|
||||||
|
ss << "Symbol: domain string is expected to be prefixed with '"
|
||||||
|
<< domain_prefix() << "', e.g. 'org.pytorch.aten'";
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
std::string qualString = d.substr(domain_prefix().size()) + "::" + s;
|
std::string qualString = d.substr(domain_prefix().size()) + "::" + s;
|
||||||
return fromQualString(qualString);
|
return fromQualString(qualString);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,7 +7,6 @@
|
|||||||
#include <ATen/core/jit_type.h>
|
#include <ATen/core/jit_type.h>
|
||||||
#include <ATen/core/stack.h>
|
#include <ATen/core/stack.h>
|
||||||
#include <ATen/core/type_factory.h>
|
#include <ATen/core/type_factory.h>
|
||||||
#include <c10/util/Exception.h>
|
|
||||||
#include <c10/util/StringUtil.h>
|
#include <c10/util/StringUtil.h>
|
||||||
#include <c10/util/hash.h>
|
#include <c10/util/hash.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
@ -413,7 +412,7 @@ size_t IValue::hash(const IValue& v) {
|
|||||||
case Tag::Enum:
|
case Tag::Enum:
|
||||||
case Tag::Stream:
|
case Tag::Stream:
|
||||||
case Tag::Uninitialized:
|
case Tag::Uninitialized:
|
||||||
TORCH_CHECK(false,
|
throw std::runtime_error(
|
||||||
"unhashable type: '" + v.type()->repr_str() + "'");
|
"unhashable type: '" + v.type()->repr_str() + "'");
|
||||||
}
|
}
|
||||||
// the above switch should be exhaustive
|
// the above switch should be exhaustive
|
||||||
|
|||||||
@ -1176,7 +1176,7 @@ struct TORCH_API IValue final {
|
|||||||
using HashIdentityIValueMap =
|
using HashIdentityIValueMap =
|
||||||
std::unordered_map<IValue, IValue, HashIdentityIValue, CompIdentityIValues>;
|
std::unordered_map<IValue, IValue, HashIdentityIValue, CompIdentityIValues>;
|
||||||
|
|
||||||
// Checks if this and rhs has a subvalues in common.
|
// Chechs if this and rhs has a subvalues in common.
|
||||||
// [t1,t2] and [t2, t3] returns true.
|
// [t1,t2] and [t2, t3] returns true.
|
||||||
bool overlaps(const IValue& rhs) const;
|
bool overlaps(const IValue& rhs) const;
|
||||||
|
|
||||||
|
|||||||
@ -1501,7 +1501,7 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
|
|||||||
// However, the CompilationUnit holds ownership of the type's graphs, so
|
// However, the CompilationUnit holds ownership of the type's graphs, so
|
||||||
// inserting a constant object into a Graph would create a reference cycle if
|
// inserting a constant object into a Graph would create a reference cycle if
|
||||||
// that constant object held a shared_ptr to its CU. For these objects we
|
// that constant object held a shared_ptr to its CU. For these objects we
|
||||||
// instantiate them with non-owning references to its CU
|
// instatiate them with non-owning references to its CU
|
||||||
Object(WeakOrStrongTypePtr type, size_t numSlots) : type_(std::move(type)) {
|
Object(WeakOrStrongTypePtr type, size_t numSlots) : type_(std::move(type)) {
|
||||||
slots_.resize(numSlots);
|
slots_.resize(numSlots);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,7 +8,6 @@
|
|||||||
#include <ATen/core/type_factory.h>
|
#include <ATen/core/type_factory.h>
|
||||||
#include <ATen/core/qualified_name.h>
|
#include <ATen/core/qualified_name.h>
|
||||||
#include <c10/util/TypeList.h>
|
#include <c10/util/TypeList.h>
|
||||||
#include <c10/util/Exception.h>
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <c10/core/SymFloat.h>
|
#include <c10/core/SymFloat.h>
|
||||||
#include <c10/core/SymBool.h>
|
#include <c10/core/SymBool.h>
|
||||||
@ -117,8 +116,10 @@ struct SingleElementType : public SharedType {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) {
|
SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) {
|
||||||
TORCH_CHECK(this->elem, c10::str(
|
if (!this->elem) {
|
||||||
|
throw std::runtime_error(c10::str(
|
||||||
"Can not create ", typeKindToString(Kind), " with None type"));
|
"Can not create ", typeKindToString(Kind), " with None type"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -373,7 +374,7 @@ struct TORCH_API SymbolicShape {
|
|||||||
// Unranked shape constructor.
|
// Unranked shape constructor.
|
||||||
SymbolicShape() : dims_(std::nullopt) {}
|
SymbolicShape() : dims_(std::nullopt) {}
|
||||||
|
|
||||||
// Known rank but unknown dimensions.
|
// Known rank but unknown dimentions.
|
||||||
SymbolicShape(std::optional<size_t> rank) : dims_(std::nullopt) {
|
SymbolicShape(std::optional<size_t> rank) : dims_(std::nullopt) {
|
||||||
if(!rank) {
|
if(!rank) {
|
||||||
return;
|
return;
|
||||||
@ -415,12 +416,16 @@ struct TORCH_API SymbolicShape {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ShapeSymbol operator[](size_t i) const {
|
ShapeSymbol operator[](size_t i) const {
|
||||||
TORCH_CHECK(dims_, "Rank isn't fixed");
|
if (!dims_) {
|
||||||
|
throw std::runtime_error("Rank isn't fixed");
|
||||||
|
}
|
||||||
return (*dims_).at(i);
|
return (*dims_).at(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
ShapeSymbol at(size_t i) const {
|
ShapeSymbol at(size_t i) const {
|
||||||
TORCH_CHECK(dims_, "Rank isn't fixed");
|
if (!dims_) {
|
||||||
|
throw std::runtime_error("Rank isn't fixed");
|
||||||
|
}
|
||||||
return (*dims_).at(i);
|
return (*dims_).at(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -515,7 +520,9 @@ struct VaryingShape {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const std::optional<T> &operator[](size_t i) const {
|
const std::optional<T> &operator[](size_t i) const {
|
||||||
TORCH_CHECK(dims_, "Rank isn't fixed");
|
if (!dims_) {
|
||||||
|
throw std::runtime_error("Rank isn't fixed");
|
||||||
|
}
|
||||||
return (*dims_).at(i);
|
return (*dims_).at(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -884,9 +891,9 @@ struct TORCH_API ListType
|
|||||||
|
|
||||||
// global singleton
|
// global singleton
|
||||||
// Given an inner type T and an identifier,
|
// Given an inner type T and an identifier,
|
||||||
// this function will return the global singleton type pointer
|
// this function wil return the global singleton type pointer
|
||||||
// the type List<T>.
|
// the type List<T>.
|
||||||
// The extra "identifier" argument is needed because we have multiple container types
|
// The extra "identifier" argument is needed beccause we have multiple container types
|
||||||
// that all re-use this function (List<T>, array<T, N>, etc.)
|
// that all re-use this function (List<T>, array<T, N>, etc.)
|
||||||
static TypePtr get(const std::string& identifier, TypePtr inner);
|
static TypePtr get(const std::string& identifier, TypePtr inner);
|
||||||
|
|
||||||
@ -950,7 +957,9 @@ struct TORCH_API DictType : public SharedType {
|
|||||||
|
|
||||||
TypePtr createWithContained(
|
TypePtr createWithContained(
|
||||||
std::vector<TypePtr> contained_types) const override {
|
std::vector<TypePtr> contained_types) const override {
|
||||||
TORCH_CHECK(contained_types.size() == 2, "Expected 2 contained types");
|
if (contained_types.size() != 2) {
|
||||||
|
throw std::runtime_error("Expected 2 contained types");
|
||||||
|
}
|
||||||
return create(std::move(contained_types.at(0)), std::move(contained_types.at(1)));
|
return create(std::move(contained_types.at(0)), std::move(contained_types.at(1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -185,11 +185,11 @@ struct TORCH_API Type {
|
|||||||
: repr_(nullptr) {}
|
: repr_(nullptr) {}
|
||||||
|
|
||||||
/* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<T> p)
|
/* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<T> p)
|
||||||
: repr_(makeSingletonSharedPtr(p.get())) {}
|
: repr_(p) {}
|
||||||
|
|
||||||
template <typename U, std::enable_if_t<std::is_convertible_v<U*, T*>, bool> = true>
|
template <typename U, std::enable_if_t<std::is_convertible_v<U*, T*>, bool> = true>
|
||||||
/* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<U> p)
|
/* implicit */ SingletonOrSharedTypePtr(SingletonTypePtr<U> p)
|
||||||
: repr_(makeSingletonSharedPtr(static_cast<T*>(p.get()))) {}
|
: repr_(SingletonTypePtr<T>(p.get())) {}
|
||||||
|
|
||||||
|
|
||||||
// We need to support construction from T* for pybind. The problem
|
// We need to support construction from T* for pybind. The problem
|
||||||
@ -202,8 +202,8 @@ struct TORCH_API Type {
|
|||||||
// Case 2: if T is exactly Type, we need to do a dynamic_cast to
|
// Case 2: if T is exactly Type, we need to do a dynamic_cast to
|
||||||
// check if it's a SharedType and do the right thing.
|
// check if it's a SharedType and do the right thing.
|
||||||
//
|
//
|
||||||
// Case 3: Otherwise, T is not a SharedType. Use a singleton
|
// Case 3: Otherwise, T is not a SharedType. (debug-check this
|
||||||
// pointer.
|
// assumption!) Use a singleton pointer.
|
||||||
|
|
||||||
template <typename U = T, std::enable_if_t<std::is_base_of_v<SharedType, U>, bool> = true>
|
template <typename U = T, std::enable_if_t<std::is_base_of_v<SharedType, U>, bool> = true>
|
||||||
/* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast<typename detail::as_shared_type<U>::type>(p)->shared_from_this()) {}
|
/* implicit */ SingletonOrSharedTypePtr(T* p) : SingletonOrSharedTypePtr(static_cast<typename detail::as_shared_type<U>::type>(p)->shared_from_this()) {}
|
||||||
@ -211,15 +211,15 @@ struct TORCH_API Type {
|
|||||||
template <typename U = T, std::enable_if_t<std::is_same_v<Type, U>, bool> = true>
|
template <typename U = T, std::enable_if_t<std::is_same_v<Type, U>, bool> = true>
|
||||||
/* implicit */ SingletonOrSharedTypePtr(T* p) {
|
/* implicit */ SingletonOrSharedTypePtr(T* p) {
|
||||||
if (auto* shared_p = dynamic_cast<typename detail::as_shared_type<U>::type>(p)) {
|
if (auto* shared_p = dynamic_cast<typename detail::as_shared_type<U>::type>(p)) {
|
||||||
repr_ = shared_p->shared_from_this();
|
repr_ = Repr(shared_p->shared_from_this());
|
||||||
} else {
|
} else {
|
||||||
repr_ = makeSingletonSharedPtr(p);
|
repr_ = Repr(p);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename U = T, std::enable_if_t<!std::is_same_v<Type, U> && !std::is_base_of_v<SharedType, U>, bool> = true>
|
template <typename U = T, std::enable_if_t<!std::is_same_v<Type, U> && !std::is_base_of_v<SharedType, U>, bool> = true>
|
||||||
/* implicit */ SingletonOrSharedTypePtr(T* p)
|
/* implicit */ SingletonOrSharedTypePtr(T* p)
|
||||||
: repr_(makeSingletonSharedPtr(p)) {
|
: repr_(p) {
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast<typename detail::as_shared_type<U>::type>(p) == nullptr);
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dynamic_cast<typename detail::as_shared_type<U>::type>(p) == nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,19 +230,19 @@ struct TORCH_API Type {
|
|||||||
~SingletonOrSharedTypePtr() = default;
|
~SingletonOrSharedTypePtr() = default;
|
||||||
|
|
||||||
T* get() const {
|
T* get() const {
|
||||||
return repr_.get();
|
return repr_.isSharedAndNonNull() ? repr_.shared_.repr_.get() : static_cast<T*>(repr_.rawRepr().first);
|
||||||
}
|
}
|
||||||
|
|
||||||
operator bool() const {
|
operator bool() const {
|
||||||
return repr_ != nullptr;
|
return repr_.isNonNull();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator==(std::nullptr_t) const {
|
bool operator==(std::nullptr_t) const {
|
||||||
return repr_ == nullptr;
|
return !repr_.isNonNull();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator!=(std::nullptr_t) const {
|
bool operator!=(std::nullptr_t) const {
|
||||||
return repr_ != nullptr;
|
return repr_.isNonNull();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename U = T, std::enable_if_t<!std::is_same_v<std::remove_const_t<U>, void>, bool> = true>
|
template <typename U = T, std::enable_if_t<!std::is_same_v<std::remove_const_t<U>, void>, bool> = true>
|
||||||
@ -255,14 +255,138 @@ struct TORCH_API Type {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Use shared_ptr's aliasing constructor to create a non-owning pointer
|
// NOTE: SharedPtrWrapper exists to work around a baffling bug in
|
||||||
// to a singleton. The lifetime is tied to the null shared_ptr, so there's
|
// nvcc; see comment in destroy() below.
|
||||||
// no reference counting overhead for the singleton itself.
|
struct SharedPtrWrapper {
|
||||||
static std::shared_ptr<T> makeSingletonSharedPtr(T* ptr) {
|
SharedPtrWrapper(std::shared_ptr<T> &&x)
|
||||||
return std::shared_ptr<T>(std::shared_ptr<T>(), ptr);
|
: repr_(std::move(x)) {}
|
||||||
}
|
std::shared_ptr<T> repr_;
|
||||||
|
};
|
||||||
|
union Repr {
|
||||||
|
Repr() : Repr(nullptr) {}
|
||||||
|
|
||||||
std::shared_ptr<T> repr_;
|
explicit Repr(std::shared_ptr<T> x)
|
||||||
|
: shared_(std::move(x)) {}
|
||||||
|
|
||||||
|
explicit Repr(std::nullptr_t)
|
||||||
|
: singletonRepr_(nullptr) {}
|
||||||
|
|
||||||
|
explicit Repr(SingletonTypePtr<T> p)
|
||||||
|
: singletonRepr_(p.get()) {}
|
||||||
|
|
||||||
|
~Repr() {
|
||||||
|
destroy();
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: the only non-UB way to access our null state is through
|
||||||
|
// rawRepr(), because our copy operation doesn't preserve which
|
||||||
|
// union member is active for null pointers.
|
||||||
|
Repr(const Repr& rhs) {
|
||||||
|
if (rhs.isSharedAndNonNull()) {
|
||||||
|
new (&shared_) SharedPtrWrapper(rhs.shared_);
|
||||||
|
} else {
|
||||||
|
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
|
||||||
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
|
||||||
|
singletonRepr_.unused_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Repr(Repr&& rhs) noexcept {
|
||||||
|
if (rhs.isSharedAndNonNull()) {
|
||||||
|
new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
|
||||||
|
} else {
|
||||||
|
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
|
||||||
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.singletonRepr_.unused_ == nullptr);
|
||||||
|
singletonRepr_.unused_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Repr& operator=(const Repr& rhs) {
|
||||||
|
if (&rhs == this) {
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
if (rhs.isSharedAndNonNull()) {
|
||||||
|
if (isSharedAndNonNull()) {
|
||||||
|
shared_ = rhs.shared_;
|
||||||
|
} else {
|
||||||
|
new (&shared_) SharedPtrWrapper(rhs.shared_);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (isSharedAndNonNull()) {
|
||||||
|
destroy();
|
||||||
|
}
|
||||||
|
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
|
||||||
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
|
||||||
|
singletonRepr_.unused_ = nullptr;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
Repr& operator=(Repr&& rhs) noexcept {
|
||||||
|
if (&rhs == this) {
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
if (rhs.isSharedAndNonNull()) {
|
||||||
|
if (isSharedAndNonNull()) {
|
||||||
|
shared_ = std::move(rhs.shared_);
|
||||||
|
} else {
|
||||||
|
new (&shared_) SharedPtrWrapper(std::move(rhs.shared_));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (isSharedAndNonNull()) {
|
||||||
|
destroy();
|
||||||
|
}
|
||||||
|
singletonRepr_.singleton_ = static_cast<T*>(rhs.rawRepr().first);
|
||||||
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.rawRepr().nullIfSingleton_ == nullptr);
|
||||||
|
singletonRepr_.unused_ = nullptr;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
SharedPtrWrapper shared_;
|
||||||
|
|
||||||
|
struct SingletonRepr {
|
||||||
|
explicit SingletonRepr(T* s) : singleton_(s) {}
|
||||||
|
T* singleton_;
|
||||||
|
void* unused_ = nullptr;
|
||||||
|
} singletonRepr_;
|
||||||
|
struct RawRepr {
|
||||||
|
void* first;
|
||||||
|
void* nullIfSingleton_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// It is UB to read the singleton part of Repr if it was
|
||||||
|
// constructed as a shared_ptr and vice versa, but memcpying out
|
||||||
|
// the representation is always OK, so here's an accessor to obey
|
||||||
|
// the letter of the law.
|
||||||
|
RawRepr rawRepr() const {
|
||||||
|
RawRepr repr{};
|
||||||
|
memcpy(&repr, reinterpret_cast<const char *>(this), sizeof(RawRepr));
|
||||||
|
return repr;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isNonNull() const {
|
||||||
|
auto repr = rawRepr();
|
||||||
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(repr.nullIfSingleton_ == nullptr || repr.first != nullptr);
|
||||||
|
return repr.first != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isSharedAndNonNull() const {
|
||||||
|
return rawRepr().nullIfSingleton_ != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void destroy() {
|
||||||
|
if (isSharedAndNonNull()) {
|
||||||
|
// Without SharedPtrWrapper, this line would read
|
||||||
|
// `shared_.~shared_ptr()` and nvcc would complain with
|
||||||
|
// "error: expected primary-expression before '>' token"
|
||||||
|
// referring to the "t" in "shared_ptr". SharedPtrWrapper
|
||||||
|
// exists to work around this compiler bug.
|
||||||
|
shared_.~SharedPtrWrapper();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} repr_;
|
||||||
};
|
};
|
||||||
|
|
||||||
using TypePtr = SingletonOrSharedTypePtr<Type>;
|
using TypePtr = SingletonOrSharedTypePtr<Type>;
|
||||||
|
|||||||
@ -21,7 +21,7 @@ namespace c10 {
|
|||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
// The first argument of the schema might be of type DispatchKeySet, in which case we remove it.
|
// The first argument of the schema might be of type DispatchKeySet, in which case we remove it.
|
||||||
// We do this because every argument in a function schema is expected to be convertible
|
// We do this because every argument in a function schema is expected to be convertable
|
||||||
// to an ivalue, but DispatchKeySet is not a type we want the jit to be aware of.
|
// to an ivalue, but DispatchKeySet is not a type we want the jit to be aware of.
|
||||||
// See Note [Plumbing Keys Through The Dispatcher]
|
// See Note [Plumbing Keys Through The Dispatcher]
|
||||||
template<class KernelFunctor>
|
template<class KernelFunctor>
|
||||||
|
|||||||
@ -251,7 +251,7 @@ TEST(OperatorRegistrationTest, whenRegisteringCPUTensorType_thenCanOnlyCallUnbox
|
|||||||
callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CPU), dummyTensor(c10::DispatchKey::CUDA));
|
callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CPU), dummyTensor(c10::DispatchKey::CUDA));
|
||||||
EXPECT_TRUE(called_kernel_cpu);
|
EXPECT_TRUE(called_kernel_cpu);
|
||||||
|
|
||||||
// Ensure that dispatch key from tensor is not used here.
|
// Ensure that disptach key from tensor is not used here.
|
||||||
called_kernel_cpu = false;
|
called_kernel_cpu = false;
|
||||||
expectThrows<c10::Error>([&] {
|
expectThrows<c10::Error>([&] {
|
||||||
callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CUDA), dummyTensor(c10::DispatchKey::CPU));
|
callOpUnboxedWithPrecomputedDispatchKeySet<void, Tensor>(*op, c10::DispatchKeySet(c10::DispatchKey::CUDA), dummyTensor(c10::DispatchKey::CPU));
|
||||||
|
|||||||
@ -172,7 +172,7 @@ VaryingShape<Stride> TensorType::computeStrideProps(
|
|||||||
// The logic below follows what TensorIterator uses in its logic:
|
// The logic below follows what TensorIterator uses in its logic:
|
||||||
// 1. Fast_set_up is the short-cut to identify a. channels_last and
|
// 1. Fast_set_up is the short-cut to identify a. channels_last and
|
||||||
// b. contiguous format, which is what we have in the below logic.
|
// b. contiguous format, which is what we have in the below logic.
|
||||||
// 2. In more general cases, it does best effort to preserve permutatoin.
|
// 2. In more generla cases, it does best effort to preserve permutatoin.
|
||||||
if (is_channels_last_strides_2d(sizes, strides) || is_channels_last_strides_3d(sizes, strides)) {
|
if (is_channels_last_strides_2d(sizes, strides) || is_channels_last_strides_3d(sizes, strides)) {
|
||||||
// case 1.a. short cut channels last
|
// case 1.a. short cut channels last
|
||||||
std::iota(stride_indices.rbegin() + 1, stride_indices.rend() - 1, 2);
|
std::iota(stride_indices.rbegin() + 1, stride_indices.rend() - 1, 2);
|
||||||
|
|||||||
@ -8,7 +8,6 @@
|
|||||||
#include <ATen/core/jit_type.h>
|
#include <ATen/core/jit_type.h>
|
||||||
#include <c10/macros/Macros.h>
|
#include <c10/macros/Macros.h>
|
||||||
#include <c10/util/env.h>
|
#include <c10/util/env.h>
|
||||||
#include <c10/util/Exception.h>
|
|
||||||
#include <c10/util/flat_hash_map.h>
|
#include <c10/util/flat_hash_map.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
#include <array>
|
#include <array>
|
||||||
@ -827,7 +826,9 @@ TupleType::TupleType(
|
|||||||
: NamedType(TypeKind::TupleType, std::move(name)),
|
: NamedType(TypeKind::TupleType, std::move(name)),
|
||||||
elements_(std::move(elements)),
|
elements_(std::move(elements)),
|
||||||
has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) {
|
has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) {
|
||||||
TORCH_CHECK(v, "Can not create tuple with None type");
|
if (!v) {
|
||||||
|
throw std::runtime_error("Can not create tuple with None type");
|
||||||
|
}
|
||||||
return v->hasFreeVariables();
|
return v->hasFreeVariables();
|
||||||
})), schema_(std::move(schema)) {
|
})), schema_(std::move(schema)) {
|
||||||
|
|
||||||
|
|||||||
@ -104,6 +104,71 @@ class Vectorized<float> {
|
|||||||
}
|
}
|
||||||
return b;
|
return b;
|
||||||
}
|
}
|
||||||
|
// Implementation is picked from
|
||||||
|
// https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L105
|
||||||
|
inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) const {
|
||||||
|
const auto c1 =
|
||||||
|
svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); // x^1: 0x1.ffffecp-1f
|
||||||
|
const auto c2 =
|
||||||
|
svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); // x^2: 0x1.fffdb6p-2f
|
||||||
|
const auto c3 =
|
||||||
|
svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); // x^3: 0x1.555e66p-3f
|
||||||
|
const auto c4 =
|
||||||
|
svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); // x^4: 0x1.573e2ep-5f
|
||||||
|
const auto c5 =
|
||||||
|
svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); // x^5: 0x1.0e4020p-7f
|
||||||
|
const auto shift = svreinterpret_f32_u32(
|
||||||
|
svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f
|
||||||
|
const auto inv_ln2 = svreinterpret_f32_u32(
|
||||||
|
svdup_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f
|
||||||
|
const auto neg_ln2_hi = svreinterpret_f32_u32(svdup_n_u32(
|
||||||
|
0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f
|
||||||
|
const auto neg_ln2_lo = svreinterpret_f32_u32(svdup_n_u32(
|
||||||
|
0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f
|
||||||
|
const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity());
|
||||||
|
const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5)
|
||||||
|
const auto zero = svdup_n_f32(0.f);
|
||||||
|
const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125)
|
||||||
|
// Range reduction:
|
||||||
|
// e^x = 2^n * e^r
|
||||||
|
// where:
|
||||||
|
// n = floor(x / ln(2))
|
||||||
|
// r = x - n * ln(2)
|
||||||
|
//
|
||||||
|
// By adding x / ln(2) with 2^23 + 127 (shift):
|
||||||
|
// * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127
|
||||||
|
// forces decimal part
|
||||||
|
// of x / ln(2) out of the result. The integer part of x / ln(2) (i.e.
|
||||||
|
// n) + 127 will occupy the whole fraction part of z in FP32 format.
|
||||||
|
// Subtracting 2^23 + 127 (shift) from z will result in the integer part
|
||||||
|
// of x / ln(2) (i.e. n) because the decimal part has been pushed out
|
||||||
|
// and lost.
|
||||||
|
// * The addition of 127 makes the FP32 fraction part of z ready to be
|
||||||
|
// used as the exponent
|
||||||
|
// in FP32 format. Left shifting z by 23 bits will result in 2^n.
|
||||||
|
const auto z = svmla_f32_z(pg, shift, x, inv_ln2);
|
||||||
|
const auto n = svsub_f32_z(pg, z, shift);
|
||||||
|
const auto scale = svreinterpret_f32_u32(
|
||||||
|
svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n
|
||||||
|
// The calculation of n * ln(2) is done using 2 steps to achieve accuracy
|
||||||
|
// beyond FP32. This outperforms longer Taylor series (3-4 tabs) both in
|
||||||
|
// term of accuracy and performance.
|
||||||
|
const auto r_hi = svmla_f32_z(pg, x, n, neg_ln2_hi);
|
||||||
|
const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo);
|
||||||
|
// Compute the truncated Taylor series of e^r.
|
||||||
|
// poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5)
|
||||||
|
const auto r2 = svmul_f32_z(pg, r, r);
|
||||||
|
const auto p1 = svmul_f32_z(pg, c1, r);
|
||||||
|
const auto p23 = svmla_f32_z(pg, c2, c3, r);
|
||||||
|
const auto p45 = svmla_f32_z(pg, c4, c5, r);
|
||||||
|
const auto p2345 = svmla_f32_z(pg, p23, p45, r2);
|
||||||
|
const auto p12345 = svmla_f32_z(pg, p1, p2345, r2);
|
||||||
|
auto poly = svmla_f32_z(pg, scale, p12345, scale);
|
||||||
|
// Handle underflow and overflow.
|
||||||
|
poly = svsel_f32(svcmplt_f32(pg, x, min_input), zero, poly);
|
||||||
|
poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly);
|
||||||
|
return poly;
|
||||||
|
}
|
||||||
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
|
static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
|
||||||
if (count == size())
|
if (count == size())
|
||||||
return svld1_f32(ptrue, reinterpret_cast<const float*>(ptr));
|
return svld1_f32(ptrue, reinterpret_cast<const float*>(ptr));
|
||||||
@ -248,41 +313,11 @@ class Vectorized<float> {
|
|||||||
return USE_SLEEF(
|
return USE_SLEEF(
|
||||||
Vectorized<float>(Sleef_expm1fx_u10sve(values)), map(std::expm1));
|
Vectorized<float>(Sleef_expm1fx_u10sve(values)), map(std::expm1));
|
||||||
}
|
}
|
||||||
// Implementation copied from Arm Optimized Routines:
|
|
||||||
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/sve/expf.c
|
|
||||||
Vectorized<float> exp_u20() const {
|
Vectorized<float> exp_u20() const {
|
||||||
// special case to handle special inputs that are too large or too small
|
return exp();
|
||||||
// i.e. where there's at least one element x, s.t. |x| >= 87.3...
|
|
||||||
svbool_t is_special_case = svacgt(svptrue_b32(), values, 0x1.5d5e2ap+6f);
|
|
||||||
if (svptest_any(svptrue_b32(), is_special_case)) {
|
|
||||||
return exp();
|
|
||||||
}
|
|
||||||
const svfloat32_t ln2_hi = svdup_n_f32(0x1.62e4p-1f);
|
|
||||||
const svfloat32_t ln2_lo = svdup_n_f32(0x1.7f7d1cp-20f);
|
|
||||||
const svfloat32_t c1 = svdup_n_f32(0.5f);
|
|
||||||
const svfloat32_t inv_ln2 = svdup_n_f32(0x1.715476p+0f);
|
|
||||||
|
|
||||||
const float shift = 0x1.803f8p17f;
|
|
||||||
|
|
||||||
/* n = round(x/(ln2/N)). */
|
|
||||||
svfloat32_t z = svmad_x(svptrue_b32(), inv_ln2, values, shift);
|
|
||||||
svfloat32_t n = svsub_x(svptrue_b32(), z, shift);
|
|
||||||
|
|
||||||
/* r = x - n*ln2/N. */
|
|
||||||
svfloat32_t r = values;
|
|
||||||
r = svmls_x(svptrue_b32(), r, n, ln2_hi);
|
|
||||||
r = svmls_x(svptrue_b32(), r, n, ln2_lo);
|
|
||||||
|
|
||||||
/* scale = 2^(n/N). */
|
|
||||||
svfloat32_t scale = svexpa(svreinterpret_u32(z));
|
|
||||||
|
|
||||||
/* poly(r) = exp(r) - 1 ~= r + 0.5 r^2. */
|
|
||||||
svfloat32_t r2 = svmul_x(svptrue_b32(), r, r);
|
|
||||||
svfloat32_t poly = svmla_x(svptrue_b32(), r, r2, c1);
|
|
||||||
return svmla_x(svptrue_b32(), scale, scale, poly);
|
|
||||||
}
|
}
|
||||||
Vectorized<float> fexp_u20() const {
|
Vectorized<float> fexp_u20() const {
|
||||||
return exp_u20();
|
return exp();
|
||||||
}
|
}
|
||||||
Vectorized<float> fmod(const Vectorized<float>& q) const {USE_SLEEF(
|
Vectorized<float> fmod(const Vectorized<float>& q) const {USE_SLEEF(
|
||||||
{ return Vectorized<float>(Sleef_fmodfx_sve(values, q)); },
|
{ return Vectorized<float>(Sleef_fmodfx_sve(values, q)); },
|
||||||
@ -418,11 +453,9 @@ class Vectorized<float> {
|
|||||||
ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH);
|
ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH);
|
||||||
|
|
||||||
// Step 2: Calculate exp(2 * x), where x is the clamped value.
|
// Step 2: Calculate exp(2 * x), where x is the clamped value.
|
||||||
// svmul_f32_z computes 2 * x, and exp_u20() computes the exponential of
|
// svmul_f32_z computes 2 * x, and svexp_f32_z computes the exponential of
|
||||||
// the result (via Vectorized<float>, then auto-converts back to
|
// the result.
|
||||||
// svfloat32_t).
|
svfloat32_t exp2x = svexp_f32_z(ptrue, svmul_f32_z(ptrue, CONST_2, x));
|
||||||
svfloat32_t exp2x =
|
|
||||||
Vectorized<float>(svmul_f32_z(ptrue, CONST_2, x)).exp_u20();
|
|
||||||
|
|
||||||
// Step 3: Calculate the numerator of the tanh function, which is exp(2x)
|
// Step 3: Calculate the numerator of the tanh function, which is exp(2x)
|
||||||
// - 1.
|
// - 1.
|
||||||
|
|||||||
@ -6,11 +6,8 @@
|
|||||||
#ifdef __aarch64__
|
#ifdef __aarch64__
|
||||||
#if !defined(CPU_CAPABILITY_SVE)
|
#if !defined(CPU_CAPABILITY_SVE)
|
||||||
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
|
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
|
||||||
#include <ATen/cpu/vec/vec128/vec128_double_neon.h>
|
|
||||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
|
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
|
||||||
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
|
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
|
||||||
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
|
|
||||||
#include <ATen/cpu/vec/vec128/vec128_uint_aarch64.h>
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <ATen/cpu/vec/vec128/vec128_convert.h>
|
#include <ATen/cpu/vec/vec128/vec128_convert.h>
|
||||||
|
|||||||
@ -354,47 +354,9 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
|
|||||||
|
|
||||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
|
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
|
||||||
Vectorized frac() const;
|
Vectorized frac() const;
|
||||||
|
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
|
||||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
|
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
|
||||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
|
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
|
||||||
|
|
||||||
#ifdef __ARM_FEATURE_BF16
|
|
||||||
Vectorized<c10::BFloat16> neg() const {
|
|
||||||
return -values;
|
|
||||||
}
|
|
||||||
Vectorized<c10::BFloat16> reciprocal() const {
|
|
||||||
return 1.0f / values;
|
|
||||||
}
|
|
||||||
Vectorized<c10::BFloat16> operator==(
|
|
||||||
const Vectorized<c10::BFloat16>& other) const {
|
|
||||||
return values == other.values;
|
|
||||||
}
|
|
||||||
|
|
||||||
Vectorized<c10::BFloat16> operator!=(
|
|
||||||
const Vectorized<c10::BFloat16>& other) const {
|
|
||||||
return values != other.values;
|
|
||||||
}
|
|
||||||
|
|
||||||
Vectorized<c10::BFloat16> operator<(
|
|
||||||
const Vectorized<c10::BFloat16>& other) const {
|
|
||||||
return values < other.values;
|
|
||||||
}
|
|
||||||
|
|
||||||
Vectorized<c10::BFloat16> operator<=(
|
|
||||||
const Vectorized<c10::BFloat16>& other) const {
|
|
||||||
return values <= other.values;
|
|
||||||
}
|
|
||||||
|
|
||||||
Vectorized<c10::BFloat16> operator>(
|
|
||||||
const Vectorized<c10::BFloat16>& other) const {
|
|
||||||
return values > other.values;
|
|
||||||
}
|
|
||||||
|
|
||||||
Vectorized<c10::BFloat16> operator>=(
|
|
||||||
const Vectorized<c10::BFloat16>& other) const {
|
|
||||||
return values >= other.values;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
|
|
||||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
|
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
|
||||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
|
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
|
||||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
|
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
|
||||||
@ -402,7 +364,6 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
|
|||||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
|
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
|
||||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
|
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
|
||||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
|
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
|
||||||
#endif
|
|
||||||
|
|
||||||
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
|
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
|
||||||
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
|
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
|
||||||
@ -451,52 +412,28 @@ template <>
|
|||||||
Vectorized<c10::BFloat16> inline operator+(
|
Vectorized<c10::BFloat16> inline operator+(
|
||||||
const Vectorized<c10::BFloat16>& a,
|
const Vectorized<c10::BFloat16>& a,
|
||||||
const Vectorized<c10::BFloat16>& b) {
|
const Vectorized<c10::BFloat16>& b) {
|
||||||
#ifdef __ARM_FEATURE_BF16
|
|
||||||
bfloat16x8_t x = a;
|
|
||||||
bfloat16x8_t y = b;
|
|
||||||
return x + y;
|
|
||||||
#else
|
|
||||||
return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
|
return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Vectorized<c10::BFloat16> inline operator-(
|
Vectorized<c10::BFloat16> inline operator-(
|
||||||
const Vectorized<c10::BFloat16>& a,
|
const Vectorized<c10::BFloat16>& a,
|
||||||
const Vectorized<c10::BFloat16>& b) {
|
const Vectorized<c10::BFloat16>& b) {
|
||||||
#ifdef __ARM_FEATURE_BF16
|
|
||||||
bfloat16x8_t x = a;
|
|
||||||
bfloat16x8_t y = b;
|
|
||||||
return x - y;
|
|
||||||
#else
|
|
||||||
return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
|
return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Vectorized<c10::BFloat16> inline operator*(
|
Vectorized<c10::BFloat16> inline operator*(
|
||||||
const Vectorized<c10::BFloat16>& a,
|
const Vectorized<c10::BFloat16>& a,
|
||||||
const Vectorized<c10::BFloat16>& b) {
|
const Vectorized<c10::BFloat16>& b) {
|
||||||
#ifdef __ARM_FEATURE_BF16
|
|
||||||
bfloat16x8_t x = a;
|
|
||||||
bfloat16x8_t y = b;
|
|
||||||
return x * y;
|
|
||||||
#else
|
|
||||||
return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
|
return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Vectorized<c10::BFloat16> inline operator/(
|
Vectorized<c10::BFloat16> inline operator/(
|
||||||
const Vectorized<c10::BFloat16>& a,
|
const Vectorized<c10::BFloat16>& a,
|
||||||
const Vectorized<c10::BFloat16>& b) {
|
const Vectorized<c10::BFloat16>& b) {
|
||||||
#ifdef __ARM_FEATURE_BF16
|
|
||||||
bfloat16x8_t x = a;
|
|
||||||
bfloat16x8_t y = b;
|
|
||||||
return x / y;
|
|
||||||
#else
|
|
||||||
return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
|
return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// frac. Implement this here so we can use subtraction
|
// frac. Implement this here so we can use subtraction
|
||||||
@ -607,19 +544,12 @@ Vectorized<c10::BFloat16> inline fmadd(
|
|||||||
const Vectorized<c10::BFloat16>& a,
|
const Vectorized<c10::BFloat16>& a,
|
||||||
const Vectorized<c10::BFloat16>& b,
|
const Vectorized<c10::BFloat16>& b,
|
||||||
const Vectorized<c10::BFloat16>& c) {
|
const Vectorized<c10::BFloat16>& c) {
|
||||||
#ifdef __ARM_FEATURE_BF16
|
|
||||||
bfloat16x8_t x = a;
|
|
||||||
bfloat16x8_t y = b;
|
|
||||||
bfloat16x8_t z = c;
|
|
||||||
return x * y + z;
|
|
||||||
#else
|
|
||||||
// NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also,
|
// NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also,
|
||||||
// vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
|
// vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
|
||||||
// elements, not the bottom and top half, so they don't seem
|
// elements, not the bottom and top half, so they don't seem
|
||||||
// particularly useful here. Ideally we would include dot product in
|
// particularly useful here. Ideally we would include dot product in
|
||||||
// the Vectorized interface...
|
// the Vectorized interface...
|
||||||
return a * b + c;
|
return a * b + c;
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -627,15 +557,8 @@ Vectorized<c10::BFloat16> inline fnmadd(
|
|||||||
const Vectorized<c10::BFloat16>& a,
|
const Vectorized<c10::BFloat16>& a,
|
||||||
const Vectorized<c10::BFloat16>& b,
|
const Vectorized<c10::BFloat16>& b,
|
||||||
const Vectorized<c10::BFloat16>& c) {
|
const Vectorized<c10::BFloat16>& c) {
|
||||||
#ifdef __ARM_FEATURE_BF16
|
|
||||||
bfloat16x8_t x = a;
|
|
||||||
bfloat16x8_t y = b;
|
|
||||||
bfloat16x8_t z = c;
|
|
||||||
return (-x) * y + z;
|
|
||||||
#else
|
|
||||||
// See NOTE [BF16 FMA] above.
|
// See NOTE [BF16 FMA] above.
|
||||||
return -a * b + c;
|
return -a * b + c;
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -643,15 +566,8 @@ Vectorized<c10::BFloat16> inline fmsub(
|
|||||||
const Vectorized<c10::BFloat16>& a,
|
const Vectorized<c10::BFloat16>& a,
|
||||||
const Vectorized<c10::BFloat16>& b,
|
const Vectorized<c10::BFloat16>& b,
|
||||||
const Vectorized<c10::BFloat16>& c) {
|
const Vectorized<c10::BFloat16>& c) {
|
||||||
#ifdef __ARM_FEATURE_BF16
|
|
||||||
bfloat16x8_t x = a;
|
|
||||||
bfloat16x8_t y = b;
|
|
||||||
bfloat16x8_t z = c;
|
|
||||||
return x * y - z;
|
|
||||||
#else
|
|
||||||
// See NOTE [BF16 FMA] above.
|
// See NOTE [BF16 FMA] above.
|
||||||
return a * b - c;
|
return a * b - c;
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -659,15 +575,8 @@ Vectorized<c10::BFloat16> inline fnmsub(
|
|||||||
const Vectorized<c10::BFloat16>& a,
|
const Vectorized<c10::BFloat16>& a,
|
||||||
const Vectorized<c10::BFloat16>& b,
|
const Vectorized<c10::BFloat16>& b,
|
||||||
const Vectorized<c10::BFloat16>& c) {
|
const Vectorized<c10::BFloat16>& c) {
|
||||||
#ifdef __ARM_FEATURE_BF16
|
|
||||||
bfloat16x8_t x = a;
|
|
||||||
bfloat16x8_t y = b;
|
|
||||||
bfloat16x8_t z = c;
|
|
||||||
return (-x) * y - z;
|
|
||||||
#else
|
|
||||||
// See NOTE [BF16 FMA] above.
|
// See NOTE [BF16 FMA] above.
|
||||||
return -a * b - c;
|
return -a * b - c;
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
|
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
|
||||||
|
|||||||
@ -5,114 +5,6 @@
|
|||||||
namespace at::vec {
|
namespace at::vec {
|
||||||
inline namespace CPU_CAPABILITY {
|
inline namespace CPU_CAPABILITY {
|
||||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
|
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
|
||||||
|
|
||||||
// Enable auto-vectorization for GCC-13+ and clang-17+
|
|
||||||
// GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001
|
|
||||||
#if __GNUC__ > 12 || (defined(__clang__) && (__clang_major__ >= 17))
|
|
||||||
|
|
||||||
template <typename from_type, typename to_type>
|
|
||||||
inline void convertImpl(
|
|
||||||
const from_type* __restrict src,
|
|
||||||
to_type* __restrict dst,
|
|
||||||
int64_t n) {
|
|
||||||
uint64_t len = static_cast<uint64_t>(n);
|
|
||||||
for (uint64_t i = 0; i < len; i++) {
|
|
||||||
dst[i] = static_cast<to_type>(src[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#define CONVERT_TEMPLATE(from_type, to_type) \
|
|
||||||
template <> \
|
|
||||||
inline void convert(const from_type* src, to_type* dst, int64_t n) { \
|
|
||||||
return convertImpl<from_type, to_type>(src, dst, n); \
|
|
||||||
}
|
|
||||||
|
|
||||||
CONVERT_TEMPLATE(uint8_t, uint8_t)
|
|
||||||
CONVERT_TEMPLATE(uint8_t, int8_t)
|
|
||||||
CONVERT_TEMPLATE(uint8_t, int16_t)
|
|
||||||
CONVERT_TEMPLATE(uint8_t, int32_t)
|
|
||||||
CONVERT_TEMPLATE(uint8_t, int64_t)
|
|
||||||
CONVERT_TEMPLATE(uint8_t, float)
|
|
||||||
CONVERT_TEMPLATE(uint8_t, double)
|
|
||||||
CONVERT_TEMPLATE(int8_t, uint8_t)
|
|
||||||
CONVERT_TEMPLATE(int8_t, int8_t)
|
|
||||||
CONVERT_TEMPLATE(int8_t, int16_t)
|
|
||||||
CONVERT_TEMPLATE(int8_t, int32_t)
|
|
||||||
CONVERT_TEMPLATE(int8_t, int64_t)
|
|
||||||
CONVERT_TEMPLATE(int8_t, float)
|
|
||||||
CONVERT_TEMPLATE(int8_t, double)
|
|
||||||
CONVERT_TEMPLATE(int16_t, uint8_t)
|
|
||||||
CONVERT_TEMPLATE(int16_t, int8_t)
|
|
||||||
CONVERT_TEMPLATE(int16_t, int16_t)
|
|
||||||
CONVERT_TEMPLATE(int16_t, int32_t)
|
|
||||||
CONVERT_TEMPLATE(int16_t, int64_t)
|
|
||||||
CONVERT_TEMPLATE(int16_t, float)
|
|
||||||
CONVERT_TEMPLATE(int16_t, double)
|
|
||||||
CONVERT_TEMPLATE(int32_t, uint8_t)
|
|
||||||
CONVERT_TEMPLATE(int32_t, int8_t)
|
|
||||||
CONVERT_TEMPLATE(int32_t, int16_t)
|
|
||||||
CONVERT_TEMPLATE(int32_t, int32_t)
|
|
||||||
CONVERT_TEMPLATE(int32_t, int64_t)
|
|
||||||
CONVERT_TEMPLATE(int32_t, float)
|
|
||||||
CONVERT_TEMPLATE(int32_t, double)
|
|
||||||
CONVERT_TEMPLATE(int64_t, uint8_t)
|
|
||||||
CONVERT_TEMPLATE(int64_t, int8_t)
|
|
||||||
CONVERT_TEMPLATE(int64_t, int16_t)
|
|
||||||
CONVERT_TEMPLATE(int64_t, int32_t)
|
|
||||||
CONVERT_TEMPLATE(int64_t, int64_t)
|
|
||||||
CONVERT_TEMPLATE(int64_t, float)
|
|
||||||
CONVERT_TEMPLATE(int64_t, double)
|
|
||||||
CONVERT_TEMPLATE(float, uint8_t)
|
|
||||||
CONVERT_TEMPLATE(float, int8_t)
|
|
||||||
CONVERT_TEMPLATE(float, int16_t)
|
|
||||||
CONVERT_TEMPLATE(float, int32_t)
|
|
||||||
CONVERT_TEMPLATE(float, int64_t)
|
|
||||||
CONVERT_TEMPLATE(float, float)
|
|
||||||
CONVERT_TEMPLATE(float, double)
|
|
||||||
CONVERT_TEMPLATE(double, uint8_t)
|
|
||||||
CONVERT_TEMPLATE(double, int8_t)
|
|
||||||
CONVERT_TEMPLATE(double, int16_t)
|
|
||||||
CONVERT_TEMPLATE(double, int32_t)
|
|
||||||
CONVERT_TEMPLATE(double, int64_t)
|
|
||||||
CONVERT_TEMPLATE(double, float)
|
|
||||||
CONVERT_TEMPLATE(double, double)
|
|
||||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
|
||||||
CONVERT_TEMPLATE(float16_t, uint8_t)
|
|
||||||
CONVERT_TEMPLATE(float16_t, int8_t)
|
|
||||||
CONVERT_TEMPLATE(float16_t, int16_t)
|
|
||||||
CONVERT_TEMPLATE(float16_t, int32_t)
|
|
||||||
CONVERT_TEMPLATE(float16_t, int64_t)
|
|
||||||
CONVERT_TEMPLATE(float16_t, float16_t)
|
|
||||||
CONVERT_TEMPLATE(float16_t, float)
|
|
||||||
CONVERT_TEMPLATE(float16_t, double)
|
|
||||||
CONVERT_TEMPLATE(uint8_t, float16_t)
|
|
||||||
CONVERT_TEMPLATE(int8_t, float16_t)
|
|
||||||
CONVERT_TEMPLATE(int16_t, float16_t)
|
|
||||||
CONVERT_TEMPLATE(int32_t, float16_t)
|
|
||||||
CONVERT_TEMPLATE(int64_t, float16_t)
|
|
||||||
CONVERT_TEMPLATE(float, float16_t)
|
|
||||||
CONVERT_TEMPLATE(double, float16_t)
|
|
||||||
#endif
|
|
||||||
#ifdef __ARM_FEATURE_BF16
|
|
||||||
CONVERT_TEMPLATE(bfloat16_t, uint8_t)
|
|
||||||
CONVERT_TEMPLATE(bfloat16_t, int8_t)
|
|
||||||
CONVERT_TEMPLATE(bfloat16_t, int16_t)
|
|
||||||
CONVERT_TEMPLATE(bfloat16_t, int32_t)
|
|
||||||
CONVERT_TEMPLATE(bfloat16_t, int64_t)
|
|
||||||
CONVERT_TEMPLATE(bfloat16_t, bfloat16_t)
|
|
||||||
CONVERT_TEMPLATE(bfloat16_t, float)
|
|
||||||
CONVERT_TEMPLATE(bfloat16_t, double)
|
|
||||||
CONVERT_TEMPLATE(uint8_t, bfloat16_t)
|
|
||||||
CONVERT_TEMPLATE(int8_t, bfloat16_t)
|
|
||||||
CONVERT_TEMPLATE(int16_t, bfloat16_t)
|
|
||||||
CONVERT_TEMPLATE(int32_t, bfloat16_t)
|
|
||||||
CONVERT_TEMPLATE(int64_t, bfloat16_t)
|
|
||||||
CONVERT_TEMPLATE(float, bfloat16_t)
|
|
||||||
CONVERT_TEMPLATE(double, bfloat16_t)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template <typename src_t>
|
template <typename src_t>
|
||||||
struct VecConvert<
|
struct VecConvert<
|
||||||
float,
|
float,
|
||||||
|
|||||||
@ -1,586 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <ATen/cpu/vec/intrinsics.h>
|
|
||||||
#include <ATen/cpu/vec/vec_base.h>
|
|
||||||
#include <c10/macros/Macros.h>
|
|
||||||
#include <c10/util/irange.h>
|
|
||||||
#include <cmath>
|
|
||||||
|
|
||||||
namespace at::vec {
|
|
||||||
// Note [CPU_CAPABILITY namespace]
|
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
// This header, and all of its subheaders, will be compiled with
|
|
||||||
// different architecture flags for each supported set of vector
|
|
||||||
// intrinsics. So we need to make sure they aren't inadvertently
|
|
||||||
// linked together. We do this by declaring objects in an `inline
|
|
||||||
// namespace` which changes the name mangling, but can still be
|
|
||||||
// accessed as `at::vec`.
|
|
||||||
inline namespace CPU_CAPABILITY {
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct is_vec_specialized_for<double> : std::bool_constant<true> {};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
class Vectorized<double> {
|
|
||||||
private:
|
|
||||||
float64x2_t values;
|
|
||||||
|
|
||||||
public:
|
|
||||||
using value_type = double;
|
|
||||||
using size_type = int;
|
|
||||||
static constexpr size_type size() {
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
Vectorized() {
|
|
||||||
values = vdupq_n_f64(0.0);
|
|
||||||
}
|
|
||||||
Vectorized(float64x2_t v) : values(v) {}
|
|
||||||
Vectorized(double val) {
|
|
||||||
values = vdupq_n_f64(val);
|
|
||||||
}
|
|
||||||
template <
|
|
||||||
typename... Args,
|
|
||||||
typename = std::enable_if_t<(sizeof...(Args) == size())>>
|
|
||||||
Vectorized(Args... vals) {
|
|
||||||
__at_align__ double buffer[size()] = {vals...};
|
|
||||||
values = vld1q_f64(buffer);
|
|
||||||
}
|
|
||||||
operator float64x2_t() const {
|
|
||||||
return values;
|
|
||||||
}
|
|
||||||
template <int64_t mask>
|
|
||||||
static Vectorized<double> blend(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& b) {
|
|
||||||
// Build an array of flags: each bit of element is 1 if the corresponding
|
|
||||||
// bit in 'mask' is set, 0 otherwise.
|
|
||||||
uint64x2_t maskArray = {
|
|
||||||
(mask & 1ULL) ? 0xFFFFFFFFFFFFFFFF : 0,
|
|
||||||
(mask & 2ULL) ? 0xFFFFFFFFFFFFFFFF : 0};
|
|
||||||
// Use BSL to select elements from b where the mask is 1, else from a
|
|
||||||
return vbslq_f64(maskArray, b.values, a.values);
|
|
||||||
}
|
|
||||||
static Vectorized<double> blendv(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& b,
|
|
||||||
const Vectorized<double>& mask_) {
|
|
||||||
return vbslq_f64(vreinterpretq_u64_f64(mask_.values), b.values, a.values);
|
|
||||||
}
|
|
||||||
template <typename step_t>
|
|
||||||
static Vectorized<double> arange(
|
|
||||||
double base = 0.,
|
|
||||||
step_t step = static_cast<step_t>(1)) {
|
|
||||||
return {base, base + static_cast<double>(step)};
|
|
||||||
}
|
|
||||||
static inline Vectorized<double> set(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& b,
|
|
||||||
int64_t count = size()) {
|
|
||||||
if (count == 0) {
|
|
||||||
return a;
|
|
||||||
} else if (count >= 2) {
|
|
||||||
return b;
|
|
||||||
} else {
|
|
||||||
float64x2_t c = {b.values[0], a.values[1]};
|
|
||||||
return c;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
static Vectorized<double> loadu(const void* ptr, int64_t count = size()) {
|
|
||||||
if (count == size()) {
|
|
||||||
return vld1q_f64(reinterpret_cast<const double*>(ptr));
|
|
||||||
} else if (count == 1) {
|
|
||||||
float64x1_t x = vld1_f64(reinterpret_cast<const double*>(ptr));
|
|
||||||
float64x1_t z = {0.0};
|
|
||||||
return vcombine_f64(x, z);
|
|
||||||
} else {
|
|
||||||
return vdupq_n_f64(0.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void store(void* ptr, int64_t count = size()) const {
|
|
||||||
if (count == size()) {
|
|
||||||
vst1q_f64(reinterpret_cast<double*>(ptr), values);
|
|
||||||
} else if (count == 1) {
|
|
||||||
vst1_f64(reinterpret_cast<double*>(ptr), vget_low_f64(values));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const double& operator[](int idx) const = delete;
|
|
||||||
double& operator[](int idx) = delete;
|
|
||||||
int64_t zero_mask() const {
|
|
||||||
// returns an integer mask where all zero elements are translated to 1-bit
|
|
||||||
// and others are translated to 0-bit
|
|
||||||
uint64x2_t cmpReg = vceqzq_f64(values);
|
|
||||||
uint64x2_t mask = {1, 2};
|
|
||||||
uint64x2_t res = vandq_u64(cmpReg, mask);
|
|
||||||
return res[0] | res[1];
|
|
||||||
}
|
|
||||||
Vectorized<double> isnan() const {
|
|
||||||
// NaN check
|
|
||||||
return vreinterpretq_f64_u32(
|
|
||||||
vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, values))));
|
|
||||||
}
|
|
||||||
bool has_inf_nan() const {
|
|
||||||
Vectorized<double> x = vsubq_f64(values, values);
|
|
||||||
float64x2_t r = x.isnan();
|
|
||||||
uint64x2_t u = vreinterpretq_u64_f64(r);
|
|
||||||
return u[0] | u[1];
|
|
||||||
}
|
|
||||||
Vectorized<double> map(double (*f)(double)) const {
|
|
||||||
float64x2_t result;
|
|
||||||
result[0] = f(values[0]);
|
|
||||||
result[1] = f(values[1]);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
Vectorized<double> map2(
|
|
||||||
const Vectorized<double>& second,
|
|
||||||
double (*const f)(double, double)) const {
|
|
||||||
float64x2_t result;
|
|
||||||
result[0] = f(values[0], second.values[0]);
|
|
||||||
result[1] = f(values[1], second.values[1]);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
Vectorized<double> abs() const {
|
|
||||||
return vabsq_f64(values);
|
|
||||||
}
|
|
||||||
Vectorized<double> angle() const {
|
|
||||||
auto zero = Vectorized<double>(0.0);
|
|
||||||
auto pi = Vectorized<double>(c10::pi<double>);
|
|
||||||
auto tmp = blendv(zero, pi, vreinterpretq_f64_u64(vcltzq_f64(values)));
|
|
||||||
return blendv(tmp, *this, isnan());
|
|
||||||
}
|
|
||||||
Vectorized<double> real() const {
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
Vectorized<double> imag() const {
|
|
||||||
return Vectorized<double>(0.0);
|
|
||||||
}
|
|
||||||
Vectorized<double> conj() const {
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
Vectorized<double> acos() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_acosd2_u10(values)), map(std::acos));
|
|
||||||
}
|
|
||||||
Vectorized<double> acosh() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_acoshd2_u10(values)), map(std::acosh));
|
|
||||||
}
|
|
||||||
Vectorized<double> asin() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_asind2_u10(values)), map(std::asin));
|
|
||||||
}
|
|
||||||
Vectorized<double> asinh() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_asinhd2_u10(values)), map(std::asinh));
|
|
||||||
}
|
|
||||||
Vectorized<double> atan() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_atand2_u10(values)), map(std::atan));
|
|
||||||
}
|
|
||||||
Vectorized<double> atanh() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_atanhd2_u10(values)), map(std::atanh));
|
|
||||||
}
|
|
||||||
Vectorized<double> atan2(const Vectorized<double>& b) const {USE_SLEEF(
|
|
||||||
{ return Vectorized<double>(Sleef_atan2d2_u10(values, b)); },
|
|
||||||
{
|
|
||||||
__at_align__ double tmp[size()];
|
|
||||||
__at_align__ double tmp_b[size()];
|
|
||||||
store(tmp);
|
|
||||||
b.store(tmp_b);
|
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
|
||||||
tmp[i] = std::atan2(tmp[i], tmp_b[i]);
|
|
||||||
}
|
|
||||||
return loadu(tmp);
|
|
||||||
})} Vectorized<double> copysign(const Vectorized<double>& sign) const {
|
|
||||||
USE_SLEEF(
|
|
||||||
{ return Vectorized<double>(Sleef_copysignd2(values, sign)); },
|
|
||||||
{
|
|
||||||
__at_align__ double tmp[size()];
|
|
||||||
__at_align__ double tmp_sign[size()];
|
|
||||||
store(tmp);
|
|
||||||
sign.store(tmp_sign);
|
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
|
||||||
tmp[i] = std::copysign(tmp[i], tmp_sign[i]);
|
|
||||||
}
|
|
||||||
return loadu(tmp);
|
|
||||||
})} Vectorized<double> erf() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_erfd2_u10(values)), map(std::erf));
|
|
||||||
}
|
|
||||||
Vectorized<double> erfc() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_erfcd2_u15(values)), map(std::erfc));
|
|
||||||
}
|
|
||||||
Vectorized<double> exp() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_expd2_u10(values)), map(std::exp));
|
|
||||||
}
|
|
||||||
Vectorized<double> exp2() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_exp2d2_u10(values)), map(std::exp2));
|
|
||||||
}
|
|
||||||
Vectorized<double> expm1() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_expm1d2_u10(values)), map(std::expm1));
|
|
||||||
}
|
|
||||||
Vectorized<double> fmod(const Vectorized<double>& q) const {USE_SLEEF(
|
|
||||||
{ return Vectorized<double>(Sleef_fmodd2(values, q)); },
|
|
||||||
{
|
|
||||||
__at_align__ double tmp[size()];
|
|
||||||
__at_align__ double tmp_q[size()];
|
|
||||||
store(tmp);
|
|
||||||
q.store(tmp_q);
|
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
|
||||||
tmp[i] = std::fmod(tmp[i], tmp_q[i]);
|
|
||||||
}
|
|
||||||
return loadu(tmp);
|
|
||||||
})} Vectorized<double> hypot(const Vectorized<double>& b) const {
|
|
||||||
USE_SLEEF(
|
|
||||||
{ return Vectorized<double>(Sleef_hypotd2_u05(values, b)); },
|
|
||||||
{
|
|
||||||
__at_align__ double tmp[size()];
|
|
||||||
__at_align__ double tmp_b[size()];
|
|
||||||
store(tmp);
|
|
||||||
b.store(tmp_b);
|
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
|
||||||
tmp[i] = std::hypot(tmp[i], tmp_b[i]);
|
|
||||||
}
|
|
||||||
return loadu(tmp);
|
|
||||||
})} Vectorized<double> i0() const {
|
|
||||||
return map(calc_i0);
|
|
||||||
}
|
|
||||||
Vectorized<double> nextafter(const Vectorized<double>& b) const {USE_SLEEF(
|
|
||||||
{ return Vectorized<double>(Sleef_nextafterd2(values, b)); },
|
|
||||||
{
|
|
||||||
__at_align__ double tmp[size()];
|
|
||||||
__at_align__ double tmp_b[size()];
|
|
||||||
store(tmp);
|
|
||||||
b.store(tmp_b);
|
|
||||||
for (int64_t i = 0; i < size(); ++i) {
|
|
||||||
tmp[i] = std::nextafter(tmp[i], tmp_b[i]);
|
|
||||||
}
|
|
||||||
return loadu(tmp);
|
|
||||||
})} Vectorized<double> log() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_logd2_u10(values)), map(std::log));
|
|
||||||
}
|
|
||||||
Vectorized<double> log2() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_log2d2_u10(values)), map(std::log2));
|
|
||||||
}
|
|
||||||
Vectorized<double> log10() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_log10d2_u10(values)), map(std::log10));
|
|
||||||
}
|
|
||||||
Vectorized<double> log1p() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_log1pd2_u10(values)), map(std::log1p));
|
|
||||||
}
|
|
||||||
Vectorized<double> frac() const;
|
|
||||||
Vectorized<double> sin() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_sind2_u10(values)), map(std::sin));
|
|
||||||
}
|
|
||||||
Vectorized<double> sinh() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_sinhd2_u10(values)), map(std::sinh));
|
|
||||||
}
|
|
||||||
Vectorized<double> cos() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_cosd2_u10(values)), map(std::cos));
|
|
||||||
}
|
|
||||||
Vectorized<double> cosh() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_coshd2_u10(values)), map(std::cosh));
|
|
||||||
}
|
|
||||||
Vectorized<double> pow(const Vectorized<double>& b) const {USE_SLEEF(
|
|
||||||
{ return Vectorized<double>(Sleef_powd2_u10(values, b)); },
|
|
||||||
{
|
|
||||||
__at_align__ double tmp[size()];
|
|
||||||
__at_align__ double tmp_b[size()];
|
|
||||||
store(tmp);
|
|
||||||
b.store(tmp_b);
|
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
|
||||||
tmp[i] = std::pow(tmp[i], tmp_b[i]);
|
|
||||||
}
|
|
||||||
return loadu(tmp);
|
|
||||||
})} // Comparison using the _CMP_**_OQ predicate.
|
|
||||||
// `O`: get false if an operand is NaN
|
|
||||||
// `Q`: do not raise if an operand is NaN
|
|
||||||
Vectorized<double> tan() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_tand2_u10(values)), map(std::tan));
|
|
||||||
}
|
|
||||||
Vectorized<double> tanh() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_tanhd2_u10(values)), map(std::tanh));
|
|
||||||
}
|
|
||||||
Vectorized<double> lgamma() const {
|
|
||||||
return USE_SLEEF(
|
|
||||||
Vectorized<double>(Sleef_lgammad2_u10(values)), map(std::lgamma));
|
|
||||||
}
|
|
||||||
Vectorized<double> erfinv() const {
|
|
||||||
return map(calc_erfinv);
|
|
||||||
}
|
|
||||||
Vectorized<double> exp_u20() const {
|
|
||||||
return exp();
|
|
||||||
}
|
|
||||||
Vectorized<double> fexp_u20() const {
|
|
||||||
return exp();
|
|
||||||
}
|
|
||||||
Vectorized<double> i0e() const {
|
|
||||||
return map(calc_i0e);
|
|
||||||
}
|
|
||||||
Vectorized<double> digamma() const {
|
|
||||||
return map(calc_digamma);
|
|
||||||
}
|
|
||||||
Vectorized<double> igamma(const Vectorized<double>& x) const {
|
|
||||||
__at_align__ double tmp[size()];
|
|
||||||
__at_align__ double tmp_x[size()];
|
|
||||||
store(tmp);
|
|
||||||
x.store(tmp_x);
|
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
|
||||||
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
|
|
||||||
}
|
|
||||||
return loadu(tmp);
|
|
||||||
}
|
|
||||||
Vectorized<double> igammac(const Vectorized<double>& x) const {
|
|
||||||
__at_align__ double tmp[size()];
|
|
||||||
__at_align__ double tmp_x[size()];
|
|
||||||
store(tmp);
|
|
||||||
x.store(tmp_x);
|
|
||||||
for (int64_t i = 0; i < size(); i++) {
|
|
||||||
tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
|
|
||||||
}
|
|
||||||
return loadu(tmp);
|
|
||||||
}
|
|
||||||
Vectorized<double> ceil() const {
|
|
||||||
return vrndpq_f64(values);
|
|
||||||
}
|
|
||||||
Vectorized<double> floor() const {
|
|
||||||
return vrndmq_f64(values);
|
|
||||||
}
|
|
||||||
Vectorized<double> neg() const {
|
|
||||||
return vnegq_f64(values);
|
|
||||||
}
|
|
||||||
Vectorized<double> round() const {
|
|
||||||
return vrndiq_f64(values);
|
|
||||||
}
|
|
||||||
Vectorized<double> trunc() const {
|
|
||||||
return vrndq_f64(values);
|
|
||||||
}
|
|
||||||
Vectorized<double> sqrt() const {
|
|
||||||
return vsqrtq_f64(values);
|
|
||||||
}
|
|
||||||
Vectorized<double> reciprocal() const {
|
|
||||||
return vdivq_f64(vdupq_n_f64(1.0), values);
|
|
||||||
}
|
|
||||||
Vectorized<double> rsqrt() const {
|
|
||||||
return vdivq_f64(vdupq_n_f64(1.0), vsqrtq_f64(values));
|
|
||||||
}
|
|
||||||
double reduce_add() const {
|
|
||||||
return vaddvq_f64(values);
|
|
||||||
}
|
|
||||||
double reduce_max() const {
|
|
||||||
return vmaxvq_f64(values);
|
|
||||||
}
|
|
||||||
Vectorized<double> operator==(const Vectorized<double>& other) const {
|
|
||||||
return Vectorized<double>(
|
|
||||||
vreinterpretq_f64_u64(vceqq_f64(values, other.values)));
|
|
||||||
}
|
|
||||||
|
|
||||||
Vectorized<double> operator!=(const Vectorized<double>& other) const {
|
|
||||||
float64x2_t r0 = vreinterpretq_f64_u32(
|
|
||||||
vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, other.values))));
|
|
||||||
return Vectorized<double>(r0);
|
|
||||||
}
|
|
||||||
|
|
||||||
Vectorized<double> operator<(const Vectorized<double>& other) const {
|
|
||||||
return Vectorized<double>(
|
|
||||||
vreinterpretq_f64_u64(vcltq_f64(values, other.values)));
|
|
||||||
}
|
|
||||||
|
|
||||||
Vectorized<double> operator<=(const Vectorized<double>& other) const {
|
|
||||||
return Vectorized<double>(
|
|
||||||
vreinterpretq_f64_u64(vcleq_f64(values, other.values)));
|
|
||||||
}
|
|
||||||
|
|
||||||
Vectorized<double> operator>(const Vectorized<double>& other) const {
|
|
||||||
return Vectorized<double>(
|
|
||||||
vreinterpretq_f64_u64(vcgtq_f64(values, other.values)));
|
|
||||||
}
|
|
||||||
|
|
||||||
Vectorized<double> operator>=(const Vectorized<double>& other) const {
|
|
||||||
return Vectorized<double>(
|
|
||||||
vreinterpretq_f64_u64(vcgeq_f64(values, other.values)));
|
|
||||||
}
|
|
||||||
|
|
||||||
Vectorized<double> eq(const Vectorized<double>& other) const;
|
|
||||||
Vectorized<double> ne(const Vectorized<double>& other) const;
|
|
||||||
Vectorized<double> gt(const Vectorized<double>& other) const;
|
|
||||||
Vectorized<double> ge(const Vectorized<double>& other) const;
|
|
||||||
Vectorized<double> lt(const Vectorized<double>& other) const;
|
|
||||||
Vectorized<double> le(const Vectorized<double>& other) const;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<double> inline operator+(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& b) {
|
|
||||||
return vaddq_f64(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<double> inline operator-(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& b) {
|
|
||||||
return vsubq_f64(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<double> inline operator*(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& b) {
|
|
||||||
return vmulq_f64(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<double> inline operator/(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& b) {
|
|
||||||
return vdivq_f64(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
// frac. Implement this here so we can use subtraction
|
|
||||||
Vectorized<double> inline Vectorized<double>::frac() const {
|
|
||||||
return *this - this->trunc();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
|
|
||||||
// either input is a NaN.
|
|
||||||
template <>
|
|
||||||
Vectorized<double> inline maximum(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& b) {
|
|
||||||
return vmaxq_f64(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
|
|
||||||
// either input is a NaN.
|
|
||||||
template <>
|
|
||||||
Vectorized<double> inline minimum(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& b) {
|
|
||||||
return vminq_f64(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<double> inline clamp(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& min,
|
|
||||||
const Vectorized<double>& max) {
|
|
||||||
return vminq_f64(max, vmaxq_f64(min, a));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<double> inline clamp_max(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& max) {
|
|
||||||
return vminq_f64(max, a);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<double> inline clamp_min(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& min) {
|
|
||||||
return vmaxq_f64(min, a);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<double> inline operator&(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& b) {
|
|
||||||
return vreinterpretq_f64_u64(
|
|
||||||
vandq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<double> inline operator|(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& b) {
|
|
||||||
return vreinterpretq_f64_u64(
|
|
||||||
vorrq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<double> inline operator^(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& b) {
|
|
||||||
return vreinterpretq_f64_u64(
|
|
||||||
veorq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b)));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline Vectorized<double> Vectorized<double>::eq(
|
|
||||||
const Vectorized<double>& other) const {
|
|
||||||
return (*this == other) & Vectorized<double>(1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline Vectorized<double> Vectorized<double>::ne(
|
|
||||||
const Vectorized<double>& other) const {
|
|
||||||
return (*this != other) & Vectorized<double>(1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline Vectorized<double> Vectorized<double>::gt(
|
|
||||||
const Vectorized<double>& other) const {
|
|
||||||
return (*this > other) & Vectorized<double>(1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline Vectorized<double> Vectorized<double>::ge(
|
|
||||||
const Vectorized<double>& other) const {
|
|
||||||
return (*this >= other) & Vectorized<double>(1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline Vectorized<double> Vectorized<double>::lt(
|
|
||||||
const Vectorized<double>& other) const {
|
|
||||||
return (*this < other) & Vectorized<double>(1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline Vectorized<double> Vectorized<double>::le(
|
|
||||||
const Vectorized<double>& other) const {
|
|
||||||
return (*this <= other) & Vectorized<double>(1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<double> inline fmadd(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& b,
|
|
||||||
const Vectorized<double>& c) {
|
|
||||||
return vfmaq_f64(c, a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<double> inline fnmadd(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& b,
|
|
||||||
const Vectorized<double>& c) {
|
|
||||||
return vfmsq_f64(c, a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<double> inline fmsub(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& b,
|
|
||||||
const Vectorized<double>& c) {
|
|
||||||
return vfmaq_f64(vnegq_f64(c), a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<double> inline fnmsub(
|
|
||||||
const Vectorized<double>& a,
|
|
||||||
const Vectorized<double>& b,
|
|
||||||
const Vectorized<double>& c) {
|
|
||||||
return vfmsq_f64(vnegq_f64(c), a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace CPU_CAPABILITY
|
|
||||||
} // namespace at::vec
|
|
||||||
@ -307,49 +307,11 @@ class Vectorized<float> {
|
|||||||
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp)
|
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp)
|
||||||
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2)
|
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2)
|
||||||
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
|
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
|
||||||
// Implementation copied from Arm Optimized Routine
|
|
||||||
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
|
|
||||||
Vectorized<float> exp_u20() const {
|
Vectorized<float> exp_u20() const {
|
||||||
// bail out to sleef if it's a special case:
|
return exp();
|
||||||
// i.e. there's an input s.t. |input| > 87.3....
|
|
||||||
const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
|
|
||||||
uint32x4_t cmp = vcagtq_f32(values, special_bound);
|
|
||||||
if (vpaddd_u64(vreinterpretq_u64_u32(cmp)) != 0) {
|
|
||||||
return exp();
|
|
||||||
}
|
|
||||||
|
|
||||||
const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f);
|
|
||||||
const float ln2_hi = 0x1.62e4p-1f;
|
|
||||||
const float ln2_lo = 0x1.7f7d1cp-20f;
|
|
||||||
const float c0 = 0x1.0e4020p-7f;
|
|
||||||
const float c2 = 0x1.555e66p-3f;
|
|
||||||
const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2};
|
|
||||||
|
|
||||||
const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000);
|
|
||||||
const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f);
|
|
||||||
const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f);
|
|
||||||
const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f);
|
|
||||||
|
|
||||||
/* exp(x) = 2^n (1 + poly(r)), with 1 + poly(r) in [1/sqrt(2),sqrt(2)]
|
|
||||||
x = ln2*n + r, with r in [-ln2/2, ln2/2]. */
|
|
||||||
|
|
||||||
float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2));
|
|
||||||
float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0);
|
|
||||||
r = vfmsq_laneq_f32(r, n, ln2_c02, 1);
|
|
||||||
uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23);
|
|
||||||
float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias));
|
|
||||||
|
|
||||||
float32x4_t r2 = vmulq_f32(r, r);
|
|
||||||
float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2);
|
|
||||||
float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3);
|
|
||||||
q = vfmaq_f32(q, p, r2);
|
|
||||||
p = vmulq_f32(c4, r);
|
|
||||||
float32x4_t poly = vfmaq_f32(p, q, r2);
|
|
||||||
|
|
||||||
return vfmaq_f32(scale, poly, scale);
|
|
||||||
}
|
}
|
||||||
Vectorized<float> fexp_u20() const {
|
Vectorized<float> fexp_u20() const {
|
||||||
return exp_u20();
|
return exp();
|
||||||
}
|
}
|
||||||
DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(
|
DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(
|
||||||
fmod,
|
fmod,
|
||||||
@ -578,6 +540,42 @@ inline Vectorized<float> Vectorized<float>::le(
|
|||||||
return (*this <= other) & Vectorized<float>(1.0f);
|
return (*this <= other) & Vectorized<float>(1.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline void convert(const float* src, int32_t* dst, int64_t n) {
|
||||||
|
int64_t i;
|
||||||
|
#ifndef __msvc_cl__
|
||||||
|
#pragma unroll
|
||||||
|
#endif
|
||||||
|
for (i = 0; i <= (n - Vectorized<float>::size());
|
||||||
|
i += Vectorized<float>::size()) {
|
||||||
|
vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i)));
|
||||||
|
}
|
||||||
|
#ifndef __msvc_cl__
|
||||||
|
#pragma unroll
|
||||||
|
#endif
|
||||||
|
for (; i < n; i++) {
|
||||||
|
dst[i] = static_cast<int32_t>(src[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline void convert(const int32_t* src, float* dst, int64_t n) {
|
||||||
|
int64_t i;
|
||||||
|
#ifndef __msvc_cl__
|
||||||
|
#pragma unroll
|
||||||
|
#endif
|
||||||
|
for (i = 0; i <= (n - Vectorized<float>::size());
|
||||||
|
i += Vectorized<float>::size()) {
|
||||||
|
vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i)));
|
||||||
|
}
|
||||||
|
#ifndef __msvc_cl__
|
||||||
|
#pragma unroll
|
||||||
|
#endif
|
||||||
|
for (; i < n; i++) {
|
||||||
|
dst[i] = static_cast<float>(src[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Vectorized<float> inline fmadd(
|
Vectorized<float> inline fmadd(
|
||||||
const Vectorized<float>& a,
|
const Vectorized<float>& a,
|
||||||
@ -634,7 +632,8 @@ inline Vectorized<float> Vectorized<float>::erf() const {
|
|||||||
// - exp(- x * x)
|
// - exp(- x * x)
|
||||||
auto pow_2 = (*this) * (*this);
|
auto pow_2 = (*this) * (*this);
|
||||||
auto neg_pow_2 = pow_2 ^ neg_zero_vec;
|
auto neg_pow_2 = pow_2 ^ neg_zero_vec;
|
||||||
auto tmp4 = neg_pow_2.exp();
|
auto tmp4 = neg_pow_2.map(
|
||||||
|
std::exp); // This can be swapped for a faster implementation of exp.
|
||||||
auto tmp5 = tmp4 ^ neg_zero_vec;
|
auto tmp5 = tmp4 ^ neg_zero_vec;
|
||||||
// erf(x) = sign(x) * (1 - r * t * exp(- x * x))
|
// erf(x) = sign(x) * (1 - r * t * exp(- x * x))
|
||||||
auto tmp6 = t * tmp5;
|
auto tmp6 = t * tmp5;
|
||||||
|
|||||||
@ -234,7 +234,7 @@ class Vectorized<c10::Half> : public Vectorized16<
|
|||||||
vshlq_u16(vandq_u16(is_zero_vec, vdupq_n_u16(1)), shift);
|
vshlq_u16(vandq_u16(is_zero_vec, vdupq_n_u16(1)), shift);
|
||||||
return vaddvq_u16(bits_vec);
|
return vaddvq_u16(bits_vec);
|
||||||
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||||
// use known working implementation.
|
// use known working implmentation.
|
||||||
__at_align__ value_type tmp[size()];
|
__at_align__ value_type tmp[size()];
|
||||||
store(tmp);
|
store(tmp);
|
||||||
int mask = 0;
|
int mask = 0;
|
||||||
@ -569,6 +569,46 @@ inline Vectorized<c10::Half> Vectorized<c10::Half>::le(
|
|||||||
return (*this <= other) & Vectorized<c10::Half>(1);
|
return (*this <= other) & Vectorized<c10::Half>(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// These are global functions, so the defaults in vec_base.h should
|
||||||
|
// work fine if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC is not available.
|
||||||
|
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||||
|
template <>
|
||||||
|
inline void convert(const float16_t* src, int16_t* dst, int64_t n) {
|
||||||
|
int64_t i;
|
||||||
|
#ifndef __msvc_cl__
|
||||||
|
#pragma unroll
|
||||||
|
#endif
|
||||||
|
for (i = 0; i <= (n - Vectorized<c10::Half>::size());
|
||||||
|
i += Vectorized<c10::Half>::size()) {
|
||||||
|
vst1q_s16(dst + i, vcvtq_s16_f16(vld1q_f16(src + i)));
|
||||||
|
}
|
||||||
|
#ifndef __msvc_cl__
|
||||||
|
#pragma unroll
|
||||||
|
#endif
|
||||||
|
for (; i < n; i++) {
|
||||||
|
dst[i] = static_cast<int16_t>(src[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline void convert(const int16_t* src, float16_t* dst, int64_t n) {
|
||||||
|
int64_t i;
|
||||||
|
#ifndef __msvc_cl__
|
||||||
|
#pragma unroll
|
||||||
|
#endif
|
||||||
|
for (i = 0; i <= (n - Vectorized<c10::Half>::size());
|
||||||
|
i += Vectorized<c10::Half>::size()) {
|
||||||
|
vst1q_f16(dst + i, vcvtq_f16_s16(vld1q_s16(src + i)));
|
||||||
|
}
|
||||||
|
#ifndef __msvc_cl__
|
||||||
|
#pragma unroll
|
||||||
|
#endif
|
||||||
|
for (; i < n; i++) {
|
||||||
|
dst[i] = static_cast<float16_t>(src[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Vectorized<c10::Half> inline fmadd(
|
Vectorized<c10::Half> inline fmadd(
|
||||||
const Vectorized<c10::Half>& a,
|
const Vectorized<c10::Half>& a,
|
||||||
|
|||||||
@ -1,794 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <ATen/cpu/vec/intrinsics.h>
|
|
||||||
#include <ATen/cpu/vec/vec_base.h>
|
|
||||||
#include <c10/macros/Macros.h>
|
|
||||||
#include <c10/util/irange.h>
|
|
||||||
|
|
||||||
namespace at::vec {
|
|
||||||
// Note [CPU_CAPABILITY namespace]
|
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
// This header, and all of its subheaders, will be compiled with
|
|
||||||
// different architecture flags for each supported set of vector
|
|
||||||
// intrinsics. So we need to make sure they aren't inadvertently
|
|
||||||
// linked together. We do this by declaring objects in an `inline
|
|
||||||
// namespace` which changes the name mangling, but can still be
|
|
||||||
// accessed as `at::vec`.
|
|
||||||
inline namespace CPU_CAPABILITY {
|
|
||||||
|
|
||||||
#define VEC_INT_NEON_TEMPLATE(vl, bit) \
|
|
||||||
template <> \
|
|
||||||
struct is_vec_specialized_for<int##bit##_t> : std::bool_constant<true> {}; \
|
|
||||||
\
|
|
||||||
template <> \
|
|
||||||
class Vectorized<int##bit##_t> { \
|
|
||||||
using neon_type = int##bit##x##vl##_t; \
|
|
||||||
\
|
|
||||||
private: \
|
|
||||||
neon_type values; \
|
|
||||||
\
|
|
||||||
public: \
|
|
||||||
using value_type = int##bit##_t; \
|
|
||||||
using size_type = int; \
|
|
||||||
static constexpr size_type size() { \
|
|
||||||
return vl; \
|
|
||||||
} \
|
|
||||||
Vectorized() { \
|
|
||||||
values = vdupq_n_s##bit(0); \
|
|
||||||
} \
|
|
||||||
Vectorized(neon_type v) : values(v) {} \
|
|
||||||
Vectorized(int##bit##_t val); \
|
|
||||||
template < \
|
|
||||||
typename... Args, \
|
|
||||||
typename = std::enable_if_t<(sizeof...(Args) == size())>> \
|
|
||||||
Vectorized(Args... vals) { \
|
|
||||||
__at_align__ int##bit##_t buffer[size()] = {vals...}; \
|
|
||||||
values = vld1q_s##bit(buffer); \
|
|
||||||
} \
|
|
||||||
operator neon_type() const { \
|
|
||||||
return values; \
|
|
||||||
} \
|
|
||||||
static Vectorized<int##bit##_t> loadu( \
|
|
||||||
const void* ptr, \
|
|
||||||
int64_t count = size()); \
|
|
||||||
void store(void* ptr, int64_t count = size()) const; \
|
|
||||||
template <int64_t mask> \
|
|
||||||
static Vectorized<int##bit##_t> blend( \
|
|
||||||
const Vectorized<int##bit##_t>& a, \
|
|
||||||
const Vectorized<int##bit##_t>& b); \
|
|
||||||
static Vectorized<int##bit##_t> blendv( \
|
|
||||||
const Vectorized<int##bit##_t>& a, \
|
|
||||||
const Vectorized<int##bit##_t>& b, \
|
|
||||||
const Vectorized<int##bit##_t>& mask_) { \
|
|
||||||
return vbslq_s##bit(vreinterpretq_u##bit##_s##bit(mask_.values), b, a); \
|
|
||||||
} \
|
|
||||||
template <typename step_t> \
|
|
||||||
static Vectorized<int##bit##_t> arange( \
|
|
||||||
value_type base = 0, \
|
|
||||||
step_t step = static_cast<step_t>(1)); \
|
|
||||||
static Vectorized<int##bit##_t> set( \
|
|
||||||
const Vectorized<int##bit##_t>& a, \
|
|
||||||
const Vectorized<int##bit##_t>& b, \
|
|
||||||
int64_t count = size()); \
|
|
||||||
const int##bit##_t& operator[](int idx) const = delete; \
|
|
||||||
int##bit##_t& operator[](int idx) = delete; \
|
|
||||||
Vectorized<int##bit##_t> abs() const { \
|
|
||||||
return vabsq_s##bit(values); \
|
|
||||||
} \
|
|
||||||
Vectorized<int##bit##_t> real() const { \
|
|
||||||
return values; \
|
|
||||||
} \
|
|
||||||
Vectorized<int##bit##_t> imag() const { \
|
|
||||||
return vdupq_n_s##bit(0); \
|
|
||||||
} \
|
|
||||||
Vectorized<int##bit##_t> conj() const { \
|
|
||||||
return values; \
|
|
||||||
} \
|
|
||||||
Vectorized<int##bit##_t> neg() const { \
|
|
||||||
return vnegq_s##bit(values); \
|
|
||||||
} \
|
|
||||||
int##bit##_t reduce_add() const { \
|
|
||||||
return vaddvq_s##bit(values); \
|
|
||||||
} \
|
|
||||||
int##bit##_t reduce_max() const; \
|
|
||||||
Vectorized<int##bit##_t> operator==( \
|
|
||||||
const Vectorized<int##bit##_t>& other) const { \
|
|
||||||
return Vectorized<value_type>( \
|
|
||||||
vreinterpretq_s##bit##_u##bit(vceqq_s##bit(values, other.values))); \
|
|
||||||
} \
|
|
||||||
Vectorized<int##bit##_t> operator!=( \
|
|
||||||
const Vectorized<int##bit##_t>& other) const; \
|
|
||||||
Vectorized<int##bit##_t> operator<( \
|
|
||||||
const Vectorized<int##bit##_t>& other) const { \
|
|
||||||
return Vectorized<value_type>( \
|
|
||||||
vreinterpretq_s##bit##_u##bit(vcltq_s##bit(values, other.values))); \
|
|
||||||
} \
|
|
||||||
Vectorized<int##bit##_t> operator<=( \
|
|
||||||
const Vectorized<int##bit##_t>& other) const { \
|
|
||||||
return Vectorized<value_type>( \
|
|
||||||
vreinterpretq_s##bit##_u##bit(vcleq_s##bit(values, other.values))); \
|
|
||||||
} \
|
|
||||||
Vectorized<int##bit##_t> operator>( \
|
|
||||||
const Vectorized<int##bit##_t>& other) const { \
|
|
||||||
return Vectorized<value_type>( \
|
|
||||||
vreinterpretq_s##bit##_u##bit(vcgtq_s##bit(values, other.values))); \
|
|
||||||
} \
|
|
||||||
Vectorized<int##bit##_t> operator>=( \
|
|
||||||
const Vectorized<int##bit##_t>& other) const { \
|
|
||||||
return Vectorized<value_type>( \
|
|
||||||
vreinterpretq_s##bit##_u##bit(vcgeq_s##bit(values, other.values))); \
|
|
||||||
} \
|
|
||||||
Vectorized<int##bit##_t> eq(const Vectorized<int##bit##_t>& other) const; \
|
|
||||||
Vectorized<int##bit##_t> ne(const Vectorized<int##bit##_t>& other) const; \
|
|
||||||
Vectorized<int##bit##_t> gt(const Vectorized<int##bit##_t>& other) const; \
|
|
||||||
Vectorized<int##bit##_t> ge(const Vectorized<int##bit##_t>& other) const; \
|
|
||||||
Vectorized<int##bit##_t> lt(const Vectorized<int##bit##_t>& other) const; \
|
|
||||||
Vectorized<int##bit##_t> le(const Vectorized<int##bit##_t>& other) const; \
|
|
||||||
}; \
|
|
||||||
template <> \
|
|
||||||
Vectorized<int##bit##_t> inline operator+( \
|
|
||||||
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
|
|
||||||
return vaddq_s##bit(a, b); \
|
|
||||||
} \
|
|
||||||
template <> \
|
|
||||||
Vectorized<int##bit##_t> inline operator-( \
|
|
||||||
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
|
|
||||||
return vsubq_s##bit(a, b); \
|
|
||||||
} \
|
|
||||||
template <> \
|
|
||||||
Vectorized<int##bit##_t> inline operator&( \
|
|
||||||
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
|
|
||||||
return vandq_s##bit(a, b); \
|
|
||||||
} \
|
|
||||||
template <> \
|
|
||||||
Vectorized<int##bit##_t> inline operator|( \
|
|
||||||
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
|
|
||||||
return vorrq_s##bit(a, b); \
|
|
||||||
} \
|
|
||||||
template <> \
|
|
||||||
Vectorized<int##bit##_t> inline operator^( \
|
|
||||||
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
|
|
||||||
return veorq_s##bit(a, b); \
|
|
||||||
} \
|
|
||||||
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::eq( \
|
|
||||||
const Vectorized<int##bit##_t>& other) const { \
|
|
||||||
return (*this == other) & Vectorized<int##bit##_t>(1); \
|
|
||||||
} \
|
|
||||||
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ne( \
|
|
||||||
const Vectorized<int##bit##_t>& other) const { \
|
|
||||||
return (*this != other) & Vectorized<int##bit##_t>(1); \
|
|
||||||
} \
|
|
||||||
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::gt( \
|
|
||||||
const Vectorized<int##bit##_t>& other) const { \
|
|
||||||
return (*this > other) & Vectorized<int##bit##_t>(1); \
|
|
||||||
} \
|
|
||||||
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ge( \
|
|
||||||
const Vectorized<int##bit##_t>& other) const { \
|
|
||||||
return (*this >= other) & Vectorized<int##bit##_t>(1); \
|
|
||||||
} \
|
|
||||||
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::lt( \
|
|
||||||
const Vectorized<int##bit##_t>& other) const { \
|
|
||||||
return (*this < other) & Vectorized<int##bit##_t>(1); \
|
|
||||||
} \
|
|
||||||
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::le( \
|
|
||||||
const Vectorized<int##bit##_t>& other) const { \
|
|
||||||
return (*this <= other) & Vectorized<int##bit##_t>(1); \
|
|
||||||
}
|
|
||||||
|
|
||||||
VEC_INT_NEON_TEMPLATE(2, 64)
|
|
||||||
VEC_INT_NEON_TEMPLATE(4, 32)
|
|
||||||
VEC_INT_NEON_TEMPLATE(8, 16)
|
|
||||||
VEC_INT_NEON_TEMPLATE(16, 8)
|
|
||||||
|
|
||||||
inline int32_t Vectorized<int32_t>::reduce_max() const {
|
|
||||||
return vmaxvq_s32(values);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline int16_t Vectorized<int16_t>::reduce_max() const {
|
|
||||||
return vmaxvq_s16(values);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline int8_t Vectorized<int8_t>::reduce_max() const {
|
|
||||||
return vmaxvq_s8(values);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int32_t> inline operator*(
|
|
||||||
const Vectorized<int32_t>& a,
|
|
||||||
const Vectorized<int32_t>& b) {
|
|
||||||
return vmulq_s32(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int16_t> inline operator*(
|
|
||||||
const Vectorized<int16_t>& a,
|
|
||||||
const Vectorized<int16_t>& b) {
|
|
||||||
return vmulq_s16(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int8_t> inline operator*(
|
|
||||||
const Vectorized<int8_t>& a,
|
|
||||||
const Vectorized<int8_t>& b) {
|
|
||||||
return vmulq_s8(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline Vectorized<int64_t> operator~(const Vectorized<int64_t>& a) {
|
|
||||||
int64x2_t val = a;
|
|
||||||
return ~val;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline Vectorized<int32_t> operator~(const Vectorized<int32_t>& a) {
|
|
||||||
return vmvnq_s32(a);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline Vectorized<int16_t> operator~(const Vectorized<int16_t>& a) {
|
|
||||||
return vmvnq_s16(a);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline Vectorized<int8_t> operator~(const Vectorized<int8_t>& a) {
|
|
||||||
return vmvnq_s8(a);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline Vectorized<int64_t> Vectorized<int64_t>::operator!=(
|
|
||||||
const Vectorized<int64_t>& other) const {
|
|
||||||
return ~(*this == other);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline Vectorized<int32_t> Vectorized<int32_t>::operator!=(
|
|
||||||
const Vectorized<int32_t>& other) const {
|
|
||||||
return ~(*this == other);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline Vectorized<int16_t> Vectorized<int16_t>::operator!=(
|
|
||||||
const Vectorized<int16_t>& other) const {
|
|
||||||
return ~(*this == other);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline Vectorized<int8_t> Vectorized<int8_t>::operator!=(
|
|
||||||
const Vectorized<int8_t>& other) const {
|
|
||||||
return ~(*this == other);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int32_t> inline minimum(
|
|
||||||
const Vectorized<int32_t>& a,
|
|
||||||
const Vectorized<int32_t>& b) {
|
|
||||||
return vminq_s32(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int16_t> inline minimum(
|
|
||||||
const Vectorized<int16_t>& a,
|
|
||||||
const Vectorized<int16_t>& b) {
|
|
||||||
return vminq_s16(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int8_t> inline minimum(
|
|
||||||
const Vectorized<int8_t>& a,
|
|
||||||
const Vectorized<int8_t>& b) {
|
|
||||||
return vminq_s8(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int32_t> inline maximum(
|
|
||||||
const Vectorized<int32_t>& a,
|
|
||||||
const Vectorized<int32_t>& b) {
|
|
||||||
return vmaxq_s32(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int16_t> inline maximum(
|
|
||||||
const Vectorized<int16_t>& a,
|
|
||||||
const Vectorized<int16_t>& b) {
|
|
||||||
return vmaxq_s16(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int8_t> inline maximum(
|
|
||||||
const Vectorized<int8_t>& a,
|
|
||||||
const Vectorized<int8_t>& b) {
|
|
||||||
return vmaxq_s8(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int64_t mask>
|
|
||||||
Vectorized<int64_t> Vectorized<int64_t>::blend(
|
|
||||||
const Vectorized<int64_t>& a,
|
|
||||||
const Vectorized<int64_t>& b) {
|
|
||||||
// Build an array of flags: each bit of element is 1 if the corresponding bit
|
|
||||||
// in 'mask' is set, 0 otherwise.
|
|
||||||
uint64x2_t maskArray = {
|
|
||||||
(mask & 1LL) ? 0xFFFFFFFFFFFFFFFF : 0,
|
|
||||||
(mask & 2LL) ? 0xFFFFFFFFFFFFFFFF : 0};
|
|
||||||
// Use BSL to select elements from b where the mask is 1, else from a
|
|
||||||
return vbslq_s64(maskArray, b.values, a.values);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int64_t mask>
|
|
||||||
Vectorized<int32_t> Vectorized<int32_t>::blend(
|
|
||||||
const Vectorized<int32_t>& a,
|
|
||||||
const Vectorized<int32_t>& b) {
|
|
||||||
// Build an array of flags: each bit of element is 1 if the corresponding bit
|
|
||||||
// in 'mask' is set, 0 otherwise.
|
|
||||||
uint32x4_t maskArray = {
|
|
||||||
(mask & 1LL) ? 0xFFFFFFFF : 0,
|
|
||||||
(mask & 2LL) ? 0xFFFFFFFF : 0,
|
|
||||||
(mask & 4LL) ? 0xFFFFFFFF : 0,
|
|
||||||
(mask & 8LL) ? 0xFFFFFFFF : 0};
|
|
||||||
// Use BSL to select elements from b where the mask is 1, else from a
|
|
||||||
return vbslq_s32(maskArray, b.values, a.values);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int64_t mask>
|
|
||||||
Vectorized<int16_t> Vectorized<int16_t>::blend(
|
|
||||||
const Vectorized<int16_t>& a,
|
|
||||||
const Vectorized<int16_t>& b) {
|
|
||||||
// Build an array of flags: each bit of element is 1 if the corresponding bit
|
|
||||||
// in 'mask' is set, 0 otherwise.
|
|
||||||
uint16x8_t maskArray = {
|
|
||||||
(mask & 1LL) ? 0xFFFF : 0,
|
|
||||||
(mask & 2LL) ? 0xFFFF : 0,
|
|
||||||
(mask & 4LL) ? 0xFFFF : 0,
|
|
||||||
(mask & 8LL) ? 0xFFFF : 0,
|
|
||||||
(mask & 16LL) ? 0xFFFF : 0,
|
|
||||||
(mask & 32LL) ? 0xFFFF : 0,
|
|
||||||
(mask & 64LL) ? 0xFFFF : 0,
|
|
||||||
(mask & 128LL) ? 0xFFFF : 0};
|
|
||||||
// Use BSL to select elements from b where the mask is 1, else from a
|
|
||||||
return vbslq_s16(maskArray, b.values, a.values);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int64_t mask>
|
|
||||||
Vectorized<int8_t> Vectorized<int8_t>::blend(
|
|
||||||
const Vectorized<int8_t>& a,
|
|
||||||
const Vectorized<int8_t>& b) {
|
|
||||||
// Build an array of flags: each bit of element is 1 if the corresponding bit
|
|
||||||
// in 'mask' is set, 0 otherwise.
|
|
||||||
uint8x16_t maskArray = {
|
|
||||||
(mask & 1LL) ? 0xFF : 0,
|
|
||||||
(mask & 2LL) ? 0xFF : 0,
|
|
||||||
(mask & 4LL) ? 0xFF : 0,
|
|
||||||
(mask & 8LL) ? 0xFF : 0,
|
|
||||||
(mask & 16LL) ? 0xFF : 0,
|
|
||||||
(mask & 32LL) ? 0xFF : 0,
|
|
||||||
(mask & 64LL) ? 0xFF : 0,
|
|
||||||
(mask & 128LL) ? 0xFF : 0,
|
|
||||||
(mask & 256LL) ? 0xFF : 0,
|
|
||||||
(mask & 512LL) ? 0xFF : 0,
|
|
||||||
(mask & 1024LL) ? 0xFF : 0,
|
|
||||||
(mask & 2048LL) ? 0xFF : 0,
|
|
||||||
(mask & 4096LL) ? 0xFF : 0,
|
|
||||||
(mask & 8192LL) ? 0xFF : 0,
|
|
||||||
(mask & 16384LL) ? 0xFF : 0,
|
|
||||||
(mask & 32768LL) ? 0xFF : 0};
|
|
||||||
// Use BSL to select elements from b where the mask is 1, else from a
|
|
||||||
return vbslq_s8(maskArray, b.values, a.values);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define VEC_INT_NEON_OPS(vl, bit) \
|
|
||||||
inline Vectorized<int##bit##_t>::Vectorized(int##bit##_t val) { \
|
|
||||||
values = vdupq_n_s##bit(val); \
|
|
||||||
} \
|
|
||||||
inline Vectorized<int##bit##_t> Vectorized<int##bit##_t>::loadu( \
|
|
||||||
const void* ptr, int64_t count) { \
|
|
||||||
if (count == size()) { \
|
|
||||||
return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(ptr)); \
|
|
||||||
} else { \
|
|
||||||
__at_align__ int##bit##_t tmp_values[size()]; \
|
|
||||||
for (const auto i : c10::irange(size())) { \
|
|
||||||
tmp_values[i] = 0; \
|
|
||||||
} \
|
|
||||||
std::memcpy( \
|
|
||||||
tmp_values, \
|
|
||||||
reinterpret_cast<const int##bit##_t*>(ptr), \
|
|
||||||
count * sizeof(int##bit##_t)); \
|
|
||||||
return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(tmp_values)); \
|
|
||||||
} \
|
|
||||||
} \
|
|
||||||
inline void Vectorized<int##bit##_t>::store(void* ptr, int64_t count) \
|
|
||||||
const { \
|
|
||||||
if (count == size()) { \
|
|
||||||
vst1q_s##bit(reinterpret_cast<int##bit##_t*>(ptr), values); \
|
|
||||||
} else { \
|
|
||||||
int##bit##_t tmp_values[size()]; \
|
|
||||||
vst1q_s##bit(reinterpret_cast<int##bit##_t*>(tmp_values), values); \
|
|
||||||
std::memcpy(ptr, tmp_values, count * sizeof(int##bit##_t)); \
|
|
||||||
} \
|
|
||||||
}
|
|
||||||
|
|
||||||
VEC_INT_NEON_OPS(2, 64)
|
|
||||||
VEC_INT_NEON_OPS(4, 32)
|
|
||||||
VEC_INT_NEON_OPS(8, 16)
|
|
||||||
VEC_INT_NEON_OPS(16, 8)
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int64_t> inline operator*(
|
|
||||||
const Vectorized<int64_t>& a,
|
|
||||||
const Vectorized<int64_t>& b) {
|
|
||||||
int64x2_t x = a;
|
|
||||||
int64x2_t y = b;
|
|
||||||
return x * y;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int64_t> inline operator/(
|
|
||||||
const Vectorized<int64_t>& a,
|
|
||||||
const Vectorized<int64_t>& b) {
|
|
||||||
int64x2_t x = a;
|
|
||||||
int64x2_t y = b;
|
|
||||||
return x / y;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int32_t> inline operator/(
|
|
||||||
const Vectorized<int32_t>& a,
|
|
||||||
const Vectorized<int32_t>& b) {
|
|
||||||
int32x4_t x = a;
|
|
||||||
int32x4_t y = b;
|
|
||||||
return x / y;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline int64_t Vectorized<int64_t>::reduce_max() const {
|
|
||||||
return std::max(values[0], values[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int64_t> inline minimum(
|
|
||||||
const Vectorized<int64_t>& a,
|
|
||||||
const Vectorized<int64_t>& b) {
|
|
||||||
int64x2_t x = a;
|
|
||||||
int64x2_t y = b;
|
|
||||||
return {std::min(x[0], y[0]), std::min(x[1], y[1])};
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int64_t> inline maximum(
|
|
||||||
const Vectorized<int64_t>& a,
|
|
||||||
const Vectorized<int64_t>& b) {
|
|
||||||
int64x2_t x = a;
|
|
||||||
int64x2_t y = b;
|
|
||||||
return {std::max(x[0], y[0]), std::max(x[1], y[1])};
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename step_t>
|
|
||||||
inline Vectorized<int64_t> Vectorized<int64_t>::arange(
|
|
||||||
int64_t base,
|
|
||||||
step_t step) {
|
|
||||||
const Vectorized<int64_t> base_vec(base);
|
|
||||||
const Vectorized<int64_t> step_vec(step);
|
|
||||||
const int64x2_t step_sizes = {0, 1};
|
|
||||||
return base_vec.values + step_sizes * step_vec.values;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename step_t>
|
|
||||||
inline Vectorized<int32_t> Vectorized<int32_t>::arange(
|
|
||||||
int32_t base,
|
|
||||||
step_t step) {
|
|
||||||
const Vectorized<int32_t> base_vec(base);
|
|
||||||
const Vectorized<int32_t> step_vec(step);
|
|
||||||
const int32x4_t step_sizes = {0, 1, 2, 3};
|
|
||||||
return vmlaq_s32(base_vec, step_sizes, step_vec);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename step_t>
|
|
||||||
inline Vectorized<int16_t> Vectorized<int16_t>::arange(
|
|
||||||
int16_t base,
|
|
||||||
step_t step) {
|
|
||||||
const Vectorized<int16_t> base_vec(base);
|
|
||||||
const Vectorized<int16_t> step_vec(step);
|
|
||||||
const int16x8_t step_sizes = {0, 1, 2, 3, 4, 5, 6, 7};
|
|
||||||
return vmlaq_s16(base_vec, step_sizes, step_vec);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename step_t>
|
|
||||||
inline Vectorized<int8_t> Vectorized<int8_t>::arange(int8_t base, step_t step) {
|
|
||||||
const Vectorized<int8_t> base_vec(base);
|
|
||||||
const Vectorized<int8_t> step_vec(step);
|
|
||||||
const int8x16_t step_sizes = {
|
|
||||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
|
|
||||||
return vmlaq_s8(base_vec, step_sizes, step_vec);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int64_t> inline operator>>(
|
|
||||||
const Vectorized<int64_t>& a,
|
|
||||||
const Vectorized<int64_t>& b) {
|
|
||||||
int64x2_t x = a;
|
|
||||||
int64x2_t y = b;
|
|
||||||
uint64x2_t u = vreinterpretq_u64_s64(y);
|
|
||||||
uint64x2_t z = {std::min(u[0], (uint64_t)63), std::min(u[1], (uint64_t)63)};
|
|
||||||
return x >> vreinterpretq_s64_u64(z);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int32_t> inline operator>>(
|
|
||||||
const Vectorized<int32_t>& a,
|
|
||||||
const Vectorized<int32_t>& b) {
|
|
||||||
int32x4_t x = a;
|
|
||||||
int32x4_t y = b;
|
|
||||||
uint32x4_t bound = vdupq_n_u32(31);
|
|
||||||
uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound);
|
|
||||||
return x >> vreinterpretq_s32_u32(z);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int16_t> inline operator>>(
|
|
||||||
const Vectorized<int16_t>& a,
|
|
||||||
const Vectorized<int16_t>& b) {
|
|
||||||
int16x8_t x = a;
|
|
||||||
int16x8_t y = b;
|
|
||||||
uint16x8_t bound = vdupq_n_u16(15);
|
|
||||||
uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound);
|
|
||||||
return x >> vreinterpretq_s16_u16(z);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int8_t> inline operator>>(
|
|
||||||
const Vectorized<int8_t>& a,
|
|
||||||
const Vectorized<int8_t>& b) {
|
|
||||||
int8x16_t x = a;
|
|
||||||
int8x16_t y = b;
|
|
||||||
uint8x16_t bound = vdupq_n_u8(7);
|
|
||||||
int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound));
|
|
||||||
return x >> z;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int64_t> inline operator<<(
|
|
||||||
const Vectorized<int64_t>& a,
|
|
||||||
const Vectorized<int64_t>& b) {
|
|
||||||
int64x2_t y = b;
|
|
||||||
uint64x2_t u = vreinterpretq_u64_s64(y);
|
|
||||||
uint64x2_t z = {std::min(u[0], (uint64_t)64), std::min(u[1], (uint64_t)64)};
|
|
||||||
return vshlq_s64(a, vreinterpretq_s64_u64(z));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int32_t> inline operator<<(
|
|
||||||
const Vectorized<int32_t>& a,
|
|
||||||
const Vectorized<int32_t>& b) {
|
|
||||||
int32x4_t y = b;
|
|
||||||
uint32x4_t bound = vdupq_n_u32(32);
|
|
||||||
uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound);
|
|
||||||
return vshlq_s32(a, vreinterpretq_s32_u32(z));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int16_t> inline operator<<(
|
|
||||||
const Vectorized<int16_t>& a,
|
|
||||||
const Vectorized<int16_t>& b) {
|
|
||||||
int16x8_t y = b;
|
|
||||||
uint16x8_t bound = vdupq_n_u16(16);
|
|
||||||
uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound);
|
|
||||||
return vshlq_s16(a, vreinterpretq_s16_u16(z));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int8_t> inline operator<<(
|
|
||||||
const Vectorized<int8_t>& a,
|
|
||||||
const Vectorized<int8_t>& b) {
|
|
||||||
int8x16_t y = b;
|
|
||||||
uint8x16_t bound = vdupq_n_u8(8);
|
|
||||||
int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound));
|
|
||||||
return vshlq_s8(a, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline Vectorized<int64_t> Vectorized<int64_t>::set(
|
|
||||||
const Vectorized<int64_t>& a,
|
|
||||||
const Vectorized<int64_t>& b,
|
|
||||||
int64_t count) {
|
|
||||||
if (count == 0) {
|
|
||||||
return a;
|
|
||||||
} else if (count >= 2) {
|
|
||||||
return b;
|
|
||||||
} else {
|
|
||||||
int64x2_t c = {b.values[0], a.values[1]};
|
|
||||||
return c;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline Vectorized<int32_t> Vectorized<int32_t>::set(
|
|
||||||
const Vectorized<int32_t>& a,
|
|
||||||
const Vectorized<int32_t>& b,
|
|
||||||
int64_t count) {
|
|
||||||
if (count == 0) {
|
|
||||||
return a;
|
|
||||||
} else if (count >= 4) {
|
|
||||||
return b;
|
|
||||||
} else {
|
|
||||||
// Build an array of flags: each bit of element is 1 if the corresponding
|
|
||||||
// bit in 'mask' is set, 0 otherwise.
|
|
||||||
uint32x4_t maskArray = {
|
|
||||||
(count >= 1LL) ? 0xFFFFFFFF : 0,
|
|
||||||
(count >= 2LL) ? 0xFFFFFFFF : 0,
|
|
||||||
(count >= 3LL) ? 0xFFFFFFFF : 0,
|
|
||||||
0};
|
|
||||||
// Use BSL to select elements from b where the mask is 1, else from a
|
|
||||||
return vbslq_s32(maskArray, b.values, a.values);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline Vectorized<int16_t> Vectorized<int16_t>::set(
|
|
||||||
const Vectorized<int16_t>& a,
|
|
||||||
const Vectorized<int16_t>& b,
|
|
||||||
int64_t count) {
|
|
||||||
if (count == 0) {
|
|
||||||
return a;
|
|
||||||
} else if (count >= 8) {
|
|
||||||
return b;
|
|
||||||
} else {
|
|
||||||
// Build an array of flags: each bit of element is 1 if the corresponding
|
|
||||||
// bit in 'mask' is set, 0 otherwise.
|
|
||||||
uint16x8_t maskArray = {
|
|
||||||
static_cast<uint16_t>((count >= 1LL) ? 0xFFFF : 0),
|
|
||||||
static_cast<uint16_t>((count >= 2LL) ? 0xFFFF : 0),
|
|
||||||
static_cast<uint16_t>((count >= 3LL) ? 0xFFFF : 0),
|
|
||||||
static_cast<uint16_t>((count >= 4LL) ? 0xFFFF : 0),
|
|
||||||
static_cast<uint16_t>((count >= 5LL) ? 0xFFFF : 0),
|
|
||||||
static_cast<uint16_t>((count >= 6LL) ? 0xFFFF : 0),
|
|
||||||
static_cast<uint16_t>((count >= 7LL) ? 0xFFFF : 0),
|
|
||||||
0};
|
|
||||||
// Use BSL to select elements from b where the mask is 1, else from a
|
|
||||||
return vbslq_s16(maskArray, b.values, a.values);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline Vectorized<int8_t> Vectorized<int8_t>::set(
|
|
||||||
const Vectorized<int8_t>& a,
|
|
||||||
const Vectorized<int8_t>& b,
|
|
||||||
int64_t count) {
|
|
||||||
if (count == 0) {
|
|
||||||
return a;
|
|
||||||
} else if (count >= 16) {
|
|
||||||
return b;
|
|
||||||
} else {
|
|
||||||
// Build an array of flags: each bit of element is 1 if the corresponding
|
|
||||||
// bit in 'mask' is set, 0 otherwise.
|
|
||||||
uint8x16_t maskArray = {
|
|
||||||
static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
|
|
||||||
0};
|
|
||||||
|
|
||||||
// Use BSL to select elements from b where the mask is 1, else from a
|
|
||||||
return vbslq_s8(maskArray, b.values, a.values);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int16_t> inline operator/(
|
|
||||||
const Vectorized<int16_t>& a,
|
|
||||||
const Vectorized<int16_t>& b) {
|
|
||||||
Vectorized<int32_t> highBitsA = vmovl_high_s16(a);
|
|
||||||
Vectorized<int32_t> highBitsB = vmovl_high_s16(b);
|
|
||||||
Vectorized<int32_t> lowBitsA = vmovl_s16(vget_low_s16(a));
|
|
||||||
Vectorized<int32_t> lowBitsB = vmovl_s16(vget_low_s16(b));
|
|
||||||
int32x4_t highBitsResult = highBitsA / highBitsB;
|
|
||||||
int32x4_t lowBitsResult = lowBitsA / lowBitsB;
|
|
||||||
return vuzp1q_s16(
|
|
||||||
vreinterpretq_s16_s32(lowBitsResult),
|
|
||||||
vreinterpretq_s16_s32(highBitsResult));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int8_t> inline operator/(
|
|
||||||
const Vectorized<int8_t>& a,
|
|
||||||
const Vectorized<int8_t>& b) {
|
|
||||||
Vectorized<int16_t> highBitsA = vmovl_high_s8(a);
|
|
||||||
Vectorized<int16_t> highBitsB = vmovl_high_s8(b);
|
|
||||||
Vectorized<int16_t> lowBitsA = vmovl_s8(vget_low_s8(a));
|
|
||||||
Vectorized<int16_t> lowBitsB = vmovl_s8(vget_low_s8(b));
|
|
||||||
int16x8_t highBitsResult = highBitsA / highBitsB;
|
|
||||||
int16x8_t lowBitsResult = lowBitsA / lowBitsB;
|
|
||||||
return vuzp1q_s8(
|
|
||||||
vreinterpretq_s8_s16(lowBitsResult),
|
|
||||||
vreinterpretq_s8_s16(highBitsResult));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int64_t> inline clamp(
|
|
||||||
const Vectorized<int64_t>& a,
|
|
||||||
const Vectorized<int64_t>& min,
|
|
||||||
const Vectorized<int64_t>& max) {
|
|
||||||
return minimum(max, maximum(min, a));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int32_t> inline clamp(
|
|
||||||
const Vectorized<int32_t>& a,
|
|
||||||
const Vectorized<int32_t>& min,
|
|
||||||
const Vectorized<int32_t>& max) {
|
|
||||||
return minimum(max, maximum(min, a));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int16_t> inline clamp(
|
|
||||||
const Vectorized<int16_t>& a,
|
|
||||||
const Vectorized<int16_t>& min,
|
|
||||||
const Vectorized<int16_t>& max) {
|
|
||||||
return minimum(max, maximum(min, a));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int8_t> inline clamp(
|
|
||||||
const Vectorized<int8_t>& a,
|
|
||||||
const Vectorized<int8_t>& min,
|
|
||||||
const Vectorized<int8_t>& max) {
|
|
||||||
return minimum(max, maximum(min, a));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int64_t> inline clamp_max(
|
|
||||||
const Vectorized<int64_t>& a,
|
|
||||||
const Vectorized<int64_t>& max) {
|
|
||||||
return minimum(max, a);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int32_t> inline clamp_max(
|
|
||||||
const Vectorized<int32_t>& a,
|
|
||||||
const Vectorized<int32_t>& max) {
|
|
||||||
return minimum(max, a);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int16_t> inline clamp_max(
|
|
||||||
const Vectorized<int16_t>& a,
|
|
||||||
const Vectorized<int16_t>& max) {
|
|
||||||
return minimum(max, a);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int8_t> inline clamp_max(
|
|
||||||
const Vectorized<int8_t>& a,
|
|
||||||
const Vectorized<int8_t>& max) {
|
|
||||||
return minimum(max, a);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int64_t> inline clamp_min(
|
|
||||||
const Vectorized<int64_t>& a,
|
|
||||||
const Vectorized<int64_t>& min) {
|
|
||||||
return maximum(min, a);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int32_t> inline clamp_min(
|
|
||||||
const Vectorized<int32_t>& a,
|
|
||||||
const Vectorized<int32_t>& min) {
|
|
||||||
return maximum(min, a);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int16_t> inline clamp_min(
|
|
||||||
const Vectorized<int16_t>& a,
|
|
||||||
const Vectorized<int16_t>& min) {
|
|
||||||
return maximum(min, a);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<int8_t> inline clamp_min(
|
|
||||||
const Vectorized<int8_t>& a,
|
|
||||||
const Vectorized<int8_t>& min) {
|
|
||||||
return maximum(min, a);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace CPU_CAPABILITY
|
|
||||||
} // namespace at::vec
|
|
||||||
@ -1,378 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <ATen/cpu/vec/intrinsics.h>
|
|
||||||
#include <ATen/cpu/vec/vec_base.h>
|
|
||||||
#include <c10/macros/Macros.h>
|
|
||||||
#include <c10/util/irange.h>
|
|
||||||
|
|
||||||
namespace at::vec {
|
|
||||||
// Note [CPU_CAPABILITY namespace]
|
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
// This header, and all of its subheaders, will be compiled with
|
|
||||||
// different architecture flags for each supported set of vector
|
|
||||||
// intrinsics. So we need to make sure they aren't inadvertently
|
|
||||||
// linked together. We do this by declaring objects in an `inline
|
|
||||||
// namespace` which changes the name mangling, but can still be
|
|
||||||
// accessed as `at::vec`.
|
|
||||||
inline namespace CPU_CAPABILITY {
|
|
||||||
|
|
||||||
#define VEC_UINT_NEON_TEMPLATE(vl, bit) \
|
|
||||||
template <> \
|
|
||||||
struct is_vec_specialized_for<uint##bit##_t> : std::bool_constant<true> {}; \
|
|
||||||
\
|
|
||||||
template <> \
|
|
||||||
class Vectorized<uint##bit##_t> { \
|
|
||||||
using neon_type = uint##bit##x##vl##_t; \
|
|
||||||
\
|
|
||||||
private: \
|
|
||||||
neon_type values; \
|
|
||||||
\
|
|
||||||
public: \
|
|
||||||
using value_type = uint##bit##_t; \
|
|
||||||
using size_type = int; \
|
|
||||||
static constexpr size_type size() { \
|
|
||||||
return vl; \
|
|
||||||
} \
|
|
||||||
Vectorized() { \
|
|
||||||
values = vdupq_n_u##bit(0); \
|
|
||||||
} \
|
|
||||||
Vectorized(neon_type v) : values(v) {} \
|
|
||||||
Vectorized(uint##bit##_t val); \
|
|
||||||
template < \
|
|
||||||
typename... Args, \
|
|
||||||
typename = std::enable_if_t<(sizeof...(Args) == size())>> \
|
|
||||||
Vectorized(Args... vals) { \
|
|
||||||
__at_align__ uint##bit##_t buffer[size()] = {vals...}; \
|
|
||||||
values = vld1q_u##bit(buffer); \
|
|
||||||
} \
|
|
||||||
operator neon_type() const { \
|
|
||||||
return values; \
|
|
||||||
} \
|
|
||||||
static Vectorized<uint##bit##_t> loadu( \
|
|
||||||
const void* ptr, \
|
|
||||||
uint64_t count = size()); \
|
|
||||||
void store(void* ptr, uint64_t count = size()) const; \
|
|
||||||
template <uint64_t mask> \
|
|
||||||
static Vectorized<uint##bit##_t> blend( \
|
|
||||||
const Vectorized<uint##bit##_t>& a, \
|
|
||||||
const Vectorized<uint##bit##_t>& b); \
|
|
||||||
static Vectorized<uint##bit##_t> blendv( \
|
|
||||||
const Vectorized<uint##bit##_t>& a, \
|
|
||||||
const Vectorized<uint##bit##_t>& b, \
|
|
||||||
const Vectorized<uint##bit##_t>& mask_) { \
|
|
||||||
return vbslq_u##bit(mask_.values, b, a); \
|
|
||||||
} \
|
|
||||||
template <typename step_t> \
|
|
||||||
static Vectorized<uint##bit##_t> arange( \
|
|
||||||
value_type base = 0, \
|
|
||||||
step_t step = static_cast<step_t>(1)); \
|
|
||||||
static Vectorized<uint##bit##_t> set( \
|
|
||||||
const Vectorized<uint##bit##_t>& a, \
|
|
||||||
const Vectorized<uint##bit##_t>& b, \
|
|
||||||
uint64_t count = size()); \
|
|
||||||
const uint##bit##_t& operator[](uint idx) const = delete; \
|
|
||||||
uint##bit##_t& operator[](uint idx) = delete; \
|
|
||||||
Vectorized<uint##bit##_t> abs() const { \
|
|
||||||
return values; \
|
|
||||||
} \
|
|
||||||
Vectorized<uint##bit##_t> real() const { \
|
|
||||||
return values; \
|
|
||||||
} \
|
|
||||||
Vectorized<uint##bit##_t> imag() const { \
|
|
||||||
return vdupq_n_u##bit(0); \
|
|
||||||
} \
|
|
||||||
Vectorized<uint##bit##_t> conj() const { \
|
|
||||||
return values; \
|
|
||||||
} \
|
|
||||||
Vectorized<uint##bit##_t> neg() const { \
|
|
||||||
return vreinterpretq_u##bit##_s##bit( \
|
|
||||||
vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values))); \
|
|
||||||
} \
|
|
||||||
uint##bit##_t reduce_add() const { \
|
|
||||||
return vaddvq_u##bit(values); \
|
|
||||||
} \
|
|
||||||
uint##bit##_t reduce_max() const; \
|
|
||||||
Vectorized<uint##bit##_t> operator==( \
|
|
||||||
const Vectorized<uint##bit##_t>& other) const { \
|
|
||||||
return Vectorized<value_type>(vceqq_u##bit(values, other.values)); \
|
|
||||||
} \
|
|
||||||
Vectorized<uint##bit##_t> operator!=( \
|
|
||||||
const Vectorized<uint##bit##_t>& other) const; \
|
|
||||||
Vectorized<uint##bit##_t> operator<( \
|
|
||||||
const Vectorized<uint##bit##_t>& other) const { \
|
|
||||||
return Vectorized<value_type>(vcltq_u##bit(values, other.values)); \
|
|
||||||
} \
|
|
||||||
Vectorized<uint##bit##_t> operator<=( \
|
|
||||||
const Vectorized<uint##bit##_t>& other) const { \
|
|
||||||
return Vectorized<value_type>(vcleq_u##bit(values, other.values)); \
|
|
||||||
} \
|
|
||||||
Vectorized<uint##bit##_t> operator>( \
|
|
||||||
const Vectorized<uint##bit##_t>& other) const { \
|
|
||||||
return Vectorized<value_type>(vcgtq_u##bit(values, other.values)); \
|
|
||||||
} \
|
|
||||||
Vectorized<uint##bit##_t> operator>=( \
|
|
||||||
const Vectorized<uint##bit##_t>& other) const { \
|
|
||||||
return Vectorized<value_type>(vcgeq_u##bit(values, other.values)); \
|
|
||||||
} \
|
|
||||||
Vectorized<uint##bit##_t> eq( \
|
|
||||||
const Vectorized<uint##bit##_t>& other) const; \
|
|
||||||
Vectorized<uint##bit##_t> ne( \
|
|
||||||
const Vectorized<uint##bit##_t>& other) const; \
|
|
||||||
Vectorized<uint##bit##_t> gt( \
|
|
||||||
const Vectorized<uint##bit##_t>& other) const; \
|
|
||||||
Vectorized<uint##bit##_t> ge( \
|
|
||||||
const Vectorized<uint##bit##_t>& other) const; \
|
|
||||||
Vectorized<uint##bit##_t> lt( \
|
|
||||||
const Vectorized<uint##bit##_t>& other) const; \
|
|
||||||
Vectorized<uint##bit##_t> le( \
|
|
||||||
const Vectorized<uint##bit##_t>& other) const; \
|
|
||||||
}; \
|
|
||||||
template <> \
|
|
||||||
Vectorized<uint##bit##_t> inline operator+( \
|
|
||||||
const Vectorized<uint##bit##_t>& a, \
|
|
||||||
const Vectorized<uint##bit##_t>& b) { \
|
|
||||||
return vaddq_u##bit(a, b); \
|
|
||||||
} \
|
|
||||||
template <> \
|
|
||||||
Vectorized<uint##bit##_t> inline operator-( \
|
|
||||||
const Vectorized<uint##bit##_t>& a, \
|
|
||||||
const Vectorized<uint##bit##_t>& b) { \
|
|
||||||
return vsubq_u##bit(a, b); \
|
|
||||||
} \
|
|
||||||
template <> \
|
|
||||||
Vectorized<uint##bit##_t> inline operator&( \
|
|
||||||
const Vectorized<uint##bit##_t>& a, \
|
|
||||||
const Vectorized<uint##bit##_t>& b) { \
|
|
||||||
return vandq_u##bit(a, b); \
|
|
||||||
} \
|
|
||||||
template <> \
|
|
||||||
Vectorized<uint##bit##_t> inline operator|( \
|
|
||||||
const Vectorized<uint##bit##_t>& a, \
|
|
||||||
const Vectorized<uint##bit##_t>& b) { \
|
|
||||||
return vorrq_u##bit(a, b); \
|
|
||||||
} \
|
|
||||||
template <> \
|
|
||||||
Vectorized<uint##bit##_t> inline operator^( \
|
|
||||||
const Vectorized<uint##bit##_t>& a, \
|
|
||||||
const Vectorized<uint##bit##_t>& b) { \
|
|
||||||
return veorq_u##bit(a, b); \
|
|
||||||
} \
|
|
||||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::eq( \
|
|
||||||
const Vectorized<uint##bit##_t>& other) const { \
|
|
||||||
return (*this == other) & Vectorized<uint##bit##_t>(1); \
|
|
||||||
} \
|
|
||||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ne( \
|
|
||||||
const Vectorized<uint##bit##_t>& other) const { \
|
|
||||||
return (*this != other) & Vectorized<uint##bit##_t>(1); \
|
|
||||||
} \
|
|
||||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::gt( \
|
|
||||||
const Vectorized<uint##bit##_t>& other) const { \
|
|
||||||
return (*this > other) & Vectorized<uint##bit##_t>(1); \
|
|
||||||
} \
|
|
||||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ge( \
|
|
||||||
const Vectorized<uint##bit##_t>& other) const { \
|
|
||||||
return (*this >= other) & Vectorized<uint##bit##_t>(1); \
|
|
||||||
} \
|
|
||||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::lt( \
|
|
||||||
const Vectorized<uint##bit##_t>& other) const { \
|
|
||||||
return (*this < other) & Vectorized<uint##bit##_t>(1); \
|
|
||||||
} \
|
|
||||||
Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::le( \
|
|
||||||
const Vectorized<uint##bit##_t>& other) const { \
|
|
||||||
return (*this <= other) & Vectorized<uint##bit##_t>(1); \
|
|
||||||
}
|
|
||||||
|
|
||||||
VEC_UINT_NEON_TEMPLATE(16, 8)
|
|
||||||
|
|
||||||
inline uint8_t Vectorized<uint8_t>::reduce_max() const {
|
|
||||||
return vmaxvq_u8(values);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<uint8_t> inline operator*(
|
|
||||||
const Vectorized<uint8_t>& a,
|
|
||||||
const Vectorized<uint8_t>& b) {
|
|
||||||
return vmulq_u8(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) {
|
|
||||||
return vmvnq_u8(a);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::operator!=(
|
|
||||||
const Vectorized<uint8_t>& other) const {
|
|
||||||
return ~(*this == other);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<uint8_t> inline minimum(
|
|
||||||
const Vectorized<uint8_t>& a,
|
|
||||||
const Vectorized<uint8_t>& b) {
|
|
||||||
return vminq_u8(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<uint8_t> inline maximum(
|
|
||||||
const Vectorized<uint8_t>& a,
|
|
||||||
const Vectorized<uint8_t>& b) {
|
|
||||||
return vmaxq_u8(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <uint64_t mask>
|
|
||||||
Vectorized<uint8_t> Vectorized<uint8_t>::blend(
|
|
||||||
const Vectorized<uint8_t>& a,
|
|
||||||
const Vectorized<uint8_t>& b) {
|
|
||||||
// Build an array of flags: each bit of element is 1 if the corresponding bit
|
|
||||||
// in 'mask' is set, 0 otherwise.
|
|
||||||
uint8x16_t maskArray = {
|
|
||||||
(mask & 1LL) ? 0xFF : 0,
|
|
||||||
(mask & 2LL) ? 0xFF : 0,
|
|
||||||
(mask & 4LL) ? 0xFF : 0,
|
|
||||||
(mask & 8LL) ? 0xFF : 0,
|
|
||||||
(mask & 16LL) ? 0xFF : 0,
|
|
||||||
(mask & 32LL) ? 0xFF : 0,
|
|
||||||
(mask & 64LL) ? 0xFF : 0,
|
|
||||||
(mask & 128LL) ? 0xFF : 0,
|
|
||||||
(mask & 256LL) ? 0xFF : 0,
|
|
||||||
(mask & 512LL) ? 0xFF : 0,
|
|
||||||
(mask & 1024LL) ? 0xFF : 0,
|
|
||||||
(mask & 2048LL) ? 0xFF : 0,
|
|
||||||
(mask & 4096LL) ? 0xFF : 0,
|
|
||||||
(mask & 8192LL) ? 0xFF : 0,
|
|
||||||
(mask & 16384LL) ? 0xFF : 0,
|
|
||||||
(mask & 32768LL) ? 0xFF : 0};
|
|
||||||
// Use BSL to select elements from b where the mask is 1, else from a
|
|
||||||
return vbslq_u8(maskArray, b.values, a.values);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define VEC_UINT_NEON_OPS(vl, bit) \
|
|
||||||
inline Vectorized<uint##bit##_t>::Vectorized(uint##bit##_t val) { \
|
|
||||||
values = vdupq_n_u##bit(val); \
|
|
||||||
} \
|
|
||||||
inline Vectorized<uint##bit##_t> Vectorized<uint##bit##_t>::loadu( \
|
|
||||||
const void* ptr, uint64_t count) { \
|
|
||||||
if (count == size()) { \
|
|
||||||
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(ptr)); \
|
|
||||||
} else { \
|
|
||||||
__at_align__ uint##bit##_t tmp_values[size()]; \
|
|
||||||
for (const auto i : c10::irange(size())) { \
|
|
||||||
tmp_values[i] = 0; \
|
|
||||||
} \
|
|
||||||
std::memcpy( \
|
|
||||||
tmp_values, \
|
|
||||||
reinterpret_cast<const uint##bit##_t*>(ptr), \
|
|
||||||
count * sizeof(uint##bit##_t)); \
|
|
||||||
return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(tmp_values)); \
|
|
||||||
} \
|
|
||||||
} \
|
|
||||||
inline void Vectorized<uint##bit##_t>::store(void* ptr, uint64_t count) \
|
|
||||||
const { \
|
|
||||||
if (count == size()) { \
|
|
||||||
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(ptr), values); \
|
|
||||||
} else { \
|
|
||||||
uint##bit##_t tmp_values[size()]; \
|
|
||||||
vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(tmp_values), values); \
|
|
||||||
std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t)); \
|
|
||||||
} \
|
|
||||||
}
|
|
||||||
|
|
||||||
VEC_UINT_NEON_OPS(16, 8)
|
|
||||||
|
|
||||||
template <typename step_t>
|
|
||||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::arange(
|
|
||||||
uint8_t base,
|
|
||||||
step_t step) {
|
|
||||||
const Vectorized<uint8_t> base_vec(base);
|
|
||||||
const Vectorized<uint8_t> step_vec(step);
|
|
||||||
const uint8x16_t step_sizes = {
|
|
||||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
|
|
||||||
return vmlaq_u8(base_vec, step_sizes, step_vec);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<uint8_t> inline operator>>(
|
|
||||||
const Vectorized<uint8_t>& a,
|
|
||||||
const Vectorized<uint8_t>& b) {
|
|
||||||
uint8x16_t x = a;
|
|
||||||
uint8x16_t bound = vdupq_n_u8(8);
|
|
||||||
uint8x16_t z = vminq_u8(b, bound);
|
|
||||||
return x >> z;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<uint8_t> inline operator<<(
|
|
||||||
const Vectorized<uint8_t>& a,
|
|
||||||
const Vectorized<uint8_t>& b) {
|
|
||||||
uint8x16_t bound = vdupq_n_u8(8);
|
|
||||||
uint8x16_t z = vminq_u8(b, bound);
|
|
||||||
return vshlq_u8(a, vreinterpretq_s8_u8(z));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline Vectorized<uint8_t> Vectorized<uint8_t>::set(
|
|
||||||
const Vectorized<uint8_t>& a,
|
|
||||||
const Vectorized<uint8_t>& b,
|
|
||||||
uint64_t count) {
|
|
||||||
if (count == 0) {
|
|
||||||
return a;
|
|
||||||
} else if (count >= 16) {
|
|
||||||
return b;
|
|
||||||
} else {
|
|
||||||
// Build an array of flags: each bit of element is 1 if the corresponding
|
|
||||||
// bit in 'mask' is set, 0 otherwise.
|
|
||||||
uint8x16_t maskArray = {
|
|
||||||
static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
|
|
||||||
static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
|
|
||||||
0};
|
|
||||||
|
|
||||||
// Use BSL to select elements from b where the mask is 1, else from a
|
|
||||||
return vbslq_u8(maskArray, b.values, a.values);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<uint8_t> inline operator/(
|
|
||||||
const Vectorized<uint8_t>& a,
|
|
||||||
const Vectorized<uint8_t>& b) {
|
|
||||||
uint8x16_t x = a;
|
|
||||||
uint8x16_t y = b;
|
|
||||||
return x / y;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<uint8_t> inline clamp(
|
|
||||||
const Vectorized<uint8_t>& a,
|
|
||||||
const Vectorized<uint8_t>& min,
|
|
||||||
const Vectorized<uint8_t>& max) {
|
|
||||||
return minimum(max, maximum(min, a));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<uint8_t> inline clamp_max(
|
|
||||||
const Vectorized<uint8_t>& a,
|
|
||||||
const Vectorized<uint8_t>& max) {
|
|
||||||
return minimum(max, a);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Vectorized<uint8_t> inline clamp_min(
|
|
||||||
const Vectorized<uint8_t>& a,
|
|
||||||
const Vectorized<uint8_t>& min) {
|
|
||||||
return maximum(min, a);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace CPU_CAPABILITY
|
|
||||||
} // namespace at::vec
|
|
||||||
@ -1740,7 +1740,7 @@ Vectorized<int16_t> inline shift_256_16(
|
|||||||
|
|
||||||
// Control masks for shuffle operation, treating 256 bits as an
|
// Control masks for shuffle operation, treating 256 bits as an
|
||||||
// array of 16-bit elements, and considering pairs of neighboring
|
// array of 16-bit elements, and considering pairs of neighboring
|
||||||
// elements. Specifically, a mask named "ctl_M_N" (M,N in [0,1], and
|
// elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and
|
||||||
// M!=N) is set so that shuffle will move element with index M from
|
// M!=N) is set so that shuffle will move element with index M from
|
||||||
// input pair into element with index N in output pair, and element
|
// input pair into element with index N in output pair, and element
|
||||||
// with index M in output pair will be set to all 0s.
|
// with index M in output pair will be set to all 0s.
|
||||||
@ -1875,7 +1875,7 @@ Vectorized<T> inline shift_256_8(
|
|||||||
|
|
||||||
// Control masks for shuffle operation, treating 256 bits as an
|
// Control masks for shuffle operation, treating 256 bits as an
|
||||||
// array of 8-bit elements, and considering quadruples of
|
// array of 8-bit elements, and considering quadruples of
|
||||||
// neighboring elements. Specifically, a mask named "ctl_M_N" (M,N
|
// neighboring elements. Specifially, a mask named "ctl_M_N" (M,N
|
||||||
// in [0,1,2,3], and M!=N) is set so that shuffle will move element
|
// in [0,1,2,3], and M!=N) is set so that shuffle will move element
|
||||||
// with index M from input quadruple into element with index N in
|
// with index M from input quadruple into element with index N in
|
||||||
// output quadruple, and other elements in output quadruple will be
|
// output quadruple, and other elements in output quadruple will be
|
||||||
|
|||||||
@ -1377,7 +1377,7 @@ Vectorized<c10::quint8> inline maximum(
|
|||||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
|
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
|
||||||
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
||||||
at::vec::Vectorized<int8_t> src) {
|
at::vec::Vectorized<int8_t> src) {
|
||||||
auto s8x8 = vget_low_s8(src);
|
auto s8x8 = vld1_s8(src.operator const int8_t*());
|
||||||
auto s16x8 = vmovl_s8(s8x8);
|
auto s16x8 = vmovl_s8(s8x8);
|
||||||
|
|
||||||
auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8));
|
auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8));
|
||||||
@ -1390,7 +1390,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
|||||||
|
|
||||||
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
||||||
at::vec::Vectorized<uint8_t> src) {
|
at::vec::Vectorized<uint8_t> src) {
|
||||||
auto u8x8 = vget_low_u8(src);
|
auto u8x8 = vld1_u8(src.operator const uint8_t*());
|
||||||
auto u16x8 = vmovl_u8(u8x8);
|
auto u16x8 = vmovl_u8(u8x8);
|
||||||
auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8));
|
auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8));
|
||||||
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
|
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
|
||||||
@ -1402,7 +1402,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
|||||||
|
|
||||||
Vectorized<float> inline convert_int8_half_register_to_float(
|
Vectorized<float> inline convert_int8_half_register_to_float(
|
||||||
at::vec::Vectorized<int8_t> src) {
|
at::vec::Vectorized<int8_t> src) {
|
||||||
auto s8x8 = vget_low_s8(src);
|
auto s8x8 = vld1_s8(src.operator const int8_t*());
|
||||||
auto s16x8 = vmovl_s8(s8x8);
|
auto s16x8 = vmovl_s8(s8x8);
|
||||||
|
|
||||||
auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8));
|
auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8));
|
||||||
@ -1412,7 +1412,7 @@ Vectorized<float> inline convert_int8_half_register_to_float(
|
|||||||
|
|
||||||
Vectorized<float> inline convert_int8_half_register_to_float(
|
Vectorized<float> inline convert_int8_half_register_to_float(
|
||||||
at::vec::Vectorized<uint8_t> src) {
|
at::vec::Vectorized<uint8_t> src) {
|
||||||
auto u8x8 = vget_low_u8(src);
|
auto u8x8 = vld1_u8(src.operator const uint8_t*());
|
||||||
auto u16x8 = vmovl_u8(u8x8);
|
auto u16x8 = vmovl_u8(u8x8);
|
||||||
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
|
auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
|
||||||
|
|
||||||
|
|||||||
@ -143,7 +143,7 @@ class Vectorized<double> {
|
|||||||
const Vectorized<double>& a,
|
const Vectorized<double>& a,
|
||||||
const Vectorized<double>& b,
|
const Vectorized<double>& b,
|
||||||
const Vectorized<double>& mask) {
|
const Vectorized<double>& mask) {
|
||||||
// the mask used here returned by comparison of vec256
|
// the mask used here returned by comparision of vec256
|
||||||
|
|
||||||
return {
|
return {
|
||||||
vec_sel(a._vec0, b._vec0, mask._vecb0),
|
vec_sel(a._vec0, b._vec0, mask._vecb0),
|
||||||
|
|||||||
@ -142,7 +142,7 @@ class Vectorized<float> {
|
|||||||
const Vectorized<float>& a,
|
const Vectorized<float>& a,
|
||||||
const Vectorized<float>& b,
|
const Vectorized<float>& b,
|
||||||
const Vectorized<float>& mask) {
|
const Vectorized<float>& mask) {
|
||||||
// the mask used here returned by comparison of vec256
|
// the mask used here returned by comparision of vec256
|
||||||
// assuming this we can use the same mask directly with vec_sel
|
// assuming this we can use the same mask directly with vec_sel
|
||||||
return {
|
return {
|
||||||
vec_sel(a._vec0, b._vec0, mask._vecb0),
|
vec_sel(a._vec0, b._vec0, mask._vecb0),
|
||||||
|
|||||||
@ -202,7 +202,7 @@ class Vectorized<int16_t> {
|
|||||||
const Vectorized<int16_t>& a,
|
const Vectorized<int16_t>& a,
|
||||||
const Vectorized<int16_t>& b,
|
const Vectorized<int16_t>& b,
|
||||||
const Vectorized<int16_t>& mask) {
|
const Vectorized<int16_t>& mask) {
|
||||||
// the mask used here returned by comparison of vec256
|
// the mask used here returned by comparision of vec256
|
||||||
// assuming this we can use the same mask directly with vec_sel
|
// assuming this we can use the same mask directly with vec_sel
|
||||||
// warning intel style mask will not work properly
|
// warning intel style mask will not work properly
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -155,7 +155,7 @@ class Vectorized<int32_t> {
|
|||||||
const Vectorized<int32_t>& a,
|
const Vectorized<int32_t>& a,
|
||||||
const Vectorized<int32_t>& b,
|
const Vectorized<int32_t>& b,
|
||||||
const Vectorized<int32_t>& mask) {
|
const Vectorized<int32_t>& mask) {
|
||||||
// the mask used here returned by comparison of vec256
|
// the mask used here returned by comparision of vec256
|
||||||
// assuming this we can use the same mask directly with vec_sel
|
// assuming this we can use the same mask directly with vec_sel
|
||||||
// warning intel style mask will not work properly
|
// warning intel style mask will not work properly
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -119,7 +119,7 @@ class Vectorized<int64_t> {
|
|||||||
const Vectorized<int64_t>& a,
|
const Vectorized<int64_t>& a,
|
||||||
const Vectorized<int64_t>& b,
|
const Vectorized<int64_t>& b,
|
||||||
const Vectorized<int64_t>& mask) {
|
const Vectorized<int64_t>& mask) {
|
||||||
// the mask used here returned by comparison of vec256
|
// the mask used here returned by comparision of vec256
|
||||||
|
|
||||||
return {
|
return {
|
||||||
vec_sel(a._vec0, b._vec0, mask._vecb0),
|
vec_sel(a._vec0, b._vec0, mask._vecb0),
|
||||||
|
|||||||
@ -397,7 +397,7 @@ inline Vectorized<bool> operator&&(
|
|||||||
const __m512i* other_ = reinterpret_cast<const __m512i*>(other.as_bytes());
|
const __m512i* other_ = reinterpret_cast<const __m512i*>(other.as_bytes());
|
||||||
__m512i out = _mm512_and_si512(*self_, *other_);
|
__m512i out = _mm512_and_si512(*self_, *other_);
|
||||||
Vectorized<bool> ret;
|
Vectorized<bool> ret;
|
||||||
// We do not have a constructor that takes __m512i, so we need to memcpy
|
// We do not have a constructer that takes __m512i, so we need to memcpy
|
||||||
std::memcpy(ret, &out, ret.size() * sizeof(bool));
|
std::memcpy(ret, &out, ret.size() * sizeof(bool));
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1852,7 +1852,7 @@ Vectorized<T> inline shift_512_8(
|
|||||||
|
|
||||||
// Control masks for shuffle operation, treating 512 bits as an
|
// Control masks for shuffle operation, treating 512 bits as an
|
||||||
// array of 8-bit elements, and considering pairs of neighboring
|
// array of 8-bit elements, and considering pairs of neighboring
|
||||||
// elements. Specifically, a mask named "ctl_M_N" (M,N in [0,1], and
|
// elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and
|
||||||
// M!=N) is set so that shuffle will move element with index M from
|
// M!=N) is set so that shuffle will move element with index M from
|
||||||
// input pair into element with index N in output pair, and element
|
// input pair into element with index N in output pair, and element
|
||||||
// with index M in output pair will be set to all 0s.
|
// with index M in output pair will be set to all 0s.
|
||||||
|
|||||||
@ -634,7 +634,7 @@ struct Vectorized {
|
|||||||
}
|
}
|
||||||
Vectorized<T> neg() const {
|
Vectorized<T> neg() const {
|
||||||
// NB: the trailing return type is needed because we need to coerce the
|
// NB: the trailing return type is needed because we need to coerce the
|
||||||
// return value back to T in the case of unary operator- incurring a
|
// return value back to T in the case of unary operator- incuring a
|
||||||
// promotion
|
// promotion
|
||||||
return map([](T x) -> T { return -x; });
|
return map([](T x) -> T { return -x; });
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,8 +16,6 @@
|
|||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
#include <c10/core/ScalarType.h>
|
#include <c10/core/ScalarType.h>
|
||||||
|
|
||||||
#include <ATen/cuda/detail/BLASConstants.h>
|
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
#include <c10/cuda/CUDAStream.h>
|
#include <c10/cuda/CUDAStream.h>
|
||||||
#include <hipblaslt/hipblaslt-ext.hpp>
|
#include <hipblaslt/hipblaslt-ext.hpp>
|
||||||
@ -1956,15 +1954,13 @@ void scaled_gemm(
|
|||||||
const void *result_scale_ptr,
|
const void *result_scale_ptr,
|
||||||
int64_t result_ld,
|
int64_t result_ld,
|
||||||
ScalarType result_dtype,
|
ScalarType result_dtype,
|
||||||
bool use_fast_accum,
|
bool use_fast_accum) {
|
||||||
const std::optional<Tensor>& alpha) {
|
// Note: see `cublasCommonArgs` for various non-intuitive manupulations
|
||||||
// Note: see `cublasCommonArgs` for various non-intuitive manipulations
|
|
||||||
// of input arguments to this function.
|
// of input arguments to this function.
|
||||||
const auto computeType = CUBLAS_COMPUTE_32F;
|
const auto computeType = CUBLAS_COMPUTE_32F;
|
||||||
const auto scaleType = CUDA_R_32F;
|
const auto scaleType = CUDA_R_32F;
|
||||||
// Note: alpha_val may change later depending on user-passed argument
|
const float alpha_val = 1.0;
|
||||||
float alpha_val = 1.0;
|
const float beta_val = 0.0;
|
||||||
float beta_val = 0.0;
|
|
||||||
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
|
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa));
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa));
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
|
||||||
@ -2035,33 +2031,6 @@ void scaled_gemm(
|
|||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS);
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS);
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle user-passed alpha
|
|
||||||
float *alpha_ptr = &alpha_val;
|
|
||||||
float *beta_ptr = &beta_val;
|
|
||||||
|
|
||||||
if (alpha.has_value()) {
|
|
||||||
auto& a = alpha.value();
|
|
||||||
|
|
||||||
// if device-tensor
|
|
||||||
if (a.is_cuda()) {
|
|
||||||
// NOTE: there are lifetime requirements on device-side pointers for alpha/beta -- the value must be
|
|
||||||
// valid & correct until the cublas call finishes (not is scheduled like host-side values). Thus
|
|
||||||
// we need to use allocations for alpha/beta that have some guarantees on lifetime - a statically
|
|
||||||
// managed 4B buffer for alpha that we'll copy the passed alpha value into, and constant memory
|
|
||||||
// for beta respectively.
|
|
||||||
float *user_alpha_ptr = at::cuda::detail::get_user_alpha_ptr();
|
|
||||||
at::Tensor user_alpha = at::from_blob(user_alpha_ptr, {1}, TensorOptions().device(kCUDA).dtype(kFloat));
|
|
||||||
user_alpha.copy_(a);
|
|
||||||
// Tell cublasLt we're using device-side pointers for alpha/beta
|
|
||||||
auto pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
|
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_POINTER_MODE, pointer_mode);
|
|
||||||
alpha_ptr = user_alpha.data_ptr<float>();
|
|
||||||
beta_ptr = at::cuda::detail::get_cublas_device_zero();
|
|
||||||
} else {
|
|
||||||
alpha_val = a.item<float>();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// For other data types, use the get_scale_mode function based on scaling type
|
// For other data types, use the get_scale_mode function based on scaling type
|
||||||
// The SCALE_MODE attrs only exist in cuBLAS 12.8+/ROCm 7.0 or in recent hipblaslt,
|
// The SCALE_MODE attrs only exist in cuBLAS 12.8+/ROCm 7.0 or in recent hipblaslt,
|
||||||
// but we must invoke get_scale_mode anyways to trigger the version checks.
|
// but we must invoke get_scale_mode anyways to trigger the version checks.
|
||||||
@ -2079,7 +2048,6 @@ void scaled_gemm(
|
|||||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||||
int returnedResult = 0;
|
int returnedResult = 0;
|
||||||
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
|
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
|
||||||
|
|
||||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||||
ltHandle,
|
ltHandle,
|
||||||
computeDesc.descriptor(),
|
computeDesc.descriptor(),
|
||||||
@ -2120,10 +2088,10 @@ void scaled_gemm(
|
|||||||
auto is_valid_status = hipblaslt_ext::matmulIsAlgoSupported(
|
auto is_valid_status = hipblaslt_ext::matmulIsAlgoSupported(
|
||||||
ltHandle,
|
ltHandle,
|
||||||
computeDesc.descriptor(),
|
computeDesc.descriptor(),
|
||||||
alpha_ptr,
|
&alpha_val,
|
||||||
Adesc.descriptor(),
|
Adesc.descriptor(),
|
||||||
Bdesc.descriptor(),
|
Bdesc.descriptor(),
|
||||||
beta_ptr,
|
&beta_val,
|
||||||
Cdesc.descriptor(),
|
Cdesc.descriptor(),
|
||||||
Ddesc.descriptor(),
|
Ddesc.descriptor(),
|
||||||
all_algos[i].algo,
|
all_algos[i].algo,
|
||||||
@ -2142,14 +2110,17 @@ void scaled_gemm(
|
|||||||
cublasStatus_t cublasStatus = cublasLtMatmul(
|
cublasStatus_t cublasStatus = cublasLtMatmul(
|
||||||
ltHandle,
|
ltHandle,
|
||||||
computeDesc.descriptor(),
|
computeDesc.descriptor(),
|
||||||
alpha_ptr,
|
&alpha_val,
|
||||||
mat1_ptr,
|
mat1_ptr,
|
||||||
Adesc.descriptor(),
|
Adesc.descriptor(),
|
||||||
mat2_ptr,
|
mat2_ptr,
|
||||||
Bdesc.descriptor(),
|
Bdesc.descriptor(),
|
||||||
beta_ptr,
|
&beta_val,
|
||||||
// NOTE: always use result_ptr here, because cuBLASLt w/device beta=0 can't handle nullptr either
|
#ifdef USE_ROCM
|
||||||
result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr
|
result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr
|
||||||
|
#else
|
||||||
|
nullptr,
|
||||||
|
#endif // ifdef USE_ROCM
|
||||||
Cdesc.descriptor(),
|
Cdesc.descriptor(),
|
||||||
result_ptr,
|
result_ptr,
|
||||||
Ddesc.descriptor(),
|
Ddesc.descriptor(),
|
||||||
|
|||||||
@ -161,8 +161,7 @@ void scaled_gemm(
|
|||||||
const void* result_scale_ptr,
|
const void* result_scale_ptr,
|
||||||
int64_t result_ld,
|
int64_t result_ld,
|
||||||
ScalarType result_dtype,
|
ScalarType result_dtype,
|
||||||
bool use_fast_accum,
|
bool use_fast_accum);
|
||||||
const std::optional<Tensor>& alpha);
|
|
||||||
|
|
||||||
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype)
|
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype)
|
||||||
|
|
||||||
|
|||||||
@ -325,9 +325,9 @@ uint64_t CUDAGeneratorImpl::seed() {
|
|||||||
*/
|
*/
|
||||||
c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
|
c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
|
||||||
// The RNG state comprises the seed, and an offset used for Philox.
|
// The RNG state comprises the seed, and an offset used for Philox.
|
||||||
constexpr size_t seed_size = sizeof(uint64_t);
|
static const size_t seed_size = sizeof(uint64_t);
|
||||||
constexpr size_t offset_size = sizeof(int64_t);
|
static const size_t offset_size = sizeof(int64_t);
|
||||||
constexpr size_t total_size = seed_size + offset_size;
|
static const size_t total_size = seed_size + offset_size;
|
||||||
|
|
||||||
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
|
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
|
||||||
auto rng_state = state_tensor.data_ptr<uint8_t>();
|
auto rng_state = state_tensor.data_ptr<uint8_t>();
|
||||||
@ -346,9 +346,9 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
|
|||||||
* and size of the internal state.
|
* and size of the internal state.
|
||||||
*/
|
*/
|
||||||
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
||||||
constexpr size_t seed_size = sizeof(uint64_t);
|
static const size_t seed_size = sizeof(uint64_t);
|
||||||
constexpr size_t offset_size = sizeof(int64_t);
|
static const size_t offset_size = sizeof(int64_t);
|
||||||
constexpr size_t total_size = seed_size + offset_size;
|
static const size_t total_size = seed_size + offset_size;
|
||||||
|
|
||||||
detail::check_rng_state(new_state);
|
detail::check_rng_state(new_state);
|
||||||
|
|
||||||
|
|||||||
@ -168,9 +168,11 @@ void CUDAGraph::instantiate() {
|
|||||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
|
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
|
||||||
// cudaGraphInstantiateWithFlags
|
// cudaGraphInstantiateWithFlags
|
||||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1ga2c652a24ba93e52b99a47bec0888233
|
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1ga2c652a24ba93e52b99a47bec0888233
|
||||||
|
#if !defined(USE_ROCM) || ROCM_VERSION >= 60200
|
||||||
int version = 0;
|
int version = 0;
|
||||||
AT_CUDA_CHECK(cudaDriverGetVersion(&version));
|
AT_CUDA_CHECK(cudaDriverGetVersion(&version));
|
||||||
if (version < 11040) {
|
if (version < 11040) {
|
||||||
|
#endif
|
||||||
// Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
|
// Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
|
||||||
// who prefer not to report error message through these arguments moving forward
|
// who prefer not to report error message through these arguments moving forward
|
||||||
// (they prefer return value, or errors on api calls internal to the capture)
|
// (they prefer return value, or errors on api calls internal to the capture)
|
||||||
@ -181,11 +183,13 @@ void CUDAGraph::instantiate() {
|
|||||||
#endif
|
#endif
|
||||||
//Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory.
|
//Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory.
|
||||||
//It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch.
|
//It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch.
|
||||||
|
#if !defined(USE_ROCM) || ROCM_VERSION >= 60200
|
||||||
} else {
|
} else {
|
||||||
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
|
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
|
||||||
graph_,
|
graph_,
|
||||||
cudaGraphInstantiateFlagAutoFreeOnLaunch));
|
cudaGraphInstantiateFlagAutoFreeOnLaunch));
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
has_graph_exec_ = true;
|
has_graph_exec_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -307,7 +311,7 @@ CUDAGraph::~CUDAGraph() {
|
|||||||
// There are recent HIP changes where hipGraphExecDestroy doesn't immediately free memory.
|
// There are recent HIP changes where hipGraphExecDestroy doesn't immediately free memory.
|
||||||
// They wait for next sync point in order to free the memory, this is to ensure that all
|
// They wait for next sync point in order to free the memory, this is to ensure that all
|
||||||
// hipGraphLaunch are finished before we release any memory. This feature was enabled in rocm6.2.
|
// hipGraphLaunch are finished before we release any memory. This feature was enabled in rocm6.2.
|
||||||
// We need to ensure all async operations finish before deleting the object.
|
// We need to ensure all async opreations finish before deleting the object.
|
||||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200)
|
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200)
|
||||||
if (capture_dev_ != UNDEFINED_DEVICE) // check if capture_dev_ contains the real device id
|
if (capture_dev_ != UNDEFINED_DEVICE) // check if capture_dev_ contains the real device id
|
||||||
{
|
{
|
||||||
|
|||||||
@ -1,192 +0,0 @@
|
|||||||
#include <ATen/cuda/CUDAGreenContext.h>
|
|
||||||
|
|
||||||
namespace at::cuda {
|
|
||||||
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
|
||||||
#if CUDA_HAS_GREEN_CONTEXT
|
|
||||||
int driver_version;
|
|
||||||
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
|
|
||||||
TORCH_CHECK(
|
|
||||||
driver_version >= 12080, "cuda driver too old to use green context!");
|
|
||||||
CUcontext pctx = nullptr;
|
|
||||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
|
|
||||||
if (C10_UNLIKELY(!pctx)) {
|
|
||||||
TORCH_WARN(
|
|
||||||
"Attempted to create a green context but"
|
|
||||||
" there was no primary context! Creating a primary context...");
|
|
||||||
|
|
||||||
cudaFree(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
CUdevice device;
|
|
||||||
device_id_ = device_id;
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
|
|
||||||
|
|
||||||
// Get device resources
|
|
||||||
CUdevResource device_resource;
|
|
||||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
|
|
||||||
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
|
|
||||||
|
|
||||||
// Split resources
|
|
||||||
std::vector<CUdevResource> result(1);
|
|
||||||
auto result_data = result.data();
|
|
||||||
unsigned int nb_groups = 1;
|
|
||||||
CUdevResource remaining;
|
|
||||||
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
|
|
||||||
result_data,
|
|
||||||
&nb_groups,
|
|
||||||
&device_resource,
|
|
||||||
&remaining,
|
|
||||||
0, // default flags
|
|
||||||
num_sms));
|
|
||||||
|
|
||||||
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
|
|
||||||
|
|
||||||
// Generate resource descriptor
|
|
||||||
CUdevResourceDesc desc;
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
|
|
||||||
&desc, result_data, 1));
|
|
||||||
|
|
||||||
// Create green context
|
|
||||||
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
|
|
||||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
|
||||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
|
|
||||||
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
|
|
||||||
|
|
||||||
// Convert to regular context
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
|
|
||||||
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<GreenContext> GreenContext::create(
|
|
||||||
uint32_t num_sms,
|
|
||||||
std::optional<uint32_t> device_id) {
|
|
||||||
#if CUDA_HAS_GREEN_CONTEXT
|
|
||||||
if (!device_id.has_value()) {
|
|
||||||
device_id = at::cuda::current_device();
|
|
||||||
}
|
|
||||||
return std::make_unique<GreenContext>(device_id.value(), num_sms);
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implement move operations
|
|
||||||
GreenContext::GreenContext(GreenContext&& other) noexcept{
|
|
||||||
#if CUDA_HAS_GREEN_CONTEXT
|
|
||||||
device_id_ = std::exchange(other.device_id_, -1);
|
|
||||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
|
||||||
context_ = std::exchange(other.context_, nullptr);
|
|
||||||
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
|
|
||||||
#if CUDA_HAS_GREEN_CONTEXT
|
|
||||||
if (this != &other) {
|
|
||||||
// Clean up current resources
|
|
||||||
if (green_ctx_) {
|
|
||||||
CUcontext current = nullptr;
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
|
|
||||||
if (current == context_) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
false,
|
|
||||||
"attempting to overwrite current green ctx "
|
|
||||||
"when it is active!");
|
|
||||||
}
|
|
||||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Take ownership of other's resources
|
|
||||||
device_id_ = std::exchange(other.device_id_, -1);
|
|
||||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
|
||||||
context_ = std::exchange(other.context_, nullptr);
|
|
||||||
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
|
|
||||||
}
|
|
||||||
return *this;
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
GreenContext::~GreenContext() noexcept{
|
|
||||||
#if CUDA_HAS_GREEN_CONTEXT
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the underlying CUDA context
|
|
||||||
CUcontext GreenContext::getContext() const {
|
|
||||||
#if CUDA_HAS_GREEN_CONTEXT
|
|
||||||
return context_;
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the underlying green context
|
|
||||||
#if CUDA_HAS_GREEN_CONTEXT
|
|
||||||
CUgreenCtx GreenContext::getGreenContext() const {
|
|
||||||
return green_ctx_;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Make this context current
|
|
||||||
void GreenContext::setContext() {
|
|
||||||
#if CUDA_HAS_GREEN_CONTEXT
|
|
||||||
auto current_stream = c10::cuda::getCurrentCUDAStream();
|
|
||||||
parent_stream_ = current_stream.stream();
|
|
||||||
|
|
||||||
at::cuda::CUDAEvent ev;
|
|
||||||
ev.record(current_stream);
|
|
||||||
|
|
||||||
CUcontext current = nullptr;
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
|
|
||||||
if (!current) {
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_));
|
|
||||||
} else {
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_));
|
|
||||||
}
|
|
||||||
// currently hardcodes the new green context to use the default stream
|
|
||||||
// TODO(eqy): consider creating a new stream if e.g., it allows interop
|
|
||||||
// with CUDA Graph captures etc.
|
|
||||||
auto default_stream = c10::cuda::getDefaultCUDAStream();
|
|
||||||
ev.block(default_stream);
|
|
||||||
c10::cuda::setCurrentCUDAStream(default_stream);
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
void GreenContext::popContext() {
|
|
||||||
#if CUDA_HAS_GREEN_CONTEXT
|
|
||||||
// see above note about stream being hardcoded to the default stream
|
|
||||||
at::cuda::CUDAEvent ev;
|
|
||||||
ev.record(c10::cuda::getCurrentCUDAStream());
|
|
||||||
CUcontext popped;
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped));
|
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
popped == context_, "expected popped context to be the current ctx");
|
|
||||||
ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_));
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
} // namespace at::cuda
|
|
||||||
@ -1,53 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
#include <ATen/cuda/CUDAEvent.h>
|
|
||||||
|
|
||||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
|
||||||
#include <c10/cuda/driver_api.h>
|
|
||||||
#include <cuda.h>
|
|
||||||
#include <memory>
|
|
||||||
#include <stdexcept>
|
|
||||||
#include <vector>
|
|
||||||
#define CUDA_HAS_GREEN_CONTEXT 1
|
|
||||||
#else
|
|
||||||
#define CUDA_HAS_GREEN_CONTEXT 0
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace at::cuda {
|
|
||||||
|
|
||||||
class TORCH_CUDA_CPP_API GreenContext {
|
|
||||||
public:
|
|
||||||
GreenContext(uint32_t device_id, uint32_t num_sms);
|
|
||||||
|
|
||||||
static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
|
|
||||||
|
|
||||||
// Delete copy constructor and assignment
|
|
||||||
GreenContext(const GreenContext&) = delete;
|
|
||||||
GreenContext& operator=(const GreenContext&) = delete;
|
|
||||||
|
|
||||||
// Implement move operations
|
|
||||||
GreenContext(GreenContext&& other) noexcept;
|
|
||||||
GreenContext& operator=(GreenContext&& other) noexcept;
|
|
||||||
~GreenContext() noexcept;
|
|
||||||
|
|
||||||
// Get the underlying CUDA context
|
|
||||||
CUcontext getContext() const;
|
|
||||||
|
|
||||||
// Get the underlying green context
|
|
||||||
#if CUDA_HAS_GREEN_CONTEXT
|
|
||||||
CUgreenCtx getGreenContext() const;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Make this context current
|
|
||||||
void setContext();
|
|
||||||
|
|
||||||
void popContext();
|
|
||||||
|
|
||||||
private:
|
|
||||||
#if CUDA_HAS_GREEN_CONTEXT
|
|
||||||
int32_t device_id_ = -1;
|
|
||||||
CUgreenCtx green_ctx_ = nullptr;
|
|
||||||
CUcontext context_ = nullptr;
|
|
||||||
cudaStream_t parent_stream_ = nullptr;
|
|
||||||
#endif
|
|
||||||
};
|
|
||||||
} // namespace at::cuda
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user