mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-27 00:54:52 +08:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			ciflow/tru
			...
			annotate_f
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 98826fd37b | 
| @ -113,7 +113,6 @@ case "$tag" in | ||||
|     UCX_COMMIT=${_UCX_COMMIT} | ||||
|     UCC_COMMIT=${_UCC_COMMIT} | ||||
|     TRITON=yes | ||||
|     INSTALL_MINGW=yes | ||||
|     ;; | ||||
|   pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11) | ||||
|     CUDA_VERSION=13.0.0 | ||||
| @ -362,7 +361,6 @@ docker build \ | ||||
|        --build-arg "OPENBLAS=${OPENBLAS:-}" \ | ||||
|        --build-arg "SKIP_SCCACHE_INSTALL=${SKIP_SCCACHE_INSTALL:-}" \ | ||||
|        --build-arg "SKIP_LLVM_SRC_BUILD_INSTALL=${SKIP_LLVM_SRC_BUILD_INSTALL:-}" \ | ||||
|        --build-arg "INSTALL_MINGW=${INSTALL_MINGW:-}" \ | ||||
|        -f $(dirname ${DOCKERFILE})/Dockerfile \ | ||||
|        -t "$tmp_tag" \ | ||||
|        "$@" \ | ||||
|  | ||||
| @ -83,6 +83,10 @@ function build_cpython { | ||||
|         py_suffix=${py_ver::-1} | ||||
|         py_folder=$py_suffix | ||||
|     fi | ||||
|     # Update to rc2 due to https://github.com/python/cpython/commit/c72699086fe4 | ||||
|     if [ "$py_suffix" == "3.14.0" ]; then | ||||
|         py_suffix="3.14.0rc2" | ||||
|     fi | ||||
|     wget -q $PYTHON_DOWNLOAD_URL/$py_folder/Python-$py_suffix.tgz -O Python-$py_ver.tgz | ||||
|     do_cpython_build $py_ver Python-$py_suffix | ||||
|  | ||||
|  | ||||
| @ -1,10 +0,0 @@ | ||||
| #!/bin/bash | ||||
|  | ||||
| set -ex | ||||
|  | ||||
| # Install MinGW-w64 for Windows cross-compilation | ||||
| apt-get update | ||||
| apt-get install -y g++-mingw-w64-x86-64-posix | ||||
|  | ||||
| echo "MinGW-w64 installed successfully" | ||||
| x86_64-w64-mingw32-g++ --version | ||||
| @ -19,8 +19,8 @@ pip_install \ | ||||
|   transformers==4.36.2 | ||||
|  | ||||
| pip_install coloredlogs packaging | ||||
| pip_install onnxruntime==1.23.1 | ||||
| pip_install onnxscript==0.5.4 | ||||
| pip_install onnxruntime==1.23.0 | ||||
| pip_install onnxscript==0.5.3 | ||||
|  | ||||
| # Cache the transformers model to be used later by ONNX tests. We need to run the transformers | ||||
| # package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/ | ||||
|  | ||||
| @ -39,13 +39,9 @@ case ${DOCKER_TAG_PREFIX} in | ||||
|         DOCKER_GPU_BUILD_ARG="" | ||||
|         ;; | ||||
|     rocm*) | ||||
|         # we want the patch version of 7.0 instead | ||||
|         if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then | ||||
|             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2" | ||||
|         fi | ||||
|         # we want the patch version of 6.4 instead | ||||
|         if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then | ||||
|             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4" | ||||
|             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2" | ||||
|         fi | ||||
|         BASE_TARGET=rocm | ||||
|         GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete | ||||
|  | ||||
| @ -75,13 +75,9 @@ case ${image} in | ||||
|         DOCKERFILE_SUFFIX="_cuda_aarch64" | ||||
|         ;; | ||||
|     manylinux2_28-builder:rocm*) | ||||
|         # we want the patch version of 7.0 instead | ||||
|         if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then | ||||
|             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2" | ||||
|         fi | ||||
|         # we want the patch version of 6.4 instead | ||||
|         if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then | ||||
|             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4" | ||||
|             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2" | ||||
|         fi | ||||
|         TARGET=rocm_final | ||||
|         MANY_LINUX_VERSION="2_28" | ||||
|  | ||||
| @ -334,12 +334,12 @@ sympy==1.13.3 | ||||
| #Pinned versions: | ||||
| #test that import: | ||||
|  | ||||
| onnx==1.19.1 | ||||
| onnx==1.18.0 | ||||
| #Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal | ||||
| #Pinned versions: | ||||
| #test that import: | ||||
|  | ||||
| onnxscript==0.5.4 | ||||
| onnxscript==0.5.3 | ||||
| #Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal | ||||
| #Pinned versions: | ||||
| #test that import: | ||||
|  | ||||
| @ -103,11 +103,6 @@ COPY ci_commit_pins/torchbench.txt torchbench.txt | ||||
| RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi | ||||
| RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt | ||||
|  | ||||
| ARG INSTALL_MINGW | ||||
| COPY ./common/install_mingw.sh install_mingw.sh | ||||
| RUN if [ -n "${INSTALL_MINGW}" ]; then bash ./install_mingw.sh; fi | ||||
| RUN rm install_mingw.sh | ||||
|  | ||||
| ARG TRITON | ||||
| ARG TRITON_CPU | ||||
|  | ||||
|  | ||||
| @ -57,8 +57,8 @@ def clone_external_repo(target: str, repo: str, dst: str = "", update_submodules | ||||
|         logger.info("Successfully cloned %s", target) | ||||
|         return r, commit | ||||
|  | ||||
|     except GitCommandError: | ||||
|         logger.exception("Git operation failed") | ||||
|     except GitCommandError as e: | ||||
|         logger.error("Git operation failed: %s", e) | ||||
|         raise | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -6,7 +6,7 @@ dependencies = [ | ||||
|     "GitPython==3.1.45", | ||||
|     "docker==7.1.0", | ||||
|     "pytest==7.3.2", | ||||
|     "uv==0.9.5" | ||||
|     "uv==0.8.6" | ||||
| ] | ||||
|  | ||||
| [tool.setuptools] | ||||
|  | ||||
| @ -485,22 +485,6 @@ test_inductor_aoti() { | ||||
|   /usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile | ||||
| } | ||||
|  | ||||
| test_inductor_aoti_cross_compile_for_windows() { | ||||
|  | ||||
|   TEST_REPORTS_DIR=$(pwd)/test/test-reports | ||||
|   mkdir -p "$TEST_REPORTS_DIR" | ||||
|  | ||||
|   # Set WINDOWS_CUDA_HOME environment variable | ||||
|   WINDOWS_CUDA_HOME="$(pwd)/win-torch-wheel-extracted" | ||||
|   export WINDOWS_CUDA_HOME | ||||
|  | ||||
|   echo "WINDOWS_CUDA_HOME is set to: $WINDOWS_CUDA_HOME" | ||||
|   echo "Contents:" | ||||
|   ls -lah "$(pwd)/win-torch-wheel-extracted/lib/x64/" || true | ||||
|  | ||||
|   python test/inductor/test_aoti_cross_compile_windows.py -k compile --package-dir "$TEST_REPORTS_DIR" --win-torch-lib-dir "$(pwd)/win-torch-wheel-extracted/torch/lib" | ||||
| } | ||||
|  | ||||
| test_inductor_cpp_wrapper_shard() { | ||||
|   if [[ -z "$NUM_TEST_SHARDS" ]]; then | ||||
|     echo "NUM_TEST_SHARDS must be defined to run a Python test shard" | ||||
| @ -916,7 +900,7 @@ test_inductor_set_cpu_affinity(){ | ||||
|   export LD_PRELOAD="$JEMALLOC_LIB":"$LD_PRELOAD" | ||||
|   export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1" | ||||
|  | ||||
|   if [[ "$(uname -m)" != "aarch64" ]]; then | ||||
|   if [[ "${TEST_CONFIG}" != *aarch64* ]]; then | ||||
|     # Use Intel OpenMP for x86 | ||||
|     IOMP_LIB="$(dirname "$(which python)")/../lib/libiomp5.so" | ||||
|     export LD_PRELOAD="$IOMP_LIB":"$LD_PRELOAD" | ||||
| @ -930,7 +914,7 @@ test_inductor_set_cpu_affinity(){ | ||||
|   cores=$((cpus / thread_per_core)) | ||||
|  | ||||
|   # Set number of cores to 16 on aarch64 for performance runs | ||||
|   if [[ "$(uname -m)" == "aarch64" && $cores -gt 16 ]]; then | ||||
|   if [[ "${TEST_CONFIG}" == *aarch64* && $cores -gt 16 ]]; then | ||||
|     cores=16 | ||||
|   fi | ||||
|   export OMP_NUM_THREADS=$cores | ||||
| @ -1683,7 +1667,7 @@ if [[ "${TEST_CONFIG}" == *numpy_2* ]]; then | ||||
|     python -m pip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0 | ||||
|   fi | ||||
|   python test/run_test.py --include dynamo/test_functions.py dynamo/test_unspec.py test_binary_ufuncs.py test_fake_tensor.py test_linalg.py test_numpy_interop.py test_tensor_creation_ops.py test_torch.py torch_np/test_basic.py | ||||
| elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" == 'default' ]]; then | ||||
| elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" != *perf_cpu_aarch64* ]]; then | ||||
|   test_linux_aarch64 | ||||
| elif [[ "${TEST_CONFIG}" == *backward* ]]; then | ||||
|   test_forward_backward_compatibility | ||||
| @ -1734,8 +1718,6 @@ elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then | ||||
|   test_inductor_triton_cpu | ||||
| elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then | ||||
|   test_inductor_micro_benchmark | ||||
| elif [[ "${TEST_CONFIG}" == *aoti_cross_compile_for_windows* ]]; then | ||||
|   test_inductor_aoti_cross_compile_for_windows | ||||
| elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then | ||||
|   install_torchvision | ||||
|   id=$((SHARD_NUMBER-1)) | ||||
|  | ||||
| @ -163,13 +163,8 @@ if [[ "$(uname)" != Darwin ]]; then | ||||
|   MEMORY_LIMIT_MAX_JOBS=12 | ||||
|   NUM_CPUS=$(( $(nproc) - 2 )) | ||||
|  | ||||
|   if [[ "$(uname)" == Linux ]]; then | ||||
|     # Defaults here for **binary** linux builds so they can be changed in one place | ||||
|     export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))} | ||||
|   else | ||||
|     # For other builds | ||||
|     export MAX_JOBS=${NUM_CPUS} | ||||
|   fi | ||||
|   # Defaults here for **binary** linux builds so they can be changed in one place | ||||
|   export MAX_JOBS=${MAX_JOBS:-$(( ${NUM_CPUS} > ${MEMORY_LIMIT_MAX_JOBS} ? ${MEMORY_LIMIT_MAX_JOBS} : ${NUM_CPUS} ))} | ||||
|  | ||||
|   cat >>"$envfile" <<EOL | ||||
|   export MAX_JOBS="${MAX_JOBS}" | ||||
|  | ||||
| @ -1,354 +0,0 @@ | ||||
| # PyTorch Docstring Writing Guide | ||||
|  | ||||
| This skill describes how to write docstrings for functions and methods in the PyTorch project, following the conventions in `torch/_tensor_docs.py` and `torch/nn/functional.py`. | ||||
|  | ||||
| ## General Principles | ||||
|  | ||||
| - Use **raw strings** (`r"""..."""`) for all docstrings to avoid issues with LaTeX/math backslashes | ||||
| - Follow **Sphinx/reStructuredText** (reST) format for documentation | ||||
| - Be **concise but complete** - include all essential information | ||||
| - Always include **examples** when possible | ||||
| - Use **cross-references** to related functions/classes | ||||
|  | ||||
| ## Docstring Structure | ||||
|  | ||||
| ### 1. Function Signature (First Line) | ||||
|  | ||||
| Start with the function signature showing all parameters: | ||||
|  | ||||
| ```python | ||||
| r"""function_name(param1, param2, *, kwarg1=default1, kwarg2=default2) -> ReturnType | ||||
| ``` | ||||
|  | ||||
| **Notes:** | ||||
| - Include the function name | ||||
| - Show positional and keyword-only arguments (use `*` separator) | ||||
| - Include default values | ||||
| - Show return type annotation | ||||
| - This line should NOT end with a period | ||||
|  | ||||
| ### 2. Brief Description | ||||
|  | ||||
| Provide a one-line description of what the function does: | ||||
|  | ||||
| ```python | ||||
| r"""conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor | ||||
|  | ||||
| Applies a 2D convolution over an input image composed of several input | ||||
| planes. | ||||
| ``` | ||||
|  | ||||
| ### 3. Mathematical Formulas (if applicable) | ||||
|  | ||||
| Use Sphinx math directives for mathematical expressions: | ||||
|  | ||||
| ```python | ||||
| .. math:: | ||||
|     \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} | ||||
| ``` | ||||
|  | ||||
| Or inline math: `:math:\`x^2\`` | ||||
|  | ||||
| ### 4. Cross-References | ||||
|  | ||||
| Link to related classes and functions using Sphinx roles: | ||||
|  | ||||
| - `:class:\`~torch.nn.ModuleName\`` - Link to a class | ||||
| - `:func:\`torch.function_name\`` - Link to a function | ||||
| - `:meth:\`~Tensor.method_name\`` - Link to a method | ||||
| - `:attr:\`attribute_name\`` - Reference an attribute | ||||
| - The `~` prefix shows only the last component (e.g., `Conv2d` instead of `torch.nn.Conv2d`) | ||||
|  | ||||
| **Example:** | ||||
| ```python | ||||
| See :class:`~torch.nn.Conv2d` for details and output shape. | ||||
| ``` | ||||
|  | ||||
| ### 5. Notes and Warnings | ||||
|  | ||||
| Use admonitions for important information: | ||||
|  | ||||
| ```python | ||||
| .. note:: | ||||
|     This function doesn't work directly with NLLLoss, | ||||
|     which expects the Log to be computed between the Softmax and itself. | ||||
|     Use log_softmax instead (it's faster and has better numerical properties). | ||||
|  | ||||
| .. warning:: | ||||
|     :func:`new_tensor` always copies :attr:`data`. If you have a Tensor | ||||
|     ``data`` and want to avoid a copy, use :func:`torch.Tensor.requires_grad_` | ||||
|     or :func:`torch.Tensor.detach`. | ||||
| ``` | ||||
|  | ||||
| ### 6. Args Section | ||||
|  | ||||
| Document all parameters with type annotations and descriptions: | ||||
|  | ||||
| ```python | ||||
| Args: | ||||
|     input (Tensor): input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` | ||||
|     weight (Tensor): filters of shape :math:`(\text{out\_channels} , kH , kW)` | ||||
|     bias (Tensor, optional): optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None`` | ||||
|     stride (int or tuple): the stride of the convolving kernel. Can be a single number or a | ||||
|       tuple `(sH, sW)`. Default: 1 | ||||
| ``` | ||||
|  | ||||
| **Formatting rules:** | ||||
| - Parameter name in **lowercase** | ||||
| - Type in parentheses: `(Type)`, `(Type, optional)` for optional parameters | ||||
| - Description follows the type | ||||
| - For optional parameters, include "Default: ``value``" at the end | ||||
| - Use double backticks for inline code: ``` ``None`` ``` | ||||
| - Indent continuation lines by 2 spaces | ||||
|  | ||||
| ### 7. Keyword Args Section (if applicable) | ||||
|  | ||||
| Sometimes keyword arguments are documented separately: | ||||
|  | ||||
| ```python | ||||
| Keyword args: | ||||
|     dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. | ||||
|         Default: if None, same :class:`torch.dtype` as this tensor. | ||||
|     device (:class:`torch.device`, optional): the desired device of returned tensor. | ||||
|         Default: if None, same :class:`torch.device` as this tensor. | ||||
|     requires_grad (bool, optional): If autograd should record operations on the | ||||
|         returned tensor. Default: ``False``. | ||||
| ``` | ||||
|  | ||||
| ### 8. Returns Section (if needed) | ||||
|  | ||||
| Document the return value: | ||||
|  | ||||
| ```python | ||||
| Returns: | ||||
|     Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution. | ||||
|         If ``hard=True``, the returned samples will be one-hot, otherwise they will | ||||
|         be probability distributions that sum to 1 across `dim`. | ||||
| ``` | ||||
|  | ||||
| Or simply include it in the function signature line if obvious from context. | ||||
|  | ||||
| ### 9. Examples Section | ||||
|  | ||||
| Always include examples when possible: | ||||
|  | ||||
| ```python | ||||
| Examples:: | ||||
|  | ||||
|     >>> inputs = torch.randn(33, 16, 30) | ||||
|     >>> filters = torch.randn(20, 16, 5) | ||||
|     >>> F.conv1d(inputs, filters) | ||||
|  | ||||
|     >>> # With square kernels and equal stride | ||||
|     >>> filters = torch.randn(8, 4, 3, 3) | ||||
|     >>> inputs = torch.randn(1, 4, 5, 5) | ||||
|     >>> F.conv2d(inputs, filters, padding=1) | ||||
| ``` | ||||
|  | ||||
| **Formatting rules:** | ||||
| - Use `Examples::` with double colon | ||||
| - Use `>>>` prompt for Python code | ||||
| - Include comments with `#` when helpful | ||||
| - Show actual output when it helps understanding (indent without `>>>`) | ||||
|  | ||||
| ### 10. External References | ||||
|  | ||||
| Link to papers or external documentation: | ||||
|  | ||||
| ```python | ||||
| .. _Link Name: | ||||
|     https://arxiv.org/abs/1611.00712 | ||||
| ``` | ||||
|  | ||||
| Reference them in text: ```See `Link Name`_``` | ||||
|  | ||||
| ## Method Types | ||||
|  | ||||
| ### Native Python Functions | ||||
|  | ||||
| For regular Python functions, use a standard docstring: | ||||
|  | ||||
| ```python | ||||
| def relu(input: Tensor, inplace: bool = False) -> Tensor: | ||||
|     r"""relu(input, inplace=False) -> Tensor | ||||
|  | ||||
|     Applies the rectified linear unit function element-wise. See | ||||
|     :class:`~torch.nn.ReLU` for more details. | ||||
|     """ | ||||
|     # implementation | ||||
| ``` | ||||
|  | ||||
| ### C-Bound Functions (using add_docstr) | ||||
|  | ||||
| For C-bound functions, use `_add_docstr`: | ||||
|  | ||||
| ```python | ||||
| conv1d = _add_docstr( | ||||
|     torch.conv1d, | ||||
|     r""" | ||||
| conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor | ||||
|  | ||||
| Applies a 1D convolution over an input signal composed of several input | ||||
| planes. | ||||
|  | ||||
| See :class:`~torch.nn.Conv1d` for details and output shape. | ||||
|  | ||||
| Args: | ||||
|     input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` | ||||
|     weight: filters of shape :math:`(\text{out\_channels} , kW)` | ||||
|     ... | ||||
| """, | ||||
| ) | ||||
| ``` | ||||
|  | ||||
| ### In-Place Variants | ||||
|  | ||||
| For in-place operations (ending with `_`), reference the original: | ||||
|  | ||||
| ```python | ||||
| add_docstr_all( | ||||
|     "abs_", | ||||
|     r""" | ||||
| abs_() -> Tensor | ||||
|  | ||||
| In-place version of :meth:`~Tensor.abs` | ||||
| """, | ||||
| ) | ||||
| ``` | ||||
|  | ||||
| ### Alias Functions | ||||
|  | ||||
| For aliases, simply reference the original: | ||||
|  | ||||
| ```python | ||||
| add_docstr_all( | ||||
|     "absolute", | ||||
|     r""" | ||||
| absolute() -> Tensor | ||||
|  | ||||
| Alias for :func:`abs` | ||||
| """, | ||||
| ) | ||||
| ``` | ||||
|  | ||||
| ## Common Patterns | ||||
|  | ||||
| ### Shape Documentation | ||||
|  | ||||
| Use LaTeX math notation for tensor shapes: | ||||
|  | ||||
| ```python | ||||
| :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` | ||||
| ``` | ||||
|  | ||||
| ### Reusable Argument Definitions | ||||
|  | ||||
| For commonly used arguments, define them once and reuse: | ||||
|  | ||||
| ```python | ||||
| common_args = parse_kwargs( | ||||
|     """ | ||||
|     dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. | ||||
|         Default: if None, same as this tensor. | ||||
| """ | ||||
| ) | ||||
|  | ||||
| # Then use with .format(): | ||||
| r""" | ||||
| ... | ||||
|  | ||||
| Keyword args: | ||||
|     {dtype} | ||||
|     {device} | ||||
| """.format(**common_args) | ||||
| ``` | ||||
|  | ||||
| ### Template Insertion | ||||
|  | ||||
| Insert reproducibility notes or other common text: | ||||
|  | ||||
| ```python | ||||
| r""" | ||||
| {tf32_note} | ||||
|  | ||||
| {cudnn_reproducibility_note} | ||||
| """.format(**reproducibility_notes, **tf32_notes) | ||||
| ``` | ||||
|  | ||||
| ## Complete Example | ||||
|  | ||||
| Here's a complete example showing all elements: | ||||
|  | ||||
| ```python | ||||
| def gumbel_softmax( | ||||
|     logits: Tensor, | ||||
|     tau: float = 1, | ||||
|     hard: bool = False, | ||||
|     eps: float = 1e-10, | ||||
|     dim: int = -1, | ||||
| ) -> Tensor: | ||||
|     r""" | ||||
|     Sample from the Gumbel-Softmax distribution and optionally discretize. | ||||
|  | ||||
|     Args: | ||||
|         logits (Tensor): `[..., num_features]` unnormalized log probabilities | ||||
|         tau (float): non-negative scalar temperature | ||||
|         hard (bool): if ``True``, the returned samples will be discretized as one-hot vectors, | ||||
|               but will be differentiated as if it is the soft sample in autograd. Default: ``False`` | ||||
|         dim (int): A dimension along which softmax will be computed. Default: -1 | ||||
|  | ||||
|     Returns: | ||||
|         Tensor: Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution. | ||||
|             If ``hard=True``, the returned samples will be one-hot, otherwise they will | ||||
|             be probability distributions that sum to 1 across `dim`. | ||||
|  | ||||
|     .. note:: | ||||
|         This function is here for legacy reasons, may be removed from nn.Functional in the future. | ||||
|  | ||||
|     Examples:: | ||||
|         >>> logits = torch.randn(20, 32) | ||||
|         >>> # Sample soft categorical using reparametrization trick: | ||||
|         >>> F.gumbel_softmax(logits, tau=1, hard=False) | ||||
|         >>> # Sample hard categorical using "Straight-through" trick: | ||||
|         >>> F.gumbel_softmax(logits, tau=1, hard=True) | ||||
|  | ||||
|     .. _Link 1: | ||||
|         https://arxiv.org/abs/1611.00712 | ||||
|     """ | ||||
|     # implementation | ||||
| ``` | ||||
|  | ||||
| ## Quick Checklist | ||||
|  | ||||
| When writing a PyTorch docstring, ensure: | ||||
|  | ||||
| - [ ] Use raw string (`r"""`) | ||||
| - [ ] Include function signature on first line | ||||
| - [ ] Provide brief description | ||||
| - [ ] Document all parameters in Args section with types | ||||
| - [ ] Include default values for optional parameters | ||||
| - [ ] Use Sphinx cross-references (`:func:`, `:class:`, `:meth:`) | ||||
| - [ ] Add mathematical formulas if applicable | ||||
| - [ ] Include at least one example in Examples section | ||||
| - [ ] Add warnings/notes for important caveats | ||||
| - [ ] Link to related module class with `:class:` | ||||
| - [ ] Use proper math notation for tensor shapes | ||||
| - [ ] Follow consistent formatting and indentation | ||||
|  | ||||
| ## Common Sphinx Roles Reference | ||||
|  | ||||
| - `:class:\`~torch.nn.Module\`` - Class reference | ||||
| - `:func:\`torch.function\`` - Function reference | ||||
| - `:meth:\`~Tensor.method\`` - Method reference | ||||
| - `:attr:\`attribute\`` - Attribute reference | ||||
| - `:math:\`equation\`` - Inline math | ||||
| - `:ref:\`label\`` - Internal reference | ||||
| - ``` ``code`` ``` - Inline code (use double backticks) | ||||
|  | ||||
| ## Additional Notes | ||||
|  | ||||
| - **Indentation**: Use 4 spaces for code, 2 spaces for continuation of parameter descriptions | ||||
| - **Line length**: Try to keep lines under 100 characters when possible | ||||
| - **Periods**: End sentences with periods, but not the signature line | ||||
| - **Backticks**: Use double backticks for code: ``` ``True`` ``None`` ``False`` ``` | ||||
| - **Types**: Common types are `Tensor`, `int`, `float`, `bool`, `str`, `tuple`, `list`, etc. | ||||
							
								
								
									
										6
									
								
								.flake8
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								.flake8
									
									
									
									
									
								
							| @ -7,12 +7,16 @@ max-line-length = 120 | ||||
| # C408 ignored because we like the dict keyword argument syntax | ||||
| # E501 is not flexible enough, we're using B950 instead | ||||
| ignore = | ||||
|     E203,E305,E402,E501,E704,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824, | ||||
|     E203,E305,E402,E501,E704,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824, | ||||
|     # shebang has extra meaning in fbcode lints, so I think it's not worth trying | ||||
|     # to line this up with executable bit | ||||
|     EXE001, | ||||
|     # these ignores are from flake8-bugbear; please fix! | ||||
|     B007,B008,B017,B019,B023,B028,B903,B905,B906,B907,B908,B910 | ||||
|     # these ignores are from flake8-comprehensions; please fix! | ||||
|     C407, | ||||
|     # these ignores are from flake8-logging-format; please fix! | ||||
|     G100,G101,G200 | ||||
|     # these ignores are from flake8-simplify. please fix or ignore with commented reason | ||||
|     SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12, | ||||
|     # SIM104 is already covered by pyupgrade ruff | ||||
|  | ||||
							
								
								
									
										7
									
								
								.github/actions/setup-rocm/action.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/actions/setup-rocm/action.yml
									
									
									
									
										vendored
									
									
								
							| @ -124,10 +124,3 @@ runs: | ||||
|       id: login-ecr | ||||
|       continue-on-error: true | ||||
|       uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1 | ||||
|  | ||||
|     - name: Preserve github env variables for use in docker | ||||
|       shell: bash | ||||
|       run: | | ||||
|         env | grep '^GITHUB' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}" | ||||
|         env | grep '^CI' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}" | ||||
|         env | grep '^RUNNER' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}" | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							| @ -1 +1 @@ | ||||
| 69bbe7363897764f9e758d851cd0340147d27f94 | ||||
| 1b013f5b5a87a1882eb143c26d79d091150d6a37 | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							| @ -1 +1 @@ | ||||
| 1752fe6809b74921644866275ab80244b96e80bc | ||||
| faffd5cf673615583da6517275e361cb3dbc77e6 | ||||
|  | ||||
							
								
								
									
										5
									
								
								.github/ci_configs/vllm/Dockerfile
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.github/ci_configs/vllm/Dockerfile
									
									
									
									
										vendored
									
									
								
							| @ -283,9 +283,6 @@ RUN --mount=type=bind,source=${TORCH_WHEELS_PATH},target=/dist \ | ||||
|         uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \ | ||||
|     fi | ||||
|  | ||||
| RUN --mount=type=cache,target=/root/.cache/uv \ | ||||
|     uv pip install --system --pre apache-tvm-ffi==0.1.0b15 | ||||
|  | ||||
| # Install the vllm wheel from previous stage | ||||
| RUN --mount=type=cache,target=/root/.cache/uv \ | ||||
|     uv pip install --system /wheels/vllm/*.whl --verbose | ||||
| @ -298,8 +295,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \ | ||||
| ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0' | ||||
| ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} | ||||
|  | ||||
| # TODO(elainewy): remove this once vllm commit is updated, and install flashinfer from pip | ||||
| # see https://github.com/pytorch/pytorch/pull/165274#issuecomment-3408531784 | ||||
| ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" | ||||
| ARG FLASHINFER_GIT_REF="v0.2.14.post1" | ||||
|  | ||||
|  | ||||
							
								
								
									
										9
									
								
								.github/label_to_label.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								.github/label_to_label.yml
									
									
									
									
										vendored
									
									
								
							| @ -15,11 +15,6 @@ | ||||
|   - "module: reinplacing" | ||||
|   then: | ||||
|   - "module: pt2-dispatcher" | ||||
| - any: | ||||
|   - "vllm-compile" | ||||
|   then: | ||||
|   - "module: vllm" | ||||
|   - "oncall: pt2" | ||||
| - any: | ||||
|   - "module: vmap" | ||||
|   then: | ||||
| @ -32,6 +27,10 @@ | ||||
|   - "module: pt2 optimizer" | ||||
|   then: | ||||
|   - "module: dynamo" | ||||
| - any: | ||||
|   - "module: flex attention" | ||||
|   then: | ||||
|   - "module: higher order operators" | ||||
| - any: | ||||
|   - "module: aotinductor" | ||||
|   then: | ||||
|  | ||||
							
								
								
									
										29
									
								
								.github/labeler.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										29
									
								
								.github/labeler.yml
									
									
									
									
										vendored
									
									
								
							| @ -133,32 +133,3 @@ | ||||
|  | ||||
| "ciflow/vllm": | ||||
| - .github/ci_commit_pins/vllm.txt | ||||
|  | ||||
| "ciflow/b200": | ||||
| - test/test_matmul_cuda.py | ||||
| - test/test_scaled_matmul_cuda.py | ||||
| - test/inductor/test_fp8.py | ||||
| - aten/src/ATen/native/cuda/Blas.cpp | ||||
| - torch/**/*cublas* | ||||
| - torch/_inductor/kernel/mm.py | ||||
| - test/inductor/test_max_autotune.py | ||||
| - third_party/fbgemm | ||||
|  | ||||
| "ciflow/h100": | ||||
| - test/test_matmul_cuda.py | ||||
| - test/test_scaled_matmul_cuda.py | ||||
| - test/inductor/test_fp8.py | ||||
| - aten/src/ATen/native/cuda/Blas.cpp | ||||
| - torch/**/*cublas* | ||||
| - torch/_inductor/kernel/mm.py | ||||
| - test/inductor/test_max_autotune.py | ||||
| - third_party/fbgemm | ||||
|  | ||||
| "ciflow/rocm": | ||||
| - test/test_matmul_cuda.py | ||||
| - test/test_scaled_matmul_cuda.py | ||||
| - test/inductor/test_fp8.py | ||||
| - aten/src/ATen/native/cuda/Blas.cpp | ||||
| - torch/_inductor/kernel/mm.py | ||||
| - test/inductor/test_max_autotune.py | ||||
| - third_party/fbgemm | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							| @ -3,7 +3,6 @@ ciflow_tracking_issue: 64124 | ||||
| ciflow_push_tags: | ||||
| - ciflow/b200 | ||||
| - ciflow/b200-symm-mem | ||||
| - ciflow/b200-distributed | ||||
| - ciflow/binaries | ||||
| - ciflow/binaries_libtorch | ||||
| - ciflow/binaries_wheel | ||||
| @ -33,7 +32,6 @@ ciflow_push_tags: | ||||
| - ciflow/rocm | ||||
| - ciflow/rocm-mi300 | ||||
| - ciflow/rocm-mi355 | ||||
| - ciflow/rocm-navi31 | ||||
| - ciflow/s390 | ||||
| - ciflow/slow | ||||
| - ciflow/torchbench | ||||
|  | ||||
							
								
								
									
										30
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										30
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							| @ -79,21 +79,21 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = { | ||||
|         "nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux'" | ||||
|     ), | ||||
|     "12.9": ( | ||||
|         "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | " | ||||
|         "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | " | ||||
|         "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | " | ||||
|         "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | " | ||||
|         "nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | " | ||||
|         "nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | " | ||||
|         "nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | " | ||||
|         "nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | " | ||||
|         "nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | " | ||||
|         "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | " | ||||
|         "nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | " | ||||
|         "nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | " | ||||
|         "nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | " | ||||
|         "nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | " | ||||
|         "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'" | ||||
|         "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | " | ||||
|         "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'" | ||||
|     ), | ||||
|     "13.0": ( | ||||
|         "nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | " | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/scripts/trymerge.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/scripts/trymerge.py
									
									
									
									
										vendored
									
									
								
							| @ -1092,7 +1092,7 @@ class GitHubPR: | ||||
|         editor = node["editor"] | ||||
|         return GitHubComment( | ||||
|             body_text=node["bodyText"], | ||||
|             created_at=node.get("createdAt", ""), | ||||
|             created_at=node["createdAt"] if "createdAt" in node else "", | ||||
|             author_login=node["author"]["login"], | ||||
|             author_url=node["author"].get("url", None), | ||||
|             author_association=node["authorAssociation"], | ||||
|  | ||||
| @ -26,8 +26,9 @@ name: !{{ build_environment }} | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v6 | ||||
|         with: | ||||
|           # TODO: Removeme once 3.14 is out | ||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||
|           python-version: "!{{ py_ver.strip('t') + ('.4' if '3.14' not in py_ver else '.0') }}" | ||||
|           python-version: "!{{ (py_ver.strip('t') + '.4') if '3.14' not in py_ver else '3.14.0-rc.2' }}" | ||||
|           freethreaded: !{{ "true" if py_ver.endswith('t') else "false" }} | ||||
| {%- endmacro %} | ||||
|  | ||||
|  | ||||
| @ -79,9 +79,9 @@ jobs: | ||||
|     runs-on: "windows-11-arm64-preview" | ||||
|     {%- else %} | ||||
|     {%- if branches == "nightly" %} | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     {%- else %} | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge.nonephemeral" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" | ||||
|     {%- endif %} | ||||
|     {%- endif %} | ||||
|     timeout-minutes: !{{ common.timeout_minutes_windows_binary }} | ||||
|  | ||||
							
								
								
									
										40
									
								
								.github/workflows/_linux-test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										40
									
								
								.github/workflows/_linux-test.yml
									
									
									
									
										vendored
									
									
								
							| @ -224,46 +224,6 @@ jobs: | ||||
|         continue-on-error: true | ||||
|         uses: ./.github/actions/download-td-artifacts | ||||
|  | ||||
|       - name: Download Windows torch wheel for cross-compilation | ||||
|         if: matrix.win_torch_wheel_artifact != '' | ||||
|         uses: seemethere/download-artifact-s3@1da556a7aa0a088e3153970611f6c432d58e80e6 # v4.2.0 | ||||
|         with: | ||||
|           name: ${{ matrix.win_torch_wheel_artifact }} | ||||
|           path: win-torch-wheel | ||||
|  | ||||
|       - name: Extract Windows wheel and setup CUDA libraries | ||||
|         if: matrix.win_torch_wheel_artifact != '' | ||||
|         shell: bash | ||||
|         run: | | ||||
|           set -x | ||||
|  | ||||
|           # Find the wheel file | ||||
|           WHEEL_FILE=$(find win-torch-wheel -name "*.whl" -type f | head -n 1) | ||||
|           if [ -z "$WHEEL_FILE" ]; then | ||||
|             echo "Error: No wheel file found in win-torch-wheel directory" | ||||
|             exit 1 | ||||
|           fi | ||||
|           echo "Found wheel file: $WHEEL_FILE" | ||||
|  | ||||
|           # Unzip the wheel file | ||||
|           unzip -q "$WHEEL_FILE" -d win-torch-wheel-extracted | ||||
|           echo "Extracted wheel contents" | ||||
|  | ||||
|           # Setup CUDA libraries (cuda.lib and cudart.lib) directory | ||||
|           mkdir -p win-torch-wheel-extracted/lib/x64 | ||||
|           if [ -f "win-torch-wheel/cuda.lib" ]; then | ||||
|             mv win-torch-wheel/cuda.lib win-torch-wheel-extracted/lib/x64/ | ||||
|             echo "Moved cuda.lib to win-torch-wheel-extracted/lib/x64/" | ||||
|           fi | ||||
|           if [ -f "win-torch-wheel/cudart.lib" ]; then | ||||
|             mv win-torch-wheel/cudart.lib win-torch-wheel-extracted/lib/x64/ | ||||
|             echo "Moved cudart.lib to win-torch-wheel-extracted/lib/x64/" | ||||
|           fi | ||||
|  | ||||
|           # Verify CUDA libraries are present | ||||
|           echo "CUDA libraries:" | ||||
|           ls -la win-torch-wheel-extracted/lib/x64/ || echo "No CUDA libraries found" | ||||
|  | ||||
|       - name: Parse ref | ||||
|         id: parse-ref | ||||
|         run: .github/scripts/parse_ref.py | ||||
|  | ||||
							
								
								
									
										25
									
								
								.github/workflows/_win-build.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										25
									
								
								.github/workflows/_win-build.yml
									
									
									
									
										vendored
									
									
								
							| @ -168,31 +168,6 @@ jobs: | ||||
|         run: | | ||||
|           .ci/pytorch/win-build.sh | ||||
|  | ||||
|       # Collect Windows torch libs and CUDA libs for cross-compilation | ||||
|       - name: Collect Windows CUDA libs for cross-compilation | ||||
|         if: steps.build.outcome != 'skipped' && inputs.cuda-version != 'cpu' | ||||
|         shell: bash | ||||
|         run: | | ||||
|           set -ex | ||||
|  | ||||
|           # Create directory structure if does not exist | ||||
|           mkdir -p /c/${{ github.run_id }}/build-results | ||||
|  | ||||
|           # Copy CUDA libs | ||||
|           CUDA_PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${{ inputs.cuda-version }}" | ||||
|  | ||||
|           if [ -f "${CUDA_PATH}/lib/x64/cuda.lib" ]; then | ||||
|             cp "${CUDA_PATH}/lib/x64/cuda.lib" /c/${{ github.run_id }}/build-results/ | ||||
|           fi | ||||
|  | ||||
|           if [ -f "${CUDA_PATH}/lib/x64/cudart.lib" ]; then | ||||
|             cp "${CUDA_PATH}/lib/x64/cudart.lib" /c/${{ github.run_id }}/build-results/ | ||||
|           fi | ||||
|  | ||||
|           # List collected files | ||||
|           echo "Collected CUDA libs:" | ||||
|           ls -lah /c/${{ github.run_id }}/build-results/*.lib | ||||
|  | ||||
|       # Upload to github so that people can click and download artifacts | ||||
|       - name: Upload artifacts to s3 | ||||
|         if: steps.build.outcome != 'skipped' | ||||
|  | ||||
							
								
								
									
										62
									
								
								.github/workflows/b200-distributed.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										62
									
								
								.github/workflows/b200-distributed.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,62 +0,0 @@ | ||||
| name: CI for distributed tests on B200 | ||||
|  | ||||
| on: | ||||
|   pull_request: | ||||
|     paths: | ||||
|       - .github/workflows/b200-distributed.yml | ||||
|   workflow_dispatch: | ||||
|   push: | ||||
|     tags: | ||||
|       - ciflow/b200-distributed/* | ||||
|   schedule: | ||||
|     - cron: 46 8 * * *  # about 1:46am PDT | ||||
|  | ||||
| concurrency: | ||||
|   group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} | ||||
|   cancel-in-progress: true | ||||
|  | ||||
| permissions: | ||||
|   id-token: write | ||||
|   contents: read | ||||
|  | ||||
| jobs: | ||||
|  | ||||
|   get-label-type: | ||||
|     if: github.repository_owner == 'pytorch' | ||||
|     name: get-label-type | ||||
|     uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main | ||||
|     with: | ||||
|       triggering_actor: ${{ github.triggering_actor }} | ||||
|       issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} | ||||
|       curr_branch: ${{ github.head_ref || github.ref_name }} | ||||
|       curr_ref_type: ${{ github.ref_type }} | ||||
|  | ||||
|   linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200: | ||||
|     name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed-b200 | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     needs: get-label-type | ||||
|     with: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       runner: linux.12xlarge.memory | ||||
|       build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200 | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 | ||||
|       cuda-arch-list: '10.0' | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "distributed", shard: 1, num_shards: 2, runner: "linux.dgx.b200.8" }, | ||||
|           { config: "distributed", shard: 2, num_shards: 2, runner: "linux.dgx.b200.8" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   linux-jammy-cuda12_8-py3_10-gcc11-test-distributed-b200: | ||||
|     name: linux-jammy-cuda12.8-py3.10-gcc11-test-b200 | ||||
|     uses: ./.github/workflows/_linux-test.yml | ||||
|     needs: | ||||
|       - linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200 | ||||
|     with: | ||||
|       timeout-minutes: 1200 | ||||
|       build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200 | ||||
|       docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.test-matrix }} | ||||
|       aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only | ||||
|     secrets: inherit | ||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -224,7 +224,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_10-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
| @ -473,7 +473,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_11-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
| @ -722,7 +722,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_12-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
| @ -971,7 +971,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_13-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
| @ -1220,7 +1220,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_13t-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
| @ -1469,7 +1469,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_14-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
| @ -1718,7 +1718,7 @@ jobs: | ||||
|       ALPINE_IMAGE: "arm64v8/alpine" | ||||
|       build_name: manywheel-py3_14t-cuda-aarch64-12_9 | ||||
|       build_environment: linux-aarch64-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|       timeout-minutes: 420 | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|  | ||||
							
								
								
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -259,7 +259,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_10-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_10-cuda12_9-test:  # Testing | ||||
| @ -925,7 +925,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_11-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_11-cuda12_9-test:  # Testing | ||||
| @ -1591,7 +1591,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_12-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_12-cuda12_9-test:  # Testing | ||||
| @ -2257,7 +2257,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_13-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_13-cuda12_9-test:  # Testing | ||||
| @ -2923,7 +2923,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_13t-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_13t-cuda12_9-test:  # Testing | ||||
| @ -3589,7 +3589,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_14-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_14-cuda12_9-test:  # Testing | ||||
| @ -4255,7 +4255,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build_name: manywheel-py3_14t-cuda12_9 | ||||
|       build_environment: linux-binary-manywheel | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' | ||||
|       PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' | ||||
|     secrets: | ||||
|       github-token: ${{ secrets.GITHUB_TOKEN }} | ||||
|   manywheel-py3_14t-cuda12_9-test:  # Testing | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -63,6 +63,7 @@ jobs: | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v6 | ||||
|         with: | ||||
|           # TODO: Removeme once 3.14 is out | ||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||
|           python-version: "3.10.4" | ||||
|           freethreaded: false | ||||
|  | ||||
							
								
								
									
										11
									
								
								.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										11
									
								
								.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -59,6 +59,7 @@ jobs: | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v6 | ||||
|         with: | ||||
|           # TODO: Removeme once 3.14 is out | ||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||
|           python-version: "3.10.4" | ||||
|           freethreaded: false | ||||
| @ -168,6 +169,7 @@ jobs: | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v6 | ||||
|         with: | ||||
|           # TODO: Removeme once 3.14 is out | ||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||
|           python-version: "3.11.4" | ||||
|           freethreaded: false | ||||
| @ -277,6 +279,7 @@ jobs: | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v6 | ||||
|         with: | ||||
|           # TODO: Removeme once 3.14 is out | ||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||
|           python-version: "3.12.4" | ||||
|           freethreaded: false | ||||
| @ -386,6 +389,7 @@ jobs: | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v6 | ||||
|         with: | ||||
|           # TODO: Removeme once 3.14 is out | ||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||
|           python-version: "3.13.4" | ||||
|           freethreaded: false | ||||
| @ -495,6 +499,7 @@ jobs: | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v6 | ||||
|         with: | ||||
|           # TODO: Removeme once 3.14 is out | ||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||
|           python-version: "3.13.4" | ||||
|           freethreaded: true | ||||
| @ -604,8 +609,9 @@ jobs: | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v6 | ||||
|         with: | ||||
|           # TODO: Removeme once 3.14 is out | ||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||
|           python-version: "3.14.0" | ||||
|           python-version: "3.14.0-rc.2" | ||||
|           freethreaded: false | ||||
|       - name: Checkout PyTorch | ||||
|         uses: actions/checkout@v4 | ||||
| @ -713,8 +719,9 @@ jobs: | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v6 | ||||
|         with: | ||||
|           # TODO: Removeme once 3.14 is out | ||||
|           # .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3 | ||||
|           python-version: "3.14.0" | ||||
|           python-version: "3.14.0-rc.2" | ||||
|           freethreaded: true | ||||
|       - name: Checkout PyTorch | ||||
|         uses: actions/checkout@v4 | ||||
|  | ||||
							
								
								
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -44,7 +44,7 @@ jobs: | ||||
|   libtorch-cpu-shared-with-deps-debug-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -291,7 +291,7 @@ jobs: | ||||
|   libtorch-cuda12_6-shared-with-deps-debug-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -541,7 +541,7 @@ jobs: | ||||
|   libtorch-cuda12_8-shared-with-deps-debug-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -791,7 +791,7 @@ jobs: | ||||
|   libtorch-cuda13_0-shared-with-deps-debug-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
|  | ||||
							
								
								
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/generated-windows-binary-libtorch-release-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -44,7 +44,7 @@ jobs: | ||||
|   libtorch-cpu-shared-with-deps-release-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -291,7 +291,7 @@ jobs: | ||||
|   libtorch-cuda12_6-shared-with-deps-release-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -541,7 +541,7 @@ jobs: | ||||
|   libtorch-cuda12_8-shared-with-deps-release-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -791,7 +791,7 @@ jobs: | ||||
|   libtorch-cuda13_0-shared-with-deps-release-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
|  | ||||
							
								
								
									
										70
									
								
								.github/workflows/generated-windows-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										70
									
								
								.github/workflows/generated-windows-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							| @ -44,7 +44,7 @@ jobs: | ||||
|   wheel-py3_10-cpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -279,7 +279,7 @@ jobs: | ||||
|   wheel-py3_10-cuda12_6-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -517,7 +517,7 @@ jobs: | ||||
|   wheel-py3_10-cuda12_8-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -755,7 +755,7 @@ jobs: | ||||
|   wheel-py3_10-cuda13_0-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -993,7 +993,7 @@ jobs: | ||||
|   wheel-py3_10-xpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -1229,7 +1229,7 @@ jobs: | ||||
|   wheel-py3_11-cpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -1464,7 +1464,7 @@ jobs: | ||||
|   wheel-py3_11-cuda12_6-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -1702,7 +1702,7 @@ jobs: | ||||
|   wheel-py3_11-cuda12_8-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -1940,7 +1940,7 @@ jobs: | ||||
|   wheel-py3_11-cuda13_0-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -2178,7 +2178,7 @@ jobs: | ||||
|   wheel-py3_11-xpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -2414,7 +2414,7 @@ jobs: | ||||
|   wheel-py3_12-cpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -2649,7 +2649,7 @@ jobs: | ||||
|   wheel-py3_12-cuda12_6-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -2887,7 +2887,7 @@ jobs: | ||||
|   wheel-py3_12-cuda12_8-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -3125,7 +3125,7 @@ jobs: | ||||
|   wheel-py3_12-cuda13_0-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -3363,7 +3363,7 @@ jobs: | ||||
|   wheel-py3_12-xpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -3599,7 +3599,7 @@ jobs: | ||||
|   wheel-py3_13-cpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -3834,7 +3834,7 @@ jobs: | ||||
|   wheel-py3_13-cuda12_6-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -4072,7 +4072,7 @@ jobs: | ||||
|   wheel-py3_13-cuda12_8-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -4310,7 +4310,7 @@ jobs: | ||||
|   wheel-py3_13-cuda13_0-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -4548,7 +4548,7 @@ jobs: | ||||
|   wheel-py3_13-xpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -4784,7 +4784,7 @@ jobs: | ||||
|   wheel-py3_13t-cpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -5019,7 +5019,7 @@ jobs: | ||||
|   wheel-py3_13t-cuda12_6-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -5257,7 +5257,7 @@ jobs: | ||||
|   wheel-py3_13t-cuda12_8-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -5495,7 +5495,7 @@ jobs: | ||||
|   wheel-py3_13t-cuda13_0-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -5733,7 +5733,7 @@ jobs: | ||||
|   wheel-py3_13t-xpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -5969,7 +5969,7 @@ jobs: | ||||
|   wheel-py3_14-cpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -6204,7 +6204,7 @@ jobs: | ||||
|   wheel-py3_14-cuda12_6-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -6442,7 +6442,7 @@ jobs: | ||||
|   wheel-py3_14-cuda12_8-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -6680,7 +6680,7 @@ jobs: | ||||
|   wheel-py3_14-cuda13_0-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -6918,7 +6918,7 @@ jobs: | ||||
|   wheel-py3_14-xpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -7154,7 +7154,7 @@ jobs: | ||||
|   wheel-py3_14t-cpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -7389,7 +7389,7 @@ jobs: | ||||
|   wheel-py3_14t-cuda12_6-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -7627,7 +7627,7 @@ jobs: | ||||
|   wheel-py3_14t-cuda12_8-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -7865,7 +7865,7 @@ jobs: | ||||
|   wheel-py3_14t-cuda13_0-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
| @ -8103,7 +8103,7 @@ jobs: | ||||
|   wheel-py3_14t-xpu-build: | ||||
|     if: ${{ github.repository_owner == 'pytorch' }} | ||||
|     needs: get-label-type | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" | ||||
|     runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" | ||||
|     timeout-minutes: 360 | ||||
|     env: | ||||
|       PYTORCH_ROOT: ${{ github.workspace }}/pytorch | ||||
|  | ||||
| @ -88,27 +88,27 @@ jobs: | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/workflows/inductor-periodic.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/inductor-periodic.yml
									
									
									
									
										vendored
									
									
								
							| @ -88,6 +88,7 @@ jobs: | ||||
|     with: | ||||
|       build-environment: linux-jammy-rocm-py3_10 | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks | ||||
|       sync-tag: rocm-build | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|  | ||||
							
								
								
									
										24
									
								
								.github/workflows/operator_benchmark.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										24
									
								
								.github/workflows/operator_benchmark.yml
									
									
									
									
										vendored
									
									
								
							| @ -52,27 +52,3 @@ jobs: | ||||
|       docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }} | ||||
|     secrets: inherit | ||||
|  | ||||
|   aarch64-opbenchmark-build: | ||||
|     if: github.repository_owner == 'pytorch' | ||||
|     name: aarch64-opbenchmark-build | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     with: | ||||
|       build-environment: linux-jammy-aarch64-py3.10 | ||||
|       runner: linux.arm64.m7g.4xlarge | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11 | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   aarch64-opbenchmark-test: | ||||
|     name: aarch64-opbenchmark-test | ||||
|     uses: ./.github/workflows/_linux-test.yml | ||||
|     needs: aarch64-opbenchmark-build | ||||
|     with: | ||||
|       build-environment: linux-jammy-aarch64-py3.10 | ||||
|       docker-image: ${{ needs.aarch64-opbenchmark-build.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.aarch64-opbenchmark-build.outputs.test-matrix }} | ||||
|     secrets: inherit | ||||
|  | ||||
							
								
								
									
										15
									
								
								.github/workflows/periodic.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										15
									
								
								.github/workflows/periodic.yml
									
									
									
									
										vendored
									
									
								
							| @ -147,16 +147,15 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9 | ||||
|       cuda-arch-list: 8.9 | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 4, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|           { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|  | ||||
							
								
								
									
										3
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							| @ -347,8 +347,7 @@ jobs: | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     needs: get-label-type | ||||
|     with: | ||||
|       # This should sync with the build in xpu.yml but xpu uses a larger runner | ||||
|       # sync-tag: linux-xpu-n-build | ||||
|       sync-tag: linux-xpu-n-build | ||||
|       runner_prefix: ${{ needs.get-label-type.outputs.label-type }} | ||||
|       build-environment: linux-jammy-xpu-n-py3.10 | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3 | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/workflows/rocm-mi300.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/rocm-mi300.yml
									
									
									
									
										vendored
									
									
								
							| @ -45,6 +45,7 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build-environment: linux-noble-rocm-py3.12-mi300 | ||||
|       docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 | ||||
|       sync-tag: rocm-build | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|  | ||||
							
								
								
									
										13
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										13
									
								
								.github/workflows/rocm-mi355.yml
									
									
									
									
										vendored
									
									
								
							| @ -42,14 +42,15 @@ jobs: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build-environment: linux-noble-rocm-py3.12-mi355 | ||||
|       docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 | ||||
|       sync-tag: rocm-build | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" }, | ||||
|           { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|           { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|  | ||||
							
								
								
									
										75
									
								
								.github/workflows/rocm-navi31.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										75
									
								
								.github/workflows/rocm-navi31.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,75 +0,0 @@ | ||||
| name: rocm-navi31 | ||||
|  | ||||
| on: | ||||
|   push: | ||||
|     tags: | ||||
|       - ciflow/rocm-navi31/* | ||||
|   workflow_dispatch: | ||||
|   schedule: | ||||
|     # We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs. | ||||
|     # Also run less frequently on weekends. | ||||
|     - cron: 45 */2 * * 1-5 | ||||
|     - cron: 45 4,12 * * 0,6 | ||||
|  | ||||
| concurrency: | ||||
|   group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} | ||||
|   cancel-in-progress: true | ||||
|  | ||||
| permissions: read-all | ||||
|  | ||||
| jobs: | ||||
|   target-determination: | ||||
|     if: github.repository_owner == 'pytorch' | ||||
|     name: before-test | ||||
|     uses: ./.github/workflows/target_determination.yml | ||||
|     permissions: | ||||
|       id-token: write | ||||
|       contents: read | ||||
|  | ||||
|   get-label-type: | ||||
|     name: get-label-type | ||||
|     uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main | ||||
|     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} | ||||
|     with: | ||||
|       triggering_actor: ${{ github.triggering_actor }} | ||||
|       issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} | ||||
|       curr_branch: ${{ github.head_ref || github.ref_name }} | ||||
|       curr_ref_type: ${{ github.ref_type }} | ||||
|  | ||||
|   linux-jammy-rocm-py3_10-build: | ||||
|     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} | ||||
|     name: linux-jammy-rocm-py3.10 | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     needs: get-label-type | ||||
|     with: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build-environment: linux-jammy-rocm-py3.10 | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 | ||||
|       sync-tag: rocm-build | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" }, | ||||
|           { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   linux-jammy-rocm-py3_10-test: | ||||
|     permissions: | ||||
|       id-token: write | ||||
|       contents: read | ||||
|     name: linux-jammy-rocm-py3_10 | ||||
|     uses: ./.github/workflows/_rocm-test.yml | ||||
|     needs: | ||||
|       - linux-jammy-rocm-py3_10-build | ||||
|       - target-determination | ||||
|     with: | ||||
|       build-environment: linux-jammy-rocm-py3.10 | ||||
|       docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} | ||||
|       tests-to-include: >- | ||||
|          ${{ github.event_name == 'schedule' && 'test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs | ||||
|          test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark | ||||
|          inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor | ||||
|          inductor/test_torchinductor inductor/test_decompose_mem_bound_mm | ||||
|          inductor/test_flex_attention inductor/test_max_autotune' || '' }} | ||||
|     secrets: inherit | ||||
							
								
								
									
										38
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										38
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							| @ -26,23 +26,11 @@ jobs: | ||||
|       id-token: write | ||||
|       contents: read | ||||
|  | ||||
|   get-label-type: | ||||
|     name: get-label-type | ||||
|     uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main | ||||
|     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} | ||||
|     with: | ||||
|       triggering_actor: ${{ github.triggering_actor }} | ||||
|       issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} | ||||
|       curr_branch: ${{ github.head_ref || github.ref_name }} | ||||
|       curr_ref_type: ${{ github.ref_type }} | ||||
|  | ||||
|   linux-jammy-rocm-py3_10-build: | ||||
|     if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} | ||||
|     name: linux-jammy-rocm-py3.10 | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     needs: get-label-type | ||||
|     with: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build-environment: linux-jammy-rocm-py3.10 | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 | ||||
|       sync-tag: rocm-build | ||||
| @ -71,3 +59,29 @@ jobs: | ||||
|       docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} | ||||
|     secrets: inherit | ||||
|  | ||||
|   linux-jammy-rocm-py3_10-gfx1100-test: | ||||
|     if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} | ||||
|     permissions: | ||||
|       id-token: write | ||||
|       contents: read | ||||
|     name: linux-jammy-rocm-py3_10-gfx1100 | ||||
|     uses: ./.github/workflows/_rocm-test.yml | ||||
|     needs: | ||||
|       - linux-jammy-rocm-py3_10-build | ||||
|       - target-determination | ||||
|     with: | ||||
|       build-environment: linux-jammy-rocm-py3.10 | ||||
|       docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" }, | ||||
|           { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1100" }, | ||||
|         ]} | ||||
|       tests-to-include: > | ||||
|          test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs | ||||
|          test_autograd inductor/test_torchinductor inductor/test_kernel_benchmark | ||||
|          inductor/test_pad_mm inductor/test_benchmark_fusion inductor/test_aot_inductor | ||||
|          inductor/test_torchinductor inductor/test_decompose_mem_bound_mm | ||||
|          inductor/test_flex_attention inductor/test_max_autotune | ||||
|     secrets: inherit | ||||
|  | ||||
							
								
								
									
										149
									
								
								.github/workflows/trunk-tagging.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										149
									
								
								.github/workflows/trunk-tagging.yml
									
									
									
									
										vendored
									
									
								
							| @ -58,10 +58,8 @@ jobs: | ||||
|           else | ||||
|             COMMIT_SHA="${{ github.sha }}" | ||||
|           fi | ||||
|           { | ||||
|             echo "sha=${COMMIT_SHA}" | ||||
|             echo "tag_name=trunk/${COMMIT_SHA}" | ||||
|           } >> "${GITHUB_OUTPUT}" | ||||
|           echo "sha=${COMMIT_SHA}" >> "${GITHUB_OUTPUT}" | ||||
|           echo "tag_name=trunk/${COMMIT_SHA}" >> "${GITHUB_OUTPUT}" | ||||
|  | ||||
|       - name: Validate commit SHA | ||||
|         run: | | ||||
| @ -89,7 +87,7 @@ jobs: | ||||
|             echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)" | ||||
|           fi | ||||
|  | ||||
|       - name: Create and push tag(s) with retry | ||||
|       - name: Create and push tag with retry | ||||
|         id: check_tag | ||||
|         env: | ||||
|           TAG_NAME: ${{ steps.commit.outputs.tag_name }} | ||||
| @ -114,23 +112,14 @@ jobs: | ||||
|             return 1 | ||||
|           } | ||||
|  | ||||
|           # Counters for summary reporting | ||||
|           created_count=0 | ||||
|           skipped_count=0 | ||||
|           failed_count=0 | ||||
|           # Exit early if tag already exists | ||||
|           if check_tag_exists; then | ||||
|             echo "✅ Tag already exists - no action needed" | ||||
|             echo "exists=true" >> "${GITHUB_OUTPUT}" | ||||
|             exit 0 | ||||
|           fi | ||||
|  | ||||
|           # Always write outputs once on exit | ||||
|           finish() { | ||||
|             set +e | ||||
|             if [ -n "${GITHUB_OUTPUT:-}" ]; then | ||||
|               { | ||||
|                 echo "created_count=${created_count}" | ||||
|                 echo "skipped_count=${skipped_count}" | ||||
|                 echo "failed_count=${failed_count}" | ||||
|               } >> "${GITHUB_OUTPUT}" | ||||
|             fi | ||||
|           } | ||||
|           trap finish EXIT | ||||
|           echo "Tag ${TAG_NAME} does not exist, proceeding with creation" | ||||
|  | ||||
|           # Retry configuration | ||||
|           MAX_RETRIES=5 | ||||
| @ -205,111 +194,31 @@ jobs: | ||||
|             } | ||||
|           } | ||||
|  | ||||
|           # New behavior for push events: enumerate commits in the push and tag each one. | ||||
|           # For workflow_dispatch, retain existing single-SHA behavior. | ||||
|  | ||||
|           # Always fetch tags once up front to improve idempotency in loops | ||||
|           git fetch origin --tags --quiet || true | ||||
|  | ||||
|           if [ "${{ github.event_name }}" = "push" ]; then | ||||
|             BEFORE_SHA="${{ github.event.before }}" | ||||
|             AFTER_SHA="${{ github.sha }}"  # same as event.after | ||||
|  | ||||
|             # List commits introduced by this push (old..new), oldest first for stable ordering | ||||
|             commits_file="$(mktemp)" | ||||
|             git rev-list --reverse "${BEFORE_SHA}..${AFTER_SHA}" > "${commits_file}" | ||||
|  | ||||
|             if [ ! -s "${commits_file}" ]; then | ||||
|               echo "No new commits found between ${BEFORE_SHA}..${AFTER_SHA}; nothing to tag." | ||||
|               rm -f "${commits_file}" | ||||
|               exit 0 | ||||
|             fi | ||||
|  | ||||
|             commit_count="$(wc -l < "${commits_file}" | tr -d ' ')" | ||||
|             echo "Found ${commit_count} commit(s) to tag for push:" | ||||
|             while IFS= read -r sha; do | ||||
|               printf '  %s\n' "${sha}" | ||||
|             done < "${commits_file}" | ||||
|  | ||||
|             while IFS= read -r sha; do | ||||
|               TAG_NAME="trunk/${sha}" | ||||
|               COMMIT_SHA="${sha}" | ||||
|  | ||||
|               # If tag already exists locally or remotely, skip (idempotent) | ||||
|               if check_tag_exists; then | ||||
|                 echo "✅ Tag ${TAG_NAME} already exists - skipping" | ||||
|                 skipped_count=$((skipped_count + 1)) | ||||
|                 continue | ||||
|               fi | ||||
|  | ||||
|               echo "Tag ${TAG_NAME} does not exist, proceeding with creation" | ||||
|  | ||||
|               if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then | ||||
|                 created_count=$((created_count + 1)) | ||||
|               else | ||||
|                 echo "Tag creation failed after all retry attempts for ${TAG_NAME}" | ||||
|                 failed_count=$((failed_count + 1)) | ||||
|               fi | ||||
|             done < "${commits_file}" | ||||
|  | ||||
|             rm -f "${commits_file}" | ||||
|  | ||||
|             if [ "${failed_count}" -gt 0 ]; then | ||||
|               exit 1 | ||||
|             fi | ||||
|           # Execute with retry | ||||
|           if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then | ||||
|             echo "exists=false" >> "${GITHUB_OUTPUT}" | ||||
|             exit 0 | ||||
|           else | ||||
|             # workflow_dispatch path (single SHA tagging preserved) | ||||
|  | ||||
|             # Exit early if tag already exists | ||||
|             if check_tag_exists; then | ||||
|               echo "✅ Tag already exists - no action needed" | ||||
|               skipped_count=1 | ||||
|               exit 0 | ||||
|             fi | ||||
|  | ||||
|             echo "Tag ${TAG_NAME} does not exist, proceeding with creation" | ||||
|  | ||||
|             if retry_with_backoff "tag_with_retry" "Creating tag ${TAG_NAME} for commit ${COMMIT_SHA}"; then | ||||
|               created_count=1 | ||||
|               exit 0 | ||||
|             else | ||||
|               echo "Tag creation failed after all retry attempts" | ||||
|               failed_count=1 | ||||
|               exit 1 | ||||
|             fi | ||||
|             echo "Tag creation failed after all retry attempts" | ||||
|             exit 1 | ||||
|           fi | ||||
|  | ||||
|       - name: Tag creation summary | ||||
|         if: always() | ||||
|         run: | | ||||
|           if [ "${{ github.event_name }}" = "push" ]; then | ||||
|             echo "Trigger: push on main" | ||||
|             echo "Created: ${{ steps.check_tag.outputs.created_count }}" | ||||
|             echo "Skipped (already existed): ${{ steps.check_tag.outputs.skipped_count }}" | ||||
|             echo "Failed: ${{ steps.check_tag.outputs.failed_count }}" | ||||
|             if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then | ||||
|               echo "✅ Completed tagging for push range ${{ github.event.before }}..${{ github.sha }}" | ||||
|             else | ||||
|               echo "❌ Some tags failed to create for push range ${{ github.event.before }}..${{ github.sha }}" | ||||
|             fi | ||||
|           if [ "${{ steps.check_tag.outputs.exists }}" = "true" ]; then | ||||
|             echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed" | ||||
|           elif [ "${{ job.status }}" = "success" ]; then | ||||
|             echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}" | ||||
|           else | ||||
|             if [ "${{ steps.check_tag.outputs.failed_count }}" = "0" ]; then | ||||
|               if [ "${{ steps.check_tag.outputs.created_count }}" = "0" ]; then | ||||
|                 echo "✅ Tag ${{ steps.commit.outputs.tag_name }} already existed - no action needed" | ||||
|               else | ||||
|                 echo "✅ Successfully created tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}" | ||||
|               fi | ||||
|             else | ||||
|               echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}" | ||||
|             fi | ||||
|  | ||||
|             echo "" | ||||
|             echo "Tag details:" | ||||
|             echo "  Name: ${{ steps.commit.outputs.tag_name }}" | ||||
|             echo "  Commit: ${{ steps.commit.outputs.sha }}" | ||||
|             echo "  Trigger: ${{ github.event_name }}" | ||||
|             if [ -n "${{ github.event.inputs.commit_sha }}" ]; then | ||||
|               echo "  Manual commit: ${{ github.event.inputs.commit_sha }}" | ||||
|             fi | ||||
|             echo "❌ Failed to create tag ${{ steps.commit.outputs.tag_name }} for commit ${{ steps.commit.outputs.sha }}" | ||||
|           fi | ||||
|  | ||||
|           echo "" | ||||
|           echo "Tag details:" | ||||
|           echo "  Name: ${{ steps.commit.outputs.tag_name }}" | ||||
|           echo "  Commit: ${{ steps.commit.outputs.sha }}" | ||||
|           echo "  Trigger: ${{ github.event_name }}" | ||||
|           if [ -n "${{ github.event.inputs.commit_sha }}" ]; then | ||||
|             echo "  Manual commit: ${{ github.event.inputs.commit_sha }}" | ||||
|           fi | ||||
|  | ||||
							
								
								
									
										51
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										51
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							| @ -190,40 +190,6 @@ jobs: | ||||
|       runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" | ||||
|     secrets: inherit | ||||
|  | ||||
|   linux-jammy-rocm-py3_10-build: | ||||
|     if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }} | ||||
|     name: linux-jammy-rocm-py3.10 | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     needs: get-label-type | ||||
|     with: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build-environment: linux-jammy-rocm-py3.10 | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 | ||||
|       sync-tag: rocm-build | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|           { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   linux-jammy-rocm-py3_10-test: | ||||
|     if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }} | ||||
|     permissions: | ||||
|       id-token: write | ||||
|       contents: read | ||||
|     name: linux-jammy-rocm-py3.10 | ||||
|     uses: ./.github/workflows/_rocm-test.yml | ||||
|     needs: | ||||
|       - linux-jammy-rocm-py3_10-build | ||||
|       - target-determination | ||||
|     with: | ||||
|       build-environment: linux-jammy-rocm-py3.10 | ||||
|       docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} | ||||
|       tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor" | ||||
|     secrets: inherit | ||||
|  | ||||
|   inductor-build: | ||||
|     name: inductor-build | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
| @ -234,23 +200,6 @@ jobs: | ||||
|       cuda-arch-list: '8.0' | ||||
|     secrets: inherit | ||||
|  | ||||
|   # Test cross-compiled models with Windows libs extracted from wheel | ||||
|   cross-compile-linux-test: | ||||
|     name: cross-compile-linux-test | ||||
|     uses: ./.github/workflows/_linux-test.yml | ||||
|     needs: | ||||
|       - linux-jammy-cuda12_8-py3_10-gcc11-build | ||||
|       - get-label-type | ||||
|       - win-vs2022-cuda12_8-py3-build | ||||
|     with: | ||||
|       build-environment: linux-jammy-cuda12.8-py3.10-gcc11 | ||||
|       docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }} | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "aoti_cross_compile_for_windows", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", win_torch_wheel_artifact: "win-vs2022-cuda12.8-py3" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   verify-cachebench-cpu-build: | ||||
|     name: verify-cachebench-cpu-build | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|  | ||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -374,7 +374,6 @@ third_party/ruy/ | ||||
| third_party/glog/ | ||||
|  | ||||
| # Virtualenv | ||||
| .venv/ | ||||
| venv/ | ||||
|  | ||||
| # Log files | ||||
|  | ||||
| @ -1138,8 +1138,11 @@ command = [ | ||||
| [[linter]] | ||||
| code = 'WORKFLOWSYNC' | ||||
| include_patterns = [ | ||||
|     '.github/workflows/*.yml', | ||||
|     '.github/workflows/*.yaml', | ||||
|     '.github/workflows/pull.yml', | ||||
|     '.github/workflows/trunk.yml', | ||||
|     '.github/workflows/periodic.yml', | ||||
|     '.github/workflows/mac-mps.yml', | ||||
|     '.github/workflows/slow.yml', | ||||
| ] | ||||
| command = [ | ||||
|     'python3', | ||||
|  | ||||
							
								
								
									
										14
									
								
								CODEOWNERS
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								CODEOWNERS
									
									
									
									
									
								
							| @ -201,17 +201,3 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A | ||||
| /torch/csrc/stable/ @janeyx99 @mikaylagawarecki | ||||
| /torch/headeronly/ @janeyx99 | ||||
| /torch/header_only_apis.txt @janeyx99 | ||||
|  | ||||
| # FlexAttention | ||||
| /torch/nn/attention/flex_attention.py @drisspg | ||||
| /torch/_higher_order_ops/flex_attention.py @drisspg | ||||
| /torch/_inductor/kernel/flex/ @drisspg | ||||
| /torch/_inductor/codegen/cpp_flex_attention_template.py @drisspg | ||||
| /test/inductor/test_flex_attention.py @drisspg | ||||
| /test/inductor/test_flex_decoding.py @drisspg | ||||
|  | ||||
| # Low Precision GEMMs | ||||
| /aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58 | ||||
| /aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58 | ||||
| /aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58 | ||||
| /test/test_scaled_matmul_cuda.py @drisspg @slayton58 | ||||
|  | ||||
| @ -289,15 +289,14 @@ IF(USE_FBGEMM_GENAI) | ||||
|  | ||||
|     set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON) | ||||
|  | ||||
|     set(fbgemm_genai_cuh | ||||
|     set(fbgemm_genai_mx8mx8bf16_grouped | ||||
|       "${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/" | ||||
|       "${FBGEMM_GENAI_SRCS}/" | ||||
|     ) | ||||
|  | ||||
|     target_include_directories(fbgemm_genai PRIVATE | ||||
|       ${FBGEMM_THIRD_PARTY}/cutlass/include | ||||
|       ${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include | ||||
|       ${fbgemm_genai_cuh} | ||||
|       ${fbgemm_genai_mx8mx8bf16_grouped} | ||||
|       ${FBGEMM_GENAI_SRCS}/common/include/   # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp | ||||
|       ${FBGEMM_GENAI_SRCS}/include/          # includes fbgemm_gpu/torch_ops.h | ||||
|     ) | ||||
| @ -314,14 +313,13 @@ IF(USE_FBGEMM_GENAI) | ||||
|  | ||||
|     # Add additional HIPCC compiler flags for performance | ||||
|     set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS | ||||
|       -mllvm | ||||
|       -amdgpu-coerce-illegal-types=1 | ||||
|       -mllvm | ||||
|       -enable-post-misched=0 | ||||
|       -mllvm | ||||
|       -greedy-reverse-local-assignment=1 | ||||
|       -fhip-new-launch-api) | ||||
|     if(DEFINED ROCM_VERSION_DEV AND ROCM_VERSION_DEV VERSION_LESS "7.2.0") | ||||
|         list(PREPEND FBGEMM_GENAI_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1) | ||||
|       endif() | ||||
|  | ||||
|     # Only compile for gfx942 for now. | ||||
|     # This is rather hacky, I could not figure out a clean solution :( | ||||
|  | ||||
| @ -19,7 +19,6 @@ | ||||
| #include <ATen/detail/MPSHooksInterface.h> | ||||
| #include <ATen/detail/MTIAHooksInterface.h> | ||||
| #include <ATen/detail/PrivateUse1HooksInterface.h> | ||||
| #include <ATen/detail/XLAHooksInterface.h> | ||||
| #include <ATen/detail/XPUHooksInterface.h> | ||||
| #include <c10/core/QEngine.h> | ||||
| #include <c10/core/impl/DeviceGuardImplInterface.h> | ||||
| @ -89,8 +88,6 @@ class TORCH_API Context { | ||||
|       return at::detail::getHIPHooks(); | ||||
|     } else if (opt_device_type == at::kHPU) { | ||||
|       return at::detail::getHPUHooks(); | ||||
|     } else if (opt_device_type == at::kXLA) { | ||||
|       return at::detail::getXLAHooks(); | ||||
|     } else { | ||||
|       TORCH_CHECK( | ||||
|           false, | ||||
| @ -199,7 +196,7 @@ class TORCH_API Context { | ||||
|     return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU); | ||||
|   } | ||||
|   static bool hasXLA() { | ||||
|     return detail::getXLAHooks().hasXLA(); | ||||
|     return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA); | ||||
|   } | ||||
|   static bool hasXPU() { | ||||
|     return detail::getXPUHooks().hasXPU(); | ||||
|  | ||||
| @ -2,6 +2,7 @@ | ||||
|  | ||||
| #include <mutex> | ||||
| #include <ATen/CachedTensorUtils.h> | ||||
| #include <c10/core/GradMode.h> | ||||
| #include <c10/util/flat_hash_map.h> | ||||
|  | ||||
| namespace at::autocast { | ||||
| @ -36,10 +37,29 @@ namespace { | ||||
| using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>; | ||||
| using val_type = std::tuple<weakref_type, Tensor>; | ||||
|  | ||||
| ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() { | ||||
|   static ska::flat_hash_map<TensorImpl*, val_type> cached_casts; | ||||
|   return cached_casts; | ||||
| // We maintain separate caches for gradient-enabled and gradient-disabled modes. | ||||
| // This ensures that tensors cached in torch.no_grad() (with requires_grad=False) | ||||
| // are not incorrectly reused in gradient-enabled contexts. | ||||
| // This fixes issue #158232 while maintaining optimal performance for both modes. | ||||
| static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_enabled() { | ||||
|   static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_enabled; | ||||
|   return cached_casts_grad_enabled; | ||||
| } | ||||
|  | ||||
| static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_disabled() { | ||||
|   static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_disabled; | ||||
|   return cached_casts_grad_disabled; | ||||
| } | ||||
|  | ||||
| // Helper function to get the appropriate cache based on current gradient mode. | ||||
| // This allows us to cache tensors separately for grad-enabled and grad-disabled contexts, | ||||
| // preventing incorrect cache hits when gradient mode changes. | ||||
| static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() { | ||||
|   return at::GradMode::is_enabled() ? | ||||
|     get_cached_casts_grad_enabled() : | ||||
|     get_cached_casts_grad_disabled(); | ||||
| } | ||||
|  | ||||
| std::mutex cached_casts_mutex; | ||||
|  | ||||
|  | ||||
| @ -86,7 +106,9 @@ thread_local bool cache_enabled = true; | ||||
|  | ||||
| void clear_cache() { | ||||
|   const std::lock_guard<std::mutex> lock(cached_casts_mutex); | ||||
|   get_cached_casts().clear(); | ||||
|   // Clear both caches to ensure consistent behavior regardless of current gradient mode | ||||
|   get_cached_casts_grad_enabled().clear(); | ||||
|   get_cached_casts_grad_disabled().clear(); | ||||
| } | ||||
|  | ||||
| int increment_nesting() { | ||||
| @ -121,6 +143,11 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_ | ||||
|   if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) { | ||||
|     // Heuristic:  Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves). | ||||
|     // See cached_casts declaration above for detailed strategy. | ||||
|     // | ||||
|     // We maintain separate caches for gradient-enabled and gradient-disabled modes | ||||
|     // (see get_cached_casts() above). This ensures correctness when mixing torch.no_grad() | ||||
|     // with torch.autocast(), while maintaining optimal performance for both training and inference. | ||||
|     // This fixes issue #158232 without any performance regression. | ||||
|     bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) && | ||||
|                          arg.scalar_type() == at::kFloat && arg.requires_grad() && | ||||
|                          arg.is_leaf() && !arg.is_view() && cache_enabled && | ||||
|  | ||||
| @ -39,7 +39,7 @@ struct HostBlock { | ||||
| }; | ||||
|  | ||||
| template <typename B> | ||||
| struct alignas(hardware_destructive_interference_size) FreeBlockList { | ||||
| struct alignas(64) FreeBlockList { | ||||
|   std::mutex mutex_; | ||||
|   std::deque<B*> list_; | ||||
| }; | ||||
| @ -122,7 +122,7 @@ struct TORCH_API HostStats { | ||||
| // Struct containing memory allocator summary statistics for host, as they | ||||
| // are staged for reporting. This is a temporary struct that is used to | ||||
| // avoid locking the allocator while collecting stats. | ||||
| struct alignas(hardware_destructive_interference_size) HostStatsStaged { | ||||
| struct alignas(64) HostStatsStaged { | ||||
|   std::mutex timing_mutex_; | ||||
|   // COUNT: total allocations (active + free) | ||||
|   // LOCK: access to this stat is protected by the allocator's blocks_mutex_ | ||||
| @ -669,7 +669,7 @@ struct CachingHostAllocatorImpl { | ||||
|     TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event"); | ||||
|   } | ||||
|  | ||||
|   alignas(hardware_destructive_interference_size) std::mutex blocks_mutex_; | ||||
|   alignas(64) std::mutex blocks_mutex_; | ||||
|   ska::flat_hash_set<B*> blocks_; // block list | ||||
|   ska::flat_hash_map<void*, B*> ptr_to_block_; | ||||
|  | ||||
| @ -677,17 +677,17 @@ struct CachingHostAllocatorImpl { | ||||
|   // size. This allows us to quickly find a free block of the right size. | ||||
|   // We use deque to store per size free list and guard the list with its own | ||||
|   // mutex. | ||||
|   alignas(hardware_destructive_interference_size) std::vector<FreeBlockList<B>> free_list_ = | ||||
|   alignas(64) std::vector<FreeBlockList<B>> free_list_ = | ||||
|       std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX); | ||||
|  | ||||
|   alignas(hardware_destructive_interference_size) std::mutex events_mutex_; | ||||
|   alignas(64) std::mutex events_mutex_; | ||||
|   std::deque<std::pair<E, B*>> events_; // event queue paired with block | ||||
|  | ||||
|   // Indicates whether the object is active. | ||||
|   // Set to false in the destructor to signal background threads to stop. | ||||
|   std::atomic<bool> active_{true}; | ||||
| protected: | ||||
|   alignas(hardware_destructive_interference_size) HostStatsStaged stats_; | ||||
|   alignas(64) HostStatsStaged stats_; | ||||
| }; | ||||
|  | ||||
| struct TORCH_API HostAllocator : public at::Allocator { | ||||
|  | ||||
| @ -59,7 +59,9 @@ struct TORCH_API Generator { | ||||
|  | ||||
|   explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl) | ||||
|    : impl_(std::move(gen_impl)) { | ||||
|     TORCH_CHECK(impl_.get(), "GeneratorImpl with nullptr is not supported"); | ||||
|     if (impl_.get() == nullptr) { | ||||
|       throw std::runtime_error("GeneratorImpl with nullptr is not supported"); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   bool operator==(const Generator& rhs) const { | ||||
|  | ||||
| @ -229,10 +229,10 @@ private: | ||||
|   } | ||||
|  | ||||
|  | ||||
|   static constexpr uint32_t kPhilox10A = 0x9E3779B9; | ||||
|   static constexpr uint32_t kPhilox10B = 0xBB67AE85; | ||||
|   static constexpr uint32_t kPhiloxSA = 0xD2511F53; | ||||
|   static constexpr uint32_t kPhiloxSB = 0xCD9E8D57; | ||||
|   static const uint32_t kPhilox10A = 0x9E3779B9; | ||||
|   static const uint32_t kPhilox10B = 0xBB67AE85; | ||||
|   static const uint32_t kPhiloxSA = 0xD2511F53; | ||||
|   static const uint32_t kPhiloxSB = 0xCD9E8D57; | ||||
| }; | ||||
|  | ||||
| typedef philox_engine Philox4_32; | ||||
|  | ||||
| @ -111,7 +111,9 @@ class TORCH_API TensorBase { | ||||
|   explicit TensorBase( | ||||
|       c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) | ||||
|       : impl_(std::move(tensor_impl)) { | ||||
|     TORCH_CHECK(impl_.get(), "TensorImpl with nullptr is not supported"); | ||||
|     if (impl_.get() == nullptr) { | ||||
|       throw std::runtime_error("TensorImpl with nullptr is not supported"); | ||||
|     } | ||||
|   } | ||||
|   TensorBase(const TensorBase&) = default; | ||||
|   TensorBase(TensorBase&&) noexcept = default; | ||||
|  | ||||
| @ -68,7 +68,11 @@ Symbol InternedStrings::_symbol(const std::string& s) { | ||||
|     return it->second; | ||||
|  | ||||
|   auto pos = s.find("::"); | ||||
|   TORCH_CHECK(pos != std::string::npos, "all symbols must have a namespace, <namespace>::<string>, but found: ", s); | ||||
|   if (pos == std::string::npos) { | ||||
|     std::stringstream ss; | ||||
|     ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s; | ||||
|     throw std::runtime_error(ss.str()); | ||||
|   } | ||||
|   Symbol ns = _symbol("namespaces::" + s.substr(0, pos)); | ||||
|  | ||||
|   Symbol sym(sym_to_info_.size()); | ||||
| @ -117,7 +121,12 @@ std::string Symbol::domainString() const { | ||||
| } | ||||
|  | ||||
| Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) { | ||||
|   TORCH_CHECK(d.compare(0, domain_prefix().size(), domain_prefix()) == 0, "Symbol: domain string is expected to be prefixed with '", domain_prefix(), "', e.g. 'org.pytorch.aten'"); | ||||
|   if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) { | ||||
|     std::ostringstream ss; | ||||
|     ss << "Symbol: domain string is expected to be prefixed with '" | ||||
|        << domain_prefix() << "', e.g. 'org.pytorch.aten'"; | ||||
|     throw std::runtime_error(ss.str()); | ||||
|   } | ||||
|   std::string qualString = d.substr(domain_prefix().size()) + "::" + s; | ||||
|   return fromQualString(qualString); | ||||
| } | ||||
|  | ||||
| @ -7,7 +7,6 @@ | ||||
| #include <ATen/core/jit_type.h> | ||||
| #include <ATen/core/stack.h> | ||||
| #include <ATen/core/type_factory.h> | ||||
| #include <c10/util/Exception.h> | ||||
| #include <c10/util/StringUtil.h> | ||||
| #include <c10/util/hash.h> | ||||
| #include <c10/util/irange.h> | ||||
| @ -413,7 +412,7 @@ size_t IValue::hash(const IValue& v) { | ||||
|     case Tag::Enum: | ||||
|     case Tag::Stream: | ||||
|     case Tag::Uninitialized: | ||||
|       TORCH_CHECK(false, | ||||
|       throw std::runtime_error( | ||||
|           "unhashable type: '" + v.type()->repr_str() + "'"); | ||||
|   } | ||||
|   // the above switch should be exhaustive | ||||
|  | ||||
| @ -8,7 +8,6 @@ | ||||
| #include <ATen/core/type_factory.h> | ||||
| #include <ATen/core/qualified_name.h> | ||||
| #include <c10/util/TypeList.h> | ||||
| #include <c10/util/Exception.h> | ||||
| #include <optional> | ||||
| #include <c10/core/SymFloat.h> | ||||
| #include <c10/core/SymBool.h> | ||||
| @ -117,8 +116,10 @@ struct SingleElementType : public SharedType { | ||||
|  | ||||
|  protected: | ||||
|   SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) { | ||||
|     TORCH_CHECK(this->elem, c10::str( | ||||
|     if (!this->elem) { | ||||
|       throw std::runtime_error(c10::str( | ||||
|             "Can not create ", typeKindToString(Kind), " with None type")); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|  private: | ||||
| @ -415,12 +416,16 @@ struct TORCH_API SymbolicShape { | ||||
|   } | ||||
|  | ||||
|   ShapeSymbol operator[](size_t i) const { | ||||
|     TORCH_CHECK(dims_, "Rank isn't fixed"); | ||||
|     if (!dims_) { | ||||
|       throw std::runtime_error("Rank isn't fixed"); | ||||
|     } | ||||
|     return (*dims_).at(i); | ||||
|   } | ||||
|  | ||||
|   ShapeSymbol at(size_t i) const { | ||||
|     TORCH_CHECK(dims_, "Rank isn't fixed"); | ||||
|     if (!dims_) { | ||||
|       throw std::runtime_error("Rank isn't fixed"); | ||||
|     } | ||||
|     return (*dims_).at(i); | ||||
|   } | ||||
|  | ||||
| @ -515,7 +520,9 @@ struct VaryingShape { | ||||
|   } | ||||
|  | ||||
|   const std::optional<T> &operator[](size_t i) const { | ||||
|     TORCH_CHECK(dims_, "Rank isn't fixed"); | ||||
|     if (!dims_) { | ||||
|       throw std::runtime_error("Rank isn't fixed"); | ||||
|     } | ||||
|     return (*dims_).at(i); | ||||
|   } | ||||
|  | ||||
| @ -950,7 +957,9 @@ struct TORCH_API DictType : public SharedType { | ||||
|  | ||||
|   TypePtr createWithContained( | ||||
|       std::vector<TypePtr> contained_types) const override { | ||||
|     TORCH_CHECK(contained_types.size() == 2, "Expected 2 contained types"); | ||||
|     if (contained_types.size() != 2) { | ||||
|       throw std::runtime_error("Expected 2 contained types"); | ||||
|     } | ||||
|     return create(std::move(contained_types.at(0)), std::move(contained_types.at(1))); | ||||
|   } | ||||
|  | ||||
|  | ||||
| @ -8,7 +8,6 @@ | ||||
| #include <ATen/core/jit_type.h> | ||||
| #include <c10/macros/Macros.h> | ||||
| #include <c10/util/env.h> | ||||
| #include <c10/util/Exception.h> | ||||
| #include <c10/util/flat_hash_map.h> | ||||
| #include <c10/util/irange.h> | ||||
| #include <array> | ||||
| @ -827,7 +826,9 @@ TupleType::TupleType( | ||||
|     : NamedType(TypeKind::TupleType, std::move(name)), | ||||
|       elements_(std::move(elements)), | ||||
|       has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) { | ||||
|         TORCH_CHECK(v, "Can not create tuple with None type"); | ||||
|         if (!v) { | ||||
|           throw std::runtime_error("Can not create tuple with None type"); | ||||
|         } | ||||
|         return v->hasFreeVariables(); | ||||
|       })), schema_(std::move(schema)) { | ||||
|  | ||||
|  | ||||
| @ -6,11 +6,8 @@ | ||||
| #ifdef __aarch64__ | ||||
| #if !defined(CPU_CAPABILITY_SVE) | ||||
| #include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h> | ||||
| #include <ATen/cpu/vec/vec128/vec128_double_neon.h> | ||||
| #include <ATen/cpu/vec/vec128/vec128_float_neon.h> | ||||
| #include <ATen/cpu/vec/vec128/vec128_half_neon.h> | ||||
| #include <ATen/cpu/vec/vec128/vec128_int_aarch64.h> | ||||
| #include <ATen/cpu/vec/vec128/vec128_uint_aarch64.h> | ||||
| #endif | ||||
|  | ||||
| #include <ATen/cpu/vec/vec128/vec128_convert.h> | ||||
|  | ||||
| @ -354,47 +354,9 @@ class Vectorized<c10::BFloat16> : public Vectorized16< | ||||
|  | ||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs) | ||||
|   Vectorized frac() const; | ||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg) | ||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc) | ||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt) | ||||
|  | ||||
| #ifdef __ARM_FEATURE_BF16 | ||||
|   Vectorized<c10::BFloat16> neg() const { | ||||
|     return -values; | ||||
|   } | ||||
|   Vectorized<c10::BFloat16> reciprocal() const { | ||||
|     return 1.0f / values; | ||||
|   } | ||||
|   Vectorized<c10::BFloat16> operator==( | ||||
|       const Vectorized<c10::BFloat16>& other) const { | ||||
|     return values == other.values; | ||||
|   } | ||||
|  | ||||
|   Vectorized<c10::BFloat16> operator!=( | ||||
|       const Vectorized<c10::BFloat16>& other) const { | ||||
|     return values != other.values; | ||||
|   } | ||||
|  | ||||
|   Vectorized<c10::BFloat16> operator<( | ||||
|       const Vectorized<c10::BFloat16>& other) const { | ||||
|     return values < other.values; | ||||
|   } | ||||
|  | ||||
|   Vectorized<c10::BFloat16> operator<=( | ||||
|       const Vectorized<c10::BFloat16>& other) const { | ||||
|     return values <= other.values; | ||||
|   } | ||||
|  | ||||
|   Vectorized<c10::BFloat16> operator>( | ||||
|       const Vectorized<c10::BFloat16>& other) const { | ||||
|     return values > other.values; | ||||
|   } | ||||
|  | ||||
|   Vectorized<c10::BFloat16> operator>=( | ||||
|       const Vectorized<c10::BFloat16>& other) const { | ||||
|     return values >= other.values; | ||||
|   } | ||||
| #else | ||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg) | ||||
|   DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal) | ||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==) | ||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=) | ||||
| @ -402,7 +364,6 @@ class Vectorized<c10::BFloat16> : public Vectorized16< | ||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=) | ||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>) | ||||
|   DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=) | ||||
| #endif | ||||
|  | ||||
| #undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD | ||||
| #undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD | ||||
| @ -451,52 +412,28 @@ template <> | ||||
| Vectorized<c10::BFloat16> inline operator+( | ||||
|     const Vectorized<c10::BFloat16>& a, | ||||
|     const Vectorized<c10::BFloat16>& b) { | ||||
| #ifdef __ARM_FEATURE_BF16 | ||||
|   bfloat16x8_t x = a; | ||||
|   bfloat16x8_t y = b; | ||||
|   return x + y; | ||||
| #else | ||||
|   return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<c10::BFloat16> inline operator-( | ||||
|     const Vectorized<c10::BFloat16>& a, | ||||
|     const Vectorized<c10::BFloat16>& b) { | ||||
| #ifdef __ARM_FEATURE_BF16 | ||||
|   bfloat16x8_t x = a; | ||||
|   bfloat16x8_t y = b; | ||||
|   return x - y; | ||||
| #else | ||||
|   return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<c10::BFloat16> inline operator*( | ||||
|     const Vectorized<c10::BFloat16>& a, | ||||
|     const Vectorized<c10::BFloat16>& b) { | ||||
| #ifdef __ARM_FEATURE_BF16 | ||||
|   bfloat16x8_t x = a; | ||||
|   bfloat16x8_t y = b; | ||||
|   return x * y; | ||||
| #else | ||||
|   return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<c10::BFloat16> inline operator/( | ||||
|     const Vectorized<c10::BFloat16>& a, | ||||
|     const Vectorized<c10::BFloat16>& b) { | ||||
| #ifdef __ARM_FEATURE_BF16 | ||||
|   bfloat16x8_t x = a; | ||||
|   bfloat16x8_t y = b; | ||||
|   return x / y; | ||||
| #else | ||||
|   return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| // frac. Implement this here so we can use subtraction | ||||
| @ -607,19 +544,12 @@ Vectorized<c10::BFloat16> inline fmadd( | ||||
|     const Vectorized<c10::BFloat16>& a, | ||||
|     const Vectorized<c10::BFloat16>& b, | ||||
|     const Vectorized<c10::BFloat16>& c) { | ||||
| #ifdef __ARM_FEATURE_BF16 | ||||
|   bfloat16x8_t x = a; | ||||
|   bfloat16x8_t y = b; | ||||
|   bfloat16x8_t z = c; | ||||
|   return x * y + z; | ||||
| #else | ||||
|   // NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16!  Also, | ||||
|   // vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered | ||||
|   // elements, not the bottom and top half, so they don't seem | ||||
|   // particularly useful here. Ideally we would include dot product in | ||||
|   // the Vectorized interface... | ||||
|   return a * b + c; | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <> | ||||
| @ -627,15 +557,8 @@ Vectorized<c10::BFloat16> inline fnmadd( | ||||
|     const Vectorized<c10::BFloat16>& a, | ||||
|     const Vectorized<c10::BFloat16>& b, | ||||
|     const Vectorized<c10::BFloat16>& c) { | ||||
| #ifdef __ARM_FEATURE_BF16 | ||||
|   bfloat16x8_t x = a; | ||||
|   bfloat16x8_t y = b; | ||||
|   bfloat16x8_t z = c; | ||||
|   return (-x) * y + z; | ||||
| #else | ||||
|   // See NOTE [BF16 FMA] above. | ||||
|   return -a * b + c; | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <> | ||||
| @ -643,15 +566,8 @@ Vectorized<c10::BFloat16> inline fmsub( | ||||
|     const Vectorized<c10::BFloat16>& a, | ||||
|     const Vectorized<c10::BFloat16>& b, | ||||
|     const Vectorized<c10::BFloat16>& c) { | ||||
| #ifdef __ARM_FEATURE_BF16 | ||||
|   bfloat16x8_t x = a; | ||||
|   bfloat16x8_t y = b; | ||||
|   bfloat16x8_t z = c; | ||||
|   return x * y - z; | ||||
| #else | ||||
|   // See NOTE [BF16 FMA] above. | ||||
|   return a * b - c; | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <> | ||||
| @ -659,15 +575,8 @@ Vectorized<c10::BFloat16> inline fnmsub( | ||||
|     const Vectorized<c10::BFloat16>& a, | ||||
|     const Vectorized<c10::BFloat16>& b, | ||||
|     const Vectorized<c10::BFloat16>& c) { | ||||
| #ifdef __ARM_FEATURE_BF16 | ||||
|   bfloat16x8_t x = a; | ||||
|   bfloat16x8_t y = b; | ||||
|   bfloat16x8_t z = c; | ||||
|   return (-x) * y - z; | ||||
| #else | ||||
|   // See NOTE [BF16 FMA] above. | ||||
|   return -a * b - c; | ||||
| #endif | ||||
| } | ||||
|  | ||||
| #endif // !defined(C10_MOBILE) && defined(__aarch64__) | ||||
|  | ||||
| @ -1,586 +0,0 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/cpu/vec/intrinsics.h> | ||||
| #include <ATen/cpu/vec/vec_base.h> | ||||
| #include <c10/macros/Macros.h> | ||||
| #include <c10/util/irange.h> | ||||
| #include <cmath> | ||||
|  | ||||
| namespace at::vec { | ||||
| // Note [CPU_CAPABILITY namespace] | ||||
| // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
| // This header, and all of its subheaders, will be compiled with | ||||
| // different architecture flags for each supported set of vector | ||||
| // intrinsics. So we need to make sure they aren't inadvertently | ||||
| // linked together. We do this by declaring objects in an `inline | ||||
| // namespace` which changes the name mangling, but can still be | ||||
| // accessed as `at::vec`. | ||||
| inline namespace CPU_CAPABILITY { | ||||
|  | ||||
| template <> | ||||
| struct is_vec_specialized_for<double> : std::bool_constant<true> {}; | ||||
|  | ||||
| template <> | ||||
| class Vectorized<double> { | ||||
|  private: | ||||
|   float64x2_t values; | ||||
|  | ||||
|  public: | ||||
|   using value_type = double; | ||||
|   using size_type = int; | ||||
|   static constexpr size_type size() { | ||||
|     return 2; | ||||
|   } | ||||
|   Vectorized() { | ||||
|     values = vdupq_n_f64(0.0); | ||||
|   } | ||||
|   Vectorized(float64x2_t v) : values(v) {} | ||||
|   Vectorized(double val) { | ||||
|     values = vdupq_n_f64(val); | ||||
|   } | ||||
|   template < | ||||
|       typename... Args, | ||||
|       typename = std::enable_if_t<(sizeof...(Args) == size())>> | ||||
|   Vectorized(Args... vals) { | ||||
|     __at_align__ double buffer[size()] = {vals...}; | ||||
|     values = vld1q_f64(buffer); | ||||
|   } | ||||
|   operator float64x2_t() const { | ||||
|     return values; | ||||
|   } | ||||
|   template <int64_t mask> | ||||
|   static Vectorized<double> blend( | ||||
|       const Vectorized<double>& a, | ||||
|       const Vectorized<double>& b) { | ||||
|     // Build an array of flags: each bit of element is 1 if the corresponding | ||||
|     // bit in 'mask' is set, 0 otherwise. | ||||
|     uint64x2_t maskArray = { | ||||
|         (mask & 1ULL) ? 0xFFFFFFFFFFFFFFFF : 0, | ||||
|         (mask & 2ULL) ? 0xFFFFFFFFFFFFFFFF : 0}; | ||||
|     // Use BSL to select elements from b where the mask is 1, else from a | ||||
|     return vbslq_f64(maskArray, b.values, a.values); | ||||
|   } | ||||
|   static Vectorized<double> blendv( | ||||
|       const Vectorized<double>& a, | ||||
|       const Vectorized<double>& b, | ||||
|       const Vectorized<double>& mask_) { | ||||
|     return vbslq_f64(vreinterpretq_u64_f64(mask_.values), b.values, a.values); | ||||
|   } | ||||
|   template <typename step_t> | ||||
|   static Vectorized<double> arange( | ||||
|       double base = 0., | ||||
|       step_t step = static_cast<step_t>(1)) { | ||||
|     return {base, base + static_cast<double>(step)}; | ||||
|   } | ||||
|   static inline Vectorized<double> set( | ||||
|       const Vectorized<double>& a, | ||||
|       const Vectorized<double>& b, | ||||
|       int64_t count = size()) { | ||||
|     if (count == 0) { | ||||
|       return a; | ||||
|     } else if (count >= 2) { | ||||
|       return b; | ||||
|     } else { | ||||
|       float64x2_t c = {b.values[0], a.values[1]}; | ||||
|       return c; | ||||
|     } | ||||
|   } | ||||
|   static Vectorized<double> loadu(const void* ptr, int64_t count = size()) { | ||||
|     if (count == size()) { | ||||
|       return vld1q_f64(reinterpret_cast<const double*>(ptr)); | ||||
|     } else if (count == 1) { | ||||
|       float64x1_t x = vld1_f64(reinterpret_cast<const double*>(ptr)); | ||||
|       float64x1_t z = {0.0}; | ||||
|       return vcombine_f64(x, z); | ||||
|     } else { | ||||
|       return vdupq_n_f64(0.0); | ||||
|     } | ||||
|   } | ||||
|   void store(void* ptr, int64_t count = size()) const { | ||||
|     if (count == size()) { | ||||
|       vst1q_f64(reinterpret_cast<double*>(ptr), values); | ||||
|     } else if (count == 1) { | ||||
|       vst1_f64(reinterpret_cast<double*>(ptr), vget_low_f64(values)); | ||||
|     } | ||||
|   } | ||||
|   const double& operator[](int idx) const = delete; | ||||
|   double& operator[](int idx) = delete; | ||||
|   int64_t zero_mask() const { | ||||
|     // returns an integer mask where all zero elements are translated to 1-bit | ||||
|     // and others are translated to 0-bit | ||||
|     uint64x2_t cmpReg = vceqzq_f64(values); | ||||
|     uint64x2_t mask = {1, 2}; | ||||
|     uint64x2_t res = vandq_u64(cmpReg, mask); | ||||
|     return res[0] | res[1]; | ||||
|   } | ||||
|   Vectorized<double> isnan() const { | ||||
|     // NaN check | ||||
|     return vreinterpretq_f64_u32( | ||||
|         vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, values)))); | ||||
|   } | ||||
|   bool has_inf_nan() const { | ||||
|     Vectorized<double> x = vsubq_f64(values, values); | ||||
|     float64x2_t r = x.isnan(); | ||||
|     uint64x2_t u = vreinterpretq_u64_f64(r); | ||||
|     return u[0] | u[1]; | ||||
|   } | ||||
|   Vectorized<double> map(double (*f)(double)) const { | ||||
|     float64x2_t result; | ||||
|     result[0] = f(values[0]); | ||||
|     result[1] = f(values[1]); | ||||
|     return result; | ||||
|   } | ||||
|   Vectorized<double> map2( | ||||
|       const Vectorized<double>& second, | ||||
|       double (*const f)(double, double)) const { | ||||
|     float64x2_t result; | ||||
|     result[0] = f(values[0], second.values[0]); | ||||
|     result[1] = f(values[1], second.values[1]); | ||||
|     return result; | ||||
|   } | ||||
|   Vectorized<double> abs() const { | ||||
|     return vabsq_f64(values); | ||||
|   } | ||||
|   Vectorized<double> angle() const { | ||||
|     auto zero = Vectorized<double>(0.0); | ||||
|     auto pi = Vectorized<double>(c10::pi<double>); | ||||
|     auto tmp = blendv(zero, pi, vreinterpretq_f64_u64(vcltzq_f64(values))); | ||||
|     return blendv(tmp, *this, isnan()); | ||||
|   } | ||||
|   Vectorized<double> real() const { | ||||
|     return *this; | ||||
|   } | ||||
|   Vectorized<double> imag() const { | ||||
|     return Vectorized<double>(0.0); | ||||
|   } | ||||
|   Vectorized<double> conj() const { | ||||
|     return *this; | ||||
|   } | ||||
|   Vectorized<double> acos() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_acosd2_u10(values)), map(std::acos)); | ||||
|   } | ||||
|   Vectorized<double> acosh() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_acoshd2_u10(values)), map(std::acosh)); | ||||
|   } | ||||
|   Vectorized<double> asin() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_asind2_u10(values)), map(std::asin)); | ||||
|   } | ||||
|   Vectorized<double> asinh() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_asinhd2_u10(values)), map(std::asinh)); | ||||
|   } | ||||
|   Vectorized<double> atan() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_atand2_u10(values)), map(std::atan)); | ||||
|   } | ||||
|   Vectorized<double> atanh() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_atanhd2_u10(values)), map(std::atanh)); | ||||
|   } | ||||
|   Vectorized<double> atan2(const Vectorized<double>& b) const {USE_SLEEF( | ||||
|       { return Vectorized<double>(Sleef_atan2d2_u10(values, b)); }, | ||||
|       { | ||||
|         __at_align__ double tmp[size()]; | ||||
|         __at_align__ double tmp_b[size()]; | ||||
|         store(tmp); | ||||
|         b.store(tmp_b); | ||||
|         for (int64_t i = 0; i < size(); i++) { | ||||
|           tmp[i] = std::atan2(tmp[i], tmp_b[i]); | ||||
|         } | ||||
|         return loadu(tmp); | ||||
|       })} Vectorized<double> copysign(const Vectorized<double>& sign) const { | ||||
|       USE_SLEEF( | ||||
|           { return Vectorized<double>(Sleef_copysignd2(values, sign)); }, | ||||
|           { | ||||
|             __at_align__ double tmp[size()]; | ||||
|             __at_align__ double tmp_sign[size()]; | ||||
|             store(tmp); | ||||
|             sign.store(tmp_sign); | ||||
|             for (int64_t i = 0; i < size(); i++) { | ||||
|               tmp[i] = std::copysign(tmp[i], tmp_sign[i]); | ||||
|             } | ||||
|             return loadu(tmp); | ||||
|           })} Vectorized<double> erf() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_erfd2_u10(values)), map(std::erf)); | ||||
|   } | ||||
|   Vectorized<double> erfc() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_erfcd2_u15(values)), map(std::erfc)); | ||||
|   } | ||||
|   Vectorized<double> exp() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_expd2_u10(values)), map(std::exp)); | ||||
|   } | ||||
|   Vectorized<double> exp2() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_exp2d2_u10(values)), map(std::exp2)); | ||||
|   } | ||||
|   Vectorized<double> expm1() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_expm1d2_u10(values)), map(std::expm1)); | ||||
|   } | ||||
|   Vectorized<double> fmod(const Vectorized<double>& q) const {USE_SLEEF( | ||||
|       { return Vectorized<double>(Sleef_fmodd2(values, q)); }, | ||||
|       { | ||||
|         __at_align__ double tmp[size()]; | ||||
|         __at_align__ double tmp_q[size()]; | ||||
|         store(tmp); | ||||
|         q.store(tmp_q); | ||||
|         for (int64_t i = 0; i < size(); i++) { | ||||
|           tmp[i] = std::fmod(tmp[i], tmp_q[i]); | ||||
|         } | ||||
|         return loadu(tmp); | ||||
|       })} Vectorized<double> hypot(const Vectorized<double>& b) const { | ||||
|       USE_SLEEF( | ||||
|           { return Vectorized<double>(Sleef_hypotd2_u05(values, b)); }, | ||||
|           { | ||||
|             __at_align__ double tmp[size()]; | ||||
|             __at_align__ double tmp_b[size()]; | ||||
|             store(tmp); | ||||
|             b.store(tmp_b); | ||||
|             for (int64_t i = 0; i < size(); i++) { | ||||
|               tmp[i] = std::hypot(tmp[i], tmp_b[i]); | ||||
|             } | ||||
|             return loadu(tmp); | ||||
|           })} Vectorized<double> i0() const { | ||||
|     return map(calc_i0); | ||||
|   } | ||||
|   Vectorized<double> nextafter(const Vectorized<double>& b) const {USE_SLEEF( | ||||
|       { return Vectorized<double>(Sleef_nextafterd2(values, b)); }, | ||||
|       { | ||||
|         __at_align__ double tmp[size()]; | ||||
|         __at_align__ double tmp_b[size()]; | ||||
|         store(tmp); | ||||
|         b.store(tmp_b); | ||||
|         for (int64_t i = 0; i < size(); ++i) { | ||||
|           tmp[i] = std::nextafter(tmp[i], tmp_b[i]); | ||||
|         } | ||||
|         return loadu(tmp); | ||||
|       })} Vectorized<double> log() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_logd2_u10(values)), map(std::log)); | ||||
|   } | ||||
|   Vectorized<double> log2() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_log2d2_u10(values)), map(std::log2)); | ||||
|   } | ||||
|   Vectorized<double> log10() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_log10d2_u10(values)), map(std::log10)); | ||||
|   } | ||||
|   Vectorized<double> log1p() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_log1pd2_u10(values)), map(std::log1p)); | ||||
|   } | ||||
|   Vectorized<double> frac() const; | ||||
|   Vectorized<double> sin() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_sind2_u10(values)), map(std::sin)); | ||||
|   } | ||||
|   Vectorized<double> sinh() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_sinhd2_u10(values)), map(std::sinh)); | ||||
|   } | ||||
|   Vectorized<double> cos() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_cosd2_u10(values)), map(std::cos)); | ||||
|   } | ||||
|   Vectorized<double> cosh() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_coshd2_u10(values)), map(std::cosh)); | ||||
|   } | ||||
|   Vectorized<double> pow(const Vectorized<double>& b) const {USE_SLEEF( | ||||
|       { return Vectorized<double>(Sleef_powd2_u10(values, b)); }, | ||||
|       { | ||||
|         __at_align__ double tmp[size()]; | ||||
|         __at_align__ double tmp_b[size()]; | ||||
|         store(tmp); | ||||
|         b.store(tmp_b); | ||||
|         for (int64_t i = 0; i < size(); i++) { | ||||
|           tmp[i] = std::pow(tmp[i], tmp_b[i]); | ||||
|         } | ||||
|         return loadu(tmp); | ||||
|       })} // Comparison using the _CMP_**_OQ predicate. | ||||
|           //   `O`: get false if an operand is NaN | ||||
|           //   `Q`: do not raise if an operand is NaN | ||||
|   Vectorized<double> tan() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_tand2_u10(values)), map(std::tan)); | ||||
|   } | ||||
|   Vectorized<double> tanh() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_tanhd2_u10(values)), map(std::tanh)); | ||||
|   } | ||||
|   Vectorized<double> lgamma() const { | ||||
|     return USE_SLEEF( | ||||
|         Vectorized<double>(Sleef_lgammad2_u10(values)), map(std::lgamma)); | ||||
|   } | ||||
|   Vectorized<double> erfinv() const { | ||||
|     return map(calc_erfinv); | ||||
|   } | ||||
|   Vectorized<double> exp_u20() const { | ||||
|     return exp(); | ||||
|   } | ||||
|   Vectorized<double> fexp_u20() const { | ||||
|     return exp(); | ||||
|   } | ||||
|   Vectorized<double> i0e() const { | ||||
|     return map(calc_i0e); | ||||
|   } | ||||
|   Vectorized<double> digamma() const { | ||||
|     return map(calc_digamma); | ||||
|   } | ||||
|   Vectorized<double> igamma(const Vectorized<double>& x) const { | ||||
|     __at_align__ double tmp[size()]; | ||||
|     __at_align__ double tmp_x[size()]; | ||||
|     store(tmp); | ||||
|     x.store(tmp_x); | ||||
|     for (int64_t i = 0; i < size(); i++) { | ||||
|       tmp[i] = calc_igamma(tmp[i], tmp_x[i]); | ||||
|     } | ||||
|     return loadu(tmp); | ||||
|   } | ||||
|   Vectorized<double> igammac(const Vectorized<double>& x) const { | ||||
|     __at_align__ double tmp[size()]; | ||||
|     __at_align__ double tmp_x[size()]; | ||||
|     store(tmp); | ||||
|     x.store(tmp_x); | ||||
|     for (int64_t i = 0; i < size(); i++) { | ||||
|       tmp[i] = calc_igammac(tmp[i], tmp_x[i]); | ||||
|     } | ||||
|     return loadu(tmp); | ||||
|   } | ||||
|   Vectorized<double> ceil() const { | ||||
|     return vrndpq_f64(values); | ||||
|   } | ||||
|   Vectorized<double> floor() const { | ||||
|     return vrndmq_f64(values); | ||||
|   } | ||||
|   Vectorized<double> neg() const { | ||||
|     return vnegq_f64(values); | ||||
|   } | ||||
|   Vectorized<double> round() const { | ||||
|     return vrndiq_f64(values); | ||||
|   } | ||||
|   Vectorized<double> trunc() const { | ||||
|     return vrndq_f64(values); | ||||
|   } | ||||
|   Vectorized<double> sqrt() const { | ||||
|     return vsqrtq_f64(values); | ||||
|   } | ||||
|   Vectorized<double> reciprocal() const { | ||||
|     return vdivq_f64(vdupq_n_f64(1.0), values); | ||||
|   } | ||||
|   Vectorized<double> rsqrt() const { | ||||
|     return vdivq_f64(vdupq_n_f64(1.0), vsqrtq_f64(values)); | ||||
|   } | ||||
|   double reduce_add() const { | ||||
|     return vaddvq_f64(values); | ||||
|   } | ||||
|   double reduce_max() const { | ||||
|     return vmaxvq_f64(values); | ||||
|   } | ||||
|   Vectorized<double> operator==(const Vectorized<double>& other) const { | ||||
|     return Vectorized<double>( | ||||
|         vreinterpretq_f64_u64(vceqq_f64(values, other.values))); | ||||
|   } | ||||
|  | ||||
|   Vectorized<double> operator!=(const Vectorized<double>& other) const { | ||||
|     float64x2_t r0 = vreinterpretq_f64_u32( | ||||
|         vmvnq_u32(vreinterpretq_u32_u64(vceqq_f64(values, other.values)))); | ||||
|     return Vectorized<double>(r0); | ||||
|   } | ||||
|  | ||||
|   Vectorized<double> operator<(const Vectorized<double>& other) const { | ||||
|     return Vectorized<double>( | ||||
|         vreinterpretq_f64_u64(vcltq_f64(values, other.values))); | ||||
|   } | ||||
|  | ||||
|   Vectorized<double> operator<=(const Vectorized<double>& other) const { | ||||
|     return Vectorized<double>( | ||||
|         vreinterpretq_f64_u64(vcleq_f64(values, other.values))); | ||||
|   } | ||||
|  | ||||
|   Vectorized<double> operator>(const Vectorized<double>& other) const { | ||||
|     return Vectorized<double>( | ||||
|         vreinterpretq_f64_u64(vcgtq_f64(values, other.values))); | ||||
|   } | ||||
|  | ||||
|   Vectorized<double> operator>=(const Vectorized<double>& other) const { | ||||
|     return Vectorized<double>( | ||||
|         vreinterpretq_f64_u64(vcgeq_f64(values, other.values))); | ||||
|   } | ||||
|  | ||||
|   Vectorized<double> eq(const Vectorized<double>& other) const; | ||||
|   Vectorized<double> ne(const Vectorized<double>& other) const; | ||||
|   Vectorized<double> gt(const Vectorized<double>& other) const; | ||||
|   Vectorized<double> ge(const Vectorized<double>& other) const; | ||||
|   Vectorized<double> lt(const Vectorized<double>& other) const; | ||||
|   Vectorized<double> le(const Vectorized<double>& other) const; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline operator+( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   return vaddq_f64(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline operator-( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   return vsubq_f64(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline operator*( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   return vmulq_f64(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline operator/( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   return vdivq_f64(a, b); | ||||
| } | ||||
|  | ||||
| // frac. Implement this here so we can use subtraction | ||||
| Vectorized<double> inline Vectorized<double>::frac() const { | ||||
|   return *this - this->trunc(); | ||||
| } | ||||
|  | ||||
| // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if | ||||
| // either input is a NaN. | ||||
| template <> | ||||
| Vectorized<double> inline maximum( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   return vmaxq_f64(a, b); | ||||
| } | ||||
|  | ||||
| // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if | ||||
| // either input is a NaN. | ||||
| template <> | ||||
| Vectorized<double> inline minimum( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   return vminq_f64(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline clamp( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& min, | ||||
|     const Vectorized<double>& max) { | ||||
|   return vminq_f64(max, vmaxq_f64(min, a)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline clamp_max( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& max) { | ||||
|   return vminq_f64(max, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline clamp_min( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& min) { | ||||
|   return vmaxq_f64(min, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline operator&( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   return vreinterpretq_f64_u64( | ||||
|       vandq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b))); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline operator|( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   return vreinterpretq_f64_u64( | ||||
|       vorrq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b))); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline operator^( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   return vreinterpretq_f64_u64( | ||||
|       veorq_u64(vreinterpretq_u64_f64(a), vreinterpretq_u64_f64(b))); | ||||
| } | ||||
|  | ||||
| inline Vectorized<double> Vectorized<double>::eq( | ||||
|     const Vectorized<double>& other) const { | ||||
|   return (*this == other) & Vectorized<double>(1.0); | ||||
| } | ||||
|  | ||||
| inline Vectorized<double> Vectorized<double>::ne( | ||||
|     const Vectorized<double>& other) const { | ||||
|   return (*this != other) & Vectorized<double>(1.0); | ||||
| } | ||||
|  | ||||
| inline Vectorized<double> Vectorized<double>::gt( | ||||
|     const Vectorized<double>& other) const { | ||||
|   return (*this > other) & Vectorized<double>(1.0); | ||||
| } | ||||
|  | ||||
| inline Vectorized<double> Vectorized<double>::ge( | ||||
|     const Vectorized<double>& other) const { | ||||
|   return (*this >= other) & Vectorized<double>(1.0); | ||||
| } | ||||
|  | ||||
| inline Vectorized<double> Vectorized<double>::lt( | ||||
|     const Vectorized<double>& other) const { | ||||
|   return (*this < other) & Vectorized<double>(1.0); | ||||
| } | ||||
|  | ||||
| inline Vectorized<double> Vectorized<double>::le( | ||||
|     const Vectorized<double>& other) const { | ||||
|   return (*this <= other) & Vectorized<double>(1.0); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline fmadd( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b, | ||||
|     const Vectorized<double>& c) { | ||||
|   return vfmaq_f64(c, a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline fnmadd( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b, | ||||
|     const Vectorized<double>& c) { | ||||
|   return vfmsq_f64(c, a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline fmsub( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b, | ||||
|     const Vectorized<double>& c) { | ||||
|   return vfmaq_f64(vnegq_f64(c), a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<double> inline fnmsub( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b, | ||||
|     const Vectorized<double>& c) { | ||||
|   return vfmsq_f64(vnegq_f64(c), a, b); | ||||
| } | ||||
|  | ||||
| } // namespace CPU_CAPABILITY | ||||
| } // namespace at::vec | ||||
| @ -1,794 +0,0 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/cpu/vec/intrinsics.h> | ||||
| #include <ATen/cpu/vec/vec_base.h> | ||||
| #include <c10/macros/Macros.h> | ||||
| #include <c10/util/irange.h> | ||||
|  | ||||
| namespace at::vec { | ||||
| // Note [CPU_CAPABILITY namespace] | ||||
| // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
| // This header, and all of its subheaders, will be compiled with | ||||
| // different architecture flags for each supported set of vector | ||||
| // intrinsics. So we need to make sure they aren't inadvertently | ||||
| // linked together. We do this by declaring objects in an `inline | ||||
| // namespace` which changes the name mangling, but can still be | ||||
| // accessed as `at::vec`. | ||||
| inline namespace CPU_CAPABILITY { | ||||
|  | ||||
| #define VEC_INT_NEON_TEMPLATE(vl, bit)                                        \ | ||||
|   template <>                                                                 \ | ||||
|   struct is_vec_specialized_for<int##bit##_t> : std::bool_constant<true> {};  \ | ||||
|                                                                               \ | ||||
|   template <>                                                                 \ | ||||
|   class Vectorized<int##bit##_t> {                                            \ | ||||
|     using neon_type = int##bit##x##vl##_t;                                    \ | ||||
|                                                                               \ | ||||
|    private:                                                                   \ | ||||
|     neon_type values;                                                         \ | ||||
|                                                                               \ | ||||
|    public:                                                                    \ | ||||
|     using value_type = int##bit##_t;                                          \ | ||||
|     using size_type = int;                                                    \ | ||||
|     static constexpr size_type size() {                                       \ | ||||
|       return vl;                                                              \ | ||||
|     }                                                                         \ | ||||
|     Vectorized() {                                                            \ | ||||
|       values = vdupq_n_s##bit(0);                                             \ | ||||
|     }                                                                         \ | ||||
|     Vectorized(neon_type v) : values(v) {}                                    \ | ||||
|     Vectorized(int##bit##_t val);                                             \ | ||||
|     template <                                                                \ | ||||
|         typename... Args,                                                     \ | ||||
|         typename = std::enable_if_t<(sizeof...(Args) == size())>>             \ | ||||
|     Vectorized(Args... vals) {                                                \ | ||||
|       __at_align__ int##bit##_t buffer[size()] = {vals...};                   \ | ||||
|       values = vld1q_s##bit(buffer);                                          \ | ||||
|     }                                                                         \ | ||||
|     operator neon_type() const {                                              \ | ||||
|       return values;                                                          \ | ||||
|     }                                                                         \ | ||||
|     static Vectorized<int##bit##_t> loadu(                                    \ | ||||
|         const void* ptr,                                                      \ | ||||
|         int64_t count = size());                                              \ | ||||
|     void store(void* ptr, int64_t count = size()) const;                      \ | ||||
|     template <int64_t mask>                                                   \ | ||||
|     static Vectorized<int##bit##_t> blend(                                    \ | ||||
|         const Vectorized<int##bit##_t>& a,                                    \ | ||||
|         const Vectorized<int##bit##_t>& b);                                   \ | ||||
|     static Vectorized<int##bit##_t> blendv(                                   \ | ||||
|         const Vectorized<int##bit##_t>& a,                                    \ | ||||
|         const Vectorized<int##bit##_t>& b,                                    \ | ||||
|         const Vectorized<int##bit##_t>& mask_) {                              \ | ||||
|       return vbslq_s##bit(vreinterpretq_u##bit##_s##bit(mask_.values), b, a); \ | ||||
|     }                                                                         \ | ||||
|     template <typename step_t>                                                \ | ||||
|     static Vectorized<int##bit##_t> arange(                                   \ | ||||
|         value_type base = 0,                                                  \ | ||||
|         step_t step = static_cast<step_t>(1));                                \ | ||||
|     static Vectorized<int##bit##_t> set(                                      \ | ||||
|         const Vectorized<int##bit##_t>& a,                                    \ | ||||
|         const Vectorized<int##bit##_t>& b,                                    \ | ||||
|         int64_t count = size());                                              \ | ||||
|     const int##bit##_t& operator[](int idx) const = delete;                   \ | ||||
|     int##bit##_t& operator[](int idx) = delete;                               \ | ||||
|     Vectorized<int##bit##_t> abs() const {                                    \ | ||||
|       return vabsq_s##bit(values);                                            \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<int##bit##_t> real() const {                                   \ | ||||
|       return values;                                                          \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<int##bit##_t> imag() const {                                   \ | ||||
|       return vdupq_n_s##bit(0);                                               \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<int##bit##_t> conj() const {                                   \ | ||||
|       return values;                                                          \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<int##bit##_t> neg() const {                                    \ | ||||
|       return vnegq_s##bit(values);                                            \ | ||||
|     }                                                                         \ | ||||
|     int##bit##_t reduce_add() const {                                         \ | ||||
|       return vaddvq_s##bit(values);                                           \ | ||||
|     }                                                                         \ | ||||
|     int##bit##_t reduce_max() const;                                          \ | ||||
|     Vectorized<int##bit##_t> operator==(                                      \ | ||||
|         const Vectorized<int##bit##_t>& other) const {                        \ | ||||
|       return Vectorized<value_type>(                                          \ | ||||
|           vreinterpretq_s##bit##_u##bit(vceqq_s##bit(values, other.values))); \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<int##bit##_t> operator!=(                                      \ | ||||
|         const Vectorized<int##bit##_t>& other) const;                         \ | ||||
|     Vectorized<int##bit##_t> operator<(                                       \ | ||||
|         const Vectorized<int##bit##_t>& other) const {                        \ | ||||
|       return Vectorized<value_type>(                                          \ | ||||
|           vreinterpretq_s##bit##_u##bit(vcltq_s##bit(values, other.values))); \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<int##bit##_t> operator<=(                                      \ | ||||
|         const Vectorized<int##bit##_t>& other) const {                        \ | ||||
|       return Vectorized<value_type>(                                          \ | ||||
|           vreinterpretq_s##bit##_u##bit(vcleq_s##bit(values, other.values))); \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<int##bit##_t> operator>(                                       \ | ||||
|         const Vectorized<int##bit##_t>& other) const {                        \ | ||||
|       return Vectorized<value_type>(                                          \ | ||||
|           vreinterpretq_s##bit##_u##bit(vcgtq_s##bit(values, other.values))); \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<int##bit##_t> operator>=(                                      \ | ||||
|         const Vectorized<int##bit##_t>& other) const {                        \ | ||||
|       return Vectorized<value_type>(                                          \ | ||||
|           vreinterpretq_s##bit##_u##bit(vcgeq_s##bit(values, other.values))); \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<int##bit##_t> eq(const Vectorized<int##bit##_t>& other) const; \ | ||||
|     Vectorized<int##bit##_t> ne(const Vectorized<int##bit##_t>& other) const; \ | ||||
|     Vectorized<int##bit##_t> gt(const Vectorized<int##bit##_t>& other) const; \ | ||||
|     Vectorized<int##bit##_t> ge(const Vectorized<int##bit##_t>& other) const; \ | ||||
|     Vectorized<int##bit##_t> lt(const Vectorized<int##bit##_t>& other) const; \ | ||||
|     Vectorized<int##bit##_t> le(const Vectorized<int##bit##_t>& other) const; \ | ||||
|   };                                                                          \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<int##bit##_t> inline operator+(                                  \ | ||||
|       const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \ | ||||
|     return vaddq_s##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<int##bit##_t> inline operator-(                                  \ | ||||
|       const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \ | ||||
|     return vsubq_s##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<int##bit##_t> inline operator&(                                  \ | ||||
|       const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \ | ||||
|     return vandq_s##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<int##bit##_t> inline operator|(                                  \ | ||||
|       const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \ | ||||
|     return vorrq_s##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<int##bit##_t> inline operator^(                                  \ | ||||
|       const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \ | ||||
|     return veorq_s##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::eq(               \ | ||||
|       const Vectorized<int##bit##_t>& other) const {                          \ | ||||
|     return (*this == other) & Vectorized<int##bit##_t>(1);                    \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ne(               \ | ||||
|       const Vectorized<int##bit##_t>& other) const {                          \ | ||||
|     return (*this != other) & Vectorized<int##bit##_t>(1);                    \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::gt(               \ | ||||
|       const Vectorized<int##bit##_t>& other) const {                          \ | ||||
|     return (*this > other) & Vectorized<int##bit##_t>(1);                     \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ge(               \ | ||||
|       const Vectorized<int##bit##_t>& other) const {                          \ | ||||
|     return (*this >= other) & Vectorized<int##bit##_t>(1);                    \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::lt(               \ | ||||
|       const Vectorized<int##bit##_t>& other) const {                          \ | ||||
|     return (*this < other) & Vectorized<int##bit##_t>(1);                     \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::le(               \ | ||||
|       const Vectorized<int##bit##_t>& other) const {                          \ | ||||
|     return (*this <= other) & Vectorized<int##bit##_t>(1);                    \ | ||||
|   } | ||||
|  | ||||
| VEC_INT_NEON_TEMPLATE(2, 64) | ||||
| VEC_INT_NEON_TEMPLATE(4, 32) | ||||
| VEC_INT_NEON_TEMPLATE(8, 16) | ||||
| VEC_INT_NEON_TEMPLATE(16, 8) | ||||
|  | ||||
| inline int32_t Vectorized<int32_t>::reduce_max() const { | ||||
|   return vmaxvq_s32(values); | ||||
| } | ||||
|  | ||||
| inline int16_t Vectorized<int16_t>::reduce_max() const { | ||||
|   return vmaxvq_s16(values); | ||||
| } | ||||
|  | ||||
| inline int8_t Vectorized<int8_t>::reduce_max() const { | ||||
|   return vmaxvq_s8(values); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int32_t> inline operator*( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& b) { | ||||
|   return vmulq_s32(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int16_t> inline operator*( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& b) { | ||||
|   return vmulq_s16(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int8_t> inline operator*( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& b) { | ||||
|   return vmulq_s8(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| inline Vectorized<int64_t> operator~(const Vectorized<int64_t>& a) { | ||||
|   int64x2_t val = a; | ||||
|   return ~val; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| inline Vectorized<int32_t> operator~(const Vectorized<int32_t>& a) { | ||||
|   return vmvnq_s32(a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| inline Vectorized<int16_t> operator~(const Vectorized<int16_t>& a) { | ||||
|   return vmvnq_s16(a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| inline Vectorized<int8_t> operator~(const Vectorized<int8_t>& a) { | ||||
|   return vmvnq_s8(a); | ||||
| } | ||||
|  | ||||
| inline Vectorized<int64_t> Vectorized<int64_t>::operator!=( | ||||
|     const Vectorized<int64_t>& other) const { | ||||
|   return ~(*this == other); | ||||
| } | ||||
|  | ||||
| inline Vectorized<int32_t> Vectorized<int32_t>::operator!=( | ||||
|     const Vectorized<int32_t>& other) const { | ||||
|   return ~(*this == other); | ||||
| } | ||||
|  | ||||
| inline Vectorized<int16_t> Vectorized<int16_t>::operator!=( | ||||
|     const Vectorized<int16_t>& other) const { | ||||
|   return ~(*this == other); | ||||
| } | ||||
|  | ||||
| inline Vectorized<int8_t> Vectorized<int8_t>::operator!=( | ||||
|     const Vectorized<int8_t>& other) const { | ||||
|   return ~(*this == other); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int32_t> inline minimum( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& b) { | ||||
|   return vminq_s32(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int16_t> inline minimum( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& b) { | ||||
|   return vminq_s16(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int8_t> inline minimum( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& b) { | ||||
|   return vminq_s8(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int32_t> inline maximum( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& b) { | ||||
|   return vmaxq_s32(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int16_t> inline maximum( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& b) { | ||||
|   return vmaxq_s16(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int8_t> inline maximum( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& b) { | ||||
|   return vmaxq_s8(a, b); | ||||
| } | ||||
|  | ||||
| template <int64_t mask> | ||||
| Vectorized<int64_t> Vectorized<int64_t>::blend( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& b) { | ||||
|   // Build an array of flags: each bit of element is 1 if the corresponding bit | ||||
|   // in 'mask' is set, 0 otherwise. | ||||
|   uint64x2_t maskArray = { | ||||
|       (mask & 1LL) ? 0xFFFFFFFFFFFFFFFF : 0, | ||||
|       (mask & 2LL) ? 0xFFFFFFFFFFFFFFFF : 0}; | ||||
|   // Use BSL to select elements from b where the mask is 1, else from a | ||||
|   return vbslq_s64(maskArray, b.values, a.values); | ||||
| } | ||||
|  | ||||
| template <int64_t mask> | ||||
| Vectorized<int32_t> Vectorized<int32_t>::blend( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& b) { | ||||
|   // Build an array of flags: each bit of element is 1 if the corresponding bit | ||||
|   // in 'mask' is set, 0 otherwise. | ||||
|   uint32x4_t maskArray = { | ||||
|       (mask & 1LL) ? 0xFFFFFFFF : 0, | ||||
|       (mask & 2LL) ? 0xFFFFFFFF : 0, | ||||
|       (mask & 4LL) ? 0xFFFFFFFF : 0, | ||||
|       (mask & 8LL) ? 0xFFFFFFFF : 0}; | ||||
|   // Use BSL to select elements from b where the mask is 1, else from a | ||||
|   return vbslq_s32(maskArray, b.values, a.values); | ||||
| } | ||||
|  | ||||
| template <int64_t mask> | ||||
| Vectorized<int16_t> Vectorized<int16_t>::blend( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& b) { | ||||
|   // Build an array of flags: each bit of element is 1 if the corresponding bit | ||||
|   // in 'mask' is set, 0 otherwise. | ||||
|   uint16x8_t maskArray = { | ||||
|       (mask & 1LL) ? 0xFFFF : 0, | ||||
|       (mask & 2LL) ? 0xFFFF : 0, | ||||
|       (mask & 4LL) ? 0xFFFF : 0, | ||||
|       (mask & 8LL) ? 0xFFFF : 0, | ||||
|       (mask & 16LL) ? 0xFFFF : 0, | ||||
|       (mask & 32LL) ? 0xFFFF : 0, | ||||
|       (mask & 64LL) ? 0xFFFF : 0, | ||||
|       (mask & 128LL) ? 0xFFFF : 0}; | ||||
|   // Use BSL to select elements from b where the mask is 1, else from a | ||||
|   return vbslq_s16(maskArray, b.values, a.values); | ||||
| } | ||||
|  | ||||
| template <int64_t mask> | ||||
| Vectorized<int8_t> Vectorized<int8_t>::blend( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& b) { | ||||
|   // Build an array of flags: each bit of element is 1 if the corresponding bit | ||||
|   // in 'mask' is set, 0 otherwise. | ||||
|   uint8x16_t maskArray = { | ||||
|       (mask & 1LL) ? 0xFF : 0, | ||||
|       (mask & 2LL) ? 0xFF : 0, | ||||
|       (mask & 4LL) ? 0xFF : 0, | ||||
|       (mask & 8LL) ? 0xFF : 0, | ||||
|       (mask & 16LL) ? 0xFF : 0, | ||||
|       (mask & 32LL) ? 0xFF : 0, | ||||
|       (mask & 64LL) ? 0xFF : 0, | ||||
|       (mask & 128LL) ? 0xFF : 0, | ||||
|       (mask & 256LL) ? 0xFF : 0, | ||||
|       (mask & 512LL) ? 0xFF : 0, | ||||
|       (mask & 1024LL) ? 0xFF : 0, | ||||
|       (mask & 2048LL) ? 0xFF : 0, | ||||
|       (mask & 4096LL) ? 0xFF : 0, | ||||
|       (mask & 8192LL) ? 0xFF : 0, | ||||
|       (mask & 16384LL) ? 0xFF : 0, | ||||
|       (mask & 32768LL) ? 0xFF : 0}; | ||||
|   // Use BSL to select elements from b where the mask is 1, else from a | ||||
|   return vbslq_s8(maskArray, b.values, a.values); | ||||
| } | ||||
|  | ||||
| #define VEC_INT_NEON_OPS(vl, bit)                                             \ | ||||
|   inline Vectorized<int##bit##_t>::Vectorized(int##bit##_t val) {             \ | ||||
|     values = vdupq_n_s##bit(val);                                             \ | ||||
|   }                                                                           \ | ||||
|   inline Vectorized<int##bit##_t> Vectorized<int##bit##_t>::loadu(            \ | ||||
|       const void* ptr, int64_t count) {                                       \ | ||||
|     if (count == size()) {                                                    \ | ||||
|       return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(ptr));        \ | ||||
|     } else {                                                                  \ | ||||
|       __at_align__ int##bit##_t tmp_values[size()];                           \ | ||||
|       for (const auto i : c10::irange(size())) {                              \ | ||||
|         tmp_values[i] = 0;                                                    \ | ||||
|       }                                                                       \ | ||||
|       std::memcpy(                                                            \ | ||||
|           tmp_values,                                                         \ | ||||
|           reinterpret_cast<const int##bit##_t*>(ptr),                         \ | ||||
|           count * sizeof(int##bit##_t));                                      \ | ||||
|       return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(tmp_values)); \ | ||||
|     }                                                                         \ | ||||
|   }                                                                           \ | ||||
|   inline void Vectorized<int##bit##_t>::store(void* ptr, int64_t count)       \ | ||||
|       const {                                                                 \ | ||||
|     if (count == size()) {                                                    \ | ||||
|       vst1q_s##bit(reinterpret_cast<int##bit##_t*>(ptr), values);             \ | ||||
|     } else {                                                                  \ | ||||
|       int##bit##_t tmp_values[size()];                                        \ | ||||
|       vst1q_s##bit(reinterpret_cast<int##bit##_t*>(tmp_values), values);      \ | ||||
|       std::memcpy(ptr, tmp_values, count * sizeof(int##bit##_t));             \ | ||||
|     }                                                                         \ | ||||
|   } | ||||
|  | ||||
| VEC_INT_NEON_OPS(2, 64) | ||||
| VEC_INT_NEON_OPS(4, 32) | ||||
| VEC_INT_NEON_OPS(8, 16) | ||||
| VEC_INT_NEON_OPS(16, 8) | ||||
|  | ||||
| template <> | ||||
| Vectorized<int64_t> inline operator*( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& b) { | ||||
|   int64x2_t x = a; | ||||
|   int64x2_t y = b; | ||||
|   return x * y; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int64_t> inline operator/( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& b) { | ||||
|   int64x2_t x = a; | ||||
|   int64x2_t y = b; | ||||
|   return x / y; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int32_t> inline operator/( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& b) { | ||||
|   int32x4_t x = a; | ||||
|   int32x4_t y = b; | ||||
|   return x / y; | ||||
| } | ||||
|  | ||||
| inline int64_t Vectorized<int64_t>::reduce_max() const { | ||||
|   return std::max(values[0], values[1]); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int64_t> inline minimum( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& b) { | ||||
|   int64x2_t x = a; | ||||
|   int64x2_t y = b; | ||||
|   return {std::min(x[0], y[0]), std::min(x[1], y[1])}; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int64_t> inline maximum( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& b) { | ||||
|   int64x2_t x = a; | ||||
|   int64x2_t y = b; | ||||
|   return {std::max(x[0], y[0]), std::max(x[1], y[1])}; | ||||
| } | ||||
|  | ||||
| template <typename step_t> | ||||
| inline Vectorized<int64_t> Vectorized<int64_t>::arange( | ||||
|     int64_t base, | ||||
|     step_t step) { | ||||
|   const Vectorized<int64_t> base_vec(base); | ||||
|   const Vectorized<int64_t> step_vec(step); | ||||
|   const int64x2_t step_sizes = {0, 1}; | ||||
|   return base_vec.values + step_sizes * step_vec.values; | ||||
| } | ||||
|  | ||||
| template <typename step_t> | ||||
| inline Vectorized<int32_t> Vectorized<int32_t>::arange( | ||||
|     int32_t base, | ||||
|     step_t step) { | ||||
|   const Vectorized<int32_t> base_vec(base); | ||||
|   const Vectorized<int32_t> step_vec(step); | ||||
|   const int32x4_t step_sizes = {0, 1, 2, 3}; | ||||
|   return vmlaq_s32(base_vec, step_sizes, step_vec); | ||||
| } | ||||
|  | ||||
| template <typename step_t> | ||||
| inline Vectorized<int16_t> Vectorized<int16_t>::arange( | ||||
|     int16_t base, | ||||
|     step_t step) { | ||||
|   const Vectorized<int16_t> base_vec(base); | ||||
|   const Vectorized<int16_t> step_vec(step); | ||||
|   const int16x8_t step_sizes = {0, 1, 2, 3, 4, 5, 6, 7}; | ||||
|   return vmlaq_s16(base_vec, step_sizes, step_vec); | ||||
| } | ||||
|  | ||||
| template <typename step_t> | ||||
| inline Vectorized<int8_t> Vectorized<int8_t>::arange(int8_t base, step_t step) { | ||||
|   const Vectorized<int8_t> base_vec(base); | ||||
|   const Vectorized<int8_t> step_vec(step); | ||||
|   const int8x16_t step_sizes = { | ||||
|       0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; | ||||
|   return vmlaq_s8(base_vec, step_sizes, step_vec); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int64_t> inline operator>>( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& b) { | ||||
|   int64x2_t x = a; | ||||
|   int64x2_t y = b; | ||||
|   uint64x2_t u = vreinterpretq_u64_s64(y); | ||||
|   uint64x2_t z = {std::min(u[0], (uint64_t)63), std::min(u[1], (uint64_t)63)}; | ||||
|   return x >> vreinterpretq_s64_u64(z); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int32_t> inline operator>>( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& b) { | ||||
|   int32x4_t x = a; | ||||
|   int32x4_t y = b; | ||||
|   uint32x4_t bound = vdupq_n_u32(31); | ||||
|   uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound); | ||||
|   return x >> vreinterpretq_s32_u32(z); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int16_t> inline operator>>( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& b) { | ||||
|   int16x8_t x = a; | ||||
|   int16x8_t y = b; | ||||
|   uint16x8_t bound = vdupq_n_u16(15); | ||||
|   uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound); | ||||
|   return x >> vreinterpretq_s16_u16(z); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int8_t> inline operator>>( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& b) { | ||||
|   int8x16_t x = a; | ||||
|   int8x16_t y = b; | ||||
|   uint8x16_t bound = vdupq_n_u8(7); | ||||
|   int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound)); | ||||
|   return x >> z; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int64_t> inline operator<<( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& b) { | ||||
|   int64x2_t y = b; | ||||
|   uint64x2_t u = vreinterpretq_u64_s64(y); | ||||
|   uint64x2_t z = {std::min(u[0], (uint64_t)64), std::min(u[1], (uint64_t)64)}; | ||||
|   return vshlq_s64(a, vreinterpretq_s64_u64(z)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int32_t> inline operator<<( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& b) { | ||||
|   int32x4_t y = b; | ||||
|   uint32x4_t bound = vdupq_n_u32(32); | ||||
|   uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound); | ||||
|   return vshlq_s32(a, vreinterpretq_s32_u32(z)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int16_t> inline operator<<( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& b) { | ||||
|   int16x8_t y = b; | ||||
|   uint16x8_t bound = vdupq_n_u16(16); | ||||
|   uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound); | ||||
|   return vshlq_s16(a, vreinterpretq_s16_u16(z)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int8_t> inline operator<<( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& b) { | ||||
|   int8x16_t y = b; | ||||
|   uint8x16_t bound = vdupq_n_u8(8); | ||||
|   int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound)); | ||||
|   return vshlq_s8(a, z); | ||||
| } | ||||
|  | ||||
| inline Vectorized<int64_t> Vectorized<int64_t>::set( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& b, | ||||
|     int64_t count) { | ||||
|   if (count == 0) { | ||||
|     return a; | ||||
|   } else if (count >= 2) { | ||||
|     return b; | ||||
|   } else { | ||||
|     int64x2_t c = {b.values[0], a.values[1]}; | ||||
|     return c; | ||||
|   } | ||||
| } | ||||
|  | ||||
| inline Vectorized<int32_t> Vectorized<int32_t>::set( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& b, | ||||
|     int64_t count) { | ||||
|   if (count == 0) { | ||||
|     return a; | ||||
|   } else if (count >= 4) { | ||||
|     return b; | ||||
|   } else { | ||||
|     // Build an array of flags: each bit of element is 1 if the corresponding | ||||
|     // bit in 'mask' is set, 0 otherwise. | ||||
|     uint32x4_t maskArray = { | ||||
|         (count >= 1LL) ? 0xFFFFFFFF : 0, | ||||
|         (count >= 2LL) ? 0xFFFFFFFF : 0, | ||||
|         (count >= 3LL) ? 0xFFFFFFFF : 0, | ||||
|         0}; | ||||
|     // Use BSL to select elements from b where the mask is 1, else from a | ||||
|     return vbslq_s32(maskArray, b.values, a.values); | ||||
|   } | ||||
| } | ||||
|  | ||||
| inline Vectorized<int16_t> Vectorized<int16_t>::set( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& b, | ||||
|     int64_t count) { | ||||
|   if (count == 0) { | ||||
|     return a; | ||||
|   } else if (count >= 8) { | ||||
|     return b; | ||||
|   } else { | ||||
|     // Build an array of flags: each bit of element is 1 if the corresponding | ||||
|     // bit in 'mask' is set, 0 otherwise. | ||||
|     uint16x8_t maskArray = { | ||||
|         static_cast<uint16_t>((count >= 1LL) ? 0xFFFF : 0), | ||||
|         static_cast<uint16_t>((count >= 2LL) ? 0xFFFF : 0), | ||||
|         static_cast<uint16_t>((count >= 3LL) ? 0xFFFF : 0), | ||||
|         static_cast<uint16_t>((count >= 4LL) ? 0xFFFF : 0), | ||||
|         static_cast<uint16_t>((count >= 5LL) ? 0xFFFF : 0), | ||||
|         static_cast<uint16_t>((count >= 6LL) ? 0xFFFF : 0), | ||||
|         static_cast<uint16_t>((count >= 7LL) ? 0xFFFF : 0), | ||||
|         0}; | ||||
|     // Use BSL to select elements from b where the mask is 1, else from a | ||||
|     return vbslq_s16(maskArray, b.values, a.values); | ||||
|   } | ||||
| } | ||||
|  | ||||
| inline Vectorized<int8_t> Vectorized<int8_t>::set( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& b, | ||||
|     int64_t count) { | ||||
|   if (count == 0) { | ||||
|     return a; | ||||
|   } else if (count >= 16) { | ||||
|     return b; | ||||
|   } else { | ||||
|     // Build an array of flags: each bit of element is 1 if the corresponding | ||||
|     // bit in 'mask' is set, 0 otherwise. | ||||
|     uint8x16_t maskArray = { | ||||
|         static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0), | ||||
|         0}; | ||||
|  | ||||
|     // Use BSL to select elements from b where the mask is 1, else from a | ||||
|     return vbslq_s8(maskArray, b.values, a.values); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int16_t> inline operator/( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& b) { | ||||
|   Vectorized<int32_t> highBitsA = vmovl_high_s16(a); | ||||
|   Vectorized<int32_t> highBitsB = vmovl_high_s16(b); | ||||
|   Vectorized<int32_t> lowBitsA = vmovl_s16(vget_low_s16(a)); | ||||
|   Vectorized<int32_t> lowBitsB = vmovl_s16(vget_low_s16(b)); | ||||
|   int32x4_t highBitsResult = highBitsA / highBitsB; | ||||
|   int32x4_t lowBitsResult = lowBitsA / lowBitsB; | ||||
|   return vuzp1q_s16( | ||||
|       vreinterpretq_s16_s32(lowBitsResult), | ||||
|       vreinterpretq_s16_s32(highBitsResult)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int8_t> inline operator/( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& b) { | ||||
|   Vectorized<int16_t> highBitsA = vmovl_high_s8(a); | ||||
|   Vectorized<int16_t> highBitsB = vmovl_high_s8(b); | ||||
|   Vectorized<int16_t> lowBitsA = vmovl_s8(vget_low_s8(a)); | ||||
|   Vectorized<int16_t> lowBitsB = vmovl_s8(vget_low_s8(b)); | ||||
|   int16x8_t highBitsResult = highBitsA / highBitsB; | ||||
|   int16x8_t lowBitsResult = lowBitsA / lowBitsB; | ||||
|   return vuzp1q_s8( | ||||
|       vreinterpretq_s8_s16(lowBitsResult), | ||||
|       vreinterpretq_s8_s16(highBitsResult)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int64_t> inline clamp( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& min, | ||||
|     const Vectorized<int64_t>& max) { | ||||
|   return minimum(max, maximum(min, a)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int32_t> inline clamp( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& min, | ||||
|     const Vectorized<int32_t>& max) { | ||||
|   return minimum(max, maximum(min, a)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int16_t> inline clamp( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& min, | ||||
|     const Vectorized<int16_t>& max) { | ||||
|   return minimum(max, maximum(min, a)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int8_t> inline clamp( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& min, | ||||
|     const Vectorized<int8_t>& max) { | ||||
|   return minimum(max, maximum(min, a)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int64_t> inline clamp_max( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& max) { | ||||
|   return minimum(max, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int32_t> inline clamp_max( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& max) { | ||||
|   return minimum(max, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int16_t> inline clamp_max( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& max) { | ||||
|   return minimum(max, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int8_t> inline clamp_max( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& max) { | ||||
|   return minimum(max, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int64_t> inline clamp_min( | ||||
|     const Vectorized<int64_t>& a, | ||||
|     const Vectorized<int64_t>& min) { | ||||
|   return maximum(min, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int32_t> inline clamp_min( | ||||
|     const Vectorized<int32_t>& a, | ||||
|     const Vectorized<int32_t>& min) { | ||||
|   return maximum(min, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int16_t> inline clamp_min( | ||||
|     const Vectorized<int16_t>& a, | ||||
|     const Vectorized<int16_t>& min) { | ||||
|   return maximum(min, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<int8_t> inline clamp_min( | ||||
|     const Vectorized<int8_t>& a, | ||||
|     const Vectorized<int8_t>& min) { | ||||
|   return maximum(min, a); | ||||
| } | ||||
|  | ||||
| } // namespace CPU_CAPABILITY | ||||
| } // namespace at::vec | ||||
| @ -1,378 +0,0 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/cpu/vec/intrinsics.h> | ||||
| #include <ATen/cpu/vec/vec_base.h> | ||||
| #include <c10/macros/Macros.h> | ||||
| #include <c10/util/irange.h> | ||||
|  | ||||
| namespace at::vec { | ||||
| // Note [CPU_CAPABILITY namespace] | ||||
| // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
| // This header, and all of its subheaders, will be compiled with | ||||
| // different architecture flags for each supported set of vector | ||||
| // intrinsics. So we need to make sure they aren't inadvertently | ||||
| // linked together. We do this by declaring objects in an `inline | ||||
| // namespace` which changes the name mangling, but can still be | ||||
| // accessed as `at::vec`. | ||||
| inline namespace CPU_CAPABILITY { | ||||
|  | ||||
| #define VEC_UINT_NEON_TEMPLATE(vl, bit)                                       \ | ||||
|   template <>                                                                 \ | ||||
|   struct is_vec_specialized_for<uint##bit##_t> : std::bool_constant<true> {}; \ | ||||
|                                                                               \ | ||||
|   template <>                                                                 \ | ||||
|   class Vectorized<uint##bit##_t> {                                           \ | ||||
|     using neon_type = uint##bit##x##vl##_t;                                   \ | ||||
|                                                                               \ | ||||
|    private:                                                                   \ | ||||
|     neon_type values;                                                         \ | ||||
|                                                                               \ | ||||
|    public:                                                                    \ | ||||
|     using value_type = uint##bit##_t;                                         \ | ||||
|     using size_type = int;                                                    \ | ||||
|     static constexpr size_type size() {                                       \ | ||||
|       return vl;                                                              \ | ||||
|     }                                                                         \ | ||||
|     Vectorized() {                                                            \ | ||||
|       values = vdupq_n_u##bit(0);                                             \ | ||||
|     }                                                                         \ | ||||
|     Vectorized(neon_type v) : values(v) {}                                    \ | ||||
|     Vectorized(uint##bit##_t val);                                            \ | ||||
|     template <                                                                \ | ||||
|         typename... Args,                                                     \ | ||||
|         typename = std::enable_if_t<(sizeof...(Args) == size())>>             \ | ||||
|     Vectorized(Args... vals) {                                                \ | ||||
|       __at_align__ uint##bit##_t buffer[size()] = {vals...};                  \ | ||||
|       values = vld1q_u##bit(buffer);                                          \ | ||||
|     }                                                                         \ | ||||
|     operator neon_type() const {                                              \ | ||||
|       return values;                                                          \ | ||||
|     }                                                                         \ | ||||
|     static Vectorized<uint##bit##_t> loadu(                                   \ | ||||
|         const void* ptr,                                                      \ | ||||
|         uint64_t count = size());                                             \ | ||||
|     void store(void* ptr, uint64_t count = size()) const;                     \ | ||||
|     template <uint64_t mask>                                                  \ | ||||
|     static Vectorized<uint##bit##_t> blend(                                   \ | ||||
|         const Vectorized<uint##bit##_t>& a,                                   \ | ||||
|         const Vectorized<uint##bit##_t>& b);                                  \ | ||||
|     static Vectorized<uint##bit##_t> blendv(                                  \ | ||||
|         const Vectorized<uint##bit##_t>& a,                                   \ | ||||
|         const Vectorized<uint##bit##_t>& b,                                   \ | ||||
|         const Vectorized<uint##bit##_t>& mask_) {                             \ | ||||
|       return vbslq_u##bit(mask_.values, b, a);                                \ | ||||
|     }                                                                         \ | ||||
|     template <typename step_t>                                                \ | ||||
|     static Vectorized<uint##bit##_t> arange(                                  \ | ||||
|         value_type base = 0,                                                  \ | ||||
|         step_t step = static_cast<step_t>(1));                                \ | ||||
|     static Vectorized<uint##bit##_t> set(                                     \ | ||||
|         const Vectorized<uint##bit##_t>& a,                                   \ | ||||
|         const Vectorized<uint##bit##_t>& b,                                   \ | ||||
|         uint64_t count = size());                                             \ | ||||
|     const uint##bit##_t& operator[](uint idx) const = delete;                 \ | ||||
|     uint##bit##_t& operator[](uint idx) = delete;                             \ | ||||
|     Vectorized<uint##bit##_t> abs() const {                                   \ | ||||
|       return values;                                                          \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<uint##bit##_t> real() const {                                  \ | ||||
|       return values;                                                          \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<uint##bit##_t> imag() const {                                  \ | ||||
|       return vdupq_n_u##bit(0);                                               \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<uint##bit##_t> conj() const {                                  \ | ||||
|       return values;                                                          \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<uint##bit##_t> neg() const {                                   \ | ||||
|       return vreinterpretq_u##bit##_s##bit(                                   \ | ||||
|           vnegq_s##bit(vreinterpretq_s##bit##_u##bit(values)));               \ | ||||
|     }                                                                         \ | ||||
|     uint##bit##_t reduce_add() const {                                        \ | ||||
|       return vaddvq_u##bit(values);                                           \ | ||||
|     }                                                                         \ | ||||
|     uint##bit##_t reduce_max() const;                                         \ | ||||
|     Vectorized<uint##bit##_t> operator==(                                     \ | ||||
|         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||
|       return Vectorized<value_type>(vceqq_u##bit(values, other.values));      \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<uint##bit##_t> operator!=(                                     \ | ||||
|         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||
|     Vectorized<uint##bit##_t> operator<(                                      \ | ||||
|         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||
|       return Vectorized<value_type>(vcltq_u##bit(values, other.values));      \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<uint##bit##_t> operator<=(                                     \ | ||||
|         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||
|       return Vectorized<value_type>(vcleq_u##bit(values, other.values));      \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<uint##bit##_t> operator>(                                      \ | ||||
|         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||
|       return Vectorized<value_type>(vcgtq_u##bit(values, other.values));      \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<uint##bit##_t> operator>=(                                     \ | ||||
|         const Vectorized<uint##bit##_t>& other) const {                       \ | ||||
|       return Vectorized<value_type>(vcgeq_u##bit(values, other.values));      \ | ||||
|     }                                                                         \ | ||||
|     Vectorized<uint##bit##_t> eq(                                             \ | ||||
|         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||
|     Vectorized<uint##bit##_t> ne(                                             \ | ||||
|         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||
|     Vectorized<uint##bit##_t> gt(                                             \ | ||||
|         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||
|     Vectorized<uint##bit##_t> ge(                                             \ | ||||
|         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||
|     Vectorized<uint##bit##_t> lt(                                             \ | ||||
|         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||
|     Vectorized<uint##bit##_t> le(                                             \ | ||||
|         const Vectorized<uint##bit##_t>& other) const;                        \ | ||||
|   };                                                                          \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<uint##bit##_t> inline operator+(                                 \ | ||||
|       const Vectorized<uint##bit##_t>& a,                                     \ | ||||
|       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||
|     return vaddq_u##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<uint##bit##_t> inline operator-(                                 \ | ||||
|       const Vectorized<uint##bit##_t>& a,                                     \ | ||||
|       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||
|     return vsubq_u##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<uint##bit##_t> inline operator&(                                 \ | ||||
|       const Vectorized<uint##bit##_t>& a,                                     \ | ||||
|       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||
|     return vandq_u##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<uint##bit##_t> inline operator|(                                 \ | ||||
|       const Vectorized<uint##bit##_t>& a,                                     \ | ||||
|       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||
|     return vorrq_u##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   template <>                                                                 \ | ||||
|   Vectorized<uint##bit##_t> inline operator^(                                 \ | ||||
|       const Vectorized<uint##bit##_t>& a,                                     \ | ||||
|       const Vectorized<uint##bit##_t>& b) {                                   \ | ||||
|     return veorq_u##bit(a, b);                                                \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::eq(             \ | ||||
|       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||
|     return (*this == other) & Vectorized<uint##bit##_t>(1);                   \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ne(             \ | ||||
|       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||
|     return (*this != other) & Vectorized<uint##bit##_t>(1);                   \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::gt(             \ | ||||
|       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||
|     return (*this > other) & Vectorized<uint##bit##_t>(1);                    \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::ge(             \ | ||||
|       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||
|     return (*this >= other) & Vectorized<uint##bit##_t>(1);                   \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::lt(             \ | ||||
|       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||
|     return (*this < other) & Vectorized<uint##bit##_t>(1);                    \ | ||||
|   }                                                                           \ | ||||
|   Vectorized<uint##bit##_t> inline Vectorized<uint##bit##_t>::le(             \ | ||||
|       const Vectorized<uint##bit##_t>& other) const {                         \ | ||||
|     return (*this <= other) & Vectorized<uint##bit##_t>(1);                   \ | ||||
|   } | ||||
|  | ||||
| VEC_UINT_NEON_TEMPLATE(16, 8) | ||||
|  | ||||
| inline uint8_t Vectorized<uint8_t>::reduce_max() const { | ||||
|   return vmaxvq_u8(values); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<uint8_t> inline operator*( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& b) { | ||||
|   return vmulq_u8(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) { | ||||
|   return vmvnq_u8(a); | ||||
| } | ||||
|  | ||||
| inline Vectorized<uint8_t> Vectorized<uint8_t>::operator!=( | ||||
|     const Vectorized<uint8_t>& other) const { | ||||
|   return ~(*this == other); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<uint8_t> inline minimum( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& b) { | ||||
|   return vminq_u8(a, b); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<uint8_t> inline maximum( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& b) { | ||||
|   return vmaxq_u8(a, b); | ||||
| } | ||||
|  | ||||
| template <uint64_t mask> | ||||
| Vectorized<uint8_t> Vectorized<uint8_t>::blend( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& b) { | ||||
|   // Build an array of flags: each bit of element is 1 if the corresponding bit | ||||
|   // in 'mask' is set, 0 otherwise. | ||||
|   uint8x16_t maskArray = { | ||||
|       (mask & 1LL) ? 0xFF : 0, | ||||
|       (mask & 2LL) ? 0xFF : 0, | ||||
|       (mask & 4LL) ? 0xFF : 0, | ||||
|       (mask & 8LL) ? 0xFF : 0, | ||||
|       (mask & 16LL) ? 0xFF : 0, | ||||
|       (mask & 32LL) ? 0xFF : 0, | ||||
|       (mask & 64LL) ? 0xFF : 0, | ||||
|       (mask & 128LL) ? 0xFF : 0, | ||||
|       (mask & 256LL) ? 0xFF : 0, | ||||
|       (mask & 512LL) ? 0xFF : 0, | ||||
|       (mask & 1024LL) ? 0xFF : 0, | ||||
|       (mask & 2048LL) ? 0xFF : 0, | ||||
|       (mask & 4096LL) ? 0xFF : 0, | ||||
|       (mask & 8192LL) ? 0xFF : 0, | ||||
|       (mask & 16384LL) ? 0xFF : 0, | ||||
|       (mask & 32768LL) ? 0xFF : 0}; | ||||
|   // Use BSL to select elements from b where the mask is 1, else from a | ||||
|   return vbslq_u8(maskArray, b.values, a.values); | ||||
| } | ||||
|  | ||||
| #define VEC_UINT_NEON_OPS(vl, bit)                                             \ | ||||
|   inline Vectorized<uint##bit##_t>::Vectorized(uint##bit##_t val) {            \ | ||||
|     values = vdupq_n_u##bit(val);                                              \ | ||||
|   }                                                                            \ | ||||
|   inline Vectorized<uint##bit##_t> Vectorized<uint##bit##_t>::loadu(           \ | ||||
|       const void* ptr, uint64_t count) {                                       \ | ||||
|     if (count == size()) {                                                     \ | ||||
|       return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(ptr));        \ | ||||
|     } else {                                                                   \ | ||||
|       __at_align__ uint##bit##_t tmp_values[size()];                           \ | ||||
|       for (const auto i : c10::irange(size())) {                               \ | ||||
|         tmp_values[i] = 0;                                                     \ | ||||
|       }                                                                        \ | ||||
|       std::memcpy(                                                             \ | ||||
|           tmp_values,                                                          \ | ||||
|           reinterpret_cast<const uint##bit##_t*>(ptr),                         \ | ||||
|           count * sizeof(uint##bit##_t));                                      \ | ||||
|       return vld1q_u##bit(reinterpret_cast<const uint##bit##_t*>(tmp_values)); \ | ||||
|     }                                                                          \ | ||||
|   }                                                                            \ | ||||
|   inline void Vectorized<uint##bit##_t>::store(void* ptr, uint64_t count)      \ | ||||
|       const {                                                                  \ | ||||
|     if (count == size()) {                                                     \ | ||||
|       vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(ptr), values);             \ | ||||
|     } else {                                                                   \ | ||||
|       uint##bit##_t tmp_values[size()];                                        \ | ||||
|       vst1q_u##bit(reinterpret_cast<uint##bit##_t*>(tmp_values), values);      \ | ||||
|       std::memcpy(ptr, tmp_values, count * sizeof(uint##bit##_t));             \ | ||||
|     }                                                                          \ | ||||
|   } | ||||
|  | ||||
| VEC_UINT_NEON_OPS(16, 8) | ||||
|  | ||||
| template <typename step_t> | ||||
| inline Vectorized<uint8_t> Vectorized<uint8_t>::arange( | ||||
|     uint8_t base, | ||||
|     step_t step) { | ||||
|   const Vectorized<uint8_t> base_vec(base); | ||||
|   const Vectorized<uint8_t> step_vec(step); | ||||
|   const uint8x16_t step_sizes = { | ||||
|       0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; | ||||
|   return vmlaq_u8(base_vec, step_sizes, step_vec); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<uint8_t> inline operator>>( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& b) { | ||||
|   uint8x16_t x = a; | ||||
|   uint8x16_t bound = vdupq_n_u8(8); | ||||
|   uint8x16_t z = vminq_u8(b, bound); | ||||
|   return x >> z; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<uint8_t> inline operator<<( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& b) { | ||||
|   uint8x16_t bound = vdupq_n_u8(8); | ||||
|   uint8x16_t z = vminq_u8(b, bound); | ||||
|   return vshlq_u8(a, vreinterpretq_s8_u8(z)); | ||||
| } | ||||
|  | ||||
| inline Vectorized<uint8_t> Vectorized<uint8_t>::set( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& b, | ||||
|     uint64_t count) { | ||||
|   if (count == 0) { | ||||
|     return a; | ||||
|   } else if (count >= 16) { | ||||
|     return b; | ||||
|   } else { | ||||
|     // Build an array of flags: each bit of element is 1 if the corresponding | ||||
|     // bit in 'mask' is set, 0 otherwise. | ||||
|     uint8x16_t maskArray = { | ||||
|         static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0), | ||||
|         static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0), | ||||
|         0}; | ||||
|  | ||||
|     // Use BSL to select elements from b where the mask is 1, else from a | ||||
|     return vbslq_u8(maskArray, b.values, a.values); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<uint8_t> inline operator/( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& b) { | ||||
|   uint8x16_t x = a; | ||||
|   uint8x16_t y = b; | ||||
|   return x / y; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<uint8_t> inline clamp( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& min, | ||||
|     const Vectorized<uint8_t>& max) { | ||||
|   return minimum(max, maximum(min, a)); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<uint8_t> inline clamp_max( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& max) { | ||||
|   return minimum(max, a); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| Vectorized<uint8_t> inline clamp_min( | ||||
|     const Vectorized<uint8_t>& a, | ||||
|     const Vectorized<uint8_t>& min) { | ||||
|   return maximum(min, a); | ||||
| } | ||||
|  | ||||
| } // namespace CPU_CAPABILITY | ||||
| } // namespace at::vec | ||||
| @ -1377,7 +1377,7 @@ Vectorized<c10::quint8> inline maximum( | ||||
| #if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) | ||||
| std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float( | ||||
|     at::vec::Vectorized<int8_t> src) { | ||||
|   auto s8x8 = vget_low_s8(src); | ||||
|   auto s8x8 = vld1_s8(src.operator const int8_t*()); | ||||
|   auto s16x8 = vmovl_s8(s8x8); | ||||
|  | ||||
|   auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8)); | ||||
| @ -1390,7 +1390,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float( | ||||
|  | ||||
| std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float( | ||||
|     at::vec::Vectorized<uint8_t> src) { | ||||
|   auto u8x8 = vget_low_u8(src); | ||||
|   auto u8x8 = vld1_u8(src.operator const uint8_t*()); | ||||
|   auto u16x8 = vmovl_u8(u8x8); | ||||
|   auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8)); | ||||
|   auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); | ||||
| @ -1402,7 +1402,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float( | ||||
|  | ||||
| Vectorized<float> inline convert_int8_half_register_to_float( | ||||
|     at::vec::Vectorized<int8_t> src) { | ||||
|   auto s8x8 = vget_low_s8(src); | ||||
|   auto s8x8 = vld1_s8(src.operator const int8_t*()); | ||||
|   auto s16x8 = vmovl_s8(s8x8); | ||||
|  | ||||
|   auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8)); | ||||
| @ -1412,7 +1412,7 @@ Vectorized<float> inline convert_int8_half_register_to_float( | ||||
|  | ||||
| Vectorized<float> inline convert_int8_half_register_to_float( | ||||
|     at::vec::Vectorized<uint8_t> src) { | ||||
|   auto u8x8 = vget_low_u8(src); | ||||
|   auto u8x8 = vld1_u8(src.operator const uint8_t*()); | ||||
|   auto u16x8 = vmovl_u8(u8x8); | ||||
|   auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); | ||||
|  | ||||
|  | ||||
| @ -16,8 +16,6 @@ | ||||
| #include <c10/util/irange.h> | ||||
| #include <c10/core/ScalarType.h> | ||||
|  | ||||
| #include <ATen/cuda/detail/BLASConstants.h> | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
| #include <c10/cuda/CUDAStream.h> | ||||
| #include <hipblaslt/hipblaslt-ext.hpp> | ||||
| @ -1956,15 +1954,13 @@ void scaled_gemm( | ||||
|     const void *result_scale_ptr, | ||||
|     int64_t result_ld, | ||||
|     ScalarType result_dtype, | ||||
|     bool use_fast_accum, | ||||
|     const std::optional<Tensor>& alpha) { | ||||
|     bool use_fast_accum) { | ||||
|   // Note: see `cublasCommonArgs` for various non-intuitive manupulations | ||||
|   // of input arguments to this function. | ||||
|   const auto computeType = CUBLAS_COMPUTE_32F; | ||||
|   const auto scaleType = CUDA_R_32F; | ||||
|   // Note: alpha_val may change later depending on user-passed argument | ||||
|   float alpha_val = 1.0; | ||||
|   float beta_val = 0.0; | ||||
|   const float alpha_val = 1.0; | ||||
|   const float beta_val = 0.0; | ||||
|   CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType); | ||||
|   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa)); | ||||
|   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); | ||||
| @ -2035,33 +2031,6 @@ void scaled_gemm( | ||||
|     computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS); | ||||
|     computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype)); | ||||
|   } | ||||
|  | ||||
|   // Handle user-passed alpha | ||||
|   float *alpha_ptr = &alpha_val; | ||||
|   float *beta_ptr = &beta_val; | ||||
|  | ||||
|   if (alpha.has_value()) { | ||||
|     auto& a = alpha.value(); | ||||
|  | ||||
|     // if device-tensor | ||||
|     if (a.is_cuda()) { | ||||
|       // NOTE: there are lifetime requirements on device-side pointers for alpha/beta -- the value must be | ||||
|       //       valid & correct until the cublas call finishes (not is scheduled like host-side values). Thus | ||||
|       //       we need to use allocations for alpha/beta that have some guarantees on lifetime - a statically | ||||
|       //       managed 4B buffer for alpha that we'll copy the passed alpha value into, and constant memory | ||||
|       //       for beta respectively. | ||||
|       float *user_alpha_ptr = at::cuda::detail::get_user_alpha_ptr(); | ||||
|       at::Tensor user_alpha = at::from_blob(user_alpha_ptr, {1}, TensorOptions().device(kCUDA).dtype(kFloat)); | ||||
|       user_alpha.copy_(a); | ||||
|       // Tell cublasLt we're using device-side pointers for alpha/beta | ||||
|       auto pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; | ||||
|       computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_POINTER_MODE, pointer_mode); | ||||
|       alpha_ptr = user_alpha.data_ptr<float>(); | ||||
|       beta_ptr = at::cuda::detail::get_cublas_device_zero(); | ||||
|     } else { | ||||
|       alpha_val = a.item<float>(); | ||||
|     } | ||||
|   } | ||||
|     // For other data types, use the get_scale_mode function based on scaling type | ||||
|     // The SCALE_MODE attrs only exist in cuBLAS 12.8+/ROCm 7.0 or in recent hipblaslt, | ||||
|     // but we must invoke get_scale_mode anyways to trigger the version checks. | ||||
| @ -2079,7 +2048,6 @@ void scaled_gemm( | ||||
|   cublasLtMatmulHeuristicResult_t heuristicResult = {}; | ||||
|   int returnedResult = 0; | ||||
|   cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle(); | ||||
|  | ||||
|   TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( | ||||
|       ltHandle, | ||||
|       computeDesc.descriptor(), | ||||
| @ -2120,10 +2088,10 @@ void scaled_gemm( | ||||
|         auto is_valid_status = hipblaslt_ext::matmulIsAlgoSupported( | ||||
|                 ltHandle, | ||||
|                 computeDesc.descriptor(), | ||||
|                 alpha_ptr, | ||||
|                 &alpha_val, | ||||
|                 Adesc.descriptor(), | ||||
|                 Bdesc.descriptor(), | ||||
|                 beta_ptr, | ||||
|                 &beta_val, | ||||
|                 Cdesc.descriptor(), | ||||
|                 Ddesc.descriptor(), | ||||
|                 all_algos[i].algo, | ||||
| @ -2142,14 +2110,17 @@ void scaled_gemm( | ||||
|   cublasStatus_t cublasStatus = cublasLtMatmul( | ||||
|       ltHandle, | ||||
|       computeDesc.descriptor(), | ||||
|       alpha_ptr, | ||||
|       &alpha_val, | ||||
|       mat1_ptr, | ||||
|       Adesc.descriptor(), | ||||
|       mat2_ptr, | ||||
|       Bdesc.descriptor(), | ||||
|       beta_ptr, | ||||
|       // NOTE: always use result_ptr here, because cuBLASLt w/device beta=0 can't handle nullptr either | ||||
|       &beta_val, | ||||
| #ifdef USE_ROCM | ||||
|       result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr | ||||
| #else | ||||
|       nullptr, | ||||
| #endif // ifdef USE_ROCM | ||||
|       Cdesc.descriptor(), | ||||
|       result_ptr, | ||||
|       Ddesc.descriptor(), | ||||
|  | ||||
| @ -161,8 +161,7 @@ void scaled_gemm( | ||||
|     const void* result_scale_ptr, | ||||
|     int64_t result_ld, | ||||
|     ScalarType result_dtype, | ||||
|     bool use_fast_accum, | ||||
|     const std::optional<Tensor>& alpha); | ||||
|     bool use_fast_accum); | ||||
|  | ||||
| #define CUDABLAS_BGEMM_ARGTYPES(Dtype)  CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype) | ||||
|  | ||||
|  | ||||
| @ -325,9 +325,9 @@ uint64_t CUDAGeneratorImpl::seed() { | ||||
|  */ | ||||
| c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const { | ||||
|   // The RNG state comprises the seed, and an offset used for Philox. | ||||
|   constexpr size_t seed_size = sizeof(uint64_t); | ||||
|   constexpr size_t offset_size = sizeof(int64_t); | ||||
|   constexpr size_t total_size = seed_size + offset_size; | ||||
|   static const size_t seed_size = sizeof(uint64_t); | ||||
|   static const size_t offset_size = sizeof(int64_t); | ||||
|   static const size_t total_size = seed_size + offset_size; | ||||
|  | ||||
|   auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt); | ||||
|   auto rng_state = state_tensor.data_ptr<uint8_t>(); | ||||
| @ -346,9 +346,9 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const { | ||||
|  * and size of the internal state. | ||||
|  */ | ||||
| void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) { | ||||
|   constexpr size_t seed_size = sizeof(uint64_t); | ||||
|   constexpr size_t offset_size = sizeof(int64_t); | ||||
|   constexpr size_t total_size = seed_size + offset_size; | ||||
|   static const size_t seed_size = sizeof(uint64_t); | ||||
|   static const size_t offset_size = sizeof(int64_t); | ||||
|   static const size_t total_size = seed_size + offset_size; | ||||
|  | ||||
|   detail::check_rng_state(new_state); | ||||
|  | ||||
|  | ||||
| @ -1,192 +0,0 @@ | ||||
| #include <ATen/cuda/CUDAGreenContext.h> | ||||
|  | ||||
| namespace at::cuda { | ||||
|   GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) { | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|     int driver_version; | ||||
|     C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version)); | ||||
|     TORCH_CHECK( | ||||
|         driver_version >= 12080, "cuda driver too old to use green context!"); | ||||
|     CUcontext pctx = nullptr; | ||||
|     C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx)); | ||||
|     if (C10_UNLIKELY(!pctx)) { | ||||
|       TORCH_WARN( | ||||
|           "Attempted to create a green context but" | ||||
|           " there was no primary context! Creating a primary context..."); | ||||
|  | ||||
|       cudaFree(0); | ||||
|     } | ||||
|  | ||||
|     CUdevice device; | ||||
|     device_id_ = device_id; | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id)); | ||||
|  | ||||
|     // Get device resources | ||||
|     CUdevResource device_resource; | ||||
|     C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_( | ||||
|         device, &device_resource, CU_DEV_RESOURCE_TYPE_SM)); | ||||
|  | ||||
|     // Split resources | ||||
|     std::vector<CUdevResource> result(1); | ||||
|     auto result_data = result.data(); | ||||
|     unsigned int nb_groups = 1; | ||||
|     CUdevResource remaining; | ||||
|  | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_( | ||||
|             result_data, | ||||
|             &nb_groups, | ||||
|             &device_resource, | ||||
|             &remaining, | ||||
|             0, // default flags | ||||
|             num_sms)); | ||||
|  | ||||
|     TORCH_CHECK(nb_groups == 1, "Failed to create single resource group"); | ||||
|  | ||||
|     // Generate resource descriptor | ||||
|     CUdevResourceDesc desc; | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_( | ||||
|             &desc, result_data, 1)); | ||||
|  | ||||
|     // Create green context | ||||
|     // CU_GREEN_CTX_DEFAULT_STREAM is required per docs: | ||||
|     // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html | ||||
|     C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_( | ||||
|         &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM)); | ||||
|  | ||||
|     // Convert to regular context | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_)); | ||||
|     TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!"); | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   std::unique_ptr<GreenContext> GreenContext::create( | ||||
|       uint32_t num_sms, | ||||
|       std::optional<uint32_t> device_id) { | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|     if (!device_id.has_value()) { | ||||
|       device_id = at::cuda::current_device(); | ||||
|     } | ||||
|     return std::make_unique<GreenContext>(device_id.value(), num_sms); | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   // Implement move operations | ||||
|   GreenContext::GreenContext(GreenContext&& other) noexcept{ | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|     device_id_ = std::exchange(other.device_id_, -1); | ||||
|     green_ctx_ = std::exchange(other.green_ctx_, nullptr); | ||||
|     context_ = std::exchange(other.context_, nullptr); | ||||
|     parent_stream_ = std::exchange(other.parent_stream_, nullptr); | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{ | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|     if (this != &other) { | ||||
|       // Clean up current resources | ||||
|       if (green_ctx_) { | ||||
|         CUcontext current = nullptr; | ||||
|         C10_CUDA_DRIVER_CHECK( | ||||
|             c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t)); | ||||
|         if (current == context_) { | ||||
|           TORCH_CHECK( | ||||
|               false, | ||||
|               "attempting to overwrite current green ctx " | ||||
|               "when it is active!"); | ||||
|         } | ||||
|         C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_)); | ||||
|       } | ||||
|  | ||||
|       // Take ownership of other's resources | ||||
|       device_id_ = std::exchange(other.device_id_, -1); | ||||
|       green_ctx_ = std::exchange(other.green_ctx_, nullptr); | ||||
|       context_ = std::exchange(other.context_, nullptr); | ||||
|       parent_stream_ = std::exchange(other.parent_stream_, nullptr); | ||||
|     } | ||||
|     return *this; | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   GreenContext::~GreenContext() noexcept{ | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_)); | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   // Get the underlying CUDA context | ||||
|   CUcontext GreenContext::getContext() const { | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|     return context_; | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   // Get the underlying green context | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|   CUgreenCtx GreenContext::getGreenContext() const { | ||||
|     return green_ctx_; | ||||
|   } | ||||
| #endif | ||||
|  | ||||
|   // Make this context current | ||||
|   void GreenContext::setContext() { | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|     auto current_stream = c10::cuda::getCurrentCUDAStream(); | ||||
|     parent_stream_ = current_stream.stream(); | ||||
|  | ||||
|     at::cuda::CUDAEvent ev; | ||||
|     ev.record(current_stream); | ||||
|  | ||||
|     CUcontext current = nullptr; | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t)); | ||||
|     if (!current) { | ||||
|       C10_CUDA_DRIVER_CHECK( | ||||
|           c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_)); | ||||
|     } else { | ||||
|       C10_CUDA_DRIVER_CHECK( | ||||
|           c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_)); | ||||
|     } | ||||
|     // currently hardcodes the new green context to use the default stream | ||||
|     // TODO(eqy): consider creating a new stream if e.g., it allows interop | ||||
|     // with CUDA Graph captures etc. | ||||
|     auto default_stream = c10::cuda::getDefaultCUDAStream(); | ||||
|     ev.block(default_stream); | ||||
|     c10::cuda::setCurrentCUDAStream(default_stream); | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   void GreenContext::popContext() { | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|     // see above note about stream being hardcoded to the default stream | ||||
|     at::cuda::CUDAEvent ev; | ||||
|     ev.record(c10::cuda::getCurrentCUDAStream()); | ||||
|     CUcontext popped; | ||||
|     C10_CUDA_DRIVER_CHECK( | ||||
|         c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped)); | ||||
|     TORCH_INTERNAL_ASSERT( | ||||
|         popped == context_, "expected popped context to be the current ctx"); | ||||
|     ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_)); | ||||
| #else | ||||
|     TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); | ||||
| #endif | ||||
|   } | ||||
| } // namespace at::cuda | ||||
| @ -1,53 +0,0 @@ | ||||
| #pragma once | ||||
| #include <ATen/cuda/CUDAEvent.h> | ||||
|  | ||||
| #if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) | ||||
| #include <c10/cuda/driver_api.h> | ||||
| #include <cuda.h> | ||||
| #include <memory> | ||||
| #include <stdexcept> | ||||
| #include <vector> | ||||
| #define CUDA_HAS_GREEN_CONTEXT 1 | ||||
| #else | ||||
| #define CUDA_HAS_GREEN_CONTEXT 0 | ||||
| #endif | ||||
|  | ||||
| namespace at::cuda { | ||||
|  | ||||
| class TORCH_CUDA_CPP_API GreenContext { | ||||
|  public: | ||||
|   GreenContext(uint32_t device_id, uint32_t num_sms); | ||||
|  | ||||
|   static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id); | ||||
|  | ||||
|   // Delete copy constructor and assignment | ||||
|   GreenContext(const GreenContext&) = delete; | ||||
|   GreenContext& operator=(const GreenContext&) = delete; | ||||
|  | ||||
|   // Implement move operations | ||||
|   GreenContext(GreenContext&& other) noexcept; | ||||
|   GreenContext& operator=(GreenContext&& other) noexcept; | ||||
|   ~GreenContext() noexcept; | ||||
|  | ||||
|   // Get the underlying CUDA context | ||||
|   CUcontext getContext() const; | ||||
|  | ||||
|   // Get the underlying green context | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|   CUgreenCtx getGreenContext() const; | ||||
| #endif | ||||
|  | ||||
|   // Make this context current | ||||
|   void setContext(); | ||||
|  | ||||
|   void popContext(); | ||||
|  | ||||
|  private: | ||||
| #if CUDA_HAS_GREEN_CONTEXT | ||||
|   int32_t device_id_ = -1; | ||||
|   CUgreenCtx green_ctx_ = nullptr; | ||||
|   CUcontext context_ = nullptr; | ||||
|   cudaStream_t parent_stream_ = nullptr; | ||||
| #endif | ||||
| }; | ||||
| } // namespace at::cuda | ||||
| @ -183,6 +183,11 @@ struct CUDACachingHostAllocatorImpl | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|   bool pinned_use_background_threads() override { | ||||
|     return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: | ||||
|         pinned_use_background_threads(); | ||||
|   } | ||||
|  | ||||
|   EventPool::Event create_event_internal(DeviceIndex idx) { | ||||
|     // Leak the event pool to avoid shutdown issue. | ||||
|     static auto* event_pool = new EventPool(); | ||||
|  | ||||
| @ -70,7 +70,11 @@ | ||||
| #define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max() | ||||
| #endif | ||||
|  | ||||
| #if defined(USE_ROCM) | ||||
| #if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM) | ||||
|  | ||||
| #if !defined(USE_ROCM) | ||||
| namespace at_cuda_detail { | ||||
| #endif | ||||
|  | ||||
| // backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16 | ||||
|  | ||||
| @ -92,6 +96,10 @@ template <> | ||||
| struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>: | ||||
|        ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {}; | ||||
|  | ||||
| #if !defined(USE_ROCM) | ||||
| } // namespace at_cuda_detail | ||||
| #endif | ||||
|  | ||||
| #endif | ||||
|  | ||||
| #if !defined(USE_ROCM) | ||||
| @ -113,7 +121,7 @@ struct cuda_type<c10::Half> { | ||||
|   using type = __half; | ||||
| }; | ||||
|  | ||||
| #if !defined(USE_ROCM) | ||||
| #if !defined(USE_ROCM) && CUB_SUPPORTS_NV_BFLOAT16() | ||||
|  | ||||
| template<> | ||||
| struct cuda_type<c10::BFloat16> { | ||||
| @ -169,6 +177,7 @@ inline void segmented_sort_pairs( | ||||
|   } | ||||
| } | ||||
|  | ||||
| #if CUB_SUPPORTS_UNIQUE_BY_KEY() | ||||
| template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT> | ||||
| inline void unique_by_key( | ||||
|   KeysInputIteratorT keys_in, ValuesInputIteratorT values_in, | ||||
| @ -184,6 +193,7 @@ inline void unique_by_key( | ||||
|   CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey, | ||||
|     keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream()); | ||||
| } | ||||
| #endif | ||||
|  | ||||
| namespace impl { | ||||
|  | ||||
| @ -195,6 +205,36 @@ __global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputItera | ||||
|   *out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b)); | ||||
| } | ||||
|  | ||||
| #if !CUB_SUPPORTS_FUTURE_VALUE() | ||||
| template<typename ValueT, typename InputIteratorT> | ||||
| struct chained_iterator { | ||||
|   using iterator_category = std::random_access_iterator_tag; | ||||
|   using difference_type   = std::ptrdiff_t; | ||||
|   using value_type        = ValueT; | ||||
|   using pointer           = ValueT*; | ||||
|   using reference         = ValueT&; | ||||
|  | ||||
|   InputIteratorT iter; | ||||
|   ValueT *first; | ||||
|   difference_type offset = 0; | ||||
|  | ||||
|   __device__ ValueT operator[](difference_type i) { | ||||
|     i +=  offset; | ||||
|     if (i == 0) { | ||||
|       return *first; | ||||
|     } else { | ||||
|       return ValueT(iter[i - 1]); | ||||
|     } | ||||
|   } | ||||
|   __device__ chained_iterator operator+(difference_type i) { | ||||
|     return chained_iterator{iter, first, i}; | ||||
|   } | ||||
|   __device__ ValueT operator*() { | ||||
|     return (*this)[0]; | ||||
|   } | ||||
| }; | ||||
| #endif | ||||
|  | ||||
| // even though cub is supposed to support tensors with int_max elements, in reality it doesn't, | ||||
| // so split at int_max/2 | ||||
| constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30 | ||||
| @ -239,6 +279,25 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT | ||||
|         first_elem_ptr, | ||||
|         scan_op); | ||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
| #if !CUB_SUPPORTS_FUTURE_VALUE() | ||||
|     using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>; | ||||
|     using tuple = typename ArgIndexInputIterator::value_type; | ||||
|     auto input_iter_transform = [=] __device__ (const tuple &x)->input_t  { | ||||
|       if (x.key == 0) { | ||||
|         return *first_elem_ptr; | ||||
|       } else { | ||||
|         return x.value; | ||||
|       } | ||||
|     }; | ||||
|     auto input_ = ATEN_CUB_TRANSFORM_ITERATOR(input_t, decltype(input_iter_transform), ArgIndexInputIterator)( | ||||
|       ArgIndexInputIterator(input + i), input_iter_transform); | ||||
|     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, | ||||
|         input_, | ||||
|         output + i, | ||||
|         scan_op, | ||||
|         size_cub, | ||||
|         at::cuda::getCurrentCUDAStream()); | ||||
| #else | ||||
|     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, | ||||
|         input + i + 1, | ||||
|         output + i, | ||||
| @ -246,6 +305,7 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT | ||||
|         ::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr), | ||||
|         size_cub, | ||||
|         at::cuda::getCurrentCUDAStream()); | ||||
| #endif | ||||
|   } | ||||
| #endif | ||||
| } | ||||
| @ -497,6 +557,16 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT | ||||
|         first_elem_ptr, | ||||
|         scan_op); | ||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
| #if !CUB_SUPPORTS_FUTURE_VALUE() | ||||
|     auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{ | ||||
|       input + i, first_elem_ptr}; | ||||
|     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, | ||||
|         input_, | ||||
|         output + i, | ||||
|         scan_op, | ||||
|         size_cub, | ||||
|         at::cuda::getCurrentCUDAStream()); | ||||
| #else | ||||
|     CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan, | ||||
|         input + i, | ||||
|         output + i, | ||||
| @ -504,10 +574,12 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT | ||||
|         ::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr), | ||||
|         size_cub, | ||||
|         at::cuda::getCurrentCUDAStream()); | ||||
| #endif | ||||
|   } | ||||
| #endif | ||||
| } | ||||
|  | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|  | ||||
| template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT> | ||||
| inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) { | ||||
| @ -535,6 +607,7 @@ inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT | ||||
| #endif | ||||
| } | ||||
|  | ||||
| #endif | ||||
|  | ||||
| template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT> | ||||
| void unique(InputIteratorT input, OutputIteratorT output, | ||||
|  | ||||
| @ -10,6 +10,14 @@ | ||||
| #define CUB_VERSION 200001 | ||||
| #endif | ||||
|  | ||||
| // cub sort support for __nv_bfloat16 is added to cub 1.13 in: | ||||
| // https://github.com/NVIDIA/cub/pull/306 | ||||
| #if CUB_VERSION >= 101300 | ||||
| #define CUB_SUPPORTS_NV_BFLOAT16() true | ||||
| #else | ||||
| #define CUB_SUPPORTS_NV_BFLOAT16() false | ||||
| #endif | ||||
|  | ||||
| // cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in: | ||||
| // https://github.com/NVIDIA/cub/pull/326 | ||||
| // CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake | ||||
| @ -20,6 +28,30 @@ | ||||
| #define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false | ||||
| #endif | ||||
|  | ||||
| // cub support for UniqueByKey is added to cub 1.16 in: | ||||
| // https://github.com/NVIDIA/cub/pull/405 | ||||
| #if CUB_VERSION >= 101600 | ||||
| #define CUB_SUPPORTS_UNIQUE_BY_KEY() true | ||||
| #else | ||||
| #define CUB_SUPPORTS_UNIQUE_BY_KEY() false | ||||
| #endif | ||||
|  | ||||
| // cub support for scan by key is added to cub 1.15 | ||||
| // in https://github.com/NVIDIA/cub/pull/376 | ||||
| #if CUB_VERSION >= 101500 | ||||
| #define CUB_SUPPORTS_SCAN_BY_KEY() 1 | ||||
| #else | ||||
| #define CUB_SUPPORTS_SCAN_BY_KEY() 0 | ||||
| #endif | ||||
|  | ||||
| // cub support for cub::FutureValue is added to cub 1.15 in: | ||||
| // https://github.com/NVIDIA/cub/pull/305 | ||||
| #if CUB_VERSION >= 101500 | ||||
| #define CUB_SUPPORTS_FUTURE_VALUE() true | ||||
| #else | ||||
| #define CUB_SUPPORTS_FUTURE_VALUE() false | ||||
| #endif | ||||
|  | ||||
| // There were many bc-breaking changes in major version release of CCCL v3.0.0 | ||||
| // Please see https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html | ||||
| #if CUB_VERSION >= 200800 | ||||
|  | ||||
| @ -1,54 +0,0 @@ | ||||
| #include <ATen/Functions.h> | ||||
| #include <ATen/Tensor.h> | ||||
| #include <ATen/cuda/Exceptions.h> | ||||
|  | ||||
| #include <mutex> | ||||
|  | ||||
| namespace at { | ||||
| namespace cuda { | ||||
| namespace detail { | ||||
|  | ||||
| __device__ __constant__ float cublas_one_device; | ||||
| __device__ __constant__ float cublas_zero_device; | ||||
|  | ||||
| float *get_cublas_device_one() { | ||||
|   static c10::once_flag init_flag; | ||||
|  | ||||
|   c10::call_once(init_flag, []() { | ||||
|     const float one = 1.f; | ||||
|     AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float))); | ||||
|   }); | ||||
|  | ||||
|   float *ptr; | ||||
|   AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device)); | ||||
|   return ptr; | ||||
| } | ||||
|  | ||||
| float *get_cublas_device_zero() { | ||||
|   static c10::once_flag init_flag; | ||||
|  | ||||
|   c10::call_once(init_flag, []() { | ||||
|     const float zero = 0.f; | ||||
|     AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float))); | ||||
|   }); | ||||
|  | ||||
|   float *ptr; | ||||
|   AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device)); | ||||
|   return ptr; | ||||
| } | ||||
|  | ||||
| float *get_user_alpha_ptr() { | ||||
|   static float *alpha_ptr; | ||||
|  | ||||
|   static c10::once_flag init_flag; | ||||
|  | ||||
|   c10::call_once(init_flag, []() { | ||||
|     AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float))); | ||||
|   }); | ||||
|  | ||||
|   return alpha_ptr; | ||||
| } | ||||
|  | ||||
| } // namespace detail | ||||
| } // namespace cuda | ||||
| } // namespace at | ||||
| @ -1,11 +0,0 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/core/TensorBase.h> | ||||
|  | ||||
| namespace at::cuda::detail { | ||||
|  | ||||
| float *get_cublas_device_one(); | ||||
| float *get_cublas_device_zero(); | ||||
| float *get_user_alpha_ptr(); | ||||
|  | ||||
| } // namespace at::cuda::detail | ||||
| @ -109,8 +109,7 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> { | ||||
|           params->c_scale_ptr, | ||||
|           params->ldc, | ||||
|           params->c_dtype, | ||||
|           params->use_fast_accum, | ||||
|           std::nullopt /* alpha */); | ||||
|           params->use_fast_accum); | ||||
|       return OK; | ||||
|     } | ||||
| }; | ||||
|  | ||||
| @ -1,23 +0,0 @@ | ||||
| #include <ATen/detail/XLAHooksInterface.h> | ||||
|  | ||||
| namespace at { | ||||
| namespace detail { | ||||
|  | ||||
| const XLAHooksInterface& getXLAHooks() { | ||||
|   auto create_impl = [] { | ||||
|     // Create XLA hooks using the registry | ||||
|     auto hooks = XLAHooksRegistry()->Create("torch_xla::detail::XLAHooks", XLAHooksArgs{}); | ||||
|     if (hooks) { | ||||
|       return hooks; | ||||
|     } | ||||
|     // If hooks creation fails, fall back to default implementation | ||||
|     return std::make_unique<XLAHooksInterface>(); | ||||
|   }; | ||||
|   static auto hooks = create_impl(); | ||||
|   return *hooks; | ||||
| } | ||||
| } // namespace detail | ||||
|  | ||||
| C10_DEFINE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs) | ||||
|  | ||||
| } // namespace at | ||||
| @ -1,79 +0,0 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <c10/core/Device.h> | ||||
| #include <c10/util/Exception.h> | ||||
| #include <c10/util/Registry.h> | ||||
|  | ||||
| #include <ATen/detail/AcceleratorHooksInterface.h> | ||||
|  | ||||
| C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") | ||||
|  | ||||
| namespace at { | ||||
|  | ||||
| constexpr const char* XLA_HELP = | ||||
|   "This error has occurred because you are trying " | ||||
|   "to use some XLA functionality, but the XLA library has not been " | ||||
|   "loaded by the dynamic linker. You must load xla libraries by `import torch_xla`"; | ||||
|  | ||||
| struct TORCH_API XLAHooksInterface : AcceleratorHooksInterface { | ||||
|   ~XLAHooksInterface() override = default; | ||||
|  | ||||
|   void init() const override { | ||||
|     TORCH_CHECK(false, "Cannot initialize XLA without torch_xla library. ", XLA_HELP); | ||||
|   } | ||||
|  | ||||
|   virtual bool hasXLA() const { | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
|   virtual std::string showConfig() const { | ||||
|     TORCH_CHECK( | ||||
|         false, | ||||
|         "Cannot query detailed XLA version without torch_xla library. ", | ||||
|         XLA_HELP); | ||||
|   } | ||||
|  | ||||
|   const Generator& getDefaultGenerator( | ||||
|       [[maybe_unused]] DeviceIndex device_index = -1) const override { | ||||
|     TORCH_CHECK( | ||||
|         false, "Cannot get default XLA generator without torch_xla library. ", XLA_HELP); | ||||
|   } | ||||
|  | ||||
|   Generator getNewGenerator( | ||||
|       [[maybe_unused]] DeviceIndex device_index = -1) const override { | ||||
|     TORCH_CHECK(false, "Cannot get XLA generator without torch_xla library. ", XLA_HELP); | ||||
|   } | ||||
|  | ||||
|   virtual DeviceIndex getCurrentDevice() const override { | ||||
|     TORCH_CHECK(false, "Cannot get current XLA device without torch_xla library. ", XLA_HELP); | ||||
|   } | ||||
|  | ||||
|   Device getDeviceFromPtr(void* /*data*/) const override { | ||||
|     TORCH_CHECK(false, "Cannot get device of pointer on XLA without torch_xla library. ", XLA_HELP); | ||||
|   } | ||||
|  | ||||
|   Allocator* getPinnedMemoryAllocator() const override { | ||||
|     TORCH_CHECK(false, "Cannot get XLA pinned memory allocator without torch_xla library. ", XLA_HELP); | ||||
|   } | ||||
|  | ||||
|   bool isPinnedPtr(const void* data) const override { | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
|   bool hasPrimaryContext(DeviceIndex device_index) const override { | ||||
|     TORCH_CHECK(false, "Cannot query primary context without torch_xla library. ", XLA_HELP); | ||||
|   } | ||||
|  | ||||
| }; | ||||
|  | ||||
| struct TORCH_API XLAHooksArgs {}; | ||||
|  | ||||
| TORCH_DECLARE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs); | ||||
| #define REGISTER_XLA_HOOKS(clsname) \ | ||||
|   C10_REGISTER_CLASS(XLAHooksRegistry, clsname, clsname) | ||||
|  | ||||
| namespace detail { | ||||
| TORCH_API const XLAHooksInterface& getXLAHooks(); | ||||
| } // namespace detail | ||||
| } // namespace at | ||||
| C10_DIAGNOSTIC_POP() | ||||
| @ -160,10 +160,6 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({ | ||||
|   DispatchKey::CUDA, | ||||
|   DispatchKey::CPU, | ||||
|   DispatchKey::PrivateUse1, | ||||
|   DispatchKey::SparseCPU, | ||||
|   DispatchKey::SparseCUDA, | ||||
|   DispatchKey::SparseCsrCPU, | ||||
|   DispatchKey::SparseCsrCUDA, | ||||
| }); | ||||
|  | ||||
| inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) { | ||||
|  | ||||
| @ -240,8 +240,8 @@ TORCH_META_FUNC(gelu_backward) ( | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| static constexpr double SELU_ALPHA = 1.6732632423543772848170429916717; | ||||
| static constexpr double SELU_SCALE = 1.0507009873554804934193349852946; | ||||
| static const double SELU_ALPHA = 1.6732632423543772848170429916717; | ||||
| static const double SELU_SCALE = 1.0507009873554804934193349852946; | ||||
|  | ||||
| DEFINE_DISPATCH(elu_stub); | ||||
| DEFINE_DISPATCH(elu_backward_stub); | ||||
|  | ||||
| @ -286,7 +286,7 @@ template void scal_fast_path<scalar_t>(int *n, scalar_t *a, scalar_t *x, int *in | ||||
| #if AT_BUILD_WITH_BLAS() | ||||
| template <> | ||||
| bool scal_use_fast_path<double>(int64_t n, int64_t incx) { | ||||
|   auto constexpr intmax = std::numeric_limits<int>::max(); | ||||
|   auto intmax = std::numeric_limits<int>::max(); | ||||
|   return n <= intmax && incx <= intmax; | ||||
| } | ||||
|  | ||||
| @ -315,7 +315,7 @@ bool gemv_use_fast_path<float>( | ||||
|     int64_t incx, | ||||
|     [[maybe_unused]] float beta, | ||||
|     int64_t incy) { | ||||
|   auto constexpr intmax = std::numeric_limits<int>::max(); | ||||
|   auto intmax = std::numeric_limits<int>::max(); | ||||
|   return (m <= intmax) && (n <= intmax) && (lda <= intmax) && | ||||
|          (incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax); | ||||
| } | ||||
|  | ||||
| @ -658,7 +658,6 @@ static void check_shape_forward(const at::Tensor& input, | ||||
|   TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported"); | ||||
|   TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported"); | ||||
|   TORCH_CHECK(!params.is_dilation_neg(), "dilation should be greater than zero"); | ||||
|   TORCH_CHECK(groups > 0, "expected groups to be greater than 0, but got groups=", groups); | ||||
|  | ||||
|   TORCH_CHECK(weight_dim == k, | ||||
|            "Expected ", weight_dim, "-dimensional input for ", weight_dim, | ||||
|  | ||||
| @ -1,6 +1,5 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <array> | ||||
| #include <ATen/native/Math.h> | ||||
| #include <c10/macros/Macros.h> | ||||
| #include <c10/util/MathConstants.h> | ||||
| @ -128,7 +127,7 @@ C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, unifor | ||||
|  | ||||
| template<typename scalar_t> | ||||
| C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) { | ||||
|   constexpr static scalar_t kTailValues[] = { | ||||
|   const static scalar_t kTailValues[] = { | ||||
|     0.0810614667953272, | ||||
|     0.0413406959554092, | ||||
|     0.0276779256849983, | ||||
| @ -140,7 +139,7 @@ C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) { | ||||
|     0.00925546218271273, | ||||
|     0.00833056343336287 | ||||
|   }; | ||||
|   if (k < std::size(kTailValues)) { | ||||
|   if (k <= 9) { | ||||
|     return kTailValues[static_cast<size_t>(k)]; | ||||
|   } | ||||
|   scalar_t kp1sq = (k + 1) * (k + 1); | ||||
|  | ||||
| @ -3620,7 +3620,7 @@ Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result) | ||||
|     try { | ||||
|       mkldnn_matmul_i8i8i32(self, mat2, result); | ||||
|       dispatched = true; | ||||
|     } catch ([[maybe_unused]] const std::exception& e) { | ||||
|     } catch (const std::exception& e) { | ||||
|       TORCH_WARN(func_name, " failed, switching to BLAS gemm: ", e.what()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
| @ -581,7 +581,7 @@ scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M, | ||||
| template <typename scalar_t> | ||||
| static scalar_t lanczos_sum_expg_scaled(scalar_t x) { | ||||
|   // lanczos approximation | ||||
|   static constexpr scalar_t lanczos_sum_expg_scaled_num[13] = { | ||||
|   static const scalar_t lanczos_sum_expg_scaled_num[13] = { | ||||
|     0.006061842346248906525783753964555936883222, | ||||
|     0.5098416655656676188125178644804694509993, | ||||
|     19.51992788247617482847860966235652136208, | ||||
| @ -596,7 +596,7 @@ static scalar_t lanczos_sum_expg_scaled(scalar_t x) { | ||||
|     103794043.1163445451906271053616070238554, | ||||
|     56906521.91347156388090791033559122686859 | ||||
|   }; | ||||
|   static constexpr scalar_t lanczos_sum_expg_scaled_denom[13] = { | ||||
|   static const scalar_t lanczos_sum_expg_scaled_denom[13] = { | ||||
|     1., | ||||
|     66., | ||||
|     1925., | ||||
| @ -712,7 +712,7 @@ static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) { | ||||
| template <typename scalar_t> | ||||
| static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) { | ||||
|   // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] | ||||
|   static constexpr scalar_t d[25][25] = | ||||
|   static const scalar_t d[25][25] = | ||||
|     {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, | ||||
|       1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, | ||||
|       3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, | ||||
|  | ||||
| @ -62,7 +62,7 @@ | ||||
| #include <utility> | ||||
| #include <vector> | ||||
|  | ||||
| static constexpr int MIOPEN_DIM_MAX = 5; | ||||
| static const int MIOPEN_DIM_MAX = 5; | ||||
|  | ||||
| namespace at::meta { | ||||
|  | ||||
|  | ||||
| @ -11,8 +11,6 @@ inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_facto | ||||
|               "pixel_shuffle expects a positive upscale_factor, but got ", | ||||
|               upscale_factor); | ||||
|   int64_t c = self.size(-3); | ||||
|   TORCH_CHECK_VALUE(upscale_factor <= std::numeric_limits<decltype(upscale_factor)>::max() / upscale_factor, | ||||
|         "upscale factor is too large, (upscale_factor)^2 overflowed: upscale_factor=", upscale_factor); | ||||
|   int64_t upscale_factor_squared = upscale_factor * upscale_factor; | ||||
|   TORCH_CHECK(c % upscale_factor_squared == 0, | ||||
|               "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of " | ||||
|  | ||||
| @ -77,7 +77,7 @@ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) { | ||||
|   // next broadcast all index tensors together | ||||
|   try { | ||||
|     indices = expand_outplace(indices); | ||||
|   } catch (std::exception&) { | ||||
|   } catch (std::exception& e) { | ||||
|     TORCH_CHECK_INDEX( | ||||
|         false, | ||||
|         "shape mismatch: indexing tensors could not be broadcast together" | ||||
|  | ||||
| @ -259,20 +259,11 @@ inline void winograd_f2k3_input_transform_inplace__rvv( | ||||
|   const vfloat32m1_t wd1 = __riscv_vfadd_vv_f32m1(d1, d2, 4); | ||||
|   const vfloat32m1_t wd2 = __riscv_vfsub_vv_f32m1(d2, d1, 4); | ||||
|   const vfloat32m1_t wd3 = __riscv_vfsub_vv_f32m1(d1, d3, 4); | ||||
|   /* GCC 14.2 (RISC-V RVV) ICE workaround: | ||||
|    * Avoid single-statement read-modify-write on MEM_REF like: | ||||
|    *   *input_tile_val = | ||||
|    *     __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val); | ||||
|    * This triggers an ICE during GIMPLE lower (gsi_replace / riscv_gimple_fold_builtin) | ||||
|    * with -march=rv64gcv. Use a temporary then write back. | ||||
|    * Do NOT refactor into the single-statement form. Clang is unaffected. | ||||
|    */ | ||||
|   vfloat32m1x4_t tmp_input_tile_val = *input_tile_val; | ||||
|   tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 0, wd0); | ||||
|   tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 1, wd1); | ||||
|   tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 2, wd2); | ||||
|   tmp_input_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_input_tile_val, 3, wd3); | ||||
|   *input_tile_val = tmp_input_tile_val; | ||||
|  | ||||
|   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wd0); | ||||
|   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wd1); | ||||
|   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 2, wd2); | ||||
|   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 3, wd3); | ||||
| } | ||||
|  | ||||
| inline void winograd_f2k3_output_transform_inplace__rvv( | ||||
| @ -286,15 +277,9 @@ inline void winograd_f2k3_output_transform_inplace__rvv( | ||||
|   const vfloat32m1_t wm0 = __riscv_vfadd_vv_f32m1(m0_plus_m1, m2, 4); | ||||
|   const vfloat32m1_t m1_sub_m2 = __riscv_vfsub_vv_f32m1(m1, m2, 4); | ||||
|   const vfloat32m1_t wm1 = __riscv_vfsub_vv_f32m1(m1_sub_m2, m3, 4); | ||||
|   /* GCC 14.2 (RISC-V RVV) ICE workaround — see note above. | ||||
|    * Keep the temporary + write-back pattern to avoid ICE. | ||||
|    * Do NOT rewrite into: | ||||
|    *   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, idx, val); | ||||
|    */ | ||||
|   vfloat32m1x4_t tmp_output_tile_val = *input_tile_val; | ||||
|   tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 0, wm0); | ||||
|   tmp_output_tile_val = __riscv_vset_v_f32m1_f32m1x4(tmp_output_tile_val, 1, wm1); | ||||
|   *input_tile_val = tmp_output_tile_val; | ||||
|  | ||||
|   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wm0); | ||||
|   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wm1); | ||||
| } | ||||
|  | ||||
| inline vfloat32m1_t | ||||
| @ -315,17 +300,11 @@ inline void winograd_f2k3_kernel_transform__rvv( | ||||
|   const vfloat32m1_t const_half = __riscv_vfmv_v_f_f32m1(0.5f, 4); | ||||
|   const vfloat32m1_t g0_plus_g2 = __riscv_vfadd_vv_f32m1(g0, g2, 4); | ||||
|   vfloat32m1_t half_g0_plus_g2 =  __riscv_vfmul_vv_f32m1(const_half, g0_plus_g2, 4); | ||||
|   /* GCC 14.2 (RISC-V RVV) ICE workaround — see note above. | ||||
|    * Keep the temporary + write-back pattern to avoid ICE. | ||||
|    * Do NOT rewrite into: | ||||
|    *   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, idx, val); | ||||
|    */ | ||||
|   vfloat32m1x4_t tmp_transform = *transform; | ||||
|   tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 0, g0); | ||||
|   tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1)); | ||||
|   tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1)); | ||||
|   tmp_transform = __riscv_vset_v_f32m1_f32m1x4(tmp_transform, 3, g2); | ||||
|   *transform = tmp_transform; | ||||
|  | ||||
|   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 0, g0); | ||||
|   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1)); | ||||
|   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1)); | ||||
|   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 3, g2); | ||||
| } | ||||
|  | ||||
| inline vfloat32m1x4_t v4f_transpose4x4__rvv(const vfloat32m1x4_t m) { | ||||
|  | ||||
| @ -1038,7 +1038,7 @@ struct HelperInterpNearest : public HelperInterpBase { | ||||
|   // We keep this structure for BC and consider as deprecated. | ||||
|   // See HelperInterpNearestExact as replacement | ||||
|  | ||||
|   static constexpr int interp_size = 1; | ||||
|   static const int interp_size = 1; | ||||
|  | ||||
|   static inline void init_indices_weights( | ||||
|     at::ScalarType output_type, | ||||
| @ -1155,7 +1155,7 @@ struct HelperInterpNearestExact : public HelperInterpNearest { | ||||
|  | ||||
| struct HelperInterpLinear : public HelperInterpBase { | ||||
|  | ||||
|   static constexpr int interp_size = 2; | ||||
|   static const int interp_size = 2; | ||||
|  | ||||
|   // Compute indices and weights for each interpolated dimension | ||||
|   // indices_weights = { | ||||
| @ -1275,7 +1275,7 @@ struct HelperInterpLinear : public HelperInterpBase { | ||||
|  | ||||
| struct HelperInterpCubic : public HelperInterpBase { | ||||
|  | ||||
|   static constexpr int interp_size = 4; | ||||
|   static const int interp_size = 4; | ||||
|  | ||||
|   // Compute indices and weights for each interpolated dimension | ||||
|   // indices_weights = { | ||||
|  | ||||
| @ -272,110 +272,28 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa | ||||
|   } | ||||
| } | ||||
|  | ||||
| /* | ||||
|  * Checks whether DISABLE_ADDMM_CUDA_LT is set. | ||||
|  * Additionally, for ROCM we test whether the architecture supports the Lt. | ||||
|  */ | ||||
| static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) { | ||||
|   // When hipBLASLt is not supported on the architecture, return true | ||||
|   #ifdef USE_ROCM | ||||
|   static const std::vector<std::string> archs = { | ||||
| static bool getDisableAddmmCudaLt() { | ||||
|     static const auto env_value = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT"); | ||||
|     if (env_value == "1") { | ||||
|       return true; | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
| static bool isSupportedHipLtROCmArch(int index) { | ||||
|     static const std::vector<std::string> archs = { | ||||
|         "gfx90a", "gfx942", | ||||
|     #if ROCM_VERSION >= 60300 | ||||
| #if ROCM_VERSION >= 60300 | ||||
|         "gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908", | ||||
|     #endif | ||||
|     #if ROCM_VERSION >= 70000 | ||||
| #endif | ||||
| #if ROCM_VERSION >= 70000 | ||||
|         "gfx950", "gfx1150", "gfx1151" | ||||
|     #endif | ||||
|   }; | ||||
|   const auto is_hipblas_lt_arch_supported = at::detail::getCUDAHooks().isGPUArch(archs, device.index()); | ||||
|   if (!is_hipblas_lt_arch_supported) { | ||||
|     return true; | ||||
|   } | ||||
|   #endif | ||||
|  | ||||
|   // Check whether it is disabled in the env | ||||
|   static const auto is_addmm_cuda_lt_disabled = c10::utils::get_env("DISABLE_ADDMM_CUDA_LT"); | ||||
|   if (is_addmm_cuda_lt_disabled == "1") { | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|   return false; | ||||
| } | ||||
|  | ||||
| /* | ||||
|  * Check whether for the given input we want to enable the Lt interface | ||||
|  */ | ||||
| static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) { | ||||
|   // Implies 2D bias which we currently not send through Lt. | ||||
|   // TODO: this check is done pre col-major input preparation, | ||||
|   // so, this condition can be ralexed in cases when a col-major | ||||
|   // copy of result is needed. | ||||
|   if (result.is_same(self)) { | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
|   #if defined(USE_ROCM) && ROCM_VERSION == 60400 | ||||
|   // hipblaslt TT fp32 regression on ROCm 6.4, cannot use | ||||
|   const auto args = cublasCommonArgs(mat1, mat2, result); | ||||
|   if (args.transa == 't' && args.transb == 't') { | ||||
|     return false; | ||||
|   } | ||||
|   #endif | ||||
|  | ||||
|   const auto mat1_sizes = mat1.sizes(); | ||||
|   const auto mat2_sizes = mat2.sizes(); | ||||
|   #if defined(CUDA_VERSION) || defined(USE_ROCM) | ||||
|   const auto scalar_type = mat1.scalar_type(); | ||||
|   return (beta.toComplexDouble() == 1.0 | ||||
|     // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] | ||||
|     // is to use lt interface only when self is bias. | ||||
|     && self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous() | ||||
|     && result.dim() == 2 && result.is_contiguous() | ||||
|     && ( // some dtype restrictions | ||||
|       #ifndef USE_ROCM | ||||
|       scalar_type == at::ScalarType::Double || | ||||
|       #endif | ||||
|       scalar_type == at::ScalarType::Float || | ||||
|       scalar_type == at::ScalarType::Half || | ||||
|       scalar_type == at::ScalarType::BFloat16 | ||||
|     ) | ||||
|     && ( // some shape/stride restrictions | ||||
|       // Strangely, if mat2 has only 1 row or column, we get | ||||
|       // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic. | ||||
|       // NOTE: extension to mat1 because mat1/mat2 can be swapped based off | ||||
|       // their row-/col-majorness. | ||||
|       mat1_sizes[0] > 1 && mat1_sizes[1] > 1 && | ||||
|       mat2_sizes[0] > 1 && mat2_sizes[1] > 1 | ||||
|       // The last conditions is to skip 16b transA and non-trans-B having | ||||
|       // leading dim >> rows when they are sliced from a large tensor | ||||
|       // see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul | ||||
|       #if !(defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM)) | ||||
|       // Related to avoiding the leading stride >> leading dim problematic case | ||||
|       // with 16b dtypes described above. For such dtypes we only allow inputs | ||||
|       // which are either row- or col-major (i.e. non-overlapping, compact memory layout). | ||||
|       // In that case the leading stride will be equal to the outer dim len. | ||||
|       // Why do we catch this case here? The following `prepare_matrix_for_cublas` method | ||||
|       // does not modify inputs as long as there is a stride of length 1 | ||||
|       // and the leading stride is at least max(1, other dim length), so we might | ||||
|       // end up with contiguous cols but not rows (i.e. holes between different rows) | ||||
|       // and vice versa. | ||||
|       && mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 && | ||||
|       mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 && | ||||
|       && ( | ||||
|         // filter by dtype | ||||
|         (scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) || | ||||
|         // check mat1/mat2 is row-/col-major | ||||
|         (mat1.is_non_overlapping_and_dense() && mat2.is_non_overlapping_and_dense()) | ||||
|       ) | ||||
|       #endif | ||||
|     ) | ||||
|   ); | ||||
|   #endif | ||||
|  | ||||
|   // no compliance by default | ||||
|   return false; | ||||
| #endif | ||||
|     }; | ||||
|     return at::detail::getCUDAHooks().isGPUArch(archs, index); | ||||
| } | ||||
| #endif | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const scalar_t* bias, cuda::blas::GEMMAndBiasActivationEpilogue activation) { | ||||
| @ -417,70 +335,7 @@ void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha, const | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename scalar_t, typename res_scalar_t = scalar_t> | ||||
| bool launchGemmAndBiasCublasLt( | ||||
|     // args contains result which is modified | ||||
|     cublasCommonArgs& args, | ||||
|     const Tensor& self, | ||||
|     const Scalar& alpha, | ||||
|     Activation activation = Activation::None | ||||
| ) { | ||||
|   const auto* self_ptr = self.const_data_ptr<scalar_t>(); | ||||
|  | ||||
|   const auto tuning_ctx = at::cuda::tunable::getTuningContext(); | ||||
|   if (tuning_ctx->IsTunableOpEnabled()) { | ||||
|     // TODO: maybe also return some success state? | ||||
|     launchTunableGemmAndBias<scalar_t>( | ||||
|       args, alpha, self_ptr, activation_to_gemm_and_blas_arg(activation) | ||||
|     ); | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|   return at::cuda::blas::gemm_and_bias<scalar_t, res_scalar_t>( | ||||
|     args.transa == 't', | ||||
|     args.transb == 't', | ||||
|     args.m, | ||||
|     args.n, | ||||
|     args.k, | ||||
|     alpha.to<at::opmath_type<scalar_t>>(), | ||||
|     args.mata->const_data_ptr<scalar_t>(), | ||||
|     args.lda, | ||||
|     args.matb->const_data_ptr<scalar_t>(), | ||||
|     args.ldb, | ||||
|     self_ptr, | ||||
|     args.result->data_ptr<res_scalar_t>(), | ||||
|     args.result_ld, | ||||
|     activation_to_gemm_and_blas_arg(activation) | ||||
|   ); | ||||
| } | ||||
|  | ||||
| template <typename scalar_t, typename res_scalar_t = scalar_t> | ||||
| bool launchGemmCublas( | ||||
|     // args contains result which is modified | ||||
|     cublasCommonArgs& args, | ||||
|     const Scalar& alpha, | ||||
|     const Scalar& beta | ||||
| ) { | ||||
|   at::cuda::blas::gemm<scalar_t, res_scalar_t>( | ||||
|     args.transa, | ||||
|     args.transb, | ||||
|     args.m, | ||||
|     args.n, | ||||
|     args.k, | ||||
|     alpha.to<at::opmath_type<scalar_t>>(), | ||||
|     args.mata->const_data_ptr<scalar_t>(), | ||||
|     args.lda, | ||||
|     args.matb->const_data_ptr<scalar_t>(), | ||||
|     args.ldb, | ||||
|     beta.to<at::opmath_type<scalar_t>>(), | ||||
|     args.result->data_ptr<res_scalar_t>(), | ||||
|     args.result_ld | ||||
|   ); | ||||
|   return true; // success! | ||||
| } | ||||
|  | ||||
| Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None, bool disable_addmm_cuda_lt_override=false) { | ||||
|   // Shape checks { | ||||
|   // Make sure to keep addmm_cuda below in sync with this code; it | ||||
|   // preflights a check to try to avoid actually needing to call | ||||
|   // expand(). | ||||
| @ -490,62 +345,105 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma | ||||
|     "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() | ||||
|   ) | ||||
|  | ||||
|   if (result.is_same(self)) { | ||||
|     TORCH_CHECK(result.dim() == 2, "tensors must be 2-D"); | ||||
|     TORCH_CHECK(self.sizes()[0] == mat1.sizes()[0], "self dim 0 must match mat1 dim 0"); | ||||
|     TORCH_CHECK(self.sizes()[1] == mat2.sizes()[1], "self dim 1 must match mat2 dim 1"); | ||||
|   } | ||||
|   // } Shape checks | ||||
|  | ||||
|   // NOLINTNEXTLINE(*c-array*) | ||||
|   TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}}; | ||||
|   checkAllSameGPU(__func__, targs); | ||||
|  | ||||
|   // Handle whether to use the Lt interface { | ||||
|   static bool persistent_disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()); | ||||
|   IntArrayRef mat1_sizes = mat1.sizes(); | ||||
|   IntArrayRef mat2_sizes = mat2.sizes(); | ||||
|   IntArrayRef self__sizes; | ||||
|   bool useLtInterface = false; | ||||
| #if defined(USE_ROCM) | ||||
|   // When hipBLASLt is not supported on the architecture, | ||||
|   // disable_addmm_cuda_lt will always be to set to true | ||||
|   static bool disable_addmm_cuda_lt = | ||||
|     !isSupportedHipLtROCmArch(self.device().index()) || getDisableAddmmCudaLt(); | ||||
| #else | ||||
|   static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt(); | ||||
| #endif | ||||
|   // if lt path fails, we recurse back into this function here and force the lt path to off | ||||
|   // we cannot update varible disable_addmm_cuda_lt from above since it is static and would be permanent | ||||
|   bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override; | ||||
|   #ifdef USE_ROCM | ||||
|   // Conditioned on the device index, which is not persistent | ||||
|   disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt; | ||||
|   #endif | ||||
|   // Condition on the input | ||||
|   disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt; | ||||
|   // } | ||||
|  | ||||
|   bool disable_addmm_cuda_lt_final = disable_addmm_cuda_lt || disable_addmm_cuda_lt_override; | ||||
| #if defined(USE_ROCM) && ROCM_VERSION == 60400 | ||||
|   // hipblaslt TT fp32 regression on ROCm 6.4, cannot use | ||||
|   cublasCommonArgs _args(mat1, mat2, result); | ||||
|   if (_args.transa == 't' && _args.transb == 't') { | ||||
|     disable_addmm_cuda_lt_final = true; | ||||
|   } | ||||
| #endif | ||||
|   at::ScalarType scalar_type = mat1.scalar_type(); | ||||
|   bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float; | ||||
|   c10::MaybeOwned<Tensor> self_; | ||||
|   if (&result != &self) { | ||||
| #if defined(CUDA_VERSION) || defined(USE_ROCM) | ||||
|     // Strangely, if mat2 has only 1 row or column, we get | ||||
|     // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic. | ||||
|     // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] | ||||
|     // is to use lt interface only when self is bias. | ||||
|     // for cuda 11.4, cublasLtMatmul is activated | ||||
|     // the last two conditions is to skip 16b transA and non-trans-B having | ||||
|     // leading dim >> rows when they are sliced from a large tensor | ||||
|     // see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul | ||||
|     if (!disable_addmm_cuda_lt_final) { | ||||
|       useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 && | ||||
|           result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] && | ||||
|           self.is_contiguous() && result.is_contiguous() && | ||||
| #ifdef USE_ROCM | ||||
|           (scalar_type == at::ScalarType::Float || | ||||
|            scalar_type == at::ScalarType::Half || | ||||
|            scalar_type == at::ScalarType::BFloat16) && | ||||
| #else | ||||
|           (scalar_type == at::ScalarType::Double || | ||||
|            scalar_type == at::ScalarType::Float || | ||||
|            scalar_type == at::ScalarType::Half || | ||||
|            scalar_type == at::ScalarType::BFloat16) && | ||||
| #endif | ||||
| #if (defined(CUDA_VERSION) && CUDA_VERSION >= 12010 || defined(USE_ROCM)) | ||||
|           mat2_sizes[0] > 1 && mat2_sizes[1] > 1; | ||||
| #else | ||||
|           mat2_sizes[0] > 1 && mat2_sizes[1] > 1 && | ||||
|           mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 && | ||||
|           mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 && | ||||
|           // avoid leading dim >> rows bugs | ||||
|           ((mat1.strides()[0] == 1 && mat1.strides()[1] == mat1_sizes[0]) || | ||||
|            (mat1.strides()[1] == 1 && mat1.strides()[0] == mat1_sizes[1]) || | ||||
|            (scalar_type != at::ScalarType::Half && | ||||
|             scalar_type != at::ScalarType::BFloat16)) && | ||||
|           ((mat2.strides()[0] == 1 && mat2.strides()[1] == mat2_sizes[0]) || | ||||
|            (mat2.strides()[1] == 1 && mat2.strides()[0] == mat2_sizes[1]) || | ||||
|            (scalar_type != at::ScalarType::Half && | ||||
|             scalar_type != at::ScalarType::BFloat16)); | ||||
| #endif | ||||
|     } | ||||
| #endif | ||||
|     if (!useLtInterface) { | ||||
|       self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm"); | ||||
|     } | ||||
|     self__sizes = self_->sizes(); | ||||
|   } else { | ||||
|     self_ = c10::MaybeOwned<Tensor>::borrowed(self); | ||||
|     self__sizes = self_->sizes(); | ||||
|     TORCH_CHECK(result.dim() == 2, "tensors must be 2-D"); | ||||
|     TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0"); | ||||
|     TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1"); | ||||
|   } | ||||
|  | ||||
|   // Handle result/self shapes | ||||
|   if (!result.is_same(self)) { | ||||
|     at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]}); | ||||
|  | ||||
|     const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> { | ||||
|       if (disable_addmm_cuda_lt) { | ||||
|         // When in non-Lt path we do expand self even before | ||||
|         // check for beta != 0.0 to make sure that | ||||
|         // test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_* | ||||
|         // runs green. | ||||
|         return expand_size(self, result.sizes(), "addmm"); | ||||
|       } | ||||
|       // copy next, should broadcast | ||||
|       return c10::MaybeOwned<Tensor>::borrowed(self); | ||||
|     }(); | ||||
|     // We copy bias when in the non-Lt path | ||||
|     if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) { | ||||
|       // NOTE: self should broadcast over result | ||||
|       at::native::copy_(result, *self_maybe_expanded); | ||||
|   if (&result != &self) { | ||||
|     at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]}); | ||||
|     if (beta.toComplexDouble() != 0.0 && !useLtInterface) { | ||||
|       at::native::copy_(result, *self_); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   // Short circuit on empty result | ||||
|   if (result.numel() == 0) { | ||||
|  | ||||
|   IntArrayRef result_sizes = result.sizes(); | ||||
|   if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) { | ||||
|     return result; | ||||
|   } | ||||
|  | ||||
|   // Short circuit if the reduction dim is empty | ||||
|   if (mat1.sizes()[1] == 0) { | ||||
|   cublasCommonArgs args(mat1, mat2, result); | ||||
|  | ||||
|   if (mat1.numel() == 0) { | ||||
|     // By definition, when beta==0, values in self should be ignored. nans and infs | ||||
|     // should not propagate | ||||
|     if (beta.toComplexDouble() == 0.) { | ||||
| @ -557,64 +455,158 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma | ||||
|         result, | ||||
|         self.expand(result.sizes()), | ||||
|         at::native::scalar_tensor( | ||||
|           beta, | ||||
|           self.scalar_type(), | ||||
|           std::nullopt /* layout */, | ||||
|           at::kCPU, | ||||
|           std::nullopt /* pin_memory */ | ||||
|         ) | ||||
|     ); | ||||
|             beta, | ||||
|             self.scalar_type(), | ||||
|             std::nullopt /* layout */, | ||||
|             at::kCPU, | ||||
|             std::nullopt /* pin_memory */)); | ||||
|   } | ||||
|  | ||||
|   cublasCommonArgs args(mat1, mat2, result); | ||||
|   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj()); | ||||
|  | ||||
|   // The Lt path | ||||
|   if (!disable_addmm_cuda_lt) { | ||||
|     bool lt_success = false; | ||||
|   if (useLtInterface) { | ||||
| #if defined(USE_ROCM) | ||||
|     bool okay = true; | ||||
|     if (is_float_output_with_half_input) { | ||||
|       #ifdef USE_ROCM | ||||
|       TORCH_CHECK(false, "float output with half input is not enabled for ROCm"); | ||||
|       #else | ||||
|       if (at::cuda::tunable::getTuningContext()->IsTunableOpEnabled()) { | ||||
|        TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input"); | ||||
|       } | ||||
|       AT_DISPATCH_REDUCED_FLOATING_TYPES( | ||||
|         scalar_type, | ||||
|         "addmm_cuda_lt", | ||||
|         [&] { | ||||
|           lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation); | ||||
|         } | ||||
|       ); | ||||
|       #endif | ||||
|     } else { | ||||
|       // !is_float_output_with_half_input | ||||
|       AT_DISPATCH_FLOATING_TYPES_AND2( | ||||
|         at::ScalarType::Half, | ||||
|         at::ScalarType::BFloat16, | ||||
|         scalar_type, | ||||
|         "addmm_cuda_lt", | ||||
|         [&] { | ||||
|           lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation); | ||||
|         auto tuning_ctx = at::cuda::tunable::getTuningContext(); | ||||
|         if (tuning_ctx->IsTunableOpEnabled()) { | ||||
|           launchTunableGemmAndBias<scalar_t>( | ||||
|               args, | ||||
|               alpha, | ||||
|               (&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr, | ||||
|               activation_to_gemm_and_blas_arg(activation)); | ||||
|         } else { | ||||
|           okay = at::cuda::blas::gemm_and_bias<scalar_t>( | ||||
|             args.transa == 't', | ||||
|             args.transb == 't', | ||||
|             args.m, | ||||
|             args.n, | ||||
|             args.k, | ||||
|             alpha.to<at::opmath_type<scalar_t>>(), | ||||
|             args.mata->const_data_ptr<scalar_t>(), | ||||
|             args.lda, | ||||
|             args.matb->const_data_ptr<scalar_t>(), | ||||
|             args.ldb, | ||||
|             // This condition is needed for mm case on ROCm for hipblasLt path. | ||||
|             // Passing the bias ptr as null to avoid accuracy issues for mm case. | ||||
|             (&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr, | ||||
|             args.result->data_ptr<scalar_t>(), | ||||
|             args.result_ld, | ||||
|             activation_to_gemm_and_blas_arg(activation) | ||||
|           ); | ||||
|         } | ||||
|       ); | ||||
|     } // end is_float_output_with_half_input | ||||
|  | ||||
|     if (!lt_success) { | ||||
|     // lt path failed; recurse but disable lt path | ||||
|       }); | ||||
|     } | ||||
|     if (!okay) { | ||||
|       // lt path failed; recurse but disable lt path | ||||
|       return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true); | ||||
|     } | ||||
|     // end Lt path | ||||
|   } else { | ||||
|     // No Lt, we use a GEMM instead | ||||
| #else | ||||
|     auto activation_epilogue = activation_to_gemm_and_blas_arg(activation); | ||||
|     bool okay = true; | ||||
|     if (is_float_output_with_half_input) { | ||||
|       AT_DISPATCH_REDUCED_FLOATING_TYPES( | ||||
|         scalar_type, | ||||
|         "addmm_cuda_lt", | ||||
|         [&] { | ||||
|         auto tuning_ctx = at::cuda::tunable::getTuningContext(); | ||||
|         if (tuning_ctx->IsTunableOpEnabled()) { | ||||
|           TORCH_CHECK(false, "Tunable GEMM is not supported for float output with reduced float input"); | ||||
|         } | ||||
|         else { | ||||
|           okay = at::cuda::blas::gemm_and_bias<scalar_t, float>( | ||||
|               args.transa == 't', | ||||
|               args.transb == 't', | ||||
|               args.m, | ||||
|               args.n, | ||||
|               args.k, | ||||
|               alpha.to<at::opmath_type<scalar_t>>(), | ||||
|               args.mata->const_data_ptr<scalar_t>(), | ||||
|               args.lda, | ||||
|               args.matb->const_data_ptr<scalar_t>(), | ||||
|               args.ldb, | ||||
|               self.const_data_ptr<scalar_t>(), | ||||
|               args.result->data_ptr<float>(), | ||||
|               args.result_ld, | ||||
|               activation_epilogue | ||||
|           ); | ||||
|         }}); | ||||
|     } else { | ||||
|       AT_DISPATCH_FLOATING_TYPES_AND2( | ||||
|         at::ScalarType::Half, | ||||
|         at::ScalarType::BFloat16, | ||||
|         scalar_type, | ||||
|         "addmm_cuda_lt", | ||||
|         [&] { | ||||
|         auto tuning_ctx = at::cuda::tunable::getTuningContext(); | ||||
|         if (tuning_ctx->IsTunableOpEnabled()) { | ||||
|           launchTunableGemmAndBias<scalar_t>( | ||||
|               args, | ||||
|               alpha, | ||||
|               self.const_data_ptr<scalar_t>(), | ||||
|               activation_epilogue); | ||||
|         } | ||||
|         else { | ||||
|           okay = at::cuda::blas::gemm_and_bias<scalar_t>( | ||||
|               args.transa == 't', | ||||
|               args.transb == 't', | ||||
|               args.m, | ||||
|               args.n, | ||||
|               args.k, | ||||
|               alpha.to<at::opmath_type<scalar_t>>(), | ||||
|               args.mata->const_data_ptr<scalar_t>(), | ||||
|               args.lda, | ||||
|               args.matb->const_data_ptr<scalar_t>(), | ||||
|               args.ldb, | ||||
|               self.const_data_ptr<scalar_t>(), | ||||
|               args.result->data_ptr<scalar_t>(), | ||||
|               args.result_ld, | ||||
|               activation_epilogue | ||||
|           ); | ||||
|       }}); | ||||
|     } | ||||
|     if (!okay) { | ||||
|       // lt path failed; recurse but disable lt path | ||||
|       return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true); | ||||
|     } | ||||
| #endif | ||||
|   } else | ||||
|   { | ||||
|     if (is_float_output_with_half_input) { | ||||
|       AT_DISPATCH_REDUCED_FLOATING_TYPES( | ||||
|         scalar_type, | ||||
|         "addmm_cuda", | ||||
|         [&] { | ||||
|           launchGemmCublas<scalar_t, float>(args, alpha, beta); | ||||
|         } | ||||
|       ); | ||||
|           using opmath_t = at::opmath_type<scalar_t>; | ||||
|           opmath_t alpha_val = alpha.to<opmath_t>(); | ||||
|           opmath_t beta_val = beta.to<opmath_t>(); | ||||
|           const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>(); | ||||
|           const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>(); | ||||
|  | ||||
|           float* result_ptr = args.result->mutable_data_ptr<float>(); | ||||
|           at::cuda::blas::gemm<scalar_t, float>( | ||||
|               args.transa, | ||||
|               args.transb, | ||||
|               args.m, | ||||
|               args.n, | ||||
|               args.k, | ||||
|               alpha_val, | ||||
|               mat1_ptr, | ||||
|               args.lda, | ||||
|               mat2_ptr, | ||||
|               args.ldb, | ||||
|               beta_val, | ||||
|               result_ptr, | ||||
|               args.result_ld); | ||||
|         }); | ||||
|     } else { | ||||
|       AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( | ||||
|         at::ScalarType::Half, | ||||
| @ -622,12 +614,28 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma | ||||
|         scalar_type, | ||||
|         "addmm_cuda", | ||||
|         [&] { | ||||
|           launchGemmCublas<scalar_t>(args, alpha, beta); | ||||
|         } | ||||
|       ); | ||||
|           using opmath_t = at::opmath_type<scalar_t>; | ||||
|           opmath_t alpha_val = alpha.to<opmath_t>(); | ||||
|           opmath_t beta_val = beta.to<opmath_t>(); | ||||
|           const scalar_t* mat1_ptr = args.mata->const_data_ptr<scalar_t>(); | ||||
|           const scalar_t* mat2_ptr = args.matb->const_data_ptr<scalar_t>(); | ||||
|           scalar_t* result_ptr = args.result->mutable_data_ptr<scalar_t>(); | ||||
|           at::cuda::blas::gemm<scalar_t>( | ||||
|               args.transa, | ||||
|               args.transb, | ||||
|               args.m, | ||||
|               args.n, | ||||
|               args.k, | ||||
|               alpha_val, | ||||
|               mat1_ptr, | ||||
|               args.lda, | ||||
|               mat2_ptr, | ||||
|               args.ldb, | ||||
|               beta_val, | ||||
|               result_ptr, | ||||
|               args.result_ld); | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|     // Apply epilogue | ||||
|     switch (activation) { | ||||
|       case Activation::RELU: | ||||
|         // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) | ||||
| @ -639,14 +647,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma | ||||
|         break; | ||||
|       default: break; | ||||
|     } | ||||
|   } // end GEMM path | ||||
|   } | ||||
|  | ||||
| // Preprocessor gate here needs to match the inverse of the check | ||||
| // gating activation_to_gemm_and_blas_arg above; here we are manually | ||||
| // performing a post-GELU because we weren't able to use the GELU | ||||
| // epilogue above. | ||||
| #if !defined(CUDA_VERSION) && !defined(USE_ROCM) | ||||
|   if (!disable_addmm_cuda_lt && activation == Activation::GELU) { | ||||
|   if (useLtInterface && activation == Activation::GELU) { | ||||
|     at::gelu_(const_cast<Tensor&>(*args.result), "tanh"); | ||||
|   } | ||||
| #endif | ||||
| @ -1351,8 +1359,7 @@ _scaled_gemm( | ||||
|           const ScalingType scaling_choice_a, const ScalingType scaling_choice_b, | ||||
|           const std::optional<Tensor>& bias, | ||||
|           const bool use_fast_accum, | ||||
|           Tensor& out, | ||||
|           const std::optional<Tensor>& alpha = std::nullopt) { | ||||
|           Tensor& out) { | ||||
|   cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b); | ||||
|   const auto out_dtype_ = args.result->scalar_type(); | ||||
|   TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); | ||||
| @ -1403,8 +1410,7 @@ _scaled_gemm( | ||||
|           args.scale_result_ptr, | ||||
|           args.result_ld, | ||||
|           out_dtype_, | ||||
|           use_fast_accum, | ||||
|           alpha); | ||||
|           use_fast_accum); | ||||
|       return out; | ||||
|   } | ||||
| } | ||||
| @ -2314,23 +2320,12 @@ _scaled_nvfp4_nvfp4( | ||||
|           const Tensor& scale_b, const SwizzleType swizzle_b, | ||||
|           const std::optional<Tensor>& bias, | ||||
|           const c10::ScalarType out_dtype, | ||||
|           Tensor& out, | ||||
|           const std::optional<Tensor>& global_scale_a = std::nullopt, | ||||
|           const std::optional<Tensor>& global_scale_b = std::nullopt) { | ||||
|           const bool single_scale, | ||||
|           Tensor& out) { | ||||
| #ifdef USE_ROCM | ||||
|   TORCH_CHECK_NOT_IMPLEMENTED(false, "NVFP4 scaling not supported on ROCM"); | ||||
| #endif | ||||
|   std::optional<Tensor> alpha = std::nullopt; | ||||
|   // Note: "Or" here means that if only one scale is passed, we check for the other. Otherwise, | ||||
|   //       if this is "And" we would silently do nothing in the case where one global scale is | ||||
|   //       passed and not the other. | ||||
|   if (global_scale_a.has_value() || global_scale_b.has_value()) { | ||||
|     TORCH_CHECK_VALUE(global_scale_a.has_value(), | ||||
|         "For two-level-scaled NVFP4, global_scale_a must have a value"); | ||||
|     TORCH_CHECK_VALUE(global_scale_b.has_value(), | ||||
|         "For two-level-scaled NVFP4, global_scale_b must have a value"); | ||||
|     alpha = global_scale_a.value().mul(global_scale_b.value()); | ||||
|   } | ||||
|   TORCH_CHECK_VALUE(single_scale, "Only single-scaled NVFP4 currently supported"); | ||||
|   // Restrictions: | ||||
|   // A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32 | ||||
|   // Scales must be swizzled | ||||
| @ -2352,7 +2347,7 @@ _scaled_nvfp4_nvfp4( | ||||
|  | ||||
|   auto scaling_choice_a = ScalingType::BlockWise1x16; | ||||
|   auto scaling_choice_b = ScalingType::BlockWise1x16; | ||||
|   return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out, alpha); | ||||
|   return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -2558,10 +2553,9 @@ _scaled_mm_cuda_v2_out( | ||||
|   } else if (gemm_impl == ScaledGemmImplementation::MXFP8_MXFP8) { | ||||
|     return _scaled_mxfp8_mxfp8(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out); | ||||
|   } else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4) { | ||||
|     return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out, | ||||
|                                scale_a[1], scale_b[1]); | ||||
|     TORCH_CHECK_NOT_IMPLEMENTED(false, "Only single-scale NVFP4 currently supported"); | ||||
|   } else if (gemm_impl == ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE) { | ||||
|     return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out); | ||||
|     return _scaled_nvfp4_nvfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, true /* single_scale */, out); | ||||
|   } else if (gemm_impl == ScaledGemmImplementation::MXFP4_MXFP4) { | ||||
|     return _scaled_mxfp4_mxfp4(mat_a, mat_b, scale_a[0], swizzle_a_enum[0], scale_b[0], swizzle_b_enum[0], bias, out_dtype_, out); | ||||
|   } else { | ||||
|  | ||||
| @ -249,7 +249,7 @@ __global__ void max_pool_forward_nhwc( | ||||
| } | ||||
|  | ||||
|  | ||||
| static constexpr int BLOCK_THREADS = 256; | ||||
| static const int BLOCK_THREADS = 256; | ||||
|  | ||||
| template <typename scalar_t, typename accscalar_t> | ||||
| #if defined (USE_ROCM) | ||||
|  | ||||
| @ -15,7 +15,9 @@ | ||||
| #include <ATen/native/cuda/block_reduce.cuh> | ||||
| #include <ATen/native/cuda/thread_constants.h> | ||||
|  | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
| #include <thrust/iterator/reverse_iterator.h> | ||||
| #endif | ||||
|  | ||||
| #ifndef AT_PER_OPERATOR_HEADERS | ||||
| #include <ATen/Functions.h> | ||||
| @ -34,9 +36,9 @@ namespace at::native { | ||||
| namespace { | ||||
|  | ||||
| #if defined(USE_ROCM) | ||||
| static constexpr int BLOCKDIMY = 16; | ||||
| static const int BLOCKDIMY = 16; | ||||
| #else | ||||
| static constexpr int BLOCKDIMY = 32; | ||||
| static const int BLOCKDIMY = 32; | ||||
| #endif | ||||
|  | ||||
| template | ||||
| @ -238,6 +240,10 @@ __global__ void renorm_kernel( | ||||
|  | ||||
| } // anonymous namespace | ||||
|  | ||||
| #if !CUB_SUPPORTS_SCAN_BY_KEY() | ||||
| template<typename index_t> | ||||
| void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); | ||||
| #endif | ||||
|  | ||||
| Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_, | ||||
|                                int64_t num_weights, int64_t padding_idx, | ||||
| @ -300,6 +306,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice | ||||
|  | ||||
|   if (scale_grad_by_freq) { | ||||
|     count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { | ||||
|       cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|  | ||||
| @ -326,6 +333,11 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice | ||||
|         num_indices | ||||
|       ); | ||||
|     }); | ||||
| #else | ||||
|     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () { | ||||
|       embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count); | ||||
|     }); | ||||
| #endif | ||||
|   } | ||||
|  | ||||
|   return embedding_backward_cuda_kernel(grad, orig_indices, | ||||
|  | ||||
| @ -10,7 +10,9 @@ | ||||
|  | ||||
| #include <c10/macros/Macros.h> | ||||
|  | ||||
| #if CUB_SUPPORTS_UNIQUE_BY_KEY() | ||||
| #include <thrust/iterator/counting_iterator.h> | ||||
| #endif | ||||
|  | ||||
| #ifndef AT_PER_OPERATOR_HEADERS | ||||
| #include <ATen/Functions.h> | ||||
| @ -194,9 +196,18 @@ __global__ void compute_num_of_partial_segments(const index_t *partials_per_segm | ||||
|             partials_per_segment_offset[num_of_segments-1]; | ||||
| } | ||||
|  | ||||
| #if !CUB_SUPPORTS_UNIQUE_BY_KEY() | ||||
| __global__ void write_num_of_segments_for_legacy_thrust_path(int64_t *num_of_segments_ptr, int64_t num_of_segments) { | ||||
|   *num_of_segments_ptr = num_of_segments; | ||||
| } | ||||
| #endif | ||||
|  | ||||
| } // anon namespace | ||||
|  | ||||
| #if !CUB_SUPPORTS_UNIQUE_BY_KEY() | ||||
| template<typename index_t> | ||||
| int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets); | ||||
| #endif | ||||
|  | ||||
| Tensor embedding_backward_cuda_kernel( | ||||
|         const Tensor &grad, | ||||
| @ -223,12 +234,20 @@ Tensor embedding_backward_cuda_kernel( | ||||
|   auto segment_offsets = at::empty({numel}, orig_indices.options()); | ||||
|   auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong)); | ||||
|   int64_t *num_of_segments_ptr = num_of_segments_tensor.mutable_data_ptr<int64_t>(); | ||||
| #if !CUB_SUPPORTS_UNIQUE_BY_KEY() | ||||
|   AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () { | ||||
|     int64_t num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets); | ||||
|     write_num_of_segments_for_legacy_thrust_path<<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(num_of_segments_ptr, num_of_segments); | ||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
|   }); | ||||
| #else | ||||
|   AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () { | ||||
|     cuda::cub::unique_by_key( | ||||
|       sorted_indices.const_data_ptr<index_t>(), thrust::make_counting_iterator(0), | ||||
|       segment_offsets.mutable_data_ptr<index_t>(), | ||||
|       num_of_segments_ptr, sorted_indices.numel()); | ||||
|   }); | ||||
| #endif | ||||
|  | ||||
|   int64_t max_segments = std::min<int64_t>(numel, num_weights); | ||||
|  | ||||
|  | ||||
| @ -31,10 +31,16 @@ | ||||
|  | ||||
| #include <c10/macros/Macros.h> | ||||
|  | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
| #include <thrust/iterator/reverse_iterator.h> | ||||
| #endif | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| #if !CUB_SUPPORTS_SCAN_BY_KEY() | ||||
| template<typename index_t> | ||||
| void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); | ||||
| #endif | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| @ -193,6 +199,7 @@ Tensor embedding_bag_backward_cuda_sum_avg( | ||||
|  | ||||
|   if (scale_grad_by_freq) { | ||||
|     count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); | ||||
| #if CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () { | ||||
|       cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|  | ||||
| @ -219,6 +226,11 @@ Tensor embedding_bag_backward_cuda_sum_avg( | ||||
|         num_indices | ||||
|       ); | ||||
|     }); | ||||
| #else | ||||
|     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () { | ||||
|       embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count); | ||||
|     }); | ||||
| #endif | ||||
|   } | ||||
|   return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices, | ||||
|       count, num_weights, padding_idx, mode == EmbeddingBagMode::MEAN, offset2bag, | ||||
|  | ||||
| @ -82,7 +82,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) { | ||||
|   // lanczos approximation | ||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||
|  | ||||
|   constexpr accscalar_t lanczos_sum_expg_scaled_num[13] = { | ||||
|   static const accscalar_t lanczos_sum_expg_scaled_num[13] = { | ||||
|     0.006061842346248906525783753964555936883222, | ||||
|     0.5098416655656676188125178644804694509993, | ||||
|     19.51992788247617482847860966235652136208, | ||||
| @ -97,7 +97,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) { | ||||
|     103794043.1163445451906271053616070238554, | ||||
|     56906521.91347156388090791033559122686859 | ||||
|   }; | ||||
|   constexpr accscalar_t lanczos_sum_expg_scaled_denom[13] = { | ||||
|   static const accscalar_t lanczos_sum_expg_scaled_denom[13] = { | ||||
|     1., | ||||
|     66., | ||||
|     1925., | ||||
| @ -126,10 +126,10 @@ __host__ __device__ scalar_t _igam_helper_fac(scalar_t a, scalar_t x) { | ||||
|  | ||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||
|   accscalar_t ax, fac, res, num, numfac; | ||||
|   constexpr accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ? | ||||
|   static const accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ? | ||||
|     7.09782712893383996843E2 : 88.72283905206835; | ||||
|   constexpr accscalar_t EXP1 = 2.718281828459045; | ||||
|   constexpr accscalar_t lanczos_g = 6.024680040776729583740234375; | ||||
|   static const accscalar_t EXP1 = 2.718281828459045; | ||||
|   static const accscalar_t lanczos_g = 6.024680040776729583740234375; | ||||
|  | ||||
|   if (::fabs(a - x) > 0.4 * ::fabs(a)) { | ||||
|     ax = a * ::log(x) - x - ::lgamma(a); | ||||
| @ -158,9 +158,9 @@ __host__ __device__ scalar_t _igam_helper_series(scalar_t a, scalar_t x) { | ||||
|   // Compute igam using DLMF 8.11.4. [igam1] | ||||
|  | ||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||
|   constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||
|   static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||
|     1.11022302462515654042E-16 : 5.9604644775390625E-8; | ||||
|   constexpr int MAXITER = 2000; | ||||
|   static const int MAXITER = 2000; | ||||
|  | ||||
|   int i; | ||||
|   accscalar_t ans, ax, c, r; | ||||
| @ -196,8 +196,8 @@ __host__ __device__ scalar_t _igamc_helper_series(scalar_t a, scalar_t x) { | ||||
|   accscalar_t fac = 1; | ||||
|   accscalar_t sum = 0; | ||||
|   accscalar_t term, logx; | ||||
|   constexpr int MAXITER = 2000; | ||||
|   constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||
|   static const int MAXITER = 2000; | ||||
|   static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||
|     1.11022302462515654042E-16 : 5.9604644775390625E-8; | ||||
|  | ||||
|   for (n = 1; n < MAXITER; n++) { | ||||
| @ -219,7 +219,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t | ||||
|   // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] | ||||
|  | ||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||
|   constexpr accscalar_t d[25][25] = | ||||
|   static const accscalar_t d[25][25] = | ||||
|     {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15}, | ||||
|     {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15}, | ||||
|     {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15}, | ||||
| @ -248,7 +248,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t | ||||
|  | ||||
|   int k, n, sgn; | ||||
|   int maxpow = 0; | ||||
|   constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||
|   static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||
|     1.11022302462515654042E-16 : 5.9604644775390625E-8; | ||||
|   accscalar_t lambda = x / a; | ||||
|   accscalar_t sigma = (x - a) / a; | ||||
| @ -314,12 +314,12 @@ __host__ __device__ scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar | ||||
|   int i; | ||||
|   accscalar_t ans, ax, c, yc, r, t, y, z; | ||||
|   accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2; | ||||
|   constexpr int MAXITER = 2000; | ||||
|   constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||
|   static const int MAXITER = 2000; | ||||
|   static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? | ||||
|     1.11022302462515654042E-16 : 5.9604644775390625E-8; | ||||
|   constexpr accscalar_t BIG = std::is_same_v<accscalar_t,double> ? | ||||
|   static const accscalar_t BIG = std::is_same_v<accscalar_t,double> ? | ||||
|     4.503599627370496e15 : 16777216.; | ||||
|   constexpr accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ? | ||||
|   static const accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ? | ||||
|     2.22044604925031308085e-16 : 5.9604644775390625E-8; | ||||
|  | ||||
|   ax = _igam_helper_fac(a, x); | ||||
| @ -385,10 +385,10 @@ __noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) { | ||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||
|   accscalar_t absxma_a; | ||||
|  | ||||
|   constexpr accscalar_t SMALL = 20.0; | ||||
|   constexpr accscalar_t LARGE = 200.0; | ||||
|   constexpr accscalar_t SMALLRATIO = 0.3; | ||||
|   constexpr accscalar_t LARGERATIO = 4.5; | ||||
|   static const accscalar_t SMALL = 20.0; | ||||
|   static const accscalar_t LARGE = 200.0; | ||||
|   static const accscalar_t SMALLRATIO = 0.3; | ||||
|   static const accscalar_t LARGERATIO = 4.5; | ||||
|  | ||||
|   if ((x < 0) || (a < 0)) { | ||||
|     // out of defined-region of the function | ||||
| @ -467,10 +467,10 @@ __noinline__ __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) { | ||||
|  | ||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||
|   accscalar_t absxma_a; | ||||
|   constexpr accscalar_t SMALL = 20.0; | ||||
|   constexpr accscalar_t LARGE = 200.0; | ||||
|   constexpr accscalar_t SMALLRATIO = 0.3; | ||||
|   constexpr accscalar_t LARGERATIO = 4.5; | ||||
|   static const accscalar_t SMALL = 20.0; | ||||
|   static const accscalar_t LARGE = 200.0; | ||||
|   static const accscalar_t SMALLRATIO = 0.3; | ||||
|   static const accscalar_t LARGERATIO = 4.5; | ||||
|  | ||||
|   // boundary values following SciPy | ||||
|   if ((x < 0) || (a < 0)) { | ||||
|  | ||||
							
								
								
									
										90
									
								
								aten/src/ATen/native/cuda/LegacyThrustHelpers.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								aten/src/ATen/native/cuda/LegacyThrustHelpers.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,90 @@ | ||||
| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS | ||||
| #include <ATen/core/Tensor.h> | ||||
| #include <ATen/native/cuda/SortingCommon.cuh> | ||||
| #include <ATen/cuda/cub_definitions.cuh> | ||||
|  | ||||
| #ifndef AT_PER_OPERATOR_HEADERS | ||||
| #include <ATen/Functions.h> | ||||
| #else | ||||
| #include <ATen/ops/empty_like.h> | ||||
| #endif | ||||
|  | ||||
| #include <ATen/cuda/ThrustAllocator.h> | ||||
| #include <thrust/device_ptr.h> | ||||
| #include <thrust/execution_policy.h> | ||||
| #include <thrust/sort.h> | ||||
| #include <thrust/unique.h> | ||||
| #include <thrust/device_ptr.h> | ||||
| #include <thrust/iterator/constant_iterator.h> | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| #if !CUB_SUPPORTS_SCAN_BY_KEY() | ||||
|  | ||||
| template<typename index_t> | ||||
| void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) { | ||||
|   cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   at::cuda::ThrustAllocator allocator; | ||||
|   auto policy = thrust::cuda::par(allocator).on(stream); | ||||
|  | ||||
|   auto num_indices = count.numel(); | ||||
|  | ||||
|   // Compute an increasing sequence per unique item in sortedIndices: | ||||
|   // sorted: 2 5 5 5 7 7 8 9 9 | ||||
|   //  count: 1 1 2 3 1 2 1 1 2 | ||||
|   auto sorted_data = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>()); | ||||
|   auto count_data = thrust::device_ptr<index_t>(count.mutable_data_ptr<index_t>()); | ||||
|   thrust::inclusive_scan_by_key( | ||||
|     policy, | ||||
|     sorted_data, | ||||
|     sorted_data + num_indices, | ||||
|     thrust::make_constant_iterator(1), | ||||
|     count_data | ||||
|   ); | ||||
|  | ||||
|   // Take the maximum of each count per unique key in reverse: | ||||
|   // sorted: 2 5 5 5 7 7 8 9 9 | ||||
|   //  count: 1 3 3 3 2 2 1 2 2 | ||||
|   thrust::inclusive_scan_by_key( | ||||
|     policy, | ||||
|     thrust::make_reverse_iterator(sorted_data + num_indices), | ||||
|     thrust::make_reverse_iterator(sorted_data), | ||||
|     thrust::make_reverse_iterator(count_data + num_indices), | ||||
|     thrust::make_reverse_iterator(count_data + num_indices), | ||||
|     thrust::equal_to<index_t>(), | ||||
|     thrust::maximum<index_t>() | ||||
|   ); | ||||
| } | ||||
|  | ||||
| template | ||||
| void embedding_dense_backward_cuda_scan<int>(Tensor &sorted_indices, Tensor &count); | ||||
| template | ||||
| void embedding_dense_backward_cuda_scan<int64_t>(Tensor &sorted_indices, Tensor &count); | ||||
|  | ||||
| #endif | ||||
|  | ||||
| template<typename index_t> | ||||
| int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets) { | ||||
|   auto stream = at::cuda::getCurrentCUDAStream(); | ||||
|   at::cuda::ThrustAllocator allocator; | ||||
|   auto policy = thrust::cuda::par(allocator).on(stream); | ||||
|   const ptrdiff_t numel = sorted_indices.numel(); | ||||
|   auto sorted_indices_dev = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>()); | ||||
|   auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); | ||||
|   auto dummy_dev = thrust::device_ptr<index_t>(dummy.mutable_data_ptr<index_t>()); | ||||
|   auto ends = thrust::unique_by_key_copy( | ||||
|           policy, | ||||
|           sorted_indices_dev, | ||||
|           sorted_indices_dev + numel, | ||||
|           thrust::make_counting_iterator(0), | ||||
|           dummy_dev, | ||||
|           thrust::device_ptr<index_t>(segment_offsets.mutable_data_ptr<index_t>())); | ||||
|   return thrust::get<0>(ends) - dummy_dev; | ||||
| } | ||||
|  | ||||
| template | ||||
| int64_t embedding_backward_cuda_kernel_unique_by_key<int>(const Tensor &sorted_indices, Tensor &segment_offsets); | ||||
| template | ||||
| int64_t embedding_backward_cuda_kernel_unique_by_key<int64_t>(const Tensor &sorted_indices, Tensor &segment_offsets); | ||||
|  | ||||
| } // namespace at::native | ||||
| @ -1,17 +1,18 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/OpMathType.h> | ||||
| #include <ATen/cuda/detail/OffsetCalculator.cuh> | ||||
| #include <ATen/detail/FunctionTraits.h> | ||||
| #include <ATen/native/TensorIterator.h> | ||||
| #include <ATen/native/TensorIteratorDynamicCasting.h> | ||||
| #include <ATen/cuda/detail/OffsetCalculator.cuh> | ||||
| #include <ATen/OpMathType.h> | ||||
| #include <ATen/native/cuda/thread_constants.h> | ||||
|  | ||||
| #include <thrust/tuple.h> | ||||
|  | ||||
| #include <ATen/native/cuda/MemoryAccess.cuh> | ||||
|  | ||||
| #include <tuple> | ||||
|  | ||||
|  | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| template<int N> | ||||
| @ -61,11 +62,7 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) { | ||||
|   #pragma unroll | ||||
|   for (int i = 0; i < elems_per_thread; i++) { | ||||
|     if (policy.check_inbounds(i)) { | ||||
| #if defined(__HIP__) | ||||
|       results[i] = c10::guts::apply(f, args[i]); | ||||
| #else | ||||
|       results[i] = std::apply(f, args[i]); | ||||
| #endif | ||||
|     } | ||||
|   } | ||||
|  | ||||
|  | ||||
| @ -231,7 +231,7 @@ const auto lcm_string = jiterator_stringify( | ||||
| const auto digamma_string = jiterator_stringify( | ||||
|   template <typename T> | ||||
|   T digamma(T x) { | ||||
|     static constexpr double PI_f64 = 3.14159265358979323846; | ||||
|     static const double PI_f64 = 3.14159265358979323846; | ||||
|  | ||||
|     // Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard | ||||
|     if (x == 0) { | ||||
| @ -3072,9 +3072,9 @@ template <typename scalar_t> | ||||
| static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { | ||||
|   // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma | ||||
|   using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; | ||||
|   static constexpr double PI_f64 = 3.14159265358979323846; | ||||
|   constexpr accscalar_t PSI_10 = 2.25175258906672110764; | ||||
|   constexpr accscalar_t A[] = { | ||||
|   static const double PI_f64 = 3.14159265358979323846; | ||||
|   const accscalar_t PSI_10 = 2.25175258906672110764; | ||||
|   const accscalar_t A[] = { | ||||
|       8.33333333333333333333E-2, | ||||
|       -2.10927960927960927961E-2, | ||||
|       7.57575757575757575758E-3, | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	