mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-19 10:04:58 +08:00
Compare commits
61 Commits
sy_invoke_
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| d25559423f | |||
| 654c149d07 | |||
| 39ebab1dd9 | |||
| 4c152a71ad | |||
| 1b43d6cd4e | |||
| 2b69673bbf | |||
| 2f74916e36 | |||
| 2b5eabc74b | |||
| 9ff95f6835 | |||
| 6fdb974f4a | |||
| 661d1653aa | |||
| 53809f9640 | |||
| 93ddd38ecd | |||
| 5804408f1b | |||
| 99117c1238 | |||
| b9bccec3bc | |||
| ca3aaef66e | |||
| f2e6f94081 | |||
| aa504d4d2a | |||
| d8ce6f8df9 | |||
| 4322354770 | |||
| 363385ad3e | |||
| e2e10753d7 | |||
| 5d99a795f5 | |||
| 2245d7d3b9 | |||
| 98b94b90dd | |||
| 5cdbda140c | |||
| 0ec53beaeb | |||
| 79fc0a9141 | |||
| d01a7b0241 | |||
| deabb3e36d | |||
| 79d2397b6b | |||
| 6ef3a62c36 | |||
| 530e782239 | |||
| c66a6c432e | |||
| 3d7a8b7e61 | |||
| de0d69b2c4 | |||
| bc60b86066 | |||
| d7782ddde7 | |||
| fb04e9ad03 | |||
| cfe799b4aa | |||
| b7f52773e6 | |||
| f6b54d8899 | |||
| da91bf5262 | |||
| 1c1638297e | |||
| ee0b5b4b1c | |||
| fcfb213c5a | |||
| 08042bbb9c | |||
| e20ca3bc2e | |||
| 4ed26f7382 | |||
| 4c79305b87 | |||
| f4b8c4f907 | |||
| d629b7a459 | |||
| 0922ba5f42 | |||
| c87295c044 | |||
| 7aa210d215 | |||
| 5a368b8010 | |||
| 602102be50 | |||
| 200156e385 | |||
| 80a0fd3f4d | |||
| 4c4a6b3644 |
@ -75,9 +75,11 @@ if [[ "$ARCH" == "aarch64" ]]; then
|
||||
# ARM system libraries
|
||||
DEPS_LIST+=(
|
||||
"/usr/lib64/libgfortran.so.5"
|
||||
"/opt/OpenBLAS/lib/libopenblas.so.0"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libgfortran.so.5"
|
||||
"libopenblas.so.0"
|
||||
)
|
||||
fi
|
||||
|
||||
|
||||
@ -100,337 +100,6 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _compile_and_extract_symbols(
|
||||
cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Helper to compile a C++ file and extract all symbols.
|
||||
|
||||
Args:
|
||||
cpp_content: C++ source code to compile
|
||||
compile_flags: Compilation flags
|
||||
exclude_list: List of symbol names to exclude. Defaults to ["main"].
|
||||
|
||||
Returns:
|
||||
List of all symbols found in the object file (excluding those in exclude_list).
|
||||
"""
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
if exclude_list is None:
|
||||
exclude_list = ["main"]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmppath = Path(tmpdir)
|
||||
cpp_file = tmppath / "test.cpp"
|
||||
obj_file = tmppath / "test.o"
|
||||
|
||||
cpp_file.write_text(cpp_content)
|
||||
|
||||
result = subprocess.run(
|
||||
compile_flags + [str(cpp_file), "-o", str(obj_file)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Compilation failed: {result.stderr}")
|
||||
|
||||
symbols = get_symbols(str(obj_file))
|
||||
|
||||
# Return all symbol names, excluding those in the exclude list
|
||||
return [name for _addr, _stype, name in symbols if name not in exclude_list]
|
||||
|
||||
|
||||
def check_stable_only_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
|
||||
|
||||
This approach tests:
|
||||
1. WITHOUT macros -> many torch symbols exposed
|
||||
2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
|
||||
3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
|
||||
4. WITH both macros -> zero torch symbols (all hidden)
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
test_cpp_content = """
|
||||
// Main torch C++ API headers
|
||||
#include <torch/torch.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
// ATen tensor library
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
// Core c10 headers (commonly used)
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
int main() { return 0; }
|
||||
"""
|
||||
|
||||
base_compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c", # Compile only, don't link
|
||||
]
|
||||
|
||||
# Compile WITHOUT any macros
|
||||
symbols_without = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=base_compile_flags,
|
||||
)
|
||||
|
||||
# We expect constexpr symbols, inline functions used by other headers etc.
|
||||
# to produce symbols
|
||||
num_symbols_without = len(symbols_without)
|
||||
print(f"Found {num_symbols_without} symbols without any macros defined")
|
||||
assert num_symbols_without != 0, (
|
||||
"Expected a non-zero number of symbols without any macros"
|
||||
)
|
||||
|
||||
# Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
|
||||
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
|
||||
|
||||
symbols_with_stable_only = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_stable_only,
|
||||
)
|
||||
|
||||
num_symbols_with_stable_only = len(symbols_with_stable_only)
|
||||
assert num_symbols_with_stable_only == 0, (
|
||||
f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}"
|
||||
)
|
||||
|
||||
# Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
|
||||
compile_flags_with_target_version = base_compile_flags + [
|
||||
"-DTORCH_TARGET_VERSION=1"
|
||||
]
|
||||
|
||||
symbols_with_target_version = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_target_version,
|
||||
)
|
||||
|
||||
num_symbols_with_target_version = len(symbols_with_target_version)
|
||||
assert num_symbols_with_target_version == 0, (
|
||||
f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}"
|
||||
)
|
||||
|
||||
# Compile WITH both macros (expect 0 symbols)
|
||||
compile_flags_with_both = base_compile_flags + [
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
"-DTORCH_TARGET_VERSION=1",
|
||||
]
|
||||
|
||||
symbols_with_both = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_both,
|
||||
)
|
||||
|
||||
num_symbols_with_both = len(symbols_with_both)
|
||||
assert num_symbols_with_both == 0, (
|
||||
f"Expected no symbols with both macros, but found {num_symbols_with_both}"
|
||||
)
|
||||
|
||||
|
||||
def check_stable_api_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
stable_dir = include_dir / "torch" / "csrc" / "stable"
|
||||
assert stable_dir.exists(), f"Expected {stable_dir} to be present"
|
||||
|
||||
stable_headers = list(stable_dir.rglob("*.h"))
|
||||
if not stable_headers:
|
||||
raise RuntimeError("Could not find any stable headers")
|
||||
|
||||
includes = []
|
||||
for header in stable_headers:
|
||||
rel_path = header.relative_to(include_dir)
|
||||
includes.append(f"#include <{rel_path.as_posix()}>")
|
||||
|
||||
includes_str = "\n".join(includes)
|
||||
test_stable_content = f"""
|
||||
{includes_str}
|
||||
int main() {{ return 0; }}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_stable = _compile_and_extract_symbols(
|
||||
cpp_content=test_stable_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_stable = len(symbols_stable)
|
||||
print(f"Found {num_symbols_stable} symbols in torch/csrc/stable")
|
||||
assert num_symbols_stable > 0, (
|
||||
f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_stable} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_headeronly_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# Find all headers in torch/headeronly
|
||||
headeronly_dir = include_dir / "torch" / "headeronly"
|
||||
assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present"
|
||||
headeronly_headers = list(headeronly_dir.rglob("*.h"))
|
||||
if not headeronly_headers:
|
||||
raise RuntimeError("Could not find any headeronly headers")
|
||||
|
||||
# Filter out platform-specific headers that may not compile everywhere
|
||||
platform_specific_keywords = [
|
||||
"cpu/vec",
|
||||
]
|
||||
|
||||
filtered_headers = []
|
||||
for header in headeronly_headers:
|
||||
rel_path = header.relative_to(include_dir).as_posix()
|
||||
if not any(
|
||||
keyword in rel_path.lower() for keyword in platform_specific_keywords
|
||||
):
|
||||
filtered_headers.append(header)
|
||||
|
||||
includes = []
|
||||
for header in filtered_headers:
|
||||
rel_path = header.relative_to(include_dir)
|
||||
includes.append(f"#include <{rel_path.as_posix()}>")
|
||||
|
||||
includes_str = "\n".join(includes)
|
||||
test_headeronly_content = f"""
|
||||
{includes_str}
|
||||
int main() {{ return 0; }}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_headeronly = _compile_and_extract_symbols(
|
||||
cpp_content=test_headeronly_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_headeronly = len(symbols_headeronly)
|
||||
print(f"Found {num_symbols_headeronly} symbols in torch/headeronly")
|
||||
assert num_symbols_headeronly > 0, (
|
||||
f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_headeronly} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_aoti_shim_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# There are no constexpr symbols etc., so we need to actually use functions
|
||||
# so that some symbols are found.
|
||||
test_shim_content = """
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
int main() {
|
||||
int32_t (*fp1)() = &aoti_torch_device_type_cpu;
|
||||
int32_t (*fp2)() = &aoti_torch_dtype_float32;
|
||||
(void)fp1; (void)fp2;
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_shim = _compile_and_extract_symbols(
|
||||
cpp_content=test_shim_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_shim = len(symbols_shim)
|
||||
assert num_symbols_shim > 0, (
|
||||
f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_shim} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_stable_c_shim_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# Check if the stable C shim exists
|
||||
stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
|
||||
if not stable_shim.exists():
|
||||
raise RuntimeError("Could not find stable c shim")
|
||||
|
||||
# There are no constexpr symbols etc., so we need to actually use functions
|
||||
# so that some symbols are found.
|
||||
test_stable_shim_content = """
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
int main() {
|
||||
// Reference stable C API functions to create undefined symbols
|
||||
AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
|
||||
AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
|
||||
(void)fp1; (void)fp2;
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_stable_shim = _compile_and_extract_symbols(
|
||||
cpp_content=test_stable_shim_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_stable_shim = len(symbols_stable_shim)
|
||||
assert num_symbols_stable_shim > 0, (
|
||||
f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_stable_shim} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_lib_symbols_for_abi_correctness(lib: str) -> None:
|
||||
print(f"lib: {lib}")
|
||||
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
|
||||
@ -460,13 +129,6 @@ def main() -> None:
|
||||
check_lib_symbols_for_abi_correctness(libtorch_cpu_path)
|
||||
check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path)
|
||||
|
||||
# Check symbols when TORCH_STABLE_ONLY is defined
|
||||
check_stable_only_symbols(install_root)
|
||||
check_stable_api_symbols(install_root)
|
||||
check_headeronly_symbols(install_root)
|
||||
check_aoti_shim_symbols(install_root)
|
||||
check_stable_c_shim_symbols(install_root)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -389,6 +389,13 @@ test_lazy_tensor_meta_reference_disabled() {
|
||||
export -n TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE
|
||||
}
|
||||
|
||||
test_dynamo_core() {
|
||||
time python test/run_test.py \
|
||||
--include-dynamo-core-tests \
|
||||
--verbose \
|
||||
--upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_dynamo_wrapped_shard() {
|
||||
if [[ -z "$NUM_TEST_SHARDS" ]]; then
|
||||
@ -1814,6 +1821,8 @@ elif [[ "${TEST_CONFIG}" == *inductor* ]]; then
|
||||
test_inductor_shard "${SHARD_NUMBER}"
|
||||
elif [[ "${TEST_CONFIG}" == *einops* ]]; then
|
||||
test_einops
|
||||
elif [[ "${TEST_CONFIG}" == *dynamo_core* ]]; then
|
||||
test_dynamo_core
|
||||
elif [[ "${TEST_CONFIG}" == *dynamo_wrapped* ]]; then
|
||||
install_torchvision
|
||||
test_dynamo_wrapped_shard "${SHARD_NUMBER}"
|
||||
|
||||
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
||||
07b6cbde121417a70e4dc871adb6d27030e0ce3f
|
||||
ee1a1350eb37804b94334768f328144f058f14e9
|
||||
|
||||
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
||||
acccf86477759b2d3500f1ae1be065f7b1e409ec
|
||||
2d82dc5caa336d179d9b46ac4a0fb8c43d84c5cc
|
||||
|
||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a
|
||||
94631807d22c09723dd006f7be5beb649d5f88d0
|
||||
|
||||
1
.github/pytorch-probot.yml
vendored
1
.github/pytorch-probot.yml
vendored
@ -7,6 +7,7 @@ ciflow_push_tags:
|
||||
- ciflow/binaries
|
||||
- ciflow/binaries_libtorch
|
||||
- ciflow/binaries_wheel
|
||||
- ciflow/dynamo
|
||||
- ciflow/h100
|
||||
- ciflow/h100-cutlass-backend
|
||||
- ciflow/h100-distributed
|
||||
|
||||
2
.github/workflows/_linux-test.yml
vendored
2
.github/workflows/_linux-test.yml
vendored
@ -326,7 +326,7 @@ jobs:
|
||||
SCCACHE_BUCKET: ${{ !contains(matrix.runner, 'b200') && 'ossci-compiler-cache-circleci-v2' || '' }}
|
||||
SCCACHE_REGION: ${{ !contains(matrix.runner, 'b200') && 'us-east-1' || '' }}
|
||||
SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }}
|
||||
DOCKER_IMAGE: ${{ inputs.docker-image }}
|
||||
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }}
|
||||
XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla
|
||||
PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }}
|
||||
|
||||
70
.github/workflows/dynamo-unittest.yml
vendored
Normal file
70
.github/workflows/dynamo-unittest.yml
vendored
Normal file
@ -0,0 +1,70 @@
|
||||
# Workflow: Dynamo Unit Test
|
||||
# runs unit tests for dynamo.
|
||||
name: dynamo-unittest
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/dynamo/*
|
||||
workflow_call:
|
||||
schedule:
|
||||
- cron: 29 8 * * * # about 1:29am PDT
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
opt_out_experiments: lf
|
||||
|
||||
dynamo-build:
|
||||
name: dynamo-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.11', '3.12']
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-py${{ matrix.python-version }}-clang12
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang12
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "dynamo_core", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
dynamo-test:
|
||||
name: dynamo-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: [get-label-type, dynamo-build]
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.11', '3.12']
|
||||
with:
|
||||
build-environment: linux-jammy-py${{ matrix.python-version }}-clang12
|
||||
docker-image: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang12
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "dynamo_core", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
@ -144,7 +144,7 @@ inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
|
||||
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
|
||||
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ')';
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ namespace indexing {
|
||||
const EllipsisIndexType Ellipsis = EllipsisIndexType();
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, const Slice& slice) {
|
||||
stream << slice.start() << ":" << slice.stop() << ":" << slice.step();
|
||||
stream << slice.start() << ':' << slice.stop() << ':' << slice.step();
|
||||
return stream;
|
||||
}
|
||||
|
||||
@ -31,12 +31,12 @@ std::ostream& operator<<(std::ostream& stream, const TensorIndex& tensor_index)
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, const std::vector<TensorIndex>& tensor_indices) {
|
||||
stream << "(";
|
||||
stream << '(';
|
||||
for (const auto i : c10::irange(tensor_indices.size())) {
|
||||
stream << tensor_indices[i];
|
||||
if (i < tensor_indices.size() - 1) stream << ", ";
|
||||
}
|
||||
stream << ")";
|
||||
stream << ')';
|
||||
return stream;
|
||||
}
|
||||
|
||||
|
||||
@ -113,7 +113,7 @@ void TensorNames::checkUnique(const char* op_name) const {
|
||||
std::ostream& operator<<(std::ostream& out, const TensorName& tensorname) {
|
||||
out << tensorname.name_ << " (index ";
|
||||
out << tensorname.origin_idx_ << " of ";
|
||||
out << tensorname.origin_ << ")";
|
||||
out << tensorname.origin_ << ')';
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -13,9 +13,9 @@ std::ostream& operator<<(std::ostream & out, const TensorGeometryArg& t) {
|
||||
if (t.pos == 0) {
|
||||
// 0 is distinguished; it usually indicates 'self' or the return
|
||||
// tensor
|
||||
out << "'" << t.name << "'";
|
||||
out << '\'' << t.name << '\'';
|
||||
} else {
|
||||
out << "argument #" << t.pos << " '" << t.name << "'";
|
||||
out << "argument #" << t.pos << " '" << t.name << '\'';
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -154,7 +154,7 @@ void checkSameGPU(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
|
||||
oss << "Tensor for " << t2 << " is on CPU, ";
|
||||
}
|
||||
oss << "but expected " << ((!t1->is_cpu() && !t2->is_cpu()) ? "them" : "it")
|
||||
<< " to be on GPU (while checking arguments for " << c << ")";
|
||||
<< " to be on GPU (while checking arguments for " << c << ')';
|
||||
TORCH_CHECK(false, oss.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
@ -199,7 +199,7 @@ void checkScalarTypes(CheckedFrom c, const TensorArg& t,
|
||||
i++;
|
||||
}
|
||||
oss << "; but got " << t->toString()
|
||||
<< " instead (while checking arguments for " << c << ")";
|
||||
<< " instead (while checking arguments for " << c << ')';
|
||||
TORCH_CHECK(false, oss.str());
|
||||
}
|
||||
}
|
||||
|
||||
@ -43,8 +43,8 @@ std::string get_mkldnn_version() {
|
||||
// https://github.com/intel/ideep/issues/29
|
||||
{
|
||||
const dnnl_version_t* ver = dnnl_version();
|
||||
ss << "Intel(R) MKL-DNN v" << ver->major << "." << ver->minor << "." << ver->patch
|
||||
<< " (Git Hash " << ver->hash << ")";
|
||||
ss << "Intel(R) MKL-DNN v" << ver->major << '.' << ver->minor << '.' << ver->patch
|
||||
<< " (Git Hash " << ver->hash << ')';
|
||||
}
|
||||
#else
|
||||
ss << "MKLDNN not found";
|
||||
@ -81,7 +81,7 @@ std::string get_openmp_version() {
|
||||
break;
|
||||
}
|
||||
if (ver_str) {
|
||||
ss << " (a.k.a. OpenMP " << ver_str << ")";
|
||||
ss << " (a.k.a. OpenMP " << ver_str << ')';
|
||||
}
|
||||
}
|
||||
#else
|
||||
@ -135,38 +135,38 @@ std::string show_config() {
|
||||
|
||||
#if defined(__GNUC__)
|
||||
{
|
||||
ss << " - GCC " << __GNUC__ << "." << __GNUC_MINOR__ << "\n";
|
||||
ss << " - GCC " << __GNUC__ << '.' << __GNUC_MINOR__ << '\n';
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__cplusplus)
|
||||
{
|
||||
ss << " - C++ Version: " << __cplusplus << "\n";
|
||||
ss << " - C++ Version: " << __cplusplus << '\n';
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__clang_major__)
|
||||
{
|
||||
ss << " - clang " << __clang_major__ << "." << __clang_minor__ << "." << __clang_patchlevel__ << "\n";
|
||||
ss << " - clang " << __clang_major__ << '.' << __clang_minor__ << '.' << __clang_patchlevel__ << '\n';
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
{
|
||||
ss << " - MSVC " << _MSC_FULL_VER << "\n";
|
||||
ss << " - MSVC " << _MSC_FULL_VER << '\n';
|
||||
}
|
||||
#endif
|
||||
|
||||
#if AT_MKL_ENABLED()
|
||||
ss << " - " << get_mkl_version() << "\n";
|
||||
ss << " - " << get_mkl_version() << '\n';
|
||||
#endif
|
||||
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
ss << " - " << get_mkldnn_version() << "\n";
|
||||
ss << " - " << get_mkldnn_version() << '\n';
|
||||
#endif
|
||||
|
||||
#ifdef _OPENMP
|
||||
ss << " - " << get_openmp_version() << "\n";
|
||||
ss << " - " << get_openmp_version() << '\n';
|
||||
#endif
|
||||
|
||||
#if AT_BUILD_WITH_LAPACK()
|
||||
@ -183,7 +183,7 @@ std::string show_config() {
|
||||
ss << " - Cross compiling on MacOSX\n";
|
||||
#endif
|
||||
|
||||
ss << " - "<< used_cpu_capability() << "\n";
|
||||
ss << " - "<< used_cpu_capability() << '\n';
|
||||
|
||||
if (hasCUDA()) {
|
||||
ss << detail::getCUDAHooks().showConfig();
|
||||
@ -200,10 +200,10 @@ std::string show_config() {
|
||||
ss << " - Build settings: ";
|
||||
for (const auto& pair : caffe2::GetBuildOptions()) {
|
||||
if (!pair.second.empty()) {
|
||||
ss << pair.first << "=" << pair.second << ", ";
|
||||
ss << pair.first << '=' << pair.second << ", ";
|
||||
}
|
||||
}
|
||||
ss << "\n";
|
||||
ss << '\n';
|
||||
|
||||
// TODO: do HIP
|
||||
// TODO: do XLA
|
||||
|
||||
@ -209,7 +209,7 @@ struct CodeTemplate {
|
||||
// to indent correctly in the context.
|
||||
void emitIndent(std::ostream& out, size_t indent) const {
|
||||
for ([[maybe_unused]] const auto i : c10::irange(indent)) {
|
||||
out << " ";
|
||||
out << ' ';
|
||||
}
|
||||
}
|
||||
void emitStringWithIndents(
|
||||
|
||||
@ -10,7 +10,7 @@ std::ostream& operator<<(std::ostream& out, const Dimname& dimname) {
|
||||
if (dimname.type() == NameType::WILDCARD) {
|
||||
out << "None";
|
||||
} else {
|
||||
out << "'" << dimname.symbol().toUnqualString() << "'";
|
||||
out << '\'' << dimname.symbol().toUnqualString() << '\'';
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
namespace at {
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Range& range) {
|
||||
out << "Range[" << range.begin << ", " << range.end << "]";
|
||||
out << "Range[" << range.begin << ", " << range.end << ']';
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -71,7 +71,7 @@ void TensorBase::enforce_invariants() {
|
||||
|
||||
void TensorBase::print() const {
|
||||
if (defined()) {
|
||||
std::cerr << "[" << toString() << " " << sizes() << "]" << '\n';
|
||||
std::cerr << '[' << toString() << ' ' << sizes() << ']' << '\n';
|
||||
} else {
|
||||
std::cerr << "[UndefinedTensor]" << '\n';
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/core/TensorAccessor.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <c10/util/Deprecated.h>
|
||||
@ -11,252 +12,37 @@
|
||||
|
||||
namespace at {
|
||||
|
||||
// The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor
|
||||
// is used to enable the __restrict__ keyword/modifier for the data
|
||||
// passed to cuda.
|
||||
template <typename T>
|
||||
struct DefaultPtrTraits {
|
||||
typedef T* PtrType;
|
||||
};
|
||||
|
||||
using torch::headeronly::DefaultPtrTraits;
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
template <typename T>
|
||||
struct RestrictPtrTraits {
|
||||
typedef T* __restrict__ PtrType;
|
||||
};
|
||||
using torch::headeronly::RestrictPtrTraits;
|
||||
#endif
|
||||
|
||||
// TensorAccessorBase and TensorAccessor are used for both CPU and CUDA tensors.
|
||||
// For CUDA tensors it is used in device code (only). This means that we restrict ourselves
|
||||
// to functions and types available there (e.g. IntArrayRef isn't).
|
||||
|
||||
// The PtrTraits argument is only relevant to cuda to support `__restrict__` pointers.
|
||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||
class TensorAccessorBase {
|
||||
public:
|
||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||
using TensorAccessorBase = torch::headeronly::detail::TensorAccessorBase<c10::IntArrayRef, T, N, PtrTraits, index_t>;
|
||||
|
||||
C10_HOST_DEVICE TensorAccessorBase(
|
||||
PtrType data_,
|
||||
const index_t* sizes_,
|
||||
const index_t* strides_)
|
||||
: data_(data_), sizes_(sizes_), strides_(strides_) {}
|
||||
C10_HOST IntArrayRef sizes() const {
|
||||
return IntArrayRef(sizes_,N);
|
||||
}
|
||||
C10_HOST IntArrayRef strides() const {
|
||||
return IntArrayRef(strides_,N);
|
||||
}
|
||||
C10_HOST_DEVICE index_t stride(index_t i) const {
|
||||
return strides_[i];
|
||||
}
|
||||
C10_HOST_DEVICE index_t size(index_t i) const {
|
||||
return sizes_[i];
|
||||
}
|
||||
C10_HOST_DEVICE PtrType data() {
|
||||
return data_;
|
||||
}
|
||||
C10_HOST_DEVICE const PtrType data() const {
|
||||
return data_;
|
||||
}
|
||||
protected:
|
||||
PtrType data_;
|
||||
const index_t* sizes_;
|
||||
const index_t* strides_;
|
||||
};
|
||||
|
||||
// The `TensorAccessor` is typically instantiated for CPU `Tensor`s using
|
||||
// `Tensor.accessor<T, N>()`.
|
||||
// For CUDA `Tensor`s, `GenericPackedTensorAccessor` is used on the host and only
|
||||
// indexing on the device uses `TensorAccessor`s.
|
||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||
class TensorAccessor : public TensorAccessorBase<T,N,PtrTraits,index_t> {
|
||||
public:
|
||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||
using TensorAccessor = torch::headeronly::detail::TensorAccessor<c10::IntArrayRef, T, N, PtrTraits, index_t>;
|
||||
|
||||
C10_HOST_DEVICE TensorAccessor(
|
||||
PtrType data_,
|
||||
const index_t* sizes_,
|
||||
const index_t* strides_)
|
||||
: TensorAccessorBase<T, N, PtrTraits, index_t>(data_,sizes_,strides_) {}
|
||||
namespace detail {
|
||||
|
||||
C10_HOST_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
|
||||
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE const TensorAccessor<T, N-1, PtrTraits, index_t> operator[](index_t i) const {
|
||||
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, template <typename U> class PtrTraits, typename index_t>
|
||||
class TensorAccessor<T,1,PtrTraits,index_t> : public TensorAccessorBase<T,1,PtrTraits,index_t> {
|
||||
public:
|
||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||
|
||||
C10_HOST_DEVICE TensorAccessor(
|
||||
PtrType data_,
|
||||
const index_t* sizes_,
|
||||
const index_t* strides_)
|
||||
: TensorAccessorBase<T, 1, PtrTraits, index_t>(data_,sizes_,strides_) {}
|
||||
C10_HOST_DEVICE T & operator[](index_t i) {
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
|
||||
return this->data_[this->strides_[0]*i];
|
||||
}
|
||||
C10_HOST_DEVICE const T & operator[](index_t i) const {
|
||||
return this->data_[this->strides_[0]*i];
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// GenericPackedTensorAccessorBase and GenericPackedTensorAccessor are used on for CUDA `Tensor`s on the host
|
||||
// and as
|
||||
// In contrast to `TensorAccessor`s, they copy the strides and sizes on instantiation (on the host)
|
||||
// in order to transfer them on the device when calling kernels.
|
||||
// On the device, indexing of multidimensional tensors gives to `TensorAccessor`s.
|
||||
// Use RestrictPtrTraits as PtrTraits if you want the tensor's data pointer to be marked as __restrict__.
|
||||
// Instantiation from data, sizes, strides is only needed on the host and std::copy isn't available
|
||||
// on the device, so those functions are host only.
|
||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||
class GenericPackedTensorAccessorBase {
|
||||
public:
|
||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||
C10_HOST GenericPackedTensorAccessorBase(
|
||||
PtrType data_,
|
||||
const index_t* sizes_,
|
||||
const index_t* strides_)
|
||||
: data_(data_) {
|
||||
std::copy(sizes_, sizes_ + N, std::begin(this->sizes_));
|
||||
std::copy(strides_, strides_ + N, std::begin(this->strides_));
|
||||
}
|
||||
|
||||
// if index_t is not int64_t, we want to have an int64_t constructor
|
||||
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
|
||||
C10_HOST GenericPackedTensorAccessorBase(
|
||||
PtrType data_,
|
||||
const source_index_t* sizes_,
|
||||
const source_index_t* strides_)
|
||||
: data_(data_) {
|
||||
for (const auto i : c10::irange(N)) {
|
||||
this->sizes_[i] = sizes_[i];
|
||||
this->strides_[i] = strides_[i];
|
||||
}
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE index_t stride(index_t i) const {
|
||||
return strides_[i];
|
||||
}
|
||||
C10_HOST_DEVICE index_t size(index_t i) const {
|
||||
return sizes_[i];
|
||||
}
|
||||
C10_HOST_DEVICE PtrType data() {
|
||||
return data_;
|
||||
}
|
||||
C10_HOST_DEVICE const PtrType data() const {
|
||||
return data_;
|
||||
}
|
||||
protected:
|
||||
PtrType data_;
|
||||
// NOLINTNEXTLINE(*c-arrays*)
|
||||
index_t sizes_[N];
|
||||
// NOLINTNEXTLINE(*c-arrays*)
|
||||
index_t strides_[N];
|
||||
C10_HOST void bounds_check_(index_t i) const {
|
||||
TORCH_CHECK_INDEX(
|
||||
template <size_t N, typename index_t>
|
||||
struct IndexBoundsCheck {
|
||||
IndexBoundsCheck(index_t i) {
|
||||
TORCH_CHECK_INDEX(
|
||||
0 <= i && i < index_t{N},
|
||||
"Index ",
|
||||
i,
|
||||
" is not within bounds of a tensor of dimension ",
|
||||
N);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||
class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase<T,N,PtrTraits,index_t> {
|
||||
public:
|
||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||
|
||||
C10_HOST GenericPackedTensorAccessor(
|
||||
PtrType data_,
|
||||
const index_t* sizes_,
|
||||
const index_t* strides_)
|
||||
: GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
||||
|
||||
// if index_t is not int64_t, we want to have an int64_t constructor
|
||||
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
|
||||
C10_HOST GenericPackedTensorAccessor(
|
||||
PtrType data_,
|
||||
const source_index_t* sizes_,
|
||||
const source_index_t* strides_)
|
||||
: GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
||||
|
||||
C10_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
|
||||
index_t* new_sizes = this->sizes_ + 1;
|
||||
index_t* new_strides = this->strides_ + 1;
|
||||
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
|
||||
}
|
||||
|
||||
C10_DEVICE const TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) const {
|
||||
const index_t* new_sizes = this->sizes_ + 1;
|
||||
const index_t* new_strides = this->strides_ + 1;
|
||||
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
|
||||
}
|
||||
|
||||
/// Returns a PackedTensorAccessor of the same dimension after transposing the
|
||||
/// two dimensions given. Does not actually move elements; transposition is
|
||||
/// made by permuting the size/stride arrays. If the dimensions are not valid,
|
||||
/// asserts.
|
||||
C10_HOST GenericPackedTensorAccessor<T, N, PtrTraits, index_t> transpose(
|
||||
index_t dim1,
|
||||
index_t dim2) const {
|
||||
this->bounds_check_(dim1);
|
||||
this->bounds_check_(dim2);
|
||||
GenericPackedTensorAccessor<T, N, PtrTraits, index_t> result(
|
||||
this->data_, this->sizes_, this->strides_);
|
||||
std::swap(result.strides_[dim1], result.strides_[dim2]);
|
||||
std::swap(result.sizes_[dim1], result.sizes_[dim2]);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, template <typename U> class PtrTraits, typename index_t>
|
||||
class GenericPackedTensorAccessor<T,1,PtrTraits,index_t> : public GenericPackedTensorAccessorBase<T,1,PtrTraits,index_t> {
|
||||
public:
|
||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||
C10_HOST GenericPackedTensorAccessor(
|
||||
PtrType data_,
|
||||
const index_t* sizes_,
|
||||
const index_t* strides_)
|
||||
: GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
||||
|
||||
// if index_t is not int64_t, we want to have an int64_t constructor
|
||||
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
|
||||
C10_HOST GenericPackedTensorAccessor(
|
||||
PtrType data_,
|
||||
const source_index_t* sizes_,
|
||||
const source_index_t* strides_)
|
||||
: GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
||||
|
||||
C10_DEVICE T & operator[](index_t i) {
|
||||
return this->data_[this->strides_[0] * i];
|
||||
}
|
||||
C10_DEVICE const T& operator[](index_t i) const {
|
||||
return this->data_[this->strides_[0]*i];
|
||||
}
|
||||
|
||||
// Same as in the general N-dimensional case, but note that in the
|
||||
// 1-dimensional case the returned PackedTensorAccessor will always be an
|
||||
// identical copy of the original
|
||||
C10_HOST GenericPackedTensorAccessor<T, 1, PtrTraits, index_t> transpose(
|
||||
index_t dim1,
|
||||
index_t dim2) const {
|
||||
this->bounds_check_(dim1);
|
||||
this->bounds_check_(dim2);
|
||||
return GenericPackedTensorAccessor<T, 1, PtrTraits, index_t>(
|
||||
this->data_, this->sizes_, this->strides_);
|
||||
}
|
||||
};
|
||||
using GenericPackedTensorAccessorBase = torch::headeronly::detail::GenericPackedTensorAccessorBase<detail::IndexBoundsCheck<N, index_t>, T, N, PtrTraits, index_t>;
|
||||
|
||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||
using GenericPackedTensorAccessor = torch::headeronly::detail::GenericPackedTensorAccessor<TensorAccessor<T, N-1, PtrTraits, index_t>, detail::IndexBoundsCheck<N, index_t>, T, N, PtrTraits, index_t>;
|
||||
|
||||
// Can't put this directly into the macro function args because of commas
|
||||
#define AT_X GenericPackedTensorAccessor<T, N, PtrTraits, index_t>
|
||||
|
||||
@ -245,6 +245,9 @@ class TORCH_API TensorBase {
|
||||
size_t weak_use_count() const noexcept {
|
||||
return impl_.weak_use_count();
|
||||
}
|
||||
bool is_uniquely_owned() const noexcept {
|
||||
return impl_.is_uniquely_owned();
|
||||
}
|
||||
|
||||
std::string toString() const;
|
||||
|
||||
|
||||
@ -9,8 +9,8 @@ APIVitals VitalsAPI;
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, TorchVital const& tv) {
|
||||
for (const auto& m : tv.attrs) {
|
||||
os << "[TORCH_VITAL] " << tv.name << "." << m.first << "\t\t "
|
||||
<< m.second.value << "\n";
|
||||
os << "[TORCH_VITAL] " << tv.name << '.' << m.first << "\t\t "
|
||||
<< m.second.value << '\n';
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
@ -100,18 +100,18 @@ inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) {
|
||||
|
||||
// this does match the way things are represented in the schema
|
||||
inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
|
||||
out << "(";
|
||||
out << '(';
|
||||
bool first = true;
|
||||
for (const auto& set : aliasInfo.beforeSets()) {
|
||||
if (first) {
|
||||
first = false;
|
||||
} else {
|
||||
out << "|";
|
||||
out << '|';
|
||||
}
|
||||
out << set.toUnqualString();
|
||||
}
|
||||
if (aliasInfo.isWrite()) {
|
||||
out << "!";
|
||||
out << '!';
|
||||
}
|
||||
if (aliasInfo.beforeSets() != aliasInfo.afterSets()) {
|
||||
out << " -> ";
|
||||
@ -120,12 +120,12 @@ inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
|
||||
if (first) {
|
||||
first = false;
|
||||
} else {
|
||||
out << "|";
|
||||
out << '|';
|
||||
}
|
||||
out << set.toUnqualString();
|
||||
}
|
||||
}
|
||||
out << ")";
|
||||
out << ')';
|
||||
return out;
|
||||
}
|
||||
} // namespace c10
|
||||
|
||||
@ -198,7 +198,7 @@ inline void swap(Blob& lhs, Blob& rhs) noexcept {
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const Blob& v) {
|
||||
return out << "Blob[" << v.TypeName() << "]";
|
||||
return out << "Blob[" << v.TypeName() << ']';
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
@ -456,8 +456,8 @@ bool ClassType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
|
||||
*why_not << "Method on class '" << repr_str()
|
||||
<< "' (1) is not compatible with interface '"
|
||||
<< rhs.repr_str() << "' (2)\n"
|
||||
<< " (1) " << self_method->getSchema() << "\n"
|
||||
<< " (2) " << schema << "\n";
|
||||
<< " (1) " << self_method->getSchema() << '\n'
|
||||
<< " (2) " << schema << '\n';
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -100,7 +100,7 @@ struct TORCH_API ClassType : public NamedType {
|
||||
std::string repr_str() const override {
|
||||
std::stringstream ss;
|
||||
ss << str()
|
||||
<< " (of Python compilation unit at: " << compilation_unit().get() << ")";
|
||||
<< " (of Python compilation unit at: " << compilation_unit().get() << ')';
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
@ -58,12 +58,12 @@ std::string DispatchKeyExtractor::dumpState() const {
|
||||
std::ostringstream oss;
|
||||
for (const auto i : c10::irange(c10::utils::bitset::NUM_BITS())) {
|
||||
if (dispatch_arg_indices_reverse_.get(i)) {
|
||||
oss << "1";
|
||||
oss << '1';
|
||||
} else {
|
||||
oss << "0";
|
||||
oss << '0';
|
||||
}
|
||||
}
|
||||
oss << " " << nonFallthroughKeys_ << "\n";
|
||||
oss << ' ' << nonFallthroughKeys_ << '\n';
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
||||
@ -69,8 +69,8 @@ private:
|
||||
|
||||
void _print_dispatch_trace(const std::string& label, const std::string& op_name, const DispatchKeySet& dispatchKeySet) {
|
||||
auto nesting_value = dispatch_trace_nesting_value();
|
||||
for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
|
||||
std::cerr << label << " op=[" << op_name << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
|
||||
for (int64_t i = 0; i < nesting_value; ++i) std::cerr << ' ';
|
||||
std::cerr << label << " op=[" << op_name << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << ']' << std::endl;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
|
||||
@ -570,7 +570,7 @@ void OperatorEntry::checkInvariants() const {
|
||||
|
||||
std::string OperatorEntry::listAllDispatchKeys() const {
|
||||
std::ostringstream str;
|
||||
str << "[";
|
||||
str << '[';
|
||||
|
||||
bool has_kernels = false;
|
||||
for (auto k : allDispatchKeysInFullSet()) {
|
||||
@ -584,7 +584,7 @@ std::string OperatorEntry::listAllDispatchKeys() const {
|
||||
str << k;
|
||||
has_kernels = true;
|
||||
}
|
||||
str << "]";
|
||||
str << ']';
|
||||
return str.str();
|
||||
}
|
||||
|
||||
@ -683,12 +683,12 @@ void OperatorEntry::setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> c
|
||||
// This WON'T report backend fallbacks.
|
||||
std::string OperatorEntry::dumpState() const {
|
||||
std::ostringstream oss;
|
||||
oss << "name: " << name_ << "\n";
|
||||
oss << "name: " << name_ << '\n';
|
||||
if (schema_) {
|
||||
oss << "schema: " << schema_->schema << "\n";
|
||||
oss << "debug: " << schema_->debug << "\n";
|
||||
oss << "schema: " << schema_->schema << '\n';
|
||||
oss << "debug: " << schema_->debug << '\n';
|
||||
oss << "alias analysis kind: " << toString(schema_->schema.aliasAnalysis())
|
||||
<< (schema_->schema.isDefaultAliasAnalysisKind() ? " (default)" : "") << "\n";
|
||||
<< (schema_->schema.isDefaultAliasAnalysisKind() ? " (default)" : "") << '\n';
|
||||
} else {
|
||||
oss << "schema: (none)\n";
|
||||
}
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
namespace c10 {
|
||||
|
||||
void FunctionSchema::dump() const {
|
||||
std::cout << *this << "\n";
|
||||
std::cout << *this << '\n';
|
||||
}
|
||||
|
||||
const std::vector<Argument>& FunctionSchema::getCorrectList(SchemaArgType type) const {
|
||||
@ -210,9 +210,9 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
|
||||
|
||||
out << schema.name();
|
||||
if (!schema.overload_name().empty()) {
|
||||
out << "." << schema.overload_name();
|
||||
out << '.' << schema.overload_name();
|
||||
}
|
||||
out << "(";
|
||||
out << '(';
|
||||
|
||||
bool seen_kwarg_only = false;
|
||||
for (const auto i : c10::irange(schema.arguments().size())) {
|
||||
@ -273,7 +273,7 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
|
||||
}
|
||||
|
||||
if (need_paren) {
|
||||
out << "(";
|
||||
out << '(';
|
||||
}
|
||||
for (const auto i : c10::irange(returns.size())) {
|
||||
if (i > 0) {
|
||||
@ -288,7 +288,7 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
|
||||
out << "...";
|
||||
}
|
||||
if (need_paren) {
|
||||
out << ")";
|
||||
out << ')';
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -471,7 +471,7 @@ bool FunctionSchema::isForwardCompatibleWith(
|
||||
if (!arguments().at(i).isForwardCompatibleWith(old.arguments().at(i))) {
|
||||
if (why_not) {
|
||||
why_not
|
||||
<< "'" << arguments().at(i).name() << "'"
|
||||
<< '\'' << arguments().at(i).name() << '\''
|
||||
<< " is not forward compatible with the older version of the schema";
|
||||
}
|
||||
return false;
|
||||
@ -511,7 +511,7 @@ bool FunctionSchema::isForwardCompatibleWith(
|
||||
.isForwardCompatibleWith(old.arguments().at(i))) {
|
||||
if (why_not) {
|
||||
why_not << "Out argument '"
|
||||
<< "'" << arguments().at(i).name()
|
||||
<< '\'' << arguments().at(i).name()
|
||||
<< " is not FC with the older version of the schema";
|
||||
}
|
||||
return false;
|
||||
|
||||
@ -571,7 +571,7 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
|
||||
if (arg.N()) {
|
||||
N = std::to_string(*arg.N());
|
||||
}
|
||||
out << "[" << N << "]";
|
||||
out << '[' << N << ']';
|
||||
} else {
|
||||
out << unopt_type->str();
|
||||
}
|
||||
@ -582,15 +582,15 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
|
||||
}
|
||||
|
||||
if (is_opt) {
|
||||
out << "?";
|
||||
out << '?';
|
||||
}
|
||||
|
||||
if (!arg.name().empty()) {
|
||||
out << " " << arg.name();
|
||||
out << ' ' << arg.name();
|
||||
}
|
||||
|
||||
if (arg.default_value()) {
|
||||
out << "=";
|
||||
out << '=';
|
||||
if ((type->kind() == c10::TypeKind::StringType ||
|
||||
unopt_type->kind() == c10::TypeKind::StringType) &&
|
||||
arg.default_value().value().isString()) {
|
||||
|
||||
@ -66,7 +66,7 @@ bool operator==(const ivalue::Tuple& lhs, const ivalue::Tuple& rhs) {
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const ivalue::EnumHolder& v) {
|
||||
out << v.qualifiedClassName() << "." << v.name();
|
||||
out << v.qualifiedClassName() << '.' << v.name();
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -526,7 +526,7 @@ std::ostream& printMaybeAnnotatedList(
|
||||
!elementTypeCanBeInferredFromMembers(list_elem_type)) {
|
||||
out << "annotate(" << the_list.type<c10::Type>()->annotation_str() << ", ";
|
||||
printList(out, the_list.toListRef(), "[", "]", formatter);
|
||||
out << ")";
|
||||
out << ')';
|
||||
return out;
|
||||
} else {
|
||||
return printList(out, the_list.toListRef(), "[", "]", formatter);
|
||||
@ -538,7 +538,7 @@ std::ostream& printDict(
|
||||
std::ostream& out,
|
||||
const Dict& v,
|
||||
const IValueFormatter& formatter) {
|
||||
out << "{";
|
||||
out << '{';
|
||||
|
||||
bool first = true;
|
||||
for (const auto& pair : v) {
|
||||
@ -552,7 +552,7 @@ std::ostream& printDict(
|
||||
first = false;
|
||||
}
|
||||
|
||||
out << "}";
|
||||
out << '}';
|
||||
return out;
|
||||
}
|
||||
}
|
||||
@ -565,8 +565,8 @@ static std::ostream& printMaybeAnnotatedDict(
|
||||
auto value_type = the_dict.type()->castRaw<DictType>()->getValueType();
|
||||
if (the_dict.toGenericDict().empty() ||
|
||||
!elementTypeCanBeInferredFromMembers(value_type)) {
|
||||
out << "annotate(" << the_dict.type<c10::Type>()->annotation_str() << ",";
|
||||
printDict(out, the_dict.toGenericDict(), formatter) << ")";
|
||||
out << "annotate(" << the_dict.type<c10::Type>()->annotation_str() << ',';
|
||||
printDict(out, the_dict.toGenericDict(), formatter) << ')';
|
||||
} else {
|
||||
return printDict(out, the_dict.toGenericDict(), formatter);
|
||||
}
|
||||
@ -577,7 +577,7 @@ static std::ostream& printComplex(std::ostream & out, const IValue & v) {
|
||||
c10::complex<double> d = v.toComplexDouble();
|
||||
IValue real(d.real()), imag(std::abs(d.imag()));
|
||||
auto sign = d.imag() >= 0 ? '+' : '-';
|
||||
return out << real << sign << imag << "j";
|
||||
return out << real << sign << imag << 'j';
|
||||
}
|
||||
|
||||
std::ostream& IValue::repr(
|
||||
@ -605,9 +605,9 @@ std::ostream& IValue::repr(
|
||||
if (static_cast<double>(i) == d) {
|
||||
// -0.0 (signed zero) needs to be parsed as -0.
|
||||
if (i == 0 && std::signbit(d)) {
|
||||
return out << "-" << i << ".";
|
||||
return out << '-' << i << '.';
|
||||
}
|
||||
return out << i << ".";
|
||||
return out << i << '.';
|
||||
}
|
||||
}
|
||||
auto orig_prec = out.precision();
|
||||
@ -643,20 +643,20 @@ std::ostream& IValue::repr(
|
||||
device_stream << v.toDevice();
|
||||
out << "torch.device(";
|
||||
c10::printQuotedString(out, device_stream.str());
|
||||
return out << ")";
|
||||
return out << ')';
|
||||
}
|
||||
case IValue::Tag::Generator: {
|
||||
auto generator = v.toGenerator();
|
||||
out << "torch.Generator(device=";
|
||||
c10::printQuotedString(out, generator.device().str());
|
||||
out << ", seed=" << generator.current_seed() << ")";
|
||||
out << ", seed=" << generator.current_seed() << ')';
|
||||
return out;
|
||||
}
|
||||
case IValue::Tag::GenericDict:
|
||||
return printMaybeAnnotatedDict(out, v, formatter);
|
||||
case IValue::Tag::Enum: {
|
||||
auto enum_holder = v.toEnumHolder();
|
||||
return out << enum_holder->qualifiedClassName() << "." <<
|
||||
return out << enum_holder->qualifiedClassName() << '.' <<
|
||||
enum_holder->name();
|
||||
}
|
||||
case IValue::Tag::Object: {
|
||||
@ -801,7 +801,7 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
|
||||
if (c == FP_NORMAL || c == FP_ZERO) {
|
||||
int64_t i = static_cast<int64_t>(d);
|
||||
if (static_cast<double>(i) == d) {
|
||||
return out << i << ".";
|
||||
return out << i << '.';
|
||||
}
|
||||
}
|
||||
auto orig_prec = out.precision();
|
||||
@ -852,7 +852,7 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
|
||||
return printDict(out, v.toGenericDict(), formatter);
|
||||
case IValue::Tag::PyObject: {
|
||||
auto py_obj = v.toPyObject();
|
||||
return out << "<PyObject at" << py_obj << ">";
|
||||
return out << "<PyObject at" << py_obj << '>';
|
||||
}
|
||||
case IValue::Tag::Generator:
|
||||
return out << "Generator";
|
||||
@ -862,22 +862,22 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
|
||||
// TODO we should attempt to call __str__ if the object defines it.
|
||||
auto obj = v.toObject();
|
||||
// print this out the way python would do it
|
||||
return out << "<" << obj->name() << " object at " << obj.get() << ">";
|
||||
return out << '<' << obj->name() << " object at " << obj.get() << '>';
|
||||
}
|
||||
case IValue::Tag::Enum: {
|
||||
auto enum_holder = v.toEnumHolder();
|
||||
return out << "Enum<" << enum_holder->unqualifiedClassName() << "." <<
|
||||
enum_holder->name() << ">";
|
||||
return out << "Enum<" << enum_holder->unqualifiedClassName() << '.' <<
|
||||
enum_holder->name() << '>';
|
||||
}
|
||||
|
||||
}
|
||||
return out << "<Invalid IValue tag=" << std::to_string(static_cast<uint32_t>(v.tag)) << ">";
|
||||
return out << "<Invalid IValue tag=" << std::to_string(static_cast<uint32_t>(v.tag)) << '>';
|
||||
}
|
||||
|
||||
#undef TORCH_FORALL_TAGS
|
||||
|
||||
void IValue::dump() const {
|
||||
std::cout << *this << "\n";
|
||||
std::cout << *this << '\n';
|
||||
}
|
||||
|
||||
std::shared_ptr<ClassType> ivalue::Object::type() const {
|
||||
@ -1050,7 +1050,7 @@ c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(
|
||||
std::stringstream err;
|
||||
err << "Cannot serialize custom bound C++ class";
|
||||
if (auto qualname = type()->name()) {
|
||||
err << " " << qualname->qualifiedName();
|
||||
err << ' ' << qualname->qualifiedName();
|
||||
}
|
||||
err << ". Please define serialization methods via def_pickle() for "
|
||||
"this class.";
|
||||
|
||||
@ -211,7 +211,7 @@ struct TORCH_API OptionalType : public UnionType {
|
||||
|
||||
std::string str() const override {
|
||||
std::stringstream ss;
|
||||
ss << getElementType()->str() << "?";
|
||||
ss << getElementType()->str() << '?';
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -240,7 +240,7 @@ struct TORCH_API OptionalType : public UnionType {
|
||||
|
||||
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "Optional[" << getElementType()->annotation_str(printer) << "]";
|
||||
ss << "Optional[" << getElementType()->annotation_str(printer) << ']';
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
@ -906,7 +906,7 @@ struct TORCH_API ListType
|
||||
|
||||
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "List[" << getElementType()->annotation_str(printer) << "]";
|
||||
ss << "List[" << getElementType()->annotation_str(printer) << ']';
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
@ -946,7 +946,7 @@ struct TORCH_API DictType : public SharedType {
|
||||
std::string str() const override {
|
||||
std::stringstream ss;
|
||||
ss << "Dict(" << getKeyType()->str() << ", " << getValueType()->str()
|
||||
<< ")";
|
||||
<< ')';
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -1018,7 +1018,7 @@ struct TORCH_API FutureType
|
||||
|
||||
std::string str() const override {
|
||||
std::stringstream ss;
|
||||
ss << "Future(" << getElementType()->str() << ")";
|
||||
ss << "Future(" << getElementType()->str() << ')';
|
||||
return ss.str();
|
||||
}
|
||||
TypePtr createWithContained(
|
||||
@ -1041,7 +1041,7 @@ struct TORCH_API FutureType
|
||||
|
||||
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "Future[" << getElementType()->annotation_str(printer) << "]";
|
||||
ss << "Future[" << getElementType()->annotation_str(printer) << ']';
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
@ -1060,7 +1060,7 @@ struct TORCH_API AwaitType
|
||||
|
||||
std::string str() const override {
|
||||
std::stringstream ss;
|
||||
ss << "Await(" << getElementType()->str() << ")";
|
||||
ss << "Await(" << getElementType()->str() << ')';
|
||||
return ss.str();
|
||||
}
|
||||
TypePtr createWithContained(
|
||||
@ -1083,7 +1083,7 @@ struct TORCH_API AwaitType
|
||||
|
||||
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "Await[" << getElementType()->annotation_str(printer) << "]";
|
||||
ss << "Await[" << getElementType()->annotation_str(printer) << ']';
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
@ -1102,7 +1102,7 @@ struct TORCH_API RRefType
|
||||
|
||||
std::string str() const override {
|
||||
std::stringstream ss;
|
||||
ss << "RRef(" << getElementType()->str() << ")";
|
||||
ss << "RRef(" << getElementType()->str() << ')';
|
||||
return ss.str();
|
||||
}
|
||||
TypePtr createWithContained(
|
||||
@ -1115,7 +1115,7 @@ struct TORCH_API RRefType
|
||||
|
||||
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "RRef[" << getElementType()->annotation_str(printer) << "]";
|
||||
ss << "RRef[" << getElementType()->annotation_str(printer) << ']';
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
@ -11,7 +11,7 @@ std::string toString(const OperatorName& opName) {
|
||||
std::ostream& operator<<(std::ostream& os, const OperatorName& opName) {
|
||||
os << opName.name;
|
||||
if (!opName.overload_name.empty()) {
|
||||
os << "." << opName.overload_name;
|
||||
os << '.' << opName.overload_name;
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
@ -65,7 +65,7 @@ VaryingShape<T> VaryingShape<T>::merge(const VaryingShape<T>& other) const {
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& out, const VaryingShape<T>& vs) {
|
||||
out << "(";
|
||||
out << '(';
|
||||
if (!vs.size()) {
|
||||
out << "*)";
|
||||
return out;
|
||||
@ -79,10 +79,10 @@ std::ostream& operator<<(std::ostream& out, const VaryingShape<T>& vs) {
|
||||
if (v.has_value()) {
|
||||
out << v.value();
|
||||
} else {
|
||||
out << "*";
|
||||
out << '*';
|
||||
}
|
||||
}
|
||||
out << ")";
|
||||
out << ')';
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -105,7 +105,7 @@ std::ostream& operator<<(
|
||||
}
|
||||
auto sizes_opt = ss.sizes();
|
||||
|
||||
os << "(";
|
||||
os << '(';
|
||||
for (size_t i = 0; i < rank_opt.value(); i++) {
|
||||
if (i > 0) {
|
||||
os << ", ";
|
||||
@ -113,10 +113,10 @@ std::ostream& operator<<(
|
||||
if(sizes_opt.has_value() && sizes_opt.value()[i].is_static()) {
|
||||
os << sizes_opt.value()[i];
|
||||
} else {
|
||||
os << "*";
|
||||
os << '*';
|
||||
}
|
||||
}
|
||||
os << ")";
|
||||
os << ')';
|
||||
|
||||
return os;
|
||||
}
|
||||
@ -131,17 +131,17 @@ std::ostream& operator<<(std::ostream& os, const ShapeSymbol& s) {
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Stride& s) {
|
||||
os << "{";
|
||||
os << '{';
|
||||
if (s.stride_index_.has_value()) {
|
||||
os << *s.stride_index_;
|
||||
} else {
|
||||
os << "*";
|
||||
os << '*';
|
||||
}
|
||||
os << ":";
|
||||
os << ':';
|
||||
if (s.stride_.has_value()) {
|
||||
os << *s.stride_;
|
||||
} else {
|
||||
os << "*";
|
||||
os << '*';
|
||||
}
|
||||
os << '}';
|
||||
return os;
|
||||
|
||||
@ -67,7 +67,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
bool has_valid_strides_info = ndim > 0 &&
|
||||
value->strides().isComplete() && value->strides().size() == ndim;
|
||||
|
||||
out << "(";
|
||||
out << '(';
|
||||
size_t i = 0;
|
||||
bool symbolic = type_verbosity() == TypeVerbosity::Symbolic;
|
||||
for (i = 0; i < *ndim; ++i) {
|
||||
@ -79,7 +79,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
} else if (symbolic) {
|
||||
out << value->symbolic_sizes().at(i);
|
||||
} else {
|
||||
out << "*";
|
||||
out << '*';
|
||||
}
|
||||
}
|
||||
if (has_valid_strides_info &&
|
||||
@ -91,7 +91,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
}
|
||||
out << value->strides()[i].value();
|
||||
}
|
||||
out << "]";
|
||||
out << ']';
|
||||
}
|
||||
if (type_verbosity() >= TypeVerbosity::Full) {
|
||||
if (value->requiresGrad()) {
|
||||
@ -107,12 +107,12 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
out << "device=" << *value->device();
|
||||
}
|
||||
}
|
||||
out << ")";
|
||||
out << ')';
|
||||
} else {
|
||||
if (type_verbosity() >= TypeVerbosity::Full) {
|
||||
size_t i = 0;
|
||||
if (value->requiresGrad()) {
|
||||
out << "("
|
||||
out << '('
|
||||
<< "requires_grad=" << *value->requiresGrad();
|
||||
i++;
|
||||
}
|
||||
@ -120,7 +120,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
out << ((i++ > 0) ? ", " : "(") << "device=" << *value->device();
|
||||
}
|
||||
if (i > 0) {
|
||||
out << ")";
|
||||
out << ')';
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -133,18 +133,18 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
out << *prim << "[]";
|
||||
} else if (t.kind() == TypeKind::OptionalType) {
|
||||
auto prim = t.castRaw<OptionalType>()->getElementType();
|
||||
out << *prim << "?";
|
||||
out << *prim << '?';
|
||||
} else if(t.kind() == TypeKind::FutureType) {
|
||||
auto elem = t.castRaw<FutureType>()->getElementType();
|
||||
out << "Future[" << *elem << "]";
|
||||
out << "Future[" << *elem << ']';
|
||||
} else if(t.kind() == TypeKind::RRefType) {
|
||||
auto elem = t.castRaw<RRefType>()->getElementType();
|
||||
out << "RRef[" << *elem << "]";
|
||||
out << "RRef[" << *elem << ']';
|
||||
} else if(auto tup = t.cast<TupleType>()) {
|
||||
if (tup->schema()) {
|
||||
out << "NamedTuple";
|
||||
}
|
||||
out << "(";
|
||||
out << '(';
|
||||
for(size_t i = 0; i < tup->elements().size(); ++i) {
|
||||
if(i > 0)
|
||||
out << ", ";
|
||||
@ -160,7 +160,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
out << *(tup->elements()[i]);
|
||||
}
|
||||
}
|
||||
out << ")";
|
||||
out << ')';
|
||||
} else if (t.kind() == TypeKind::FunctionType) {
|
||||
out << "Function";
|
||||
} else {
|
||||
@ -475,7 +475,7 @@ std::optional<TypePtr> unifyTypeList(
|
||||
why_not << "Could not unify type list since element " << i << " of type "
|
||||
<< elements.at(i)->repr_str()
|
||||
<< " did not match the types before it ("
|
||||
<< ret_type->repr_str() << ")";
|
||||
<< ret_type->repr_str() << ')';
|
||||
return std::nullopt;
|
||||
}
|
||||
ret_type = *maybe_unified;
|
||||
@ -907,13 +907,13 @@ std::string TupleType::str() const {
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
ss << name()->qualifiedName();
|
||||
} else {
|
||||
ss << "(";
|
||||
ss << '(';
|
||||
for(size_t i = 0; i < elements().size(); ++i) {
|
||||
if(i > 0)
|
||||
ss << ", ";
|
||||
ss << elements()[i]->str();
|
||||
}
|
||||
ss << ")";
|
||||
ss << ')';
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
@ -1003,8 +1003,8 @@ bool InterfaceType::isSubTypeImpl(
|
||||
*why_not << "Method on interface '" << lhs.repr_str()
|
||||
<< "' (1) is not compatible with interface '"
|
||||
<< rhs.repr_str() << "' (2)\n"
|
||||
<< " (1) " << *self_schema << "\n"
|
||||
<< " (2) " << schema << "\n";
|
||||
<< " (1) " << *self_schema << '\n'
|
||||
<< " (2) " << schema << '\n';
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
@ -1078,7 +1078,7 @@ SymbolicShape SymbolicShape::merge(const SymbolicShape& other) const {
|
||||
}
|
||||
|
||||
void SymbolicShape::dump() const {
|
||||
std::cout << *this << "\n";
|
||||
std::cout << *this << '\n';
|
||||
}
|
||||
|
||||
bool EnumType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
|
||||
|
||||
@ -205,9 +205,9 @@ UnionType::UnionType(std::vector<TypePtr> reference, TypeKind kind) : SharedType
|
||||
for (const auto i : c10::irange(reference.size())) {
|
||||
msg << reference[i]->repr_str();
|
||||
if (i > 0) {
|
||||
msg << ",";
|
||||
msg << ',';
|
||||
}
|
||||
msg << " ";
|
||||
msg << ' ';
|
||||
}
|
||||
msg << "} has the single type " << types_[0]->repr_str()
|
||||
<< ". Use the common supertype instead of creating a Union"
|
||||
|
||||
@ -80,7 +80,7 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
|
||||
}
|
||||
stream << buf[i];
|
||||
}
|
||||
stream << "]";
|
||||
stream << ']';
|
||||
return stream;
|
||||
}
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
|
||||
}
|
||||
stream << buf[i];
|
||||
}
|
||||
stream << "]";
|
||||
stream << ']';
|
||||
return stream;
|
||||
}
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <shared_mutex>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cusparse.h>
|
||||
@ -88,8 +89,13 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
|
||||
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
|
||||
|
||||
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
|
||||
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace();
|
||||
struct WorkspaceMapWithMutex {
|
||||
std::map<std::tuple<void*, void*>, at::DataPtr> map;
|
||||
std::shared_mutex mutex;
|
||||
};
|
||||
|
||||
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublas_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
|
||||
TORCH_CUDA_CPP_API size_t getCUDABlasLtWorkspaceSize();
|
||||
TORCH_CUDA_CPP_API void* getCUDABlasLtWorkspace();
|
||||
|
||||
@ -175,17 +175,24 @@ void CUDAGraph::instantiate() {
|
||||
// Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
|
||||
// who prefer not to report error message through these arguments moving forward
|
||||
// (they prefer return value, or errors on api calls internal to the capture)
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
|
||||
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, 0));
|
||||
// ROCM appears to fail with HIP error: invalid argument
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && !defined(USE_ROCM)
|
||||
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, cudaGraphInstantiateFlagUseNodePriority));
|
||||
#else
|
||||
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
|
||||
#endif
|
||||
//Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory.
|
||||
//It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch.
|
||||
} else {
|
||||
#if !defined(USE_ROCM)
|
||||
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
|
||||
graph_,
|
||||
cudaGraphInstantiateFlagAutoFreeOnLaunch | cudaGraphInstantiateFlagUseNodePriority));
|
||||
#else
|
||||
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
|
||||
graph_,
|
||||
cudaGraphInstantiateFlagAutoFreeOnLaunch));
|
||||
#endif
|
||||
}
|
||||
has_graph_exec_ = true;
|
||||
}
|
||||
|
||||
@ -99,7 +99,7 @@ void destroyCublasHandle(cublasHandle_t handle) {
|
||||
// - Comments of @soumith copied from cuDNN handle pool implementation
|
||||
#ifdef NO_CUDNN_DESTROY_HANDLE
|
||||
#else
|
||||
cublasDestroy(handle);
|
||||
cublasDestroy(handle);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -107,19 +107,27 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
|
||||
|
||||
} // namespace
|
||||
|
||||
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
|
||||
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
|
||||
WorkspaceMapWithMutex& cublas_handle_stream_to_workspace() {
|
||||
static auto& instance = *new WorkspaceMapWithMutex;
|
||||
return instance;
|
||||
}
|
||||
|
||||
std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace() {
|
||||
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
|
||||
WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace() {
|
||||
static auto& instance = *new WorkspaceMapWithMutex;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void clearCublasWorkspaces() {
|
||||
cublas_handle_stream_to_workspace().clear();
|
||||
cublaslt_handle_stream_to_workspace().clear();
|
||||
{
|
||||
auto& workspace = cublas_handle_stream_to_workspace();
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
workspace.map.clear();
|
||||
}
|
||||
{
|
||||
auto& workspace = cublaslt_handle_stream_to_workspace();
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
workspace.map.clear();
|
||||
}
|
||||
}
|
||||
|
||||
size_t parseChosenWorkspaceSize() {
|
||||
@ -233,6 +241,38 @@ at::DataPtr getNewCUDABlasLtWorkspace() {
|
||||
return c10::cuda::CUDACachingAllocator::get()->allocate(getCUDABlasLtWorkspaceSize());
|
||||
}
|
||||
|
||||
void setWorkspaceForHandle(cublasHandle_t handle, c10::cuda::CUDAStream stream) {
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
|
||||
auto& workspace = cublas_handle_stream_to_workspace();
|
||||
|
||||
size_t workspace_size = getChosenWorkspaceSize();
|
||||
|
||||
// Fast path: check if workspace already exists
|
||||
{
|
||||
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
if (workspace_it != workspace.map.end()) {
|
||||
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(
|
||||
handle, workspace_it->second.get(), workspace_size));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: allocate workspace outside the lock
|
||||
auto new_workspace = getNewWorkspace();
|
||||
|
||||
// Insert with lock (double-check in case another thread inserted while we
|
||||
// were allocating)
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.try_emplace(key, std::move(new_workspace)).first;
|
||||
TORCH_CUDABLAS_CHECK(
|
||||
cublasSetWorkspace(handle, workspace_it->second.get(), workspace_size));
|
||||
}
|
||||
}
|
||||
|
||||
void* getCUDABlasLtWorkspace() {
|
||||
#ifndef USE_ROCM
|
||||
static bool unified = c10::utils::check_env(TORCH_CUBLASLT_UNIFIED_WORKSPACE) == true;
|
||||
@ -241,8 +281,10 @@ void* getCUDABlasLtWorkspace() {
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
|
||||
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
|
||||
auto& workspace = at::cuda::cublas_handle_stream_to_workspace();
|
||||
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
TORCH_INTERNAL_ASSERT(workspace_it != workspace.map.end());
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
#endif
|
||||
@ -250,11 +292,29 @@ void* getCUDABlasLtWorkspace() {
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
auto workspace_it = cublaslt_handle_stream_to_workspace().find(key);
|
||||
if (workspace_it == cublaslt_handle_stream_to_workspace().end()) {
|
||||
workspace_it = cublaslt_handle_stream_to_workspace().insert(workspace_it, {key, getNewCUDABlasLtWorkspace()});
|
||||
|
||||
auto& workspace = cublaslt_handle_stream_to_workspace();
|
||||
|
||||
// Fast path: check if workspace already exists
|
||||
{
|
||||
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
if (workspace_it != workspace.map.end()) {
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: allocate workspace outside the lock
|
||||
auto new_workspace = getNewCUDABlasLtWorkspace();
|
||||
|
||||
// Insert with lock (double-check in case another thread inserted while we
|
||||
// were allocating)
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it =
|
||||
workspace.map.try_emplace(key, std::move(new_workspace)).first;
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
|
||||
cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
@ -298,13 +358,8 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
// will allocate memory dynamically (even if they're cheap) outside
|
||||
// PyTorch's CUDA caching allocator. It's possible that CCA used up
|
||||
// all the memory and cublas's cudaMallocAsync will return OOM
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
auto workspace_it = cublas_handle_stream_to_workspace().find(key);
|
||||
if (workspace_it == cublas_handle_stream_to_workspace().end()) {
|
||||
workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
|
||||
}
|
||||
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
|
||||
setWorkspaceForHandle(handle, stream);
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
|
||||
// FP32 data type calculations based on the value of the allow_tf32 flag.
|
||||
|
||||
@ -411,16 +411,16 @@ std::string CUDAHooks::showConfig() const {
|
||||
// HIP_VERSION value format was changed after ROCm v4.2 to include the patch number
|
||||
if(v < 500) {
|
||||
// If major=xx, minor=yy then format -> xxyy
|
||||
oss << (v / 100) << "." << (v % 10);
|
||||
oss << (v / 100) << '.' << (v % 10);
|
||||
}
|
||||
else {
|
||||
// If major=xx, minor=yy & patch=zzzzz then format -> xxyyzzzzz
|
||||
oss << (v / 10000000) << "." << (v / 100000 % 100) << "." << (v % 100000);
|
||||
oss << (v / 10000000) << '.' << (v / 100000 % 100) << '.' << (v % 100000);
|
||||
}
|
||||
#else
|
||||
oss << (v / 1000) << "." << (v / 10 % 100);
|
||||
oss << (v / 1000) << '.' << (v / 10 % 100);
|
||||
if (v % 10 != 0) {
|
||||
oss << "." << (v % 10);
|
||||
oss << '.' << (v % 10);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
@ -431,16 +431,16 @@ std::string CUDAHooks::showConfig() const {
|
||||
oss << " - HIP Runtime ";
|
||||
#endif
|
||||
printCudaStyleVersion(runtimeVersion);
|
||||
oss << "\n";
|
||||
oss << '\n';
|
||||
|
||||
// TODO: Make HIPIFY understand CUDART_VERSION macro
|
||||
#if !defined(USE_ROCM)
|
||||
if (runtimeVersion != CUDART_VERSION) {
|
||||
oss << " - Built with CUDA Runtime ";
|
||||
printCudaStyleVersion(CUDART_VERSION);
|
||||
oss << "\n";
|
||||
oss << '\n';
|
||||
}
|
||||
oss << " - NVCC architecture flags: " << NVCC_FLAGS_EXTRA << "\n";
|
||||
oss << " - NVCC architecture flags: " << NVCC_FLAGS_EXTRA << '\n';
|
||||
#endif
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
@ -448,9 +448,9 @@ std::string CUDAHooks::showConfig() const {
|
||||
|
||||
|
||||
auto printCudnnStyleVersion = [&](size_t v) {
|
||||
oss << (v / 1000) << "." << (v / 100 % 10);
|
||||
oss << (v / 1000) << '.' << (v / 100 % 10);
|
||||
if (v % 100 != 0) {
|
||||
oss << "." << (v % 100);
|
||||
oss << '.' << (v % 100);
|
||||
}
|
||||
};
|
||||
|
||||
@ -461,22 +461,22 @@ std::string CUDAHooks::showConfig() const {
|
||||
if (cudnnCudartVersion != CUDART_VERSION) {
|
||||
oss << " (built against CUDA ";
|
||||
printCudaStyleVersion(cudnnCudartVersion);
|
||||
oss << ")";
|
||||
oss << ')';
|
||||
}
|
||||
oss << "\n";
|
||||
oss << '\n';
|
||||
if (cudnnVersion != CUDNN_VERSION) {
|
||||
oss << " - Built with CuDNN ";
|
||||
printCudnnStyleVersion(CUDNN_VERSION);
|
||||
oss << "\n";
|
||||
oss << '\n';
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
// TODO: Check if miopen has the functions above and unify
|
||||
oss << " - MIOpen " << MIOPEN_VERSION_MAJOR << "." << MIOPEN_VERSION_MINOR << "." << MIOPEN_VERSION_PATCH << "\n";
|
||||
oss << " - MIOpen " << MIOPEN_VERSION_MAJOR << '.' << MIOPEN_VERSION_MINOR << '.' << MIOPEN_VERSION_PATCH << '\n';
|
||||
#endif
|
||||
|
||||
#if AT_MAGMA_ENABLED()
|
||||
oss << " - Magma " << MAGMA_VERSION_MAJOR << "." << MAGMA_VERSION_MINOR << "." << MAGMA_VERSION_MICRO << "\n";
|
||||
oss << " - Magma " << MAGMA_VERSION_MAJOR << '.' << MAGMA_VERSION_MINOR << '.' << MAGMA_VERSION_MICRO << '\n';
|
||||
#endif
|
||||
|
||||
return oss.str();
|
||||
|
||||
@ -42,7 +42,7 @@ static inline void launch_jitted_vectorized_kernel_dynamic(
|
||||
|
||||
// The cache key includes all the parameters to generate_code + vec_size + dev_idx
|
||||
std::stringstream ss;
|
||||
ss << nInputs << "_" << nOutputs << f;
|
||||
ss << nInputs << '_' << nOutputs << f;
|
||||
ss << f_inputs_type_str << compute_type_str << result_type_str;
|
||||
ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
|
||||
ss << extra_args_types;
|
||||
@ -144,7 +144,7 @@ static inline void launch_jitted_unrolled_kernel_dynamic(
|
||||
|
||||
// The cache key includes all the parameters to generate_code + dev_idx
|
||||
std::stringstream ss;
|
||||
ss << nInputs << "_" << nOutputs << f;
|
||||
ss << nInputs << '_' << nOutputs << f;
|
||||
ss << f_inputs_type_str << compute_type_str << result_type_str;
|
||||
ss << contiguous << dynamic_casting;
|
||||
ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
|
||||
|
||||
@ -52,10 +52,10 @@ TuningContext* getTuningContext() {
|
||||
std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry) {
|
||||
static const bool blaslog = c10::utils::get_env("PYTORCH_TUNABLEOP_BLAS_LOG") == "1";
|
||||
if (!blaslog) {
|
||||
return stream << entry.key_ << "," << entry.time_;
|
||||
return stream << entry.key_ << ',' << entry.time_;
|
||||
}
|
||||
else {
|
||||
return stream << entry.key_ << "," << entry.time_ << ",BLAS_PARAMS: " << entry.blas_sig_;
|
||||
return stream << entry.key_ << ',' << entry.time_ << ",BLAS_PARAMS: " << entry.blas_sig_;
|
||||
}
|
||||
}
|
||||
|
||||
@ -156,10 +156,10 @@ void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std
|
||||
if (isNew) {
|
||||
static const bool blaslog = c10::utils::get_env("PYTORCH_TUNABLEOP_BLAS_LOG") == "1";
|
||||
if (!blaslog) {
|
||||
untuned_file << op_signature << "," << params_signature << std::endl;
|
||||
untuned_file << op_signature << ',' << params_signature << std::endl;
|
||||
}
|
||||
else {
|
||||
untuned_file << op_signature << "," << params_signature << ",BLAS_PARAMS: " << blas_signature << std::endl;
|
||||
untuned_file << op_signature << ',' << params_signature << ",BLAS_PARAMS: " << blas_signature << std::endl;
|
||||
}
|
||||
TUNABLE_LOG3("Untuned,", op_signature, ",", params_signature);
|
||||
}
|
||||
@ -201,7 +201,7 @@ void TuningResultsManager::InitRealtimeAppend(const std::string& filename, const
|
||||
|
||||
if(!file_exists || file_empty) {
|
||||
for(const auto& [key, val] : validators) {
|
||||
(*realtime_out_) << "Validator," << key << "," << val << std::endl;
|
||||
(*realtime_out_) << "Validator," << key << ',' << val << std::endl;
|
||||
realtime_out_->flush();
|
||||
}
|
||||
validators_written_ = true;
|
||||
@ -219,7 +219,7 @@ void TuningResultsManager::AppendResultLine(const std::string& op_sig, const std
|
||||
return;
|
||||
}
|
||||
|
||||
(*realtime_out_) << op_sig << "," << param_sig << "," << result << std::endl;
|
||||
(*realtime_out_) << op_sig << ',' << param_sig << ',' << result << std::endl;
|
||||
realtime_out_->flush(); //ensure immediate write to disk
|
||||
|
||||
TUNABLE_LOG3("Realtime append: ", op_sig, "(", param_sig, ") -> ", result);
|
||||
|
||||
@ -93,31 +93,31 @@ std::string cudnnTypeToString(cudnnDataType_t dtype) {
|
||||
return "CUDNN_DATA_UINT8x4";
|
||||
default:
|
||||
std::ostringstream oss;
|
||||
oss << "(unknown data-type " << static_cast<int>(dtype) << ")";
|
||||
oss << "(unknown data-type " << static_cast<int>(dtype) << ')';
|
||||
return oss.str();
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d) {
|
||||
out << "TensorDescriptor " << static_cast<void*>(d.desc()) << "\n";
|
||||
out << "TensorDescriptor " << static_cast<void*>(d.desc()) << '\n';
|
||||
int nbDims = 0;
|
||||
int dimA[CUDNN_DIM_MAX];
|
||||
int strideA[CUDNN_DIM_MAX];
|
||||
cudnnDataType_t dtype{};
|
||||
cudnnGetTensorNdDescriptor(d.desc(), CUDNN_DIM_MAX, &dtype, &nbDims, dimA, strideA);
|
||||
out << " type = " << cudnnTypeToString(dtype) << "\n";
|
||||
out << " nbDims = " << nbDims << "\n";
|
||||
out << " type = " << cudnnTypeToString(dtype) << '\n';
|
||||
out << " nbDims = " << nbDims << '\n';
|
||||
// Read out only nbDims of the arrays!
|
||||
out << " dimA = ";
|
||||
for (auto i : ArrayRef<int>{dimA, static_cast<size_t>(nbDims)}) {
|
||||
out << i << ", ";
|
||||
}
|
||||
out << "\n";
|
||||
out << '\n';
|
||||
out << " strideA = ";
|
||||
for (auto i : ArrayRef<int>{strideA, static_cast<size_t>(nbDims)}) {
|
||||
out << i << ", ";
|
||||
}
|
||||
out << "\n";
|
||||
out << '\n';
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -168,27 +168,27 @@ std::string cudnnMemoryFormatToString(cudnnTensorFormat_t tformat) {
|
||||
return "CUDNN_TENSOR_NHWC";
|
||||
default:
|
||||
std::ostringstream oss;
|
||||
oss << "(unknown cudnn tensor format " << static_cast<int>(tformat) << ")";
|
||||
oss << "(unknown cudnn tensor format " << static_cast<int>(tformat) << ')';
|
||||
return oss.str();
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d) {
|
||||
out << "FilterDescriptor " << static_cast<void*>(d.desc()) << "\n";
|
||||
out << "FilterDescriptor " << static_cast<void*>(d.desc()) << '\n';
|
||||
int nbDims = 0;
|
||||
int dimA[CUDNN_DIM_MAX];
|
||||
cudnnDataType_t dtype{};
|
||||
cudnnTensorFormat_t tformat{};
|
||||
cudnnGetFilterNdDescriptor(d.desc(), CUDNN_DIM_MAX, &dtype, &tformat, &nbDims, dimA);
|
||||
out << " type = " << cudnnTypeToString(dtype) << "\n";
|
||||
out << " tensor_format = " << cudnnMemoryFormatToString(tformat) << "\n";
|
||||
out << " nbDims = " << nbDims << "\n";
|
||||
out << " type = " << cudnnTypeToString(dtype) << '\n';
|
||||
out << " tensor_format = " << cudnnMemoryFormatToString(tformat) << '\n';
|
||||
out << " nbDims = " << nbDims << '\n';
|
||||
// Read out only nbDims of the arrays!
|
||||
out << " dimA = ";
|
||||
for (auto i : ArrayRef<int>{dimA, static_cast<size_t>(nbDims)}) {
|
||||
out << i << ", ";
|
||||
}
|
||||
out << "\n";
|
||||
out << '\n';
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -346,15 +346,15 @@ void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int6
|
||||
}
|
||||
|
||||
std::ostream& operator<< (std::ostream& os, const DynamicLayer& layer) {
|
||||
os << layer.layerId() << ":" << layer.key();
|
||||
os << layer.layerId() << ':' << layer.key();
|
||||
return os;
|
||||
}
|
||||
std::ostream& operator<< (std::ostream& os, const std::vector<DynamicLayer>& dls) {
|
||||
os << "DynamicLayerStack[ ";
|
||||
for (const auto& layer : dls) {
|
||||
os << layer << " ";
|
||||
os << layer << ' ';
|
||||
}
|
||||
os << "]";
|
||||
os << ']';
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ void dumpTensor(std::ostream& ss, const Tensor& tensor) {
|
||||
if (batched) {
|
||||
ss << "Batched[lvl=" << batched->level() << " dim=" << batched->bdim() << ", ";
|
||||
dumpTensor(ss, batched->value());
|
||||
ss << "]";
|
||||
ss << ']';
|
||||
return;
|
||||
}
|
||||
ss << "Tensor" << tensor.sizes();
|
||||
@ -36,7 +36,7 @@ void dumpTensor(std::ostream& ss, const Tensor& tensor) {
|
||||
ss << "dead, ";
|
||||
}
|
||||
dumpTensor(ss, wrapped->value());
|
||||
ss << "]";
|
||||
ss << ']';
|
||||
}
|
||||
|
||||
void TensorWrapper::refreshMetadata() {
|
||||
|
||||
@ -73,32 +73,32 @@ std::string miopenTypeToString(miopenDataType_t dtype) {
|
||||
return "miopenBFloat16";
|
||||
default:
|
||||
std::ostringstream oss;
|
||||
oss << "(unknown data-type " << static_cast<int>(dtype) << ")";
|
||||
oss << "(unknown data-type " << static_cast<int>(dtype) << ')';
|
||||
return oss.str();
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d) {
|
||||
out << "TensorDescriptor " << static_cast<void*>(d.desc()) << "\n";
|
||||
out << "TensorDescriptor " << static_cast<void*>(d.desc()) << '\n';
|
||||
int nbDims = 0;
|
||||
int dimA[MIOPEN_DIM_MAX];
|
||||
int strideA[MIOPEN_DIM_MAX];
|
||||
miopenDataType_t dtype;
|
||||
miopenGetTensorDescriptorSize(d.desc(), &nbDims);
|
||||
miopenGetTensorDescriptor(d.desc(), &dtype, dimA, strideA);
|
||||
out << " type = " << miopenTypeToString(dtype) << "\n";
|
||||
out << " nbDims = " << nbDims << "\n";
|
||||
out << " type = " << miopenTypeToString(dtype) << '\n';
|
||||
out << " nbDims = " << nbDims << '\n';
|
||||
// Read out only nbDims of the arrays!
|
||||
out << " dimA = ";
|
||||
for (auto i : ArrayRef<int>{dimA, static_cast<size_t>(nbDims)}) {
|
||||
out << i << ", ";
|
||||
}
|
||||
out << "\n";
|
||||
out << '\n';
|
||||
out << " strideA = ";
|
||||
for (auto i : ArrayRef<int>{strideA, static_cast<size_t>(nbDims)}) {
|
||||
out << i << ", ";
|
||||
}
|
||||
out << "\n";
|
||||
out << '\n';
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -91,7 +91,7 @@ struct OperationInfo : BaseInfo {
|
||||
std::stringstream kernelStr;
|
||||
kernelStr << kernelName;
|
||||
for (const Tensor& tensor : tensors) {
|
||||
kernelStr << ":" << BaseInfo::buildTensorString(tensor, includeBufferId);
|
||||
kernelStr << ':' << BaseInfo::buildTensorString(tensor, includeBufferId);
|
||||
}
|
||||
return kernelStr.str();
|
||||
}
|
||||
|
||||
@ -39,9 +39,9 @@ std::string BaseInfo::buildTensorString(const Tensor& tensor, bool includeBuffer
|
||||
// see comments for INCLUDE_BUFFER_ID
|
||||
if (includeBufferId && deviceType == at::kMPS) {
|
||||
id<MTLBuffer> buffer = __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
||||
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer)) << ":" << buffer.retainCount << ")";
|
||||
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer)) << ':' << buffer.retainCount << ')';
|
||||
}
|
||||
tensorStr << ":" << tensor.scalar_type() << tensor.sizes();
|
||||
tensorStr << ':' << tensor.scalar_type() << tensor.sizes();
|
||||
return tensorStr.str();
|
||||
} else {
|
||||
return "undefined";
|
||||
|
||||
@ -167,7 +167,7 @@ static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, co
|
||||
std::stringstream ss;
|
||||
ss << arg_name << " should be greater than zero but got (";
|
||||
std::copy(args.begin(), args.end() - 1, std::ostream_iterator<int>(ss,", "));
|
||||
ss << args.back() << ")" << " (while checking arguments for " << c << ")";
|
||||
ss << args.back() << ")" << " (while checking arguments for " << c << ')';
|
||||
TORCH_CHECK(false, ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
@ -639,7 +639,7 @@ static std::ostream& operator<<(std::ostream & out, const ConvParams<T>& params)
|
||||
<< " deterministic = " << params.deterministic
|
||||
<< " cudnn_enabled = " << params.cudnn_enabled
|
||||
<< " allow_tf32 = " << params.allow_tf32
|
||||
<< "}";
|
||||
<< '}';
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -1936,7 +1936,7 @@ static bool should_fold(const Tensor& tensor1, const Tensor& tensor2, bool has_o
|
||||
|
||||
// We order the tensors. t1 will be the larger tensor
|
||||
// We can always transpose tensor2 as the dimensions are always >= 1 (precondition from matmul)
|
||||
// and tensor1_larger iff tensor2.dim() > tensor1.dim()
|
||||
// and tensor1_larger iff tensor2.dim() > tensor1.dim(9
|
||||
const auto t1 = tensor1_larger ? MaybeOwned<Tensor>::borrowed(tensor1)
|
||||
: MaybeOwned<Tensor>::owned(tensor2.mT());
|
||||
const int64_t dim_t1 = t1->dim();
|
||||
@ -1948,11 +1948,20 @@ static bool should_fold(const Tensor& tensor1, const Tensor& tensor2, bool has_o
|
||||
return false;
|
||||
}
|
||||
|
||||
// If we require a gradient, we should fold to minimize backward memory usage - even if this
|
||||
// leads to a copy in forward because is needed in backward,
|
||||
// only time we avoid this strict pre-allocated memory usage (has_out = True)
|
||||
bool requires_grad = tensor1.requires_grad() || tensor2.requires_grad();
|
||||
if (requires_grad && !has_out) {
|
||||
// In this case we *do* incur in an extra copy to avoid creating an unnecessary large tensor in the backward
|
||||
// Suppose we don't fold here. Let t1.shape = [b, m, n] t2.shape = [n, k] like in a transformer
|
||||
// t2 will be expanded to a tensor of shape [b, n, k] and then we do t1.bmm(t2_expanded)
|
||||
// The issue appears in the backward.
|
||||
// The output gradient g of this operation would have shape [b, m, k]
|
||||
// The backward wrt. t2 of bmm would be given by t1.mH @ g, which has shape [b, n, k]
|
||||
// Then, the backward of expand is simply `sum(0)`. As such, we are instantiating a tensor
|
||||
// of shape [b, n, k] unnecessarily, which may cause a large memory footprint, and in the
|
||||
// worst case, an OOM
|
||||
bool t2_requires_grad = tensor1_larger ? tensor2.requires_grad() : tensor1.requires_grad();
|
||||
if (t2_requires_grad && !has_out) {
|
||||
// We should be checking !at::GradMode::is_enabled(), but apparently
|
||||
// this regresses performance in some cases:
|
||||
// https://github.com/pytorch/pytorch/issues/118548#issuecomment-1916022394
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -3532,9 +3541,9 @@ Tensor _dyn_quant_matmul_4bit_cpu(
|
||||
const int64_t out_features) {
|
||||
auto M = inp.size(0);
|
||||
TORCH_CHECK(
|
||||
inp.dtype() == kFloat,
|
||||
inp.dtype() == kFloat || (inp.dtype() == kBFloat16 && block_size == in_features),
|
||||
__func__,
|
||||
" : expect input to be 32-bit float tensor.");
|
||||
" : expect input to be float32 or bfloat16 tensor.");
|
||||
TORCH_CHECK(
|
||||
block_size == in_features ||
|
||||
(!(block_size % 32) && !(in_features % block_size)),
|
||||
|
||||
@ -847,7 +847,7 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional<int64_t
|
||||
<< ", hop_length=" << hop_length << ", win_length=" << win_length \
|
||||
<< ", window="; \
|
||||
if (window.defined()) { \
|
||||
SS << window.toString() << "{" << window.sizes() << "}"; \
|
||||
SS << window.toString() << '{' << window.sizes() << '}'; \
|
||||
} else { \
|
||||
SS << "None"; \
|
||||
} \
|
||||
@ -1046,7 +1046,7 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const std::optional<int64_
|
||||
<< ", hop_length=" << hop_length << ", win_length=" << win_length \
|
||||
<< ", window="; \
|
||||
if (window.defined()) { \
|
||||
SS << window.toString() << "{" << window.sizes() << "}"; \
|
||||
SS << window.toString() << '{' << window.sizes() << '}'; \
|
||||
} else { \
|
||||
SS << "None"; \
|
||||
} \
|
||||
|
||||
@ -1087,7 +1087,8 @@ TORCH_IMPL_FUNC(index_copy_out)
|
||||
result.copy_(self);
|
||||
|
||||
// See Note [Enabling Deterministic Operations]
|
||||
if (result.is_cuda() && globalContext().deterministicAlgorithms()) {
|
||||
if ((result.is_cuda() || result.is_xpu()) &&
|
||||
globalContext().deterministicAlgorithms()) {
|
||||
torch::List<std::optional<Tensor>> indices;
|
||||
indices.resize(dim + 1);
|
||||
indices.set(dim, index);
|
||||
|
||||
@ -523,7 +523,7 @@ Tensor _functional_assert_async_msg_cpu(
|
||||
}
|
||||
|
||||
void _print(std::string_view s) {
|
||||
std::cout << s << "\n";
|
||||
std::cout << s << '\n';
|
||||
}
|
||||
|
||||
// Sorting-based algorithm for isin(); used when the number of test elements is
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/native/cpu/int_mm_kernel.h>
|
||||
#include <ATen/native/cpu/utils.h>
|
||||
#include <cmath>
|
||||
#include <c10/util/Unroll.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
@ -793,6 +794,139 @@ bool can_use_kleidiai(
|
||||
}
|
||||
#endif
|
||||
|
||||
static void ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
|
||||
size_t m,
|
||||
size_t n,
|
||||
size_t k,
|
||||
const uint16_t* lhs_bf16,
|
||||
const uint8_t* rhs_qs4cx,
|
||||
const float* rhs_scales,
|
||||
uint16_t* dst_bf16,
|
||||
float scalar_min,
|
||||
float scalar_max,
|
||||
const float* bias) {
|
||||
// Roundup lambda for internal stride calculations
|
||||
auto roundup = [](size_t a, size_t b) { return ((a + b - 1) / b) * b; };
|
||||
|
||||
// Cast bfloat16 to float32 inline
|
||||
auto cast_bf16_to_f32 = [](uint16_t bf16_val) {
|
||||
uint32_t tmp = static_cast<uint32_t>(bf16_val) << 16;
|
||||
float f;
|
||||
std::memcpy(&f, &tmp, sizeof(f));
|
||||
return f;
|
||||
};
|
||||
|
||||
// Cast float32 to bfloat16 inline
|
||||
auto cast_f32_to_bf16 = [](float f) {
|
||||
uint32_t bits;
|
||||
std::memcpy(&bits, &f, sizeof(bits));
|
||||
return static_cast<uint16_t>(bits >> 16);
|
||||
};
|
||||
|
||||
// Quantization pack lambda (channelwise QA8DX)
|
||||
auto quant_pack_8bit_channelwise =
|
||||
[&](size_t M, size_t K, const uint16_t* src_bf16, int8_t* dst_qa8dx) {
|
||||
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
|
||||
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
|
||||
|
||||
const size_t dst_stride =
|
||||
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
|
||||
for (size_t i = 0; i < M; ++i) {
|
||||
const uint16_t* row_ptr = src_bf16 + i * K;
|
||||
// find min/max
|
||||
float mn = FLT_MAX, mx = -FLT_MAX;
|
||||
for (size_t j = 0; j < K; ++j) {
|
||||
float v = cast_bf16_to_f32(row_ptr[j]);
|
||||
mn = std::min(mn, v);
|
||||
mx = std::max(mx, v);
|
||||
}
|
||||
float rmin = std::min(0.0f, mn);
|
||||
float rmax = std::max(0.0f, mx);
|
||||
constexpr float qmin = static_cast<float>(kI8Min);
|
||||
constexpr float qmax = static_cast<float>(kI8Max);
|
||||
float scale = (rmin == rmax) ? 1.f : (qmax - qmin) / (rmax - rmin);
|
||||
float recip = scale ? 1.0f / scale : 0.0f;
|
||||
int32_t zp;
|
||||
float des_min = rmin * scale;
|
||||
float des_max = rmax * scale;
|
||||
float err_min = qmin + des_min;
|
||||
float err_max = qmax + des_max;
|
||||
float zp_f =
|
||||
(err_min + err_max) > 0 ? qmin - des_min : qmax - des_max;
|
||||
zp_f = std::clamp(zp_f, qmin, qmax);
|
||||
zp = std::lrintf(zp_f);
|
||||
int8_t* out_ptr = dst_qa8dx + i * dst_stride;
|
||||
// store header
|
||||
*reinterpret_cast<float*>(out_ptr) = recip;
|
||||
*reinterpret_cast<int32_t*>(out_ptr + sizeof(float)) = -zp;
|
||||
out_ptr += sizeof(float) + sizeof(int32_t);
|
||||
// quantize
|
||||
for (size_t j = 0; j < K; ++j) {
|
||||
float v = cast_bf16_to_f32(row_ptr[j]);
|
||||
int32_t q = static_cast<int32_t>(std::round(v * scale)) + zp;
|
||||
q = std::clamp(
|
||||
q, static_cast<int32_t>(kI8Min), static_cast<int32_t>(kI8Max));
|
||||
*out_ptr++ = static_cast<int8_t>(q);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// MatMul lambda (MXN x MXK -> MNXK BF16)
|
||||
auto matmul_kernel = [&](size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const int8_t* lhs,
|
||||
const uint8_t* rhs,
|
||||
const float* scales,
|
||||
uint16_t* dst,
|
||||
float lo,
|
||||
float hi) {
|
||||
const size_t lhs_stride =
|
||||
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
|
||||
const size_t rhs_stride = roundup(K, 2) / 2;
|
||||
for (size_t i = 0; i < M; ++i) {
|
||||
const int8_t* lhs_row = lhs + i * lhs_stride;
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
int32_t acc = 0;
|
||||
const int8_t* lptr = lhs_row;
|
||||
const uint8_t* rptr = rhs + j * rhs_stride;
|
||||
float lhs_scale = *reinterpret_cast<const float*>(lptr);
|
||||
int32_t lhs_off =
|
||||
*reinterpret_cast<const int32_t*>(lptr + sizeof(float));
|
||||
lptr += sizeof(float) + sizeof(int32_t);
|
||||
for (size_t t = 0; t < K; ++t) {
|
||||
int32_t lv = static_cast<int32_t>(lptr[t]);
|
||||
uint8_t bv = rptr[t / 2];
|
||||
int32_t rv = ((t & 1) == 0) ? (static_cast<int32_t>(bv & 0xF) - 8)
|
||||
: (static_cast<int32_t>(bv >> 4) - 8);
|
||||
acc += lv * rv + lhs_off * rv;
|
||||
}
|
||||
float res = static_cast<float>(acc) * scales[j] * lhs_scale;
|
||||
if (bias) {
|
||||
res += bias[j];
|
||||
}
|
||||
res = std::clamp(res, lo, hi);
|
||||
*dst++ = cast_f32_to_bf16(res);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// allocate and run
|
||||
std::unique_ptr<int8_t[]> packed(
|
||||
new int8_t[m * (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t))]);
|
||||
quant_pack_8bit_channelwise(m, k, lhs_bf16, packed.get());
|
||||
matmul_kernel(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
packed.get(),
|
||||
rhs_qs4cx,
|
||||
rhs_scales,
|
||||
dst_bf16,
|
||||
scalar_min,
|
||||
scalar_max);
|
||||
}
|
||||
|
||||
/**
|
||||
* The Int4 quantized weights must be represented as a uint8 tensor
|
||||
* For matrix multiplication with a weight shape of (N x K)
|
||||
@ -819,21 +953,21 @@ void dyn_quant_pack_4bit_weight_kernel(
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
if (can_use_kleidiai(scales_zeros, K, block_size)) {
|
||||
const int64_t weight_packed_size =
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, weights.scalar_type());
|
||||
packed_weights.resize_({weight_packed_size});
|
||||
kleidiai::kai_pack_int4_rhs(
|
||||
packed_weights, weights, scales_zeros, bias, N, K, block_size);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
TORCH_CHECK(
|
||||
bias.has_value() == 0,
|
||||
__func__,
|
||||
" : Bias is unsupported in reference implementation");
|
||||
packed_weights = packed_weights.to(kFloat);
|
||||
auto weight_reshaped = weights.view({-1}).to(kFloat);
|
||||
auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat);
|
||||
auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0);
|
||||
auto weight_reshaped = weights.reshape({-1}).to(kFloat);
|
||||
auto scales_zeros_reshaped = scales_zeros.reshape({-1}).to(kFloat);
|
||||
std::vector<at::Tensor> tensors_to_cat = {weight_reshaped, scales_zeros_reshaped};
|
||||
if (bias.has_value()) {
|
||||
tensors_to_cat.push_back(bias.value().view({-1}).to(kFloat));
|
||||
}
|
||||
auto res = at::cat(tensors_to_cat, 0);
|
||||
packed_weights.resize_(res.sizes()).copy_(res);
|
||||
}
|
||||
}
|
||||
@ -847,7 +981,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
const float* rhs_scales_f32,
|
||||
float* dst_f32,
|
||||
float scalar_min,
|
||||
float scalar_max) {
|
||||
float scalar_max,
|
||||
const float* bias) {
|
||||
const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float));
|
||||
|
||||
auto lhs_qa8dx_buffer = std::make_unique<uint8_t[]>(input_size_8bit);
|
||||
@ -857,6 +992,9 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
// required format for matmul
|
||||
auto input_quant_pack_8bit_channelwise =
|
||||
[&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) {
|
||||
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
|
||||
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
|
||||
|
||||
const size_t dst_stride =
|
||||
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
|
||||
|
||||
@ -877,8 +1015,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
}
|
||||
|
||||
// Maximum/minimum int8 values
|
||||
const float qmin = (float)INT8_MIN;
|
||||
const float qmax = (float)INT8_MAX;
|
||||
constexpr float qmin = static_cast<float>(kI8Min);
|
||||
constexpr float qmax = static_cast<float>(kI8Max);
|
||||
|
||||
const float rmin0 = std::min(0.0f, min0);
|
||||
const float rmax0 = std::max(0.0f, max0);
|
||||
@ -904,7 +1042,7 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
zero_point0 = std::min(zero_point0, qmax);
|
||||
|
||||
// Round to nearest integer
|
||||
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
||||
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
|
||||
|
||||
int8_t* dst_ptr = lhs_qa8dx + m_idx * dst_stride;
|
||||
|
||||
@ -922,8 +1060,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
|
||||
|
||||
v0_s32 = v0_s32 + nudged_zero_point0;
|
||||
v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT8_MIN));
|
||||
v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT8_MAX));
|
||||
v0_s32 = std::max(v0_s32, static_cast<int32_t>(kI8Min));
|
||||
v0_s32 = std::min(v0_s32, static_cast<int32_t>(kI8Max));
|
||||
dst_ptr[0] = (int8_t)v0_s32;
|
||||
dst_ptr += sizeof(int8_t);
|
||||
}
|
||||
@ -987,6 +1125,10 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
|
||||
main_acc = main_acc * lhs_scale;
|
||||
|
||||
if (bias) {
|
||||
main_acc += bias[n_idx];
|
||||
}
|
||||
|
||||
// Clamp (min-max) operation
|
||||
main_acc = std::max(main_acc, scalar_min);
|
||||
main_acc = std::min(main_acc, scalar_max);
|
||||
@ -1007,12 +1149,16 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
const float* rhs_scales_fp32,
|
||||
float* dst_f32,
|
||||
float scalar_min,
|
||||
float scalar_max) {
|
||||
float scalar_max,
|
||||
const float* bias) {
|
||||
// Lambda for LHS quantization
|
||||
auto lhs_quant_pack = [&](size_t m,
|
||||
size_t k,
|
||||
const float* lhs_f32,
|
||||
int8_t* lhs_qa8dx) {
|
||||
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
|
||||
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
|
||||
|
||||
const size_t dst_stride =
|
||||
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
|
||||
|
||||
@ -1028,8 +1174,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
min0 = std::min(src0_0, min0);
|
||||
}
|
||||
|
||||
const float qmin = (float)INT8_MIN;
|
||||
const float qmax = (float)INT8_MAX;
|
||||
constexpr float qmin = static_cast<float>(kI8Min);
|
||||
constexpr float qmax = static_cast<float>(kI8Max);
|
||||
|
||||
const float rmin0 = std::min(0.0f, min0);
|
||||
const float rmax0 = std::max(0.0f, max0);
|
||||
@ -1046,7 +1192,7 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
|
||||
zero_point0 = std::max(zero_point0, qmin);
|
||||
zero_point0 = std::min(zero_point0, qmax);
|
||||
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
||||
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
|
||||
|
||||
int8_t* dst_ptr = lhs_qa8dx + row_idx * dst_stride;
|
||||
|
||||
@ -1059,9 +1205,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
const float src0_0 = src_ptr[k_idx];
|
||||
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
|
||||
v0_s32 = std::max(
|
||||
std::min(
|
||||
v0_s32 + nudged_zero_point0, static_cast<int32_t>(INT8_MAX)),
|
||||
static_cast<int32_t>(INT8_MIN));
|
||||
std::min(v0_s32 + nudged_zero_point0, static_cast<int32_t>(kI8Max)),
|
||||
static_cast<int32_t>(kI8Min));
|
||||
dst_ptr[0] = (int8_t)v0_s32;
|
||||
dst_ptr += sizeof(int8_t);
|
||||
}
|
||||
@ -1118,6 +1263,11 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
}
|
||||
|
||||
main_acc = main_acc * lhs_scale;
|
||||
|
||||
if (bias) {
|
||||
main_acc += bias[col_idx];
|
||||
}
|
||||
|
||||
main_acc = std::max(main_acc, scalar_min);
|
||||
main_acc = std::min(main_acc, scalar_max);
|
||||
|
||||
@ -1128,28 +1278,27 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
}
|
||||
|
||||
/**
|
||||
* Dynamic Input Quant 4 bit weights matmul execution flow
|
||||
(INT4 Weights + FP scales + FP32 Bias)
|
||||
FP32 Input Packed Buffer
|
||||
| |
|
||||
Quantize Cast
|
||||
to INT8 to INT8
|
||||
| |
|
||||
v v
|
||||
INT8 Input INT8 Weights
|
||||
\ /
|
||||
\ /
|
||||
\ /
|
||||
INT8 Matrix Multiplication
|
||||
|
|
||||
v
|
||||
FP32 Dequantized and Accumulate in FP32
|
||||
|
|
||||
v
|
||||
FP32 Final Output
|
||||
|
||||
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
|
||||
* Float32 Scales. If not provided, we will use fallback implementation.
|
||||
* Dynamic INT4 weight-only MatMul with per-row input quantization.
|
||||
*
|
||||
* Execution Flow:
|
||||
*
|
||||
* (INT4 Weights + FP Scales [+ optional Bias])
|
||||
*
|
||||
* Input (FP32 or BF16) Packed Weight Buffer
|
||||
* | |
|
||||
* Row-wise Quantization (INT8) |
|
||||
* | |
|
||||
* INT8 Input Activation INT4 Quantized Weights + Scales
|
||||
* \ /
|
||||
* \ /
|
||||
* Quantized Matrix Multiply
|
||||
* |
|
||||
* Output Tensor (BF16 or FP32)
|
||||
*
|
||||
* Notes:
|
||||
* - Groupwise kernels expect BF16 scales
|
||||
* - Channelwise kernels expect FP32 scales
|
||||
* - Bias is currently unsupported in fallback path
|
||||
*/
|
||||
void dyn_quant_matmul_4bit_kernel(
|
||||
const Tensor& output,
|
||||
@ -1161,65 +1310,75 @@ void dyn_quant_matmul_4bit_kernel(
|
||||
const int64_t block_size) {
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
const int64_t weight_packed_size =
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, inp.scalar_type());
|
||||
if (weight_packed_size == packed_weights.numel()) {
|
||||
// KleidiAI interface internally handles the Channelwise and groupwise
|
||||
// distinction
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm(
|
||||
output, inp, packed_weights, M, N, K, block_size);
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm(output, inp, packed_weights, M, N, K, block_size);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
float* lhs_f32 = reinterpret_cast<float*>(inp.data_ptr());
|
||||
const auto weights_size = N * K / 2;
|
||||
// The weights needs to be in uint8_t data type after quantization
|
||||
auto extracted_weights =
|
||||
(packed_weights.narrow(0, 0, weights_size)).to(kByte);
|
||||
auto float32_scales =
|
||||
(packed_weights.narrow(
|
||||
0, weights_size, packed_weights.size(0) - weights_size))
|
||||
.to(kFloat);
|
||||
uint8_t* rhs_4bit =
|
||||
reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
|
||||
float* rhs_scales_f32 = reinterpret_cast<float*>(float32_scales.data_ptr());
|
||||
float* dst_f32 = reinterpret_cast<float*>(output.data_ptr());
|
||||
if (block_size == K) {
|
||||
ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lhs_f32,
|
||||
rhs_4bit,
|
||||
rhs_scales_f32,
|
||||
dst_f32,
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
} else if (!(block_size % 32) && !(K % block_size)) {
|
||||
ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_size,
|
||||
lhs_f32,
|
||||
rhs_4bit,
|
||||
rhs_scales_f32,
|
||||
dst_f32,
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
block_size == K || (!(block_size % 32) && !(K % block_size)),
|
||||
__func__,
|
||||
": Group size should be multiple 32 or in_features [",
|
||||
K,
|
||||
"]. Provided ",
|
||||
block_size);
|
||||
{
|
||||
void* input = inp.data_ptr();
|
||||
void* dst = output.data_ptr();
|
||||
|
||||
// Extract weights, sclaes and biases form from packed tensor
|
||||
const int weights_elements = N * K / 2;
|
||||
const int scale_elements = N * (K / block_size);
|
||||
TORCH_CHECK(packed_weights.numel() >= (weights_elements + scale_elements), "Invalid packed weight tensor size");
|
||||
|
||||
auto extracted_weights = packed_weights.narrow(0, 0, weights_elements).to(kByte);
|
||||
auto extracted_scales_and_bias = packed_weights.narrow(0, weights_elements, packed_weights.size(0) - weights_elements).to(kFloat);
|
||||
auto float32_scales = extracted_scales_and_bias.narrow(0, 0, scale_elements);
|
||||
|
||||
int bias_elements = packed_weights.numel() - (weights_elements + scale_elements);
|
||||
float* weight_scales = float32_scales.data_ptr<float>();
|
||||
|
||||
void* bias_data = nullptr;
|
||||
if (bias_elements) {
|
||||
auto float32_bias = extracted_scales_and_bias.narrow(0, scale_elements, bias_elements);
|
||||
TORCH_CHECK(float32_bias.size(0) == N, "Expected bias length to match output dimension");
|
||||
bias_data = float32_bias.data_ptr();
|
||||
|
||||
}
|
||||
// 2 elements of 4 bit weights are packed into 1 uint8 packet
|
||||
uint8_t* weights_4bit = reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
|
||||
|
||||
// Dispatch to reference kernels
|
||||
if (inp.scalar_type() == at::kBFloat16) {
|
||||
// BF16 input, BF16 output
|
||||
constexpr float BF16_MAX = 3.38953139e+38f;
|
||||
constexpr float BF16_MIN = -BF16_MAX;
|
||||
if (block_size == K) {
|
||||
ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
|
||||
M, N, K,
|
||||
(uint16_t*)input, weights_4bit, weight_scales,
|
||||
(uint16_t*)dst, BF16_MIN, BF16_MAX, (float*)bias_data);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported block size for BF16 fallback");
|
||||
}
|
||||
} else if (inp.scalar_type() == at::kFloat) {
|
||||
// FP32 input, FP32 output
|
||||
if (block_size == K) {
|
||||
ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
M, N, K,
|
||||
(float*)input, weights_4bit, weight_scales,
|
||||
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
|
||||
} else if (!(block_size % 32) && !(K % block_size)) {
|
||||
ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
M, N, K, block_size,
|
||||
(float*)input, weights_4bit, weight_scales,
|
||||
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported block size for FP32 fallback");
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported input/output dtype combination for int4mm kernel");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
}
|
||||
ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel)
|
||||
ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel)
|
||||
REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel)
|
||||
|
||||
@ -78,9 +78,18 @@ __global__ void EmbeddingBag_updateOutputKernel_max(
|
||||
scalar_t weightFeatMax = 0;
|
||||
int64_t bag_size_ = 0;
|
||||
int64_t maxWord = -1;
|
||||
|
||||
// Separate validation loop reduces register pressure in the main loop below.
|
||||
// No early exit (break) on invalid input as benchmarking shows it degrades performance.
|
||||
bool has_invalid_index = false;
|
||||
for (int64_t emb = begin; emb < end; emb++) {
|
||||
index_t input_idx = input[emb];
|
||||
has_invalid_index = has_invalid_index || (input_idx < 0 || input_idx >= numRows);
|
||||
}
|
||||
CUDA_KERNEL_ASSERT(!has_invalid_index && "Invalid input index in EmbeddingBag: index out of range [0, numRows)");
|
||||
|
||||
for (int64_t emb = begin; emb < end; emb++) {
|
||||
bool pad = (input[emb] == padding_idx);
|
||||
CUDA_KERNEL_ASSERT(input[emb] < numRows);
|
||||
const int64_t weightRow = input[emb] * weight_stride0;
|
||||
scalar_t weightValue = weightFeat[weightRow];
|
||||
if (bag_size_ == 0 || weightValue > weightFeatMax) {
|
||||
@ -129,10 +138,19 @@ __global__ void EmbeddingBag_updateOutputKernel_sum_mean(
|
||||
CUDA_KERNEL_ASSERT(end >= begin);
|
||||
accscalar_t weightFeatSum = 0;
|
||||
int64_t bag_size_ = 0;
|
||||
|
||||
// Separate validation loop reduces register pressure in the main loop below.
|
||||
// No early exit (break) on invalid input as benchmarking shows it degrades performance.
|
||||
bool has_invalid_index = false;
|
||||
for (int64_t emb = begin; emb < end; emb++) {
|
||||
index_t input_idx = input[emb];
|
||||
has_invalid_index = has_invalid_index || (input_idx < 0 || input_idx >= numRows);
|
||||
}
|
||||
CUDA_KERNEL_ASSERT(!has_invalid_index && "Invalid input index in EmbeddingBag: index out of range [0, numRows)");
|
||||
|
||||
for (int64_t emb = begin; emb < end; emb++) {
|
||||
index_t input_idx = input[emb];
|
||||
bool pad = (input_idx == padding_idx);
|
||||
CUDA_KERNEL_ASSERT(0 <= input_idx && input_idx < numRows);
|
||||
const int64_t weightRow = input_idx * weight_stride0;
|
||||
scalar_t weightValue = weightFeat[weightRow];
|
||||
weightValue = pad ? static_cast<scalar_t>(0) : weightValue;
|
||||
|
||||
@ -78,9 +78,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const SwizzleType& swizzle_a,
|
||||
const SwizzleType swizzle_a,
|
||||
const Tensor& scale_b,
|
||||
const SwizzleType& swizzle_b,
|
||||
const SwizzleType swizzle_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
Tensor& out) {
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
|
||||
@ -5,69 +5,11 @@
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
// ROCm 6.3 is planned to have these functions, but until then here they are.
|
||||
#if defined(USE_ROCM)
|
||||
#include <device_functions.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
|
||||
__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) {
|
||||
#if (defined(__gfx942__)) && \
|
||||
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16)
|
||||
typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2;
|
||||
static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw));
|
||||
union {
|
||||
__hip_bfloat162_raw bf162_raw;
|
||||
vec_short2 vs2;
|
||||
} u{static_cast<__hip_bfloat162_raw>(value)};
|
||||
u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)address, u.vs2);
|
||||
return static_cast<__hip_bfloat162>(u.bf162_raw);
|
||||
#else
|
||||
static_assert(sizeof(unsigned int) == sizeof(__hip_bfloat162_raw));
|
||||
union u_hold {
|
||||
__hip_bfloat162_raw h2r;
|
||||
unsigned int u32;
|
||||
};
|
||||
u_hold old_val, new_val;
|
||||
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
do {
|
||||
new_val.h2r = __hadd2(old_val.h2r, value);
|
||||
} while (!__hip_atomic_compare_exchange_strong(
|
||||
(unsigned int*)address, &old_val.u32, new_val.u32,
|
||||
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
|
||||
return old_val.h2r;
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) {
|
||||
#if (defined(__gfx942__)) && \
|
||||
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16)
|
||||
// The api expects an ext_vector_type of half
|
||||
typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162;
|
||||
static_assert(sizeof(vec_fp162) == sizeof(__half2_raw));
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
vec_fp162 fp16;
|
||||
} u {static_cast<__half2_raw>(value)};
|
||||
u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16((vec_fp162*)address, u.fp16);
|
||||
return static_cast<__half2>(u.h2r);
|
||||
#else
|
||||
static_assert(sizeof(__half2_raw) == sizeof(unsigned int));
|
||||
union u_hold {
|
||||
__half2_raw h2r;
|
||||
unsigned int u32;
|
||||
};
|
||||
u_hold old_val, new_val;
|
||||
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
do {
|
||||
new_val.h2r = __hadd2(old_val.h2r, value);
|
||||
} while (!__hip_atomic_compare_exchange_strong(
|
||||
(unsigned int*)address, &old_val.u32, new_val.u32,
|
||||
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
|
||||
return old_val.h2r;
|
||||
#endif
|
||||
}
|
||||
#define ATOMICADD preview_unsafeAtomicAdd
|
||||
#define ATOMICADD unsafeAtomicAdd
|
||||
#define NATIVE_ZERO_BF16 __float2bfloat16(0.0f)
|
||||
#else
|
||||
#define ATOMICADD atomicAdd
|
||||
|
||||
@ -11,7 +11,7 @@ static inline std::ostream& operator<<(std::ostream& out, dim3 dim) {
|
||||
if (dim.y == 1 && dim.z == 1) {
|
||||
out << dim.x;
|
||||
} else {
|
||||
out << "[" << dim.x << "," << dim.y << "," << dim.z << "]";
|
||||
out << '[' << dim.x << ',' << dim.y << ',' << dim.z << ']';
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -27,7 +27,7 @@ std::ostream& operator<<(std::ostream& out, const ReduceConfig& config) {
|
||||
out << "input_mult=[";
|
||||
for (int i = 0; i < 3; i++) {
|
||||
if (i != 0) {
|
||||
out << ",";
|
||||
out << ',';
|
||||
}
|
||||
out << config.input_mult[i];
|
||||
}
|
||||
@ -35,7 +35,7 @@ std::ostream& operator<<(std::ostream& out, const ReduceConfig& config) {
|
||||
out << "output_mult=[";
|
||||
for (int i = 0; i < 2; i++) {
|
||||
if (i != 0) {
|
||||
out << ",";
|
||||
out << ',';
|
||||
}
|
||||
out << config.output_mult[i];
|
||||
}
|
||||
@ -49,7 +49,7 @@ std::ostream& operator<<(std::ostream& out, const ReduceConfig& config) {
|
||||
out << "block=" << config.block() << ", ";
|
||||
out << "grid=" << config.grid() << ", ";
|
||||
out << "global_memory_size=" << config.global_memory_size();
|
||||
out << ")";
|
||||
out << ')';
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -740,7 +740,12 @@ _scaled_rowwise_rowwise(
|
||||
TORCH_CHECK_VALUE(scale_a.numel() == mat_a.size(0) && scale_a.scalar_type() == kFloat, "scale_a must have ", mat_a.size(0), " Float elements, got ", scale_a.numel())
|
||||
TORCH_CHECK_VALUE(scale_b.numel() == mat_b.size(1) && scale_b.scalar_type() == kFloat, "scale_b must have ", mat_b.size(1), " Float elements, got ", scale_b.numel())
|
||||
|
||||
TORCH_CHECK_VALUE(scale_a.stride(1) == 1, "expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1));
|
||||
// if we have a scale of shape [256, 1] (say), then stride can be [1, 0] - handle this case
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(1) == 1 ||
|
||||
scale_a.size(1) == 1,
|
||||
"expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1)
|
||||
);
|
||||
TORCH_CHECK_VALUE(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
|
||||
|
||||
auto scaling_choice_a = ScalingType::RowWise;
|
||||
|
||||
@ -364,9 +364,9 @@ void f8f8bf16_grouped_gemm_impl_sm90(
|
||||
// reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(
|
||||
// stride_output_h + group_count);
|
||||
|
||||
// std::cout << "PTRS " << mat_a.data_ptr() << " " << mat_b.data_ptr() << "
|
||||
// std::cout << "PTRS " << mat_a.data_ptr() << ' ' << mat_b.data_ptr() << "
|
||||
// "
|
||||
// << out.data_ptr() << " " << scale_a.data_ptr() << " "
|
||||
// << out.data_ptr() << ' ' << scale_a.data_ptr() << ' '
|
||||
// << scale_b.data_ptr() << "\n";
|
||||
// for (int i = 0; i < group_count; i++) {
|
||||
// std::cout << "A " << (void*)inputA_ptrs_h[i] << "\n";
|
||||
|
||||
@ -1057,14 +1057,14 @@ std::string generate_code(
|
||||
// TODO these arrays are potentially of the different types, use function
|
||||
// traits to determine the types
|
||||
declare_load_arrays << f_inputs_type << " arg" << std::to_string(i)
|
||||
<< "[" << std::to_string(thread_work_size) << "];\n";
|
||||
<< '[' << std::to_string(thread_work_size) << "];\n";
|
||||
}
|
||||
env.s("declare_load_arrays", declare_load_arrays.str());
|
||||
|
||||
std::stringstream declare_store_arrays;
|
||||
for (int i = 0; i < nOutputs; i++) {
|
||||
declare_store_arrays << result_type << " out" << std::to_string(i)
|
||||
<< "[" << std::to_string(thread_work_size) << "];\n";
|
||||
<< '[' << std::to_string(thread_work_size) << "];\n";
|
||||
}
|
||||
env.s("declare_store_arrays", declare_store_arrays.str());
|
||||
|
||||
@ -1217,7 +1217,7 @@ std::string generate_code(
|
||||
for (const auto i : c10::irange(nInputs)){
|
||||
auto i_string = std::to_string(i);
|
||||
vector_inputs << "auto * input" << i_string <<
|
||||
" = reinterpret_cast<const scalar_t*>(data[" << i_string << "+" << nOutputs << "])" <<
|
||||
" = reinterpret_cast<const scalar_t*>(data[" << i_string << '+' << nOutputs << "])" <<
|
||||
" + block_work_size * idx;\n";
|
||||
}
|
||||
env.s("vector_inputs", vector_inputs.str());
|
||||
@ -1543,17 +1543,17 @@ NvrtcFunction jit_pwise_function(
|
||||
|
||||
// Constructs file path by appending constructed cubin name to cache path
|
||||
std::stringstream ss;
|
||||
ss << *cache_dir << "/";
|
||||
ss << *cache_dir << '/';
|
||||
ss << kernel_name;
|
||||
#ifdef USE_ROCM
|
||||
ss << "_arch" << prop->gcnArchName;
|
||||
#else
|
||||
ss << "_arch" << cuda_major << "." << cuda_minor;
|
||||
ss << "_arch" << cuda_major << '.' << cuda_minor;
|
||||
#endif
|
||||
ss << "_nvrtc" << nvrtc_major << "." << nvrtc_minor;
|
||||
ss << "_nvrtc" << nvrtc_major << '.' << nvrtc_minor;
|
||||
ss << (compile_to_sass ? "_sass" : "_ptx");
|
||||
ss << "_" << code.length();
|
||||
ss << "_" << hash_code;
|
||||
ss << '_' << code.length();
|
||||
ss << '_' << hash_code;
|
||||
file_path = ss.str();
|
||||
|
||||
std::ifstream readin{file_path, std::ios::in | std::ifstream::binary};
|
||||
|
||||
@ -82,15 +82,15 @@ namespace native {
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const ConvolutionParams& params) {
|
||||
out << "ConvolutionParams \n"
|
||||
<< " memory_format = " << params.memory_format << "\n"
|
||||
<< " data_type = " << cudnnTypeToString(params.dataType) << "\n"
|
||||
<< " padding = " << ArrayRef<int>{params.padding} << "\n"
|
||||
<< " stride = " << ArrayRef<int>{params.stride} << "\n"
|
||||
<< " dilation = " << ArrayRef<int>{params.dilation} << "\n"
|
||||
<< " groups = " << params.groups << "\n"
|
||||
<< " memory_format = " << params.memory_format << '\n'
|
||||
<< " data_type = " << cudnnTypeToString(params.dataType) << '\n'
|
||||
<< " padding = " << ArrayRef<int>{params.padding} << '\n'
|
||||
<< " stride = " << ArrayRef<int>{params.stride} << '\n'
|
||||
<< " dilation = " << ArrayRef<int>{params.dilation} << '\n'
|
||||
<< " groups = " << params.groups << '\n'
|
||||
<< " deterministic = " << (params.deterministic ? "true" : "false")
|
||||
<< "\n"
|
||||
<< " allow_tf32 = " << (params.allow_tf32 ? "true" : "false") << "\n";
|
||||
<< '\n'
|
||||
<< " allow_tf32 = " << (params.allow_tf32 ? "true" : "false") << '\n';
|
||||
|
||||
return out;
|
||||
}
|
||||
@ -173,16 +173,16 @@ std::string repro_from_args(const ConvolutionParams& params) {
|
||||
at::globalContext().float32Precision(
|
||||
at::Float32Backend::CUDA, at::Float32Op::MATMUL) ==
|
||||
at::Float32Precision::TF32)
|
||||
<< "\n";
|
||||
<< '\n';
|
||||
ss << "torch.backends.cudnn.benchmark = "
|
||||
<< pybool(at::globalContext().benchmarkCuDNN()) << "\n";
|
||||
<< pybool(at::globalContext().benchmarkCuDNN()) << '\n';
|
||||
ss << "torch.backends.cudnn.deterministic = " << pybool(params.deterministic)
|
||||
<< "\n";
|
||||
<< '\n';
|
||||
ss << "torch.backends.cudnn.allow_tf32 = " << pybool(params.allow_tf32)
|
||||
<< "\n";
|
||||
<< '\n';
|
||||
ss << "data = torch.randn(" << ArrayRef<int>(params.input_size, dim)
|
||||
<< ", dtype=" << full_dtype << ", ";
|
||||
ss << "device='cuda', requires_grad=True)" << to_channels_last << "\n";
|
||||
ss << "device='cuda', requires_grad=True)" << to_channels_last << '\n';
|
||||
ss << "net = torch.nn.Conv" << dim - 2 << "d(" << in_channels << ", "
|
||||
<< out_channels << ", ";
|
||||
ss << "kernel_size=" << ArrayRef<int>(¶ms.weight_size[2], dim - 2)
|
||||
@ -192,7 +192,7 @@ std::string repro_from_args(const ConvolutionParams& params) {
|
||||
ss << "dilation=" << ArrayRef<int>(params.dilation, dim - 2) << ", ";
|
||||
ss << "groups=" << params.groups << ")\n";
|
||||
ss << "net = net.cuda()." << partial_dtype << "()" << to_channels_last
|
||||
<< "\n";
|
||||
<< '\n';
|
||||
ss << "out = net(data)\n";
|
||||
ss << "out.backward(torch.randn_like(out))\n";
|
||||
ss << "torch.cuda.synchronize()\n\n";
|
||||
|
||||
@ -93,11 +93,10 @@ std::ostream& operator<<(std::ostream& out, const ConvolutionArgs& args) {
|
||||
<< "input: " << args.idesc // already has a trailing newline
|
||||
<< "output: " << args.odesc // already has a trailing newline
|
||||
<< "weight: " << args.wdesc // already has a trailing newline
|
||||
<< "Pointer addresses: "
|
||||
<< "\n"
|
||||
<< " input: " << args.input.const_data_ptr() << "\n"
|
||||
<< " output: " << args.output.const_data_ptr() << "\n"
|
||||
<< " weight: " << args.weight.const_data_ptr() << "\n";
|
||||
<< "Pointer addresses: " << '\n'
|
||||
<< " input: " << args.input.const_data_ptr() << '\n'
|
||||
<< " output: " << args.output.const_data_ptr() << '\n'
|
||||
<< " weight: " << args.weight.const_data_ptr() << '\n';
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -21,18 +21,27 @@ void kai_pack_int4_rhs(
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||
if (bl == k) {
|
||||
// Channelwise
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
auto& params = kernel_packet.rhs_pack_params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
|
||||
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
|
||||
kernel_packet, weight_packed, weight, scales, bias, n, k);
|
||||
if (weight.scalar_type() == at::kBFloat16) {
|
||||
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
|
||||
auto& params = kernel_packet.rhs_pack_params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>(
|
||||
kernel_packet, weight_packed, weight, scales, bias, n, k);
|
||||
} else {
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
auto& params = kernel_packet.rhs_pack_params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
|
||||
kernel_packet, weight_packed, weight, scales, bias, n, k);
|
||||
}
|
||||
} else if (!(bl % 32) && !(k % bl)) {
|
||||
// Groupwise
|
||||
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
|
||||
@ -63,19 +72,29 @@ void kai_pack_int4_rhs(
|
||||
size_t kai_pack_rhs_int4_size(
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
const int64_t bl,
|
||||
at::ScalarType tensor_dtype) {
|
||||
size_t packed_size = n * k;
|
||||
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||
if (bl == k) {
|
||||
// Channelwise
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
|
||||
if (tensor_dtype == at::kBFloat16) {
|
||||
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
|
||||
} else {
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
|
||||
}
|
||||
} else if (!(bl % 32) && !(k % bl)) {
|
||||
// Groupwise
|
||||
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
|
||||
@ -148,8 +167,7 @@ static void kai_quant_pack_lhs_int4_mm_groupwise(
|
||||
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
|
||||
const int64_t m_idx = thread_id * vec_per_thread;
|
||||
auto lhs_packed_ptr = lhs_packed_base +
|
||||
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
|
||||
m_idx, k, mr, kr, sr);
|
||||
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (m - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
@ -259,8 +277,7 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
|
||||
const int64_t m_idx = thread_id * vec_per_thread;
|
||||
auto lhs_packed_ptr = lhs_packed_base +
|
||||
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
|
||||
m_idx, k, mr, kr, sr);
|
||||
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (m - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
@ -320,19 +337,144 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
});
|
||||
}
|
||||
|
||||
void kai_quant_pack_lhs_int4_mm(
|
||||
static void kai_quant_pack_lhs_int4_mm_bf16_channelwise(
|
||||
const Tensor& output,
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const int64_t m,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
// Kernel IDs for GEMM and GEMV
|
||||
constexpr kai_kernel_id gemm_id =
|
||||
kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm;
|
||||
constexpr kai_kernel_id gemv_id =
|
||||
kai_kernel_id::matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod;
|
||||
|
||||
// Get total threads and select kernel
|
||||
const int64_t total_threads = at::get_num_threads();
|
||||
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemv_id);
|
||||
if (cpuinfo_has_arm_i8mm() && m > 1) {
|
||||
kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemm_id);
|
||||
}
|
||||
|
||||
// Thread blocking parameters
|
||||
const int64_t n_step = kernel_packet.ukernel.get_n_step();
|
||||
const size_t mr = kernel_packet.ukernel.get_mr();
|
||||
const size_t kr = kernel_packet.ukernel.get_kr();
|
||||
const size_t sr = kernel_packet.ukernel.get_sr();
|
||||
|
||||
const size_t lhs_packed_size =
|
||||
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr);
|
||||
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size);
|
||||
uint8_t* dst_act_mtx_bf16 = reinterpret_cast<uint8_t*>(output.data_ptr());
|
||||
const uint8_t* lhs_native_mtx_bf16 =
|
||||
reinterpret_cast<const uint8_t*>(input.data_ptr());
|
||||
const uint8_t* rhs_packed_mtx_qs4cx =
|
||||
reinterpret_cast<const uint8_t*>(weight.data_ptr());
|
||||
uint8_t* lhs_packed_base = lhs_packed.get();
|
||||
|
||||
constexpr int32_t element_size = sizeof(uint16_t);
|
||||
const size_t lhs_stride = k * element_size;
|
||||
const size_t dst_stride = n * element_size;
|
||||
|
||||
// LHS quantization packing
|
||||
int64_t vec_per_thread = get_vec_per_thread(m, total_threads, mr);
|
||||
int64_t num_threads = (m + vec_per_thread - 1) / vec_per_thread;
|
||||
const size_t src_stride = vec_per_thread * lhs_stride;
|
||||
|
||||
auto lhs_quant_pack = [=, &kernel_packet](int64_t thread_id) {
|
||||
const auto lhs_src_ptr = lhs_native_mtx_bf16 + thread_id * src_stride;
|
||||
const int64_t m_idx = thread_id * vec_per_thread;
|
||||
auto lhs_packed_ptr = lhs_packed_base +
|
||||
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (m - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
|
||||
kernel_packet.kai_run_lhs_quant_pack(
|
||||
vec_num,
|
||||
k,
|
||||
mr,
|
||||
kr,
|
||||
sr,
|
||||
0,
|
||||
(const uint16_t*)lhs_src_ptr,
|
||||
lhs_stride,
|
||||
lhs_packed_ptr);
|
||||
};
|
||||
|
||||
at::parallel_for(
|
||||
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
|
||||
lhs_quant_pack(thread_id);
|
||||
}
|
||||
});
|
||||
|
||||
// Matrix multiplication
|
||||
vec_per_thread = get_vec_per_thread(n, total_threads, n_step);
|
||||
num_threads = (n + vec_per_thread - 1) / vec_per_thread;
|
||||
|
||||
auto mm = [=, &kernel_packet](int64_t thread_id) {
|
||||
const auto rhs_packed_ptr = rhs_packed_mtx_qs4cx +
|
||||
kernel_packet.ukernel.get_rhs_packed_offset(
|
||||
thread_id * vec_per_thread, k);
|
||||
auto dst_ptr = dst_act_mtx_bf16 +
|
||||
kernel_packet.ukernel.get_dst_offset(
|
||||
0, thread_id * vec_per_thread, dst_stride);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (n - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
|
||||
kernel_packet.ukernel.run_matmul(
|
||||
m,
|
||||
vec_num,
|
||||
k,
|
||||
lhs_packed_base,
|
||||
rhs_packed_ptr,
|
||||
(uint16_t*)dst_ptr,
|
||||
dst_stride,
|
||||
element_size, // dst_stride_col
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
};
|
||||
|
||||
at::parallel_for(
|
||||
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
|
||||
mm(thread_id);
|
||||
}
|
||||
});
|
||||
}
|
||||
void kai_quant_pack_lhs_int4_mm(
|
||||
const at::Tensor& output,
|
||||
const at::Tensor& input,
|
||||
const at::Tensor& weight,
|
||||
const int64_t m,
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||
if (bl == k) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
output, input, weight, m, n, k);
|
||||
} else if (!(bl % 32) && !(k % bl)) {
|
||||
const auto input_dtype = input.dtype();
|
||||
|
||||
if (input_dtype == at::kBFloat16) {
|
||||
if (cpuinfo_has_arm_bf16()) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_bf16_channelwise(
|
||||
output, input, weight, m, n, k);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"BF16 Unsupported: CPU does not support BF16. Please use a CPU with BF16 support.");
|
||||
}
|
||||
} else if (input_dtype == at::kFloat) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
output, input, weight, m, n, k);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Unsupported input data type: Only Bfloat16 and Float inputs are supported.");
|
||||
}
|
||||
} else if ((bl % 32 == 0) && (k % bl == 0)) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_groupwise(
|
||||
output, input, weight, m, n, k, bl);
|
||||
}
|
||||
|
||||
@ -25,7 +25,8 @@ void kai_pack_int4_rhs(
|
||||
size_t kai_pack_rhs_int4_size(
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl);
|
||||
const int64_t bl,
|
||||
at::ScalarType tensor_dtype = at::kFloat);
|
||||
|
||||
/**
|
||||
* @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul )
|
||||
|
||||
@ -36,7 +36,8 @@ void kai_pack_rhs_groupwise_int4(
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
|
||||
}
|
||||
|
||||
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
|
||||
float* bias_ptr =
|
||||
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
|
||||
auto& params = kernel.rhs_pack_params;
|
||||
|
||||
kernel.kai_run_rhs_pack(
|
||||
@ -73,7 +74,8 @@ void kai_pack_rhs_channelwise_int4(
|
||||
auto weight_packed_data =
|
||||
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
|
||||
const auto weight_data = weight.data_ptr<uint8_t>();
|
||||
const auto scales_data = scales.data_ptr<float>();
|
||||
|
||||
const auto scales_data = scales.to(kFloat).data_ptr<float>();
|
||||
|
||||
if (weight_data == nullptr) {
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
|
||||
@ -83,7 +85,8 @@ void kai_pack_rhs_channelwise_int4(
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
|
||||
}
|
||||
|
||||
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
|
||||
float* bias_ptr =
|
||||
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
|
||||
auto& params = kernel.rhs_pack_params;
|
||||
|
||||
kernel.kai_run_rhs_pack(
|
||||
|
||||
@ -68,5 +68,39 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel(
|
||||
const kai_kernel_id id) {
|
||||
return channelwise_8bit_4bit_kernels.at(id);
|
||||
}
|
||||
|
||||
// Kernel Mapping - BF16 Channelwise
|
||||
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>
|
||||
bf16_channelwise_8bit_4bit_kernels = {
|
||||
{kai_kernel_id::
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_n_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_mr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_nr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_kr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_sr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_dst_size_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_run_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod}}},
|
||||
{kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_n_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_mr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_nr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_kr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_sr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_dst_size_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_run_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm}}}};
|
||||
|
||||
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp kai_select_bf16_channelwise_matmul_ukernel(
|
||||
const kai_kernel_id id) {
|
||||
return bf16_channelwise_8bit_4bit_kernels.at(id);
|
||||
}
|
||||
} // namespace at::native::kleidiai
|
||||
#endif
|
||||
|
||||
@ -10,21 +10,32 @@
|
||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h>
|
||||
|
||||
namespace at::native::kleidiai {
|
||||
|
||||
enum class kai_kernel_id {
|
||||
// FP32 inputs, 4-bit weights, FP32 output
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod =
|
||||
0, // Groupwise 4 bit GEMV
|
||||
0, // Groupwise 4-bit GEMV (per-group scales, NEON DOTPROD)
|
||||
matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm =
|
||||
1, // Groupwise 4 bit GEMM
|
||||
1, // Groupwise 4-bit GEMM (per-group scales, NEON I8MM)
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod =
|
||||
2, // Channelwise 4 bit GEMV
|
||||
2, // Channelwise 4-bit GEMV (per-channel scales, NEON DOTPROD)
|
||||
matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm =
|
||||
3 // Channelwise 4 bit GEMM
|
||||
3, // Channelwise 4-bit GEMM (per-channel scales, NEON I8MM)
|
||||
|
||||
// BF16 inputs, 4-bit weights, BF16 output
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod =
|
||||
4, // Channelwise 4-bit GEMV with BF16 input/output
|
||||
matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm =
|
||||
5 // Channelwise 4-bit GEMM with BF16 input/output
|
||||
};
|
||||
|
||||
// Channelwise Kernel mapping
|
||||
@ -66,6 +77,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
|
||||
void* rhs_packed,
|
||||
size_t extra_bytes,
|
||||
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
|
||||
size_t(*kai_get_lhs_quant_pack_offset)(
|
||||
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
|
||||
);
|
||||
|
||||
kai_matmul_ukernel_f32_qa8dxp_qs4cxp(
|
||||
const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel)
|
||||
@ -75,12 +89,71 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
|
||||
kai_get_rhs_packed_size(
|
||||
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {}
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32){}
|
||||
};
|
||||
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp
|
||||
kai_select_channelwise_matmul_ukernel(const kai_kernel_id id);
|
||||
|
||||
// bf16 Channelwise Kernel mapping
|
||||
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp {
|
||||
struct kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel ukernel;
|
||||
struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params;
|
||||
size_t (*kai_get_lhs_packed_size)(
|
||||
size_t m,
|
||||
size_t k,
|
||||
size_t mr,
|
||||
size_t kr,
|
||||
size_t sr);
|
||||
size_t (*kai_get_rhs_packed_size)(
|
||||
size_t n,
|
||||
size_t k,
|
||||
size_t nr,
|
||||
size_t kr,
|
||||
size_t sr);
|
||||
void (*kai_run_lhs_quant_pack)(
|
||||
size_t m,
|
||||
size_t k,
|
||||
size_t mr,
|
||||
size_t kr,
|
||||
size_t sr,
|
||||
size_t m_idx_start,
|
||||
const void* lhs,
|
||||
size_t lhs_stride,
|
||||
void* lhs_packed);
|
||||
void (*kai_run_rhs_pack)(
|
||||
size_t num_groups,
|
||||
size_t n,
|
||||
size_t k,
|
||||
size_t nr,
|
||||
size_t kr,
|
||||
size_t sr,
|
||||
const uint8_t* rhs,
|
||||
const float* bias,
|
||||
const float* scale,
|
||||
void* rhs_packed,
|
||||
size_t extra_bytes,
|
||||
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
|
||||
size_t(*kai_get_lhs_quant_pack_offset)(
|
||||
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
|
||||
);
|
||||
|
||||
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp(
|
||||
const kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel& kernel)
|
||||
: ukernel(kernel),
|
||||
kai_get_lhs_packed_size(
|
||||
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon),
|
||||
kai_get_rhs_packed_size(
|
||||
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_bf16_neon),
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon){}
|
||||
};
|
||||
|
||||
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp
|
||||
kai_select_bf16_channelwise_matmul_ukernel(const kai_kernel_id id);
|
||||
|
||||
// Groupwise Kernel mapping
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
|
||||
struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel;
|
||||
@ -125,6 +198,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
|
||||
void* rhs_packed,
|
||||
size_t extra_bytes,
|
||||
const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params);
|
||||
size_t(*kai_get_lhs_quant_pack_offset)(
|
||||
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
|
||||
);
|
||||
|
||||
kai_matmul_ukernel_f32_qa8dxp_qs4c32p(
|
||||
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel)
|
||||
@ -134,7 +210,8 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
|
||||
kai_get_rhs_packed_size(
|
||||
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
|
||||
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {}
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
|
||||
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32) {}
|
||||
};
|
||||
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(
|
||||
|
||||
@ -115,7 +115,7 @@ std::ostream& operator<<(
|
||||
std::copy(
|
||||
strides.begin(), strides.end() - 1, std::ostream_iterator<int>(oss, ","));
|
||||
oss << sizes.back();
|
||||
output << oss.str() << "}";
|
||||
output << oss.str() << '}';
|
||||
return output;
|
||||
}
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ std::ostream& operator<<(std::ostream& out, const ConvParams& params) {
|
||||
<< " transposed = " << params.transposed
|
||||
<< " output_padding = " << IntArrayRef{params.output_padding}
|
||||
<< " groups = " << params.groups << " benchmark = " << params.benchmark
|
||||
<< " deterministic = " << params.deterministic << "}";
|
||||
<< " deterministic = " << params.deterministic << '}';
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -82,6 +82,7 @@ NSArray<NSNumber*>* getTensorAxes(const TensorBase& t);
|
||||
NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
|
||||
std::string getMPSShapeString(MPSShape* shape);
|
||||
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false);
|
||||
std::string to_hex_key(float);
|
||||
std::string getArrayRefString(const IntArrayRef s);
|
||||
// use has_storage() on the returned tensor to determine if src actually is a view
|
||||
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);
|
||||
|
||||
@ -301,6 +301,10 @@ std::string getArrayRefString(const IntArrayRef s) {
|
||||
return fmt::to_string(fmt::join(s, ","));
|
||||
}
|
||||
|
||||
std::string to_hex_key(float f) {
|
||||
return fmt::format("{:a}", f);
|
||||
}
|
||||
|
||||
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, bool exclude_shape) {
|
||||
fmt::basic_memory_buffer<char, 100> buffer;
|
||||
auto buf_iterator = std::back_inserter(buffer);
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/TensorCompare.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <algorithm>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -89,13 +90,21 @@ static void check_min_max_dims(const OptionalTensorRef clamp_opt, const Tensor&
|
||||
auto clamp_shape = clamp_opt->sizes();
|
||||
auto input_shape = input_t.sizes();
|
||||
|
||||
TORCH_CHECK(num_clamp_dims <= num_input_dims,
|
||||
op_name + ": clamp tensor number of dims must not be greater than that of input tensor")
|
||||
if (num_clamp_dims > num_input_dims) {
|
||||
auto leading_dims = num_clamp_dims - num_input_dims;
|
||||
for (int64_t i = 0; i < leading_dims; ++i) {
|
||||
TORCH_CHECK(clamp_shape[i] == 1,
|
||||
op_name + ": clamp tensor leading shape must be 1 to broadcast with input tensor");
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_clamp_dims; i++)
|
||||
auto clamp_idx = num_clamp_dims - 1;
|
||||
auto input_idx = num_input_dims - 1;
|
||||
auto common_dims = std::min(num_clamp_dims, num_input_dims);
|
||||
for (int64_t i = 0; i < common_dims; ++i)
|
||||
// One of the indices is allowed to be 1; will be handled by broadcast
|
||||
TORCH_CHECK(clamp_shape[num_clamp_dims - 1 - i] == input_shape[num_input_dims - 1 - i] ||
|
||||
clamp_shape[num_clamp_dims - 1 - i] == 1 || input_shape[num_input_dims - 1 - i] == 1,
|
||||
TORCH_CHECK(clamp_shape[clamp_idx - i] == input_shape[input_idx - i] || clamp_shape[clamp_idx - i] == 1 ||
|
||||
input_shape[input_idx - i] == 1,
|
||||
op_name + ": clamp tensor trailing shape must match input tensor")
|
||||
}
|
||||
}
|
||||
@ -136,9 +145,6 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
|
||||
|
||||
auto result_type = output_t.scalar_type();
|
||||
|
||||
IntArrayRef new_min_shape;
|
||||
IntArrayRef new_max_shape;
|
||||
|
||||
auto num_min_dims = min_opt->dim();
|
||||
auto num_max_dims = max_opt->dim();
|
||||
auto num_input_dims = input_t.dim();
|
||||
@ -146,24 +152,32 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
|
||||
std::vector<int64_t> new_min_arr(num_input_dims);
|
||||
std::vector<int64_t> new_max_arr(num_input_dims);
|
||||
|
||||
if (has_min && num_min_dims < num_input_dims) {
|
||||
fill_new_shape(num_input_dims, num_min_dims, new_min_arr.data(), min_opt->sizes());
|
||||
new_min_shape = IntArrayRef(new_min_arr);
|
||||
}
|
||||
|
||||
if (has_max && num_max_dims < num_input_dims) {
|
||||
fill_new_shape(num_input_dims, num_max_dims, new_max_arr.data(), max_opt->sizes());
|
||||
new_max_shape = IntArrayRef(new_max_arr);
|
||||
}
|
||||
|
||||
Tensor min_opt_tensor;
|
||||
Tensor max_opt_tensor;
|
||||
|
||||
auto reshape_clamp_tensor = [&](const OptionalTensorRef clamp_tensor_ref,
|
||||
int64_t num_clamp_dims,
|
||||
std::vector<int64_t>& new_shape_storage) -> Tensor {
|
||||
IntArrayRef clamp_shape = clamp_tensor_ref->sizes();
|
||||
bool requires_view = false;
|
||||
|
||||
if (num_clamp_dims > num_input_dims) {
|
||||
clamp_shape = clamp_shape.slice(num_clamp_dims - num_input_dims);
|
||||
requires_view = true;
|
||||
} else if (num_clamp_dims < num_input_dims) {
|
||||
fill_new_shape(num_input_dims, num_clamp_dims, new_shape_storage.data(), clamp_shape);
|
||||
clamp_shape = IntArrayRef(new_shape_storage);
|
||||
requires_view = true;
|
||||
}
|
||||
|
||||
return requires_view ? (*clamp_tensor_ref).view(clamp_shape) : *clamp_tensor_ref;
|
||||
};
|
||||
|
||||
if (has_min) {
|
||||
min_opt_tensor = (num_min_dims < num_input_dims) ? (*min_opt).view(new_min_shape) : *min_opt;
|
||||
min_opt_tensor = reshape_clamp_tensor(min_opt, num_min_dims, new_min_arr);
|
||||
}
|
||||
if (has_max) {
|
||||
max_opt_tensor = (num_max_dims < num_input_dims) ? (*max_opt).view(new_max_shape) : *max_opt;
|
||||
max_opt_tensor = reshape_clamp_tensor(max_opt, num_max_dims, new_max_arr);
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
@ -244,8 +258,8 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
|
||||
|
||||
@autoreleasepool {
|
||||
// the optional min/max refs could affect how we build the cached graph
|
||||
std::string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
std::string key = op_name + (has_min ? ("_min:" + to_hex_key(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + to_hex_key(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
if (has_min)
|
||||
newCachedGraph->minTensor = [mpsGraph constantWithScalar:min_scalar
|
||||
|
||||
@ -301,12 +301,12 @@ class AvgPoolMicrokernelTester {
|
||||
ASSERT_NEAR(
|
||||
float(int32_t(y[i * yStride() + k])), yFP[i * kc() + k], 0.5001f)
|
||||
<< "at pixel " << i << ", channel " << k << ", n = " << n()
|
||||
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
|
||||
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
|
||||
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
|
||||
ASSERT_EQ(
|
||||
uint32_t(yRef[i * kc() + k]), uint32_t(y[i * yStride() + k]))
|
||||
<< "at pixel " << i << ", channel " << k << ", n = " << n()
|
||||
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
|
||||
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
|
||||
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
|
||||
}
|
||||
}
|
||||
@ -396,12 +396,12 @@ class AvgPoolMicrokernelTester {
|
||||
ASSERT_NEAR(
|
||||
float(int32_t(y[i * yStride() + k])), yFP[i * kc() + k], 0.5001f)
|
||||
<< "at pixel " << i << ", channel " << k << ", n = " << n()
|
||||
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
|
||||
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
|
||||
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
|
||||
ASSERT_EQ(
|
||||
uint32_t(yRef[i * kc() + k]), uint32_t(y[i * yStride() + k]))
|
||||
<< "at pixel " << i << ", channel " << k << ", n = " << n()
|
||||
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
|
||||
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
|
||||
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
|
||||
}
|
||||
}
|
||||
|
||||
@ -232,7 +232,7 @@ class MaxPoolMicrokernelTester {
|
||||
ASSERT_EQ(
|
||||
uint32_t(yRef[i * kc() + k]), uint32_t(y[i * yStride() + k]))
|
||||
<< "at pixel " << i << ", channel " << k << ", n = " << n()
|
||||
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
|
||||
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
|
||||
<< "), kc = " << kc();
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,7 +17,7 @@ inline std::vector<T> _expand_param_if_needed(
|
||||
std::ostringstream ss;
|
||||
ss << "expected " << param_name << " to be a single integer value or a "
|
||||
<< "list of " << expected_dim << " values to match the convolution "
|
||||
<< "dimensions, but got " << param_name << "=" << list_param;
|
||||
<< "dimensions, but got " << param_name << '=' << list_param;
|
||||
TORCH_CHECK(false, ss.str());
|
||||
} else {
|
||||
return list_param.vec();
|
||||
|
||||
@ -358,9 +358,9 @@ std::string Adapter::stringize() const {
|
||||
std::string device_type = get_device_type_str(properties.deviceType);
|
||||
VkPhysicalDeviceLimits limits = properties.limits;
|
||||
|
||||
ss << "{" << std::endl;
|
||||
ss << '{' << std::endl;
|
||||
ss << " Physical Device Info {" << std::endl;
|
||||
ss << " apiVersion: " << v_major << "." << v_minor << std::endl;
|
||||
ss << " apiVersion: " << v_major << '.' << v_minor << std::endl;
|
||||
ss << " driverversion: " << properties.driverVersion << std::endl;
|
||||
ss << " deviceType: " << device_type << std::endl;
|
||||
ss << " deviceName: " << properties.deviceName << std::endl;
|
||||
@ -371,7 +371,7 @@ std::string Adapter::stringize() const {
|
||||
|
||||
#define PRINT_LIMIT_PROP_VEC3(name) \
|
||||
ss << " " << std::left << std::setw(36) << #name << limits.name[0] \
|
||||
<< "," << limits.name[1] << "," << limits.name[2] << std::endl;
|
||||
<< ',' << limits.name[1] << ',' << limits.name[2] << std::endl;
|
||||
|
||||
ss << " Physical Device Limits {" << std::endl;
|
||||
PRINT_LIMIT_PROP(maxImageDimension1D);
|
||||
@ -425,7 +425,7 @@ std::string Adapter::stringize() const {
|
||||
;
|
||||
}
|
||||
ss << " ]" << std::endl;
|
||||
ss << "}";
|
||||
ss << '}';
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -33,7 +33,7 @@ std::ostream& operator<<(std::ostream& out, const VkResult result) {
|
||||
VK_RESULT_CASE(VK_ERROR_FORMAT_NOT_SUPPORTED)
|
||||
VK_RESULT_CASE(VK_ERROR_FRAGMENTED_POOL)
|
||||
default:
|
||||
out << "VK_ERROR_UNKNOWN (VkResult " << result << ")";
|
||||
out << "VK_ERROR_UNKNOWN (VkResult " << result << ')';
|
||||
break;
|
||||
}
|
||||
return out;
|
||||
@ -46,7 +46,7 @@ std::ostream& operator<<(std::ostream& out, const VkResult result) {
|
||||
//
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const SourceLocation& loc) {
|
||||
out << loc.function << " at " << loc.file << ":" << loc.line;
|
||||
out << loc.function << " at " << loc.file << ':' << loc.line;
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -66,7 +66,7 @@ Error::Error(SourceLocation source_location, const char* cond, std::string msg)
|
||||
: msg_(std::move(msg)), source_location_{source_location} {
|
||||
std::ostringstream oss;
|
||||
oss << "Exception raised from " << source_location_ << ": ";
|
||||
oss << "(" << cond << ") is false! ";
|
||||
oss << '(' << cond << ") is false! ";
|
||||
oss << msg_;
|
||||
what_ = oss.str();
|
||||
}
|
||||
|
||||
@ -173,8 +173,8 @@ void QueryPool::extract_results() {
|
||||
|
||||
static std::string stringize(const VkExtent3D& extents) {
|
||||
std::stringstream ss;
|
||||
ss << "{" << extents.width << ", " << extents.height << ", " << extents.depth
|
||||
<< "}";
|
||||
ss << '{' << extents.width << ", " << extents.height << ", " << extents.depth
|
||||
<< '}';
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
@ -149,7 +149,7 @@ VKAPI_ATTR VkBool32 VKAPI_CALL debug_report_callback_fn(
|
||||
(void)flags;
|
||||
|
||||
std::stringstream stream;
|
||||
stream << layer_prefix << " " << message_code << " " << message << std::endl;
|
||||
stream << layer_prefix << ' ' << message_code << ' ' << message << std::endl;
|
||||
const std::string log = stream.str();
|
||||
|
||||
std::cout << log;
|
||||
|
||||
@ -253,7 +253,7 @@ using vec4 = vec<4u>;
|
||||
|
||||
// uvec3 is the type representing tensor extents. Useful for debugging.
|
||||
inline std::ostream& operator<<(std::ostream& os, const uvec3& v) {
|
||||
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ")";
|
||||
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ')';
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
@ -61,6 +61,7 @@ list(APPEND ATen_CUDA_TEST_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_math_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cub_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cublas_handle_pool_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp
|
||||
|
||||
@ -246,7 +246,7 @@ void TestToCFloat() {
|
||||
void TestToString() {
|
||||
Tensor b = ones({3, 7}) * .0000001f;
|
||||
std::stringstream s;
|
||||
s << b << "\n";
|
||||
s << b << '\n';
|
||||
std::string expect = "1e-07 *";
|
||||
ASSERT_EQ_RESOLVED(s.str().substr(0, expect.size()), expect);
|
||||
}
|
||||
|
||||
77
aten/src/ATen/test/cuda_cublas_handle_pool_test.cpp
Normal file
77
aten/src/ATen/test/cuda_cublas_handle_pool_test.cpp
Normal file
@ -0,0 +1,77 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
// Test concurrent access to getCurrentCUDABlasHandle and getCUDABlasLtWorkspace
|
||||
// to verify that the data race fix is working correctly
|
||||
|
||||
TEST(CUDABlasHandlePoolTest, ConcurrentGetAndClearWorkspaces) {
|
||||
if (!at::cuda::is_available()) {
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int num_accessor_threads = 15;
|
||||
constexpr int num_clear_threads = 5;
|
||||
constexpr int iterations_per_thread = 50;
|
||||
|
||||
std::atomic<bool> stop{false};
|
||||
std::atomic<int> error_count{0};
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(num_accessor_threads + num_clear_threads);
|
||||
|
||||
// Launch accessor threads
|
||||
for (int i = 0; i < num_accessor_threads; ++i) {
|
||||
threads.emplace_back([&stop, &error_count]() {
|
||||
try {
|
||||
at::cuda::CUDAGuard device_guard(0);
|
||||
|
||||
while (!stop.load(std::memory_order_relaxed)) {
|
||||
const auto handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
const auto workspace = at::cuda::getCUDABlasLtWorkspace();
|
||||
|
||||
if (handle == nullptr || workspace == nullptr) {
|
||||
error_count++;
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
error_count++;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Launch threads that clear workspaces
|
||||
for (int i = 0; i < num_clear_threads; ++i) {
|
||||
threads.emplace_back([&error_count]() {
|
||||
try {
|
||||
for (int j = 0; j < iterations_per_thread; ++j) {
|
||||
at::cuda::clearCublasWorkspaces();
|
||||
std::this_thread::yield();
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
error_count++;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Let them run for a bit
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
stop.store(true, std::memory_order_relaxed);
|
||||
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
|
||||
EXPECT_EQ(error_count.load(), 0);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
c10::cuda::CUDACachingAllocator::init(1);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@ -33,7 +33,7 @@ struct Foo {
|
||||
static void apply(Tensor a, Tensor b) {
|
||||
scalar_type s = 1;
|
||||
std::stringstream ss;
|
||||
ss << "hello, dispatch: " << a.toString() << s << "\n";
|
||||
ss << "hello, dispatch: " << a.toString() << s << '\n';
|
||||
auto data = (scalar_type*)a.data_ptr();
|
||||
(void)data;
|
||||
}
|
||||
@ -73,8 +73,8 @@ TEST(TestScalar, TestScalar) {
|
||||
Scalar bar = 3.0;
|
||||
Half h = bar.toHalf();
|
||||
Scalar h2 = h;
|
||||
cout << "H2: " << h2.toDouble() << " " << what.toFloat() << " "
|
||||
<< bar.toDouble() << " " << what.isIntegral(false) << "\n";
|
||||
cout << "H2: " << h2.toDouble() << ' ' << what.toFloat() << ' '
|
||||
<< bar.toDouble() << ' ' << what.isIntegral(false) << '\n';
|
||||
auto gen = at::detail::getDefaultCPUGenerator();
|
||||
{
|
||||
// See Note [Acquire lock when using random generators]
|
||||
@ -84,7 +84,7 @@ TEST(TestScalar, TestScalar) {
|
||||
}
|
||||
if (at::hasCUDA()) {
|
||||
auto t2 = zeros({4, 4}, at::kCUDA);
|
||||
cout << &t2 << "\n";
|
||||
cout << &t2 << '\n';
|
||||
}
|
||||
auto t = ones({4, 4});
|
||||
|
||||
@ -129,7 +129,7 @@ TEST(TestScalar, TestScalar) {
|
||||
std::stringstream ss;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
||||
ASSERT_NO_THROW(
|
||||
ss << "hello, dispatch" << x.toString() << s << "\n");
|
||||
ss << "hello, dispatch" << x.toString() << s << '\n');
|
||||
auto data = (scalar_t*)x.data_ptr();
|
||||
(void)data;
|
||||
});
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
int main() {
|
||||
std::cout << at::ones({3,4}, at::CPU(at::kFloat)) << "\n";
|
||||
std::cout << at::ones({3,4}, at::CPU(at::kFloat)) << '\n';
|
||||
}
|
||||
|
||||
@ -1828,9 +1828,9 @@ namespace {
|
||||
#endif
|
||||
|
||||
EXPECT_EQ(u16, c10::detail::fp16_ieee_from_fp32_value(f32s[i]))
|
||||
<< "Test failed for float to uint16 " << f32s[i] << "\n";
|
||||
<< "Test failed for float to uint16 " << f32s[i] << '\n';
|
||||
EXPECT_EQ(x, c10::detail::fp16_ieee_to_fp32_value(u16))
|
||||
<< "Test failed for uint16 to float " << u16 << "\n";
|
||||
<< "Test failed for uint16 to float " << u16 << '\n';
|
||||
}
|
||||
}
|
||||
TEST(FP8E4M3Test, FP8E4M3ConversionFloat) {
|
||||
@ -1848,10 +1848,10 @@ namespace {
|
||||
EXPECT_TRUE(std::isnan(f32));
|
||||
} else {
|
||||
EXPECT_EQ(f32, c10::detail::fp8e4m3fn_to_fp32_value(input))
|
||||
<< "Test failed for u8 to float " << input << "\n";
|
||||
<< "Test failed for u8 to float " << input << '\n';
|
||||
}
|
||||
EXPECT_EQ(u8, c10::detail::fp8e4m3fn_from_fp32_value(f32))
|
||||
<< "Test failed for float to u8 " << f32 << "\n";
|
||||
<< "Test failed for float to u8 " << f32 << '\n';
|
||||
}
|
||||
}
|
||||
TEST(FP8E4M3Test, FP8E4M3BinaryAdd) {
|
||||
@ -2015,10 +2015,10 @@ namespace {
|
||||
EXPECT_TRUE(std::isnan(f32));
|
||||
} else {
|
||||
EXPECT_EQ(f32, c10::detail::fp8e5m2_to_fp32_value(input))
|
||||
<< "Test failed for u8 to float " << input << "\n";
|
||||
<< "Test failed for u8 to float " << input << '\n';
|
||||
}
|
||||
EXPECT_EQ(u8, c10::detail::fp8e5m2_from_fp32_value(f32))
|
||||
<< "Test failed for float to u8 " << f32 << "\n";
|
||||
<< "Test failed for float to u8 " << f32 << '\n';
|
||||
}
|
||||
}
|
||||
TEST(FP8E5M2Test, FP8E5M2BinaryAdd) {
|
||||
|
||||
@ -19,7 +19,7 @@ TEST(Vitals, Basic) {
|
||||
c10::utils::set_env("TORCH_VITAL", "1");
|
||||
TORCH_VITAL_DEFINE(Testing);
|
||||
TORCH_VITAL(Testing, Attribute0) << 1;
|
||||
TORCH_VITAL(Testing, Attribute1) << "1";
|
||||
TORCH_VITAL(Testing, Attribute1) << '1';
|
||||
TORCH_VITAL(Testing, Attribute2) << 1.0f;
|
||||
TORCH_VITAL(Testing, Attribute3) << 1.0;
|
||||
auto t = at::ones({1, 1});
|
||||
|
||||
@ -129,14 +129,14 @@ void showRtol(const at::Tensor& a, const at::Tensor& b) {
|
||||
std::cout << "Max Diff allowed: " << maxDiff << std::endl;
|
||||
if (diff.sizes().size() == 2) {
|
||||
for (const auto y : c10::irange(diff.sizes()[0])) {
|
||||
std::cout << y << ":";
|
||||
std::cout << y << ':';
|
||||
for (const auto x : c10::irange(diff.sizes()[1])) {
|
||||
float diff_xy = diff[y][x].item<float>();
|
||||
if (diff_xy > maxDiff) {
|
||||
std::cout << std::setw(5) << x;
|
||||
}
|
||||
else {
|
||||
std::cout << std::setw(5) << " ";
|
||||
std::cout << std::setw(5) << ' ';
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
@ -3276,7 +3276,7 @@ TEST_F(VulkanAPITest, masked_fill_invalidinputs_exceptions) {
|
||||
|
||||
void print_shape(const std::vector<int64_t>& shape) {
|
||||
for (const auto& num : shape) {
|
||||
std::cout << num << " ";
|
||||
std::cout << num << ' ';
|
||||
}
|
||||
}
|
||||
|
||||
@ -3367,7 +3367,7 @@ void test_masked_fill_scalar(
|
||||
print_shape(tmp_curr_input_shape);
|
||||
std::cout << "], and mask of shape [";
|
||||
print_shape(tmp_curr_mask_shape);
|
||||
std::cout << "]" << std::endl;
|
||||
std::cout << ']' << std::endl;
|
||||
}
|
||||
|
||||
ASSERT_TRUE(check);
|
||||
@ -4542,9 +4542,9 @@ void test_softmax(const at::IntArrayRef shape, bool log_softmax = false) {
|
||||
if (!check) {
|
||||
std::cout << "Softmax test failed on axis " << dim << "for tensor dims {";
|
||||
for (uint32_t place = 0; place < shape.size() - 1; place++) {
|
||||
std::cout << shape[place] << " ";
|
||||
std::cout << shape[place] << ' ';
|
||||
}
|
||||
std::cout << shape.back() << "}" << std::endl;
|
||||
std::cout << shape.back() << '}' << std::endl;
|
||||
showRtol(out_cpu, out_vulkan.cpu());
|
||||
}
|
||||
ASSERT_TRUE(check);
|
||||
|
||||
@ -95,7 +95,7 @@ void showRtol(
|
||||
std::cout << "Max Diff found is: " << diff.max().item<double>() << std::endl;
|
||||
if (diff.sizes().size() == 2) {
|
||||
for (const auto y : c10::irange(diff.sizes()[0])) {
|
||||
std::cout << y << ":";
|
||||
std::cout << y << ':';
|
||||
for (const auto x : c10::irange(diff.sizes()[1])) {
|
||||
double diff_xy = diff[y][x].item<double>();
|
||||
if (diff_xy > maxDiff) {
|
||||
@ -109,7 +109,7 @@ void showRtol(
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::cout << std::setw(5) << " ";
|
||||
std::cout << std::setw(5) << ' ';
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
@ -148,19 +148,19 @@ using at::native::vulkan::api::utils::ivec4;
|
||||
using at::native::vulkan::api::utils::vec4;
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const vec4& v) {
|
||||
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
|
||||
<< v.data[3u] << ")";
|
||||
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
|
||||
<< v.data[3u] << ')';
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const ivec3& v) {
|
||||
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ")";
|
||||
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ')';
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const ivec4& v) {
|
||||
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
|
||||
<< v.data[3u] << ")";
|
||||
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
|
||||
<< v.data[3u] << ')';
|
||||
return os;
|
||||
}
|
||||
|
||||
@ -3379,51 +3379,51 @@ bool _test_quantized_linear(
|
||||
showRtol(out_cpu_dequant, out_vk_to_cpu_dequant);
|
||||
}
|
||||
if (xpos != -1 && ypos != -1) {
|
||||
std::cout << "\nFailure caused on row/col: " << ypos << "/" << xpos
|
||||
<< "\n";
|
||||
std::cout << "\nFailure caused on row/col: " << ypos << '/' << xpos
|
||||
<< '\n';
|
||||
std::cout << "Input tensor scale: " << scale << " zerop: " << zero_point
|
||||
<< "\n";
|
||||
std::cout << "Input tensor row " << ypos << "\n";
|
||||
<< '\n';
|
||||
std::cout << "Input tensor row " << ypos << '\n';
|
||||
for (int i = 0; i < input_cpu.sizes()[1]; i++) {
|
||||
std::cout << input_cpu[ypos][i].item<double>() << ", ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
std::cout << '\n';
|
||||
|
||||
std::cout << "Weight tensor scale: " << w_scale
|
||||
<< " zerop: " << w_zero_point << "\n";
|
||||
std::cout << "Weight tensor col " << xpos << "\n";
|
||||
<< " zerop: " << w_zero_point << '\n';
|
||||
std::cout << "Weight tensor col " << xpos << '\n';
|
||||
for (int i = 0; i < weight.sizes()[1]; i++) {
|
||||
std::cout << weight[xpos][i].item<double>() << ", ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
std::cout << '\n';
|
||||
|
||||
std::cout << "Input tensor quantized row " << ypos << " with dtype "
|
||||
<< (input_quant_dtype_int8 ? "QInt8" : "QUInt8") << "\n";
|
||||
<< (input_quant_dtype_int8 ? "QInt8" : "QUInt8") << '\n';
|
||||
for (int i = 0; i < input_cpu.sizes()[1]; i++) {
|
||||
std::cout << input_cpu_quantized[ypos][i].item<double>() << ", ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
std::cout << '\n';
|
||||
|
||||
std::cout << "Weight tensor quantized col " << xpos << " with dtype "
|
||||
<< (weight_quant_dtype_int8 ? "QInt8" : "QUInt8") << "\n";
|
||||
<< (weight_quant_dtype_int8 ? "QInt8" : "QUInt8") << '\n';
|
||||
for (int i = 0; i < weight.sizes()[1]; i++) {
|
||||
std::cout << weight_cpu_quantized[xpos][i].item<double>() << ", ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
std::cout << '\n';
|
||||
|
||||
std::cout << "bias tensor\n";
|
||||
for (int i = 0; i < bias.sizes()[0]; i++) {
|
||||
std::cout << bias[i].item<double>() << ", ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
std::cout << '\n';
|
||||
|
||||
std::cout << "out_scale: " << out_scale
|
||||
<< " out_zero_point: " << out_zero_point << "\n";
|
||||
<< " out_zero_point: " << out_zero_point << '\n';
|
||||
|
||||
std::cout << "cpu unmatched output: "
|
||||
<< out_cpu_dequant[ypos][xpos].item<double>() << "\n";
|
||||
<< out_cpu_dequant[ypos][xpos].item<double>() << '\n';
|
||||
std::cout << "vk unmatched output: "
|
||||
<< out_vk_to_cpu_dequant[ypos][xpos].item<double>() << "\n";
|
||||
<< out_vk_to_cpu_dequant[ypos][xpos].item<double>() << '\n';
|
||||
}
|
||||
}
|
||||
return check;
|
||||
|
||||
@ -10,6 +10,13 @@
|
||||
...
|
||||
}
|
||||
|
||||
{
|
||||
ignore_empty_generic_uninitialised_conditional_jump
|
||||
Memcheck:Cond
|
||||
fun:_ZN2at6detail13empty_genericEN3c108ArrayRefIlEEPNS1_9AllocatorENS1_14DispatchKeySetENS1_10ScalarTypeESt8optionalINS1_12MemoryFormatEE
|
||||
...
|
||||
}
|
||||
|
||||
{
|
||||
Cond_cuda
|
||||
Memcheck:Cond
|
||||
|
||||
@ -189,6 +189,10 @@ skip:
|
||||
- hf_Whisper
|
||||
- hf_distil_whisper
|
||||
- timm_vision_transformer_large
|
||||
# https://github.com/pytorch/pytorch/issues/167895
|
||||
- stable_diffusion
|
||||
- stable_diffusion_text_encoder
|
||||
- stable_diffusion_unet
|
||||
|
||||
device:
|
||||
cpu:
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# These load paths point to different files in internal and OSS environment
|
||||
|
||||
load("@bazel_skylib//lib:paths.bzl", "paths")
|
||||
load("//tools/build_defs:cell_defs.bzl", "get_fbsource_cell")
|
||||
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
|
||||
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
|
||||
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
|
||||
@ -590,6 +591,9 @@ def pt_operator_query_codegen(
|
||||
pt_allow_forced_schema_registration = True,
|
||||
compatible_with = [],
|
||||
apple_sdks = None):
|
||||
if get_fbsource_cell() == "fbcode":
|
||||
return
|
||||
|
||||
oplist_dir_name = name + "_pt_oplist"
|
||||
|
||||
# @lint-ignore BUCKLINT
|
||||
@ -865,6 +869,9 @@ def define_buck_targets(
|
||||
pt_xplat_cxx_library = fb_xplat_cxx_library,
|
||||
c2_fbandroid_xplat_compiler_flags = [],
|
||||
labels = []):
|
||||
if get_fbsource_cell() == "fbcode":
|
||||
return
|
||||
|
||||
# @lint-ignore BUCKLINT
|
||||
fb_native.filegroup(
|
||||
name = "metal_build_srcs",
|
||||
|
||||
@ -176,7 +176,7 @@ std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
|
||||
os << k;
|
||||
first = false;
|
||||
}
|
||||
os << ")";
|
||||
os << ')';
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ struct C10_API SafePyObject {
|
||||
(*other.pyinterpreter_)->incref(other.data_);
|
||||
}
|
||||
if (data_ != nullptr) {
|
||||
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
|
||||
(*pyinterpreter_)->decref(data_);
|
||||
}
|
||||
data_ = other.data_;
|
||||
pyinterpreter_ = other.pyinterpreter_;
|
||||
@ -53,7 +53,7 @@ struct C10_API SafePyObject {
|
||||
|
||||
~SafePyObject() {
|
||||
if (data_ != nullptr) {
|
||||
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
|
||||
(*pyinterpreter_)->decref(data_);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -34,20 +34,6 @@ namespace c10 {
|
||||
// See [dtype Macros note] in torch/headeronly/core/ScalarType.h
|
||||
// regarding macros.
|
||||
|
||||
template <typename T>
|
||||
struct CppTypeToScalarType;
|
||||
|
||||
#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
|
||||
template <> \
|
||||
struct CppTypeToScalarType<cpp_type> \
|
||||
: std:: \
|
||||
integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \
|
||||
};
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
||||
|
||||
#undef SPECIALIZE_CppTypeToScalarType
|
||||
|
||||
#define DEFINE_CONSTANT(_, name) \
|
||||
constexpr ScalarType k##name = ScalarType::name;
|
||||
|
||||
|
||||
@ -48,6 +48,30 @@ void warnDeprecatedDataPtr() {
|
||||
TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid.");
|
||||
}
|
||||
|
||||
void StorageImpl::incref_pyobject() const {
|
||||
// Because intrusive_ptr incref uses relaxed memory order, we need to
|
||||
// do an acquire fence to ensure that the kHasPyObject bit was
|
||||
// observed before the load of the PyObject* below.
|
||||
// NB: This is a no-op on x86/x86-64
|
||||
std::atomic_thread_fence(std::memory_order_acquire);
|
||||
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
|
||||
}
|
||||
|
||||
void StorageImpl::decref_pyobject() const {
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
|
||||
}
|
||||
|
||||
bool StorageImpl::try_incref_pyobject() const {
|
||||
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
|
||||
if (C10_UNLIKELY(!interp)) {
|
||||
return false;
|
||||
}
|
||||
return (*interp)->try_incref(pyobj_slot_);
|
||||
}
|
||||
|
||||
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
|
||||
// Allowlist verification.
|
||||
// Only if the devicetype is in the allowlist,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user