mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-18 17:45:09 +08:00
Compare commits
91 Commits
whc/pp_fix
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| d65371b41f | |||
| 4414e1bff0 | |||
| 694f9b943c | |||
| 01deee228a | |||
| 1233be0923 | |||
| 02b55c3f4a | |||
| ae3ce54f27 | |||
| 2f3bb7482c | |||
| 567dcdba75 | |||
| 77acc66df9 | |||
| 95d1df7d4e | |||
| 094e529c64 | |||
| a4c7bf7e8d | |||
| 22ccd44d73 | |||
| 39ebab1dd9 | |||
| 4c152a71ad | |||
| 1b43d6cd4e | |||
| 2b69673bbf | |||
| 2f74916e36 | |||
| 2b5eabc74b | |||
| 9ff95f6835 | |||
| 6fdb974f4a | |||
| 661d1653aa | |||
| 53809f9640 | |||
| 93ddd38ecd | |||
| 5804408f1b | |||
| 99117c1238 | |||
| b9bccec3bc | |||
| ca3aaef66e | |||
| f2e6f94081 | |||
| aa504d4d2a | |||
| d8ce6f8df9 | |||
| 4322354770 | |||
| 363385ad3e | |||
| e2e10753d7 | |||
| 5d99a795f5 | |||
| 2245d7d3b9 | |||
| 98b94b90dd | |||
| 5cdbda140c | |||
| 0ec53beaeb | |||
| 79fc0a9141 | |||
| d01a7b0241 | |||
| deabb3e36d | |||
| 79d2397b6b | |||
| 6ef3a62c36 | |||
| 530e782239 | |||
| c66a6c432e | |||
| 3d7a8b7e61 | |||
| de0d69b2c4 | |||
| bc60b86066 | |||
| d7782ddde7 | |||
| fb04e9ad03 | |||
| cfe799b4aa | |||
| b7f52773e6 | |||
| f6b54d8899 | |||
| da91bf5262 | |||
| 1c1638297e | |||
| ee0b5b4b1c | |||
| fcfb213c5a | |||
| 08042bbb9c | |||
| e20ca3bc2e | |||
| 4ed26f7382 | |||
| 4c79305b87 | |||
| f4b8c4f907 | |||
| d629b7a459 | |||
| 0922ba5f42 | |||
| c87295c044 | |||
| 7aa210d215 | |||
| 5a368b8010 | |||
| 602102be50 | |||
| 200156e385 | |||
| a2daf3fc86 | |||
| 52b45c16de | |||
| 2ef85bed5a | |||
| d99c6bcf69 | |||
| 8378abda84 | |||
| 5b42a5d9a6 | |||
| caca3f2eec | |||
| 9e2bf129e1 | |||
| c429b1fc5c | |||
| 1176b2b0b7 | |||
| dd37a1a434 | |||
| a74adcf80e | |||
| 5eac46a011 | |||
| e0fff31ae3 | |||
| 7ede33b8e3 | |||
| 065176cd97 | |||
| 02ee7dd7d3 | |||
| 99fdca8f4d | |||
| 9d1a74cb0c | |||
| 40e6f090d9 |
@ -75,9 +75,11 @@ if [[ "$ARCH" == "aarch64" ]]; then
|
||||
# ARM system libraries
|
||||
DEPS_LIST+=(
|
||||
"/usr/lib64/libgfortran.so.5"
|
||||
"/opt/OpenBLAS/lib/libopenblas.so.0"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libgfortran.so.5"
|
||||
"libopenblas.so.0"
|
||||
)
|
||||
fi
|
||||
|
||||
|
||||
@ -100,337 +100,6 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _compile_and_extract_symbols(
|
||||
cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Helper to compile a C++ file and extract all symbols.
|
||||
|
||||
Args:
|
||||
cpp_content: C++ source code to compile
|
||||
compile_flags: Compilation flags
|
||||
exclude_list: List of symbol names to exclude. Defaults to ["main"].
|
||||
|
||||
Returns:
|
||||
List of all symbols found in the object file (excluding those in exclude_list).
|
||||
"""
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
if exclude_list is None:
|
||||
exclude_list = ["main"]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmppath = Path(tmpdir)
|
||||
cpp_file = tmppath / "test.cpp"
|
||||
obj_file = tmppath / "test.o"
|
||||
|
||||
cpp_file.write_text(cpp_content)
|
||||
|
||||
result = subprocess.run(
|
||||
compile_flags + [str(cpp_file), "-o", str(obj_file)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Compilation failed: {result.stderr}")
|
||||
|
||||
symbols = get_symbols(str(obj_file))
|
||||
|
||||
# Return all symbol names, excluding those in the exclude list
|
||||
return [name for _addr, _stype, name in symbols if name not in exclude_list]
|
||||
|
||||
|
||||
def check_stable_only_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts.
|
||||
|
||||
This approach tests:
|
||||
1. WITHOUT macros -> many torch symbols exposed
|
||||
2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden)
|
||||
3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden)
|
||||
4. WITH both macros -> zero torch symbols (all hidden)
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
test_cpp_content = """
|
||||
// Main torch C++ API headers
|
||||
#include <torch/torch.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
// ATen tensor library
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
// Core c10 headers (commonly used)
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
int main() { return 0; }
|
||||
"""
|
||||
|
||||
base_compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c", # Compile only, don't link
|
||||
]
|
||||
|
||||
# Compile WITHOUT any macros
|
||||
symbols_without = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=base_compile_flags,
|
||||
)
|
||||
|
||||
# We expect constexpr symbols, inline functions used by other headers etc.
|
||||
# to produce symbols
|
||||
num_symbols_without = len(symbols_without)
|
||||
print(f"Found {num_symbols_without} symbols without any macros defined")
|
||||
assert num_symbols_without != 0, (
|
||||
"Expected a non-zero number of symbols without any macros"
|
||||
)
|
||||
|
||||
# Compile WITH TORCH_STABLE_ONLY (expect 0 symbols)
|
||||
compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"]
|
||||
|
||||
symbols_with_stable_only = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_stable_only,
|
||||
)
|
||||
|
||||
num_symbols_with_stable_only = len(symbols_with_stable_only)
|
||||
assert num_symbols_with_stable_only == 0, (
|
||||
f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}"
|
||||
)
|
||||
|
||||
# Compile WITH TORCH_TARGET_VERSION (expect 0 symbols)
|
||||
compile_flags_with_target_version = base_compile_flags + [
|
||||
"-DTORCH_TARGET_VERSION=1"
|
||||
]
|
||||
|
||||
symbols_with_target_version = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_target_version,
|
||||
)
|
||||
|
||||
num_symbols_with_target_version = len(symbols_with_target_version)
|
||||
assert num_symbols_with_target_version == 0, (
|
||||
f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}"
|
||||
)
|
||||
|
||||
# Compile WITH both macros (expect 0 symbols)
|
||||
compile_flags_with_both = base_compile_flags + [
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
"-DTORCH_TARGET_VERSION=1",
|
||||
]
|
||||
|
||||
symbols_with_both = _compile_and_extract_symbols(
|
||||
cpp_content=test_cpp_content,
|
||||
compile_flags=compile_flags_with_both,
|
||||
)
|
||||
|
||||
num_symbols_with_both = len(symbols_with_both)
|
||||
assert num_symbols_with_both == 0, (
|
||||
f"Expected no symbols with both macros, but found {num_symbols_with_both}"
|
||||
)
|
||||
|
||||
|
||||
def check_stable_api_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that stable API headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
stable_dir = include_dir / "torch" / "csrc" / "stable"
|
||||
assert stable_dir.exists(), f"Expected {stable_dir} to be present"
|
||||
|
||||
stable_headers = list(stable_dir.rglob("*.h"))
|
||||
if not stable_headers:
|
||||
raise RuntimeError("Could not find any stable headers")
|
||||
|
||||
includes = []
|
||||
for header in stable_headers:
|
||||
rel_path = header.relative_to(include_dir)
|
||||
includes.append(f"#include <{rel_path.as_posix()}>")
|
||||
|
||||
includes_str = "\n".join(includes)
|
||||
test_stable_content = f"""
|
||||
{includes_str}
|
||||
int main() {{ return 0; }}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_stable = _compile_and_extract_symbols(
|
||||
cpp_content=test_stable_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_stable = len(symbols_stable)
|
||||
print(f"Found {num_symbols_stable} symbols in torch/csrc/stable")
|
||||
assert num_symbols_stable > 0, (
|
||||
f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_stable} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_headeronly_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# Find all headers in torch/headeronly
|
||||
headeronly_dir = include_dir / "torch" / "headeronly"
|
||||
assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present"
|
||||
headeronly_headers = list(headeronly_dir.rglob("*.h"))
|
||||
if not headeronly_headers:
|
||||
raise RuntimeError("Could not find any headeronly headers")
|
||||
|
||||
# Filter out platform-specific headers that may not compile everywhere
|
||||
platform_specific_keywords = [
|
||||
"cpu/vec",
|
||||
]
|
||||
|
||||
filtered_headers = []
|
||||
for header in headeronly_headers:
|
||||
rel_path = header.relative_to(include_dir).as_posix()
|
||||
if not any(
|
||||
keyword in rel_path.lower() for keyword in platform_specific_keywords
|
||||
):
|
||||
filtered_headers.append(header)
|
||||
|
||||
includes = []
|
||||
for header in filtered_headers:
|
||||
rel_path = header.relative_to(include_dir)
|
||||
includes.append(f"#include <{rel_path.as_posix()}>")
|
||||
|
||||
includes_str = "\n".join(includes)
|
||||
test_headeronly_content = f"""
|
||||
{includes_str}
|
||||
int main() {{ return 0; }}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_headeronly = _compile_and_extract_symbols(
|
||||
cpp_content=test_headeronly_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_headeronly = len(symbols_headeronly)
|
||||
print(f"Found {num_symbols_headeronly} symbols in torch/headeronly")
|
||||
assert num_symbols_headeronly > 0, (
|
||||
f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_headeronly} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_aoti_shim_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# There are no constexpr symbols etc., so we need to actually use functions
|
||||
# so that some symbols are found.
|
||||
test_shim_content = """
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
int main() {
|
||||
int32_t (*fp1)() = &aoti_torch_device_type_cpu;
|
||||
int32_t (*fp2)() = &aoti_torch_dtype_float32;
|
||||
(void)fp1; (void)fp2;
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_shim = _compile_and_extract_symbols(
|
||||
cpp_content=test_shim_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_shim = len(symbols_shim)
|
||||
assert num_symbols_shim > 0, (
|
||||
f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_shim} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_stable_c_shim_symbols(install_root: Path) -> None:
|
||||
"""
|
||||
Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY.
|
||||
"""
|
||||
include_dir = install_root / "include"
|
||||
assert include_dir.exists(), f"Expected {include_dir} to be present"
|
||||
|
||||
# Check if the stable C shim exists
|
||||
stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h"
|
||||
if not stable_shim.exists():
|
||||
raise RuntimeError("Could not find stable c shim")
|
||||
|
||||
# There are no constexpr symbols etc., so we need to actually use functions
|
||||
# so that some symbols are found.
|
||||
test_stable_shim_content = """
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
int main() {
|
||||
// Reference stable C API functions to create undefined symbols
|
||||
AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string;
|
||||
AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads;
|
||||
(void)fp1; (void)fp2;
|
||||
return 0;
|
||||
}
|
||||
"""
|
||||
|
||||
compile_flags = [
|
||||
"g++",
|
||||
"-std=c++17",
|
||||
f"-I{include_dir}",
|
||||
f"-I{include_dir}/torch/csrc/api/include",
|
||||
"-c",
|
||||
"-DTORCH_STABLE_ONLY",
|
||||
]
|
||||
|
||||
symbols_stable_shim = _compile_and_extract_symbols(
|
||||
cpp_content=test_stable_shim_content,
|
||||
compile_flags=compile_flags,
|
||||
)
|
||||
num_symbols_stable_shim = len(symbols_stable_shim)
|
||||
assert num_symbols_stable_shim > 0, (
|
||||
f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, "
|
||||
f"but found {num_symbols_stable_shim} symbols"
|
||||
)
|
||||
|
||||
|
||||
def check_lib_symbols_for_abi_correctness(lib: str) -> None:
|
||||
print(f"lib: {lib}")
|
||||
cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS)
|
||||
@ -460,13 +129,6 @@ def main() -> None:
|
||||
check_lib_symbols_for_abi_correctness(libtorch_cpu_path)
|
||||
check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path)
|
||||
|
||||
# Check symbols when TORCH_STABLE_ONLY is defined
|
||||
check_stable_only_symbols(install_root)
|
||||
check_stable_api_symbols(install_root)
|
||||
check_headeronly_symbols(install_root)
|
||||
check_aoti_shim_symbols(install_root)
|
||||
check_stable_c_shim_symbols(install_root)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -389,6 +389,13 @@ test_lazy_tensor_meta_reference_disabled() {
|
||||
export -n TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE
|
||||
}
|
||||
|
||||
test_dynamo_core() {
|
||||
time python test/run_test.py \
|
||||
--include-dynamo-core-tests \
|
||||
--verbose \
|
||||
--upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_dynamo_wrapped_shard() {
|
||||
if [[ -z "$NUM_TEST_SHARDS" ]]; then
|
||||
@ -1814,6 +1821,8 @@ elif [[ "${TEST_CONFIG}" == *inductor* ]]; then
|
||||
test_inductor_shard "${SHARD_NUMBER}"
|
||||
elif [[ "${TEST_CONFIG}" == *einops* ]]; then
|
||||
test_einops
|
||||
elif [[ "${TEST_CONFIG}" == *dynamo_core* ]]; then
|
||||
test_dynamo_core
|
||||
elif [[ "${TEST_CONFIG}" == *dynamo_wrapped* ]]; then
|
||||
install_torchvision
|
||||
test_dynamo_wrapped_shard "${SHARD_NUMBER}"
|
||||
|
||||
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
||||
07b6cbde121417a70e4dc871adb6d27030e0ce3f
|
||||
ee1a1350eb37804b94334768f328144f058f14e9
|
||||
|
||||
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
||||
acccf86477759b2d3500f1ae1be065f7b1e409ec
|
||||
2d82dc5caa336d179d9b46ac4a0fb8c43d84c5cc
|
||||
|
||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a
|
||||
94631807d22c09723dd006f7be5beb649d5f88d0
|
||||
|
||||
1
.github/pytorch-probot.yml
vendored
1
.github/pytorch-probot.yml
vendored
@ -7,6 +7,7 @@ ciflow_push_tags:
|
||||
- ciflow/binaries
|
||||
- ciflow/binaries_libtorch
|
||||
- ciflow/binaries_wheel
|
||||
- ciflow/dynamo
|
||||
- ciflow/h100
|
||||
- ciflow/h100-cutlass-backend
|
||||
- ciflow/h100-distributed
|
||||
|
||||
2
.github/workflows/_linux-test.yml
vendored
2
.github/workflows/_linux-test.yml
vendored
@ -326,7 +326,7 @@ jobs:
|
||||
SCCACHE_BUCKET: ${{ !contains(matrix.runner, 'b200') && 'ossci-compiler-cache-circleci-v2' || '' }}
|
||||
SCCACHE_REGION: ${{ !contains(matrix.runner, 'b200') && 'us-east-1' || '' }}
|
||||
SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }}
|
||||
DOCKER_IMAGE: ${{ inputs.docker-image }}
|
||||
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }}
|
||||
XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla
|
||||
PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }}
|
||||
|
||||
70
.github/workflows/dynamo-unittest.yml
vendored
Normal file
70
.github/workflows/dynamo-unittest.yml
vendored
Normal file
@ -0,0 +1,70 @@
|
||||
# Workflow: Dynamo Unit Test
|
||||
# runs unit tests for dynamo.
|
||||
name: dynamo-unittest
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/dynamo/*
|
||||
workflow_call:
|
||||
schedule:
|
||||
- cron: 29 8 * * * # about 1:29am PDT
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
opt_out_experiments: lf
|
||||
|
||||
dynamo-build:
|
||||
name: dynamo-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.11', '3.12']
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-py${{ matrix.python-version }}-clang12
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang12
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "dynamo_core", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
dynamo-test:
|
||||
name: dynamo-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: [get-label-type, dynamo-build]
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.11', '3.12']
|
||||
with:
|
||||
build-environment: linux-jammy-py${{ matrix.python-version }}-clang12
|
||||
docker-image: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang12
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "dynamo_core", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/core/TensorAccessor.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <c10/util/Deprecated.h>
|
||||
@ -11,252 +12,37 @@
|
||||
|
||||
namespace at {
|
||||
|
||||
// The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor
|
||||
// is used to enable the __restrict__ keyword/modifier for the data
|
||||
// passed to cuda.
|
||||
template <typename T>
|
||||
struct DefaultPtrTraits {
|
||||
typedef T* PtrType;
|
||||
};
|
||||
|
||||
using torch::headeronly::DefaultPtrTraits;
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
template <typename T>
|
||||
struct RestrictPtrTraits {
|
||||
typedef T* __restrict__ PtrType;
|
||||
};
|
||||
using torch::headeronly::RestrictPtrTraits;
|
||||
#endif
|
||||
|
||||
// TensorAccessorBase and TensorAccessor are used for both CPU and CUDA tensors.
|
||||
// For CUDA tensors it is used in device code (only). This means that we restrict ourselves
|
||||
// to functions and types available there (e.g. IntArrayRef isn't).
|
||||
|
||||
// The PtrTraits argument is only relevant to cuda to support `__restrict__` pointers.
|
||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||
class TensorAccessorBase {
|
||||
public:
|
||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||
using TensorAccessorBase = torch::headeronly::detail::TensorAccessorBase<c10::IntArrayRef, T, N, PtrTraits, index_t>;
|
||||
|
||||
C10_HOST_DEVICE TensorAccessorBase(
|
||||
PtrType data_,
|
||||
const index_t* sizes_,
|
||||
const index_t* strides_)
|
||||
: data_(data_), sizes_(sizes_), strides_(strides_) {}
|
||||
C10_HOST IntArrayRef sizes() const {
|
||||
return IntArrayRef(sizes_,N);
|
||||
}
|
||||
C10_HOST IntArrayRef strides() const {
|
||||
return IntArrayRef(strides_,N);
|
||||
}
|
||||
C10_HOST_DEVICE index_t stride(index_t i) const {
|
||||
return strides_[i];
|
||||
}
|
||||
C10_HOST_DEVICE index_t size(index_t i) const {
|
||||
return sizes_[i];
|
||||
}
|
||||
C10_HOST_DEVICE PtrType data() {
|
||||
return data_;
|
||||
}
|
||||
C10_HOST_DEVICE const PtrType data() const {
|
||||
return data_;
|
||||
}
|
||||
protected:
|
||||
PtrType data_;
|
||||
const index_t* sizes_;
|
||||
const index_t* strides_;
|
||||
};
|
||||
|
||||
// The `TensorAccessor` is typically instantiated for CPU `Tensor`s using
|
||||
// `Tensor.accessor<T, N>()`.
|
||||
// For CUDA `Tensor`s, `GenericPackedTensorAccessor` is used on the host and only
|
||||
// indexing on the device uses `TensorAccessor`s.
|
||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||
class TensorAccessor : public TensorAccessorBase<T,N,PtrTraits,index_t> {
|
||||
public:
|
||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||
using TensorAccessor = torch::headeronly::detail::TensorAccessor<c10::IntArrayRef, T, N, PtrTraits, index_t>;
|
||||
|
||||
C10_HOST_DEVICE TensorAccessor(
|
||||
PtrType data_,
|
||||
const index_t* sizes_,
|
||||
const index_t* strides_)
|
||||
: TensorAccessorBase<T, N, PtrTraits, index_t>(data_,sizes_,strides_) {}
|
||||
namespace detail {
|
||||
|
||||
C10_HOST_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
|
||||
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE const TensorAccessor<T, N-1, PtrTraits, index_t> operator[](index_t i) const {
|
||||
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, template <typename U> class PtrTraits, typename index_t>
|
||||
class TensorAccessor<T,1,PtrTraits,index_t> : public TensorAccessorBase<T,1,PtrTraits,index_t> {
|
||||
public:
|
||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||
|
||||
C10_HOST_DEVICE TensorAccessor(
|
||||
PtrType data_,
|
||||
const index_t* sizes_,
|
||||
const index_t* strides_)
|
||||
: TensorAccessorBase<T, 1, PtrTraits, index_t>(data_,sizes_,strides_) {}
|
||||
C10_HOST_DEVICE T & operator[](index_t i) {
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
|
||||
return this->data_[this->strides_[0]*i];
|
||||
}
|
||||
C10_HOST_DEVICE const T & operator[](index_t i) const {
|
||||
return this->data_[this->strides_[0]*i];
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// GenericPackedTensorAccessorBase and GenericPackedTensorAccessor are used on for CUDA `Tensor`s on the host
|
||||
// and as
|
||||
// In contrast to `TensorAccessor`s, they copy the strides and sizes on instantiation (on the host)
|
||||
// in order to transfer them on the device when calling kernels.
|
||||
// On the device, indexing of multidimensional tensors gives to `TensorAccessor`s.
|
||||
// Use RestrictPtrTraits as PtrTraits if you want the tensor's data pointer to be marked as __restrict__.
|
||||
// Instantiation from data, sizes, strides is only needed on the host and std::copy isn't available
|
||||
// on the device, so those functions are host only.
|
||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||
class GenericPackedTensorAccessorBase {
|
||||
public:
|
||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||
C10_HOST GenericPackedTensorAccessorBase(
|
||||
PtrType data_,
|
||||
const index_t* sizes_,
|
||||
const index_t* strides_)
|
||||
: data_(data_) {
|
||||
std::copy(sizes_, sizes_ + N, std::begin(this->sizes_));
|
||||
std::copy(strides_, strides_ + N, std::begin(this->strides_));
|
||||
}
|
||||
|
||||
// if index_t is not int64_t, we want to have an int64_t constructor
|
||||
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
|
||||
C10_HOST GenericPackedTensorAccessorBase(
|
||||
PtrType data_,
|
||||
const source_index_t* sizes_,
|
||||
const source_index_t* strides_)
|
||||
: data_(data_) {
|
||||
for (const auto i : c10::irange(N)) {
|
||||
this->sizes_[i] = sizes_[i];
|
||||
this->strides_[i] = strides_[i];
|
||||
}
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE index_t stride(index_t i) const {
|
||||
return strides_[i];
|
||||
}
|
||||
C10_HOST_DEVICE index_t size(index_t i) const {
|
||||
return sizes_[i];
|
||||
}
|
||||
C10_HOST_DEVICE PtrType data() {
|
||||
return data_;
|
||||
}
|
||||
C10_HOST_DEVICE const PtrType data() const {
|
||||
return data_;
|
||||
}
|
||||
protected:
|
||||
PtrType data_;
|
||||
// NOLINTNEXTLINE(*c-arrays*)
|
||||
index_t sizes_[N];
|
||||
// NOLINTNEXTLINE(*c-arrays*)
|
||||
index_t strides_[N];
|
||||
C10_HOST void bounds_check_(index_t i) const {
|
||||
TORCH_CHECK_INDEX(
|
||||
template <size_t N, typename index_t>
|
||||
struct IndexBoundsCheck {
|
||||
IndexBoundsCheck(index_t i) {
|
||||
TORCH_CHECK_INDEX(
|
||||
0 <= i && i < index_t{N},
|
||||
"Index ",
|
||||
i,
|
||||
" is not within bounds of a tensor of dimension ",
|
||||
N);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||
class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase<T,N,PtrTraits,index_t> {
|
||||
public:
|
||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||
|
||||
C10_HOST GenericPackedTensorAccessor(
|
||||
PtrType data_,
|
||||
const index_t* sizes_,
|
||||
const index_t* strides_)
|
||||
: GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
||||
|
||||
// if index_t is not int64_t, we want to have an int64_t constructor
|
||||
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
|
||||
C10_HOST GenericPackedTensorAccessor(
|
||||
PtrType data_,
|
||||
const source_index_t* sizes_,
|
||||
const source_index_t* strides_)
|
||||
: GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
||||
|
||||
C10_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
|
||||
index_t* new_sizes = this->sizes_ + 1;
|
||||
index_t* new_strides = this->strides_ + 1;
|
||||
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
|
||||
}
|
||||
|
||||
C10_DEVICE const TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) const {
|
||||
const index_t* new_sizes = this->sizes_ + 1;
|
||||
const index_t* new_strides = this->strides_ + 1;
|
||||
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
|
||||
}
|
||||
|
||||
/// Returns a PackedTensorAccessor of the same dimension after transposing the
|
||||
/// two dimensions given. Does not actually move elements; transposition is
|
||||
/// made by permuting the size/stride arrays. If the dimensions are not valid,
|
||||
/// asserts.
|
||||
C10_HOST GenericPackedTensorAccessor<T, N, PtrTraits, index_t> transpose(
|
||||
index_t dim1,
|
||||
index_t dim2) const {
|
||||
this->bounds_check_(dim1);
|
||||
this->bounds_check_(dim2);
|
||||
GenericPackedTensorAccessor<T, N, PtrTraits, index_t> result(
|
||||
this->data_, this->sizes_, this->strides_);
|
||||
std::swap(result.strides_[dim1], result.strides_[dim2]);
|
||||
std::swap(result.sizes_[dim1], result.sizes_[dim2]);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, template <typename U> class PtrTraits, typename index_t>
|
||||
class GenericPackedTensorAccessor<T,1,PtrTraits,index_t> : public GenericPackedTensorAccessorBase<T,1,PtrTraits,index_t> {
|
||||
public:
|
||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||
C10_HOST GenericPackedTensorAccessor(
|
||||
PtrType data_,
|
||||
const index_t* sizes_,
|
||||
const index_t* strides_)
|
||||
: GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
||||
|
||||
// if index_t is not int64_t, we want to have an int64_t constructor
|
||||
template <typename source_index_t, class = std::enable_if_t<std::is_same_v<source_index_t, int64_t>>>
|
||||
C10_HOST GenericPackedTensorAccessor(
|
||||
PtrType data_,
|
||||
const source_index_t* sizes_,
|
||||
const source_index_t* strides_)
|
||||
: GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
||||
|
||||
C10_DEVICE T & operator[](index_t i) {
|
||||
return this->data_[this->strides_[0] * i];
|
||||
}
|
||||
C10_DEVICE const T& operator[](index_t i) const {
|
||||
return this->data_[this->strides_[0]*i];
|
||||
}
|
||||
|
||||
// Same as in the general N-dimensional case, but note that in the
|
||||
// 1-dimensional case the returned PackedTensorAccessor will always be an
|
||||
// identical copy of the original
|
||||
C10_HOST GenericPackedTensorAccessor<T, 1, PtrTraits, index_t> transpose(
|
||||
index_t dim1,
|
||||
index_t dim2) const {
|
||||
this->bounds_check_(dim1);
|
||||
this->bounds_check_(dim2);
|
||||
return GenericPackedTensorAccessor<T, 1, PtrTraits, index_t>(
|
||||
this->data_, this->sizes_, this->strides_);
|
||||
}
|
||||
};
|
||||
using GenericPackedTensorAccessorBase = torch::headeronly::detail::GenericPackedTensorAccessorBase<detail::IndexBoundsCheck<N, index_t>, T, N, PtrTraits, index_t>;
|
||||
|
||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||
using GenericPackedTensorAccessor = torch::headeronly::detail::GenericPackedTensorAccessor<TensorAccessor<T, N-1, PtrTraits, index_t>, detail::IndexBoundsCheck<N, index_t>, T, N, PtrTraits, index_t>;
|
||||
|
||||
// Can't put this directly into the macro function args because of commas
|
||||
#define AT_X GenericPackedTensorAccessor<T, N, PtrTraits, index_t>
|
||||
|
||||
@ -245,6 +245,9 @@ class TORCH_API TensorBase {
|
||||
size_t weak_use_count() const noexcept {
|
||||
return impl_.weak_use_count();
|
||||
}
|
||||
bool is_uniquely_owned() const noexcept {
|
||||
return impl_.is_uniquely_owned();
|
||||
}
|
||||
|
||||
std::string toString() const;
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <shared_mutex>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cusparse.h>
|
||||
@ -88,8 +89,13 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
|
||||
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
|
||||
|
||||
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
|
||||
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace();
|
||||
struct WorkspaceMapWithMutex {
|
||||
std::map<std::tuple<void*, void*>, at::DataPtr> map;
|
||||
std::shared_mutex mutex;
|
||||
};
|
||||
|
||||
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublas_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
|
||||
TORCH_CUDA_CPP_API size_t getCUDABlasLtWorkspaceSize();
|
||||
TORCH_CUDA_CPP_API void* getCUDABlasLtWorkspace();
|
||||
|
||||
@ -175,17 +175,24 @@ void CUDAGraph::instantiate() {
|
||||
// Trailing NULL, NULL, 0 arguments were recommended by Cuda driver people,
|
||||
// who prefer not to report error message through these arguments moving forward
|
||||
// (they prefer return value, or errors on api calls internal to the capture)
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
|
||||
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, 0));
|
||||
// ROCM appears to fail with HIP error: invalid argument
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && !defined(USE_ROCM)
|
||||
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, cudaGraphInstantiateFlagUseNodePriority));
|
||||
#else
|
||||
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
|
||||
#endif
|
||||
//Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory.
|
||||
//It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch.
|
||||
} else {
|
||||
#if !defined(USE_ROCM)
|
||||
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
|
||||
graph_,
|
||||
cudaGraphInstantiateFlagAutoFreeOnLaunch | cudaGraphInstantiateFlagUseNodePriority));
|
||||
#else
|
||||
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
|
||||
graph_,
|
||||
cudaGraphInstantiateFlagAutoFreeOnLaunch));
|
||||
#endif
|
||||
}
|
||||
has_graph_exec_ = true;
|
||||
}
|
||||
|
||||
@ -99,7 +99,7 @@ void destroyCublasHandle(cublasHandle_t handle) {
|
||||
// - Comments of @soumith copied from cuDNN handle pool implementation
|
||||
#ifdef NO_CUDNN_DESTROY_HANDLE
|
||||
#else
|
||||
cublasDestroy(handle);
|
||||
cublasDestroy(handle);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -107,19 +107,27 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
|
||||
|
||||
} // namespace
|
||||
|
||||
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
|
||||
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
|
||||
WorkspaceMapWithMutex& cublas_handle_stream_to_workspace() {
|
||||
static auto& instance = *new WorkspaceMapWithMutex;
|
||||
return instance;
|
||||
}
|
||||
|
||||
std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace() {
|
||||
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
|
||||
WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace() {
|
||||
static auto& instance = *new WorkspaceMapWithMutex;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void clearCublasWorkspaces() {
|
||||
cublas_handle_stream_to_workspace().clear();
|
||||
cublaslt_handle_stream_to_workspace().clear();
|
||||
{
|
||||
auto& workspace = cublas_handle_stream_to_workspace();
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
workspace.map.clear();
|
||||
}
|
||||
{
|
||||
auto& workspace = cublaslt_handle_stream_to_workspace();
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
workspace.map.clear();
|
||||
}
|
||||
}
|
||||
|
||||
size_t parseChosenWorkspaceSize() {
|
||||
@ -233,6 +241,38 @@ at::DataPtr getNewCUDABlasLtWorkspace() {
|
||||
return c10::cuda::CUDACachingAllocator::get()->allocate(getCUDABlasLtWorkspaceSize());
|
||||
}
|
||||
|
||||
void setWorkspaceForHandle(cublasHandle_t handle, c10::cuda::CUDAStream stream) {
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
|
||||
auto& workspace = cublas_handle_stream_to_workspace();
|
||||
|
||||
size_t workspace_size = getChosenWorkspaceSize();
|
||||
|
||||
// Fast path: check if workspace already exists
|
||||
{
|
||||
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
if (workspace_it != workspace.map.end()) {
|
||||
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(
|
||||
handle, workspace_it->second.get(), workspace_size));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: allocate workspace outside the lock
|
||||
auto new_workspace = getNewWorkspace();
|
||||
|
||||
// Insert with lock (double-check in case another thread inserted while we
|
||||
// were allocating)
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.try_emplace(key, std::move(new_workspace)).first;
|
||||
TORCH_CUDABLAS_CHECK(
|
||||
cublasSetWorkspace(handle, workspace_it->second.get(), workspace_size));
|
||||
}
|
||||
}
|
||||
|
||||
void* getCUDABlasLtWorkspace() {
|
||||
#ifndef USE_ROCM
|
||||
static bool unified = c10::utils::check_env(TORCH_CUBLASLT_UNIFIED_WORKSPACE) == true;
|
||||
@ -241,8 +281,10 @@ void* getCUDABlasLtWorkspace() {
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
|
||||
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
|
||||
auto& workspace = at::cuda::cublas_handle_stream_to_workspace();
|
||||
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
TORCH_INTERNAL_ASSERT(workspace_it != workspace.map.end());
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
#endif
|
||||
@ -250,11 +292,29 @@ void* getCUDABlasLtWorkspace() {
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
auto workspace_it = cublaslt_handle_stream_to_workspace().find(key);
|
||||
if (workspace_it == cublaslt_handle_stream_to_workspace().end()) {
|
||||
workspace_it = cublaslt_handle_stream_to_workspace().insert(workspace_it, {key, getNewCUDABlasLtWorkspace()});
|
||||
|
||||
auto& workspace = cublaslt_handle_stream_to_workspace();
|
||||
|
||||
// Fast path: check if workspace already exists
|
||||
{
|
||||
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
if (workspace_it != workspace.map.end()) {
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: allocate workspace outside the lock
|
||||
auto new_workspace = getNewCUDABlasLtWorkspace();
|
||||
|
||||
// Insert with lock (double-check in case another thread inserted while we
|
||||
// were allocating)
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it =
|
||||
workspace.map.try_emplace(key, std::move(new_workspace)).first;
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
|
||||
cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
@ -298,13 +358,8 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
// will allocate memory dynamically (even if they're cheap) outside
|
||||
// PyTorch's CUDA caching allocator. It's possible that CCA used up
|
||||
// all the memory and cublas's cudaMallocAsync will return OOM
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
auto workspace_it = cublas_handle_stream_to_workspace().find(key);
|
||||
if (workspace_it == cublas_handle_stream_to_workspace().end()) {
|
||||
workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
|
||||
}
|
||||
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
|
||||
setWorkspaceForHandle(handle, stream);
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
|
||||
// FP32 data type calculations based on the value of the allow_tf32 flag.
|
||||
|
||||
@ -1936,7 +1936,7 @@ static bool should_fold(const Tensor& tensor1, const Tensor& tensor2, bool has_o
|
||||
|
||||
// We order the tensors. t1 will be the larger tensor
|
||||
// We can always transpose tensor2 as the dimensions are always >= 1 (precondition from matmul)
|
||||
// and tensor1_larger iff tensor2.dim() > tensor1.dim()
|
||||
// and tensor1_larger iff tensor2.dim() > tensor1.dim(9
|
||||
const auto t1 = tensor1_larger ? MaybeOwned<Tensor>::borrowed(tensor1)
|
||||
: MaybeOwned<Tensor>::owned(tensor2.mT());
|
||||
const int64_t dim_t1 = t1->dim();
|
||||
@ -1948,11 +1948,20 @@ static bool should_fold(const Tensor& tensor1, const Tensor& tensor2, bool has_o
|
||||
return false;
|
||||
}
|
||||
|
||||
// If we require a gradient, we should fold to minimize backward memory usage - even if this
|
||||
// leads to a copy in forward because is needed in backward,
|
||||
// only time we avoid this strict pre-allocated memory usage (has_out = True)
|
||||
bool requires_grad = tensor1.requires_grad() || tensor2.requires_grad();
|
||||
if (requires_grad && !has_out) {
|
||||
// In this case we *do* incur in an extra copy to avoid creating an unnecessary large tensor in the backward
|
||||
// Suppose we don't fold here. Let t1.shape = [b, m, n] t2.shape = [n, k] like in a transformer
|
||||
// t2 will be expanded to a tensor of shape [b, n, k] and then we do t1.bmm(t2_expanded)
|
||||
// The issue appears in the backward.
|
||||
// The output gradient g of this operation would have shape [b, m, k]
|
||||
// The backward wrt. t2 of bmm would be given by t1.mH @ g, which has shape [b, n, k]
|
||||
// Then, the backward of expand is simply `sum(0)`. As such, we are instantiating a tensor
|
||||
// of shape [b, n, k] unnecessarily, which may cause a large memory footprint, and in the
|
||||
// worst case, an OOM
|
||||
bool t2_requires_grad = tensor1_larger ? tensor2.requires_grad() : tensor1.requires_grad();
|
||||
if (t2_requires_grad && !has_out) {
|
||||
// We should be checking !at::GradMode::is_enabled(), but apparently
|
||||
// this regresses performance in some cases:
|
||||
// https://github.com/pytorch/pytorch/issues/118548#issuecomment-1916022394
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -3532,9 +3541,9 @@ Tensor _dyn_quant_matmul_4bit_cpu(
|
||||
const int64_t out_features) {
|
||||
auto M = inp.size(0);
|
||||
TORCH_CHECK(
|
||||
inp.dtype() == kFloat,
|
||||
inp.dtype() == kFloat || (inp.dtype() == kBFloat16 && block_size == in_features),
|
||||
__func__,
|
||||
" : expect input to be 32-bit float tensor.");
|
||||
" : expect input to be float32 or bfloat16 tensor.");
|
||||
TORCH_CHECK(
|
||||
block_size == in_features ||
|
||||
(!(block_size % 32) && !(in_features % block_size)),
|
||||
|
||||
@ -1087,7 +1087,8 @@ TORCH_IMPL_FUNC(index_copy_out)
|
||||
result.copy_(self);
|
||||
|
||||
// See Note [Enabling Deterministic Operations]
|
||||
if (result.is_cuda() && globalContext().deterministicAlgorithms()) {
|
||||
if ((result.is_cuda() || result.is_xpu()) &&
|
||||
globalContext().deterministicAlgorithms()) {
|
||||
torch::List<std::optional<Tensor>> indices;
|
||||
indices.resize(dim + 1);
|
||||
indices.set(dim, index);
|
||||
|
||||
@ -904,19 +904,11 @@ Tensor mvlgamma(const Tensor& self, int64_t p) {
|
||||
return args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER);
|
||||
}
|
||||
|
||||
// since mvlgamma_ has different signature from its
|
||||
// out and functional variant, we explicitly
|
||||
// define it (instead of using structured kernel).
|
||||
Tensor& mvlgamma_(Tensor& self, int64_t p) {
|
||||
mvlgamma_check(self, p);
|
||||
Tensor args = native::arange(
|
||||
-p *HALF + HALF,
|
||||
HALF,
|
||||
HALF,
|
||||
optTypeMetaToScalarType(self.options().dtype_opt()),
|
||||
self.options().layout_opt(),
|
||||
self.options().device_opt(),
|
||||
self.options().pinned_memory_opt());
|
||||
args = args.add(self.unsqueeze(-1));
|
||||
const auto p2_sub_p = static_cast<double>(p * (p - 1));
|
||||
return self.copy_(args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER));
|
||||
return at::mvlgamma_out(self, self, p);
|
||||
}
|
||||
|
||||
Tensor& mvlgamma_out(const Tensor& self, int64_t p, Tensor& result) {
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/native/cpu/int_mm_kernel.h>
|
||||
#include <ATen/native/cpu/utils.h>
|
||||
#include <cmath>
|
||||
#include <c10/util/Unroll.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
@ -793,6 +794,139 @@ bool can_use_kleidiai(
|
||||
}
|
||||
#endif
|
||||
|
||||
static void ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
|
||||
size_t m,
|
||||
size_t n,
|
||||
size_t k,
|
||||
const uint16_t* lhs_bf16,
|
||||
const uint8_t* rhs_qs4cx,
|
||||
const float* rhs_scales,
|
||||
uint16_t* dst_bf16,
|
||||
float scalar_min,
|
||||
float scalar_max,
|
||||
const float* bias) {
|
||||
// Roundup lambda for internal stride calculations
|
||||
auto roundup = [](size_t a, size_t b) { return ((a + b - 1) / b) * b; };
|
||||
|
||||
// Cast bfloat16 to float32 inline
|
||||
auto cast_bf16_to_f32 = [](uint16_t bf16_val) {
|
||||
uint32_t tmp = static_cast<uint32_t>(bf16_val) << 16;
|
||||
float f;
|
||||
std::memcpy(&f, &tmp, sizeof(f));
|
||||
return f;
|
||||
};
|
||||
|
||||
// Cast float32 to bfloat16 inline
|
||||
auto cast_f32_to_bf16 = [](float f) {
|
||||
uint32_t bits;
|
||||
std::memcpy(&bits, &f, sizeof(bits));
|
||||
return static_cast<uint16_t>(bits >> 16);
|
||||
};
|
||||
|
||||
// Quantization pack lambda (channelwise QA8DX)
|
||||
auto quant_pack_8bit_channelwise =
|
||||
[&](size_t M, size_t K, const uint16_t* src_bf16, int8_t* dst_qa8dx) {
|
||||
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
|
||||
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
|
||||
|
||||
const size_t dst_stride =
|
||||
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
|
||||
for (size_t i = 0; i < M; ++i) {
|
||||
const uint16_t* row_ptr = src_bf16 + i * K;
|
||||
// find min/max
|
||||
float mn = FLT_MAX, mx = -FLT_MAX;
|
||||
for (size_t j = 0; j < K; ++j) {
|
||||
float v = cast_bf16_to_f32(row_ptr[j]);
|
||||
mn = std::min(mn, v);
|
||||
mx = std::max(mx, v);
|
||||
}
|
||||
float rmin = std::min(0.0f, mn);
|
||||
float rmax = std::max(0.0f, mx);
|
||||
constexpr float qmin = static_cast<float>(kI8Min);
|
||||
constexpr float qmax = static_cast<float>(kI8Max);
|
||||
float scale = (rmin == rmax) ? 1.f : (qmax - qmin) / (rmax - rmin);
|
||||
float recip = scale ? 1.0f / scale : 0.0f;
|
||||
int32_t zp;
|
||||
float des_min = rmin * scale;
|
||||
float des_max = rmax * scale;
|
||||
float err_min = qmin + des_min;
|
||||
float err_max = qmax + des_max;
|
||||
float zp_f =
|
||||
(err_min + err_max) > 0 ? qmin - des_min : qmax - des_max;
|
||||
zp_f = std::clamp(zp_f, qmin, qmax);
|
||||
zp = std::lrintf(zp_f);
|
||||
int8_t* out_ptr = dst_qa8dx + i * dst_stride;
|
||||
// store header
|
||||
*reinterpret_cast<float*>(out_ptr) = recip;
|
||||
*reinterpret_cast<int32_t*>(out_ptr + sizeof(float)) = -zp;
|
||||
out_ptr += sizeof(float) + sizeof(int32_t);
|
||||
// quantize
|
||||
for (size_t j = 0; j < K; ++j) {
|
||||
float v = cast_bf16_to_f32(row_ptr[j]);
|
||||
int32_t q = static_cast<int32_t>(std::round(v * scale)) + zp;
|
||||
q = std::clamp(
|
||||
q, static_cast<int32_t>(kI8Min), static_cast<int32_t>(kI8Max));
|
||||
*out_ptr++ = static_cast<int8_t>(q);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// MatMul lambda (MXN x MXK -> MNXK BF16)
|
||||
auto matmul_kernel = [&](size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const int8_t* lhs,
|
||||
const uint8_t* rhs,
|
||||
const float* scales,
|
||||
uint16_t* dst,
|
||||
float lo,
|
||||
float hi) {
|
||||
const size_t lhs_stride =
|
||||
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
|
||||
const size_t rhs_stride = roundup(K, 2) / 2;
|
||||
for (size_t i = 0; i < M; ++i) {
|
||||
const int8_t* lhs_row = lhs + i * lhs_stride;
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
int32_t acc = 0;
|
||||
const int8_t* lptr = lhs_row;
|
||||
const uint8_t* rptr = rhs + j * rhs_stride;
|
||||
float lhs_scale = *reinterpret_cast<const float*>(lptr);
|
||||
int32_t lhs_off =
|
||||
*reinterpret_cast<const int32_t*>(lptr + sizeof(float));
|
||||
lptr += sizeof(float) + sizeof(int32_t);
|
||||
for (size_t t = 0; t < K; ++t) {
|
||||
int32_t lv = static_cast<int32_t>(lptr[t]);
|
||||
uint8_t bv = rptr[t / 2];
|
||||
int32_t rv = ((t & 1) == 0) ? (static_cast<int32_t>(bv & 0xF) - 8)
|
||||
: (static_cast<int32_t>(bv >> 4) - 8);
|
||||
acc += lv * rv + lhs_off * rv;
|
||||
}
|
||||
float res = static_cast<float>(acc) * scales[j] * lhs_scale;
|
||||
if (bias) {
|
||||
res += bias[j];
|
||||
}
|
||||
res = std::clamp(res, lo, hi);
|
||||
*dst++ = cast_f32_to_bf16(res);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// allocate and run
|
||||
std::unique_ptr<int8_t[]> packed(
|
||||
new int8_t[m * (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t))]);
|
||||
quant_pack_8bit_channelwise(m, k, lhs_bf16, packed.get());
|
||||
matmul_kernel(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
packed.get(),
|
||||
rhs_qs4cx,
|
||||
rhs_scales,
|
||||
dst_bf16,
|
||||
scalar_min,
|
||||
scalar_max);
|
||||
}
|
||||
|
||||
/**
|
||||
* The Int4 quantized weights must be represented as a uint8 tensor
|
||||
* For matrix multiplication with a weight shape of (N x K)
|
||||
@ -819,21 +953,21 @@ void dyn_quant_pack_4bit_weight_kernel(
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
if (can_use_kleidiai(scales_zeros, K, block_size)) {
|
||||
const int64_t weight_packed_size =
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, weights.scalar_type());
|
||||
packed_weights.resize_({weight_packed_size});
|
||||
kleidiai::kai_pack_int4_rhs(
|
||||
packed_weights, weights, scales_zeros, bias, N, K, block_size);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
TORCH_CHECK(
|
||||
bias.has_value() == 0,
|
||||
__func__,
|
||||
" : Bias is unsupported in reference implementation");
|
||||
packed_weights = packed_weights.to(kFloat);
|
||||
auto weight_reshaped = weights.view({-1}).to(kFloat);
|
||||
auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat);
|
||||
auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0);
|
||||
auto weight_reshaped = weights.reshape({-1}).to(kFloat);
|
||||
auto scales_zeros_reshaped = scales_zeros.reshape({-1}).to(kFloat);
|
||||
std::vector<at::Tensor> tensors_to_cat = {weight_reshaped, scales_zeros_reshaped};
|
||||
if (bias.has_value()) {
|
||||
tensors_to_cat.push_back(bias.value().view({-1}).to(kFloat));
|
||||
}
|
||||
auto res = at::cat(tensors_to_cat, 0);
|
||||
packed_weights.resize_(res.sizes()).copy_(res);
|
||||
}
|
||||
}
|
||||
@ -847,7 +981,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
const float* rhs_scales_f32,
|
||||
float* dst_f32,
|
||||
float scalar_min,
|
||||
float scalar_max) {
|
||||
float scalar_max,
|
||||
const float* bias) {
|
||||
const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float));
|
||||
|
||||
auto lhs_qa8dx_buffer = std::make_unique<uint8_t[]>(input_size_8bit);
|
||||
@ -857,6 +992,9 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
// required format for matmul
|
||||
auto input_quant_pack_8bit_channelwise =
|
||||
[&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) {
|
||||
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
|
||||
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
|
||||
|
||||
const size_t dst_stride =
|
||||
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
|
||||
|
||||
@ -877,8 +1015,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
}
|
||||
|
||||
// Maximum/minimum int8 values
|
||||
const float qmin = (float)INT8_MIN;
|
||||
const float qmax = (float)INT8_MAX;
|
||||
constexpr float qmin = static_cast<float>(kI8Min);
|
||||
constexpr float qmax = static_cast<float>(kI8Max);
|
||||
|
||||
const float rmin0 = std::min(0.0f, min0);
|
||||
const float rmax0 = std::max(0.0f, max0);
|
||||
@ -904,7 +1042,7 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
zero_point0 = std::min(zero_point0, qmax);
|
||||
|
||||
// Round to nearest integer
|
||||
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
||||
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
|
||||
|
||||
int8_t* dst_ptr = lhs_qa8dx + m_idx * dst_stride;
|
||||
|
||||
@ -922,8 +1060,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
|
||||
|
||||
v0_s32 = v0_s32 + nudged_zero_point0;
|
||||
v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT8_MIN));
|
||||
v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT8_MAX));
|
||||
v0_s32 = std::max(v0_s32, static_cast<int32_t>(kI8Min));
|
||||
v0_s32 = std::min(v0_s32, static_cast<int32_t>(kI8Max));
|
||||
dst_ptr[0] = (int8_t)v0_s32;
|
||||
dst_ptr += sizeof(int8_t);
|
||||
}
|
||||
@ -987,6 +1125,10 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
|
||||
main_acc = main_acc * lhs_scale;
|
||||
|
||||
if (bias) {
|
||||
main_acc += bias[n_idx];
|
||||
}
|
||||
|
||||
// Clamp (min-max) operation
|
||||
main_acc = std::max(main_acc, scalar_min);
|
||||
main_acc = std::min(main_acc, scalar_max);
|
||||
@ -1007,12 +1149,16 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
const float* rhs_scales_fp32,
|
||||
float* dst_f32,
|
||||
float scalar_min,
|
||||
float scalar_max) {
|
||||
float scalar_max,
|
||||
const float* bias) {
|
||||
// Lambda for LHS quantization
|
||||
auto lhs_quant_pack = [&](size_t m,
|
||||
size_t k,
|
||||
const float* lhs_f32,
|
||||
int8_t* lhs_qa8dx) {
|
||||
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
|
||||
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
|
||||
|
||||
const size_t dst_stride =
|
||||
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
|
||||
|
||||
@ -1028,8 +1174,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
min0 = std::min(src0_0, min0);
|
||||
}
|
||||
|
||||
const float qmin = (float)INT8_MIN;
|
||||
const float qmax = (float)INT8_MAX;
|
||||
constexpr float qmin = static_cast<float>(kI8Min);
|
||||
constexpr float qmax = static_cast<float>(kI8Max);
|
||||
|
||||
const float rmin0 = std::min(0.0f, min0);
|
||||
const float rmax0 = std::max(0.0f, max0);
|
||||
@ -1046,7 +1192,7 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
|
||||
zero_point0 = std::max(zero_point0, qmin);
|
||||
zero_point0 = std::min(zero_point0, qmax);
|
||||
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
||||
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
|
||||
|
||||
int8_t* dst_ptr = lhs_qa8dx + row_idx * dst_stride;
|
||||
|
||||
@ -1059,9 +1205,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
const float src0_0 = src_ptr[k_idx];
|
||||
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
|
||||
v0_s32 = std::max(
|
||||
std::min(
|
||||
v0_s32 + nudged_zero_point0, static_cast<int32_t>(INT8_MAX)),
|
||||
static_cast<int32_t>(INT8_MIN));
|
||||
std::min(v0_s32 + nudged_zero_point0, static_cast<int32_t>(kI8Max)),
|
||||
static_cast<int32_t>(kI8Min));
|
||||
dst_ptr[0] = (int8_t)v0_s32;
|
||||
dst_ptr += sizeof(int8_t);
|
||||
}
|
||||
@ -1118,6 +1263,11 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
}
|
||||
|
||||
main_acc = main_acc * lhs_scale;
|
||||
|
||||
if (bias) {
|
||||
main_acc += bias[col_idx];
|
||||
}
|
||||
|
||||
main_acc = std::max(main_acc, scalar_min);
|
||||
main_acc = std::min(main_acc, scalar_max);
|
||||
|
||||
@ -1128,28 +1278,27 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
}
|
||||
|
||||
/**
|
||||
* Dynamic Input Quant 4 bit weights matmul execution flow
|
||||
(INT4 Weights + FP scales + FP32 Bias)
|
||||
FP32 Input Packed Buffer
|
||||
| |
|
||||
Quantize Cast
|
||||
to INT8 to INT8
|
||||
| |
|
||||
v v
|
||||
INT8 Input INT8 Weights
|
||||
\ /
|
||||
\ /
|
||||
\ /
|
||||
INT8 Matrix Multiplication
|
||||
|
|
||||
v
|
||||
FP32 Dequantized and Accumulate in FP32
|
||||
|
|
||||
v
|
||||
FP32 Final Output
|
||||
|
||||
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
|
||||
* Float32 Scales. If not provided, we will use fallback implementation.
|
||||
* Dynamic INT4 weight-only MatMul with per-row input quantization.
|
||||
*
|
||||
* Execution Flow:
|
||||
*
|
||||
* (INT4 Weights + FP Scales [+ optional Bias])
|
||||
*
|
||||
* Input (FP32 or BF16) Packed Weight Buffer
|
||||
* | |
|
||||
* Row-wise Quantization (INT8) |
|
||||
* | |
|
||||
* INT8 Input Activation INT4 Quantized Weights + Scales
|
||||
* \ /
|
||||
* \ /
|
||||
* Quantized Matrix Multiply
|
||||
* |
|
||||
* Output Tensor (BF16 or FP32)
|
||||
*
|
||||
* Notes:
|
||||
* - Groupwise kernels expect BF16 scales
|
||||
* - Channelwise kernels expect FP32 scales
|
||||
* - Bias is currently unsupported in fallback path
|
||||
*/
|
||||
void dyn_quant_matmul_4bit_kernel(
|
||||
const Tensor& output,
|
||||
@ -1161,65 +1310,75 @@ void dyn_quant_matmul_4bit_kernel(
|
||||
const int64_t block_size) {
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
const int64_t weight_packed_size =
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, inp.scalar_type());
|
||||
if (weight_packed_size == packed_weights.numel()) {
|
||||
// KleidiAI interface internally handles the Channelwise and groupwise
|
||||
// distinction
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm(
|
||||
output, inp, packed_weights, M, N, K, block_size);
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm(output, inp, packed_weights, M, N, K, block_size);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
float* lhs_f32 = reinterpret_cast<float*>(inp.data_ptr());
|
||||
const auto weights_size = N * K / 2;
|
||||
// The weights needs to be in uint8_t data type after quantization
|
||||
auto extracted_weights =
|
||||
(packed_weights.narrow(0, 0, weights_size)).to(kByte);
|
||||
auto float32_scales =
|
||||
(packed_weights.narrow(
|
||||
0, weights_size, packed_weights.size(0) - weights_size))
|
||||
.to(kFloat);
|
||||
uint8_t* rhs_4bit =
|
||||
reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
|
||||
float* rhs_scales_f32 = reinterpret_cast<float*>(float32_scales.data_ptr());
|
||||
float* dst_f32 = reinterpret_cast<float*>(output.data_ptr());
|
||||
if (block_size == K) {
|
||||
ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lhs_f32,
|
||||
rhs_4bit,
|
||||
rhs_scales_f32,
|
||||
dst_f32,
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
} else if (!(block_size % 32) && !(K % block_size)) {
|
||||
ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_size,
|
||||
lhs_f32,
|
||||
rhs_4bit,
|
||||
rhs_scales_f32,
|
||||
dst_f32,
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
block_size == K || (!(block_size % 32) && !(K % block_size)),
|
||||
__func__,
|
||||
": Group size should be multiple 32 or in_features [",
|
||||
K,
|
||||
"]. Provided ",
|
||||
block_size);
|
||||
{
|
||||
void* input = inp.data_ptr();
|
||||
void* dst = output.data_ptr();
|
||||
|
||||
// Extract weights, sclaes and biases form from packed tensor
|
||||
const int weights_elements = N * K / 2;
|
||||
const int scale_elements = N * (K / block_size);
|
||||
TORCH_CHECK(packed_weights.numel() >= (weights_elements + scale_elements), "Invalid packed weight tensor size");
|
||||
|
||||
auto extracted_weights = packed_weights.narrow(0, 0, weights_elements).to(kByte);
|
||||
auto extracted_scales_and_bias = packed_weights.narrow(0, weights_elements, packed_weights.size(0) - weights_elements).to(kFloat);
|
||||
auto float32_scales = extracted_scales_and_bias.narrow(0, 0, scale_elements);
|
||||
|
||||
int bias_elements = packed_weights.numel() - (weights_elements + scale_elements);
|
||||
float* weight_scales = float32_scales.data_ptr<float>();
|
||||
|
||||
void* bias_data = nullptr;
|
||||
if (bias_elements) {
|
||||
auto float32_bias = extracted_scales_and_bias.narrow(0, scale_elements, bias_elements);
|
||||
TORCH_CHECK(float32_bias.size(0) == N, "Expected bias length to match output dimension");
|
||||
bias_data = float32_bias.data_ptr();
|
||||
|
||||
}
|
||||
// 2 elements of 4 bit weights are packed into 1 uint8 packet
|
||||
uint8_t* weights_4bit = reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
|
||||
|
||||
// Dispatch to reference kernels
|
||||
if (inp.scalar_type() == at::kBFloat16) {
|
||||
// BF16 input, BF16 output
|
||||
constexpr float BF16_MAX = 3.38953139e+38f;
|
||||
constexpr float BF16_MIN = -BF16_MAX;
|
||||
if (block_size == K) {
|
||||
ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
|
||||
M, N, K,
|
||||
(uint16_t*)input, weights_4bit, weight_scales,
|
||||
(uint16_t*)dst, BF16_MIN, BF16_MAX, (float*)bias_data);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported block size for BF16 fallback");
|
||||
}
|
||||
} else if (inp.scalar_type() == at::kFloat) {
|
||||
// FP32 input, FP32 output
|
||||
if (block_size == K) {
|
||||
ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
M, N, K,
|
||||
(float*)input, weights_4bit, weight_scales,
|
||||
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
|
||||
} else if (!(block_size % 32) && !(K % block_size)) {
|
||||
ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
M, N, K, block_size,
|
||||
(float*)input, weights_4bit, weight_scales,
|
||||
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported block size for FP32 fallback");
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported input/output dtype combination for int4mm kernel");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
}
|
||||
ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel)
|
||||
ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel)
|
||||
REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel)
|
||||
|
||||
@ -78,9 +78,18 @@ __global__ void EmbeddingBag_updateOutputKernel_max(
|
||||
scalar_t weightFeatMax = 0;
|
||||
int64_t bag_size_ = 0;
|
||||
int64_t maxWord = -1;
|
||||
|
||||
// Separate validation loop reduces register pressure in the main loop below.
|
||||
// No early exit (break) on invalid input as benchmarking shows it degrades performance.
|
||||
bool has_invalid_index = false;
|
||||
for (int64_t emb = begin; emb < end; emb++) {
|
||||
index_t input_idx = input[emb];
|
||||
has_invalid_index = has_invalid_index || (input_idx < 0 || input_idx >= numRows);
|
||||
}
|
||||
CUDA_KERNEL_ASSERT(!has_invalid_index && "Invalid input index in EmbeddingBag: index out of range [0, numRows)");
|
||||
|
||||
for (int64_t emb = begin; emb < end; emb++) {
|
||||
bool pad = (input[emb] == padding_idx);
|
||||
CUDA_KERNEL_ASSERT(input[emb] < numRows);
|
||||
const int64_t weightRow = input[emb] * weight_stride0;
|
||||
scalar_t weightValue = weightFeat[weightRow];
|
||||
if (bag_size_ == 0 || weightValue > weightFeatMax) {
|
||||
@ -129,10 +138,19 @@ __global__ void EmbeddingBag_updateOutputKernel_sum_mean(
|
||||
CUDA_KERNEL_ASSERT(end >= begin);
|
||||
accscalar_t weightFeatSum = 0;
|
||||
int64_t bag_size_ = 0;
|
||||
|
||||
// Separate validation loop reduces register pressure in the main loop below.
|
||||
// No early exit (break) on invalid input as benchmarking shows it degrades performance.
|
||||
bool has_invalid_index = false;
|
||||
for (int64_t emb = begin; emb < end; emb++) {
|
||||
index_t input_idx = input[emb];
|
||||
has_invalid_index = has_invalid_index || (input_idx < 0 || input_idx >= numRows);
|
||||
}
|
||||
CUDA_KERNEL_ASSERT(!has_invalid_index && "Invalid input index in EmbeddingBag: index out of range [0, numRows)");
|
||||
|
||||
for (int64_t emb = begin; emb < end; emb++) {
|
||||
index_t input_idx = input[emb];
|
||||
bool pad = (input_idx == padding_idx);
|
||||
CUDA_KERNEL_ASSERT(0 <= input_idx && input_idx < numRows);
|
||||
const int64_t weightRow = input_idx * weight_stride0;
|
||||
scalar_t weightValue = weightFeat[weightRow];
|
||||
weightValue = pad ? static_cast<scalar_t>(0) : weightValue;
|
||||
|
||||
@ -78,9 +78,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const SwizzleType& swizzle_a,
|
||||
const SwizzleType swizzle_a,
|
||||
const Tensor& scale_b,
|
||||
const SwizzleType& swizzle_b,
|
||||
const SwizzleType swizzle_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
Tensor& out) {
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
|
||||
@ -5,69 +5,11 @@
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
// ROCm 6.3 is planned to have these functions, but until then here they are.
|
||||
#if defined(USE_ROCM)
|
||||
#include <device_functions.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
|
||||
__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) {
|
||||
#if (defined(__gfx942__)) && \
|
||||
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16)
|
||||
typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2;
|
||||
static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw));
|
||||
union {
|
||||
__hip_bfloat162_raw bf162_raw;
|
||||
vec_short2 vs2;
|
||||
} u{static_cast<__hip_bfloat162_raw>(value)};
|
||||
u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)address, u.vs2);
|
||||
return static_cast<__hip_bfloat162>(u.bf162_raw);
|
||||
#else
|
||||
static_assert(sizeof(unsigned int) == sizeof(__hip_bfloat162_raw));
|
||||
union u_hold {
|
||||
__hip_bfloat162_raw h2r;
|
||||
unsigned int u32;
|
||||
};
|
||||
u_hold old_val, new_val;
|
||||
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
do {
|
||||
new_val.h2r = __hadd2(old_val.h2r, value);
|
||||
} while (!__hip_atomic_compare_exchange_strong(
|
||||
(unsigned int*)address, &old_val.u32, new_val.u32,
|
||||
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
|
||||
return old_val.h2r;
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) {
|
||||
#if (defined(__gfx942__)) && \
|
||||
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16)
|
||||
// The api expects an ext_vector_type of half
|
||||
typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162;
|
||||
static_assert(sizeof(vec_fp162) == sizeof(__half2_raw));
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
vec_fp162 fp16;
|
||||
} u {static_cast<__half2_raw>(value)};
|
||||
u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16((vec_fp162*)address, u.fp16);
|
||||
return static_cast<__half2>(u.h2r);
|
||||
#else
|
||||
static_assert(sizeof(__half2_raw) == sizeof(unsigned int));
|
||||
union u_hold {
|
||||
__half2_raw h2r;
|
||||
unsigned int u32;
|
||||
};
|
||||
u_hold old_val, new_val;
|
||||
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
do {
|
||||
new_val.h2r = __hadd2(old_val.h2r, value);
|
||||
} while (!__hip_atomic_compare_exchange_strong(
|
||||
(unsigned int*)address, &old_val.u32, new_val.u32,
|
||||
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
|
||||
return old_val.h2r;
|
||||
#endif
|
||||
}
|
||||
#define ATOMICADD preview_unsafeAtomicAdd
|
||||
#define ATOMICADD unsafeAtomicAdd
|
||||
#define NATIVE_ZERO_BF16 __float2bfloat16(0.0f)
|
||||
#else
|
||||
#define ATOMICADD atomicAdd
|
||||
|
||||
@ -740,7 +740,12 @@ _scaled_rowwise_rowwise(
|
||||
TORCH_CHECK_VALUE(scale_a.numel() == mat_a.size(0) && scale_a.scalar_type() == kFloat, "scale_a must have ", mat_a.size(0), " Float elements, got ", scale_a.numel())
|
||||
TORCH_CHECK_VALUE(scale_b.numel() == mat_b.size(1) && scale_b.scalar_type() == kFloat, "scale_b must have ", mat_b.size(1), " Float elements, got ", scale_b.numel())
|
||||
|
||||
TORCH_CHECK_VALUE(scale_a.stride(1) == 1, "expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1));
|
||||
// if we have a scale of shape [256, 1] (say), then stride can be [1, 0] - handle this case
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(1) == 1 ||
|
||||
scale_a.size(1) == 1,
|
||||
"expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1)
|
||||
);
|
||||
TORCH_CHECK_VALUE(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
|
||||
|
||||
auto scaling_choice_a = ScalingType::RowWise;
|
||||
@ -1096,6 +1101,19 @@ _scaled_mxfp8_mxfp8(
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
}
|
||||
|
||||
void
|
||||
_check_mxfp4_support() {
|
||||
#ifndef USE_ROCM
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// Only on B200 GPUs
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
// B200 = 10.0, B300 = 10.3
|
||||
dprops->major == 10,
|
||||
"MXFP4 scaling only supported in CUDA for B200/B300"
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
Tensor&
|
||||
_scaled_mxfp4_mxfp4(
|
||||
@ -1108,6 +1126,7 @@ _scaled_mxfp4_mxfp4(
|
||||
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
|
||||
#else
|
||||
_check_mxfp4_support();
|
||||
// Restrictions:
|
||||
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
|
||||
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
|
||||
|
||||
@ -21,18 +21,27 @@ void kai_pack_int4_rhs(
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||
if (bl == k) {
|
||||
// Channelwise
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
auto& params = kernel_packet.rhs_pack_params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
|
||||
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
|
||||
kernel_packet, weight_packed, weight, scales, bias, n, k);
|
||||
if (weight.scalar_type() == at::kBFloat16) {
|
||||
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
|
||||
auto& params = kernel_packet.rhs_pack_params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>(
|
||||
kernel_packet, weight_packed, weight, scales, bias, n, k);
|
||||
} else {
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
auto& params = kernel_packet.rhs_pack_params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
|
||||
kernel_packet, weight_packed, weight, scales, bias, n, k);
|
||||
}
|
||||
} else if (!(bl % 32) && !(k % bl)) {
|
||||
// Groupwise
|
||||
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
|
||||
@ -63,19 +72,29 @@ void kai_pack_int4_rhs(
|
||||
size_t kai_pack_rhs_int4_size(
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
const int64_t bl,
|
||||
at::ScalarType tensor_dtype) {
|
||||
size_t packed_size = n * k;
|
||||
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||
if (bl == k) {
|
||||
// Channelwise
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
|
||||
if (tensor_dtype == at::kBFloat16) {
|
||||
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
|
||||
} else {
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
|
||||
}
|
||||
} else if (!(bl % 32) && !(k % bl)) {
|
||||
// Groupwise
|
||||
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
|
||||
@ -148,8 +167,7 @@ static void kai_quant_pack_lhs_int4_mm_groupwise(
|
||||
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
|
||||
const int64_t m_idx = thread_id * vec_per_thread;
|
||||
auto lhs_packed_ptr = lhs_packed_base +
|
||||
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
|
||||
m_idx, k, mr, kr, sr);
|
||||
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (m - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
@ -259,8 +277,7 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
|
||||
const int64_t m_idx = thread_id * vec_per_thread;
|
||||
auto lhs_packed_ptr = lhs_packed_base +
|
||||
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
|
||||
m_idx, k, mr, kr, sr);
|
||||
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (m - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
@ -320,19 +337,144 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
});
|
||||
}
|
||||
|
||||
void kai_quant_pack_lhs_int4_mm(
|
||||
static void kai_quant_pack_lhs_int4_mm_bf16_channelwise(
|
||||
const Tensor& output,
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const int64_t m,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
// Kernel IDs for GEMM and GEMV
|
||||
constexpr kai_kernel_id gemm_id =
|
||||
kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm;
|
||||
constexpr kai_kernel_id gemv_id =
|
||||
kai_kernel_id::matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod;
|
||||
|
||||
// Get total threads and select kernel
|
||||
const int64_t total_threads = at::get_num_threads();
|
||||
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemv_id);
|
||||
if (cpuinfo_has_arm_i8mm() && m > 1) {
|
||||
kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemm_id);
|
||||
}
|
||||
|
||||
// Thread blocking parameters
|
||||
const int64_t n_step = kernel_packet.ukernel.get_n_step();
|
||||
const size_t mr = kernel_packet.ukernel.get_mr();
|
||||
const size_t kr = kernel_packet.ukernel.get_kr();
|
||||
const size_t sr = kernel_packet.ukernel.get_sr();
|
||||
|
||||
const size_t lhs_packed_size =
|
||||
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr);
|
||||
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size);
|
||||
uint8_t* dst_act_mtx_bf16 = reinterpret_cast<uint8_t*>(output.data_ptr());
|
||||
const uint8_t* lhs_native_mtx_bf16 =
|
||||
reinterpret_cast<const uint8_t*>(input.data_ptr());
|
||||
const uint8_t* rhs_packed_mtx_qs4cx =
|
||||
reinterpret_cast<const uint8_t*>(weight.data_ptr());
|
||||
uint8_t* lhs_packed_base = lhs_packed.get();
|
||||
|
||||
constexpr int32_t element_size = sizeof(uint16_t);
|
||||
const size_t lhs_stride = k * element_size;
|
||||
const size_t dst_stride = n * element_size;
|
||||
|
||||
// LHS quantization packing
|
||||
int64_t vec_per_thread = get_vec_per_thread(m, total_threads, mr);
|
||||
int64_t num_threads = (m + vec_per_thread - 1) / vec_per_thread;
|
||||
const size_t src_stride = vec_per_thread * lhs_stride;
|
||||
|
||||
auto lhs_quant_pack = [=, &kernel_packet](int64_t thread_id) {
|
||||
const auto lhs_src_ptr = lhs_native_mtx_bf16 + thread_id * src_stride;
|
||||
const int64_t m_idx = thread_id * vec_per_thread;
|
||||
auto lhs_packed_ptr = lhs_packed_base +
|
||||
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (m - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
|
||||
kernel_packet.kai_run_lhs_quant_pack(
|
||||
vec_num,
|
||||
k,
|
||||
mr,
|
||||
kr,
|
||||
sr,
|
||||
0,
|
||||
(const uint16_t*)lhs_src_ptr,
|
||||
lhs_stride,
|
||||
lhs_packed_ptr);
|
||||
};
|
||||
|
||||
at::parallel_for(
|
||||
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
|
||||
lhs_quant_pack(thread_id);
|
||||
}
|
||||
});
|
||||
|
||||
// Matrix multiplication
|
||||
vec_per_thread = get_vec_per_thread(n, total_threads, n_step);
|
||||
num_threads = (n + vec_per_thread - 1) / vec_per_thread;
|
||||
|
||||
auto mm = [=, &kernel_packet](int64_t thread_id) {
|
||||
const auto rhs_packed_ptr = rhs_packed_mtx_qs4cx +
|
||||
kernel_packet.ukernel.get_rhs_packed_offset(
|
||||
thread_id * vec_per_thread, k);
|
||||
auto dst_ptr = dst_act_mtx_bf16 +
|
||||
kernel_packet.ukernel.get_dst_offset(
|
||||
0, thread_id * vec_per_thread, dst_stride);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (n - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
|
||||
kernel_packet.ukernel.run_matmul(
|
||||
m,
|
||||
vec_num,
|
||||
k,
|
||||
lhs_packed_base,
|
||||
rhs_packed_ptr,
|
||||
(uint16_t*)dst_ptr,
|
||||
dst_stride,
|
||||
element_size, // dst_stride_col
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
};
|
||||
|
||||
at::parallel_for(
|
||||
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
|
||||
mm(thread_id);
|
||||
}
|
||||
});
|
||||
}
|
||||
void kai_quant_pack_lhs_int4_mm(
|
||||
const at::Tensor& output,
|
||||
const at::Tensor& input,
|
||||
const at::Tensor& weight,
|
||||
const int64_t m,
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||
if (bl == k) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
output, input, weight, m, n, k);
|
||||
} else if (!(bl % 32) && !(k % bl)) {
|
||||
const auto input_dtype = input.dtype();
|
||||
|
||||
if (input_dtype == at::kBFloat16) {
|
||||
if (cpuinfo_has_arm_bf16()) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_bf16_channelwise(
|
||||
output, input, weight, m, n, k);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"BF16 Unsupported: CPU does not support BF16. Please use a CPU with BF16 support.");
|
||||
}
|
||||
} else if (input_dtype == at::kFloat) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
output, input, weight, m, n, k);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Unsupported input data type: Only Bfloat16 and Float inputs are supported.");
|
||||
}
|
||||
} else if ((bl % 32 == 0) && (k % bl == 0)) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_groupwise(
|
||||
output, input, weight, m, n, k, bl);
|
||||
}
|
||||
|
||||
@ -25,7 +25,8 @@ void kai_pack_int4_rhs(
|
||||
size_t kai_pack_rhs_int4_size(
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl);
|
||||
const int64_t bl,
|
||||
at::ScalarType tensor_dtype = at::kFloat);
|
||||
|
||||
/**
|
||||
* @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul )
|
||||
|
||||
@ -36,7 +36,8 @@ void kai_pack_rhs_groupwise_int4(
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
|
||||
}
|
||||
|
||||
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
|
||||
float* bias_ptr =
|
||||
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
|
||||
auto& params = kernel.rhs_pack_params;
|
||||
|
||||
kernel.kai_run_rhs_pack(
|
||||
@ -73,7 +74,8 @@ void kai_pack_rhs_channelwise_int4(
|
||||
auto weight_packed_data =
|
||||
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
|
||||
const auto weight_data = weight.data_ptr<uint8_t>();
|
||||
const auto scales_data = scales.data_ptr<float>();
|
||||
|
||||
const auto scales_data = scales.to(kFloat).data_ptr<float>();
|
||||
|
||||
if (weight_data == nullptr) {
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
|
||||
@ -83,7 +85,8 @@ void kai_pack_rhs_channelwise_int4(
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
|
||||
}
|
||||
|
||||
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
|
||||
float* bias_ptr =
|
||||
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
|
||||
auto& params = kernel.rhs_pack_params;
|
||||
|
||||
kernel.kai_run_rhs_pack(
|
||||
|
||||
@ -68,5 +68,39 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel(
|
||||
const kai_kernel_id id) {
|
||||
return channelwise_8bit_4bit_kernels.at(id);
|
||||
}
|
||||
|
||||
// Kernel Mapping - BF16 Channelwise
|
||||
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>
|
||||
bf16_channelwise_8bit_4bit_kernels = {
|
||||
{kai_kernel_id::
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_n_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_mr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_nr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_kr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_sr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_dst_size_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_run_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod}}},
|
||||
{kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_n_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_mr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_nr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_kr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_sr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_dst_size_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_run_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm}}}};
|
||||
|
||||
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp kai_select_bf16_channelwise_matmul_ukernel(
|
||||
const kai_kernel_id id) {
|
||||
return bf16_channelwise_8bit_4bit_kernels.at(id);
|
||||
}
|
||||
} // namespace at::native::kleidiai
|
||||
#endif
|
||||
|
||||
@ -10,21 +10,32 @@
|
||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h>
|
||||
|
||||
namespace at::native::kleidiai {
|
||||
|
||||
enum class kai_kernel_id {
|
||||
// FP32 inputs, 4-bit weights, FP32 output
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod =
|
||||
0, // Groupwise 4 bit GEMV
|
||||
0, // Groupwise 4-bit GEMV (per-group scales, NEON DOTPROD)
|
||||
matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm =
|
||||
1, // Groupwise 4 bit GEMM
|
||||
1, // Groupwise 4-bit GEMM (per-group scales, NEON I8MM)
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod =
|
||||
2, // Channelwise 4 bit GEMV
|
||||
2, // Channelwise 4-bit GEMV (per-channel scales, NEON DOTPROD)
|
||||
matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm =
|
||||
3 // Channelwise 4 bit GEMM
|
||||
3, // Channelwise 4-bit GEMM (per-channel scales, NEON I8MM)
|
||||
|
||||
// BF16 inputs, 4-bit weights, BF16 output
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod =
|
||||
4, // Channelwise 4-bit GEMV with BF16 input/output
|
||||
matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm =
|
||||
5 // Channelwise 4-bit GEMM with BF16 input/output
|
||||
};
|
||||
|
||||
// Channelwise Kernel mapping
|
||||
@ -66,6 +77,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
|
||||
void* rhs_packed,
|
||||
size_t extra_bytes,
|
||||
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
|
||||
size_t(*kai_get_lhs_quant_pack_offset)(
|
||||
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
|
||||
);
|
||||
|
||||
kai_matmul_ukernel_f32_qa8dxp_qs4cxp(
|
||||
const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel)
|
||||
@ -75,12 +89,71 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
|
||||
kai_get_rhs_packed_size(
|
||||
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {}
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32){}
|
||||
};
|
||||
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp
|
||||
kai_select_channelwise_matmul_ukernel(const kai_kernel_id id);
|
||||
|
||||
// bf16 Channelwise Kernel mapping
|
||||
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp {
|
||||
struct kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel ukernel;
|
||||
struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params;
|
||||
size_t (*kai_get_lhs_packed_size)(
|
||||
size_t m,
|
||||
size_t k,
|
||||
size_t mr,
|
||||
size_t kr,
|
||||
size_t sr);
|
||||
size_t (*kai_get_rhs_packed_size)(
|
||||
size_t n,
|
||||
size_t k,
|
||||
size_t nr,
|
||||
size_t kr,
|
||||
size_t sr);
|
||||
void (*kai_run_lhs_quant_pack)(
|
||||
size_t m,
|
||||
size_t k,
|
||||
size_t mr,
|
||||
size_t kr,
|
||||
size_t sr,
|
||||
size_t m_idx_start,
|
||||
const void* lhs,
|
||||
size_t lhs_stride,
|
||||
void* lhs_packed);
|
||||
void (*kai_run_rhs_pack)(
|
||||
size_t num_groups,
|
||||
size_t n,
|
||||
size_t k,
|
||||
size_t nr,
|
||||
size_t kr,
|
||||
size_t sr,
|
||||
const uint8_t* rhs,
|
||||
const float* bias,
|
||||
const float* scale,
|
||||
void* rhs_packed,
|
||||
size_t extra_bytes,
|
||||
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
|
||||
size_t(*kai_get_lhs_quant_pack_offset)(
|
||||
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
|
||||
);
|
||||
|
||||
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp(
|
||||
const kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel& kernel)
|
||||
: ukernel(kernel),
|
||||
kai_get_lhs_packed_size(
|
||||
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon),
|
||||
kai_get_rhs_packed_size(
|
||||
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_bf16_neon),
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon){}
|
||||
};
|
||||
|
||||
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp
|
||||
kai_select_bf16_channelwise_matmul_ukernel(const kai_kernel_id id);
|
||||
|
||||
// Groupwise Kernel mapping
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
|
||||
struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel;
|
||||
@ -125,6 +198,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
|
||||
void* rhs_packed,
|
||||
size_t extra_bytes,
|
||||
const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params);
|
||||
size_t(*kai_get_lhs_quant_pack_offset)(
|
||||
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
|
||||
);
|
||||
|
||||
kai_matmul_ukernel_f32_qa8dxp_qs4c32p(
|
||||
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel)
|
||||
@ -134,7 +210,8 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
|
||||
kai_get_rhs_packed_size(
|
||||
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
|
||||
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {}
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
|
||||
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32) {}
|
||||
};
|
||||
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(
|
||||
|
||||
@ -82,6 +82,7 @@ NSArray<NSNumber*>* getTensorAxes(const TensorBase& t);
|
||||
NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
|
||||
std::string getMPSShapeString(MPSShape* shape);
|
||||
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false);
|
||||
std::string to_hex_key(float);
|
||||
std::string getArrayRefString(const IntArrayRef s);
|
||||
// use has_storage() on the returned tensor to determine if src actually is a view
|
||||
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);
|
||||
|
||||
@ -301,6 +301,10 @@ std::string getArrayRefString(const IntArrayRef s) {
|
||||
return fmt::to_string(fmt::join(s, ","));
|
||||
}
|
||||
|
||||
std::string to_hex_key(float f) {
|
||||
return fmt::format("{:a}", f);
|
||||
}
|
||||
|
||||
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, bool exclude_shape) {
|
||||
fmt::basic_memory_buffer<char, 100> buffer;
|
||||
auto buf_iterator = std::back_inserter(buffer);
|
||||
|
||||
@ -96,7 +96,9 @@ kernel void addmm(
|
||||
auto bias =
|
||||
biasData[thread_id.y * strides[3].x + thread_id.x * strides[3].y];
|
||||
outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] =
|
||||
static_cast<T>(alpha_beta[0] * sum + alpha_beta[1] * bias);
|
||||
static_cast<T>(
|
||||
c10::metal::mul(alpha_beta[0], sum) +
|
||||
c10::metal::mul(alpha_beta[1], bias));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -121,7 +121,7 @@ Tensor& do_metal_addmm(const Tensor& self,
|
||||
const Scalar& alpha,
|
||||
const Scalar& beta,
|
||||
const Tensor& bias) {
|
||||
if (beta.toDouble() == 0 && alpha.toDouble() == 1) {
|
||||
if (beta.isFloatingPoint() && alpha.isFloatingPoint() && beta.toDouble() == 0 && alpha.toDouble() == 1) {
|
||||
return do_metal_mm(self, other, output);
|
||||
}
|
||||
auto stream = getCurrentMPSStream();
|
||||
@ -147,13 +147,15 @@ Tensor& do_metal_addmm(const Tensor& self,
|
||||
std::array<int64_t, 2> i64;
|
||||
std::array<int32_t, 2> i32;
|
||||
std::array<float, 2> f32;
|
||||
} alpha_beta;
|
||||
std::array<c10::complex<float>, 2> c64;
|
||||
} alpha_beta{};
|
||||
if (output.scalar_type() == kLong) {
|
||||
alpha_beta.i64 = {alpha.toLong(), beta.toLong()};
|
||||
} else if (c10::isIntegralType(output.scalar_type(), true)) {
|
||||
alpha_beta.i32 = {alpha.toInt(), beta.toInt()};
|
||||
} else if (c10::isComplexType(output.scalar_type())) {
|
||||
alpha_beta.c64 = {alpha.toComplexFloat(), beta.toComplexFloat()};
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(c10::isFloatingType(output.scalar_type()));
|
||||
alpha_beta.f32 = {alpha.toFloat(), beta.toFloat()};
|
||||
}
|
||||
constexpr uint32_t TILE_DIM = 16; // fastest performance from tests on multiple macs
|
||||
|
||||
@ -91,25 +91,30 @@ static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
|
||||
#include <ATen/native/mps/Repeat_metallib.h>
|
||||
#endif
|
||||
|
||||
template <typename index_t>
|
||||
void computeRepeatIndices(const index_t* repeat_ptr,
|
||||
const int64_t* cumsum_ptr,
|
||||
index_t* result_ptr,
|
||||
int64_t size,
|
||||
int64_t result_size) {
|
||||
id<MTLBuffer> repeatBuffer = reinterpret_cast<id<MTLBuffer>>(repeat_ptr);
|
||||
id<MTLBuffer> cumsumBuffer = reinterpret_cast<id<MTLBuffer>>(cumsum_ptr);
|
||||
id<MTLBuffer> resultBuffer = reinterpret_cast<id<MTLBuffer>>(result_ptr);
|
||||
TORCH_CHECK(repeatBuffer && cumsumBuffer && resultBuffer);
|
||||
|
||||
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
|
||||
TORCH_CHECK(repeat.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
|
||||
std::string scalar_type;
|
||||
if constexpr (std::is_same_v<index_t, int32_t>) {
|
||||
if (repeat.scalar_type() == kInt) {
|
||||
scalar_type = "int32_t";
|
||||
} else if constexpr (std::is_same_v<index_t, int64_t>) {
|
||||
} else if (repeat.scalar_type() == kLong) {
|
||||
scalar_type = "int64_t";
|
||||
} else {
|
||||
TORCH_CHECK(false, "repeat_interleave: unsupported indexing data type");
|
||||
TORCH_CHECK(false, "repeats has to be Long or Int tensor");
|
||||
}
|
||||
if (repeat.size(0) == 0) {
|
||||
return at::empty_like(repeat, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
}
|
||||
Tensor repeat_ = repeat.contiguous();
|
||||
Tensor cumsum = repeat.cumsum(0);
|
||||
int64_t total = 0;
|
||||
if (output_size.has_value()) {
|
||||
total = output_size.value();
|
||||
} else {
|
||||
total = cumsum[-1].item<int64_t>();
|
||||
TORCH_CHECK((repeat >= 0).all().item<uint8_t>(), "repeats can not be negative");
|
||||
}
|
||||
|
||||
auto result = at::empty({total}, repeat.options());
|
||||
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
dispatch_sync(mpsStream->queue(), ^() {
|
||||
@ -121,20 +126,13 @@ void computeRepeatIndices(const index_t* repeat_ptr,
|
||||
getMPSProfiler().beginProfileKernel(pipelineState, "repeat_interleave:" + scalar_type, false);
|
||||
|
||||
[computeEncoder setComputePipelineState:pipelineState];
|
||||
mps::mtl_setArgs(computeEncoder, repeatBuffer, cumsumBuffer, resultBuffer, size);
|
||||
mps::mtl_dispatch1DJob(computeEncoder, pipelineState, size);
|
||||
mps::mtl_setArgs(computeEncoder, repeat_, cumsum, result, repeat.size(0));
|
||||
mps::mtl_dispatch1DJob(computeEncoder, pipelineState, repeat.size(0));
|
||||
|
||||
getMPSProfiler().endProfileKernel(pipelineState);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
|
||||
Tensor output;
|
||||
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
|
||||
output = repeat_interleave_common<index_t, computeRepeatIndices<index_t>>(repeat, output_size);
|
||||
});
|
||||
return output;
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/TensorCompare.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <algorithm>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -89,13 +90,21 @@ static void check_min_max_dims(const OptionalTensorRef clamp_opt, const Tensor&
|
||||
auto clamp_shape = clamp_opt->sizes();
|
||||
auto input_shape = input_t.sizes();
|
||||
|
||||
TORCH_CHECK(num_clamp_dims <= num_input_dims,
|
||||
op_name + ": clamp tensor number of dims must not be greater than that of input tensor")
|
||||
if (num_clamp_dims > num_input_dims) {
|
||||
auto leading_dims = num_clamp_dims - num_input_dims;
|
||||
for (int64_t i = 0; i < leading_dims; ++i) {
|
||||
TORCH_CHECK(clamp_shape[i] == 1,
|
||||
op_name + ": clamp tensor leading shape must be 1 to broadcast with input tensor");
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_clamp_dims; i++)
|
||||
auto clamp_idx = num_clamp_dims - 1;
|
||||
auto input_idx = num_input_dims - 1;
|
||||
auto common_dims = std::min(num_clamp_dims, num_input_dims);
|
||||
for (int64_t i = 0; i < common_dims; ++i)
|
||||
// One of the indices is allowed to be 1; will be handled by broadcast
|
||||
TORCH_CHECK(clamp_shape[num_clamp_dims - 1 - i] == input_shape[num_input_dims - 1 - i] ||
|
||||
clamp_shape[num_clamp_dims - 1 - i] == 1 || input_shape[num_input_dims - 1 - i] == 1,
|
||||
TORCH_CHECK(clamp_shape[clamp_idx - i] == input_shape[input_idx - i] || clamp_shape[clamp_idx - i] == 1 ||
|
||||
input_shape[input_idx - i] == 1,
|
||||
op_name + ": clamp tensor trailing shape must match input tensor")
|
||||
}
|
||||
}
|
||||
@ -136,9 +145,6 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
|
||||
|
||||
auto result_type = output_t.scalar_type();
|
||||
|
||||
IntArrayRef new_min_shape;
|
||||
IntArrayRef new_max_shape;
|
||||
|
||||
auto num_min_dims = min_opt->dim();
|
||||
auto num_max_dims = max_opt->dim();
|
||||
auto num_input_dims = input_t.dim();
|
||||
@ -146,24 +152,32 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
|
||||
std::vector<int64_t> new_min_arr(num_input_dims);
|
||||
std::vector<int64_t> new_max_arr(num_input_dims);
|
||||
|
||||
if (has_min && num_min_dims < num_input_dims) {
|
||||
fill_new_shape(num_input_dims, num_min_dims, new_min_arr.data(), min_opt->sizes());
|
||||
new_min_shape = IntArrayRef(new_min_arr);
|
||||
}
|
||||
|
||||
if (has_max && num_max_dims < num_input_dims) {
|
||||
fill_new_shape(num_input_dims, num_max_dims, new_max_arr.data(), max_opt->sizes());
|
||||
new_max_shape = IntArrayRef(new_max_arr);
|
||||
}
|
||||
|
||||
Tensor min_opt_tensor;
|
||||
Tensor max_opt_tensor;
|
||||
|
||||
auto reshape_clamp_tensor = [&](const OptionalTensorRef clamp_tensor_ref,
|
||||
int64_t num_clamp_dims,
|
||||
std::vector<int64_t>& new_shape_storage) -> Tensor {
|
||||
IntArrayRef clamp_shape = clamp_tensor_ref->sizes();
|
||||
bool requires_view = false;
|
||||
|
||||
if (num_clamp_dims > num_input_dims) {
|
||||
clamp_shape = clamp_shape.slice(num_clamp_dims - num_input_dims);
|
||||
requires_view = true;
|
||||
} else if (num_clamp_dims < num_input_dims) {
|
||||
fill_new_shape(num_input_dims, num_clamp_dims, new_shape_storage.data(), clamp_shape);
|
||||
clamp_shape = IntArrayRef(new_shape_storage);
|
||||
requires_view = true;
|
||||
}
|
||||
|
||||
return requires_view ? (*clamp_tensor_ref).view(clamp_shape) : *clamp_tensor_ref;
|
||||
};
|
||||
|
||||
if (has_min) {
|
||||
min_opt_tensor = (num_min_dims < num_input_dims) ? (*min_opt).view(new_min_shape) : *min_opt;
|
||||
min_opt_tensor = reshape_clamp_tensor(min_opt, num_min_dims, new_min_arr);
|
||||
}
|
||||
if (has_max) {
|
||||
max_opt_tensor = (num_max_dims < num_input_dims) ? (*max_opt).view(new_max_shape) : *max_opt;
|
||||
max_opt_tensor = reshape_clamp_tensor(max_opt, num_max_dims, new_max_arr);
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
@ -244,8 +258,8 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
|
||||
|
||||
@autoreleasepool {
|
||||
// the optional min/max refs could affect how we build the cached graph
|
||||
std::string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
std::string key = op_name + (has_min ? ("_min:" + to_hex_key(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + to_hex_key(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
if (has_min)
|
||||
newCachedGraph->minTensor = [mpsGraph constantWithScalar:min_scalar
|
||||
|
||||
@ -61,6 +61,7 @@ list(APPEND ATen_CUDA_TEST_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_math_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cub_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cublas_handle_pool_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp
|
||||
|
||||
77
aten/src/ATen/test/cuda_cublas_handle_pool_test.cpp
Normal file
77
aten/src/ATen/test/cuda_cublas_handle_pool_test.cpp
Normal file
@ -0,0 +1,77 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
// Test concurrent access to getCurrentCUDABlasHandle and getCUDABlasLtWorkspace
|
||||
// to verify that the data race fix is working correctly
|
||||
|
||||
TEST(CUDABlasHandlePoolTest, ConcurrentGetAndClearWorkspaces) {
|
||||
if (!at::cuda::is_available()) {
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int num_accessor_threads = 15;
|
||||
constexpr int num_clear_threads = 5;
|
||||
constexpr int iterations_per_thread = 50;
|
||||
|
||||
std::atomic<bool> stop{false};
|
||||
std::atomic<int> error_count{0};
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(num_accessor_threads + num_clear_threads);
|
||||
|
||||
// Launch accessor threads
|
||||
for (int i = 0; i < num_accessor_threads; ++i) {
|
||||
threads.emplace_back([&stop, &error_count]() {
|
||||
try {
|
||||
at::cuda::CUDAGuard device_guard(0);
|
||||
|
||||
while (!stop.load(std::memory_order_relaxed)) {
|
||||
const auto handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
const auto workspace = at::cuda::getCUDABlasLtWorkspace();
|
||||
|
||||
if (handle == nullptr || workspace == nullptr) {
|
||||
error_count++;
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
error_count++;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Launch threads that clear workspaces
|
||||
for (int i = 0; i < num_clear_threads; ++i) {
|
||||
threads.emplace_back([&error_count]() {
|
||||
try {
|
||||
for (int j = 0; j < iterations_per_thread; ++j) {
|
||||
at::cuda::clearCublasWorkspaces();
|
||||
std::this_thread::yield();
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
error_count++;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Let them run for a bit
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
stop.store(true, std::memory_order_relaxed);
|
||||
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
|
||||
EXPECT_EQ(error_count.load(), 0);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
c10::cuda::CUDACachingAllocator::init(1);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@ -10,6 +10,13 @@
|
||||
...
|
||||
}
|
||||
|
||||
{
|
||||
ignore_empty_generic_uninitialised_conditional_jump
|
||||
Memcheck:Cond
|
||||
fun:_ZN2at6detail13empty_genericEN3c108ArrayRefIlEEPNS1_9AllocatorENS1_14DispatchKeySetENS1_10ScalarTypeESt8optionalINS1_12MemoryFormatEE
|
||||
...
|
||||
}
|
||||
|
||||
{
|
||||
Cond_cuda
|
||||
Memcheck:Cond
|
||||
|
||||
@ -9,28 +9,61 @@ def check_perf_csv(filename, threshold, threshold_scale):
|
||||
"""
|
||||
Basic performance checking.
|
||||
"""
|
||||
try:
|
||||
df = pd.read_csv(filename)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: File {filename} not found")
|
||||
sys.exit(1)
|
||||
|
||||
df = pd.read_csv(filename)
|
||||
effective_threshold = threshold * threshold_scale
|
||||
print(f"Checking {filename} (speedup threshold >= {effective_threshold:.2f}x)\n")
|
||||
|
||||
failed = []
|
||||
for _, row in df.iterrows():
|
||||
model_name = row["name"]
|
||||
speedup = row["speedup"]
|
||||
if speedup < threshold * threshold_scale:
|
||||
failed.append(model_name)
|
||||
speedup = float(row["speedup"])
|
||||
abs_latency = float(row["abs_latency"])
|
||||
compilation_latency = float(row["compilation_latency"])
|
||||
compression_ratio = float(row["compression_ratio"])
|
||||
eager_peak_mem = float(row["eager_peak_mem"])
|
||||
dynamo_peak_mem = float(row["dynamo_peak_mem"])
|
||||
|
||||
print(f"{model_name:34} {speedup}")
|
||||
perf_summary = f"{model_name:34} speedup={speedup:.3f}x"
|
||||
if pd.notna(abs_latency):
|
||||
perf_summary += f", latency={abs_latency:.1f} ms/iter"
|
||||
if pd.notna(compilation_latency):
|
||||
perf_summary += f", compile={compilation_latency:.3f}s"
|
||||
if pd.notna(compression_ratio):
|
||||
perf_summary += f", mem_ratio={1 / compression_ratio:.2f}x"
|
||||
if pd.notna(eager_peak_mem) and pd.notna(dynamo_peak_mem):
|
||||
perf_summary += (
|
||||
f" (eager={eager_peak_mem:.1f} GB, dynamo={dynamo_peak_mem:.1f} GB)"
|
||||
)
|
||||
|
||||
if speedup < effective_threshold:
|
||||
failed.append((model_name, speedup))
|
||||
|
||||
print(perf_summary)
|
||||
|
||||
if failed:
|
||||
print(
|
||||
textwrap.dedent(
|
||||
f"""
|
||||
Error {len(failed)} models performance regressed
|
||||
{" ".join(failed)}
|
||||
Error {len(failed)} model(s) performance regressed
|
||||
{" ".join([name for name, _ in failed])}
|
||||
"""
|
||||
)
|
||||
)
|
||||
for name, sp in sorted(failed, key=lambda x: x[1]):
|
||||
pct_from_target = (sp / effective_threshold - 1.0) * 100.0
|
||||
print(
|
||||
f" - {name}: {sp:.3f}x (< {effective_threshold:.2f}x; {pct_from_target:.1f}% from target)"
|
||||
)
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(
|
||||
f"\nAll {len(df)} model(s) passed threshold check (>= {effective_threshold:.2f}x)"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -44,7 +77,7 @@ if __name__ == "__main__":
|
||||
"-s",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="multiple threshold by this value to relax the check",
|
||||
help="multiply threshold by this value to relax the check",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
check_perf_csv(args.file, args.threshold, args.threshold_scale)
|
||||
|
||||
@ -189,6 +189,10 @@ skip:
|
||||
- hf_Whisper
|
||||
- hf_distil_whisper
|
||||
- timm_vision_transformer_large
|
||||
# https://github.com/pytorch/pytorch/issues/167895
|
||||
- stable_diffusion
|
||||
- stable_diffusion_text_encoder
|
||||
- stable_diffusion_unet
|
||||
|
||||
device:
|
||||
cpu:
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# These load paths point to different files in internal and OSS environment
|
||||
|
||||
load("@bazel_skylib//lib:paths.bzl", "paths")
|
||||
load("//tools/build_defs:cell_defs.bzl", "get_fbsource_cell")
|
||||
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
|
||||
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
|
||||
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
|
||||
@ -590,6 +591,9 @@ def pt_operator_query_codegen(
|
||||
pt_allow_forced_schema_registration = True,
|
||||
compatible_with = [],
|
||||
apple_sdks = None):
|
||||
if get_fbsource_cell() == "fbcode":
|
||||
return
|
||||
|
||||
oplist_dir_name = name + "_pt_oplist"
|
||||
|
||||
# @lint-ignore BUCKLINT
|
||||
@ -865,6 +869,9 @@ def define_buck_targets(
|
||||
pt_xplat_cxx_library = fb_xplat_cxx_library,
|
||||
c2_fbandroid_xplat_compiler_flags = [],
|
||||
labels = []):
|
||||
if get_fbsource_cell() == "fbcode":
|
||||
return
|
||||
|
||||
# @lint-ignore BUCKLINT
|
||||
fb_native.filegroup(
|
||||
name = "metal_build_srcs",
|
||||
|
||||
@ -44,7 +44,7 @@ struct C10_API SafePyObject {
|
||||
(*other.pyinterpreter_)->incref(other.data_);
|
||||
}
|
||||
if (data_ != nullptr) {
|
||||
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
|
||||
(*pyinterpreter_)->decref(data_);
|
||||
}
|
||||
data_ = other.data_;
|
||||
pyinterpreter_ = other.pyinterpreter_;
|
||||
@ -53,7 +53,7 @@ struct C10_API SafePyObject {
|
||||
|
||||
~SafePyObject() {
|
||||
if (data_ != nullptr) {
|
||||
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
|
||||
(*pyinterpreter_)->decref(data_);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -34,20 +34,6 @@ namespace c10 {
|
||||
// See [dtype Macros note] in torch/headeronly/core/ScalarType.h
|
||||
// regarding macros.
|
||||
|
||||
template <typename T>
|
||||
struct CppTypeToScalarType;
|
||||
|
||||
#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
|
||||
template <> \
|
||||
struct CppTypeToScalarType<cpp_type> \
|
||||
: std:: \
|
||||
integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \
|
||||
};
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
||||
|
||||
#undef SPECIALIZE_CppTypeToScalarType
|
||||
|
||||
#define DEFINE_CONSTANT(_, name) \
|
||||
constexpr ScalarType k##name = ScalarType::name;
|
||||
|
||||
@ -106,13 +92,6 @@ inline bool isComplexType(ScalarType t) {
|
||||
t == ScalarType::ComplexDouble);
|
||||
}
|
||||
|
||||
inline bool isQIntType(ScalarType t) {
|
||||
// Don't forget to extend this when adding new QInt types
|
||||
return t == ScalarType::QInt8 || t == ScalarType::QUInt8 ||
|
||||
t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 ||
|
||||
t == ScalarType::QUInt2x4;
|
||||
}
|
||||
|
||||
inline bool isBitsType(ScalarType t) {
|
||||
return t == ScalarType::Bits1x8 || t == ScalarType::Bits2x4 ||
|
||||
t == ScalarType::Bits4x2 || t == ScalarType::Bits8 ||
|
||||
|
||||
@ -48,6 +48,30 @@ void warnDeprecatedDataPtr() {
|
||||
TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid.");
|
||||
}
|
||||
|
||||
void StorageImpl::incref_pyobject() const {
|
||||
// Because intrusive_ptr incref uses relaxed memory order, we need to
|
||||
// do an acquire fence to ensure that the kHasPyObject bit was
|
||||
// observed before the load of the PyObject* below.
|
||||
// NB: This is a no-op on x86/x86-64
|
||||
std::atomic_thread_fence(std::memory_order_acquire);
|
||||
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
|
||||
}
|
||||
|
||||
void StorageImpl::decref_pyobject() const {
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
|
||||
}
|
||||
|
||||
bool StorageImpl::try_incref_pyobject() const {
|
||||
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
|
||||
if (C10_UNLIKELY(!interp)) {
|
||||
return false;
|
||||
}
|
||||
return (*interp)->try_incref(pyobj_slot_);
|
||||
}
|
||||
|
||||
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
|
||||
// Allowlist verification.
|
||||
// Only if the devicetype is in the allowlist,
|
||||
|
||||
@ -105,6 +105,12 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
||||
data_ptr_.clear();
|
||||
}
|
||||
|
||||
void incref_pyobject() const override final;
|
||||
|
||||
void decref_pyobject() const override final;
|
||||
|
||||
bool try_incref_pyobject() const override final;
|
||||
|
||||
size_t nbytes() const {
|
||||
// OK to do this instead of maybe_as_int as nbytes is guaranteed positive
|
||||
TORCH_CHECK(!size_bytes_is_heap_allocated_);
|
||||
@ -370,4 +376,18 @@ C10_API c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
|
||||
bool resizable,
|
||||
std::optional<at::Device> device_opt);
|
||||
|
||||
namespace detail {
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
template <class T>
|
||||
struct TargetTraits<
|
||||
T,
|
||||
std::enable_if_t<
|
||||
std::is_base_of_v<c10::StorageImpl, std::remove_cv_t<T>>>> {
|
||||
static constexpr bool can_have_pyobject = true;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace c10
|
||||
|
||||
@ -277,7 +277,6 @@ void TensorImpl::release_resources() {
|
||||
if (storage_) {
|
||||
storage_ = {};
|
||||
}
|
||||
pyobj_slot_.maybe_destroy_pyobj();
|
||||
}
|
||||
|
||||
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
||||
@ -989,6 +988,30 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) {
|
||||
}
|
||||
}
|
||||
|
||||
void TensorImpl::incref_pyobject() const {
|
||||
// Because intrusive_ptr incref uses relaxed memory order, we need to
|
||||
// do an acquire fence to ensure that the kHasPyObject bit was
|
||||
// observed before the load of the PyObject* below.
|
||||
// NB: This is a no-op on x86/x86-64
|
||||
std::atomic_thread_fence(std::memory_order_acquire);
|
||||
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
|
||||
}
|
||||
|
||||
void TensorImpl::decref_pyobject() const {
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
|
||||
}
|
||||
|
||||
bool TensorImpl::try_incref_pyobject() const {
|
||||
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
|
||||
if (C10_UNLIKELY(!interp)) {
|
||||
return false;
|
||||
}
|
||||
return (*interp)->try_incref(pyobj_slot_);
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -2178,6 +2178,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return &pyobj_slot_;
|
||||
}
|
||||
|
||||
void incref_pyobject() const override final;
|
||||
|
||||
void decref_pyobject() const override final;
|
||||
|
||||
bool try_incref_pyobject() const override final;
|
||||
|
||||
private:
|
||||
// See NOTE [std::optional operator usage in CUDA]
|
||||
// We probably don't want to expose this publicly until
|
||||
@ -3079,6 +3085,19 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
friend class C10_TensorImpl_Size_Check_Dummy_Class;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
template <class T>
|
||||
struct TargetTraits<
|
||||
T,
|
||||
std::enable_if_t<std::is_base_of_v<c10::TensorImpl, std::remove_cv_t<T>>>> {
|
||||
static constexpr bool can_have_pyobject = true;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Note [TensorImpl size constraints]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Changed the size of TensorImpl? If the size went down, good for
|
||||
|
||||
@ -11,8 +11,11 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
||||
|
||||
void incref(PyObject* pyobj) const override {} // do nothing
|
||||
|
||||
void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
|
||||
} // do nothing
|
||||
void decref(PyObject* pyobj) const override {} // do nothing
|
||||
|
||||
bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
#define PANIC(m) \
|
||||
TORCH_INTERNAL_ASSERT( \
|
||||
@ -20,6 +23,10 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
||||
"attempted to call " #m \
|
||||
" on a Tensor with nontrivial PyObject after corresponding interpreter died")
|
||||
|
||||
size_t refcnt(PyObject* pyobj) const override {
|
||||
PANIC(refcnt);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<TensorImpl> detach(const TensorImpl* self) const override {
|
||||
PANIC(detach);
|
||||
}
|
||||
|
||||
@ -18,6 +18,9 @@ namespace c10 {
|
||||
struct IValue;
|
||||
class OperatorHandle;
|
||||
struct TensorImpl;
|
||||
namespace impl {
|
||||
struct PyObjectSlot;
|
||||
} // namespace impl
|
||||
} // namespace c10
|
||||
|
||||
namespace torch::jit {
|
||||
@ -126,9 +129,12 @@ struct C10_API PyInterpreterVTable {
|
||||
|
||||
// Run Py_INCREF on a PyObject.
|
||||
virtual void incref(PyObject* pyobj) const = 0;
|
||||
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
|
||||
// See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
|
||||
virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
|
||||
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call.
|
||||
virtual void decref(PyObject* pyobj) const = 0;
|
||||
// Run PyUnstable_TryIncRef on a PyObject if it's not NULL.
|
||||
virtual bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const = 0;
|
||||
// Run Py_REFCNT on a PyObject.
|
||||
virtual size_t refcnt(PyObject* pyobj) const = 0;
|
||||
|
||||
// Perform a detach by deferring to the __torch_dispatch__ implementation of
|
||||
// detach, which will also arrange for the PyObject to get copied in this
|
||||
|
||||
@ -1,56 +0,0 @@
|
||||
#include <c10/core/impl/PyObjectSlot.h>
|
||||
|
||||
namespace c10::impl {
|
||||
|
||||
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
|
||||
|
||||
PyObjectSlot::~PyObjectSlot() {
|
||||
maybe_destroy_pyobj();
|
||||
}
|
||||
|
||||
void PyObjectSlot::maybe_destroy_pyobj() {
|
||||
if (owns_pyobj()) {
|
||||
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
|
||||
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
|
||||
(*pyobj_interpreter_.load(std::memory_order_acquire))
|
||||
->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
|
||||
// NB: this destructor can only be entered when there are no
|
||||
// references to this C++ object (obviously), NOR any references
|
||||
// to the PyObject (if there are references to the PyObject,
|
||||
// then the PyObject holds an owning reference to the tensor).
|
||||
// So it is OK to clear pyobj_ here as it is impossible for it to
|
||||
// be used again (modulo weak reference races)
|
||||
pyobj_ = nullptr; // for safety
|
||||
}
|
||||
}
|
||||
|
||||
PyInterpreter* PyObjectSlot::pyobj_interpreter() {
|
||||
return pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
return reinterpret_cast<PyObject*>(
|
||||
reinterpret_cast<uintptr_t>(pyobj_) & ~0x1ULL);
|
||||
}
|
||||
|
||||
PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
|
||||
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
if (interpreter) {
|
||||
return *interpreter;
|
||||
}
|
||||
TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set");
|
||||
}
|
||||
|
||||
bool PyObjectSlot::owns_pyobj() {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
return reinterpret_cast<uintptr_t>(pyobj_) & 1;
|
||||
}
|
||||
|
||||
void PyObjectSlot::set_owns_pyobj(bool b) {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
pyobj_ = reinterpret_cast<PyObject*>(
|
||||
reinterpret_cast<uintptr_t>(_unchecked_untagged_pyobj()) | b);
|
||||
}
|
||||
|
||||
} // namespace c10::impl
|
||||
@ -8,117 +8,58 @@
|
||||
|
||||
#include <atomic>
|
||||
|
||||
namespace torch::utils {
|
||||
class PyObjectPreservation;
|
||||
}
|
||||
|
||||
namespace c10::impl {
|
||||
|
||||
struct C10_API PyObjectSlot {
|
||||
public:
|
||||
PyObjectSlot();
|
||||
|
||||
~PyObjectSlot();
|
||||
|
||||
void maybe_destroy_pyobj();
|
||||
|
||||
// Associate the TensorImpl with the specified PyObject, and, if necessary,
|
||||
// also tag the interpreter.
|
||||
//
|
||||
// NB: This lives in a header so that we can inline away the switch on status
|
||||
//
|
||||
// NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
|
||||
// PyObject if necessary!
|
||||
void init_pyobj(PyObject* pyobj) {
|
||||
pyobj_interpreter_.store(
|
||||
getGlobalPyInterpreter(), std::memory_order_relaxed);
|
||||
pyobj_ = pyobj;
|
||||
}
|
||||
PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
|
||||
|
||||
// Query the PyObject interpreter. This may return null if there is no
|
||||
// interpreter. This is racy!
|
||||
PyInterpreter* pyobj_interpreter();
|
||||
|
||||
PyObject* _unchecked_untagged_pyobj() const;
|
||||
|
||||
// Test the interpreter tag. If tagged for the current interpreter, return
|
||||
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
|
||||
// returns a nullopt. If it is definitely invalid, raises an error.
|
||||
//
|
||||
// If `ignore_hermetic_tls` is false and this function is called from a
|
||||
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
|
||||
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
|
||||
// context is ignored, allowing you to check the interpreter tag of a
|
||||
// nonhermetic PyObject from within a hermetic context. This is necessary
|
||||
// because there are some cases where the deallocator function of a
|
||||
// nonhermetic PyObject is called from within a hermetic context, so it must
|
||||
// be properly treated as a nonhermetic PyObject.
|
||||
//
|
||||
// NB: this lives in header so that we can avoid actually creating the
|
||||
// std::optional
|
||||
|
||||
// @todo alban: I'm not too sure what's going on here, we can probably delete
|
||||
// it but it's worthwhile making sure
|
||||
std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const {
|
||||
impl::PyInterpreter* interpreter =
|
||||
pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
if (interpreter == nullptr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return _unchecked_untagged_pyobj();
|
||||
}
|
||||
// interpreter.
|
||||
PyInterpreter* pyobj_interpreter() const {
|
||||
return pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
PyInterpreter& load_pyobj_interpreter() const;
|
||||
PyInterpreter& load_pyobj_interpreter() const {
|
||||
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
interpreter, "cannot access PyObject for Tensor - no interpreter set");
|
||||
return *interpreter;
|
||||
}
|
||||
|
||||
bool owns_pyobj();
|
||||
PyObject* load_pyobj() const {
|
||||
return pyobj_.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
void set_owns_pyobj(bool b);
|
||||
void store_pyobj(PyObject* obj) {
|
||||
pyobj_.store(obj, std::memory_order_release);
|
||||
}
|
||||
|
||||
bool has_unique_reference() const {
|
||||
PyObject* pyobj = load_pyobj();
|
||||
return pyobj != nullptr && load_pyobj_interpreter()->refcnt(pyobj) == 1;
|
||||
}
|
||||
|
||||
void clear() {
|
||||
pyobj_.store(nullptr, std::memory_order_relaxed);
|
||||
pyobj_interpreter_.store(nullptr, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
private:
|
||||
// This field contains the interpreter tag for this object. See
|
||||
// Note [Python interpreter tag] for general context
|
||||
//
|
||||
// Note [Memory ordering on Python interpreter tag]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// What memory_order do we need when accessing this atomic? We don't
|
||||
// need a single total modification order (as provided by
|
||||
// memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only
|
||||
// transition from -1 to some positive integer and never changes afterwards.
|
||||
// Because there is only one modification, it trivially already has a total
|
||||
// modification order (e.g., we don't need fences or locked instructions on
|
||||
// x86)
|
||||
//
|
||||
// In fact, one could make a reasonable argument that relaxed reads are OK,
|
||||
// due to the presence of external locking (GIL) to ensure that interactions
|
||||
// with other data structures are still correctly synchronized, so that
|
||||
// we fall in the "Single-Location Data Structures" case as described in
|
||||
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
|
||||
// However, on x86, it doesn't matter if I use acquire or relaxed on the load
|
||||
// as I get the same assembly in both cases. So I just use the more
|
||||
// conservative acquire (which will impede compiler optimizations but I don't
|
||||
// care)
|
||||
// This is now always the global interpreter if the PyObject is set.
|
||||
// Maybe we can remove this field some day...
|
||||
std::atomic<PyInterpreter*> pyobj_interpreter_;
|
||||
|
||||
// This field contains a reference to a PyObject representing this Tensor.
|
||||
// If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new
|
||||
// PyObject for it and set this field. This field does not have to be
|
||||
// protected by an atomic as it is only allowed to be accessed when you hold
|
||||
// the GIL, or during destruction of the tensor.
|
||||
//
|
||||
// When a PyObject dies, you are obligated to clear this field
|
||||
// (otherwise, you will try to use-after-free the pyobj); this currently
|
||||
// occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp
|
||||
//
|
||||
// NB: Ordinarily, this should not be a strong reference, as if the
|
||||
// PyObject owns the Tensor, this would create a reference cycle.
|
||||
// However, sometimes this ownership flips. To track who owns
|
||||
// who, this has a single pointer tag indicating whether or not the
|
||||
// C++ object owns the PyObject (the common case, zero, means PyObject
|
||||
// owns the C++ object); see _unchecked_untagged_pyobj for raw access
|
||||
// or check_pyobj for checked access. See references to PyObject
|
||||
// resurrection in torch/csrc/autograd/python_variable.cpp
|
||||
PyObject* pyobj_;
|
||||
// The PyObject representing this Tensor or nullptr. Ownership is managed
|
||||
// by intrusive_ptr. By the time the PyObjectSlot is destroyed, this
|
||||
// reference is already dead.
|
||||
std::atomic<PyObject*> pyobj_;
|
||||
|
||||
friend class torch::utils::PyObjectPreservation;
|
||||
};
|
||||
|
||||
} // namespace c10::impl
|
||||
|
||||
@ -50,7 +50,13 @@ namespace c10 {
|
||||
/// However, you should prefer to use ArrayRef when possible, because its use
|
||||
/// of TORCH_CHECK will lead to better user-facing error messages.
|
||||
template <typename T>
|
||||
class ArrayRef final : public HeaderOnlyArrayRef<T> {
|
||||
// ArrayRef cannot be derived from. Normally, we would use `final`
|
||||
// specifier to force this constraint at compile time. However, Intel
|
||||
// compiler does not recognize ArrayRef as a class template (which is
|
||||
// required in the definition of at::TensorAccessor, for instance)
|
||||
// when `final` specifier is used. So, we cannot define ArrayRef as
|
||||
// final because of the Intel compiler issue.
|
||||
class ArrayRef : public HeaderOnlyArrayRef<T> {
|
||||
public:
|
||||
/// @name Constructors, all inherited from HeaderOnlyArrayRef except for
|
||||
/// SmallVector. As inherited constructors won't work with class template
|
||||
|
||||
@ -379,7 +379,11 @@ C10_API std::string GetExceptionString(const std::exception& e);
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
#ifdef STRIP_ERROR_MESSAGES
|
||||
#define TORCH_RETHROW(e, ...) throw
|
||||
#define TORCH_RETHROW(e, ...) \
|
||||
do { \
|
||||
(void)e; /* Suppress unused variable warning */ \
|
||||
throw; \
|
||||
} while (false)
|
||||
#else
|
||||
#define TORCH_RETHROW(e, ...) \
|
||||
do { \
|
||||
|
||||
@ -12,6 +12,10 @@ template <typename, typename...>
|
||||
class class_;
|
||||
}
|
||||
|
||||
namespace torch::utils {
|
||||
class PyObjectPreservation;
|
||||
}
|
||||
|
||||
namespace c10 {
|
||||
class intrusive_ptr_target;
|
||||
namespace raw {
|
||||
@ -33,6 +37,8 @@ constexpr uint64_t kImpracticallyHugeWeakReferenceCount =
|
||||
constexpr uint64_t kReferenceCountOne = 1;
|
||||
constexpr uint64_t kWeakReferenceCountOne = (kReferenceCountOne << 32);
|
||||
constexpr uint64_t kUniqueRef = (kReferenceCountOne | kWeakReferenceCountOne);
|
||||
// Indicates whether the object has a PyObject wrapper.
|
||||
constexpr uint64_t kHasPyObject = (uint64_t(1) << 63);
|
||||
|
||||
template <class TTarget>
|
||||
struct intrusive_target_default_null_type final {
|
||||
@ -55,7 +61,11 @@ inline uint32_t refcount(uint64_t combined_refcount) {
|
||||
}
|
||||
|
||||
inline uint32_t weakcount(uint64_t combined_refcount) {
|
||||
return static_cast<uint32_t>(combined_refcount >> 32);
|
||||
return static_cast<uint32_t>((combined_refcount & ~kHasPyObject) >> 32);
|
||||
}
|
||||
|
||||
inline bool has_pyobject(uint64_t combined_refcount) {
|
||||
return (combined_refcount & kHasPyObject) != 0;
|
||||
}
|
||||
|
||||
// The only requirement for refcount increment is that it happens-before
|
||||
@ -66,12 +76,6 @@ inline uint64_t atomic_combined_refcount_increment(
|
||||
return combined_refcount.fetch_add(inc, std::memory_order_relaxed) + inc;
|
||||
}
|
||||
|
||||
inline uint32_t atomic_refcount_increment(
|
||||
std::atomic<uint64_t>& combined_refcount) {
|
||||
return detail::refcount(atomic_combined_refcount_increment(
|
||||
combined_refcount, kReferenceCountOne));
|
||||
}
|
||||
|
||||
inline uint32_t atomic_weakcount_increment(
|
||||
std::atomic<uint64_t>& combined_refcount) {
|
||||
return detail::weakcount(atomic_combined_refcount_increment(
|
||||
@ -99,6 +103,11 @@ inline uint32_t atomic_weakcount_decrement(
|
||||
combined_refcount, kWeakReferenceCountOne));
|
||||
}
|
||||
|
||||
template <class T, class = void>
|
||||
struct TargetTraits {
|
||||
static constexpr bool can_have_pyobject = false;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
@ -155,6 +164,23 @@ class C10_API intrusive_ptr_target {
|
||||
// we can atomically operate on both at the same time for performance
|
||||
// and defined behaviors.
|
||||
//
|
||||
// Note [PyObject preservation for Tensor and Storages]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// intrusive_ptr has special support for preserving PyObject wrappers
|
||||
// for TensorImpl and StorageImpl. The most significant bit (kHasPyObject) of
|
||||
// the combined_refcount_ is used to indicate whether the object has a
|
||||
// PyObject wrapper.
|
||||
//
|
||||
// - The PyObject, if it exists, holds a strong reference to the
|
||||
// intrusive_ptr_target.
|
||||
//
|
||||
// - When the refcount goes from 1 to 2, we incref the PyObject.
|
||||
//
|
||||
// - When the refcount goes from 2 to 1, we decref the PyObject.
|
||||
//
|
||||
// In other words, the intrusive_ptr keeps the PyObject alive as long as there
|
||||
// are other C++ references to the intrusive_ptr_target.
|
||||
|
||||
mutable std::atomic<uint64_t> combined_refcount_;
|
||||
static_assert(sizeof(std::atomic<uint64_t>) == 8);
|
||||
static_assert(alignof(std::atomic<uint64_t>) == 8);
|
||||
@ -172,6 +198,8 @@ class C10_API intrusive_ptr_target {
|
||||
template <typename T>
|
||||
friend struct ExclusivelyOwnedTensorTraits;
|
||||
|
||||
friend class torch::utils::PyObjectPreservation;
|
||||
|
||||
protected:
|
||||
// protected destructor. We never want to destruct intrusive_ptr_target*
|
||||
// directly.
|
||||
@ -255,6 +283,16 @@ class C10_API intrusive_ptr_target {
|
||||
*/
|
||||
virtual void release_resources() {}
|
||||
|
||||
/**
|
||||
* These two methods are called when the refcount transitions between one
|
||||
* and two and the object has a PyObject wrapper.
|
||||
*/
|
||||
virtual void incref_pyobject() const {}
|
||||
virtual void decref_pyobject() const {}
|
||||
virtual bool try_incref_pyobject() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t refcount(std::memory_order order = std::memory_order_relaxed) const {
|
||||
return detail::refcount(combined_refcount_.load(order));
|
||||
}
|
||||
@ -265,6 +303,19 @@ class C10_API intrusive_ptr_target {
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
template <>
|
||||
struct TargetTraits<c10::intrusive_ptr_target> {
|
||||
// A generic intrusive_ptr<intrusive_ptr_target> may actually be a TensorImpl
|
||||
// or StorageImpl, so we have to allow for PyObject support.
|
||||
static constexpr bool can_have_pyobject = true;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <class TTarget, class NullType>
|
||||
class weak_intrusive_ptr;
|
||||
|
||||
@ -314,18 +365,34 @@ class intrusive_ptr final {
|
||||
|
||||
void retain_() {
|
||||
if (target_ != NullType::singleton()) {
|
||||
uint32_t new_refcount =
|
||||
detail::atomic_refcount_increment(target_->combined_refcount_);
|
||||
uint64_t combined = detail::atomic_combined_refcount_increment(
|
||||
target_->combined_refcount_, detail::kReferenceCountOne);
|
||||
uint32_t new_refcount = detail::refcount(combined);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
new_refcount != 1,
|
||||
"intrusive_ptr: Cannot increase refcount after it reached zero.");
|
||||
|
||||
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
|
||||
// If the refcount transitioned from 1 to 2, we need to incref the
|
||||
// PyObject. In other words, we need to ensure that the PyObject stays
|
||||
// alive now that we have a C++ reference to this object in addition to
|
||||
// the PyObject itself.
|
||||
if (C10_UNLIKELY(
|
||||
detail::has_pyobject(combined) &&
|
||||
detail::refcount(combined) == 2)) {
|
||||
target_->incref_pyobject();
|
||||
}
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
!detail::has_pyobject(combined),
|
||||
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void reset_() noexcept {
|
||||
if (target_ != NullType::singleton()) {
|
||||
if (target_->combined_refcount_.load(std::memory_order_acquire) ==
|
||||
detail::kUniqueRef) {
|
||||
if (is_uniquely_owned()) {
|
||||
// Both counts are 1, so there are no weak references and
|
||||
// we are releasing the last strong reference. No other
|
||||
// threads can observe the effects of this target_ deletion
|
||||
@ -337,9 +404,10 @@ class intrusive_ptr final {
|
||||
|
||||
auto combined_refcount = detail::atomic_combined_refcount_decrement(
|
||||
target_->combined_refcount_, detail::kReferenceCountOne);
|
||||
if (detail::refcount(combined_refcount) == 0) {
|
||||
bool should_delete =
|
||||
(combined_refcount == detail::kWeakReferenceCountOne);
|
||||
uint32_t new_refcount = detail::refcount(combined_refcount);
|
||||
bool has_pyobject = detail::has_pyobject(combined_refcount);
|
||||
if (new_refcount == 0) {
|
||||
bool should_delete = detail::weakcount(combined_refcount) == 1;
|
||||
// See comment above about weakcount. As long as refcount>0,
|
||||
// weakcount is one larger than the actual number of weak references.
|
||||
// So we need to decrement it here.
|
||||
@ -356,6 +424,18 @@ class intrusive_ptr final {
|
||||
if (should_delete) {
|
||||
delete target_;
|
||||
}
|
||||
} else if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
|
||||
// If the refcount transitioned from 2 to 1, we need to decref the
|
||||
// PyObject. In other words, we don't want to keep the PyObject alive if
|
||||
// there are no C++ references to this object other than the PyObject
|
||||
// itself.
|
||||
if (C10_UNLIKELY(has_pyobject && new_refcount == 1)) {
|
||||
target_->decref_pyobject();
|
||||
}
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
!has_pyobject,
|
||||
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -522,6 +602,16 @@ class intrusive_ptr final {
|
||||
return use_count() == 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stronger than unique() in that it must not have any weakrefs as well.
|
||||
*/
|
||||
bool is_uniquely_owned() const noexcept {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(target_ != NullType::singleton());
|
||||
uint64_t combined =
|
||||
target_->combined_refcount_.load(std::memory_order_acquire);
|
||||
return (combined & ~detail::kHasPyObject) == detail::kUniqueRef;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an owning (!) pointer to the underlying object and makes the
|
||||
* intrusive_ptr instance invalid. That means the refcount is not decreased.
|
||||
@ -932,6 +1022,7 @@ class weak_intrusive_ptr final {
|
||||
if (target_ == NullType::singleton()) {
|
||||
return intrusive_ptr<TTarget, NullType>();
|
||||
} else {
|
||||
bool increfed = false;
|
||||
auto combined_refcount =
|
||||
target_->combined_refcount_.load(std::memory_order_relaxed);
|
||||
do {
|
||||
@ -940,12 +1031,31 @@ class weak_intrusive_ptr final {
|
||||
// Return nullptr.
|
||||
return intrusive_ptr<TTarget, NullType>();
|
||||
}
|
||||
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
|
||||
if (detail::has_pyobject(combined_refcount) &&
|
||||
detail::refcount(combined_refcount) == 1 && !increfed) {
|
||||
// Object has a python wrapper with no other C++ references.
|
||||
// We need to to incref the Python object before we acquire a
|
||||
// strong reference to the C++ object to avoid a situation
|
||||
// where the Python object is deallocated concurrently.
|
||||
if (!target_->try_incref_pyobject()) {
|
||||
return intrusive_ptr<TTarget, NullType>();
|
||||
}
|
||||
increfed = true;
|
||||
}
|
||||
}
|
||||
} while (!target_->combined_refcount_.compare_exchange_weak(
|
||||
combined_refcount,
|
||||
combined_refcount + detail::kReferenceCountOne,
|
||||
std::memory_order_acquire,
|
||||
std::memory_order_relaxed));
|
||||
|
||||
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
|
||||
if (increfed && detail::refcount(combined_refcount) != 1) {
|
||||
target_->decref_pyobject();
|
||||
}
|
||||
}
|
||||
|
||||
return intrusive_ptr<TTarget, NullType>(
|
||||
target_, raw::DontIncreaseRefcount{});
|
||||
}
|
||||
@ -1060,7 +1170,18 @@ namespace intrusive_ptr {
|
||||
// NullType::singleton to this function
|
||||
inline void incref(intrusive_ptr_target* self) {
|
||||
if (self) {
|
||||
detail::atomic_refcount_increment(self->combined_refcount_);
|
||||
uint64_t combined = detail::atomic_combined_refcount_increment(
|
||||
self->combined_refcount_, detail::kReferenceCountOne);
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
if (C10_UNLIKELY(
|
||||
detail::has_pyobject(combined) &&
|
||||
detail::refcount(combined) == 2)) {
|
||||
self->incref_pyobject();
|
||||
}
|
||||
#else
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!detail::has_pyobject(combined));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -15,6 +15,8 @@ using namespace c10::CachingDeviceAllocator;
|
||||
// newly allocated memory with 512-byte alignment.
|
||||
constexpr size_t kDeviceAlignment = 512;
|
||||
|
||||
class XPUAllocator;
|
||||
|
||||
namespace {
|
||||
using stream_set = ska::flat_hash_set<xpu::XPUStream>;
|
||||
|
||||
@ -23,14 +25,19 @@ typedef bool (*Comparison)(const Block*, const Block*);
|
||||
bool BlockComparatorSize(const Block* a, const Block* b);
|
||||
bool BlockComparatorAddress(const Block* a, const Block* b);
|
||||
|
||||
struct PrivatePool;
|
||||
|
||||
struct BlockPool {
|
||||
BlockPool(bool small)
|
||||
BlockPool(bool small, PrivatePool* private_pool = nullptr)
|
||||
: blocks(BlockComparatorSize),
|
||||
unmapped(BlockComparatorAddress),
|
||||
is_small(small) {}
|
||||
is_small(small),
|
||||
owner_PrivatePool(private_pool) {}
|
||||
|
||||
std::set<Block*, Comparison> blocks;
|
||||
std::set<Block*, Comparison> unmapped;
|
||||
const bool is_small;
|
||||
PrivatePool* owner_PrivatePool;
|
||||
};
|
||||
|
||||
struct ExpandableSegment;
|
||||
@ -349,6 +356,43 @@ struct AllocParams {
|
||||
StatTypes stat_types = {};
|
||||
};
|
||||
|
||||
// Internal implementation that manages actual memory blocks.
|
||||
// high level MemPool interface wraps PrivatePool via MempoolId.
|
||||
struct PrivatePool {
|
||||
PrivatePool(MempoolId_t id, XPUAllocator* allocator = nullptr)
|
||||
: id(std::move(id)),
|
||||
allocator_(allocator),
|
||||
large_blocks(/*small=*/false, this),
|
||||
small_blocks(/*small=*/true, this) {}
|
||||
PrivatePool(const PrivatePool&) = delete;
|
||||
PrivatePool(PrivatePool&&) = delete;
|
||||
PrivatePool& operator=(const PrivatePool&) = delete;
|
||||
PrivatePool& operator=(PrivatePool&&) = delete;
|
||||
~PrivatePool() = default;
|
||||
|
||||
// default Mempool when no Mempool is specified
|
||||
MempoolId_t id{0, 0};
|
||||
// Number of live graphs using this pool
|
||||
int use_count{1};
|
||||
// Number of unfreed allocations made for this pool. When use_count and
|
||||
// allocation_count drop to zero, we can delete this PrivatePool from
|
||||
// graph_pools.
|
||||
int allocation_count{0};
|
||||
XPUAllocator* allocator_;
|
||||
BlockPool large_blocks;
|
||||
BlockPool small_blocks;
|
||||
|
||||
public:
|
||||
XPUAllocator* allocator() {
|
||||
return allocator_;
|
||||
}
|
||||
};
|
||||
struct MempoolIdHash {
|
||||
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
|
||||
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
class DeviceCachingAllocator {
|
||||
@ -365,6 +409,13 @@ class DeviceCachingAllocator {
|
||||
bool set_fraction = false;
|
||||
std::vector<ExpandableSegment*> expandable_segments;
|
||||
std::vector<c10::DeviceIndex> devices_with_peer_access; // reserved
|
||||
std::vector<std::pair<MempoolId_t, std::function<bool(sycl::queue*)>>>
|
||||
captures_underway;
|
||||
ska::flat_hash_map<MempoolId_t, std::unique_ptr<PrivatePool>, MempoolIdHash>
|
||||
graph_pools;
|
||||
// Pools no longer referenced by any graph.
|
||||
ska::flat_hash_map<MempoolId_t, PrivatePool*, MempoolIdHash>
|
||||
graph_pools_freeable;
|
||||
|
||||
size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) {
|
||||
if (!src || src->allocated || src->event_count > 0 ||
|
||||
@ -463,7 +514,22 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
BlockPool& get_pool(size_t size) {
|
||||
BlockPool& get_pool(size_t size, sycl::queue* queue) {
|
||||
if (C10_UNLIKELY(!captures_underway.empty())) {
|
||||
for (auto& entry : captures_underway) {
|
||||
// lookup for mempool id matching current capture graph
|
||||
if (entry.second(queue)) {
|
||||
auto it1 = graph_pools.find(entry.first);
|
||||
// lookup mempool
|
||||
TORCH_INTERNAL_ASSERT(it1 != graph_pools.end());
|
||||
if (size <= kSmallSize) {
|
||||
return it1->second->small_blocks;
|
||||
} else {
|
||||
return it1->second->large_blocks;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (size < kSmallSize) {
|
||||
return small_blocks;
|
||||
} else {
|
||||
@ -669,6 +735,10 @@ class DeviceCachingAllocator {
|
||||
if (!ptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (p.pool->owner_PrivatePool) {
|
||||
p.pool->owner_PrivatePool->allocation_count++;
|
||||
}
|
||||
p.block = new Block(device, p.queue(), size, p.pool, ptr);
|
||||
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
|
||||
stats.reserved_bytes[stat_type].increase(size);
|
||||
@ -677,11 +747,14 @@ class DeviceCachingAllocator {
|
||||
return true;
|
||||
}
|
||||
|
||||
void synchronize_and_free_events() {
|
||||
void synchronize_and_free_events(PrivatePool* pool = nullptr) {
|
||||
for (auto& xe : xpu_events) {
|
||||
for (auto& e : xe.second) {
|
||||
auto event = e.first;
|
||||
auto* block = e.second;
|
||||
if (pool && block->pool->owner_PrivatePool != pool) {
|
||||
continue;
|
||||
}
|
||||
event.wait();
|
||||
block->event_count--;
|
||||
if (block->event_count == 0) {
|
||||
@ -785,6 +858,13 @@ class DeviceCachingAllocator {
|
||||
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
|
||||
stats.reserved_bytes[stat_type].decrease(unmapped.size);
|
||||
});
|
||||
|
||||
if (block->pool->owner_PrivatePool) {
|
||||
// The Freed block belonged to a XPU graph's PrivatePool.
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
block->pool->owner_PrivatePool->allocation_count > 0);
|
||||
block->pool->owner_PrivatePool->allocation_count--;
|
||||
}
|
||||
}
|
||||
|
||||
void release_blocks(BlockPool& pool) {
|
||||
@ -812,13 +892,41 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
bool release_cached_blocks() {
|
||||
synchronize_and_free_events();
|
||||
// See Note [Safe to Free Blocks on BlockPool]
|
||||
c10::xpu::syncStreamsOnDevice(device_index);
|
||||
bool release_cached_blocks(MempoolId_t mempool_id) {
|
||||
if (mempool_id.first == 0 && mempool_id.second == 0 &&
|
||||
captures_underway.empty()) {
|
||||
synchronize_and_free_events();
|
||||
// See Note [Safe to Free Blocks on BlockPool]
|
||||
c10::xpu::syncStreamsOnDevice(device_index);
|
||||
|
||||
release_blocks(large_blocks);
|
||||
release_blocks(small_blocks);
|
||||
release_blocks(large_blocks);
|
||||
release_blocks(small_blocks);
|
||||
}
|
||||
|
||||
for (auto it = graph_pools_freeable.begin();
|
||||
it != graph_pools_freeable.end();) {
|
||||
if (mempool_id.first != 0 || mempool_id.second != 0) {
|
||||
if (it->first == mempool_id) {
|
||||
// If there is an active mempool, we sync only the events
|
||||
// associated with the pool
|
||||
synchronize_and_free_events(it->second);
|
||||
} else {
|
||||
// otherwise we move on
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(it->second->use_count == 0);
|
||||
release_blocks(it->second->small_blocks);
|
||||
release_blocks(it->second->large_blocks);
|
||||
if (it->second->allocation_count == 0) {
|
||||
auto erase_count = graph_pools.erase(it->first);
|
||||
TORCH_INTERNAL_ASSERT(erase_count == 1);
|
||||
it = graph_pools_freeable.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -903,6 +1011,30 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
void create_or_incref_pool(
|
||||
MempoolId_t mempool_id,
|
||||
XPUAllocator* allocator = nullptr) {
|
||||
auto it = graph_pools.find(mempool_id);
|
||||
if (it == graph_pools.end()) {
|
||||
// mempool_id does not reference an existing pool.
|
||||
// Make a new pool for XPU graph capture or memory pool usage.
|
||||
graph_pools.emplace(
|
||||
mempool_id, std::make_unique<PrivatePool>(mempool_id, allocator));
|
||||
} else {
|
||||
// mempool_id references an existing pool, which the current XPU graph
|
||||
// capture will share.
|
||||
TORCH_INTERNAL_ASSERT(it->second->use_count > 0);
|
||||
TORCH_INTERNAL_ASSERT(allocator == nullptr);
|
||||
it->second->use_count++;
|
||||
}
|
||||
}
|
||||
|
||||
PrivatePool* get_private_pool(MempoolId_t mempool_id) {
|
||||
auto it = graph_pools.find(mempool_id);
|
||||
TORCH_INTERNAL_ASSERT(it != graph_pools.end());
|
||||
return it->second.get();
|
||||
}
|
||||
|
||||
public:
|
||||
DeviceCachingAllocator(DeviceIndex device_index)
|
||||
: large_blocks(/* small */ false),
|
||||
@ -911,9 +1043,11 @@ class DeviceCachingAllocator {
|
||||
|
||||
Block* malloc(DeviceIndex device, size_t orig_size, sycl::queue& queue) {
|
||||
std::scoped_lock<std::recursive_mutex> lock(mutex);
|
||||
process_events();
|
||||
if (C10_LIKELY(captures_underway.empty())) {
|
||||
process_events();
|
||||
}
|
||||
size_t size = round_size(orig_size);
|
||||
auto& pool = get_pool(size);
|
||||
auto& pool = get_pool(size, &queue);
|
||||
const size_t alloc_size = get_allocation_size(size);
|
||||
AllocParams params(device, size, &queue, &pool, alloc_size);
|
||||
params.stat_types = get_stat_types_for_pool(pool);
|
||||
@ -923,7 +1057,7 @@ class DeviceCachingAllocator {
|
||||
// Can't reuse an existing block, try to get a new one.
|
||||
if (!block_found) {
|
||||
block_found = alloc_block(params, false) ||
|
||||
(release_cached_blocks() && alloc_block(params, true));
|
||||
(release_cached_blocks({0, 0}) && alloc_block(params, true));
|
||||
}
|
||||
if (!block_found) {
|
||||
const auto& raw_device = c10::xpu::get_raw_device(device);
|
||||
@ -1016,9 +1150,9 @@ class DeviceCachingAllocator {
|
||||
block->stream_uses.insert(stream);
|
||||
}
|
||||
|
||||
void emptyCache() {
|
||||
void emptyCache(MempoolId_t mempool_id) {
|
||||
std::scoped_lock<std::recursive_mutex> lock(mutex);
|
||||
release_cached_blocks();
|
||||
release_cached_blocks(mempool_id);
|
||||
}
|
||||
|
||||
DeviceStats getStats() {
|
||||
@ -1172,9 +1306,9 @@ class XPUAllocator : public DeviceAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
void emptyCache(MempoolId_t mempool_id [[maybe_unused]] = {0, 0}) override {
|
||||
void emptyCache(MempoolId_t mempool_id) override {
|
||||
for (auto& da : device_allocators) {
|
||||
da->emptyCache();
|
||||
da->emptyCache(mempool_id);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1290,8 +1424,8 @@ void init(DeviceIndex device_count) {
|
||||
return allocator.init(device_count);
|
||||
}
|
||||
|
||||
void emptyCache() {
|
||||
return allocator.emptyCache();
|
||||
void emptyCache(MempoolId_t mempool_id) {
|
||||
return allocator.emptyCache(mempool_id);
|
||||
}
|
||||
|
||||
void resetPeakStats(DeviceIndex device) {
|
||||
|
||||
@ -10,7 +10,7 @@ C10_XPU_API Allocator* get();
|
||||
|
||||
C10_XPU_API void init(DeviceIndex device_count);
|
||||
|
||||
C10_XPU_API void emptyCache();
|
||||
C10_XPU_API void emptyCache(MempoolId_t mempool_id = {0, 0});
|
||||
|
||||
C10_XPU_API void resetPeakStats(DeviceIndex device);
|
||||
|
||||
|
||||
@ -1643,8 +1643,6 @@ if(USE_CUDA)
|
||||
target_link_libraries(torch_cuda PUBLIC c10_cuda)
|
||||
if(TARGET torch::nvtx3)
|
||||
target_link_libraries(torch_cuda PRIVATE torch::nvtx3)
|
||||
else()
|
||||
target_link_libraries(torch_cuda PUBLIC torch::nvtoolsext)
|
||||
endif()
|
||||
|
||||
target_include_directories(
|
||||
@ -1741,9 +1739,6 @@ if(BUILD_SHARED_LIBS)
|
||||
if(USE_CUDA)
|
||||
target_link_libraries(torch_global_deps ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS})
|
||||
target_link_libraries(torch_global_deps torch::cudart)
|
||||
if(TARGET torch::nvtoolsext)
|
||||
target_link_libraries(torch_global_deps torch::nvtoolsext)
|
||||
endif()
|
||||
endif()
|
||||
install(TARGETS torch_global_deps DESTINATION "${TORCH_INSTALL_LIB_DIR}")
|
||||
endif()
|
||||
|
||||
@ -734,7 +734,7 @@ void PyTorchStreamWriter::setup(const string& file_name) {
|
||||
file_name,
|
||||
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary
|
||||
);
|
||||
} catch (const std::ios_base::failure& e) {
|
||||
} catch (const std::ios_base::failure&) {
|
||||
#ifdef _WIN32
|
||||
// Windows have verbose error code, we prefer to use it than std errno.
|
||||
uint32_t error_code = GetLastError();
|
||||
@ -773,8 +773,20 @@ void PyTorchStreamWriter::writeRecord(
|
||||
bool compress) {
|
||||
AT_ASSERT(!finalized_);
|
||||
AT_ASSERT(!archive_name_plus_slash_.empty());
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
files_written_.count(name) == 0, "Tried to serialize file twice: ", name);
|
||||
if (files_written_.count(name) > 0) {
|
||||
// Allow multiple writes for triton binaries
|
||||
bool is_triton_extension =
|
||||
c10::ends_with(name, ".so") ||
|
||||
c10::ends_with(name, ".cubin") ||
|
||||
c10::ends_with(name, ".hsaco");
|
||||
|
||||
if (is_triton_extension) {
|
||||
LOG(WARNING) << "File '" << name << "' is being serialized multiple times";
|
||||
return;
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(false, "Tried to serialize file twice: ", name);
|
||||
}
|
||||
if (name == kSerializationIdRecordName && serialization_id_.empty()) {
|
||||
// In case of copying records from another file, skip writing a different
|
||||
// serialization_id than the one computed in this writer.
|
||||
|
||||
@ -118,11 +118,6 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
|
||||
endif()
|
||||
endif()
|
||||
if("${_arch}" STREQUAL "121a")
|
||||
if(_existing_arch_flags MATCHES ".*compute_120.*")
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_121a,code=sm_121a")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
list(JOIN _file_compile_flags " " _file_compile_flags)
|
||||
|
||||
@ -131,7 +126,7 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
|
||||
"89;90a;100a;103a;120a;121a")
|
||||
"89;90a;100a;103a;120a")
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
|
||||
"90a")
|
||||
|
||||
@ -968,11 +968,8 @@ find_package_handle_standard_args(nvtx3 DEFAULT_MSG nvtx3_dir)
|
||||
if(nvtx3_FOUND)
|
||||
add_library(torch::nvtx3 INTERFACE IMPORTED)
|
||||
target_include_directories(torch::nvtx3 INTERFACE "${nvtx3_dir}")
|
||||
target_compile_definitions(torch::nvtx3 INTERFACE TORCH_CUDA_USE_NVTX3)
|
||||
else()
|
||||
message(WARNING "Cannot find NVTX3, find old NVTX instead")
|
||||
add_library(torch::nvtoolsext INTERFACE IMPORTED)
|
||||
set_property(TARGET torch::nvtoolsext PROPERTY INTERFACE_LINK_LIBRARIES CUDA::nvToolsExt)
|
||||
message(FATAL_ERROR "Cannot find NVTX3!")
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
@ -132,9 +132,6 @@ if(@USE_CUDA@)
|
||||
else()
|
||||
set(TORCH_CUDA_LIBRARIES ${CUDA_NVRTC_LIB})
|
||||
endif()
|
||||
if(TARGET torch::nvtoolsext)
|
||||
list(APPEND TORCH_CUDA_LIBRARIES torch::nvtoolsext)
|
||||
endif()
|
||||
|
||||
if(@BUILD_SHARED_LIBS@)
|
||||
find_library(C10_CUDA_LIBRARY c10_cuda PATHS "${TORCH_INSTALL_PREFIX}/lib")
|
||||
|
||||
@ -10,7 +10,7 @@ API. This API can roughly be divided into five parts:
|
||||
- **TorchScript**: An interface to the TorchScript JIT compiler and interpreter.
|
||||
- **C++ Extensions**: A means of extending the Python API with custom C++ and CUDA routines.
|
||||
|
||||
Combining, these building blocks form a research and
|
||||
Combined, these building blocks form a research and
|
||||
production ready C++ library for tensor computation and dynamic neural
|
||||
networks with strong emphasis on GPU acceleration as well as fast CPU
|
||||
performance. It is currently in use at Facebook in research and
|
||||
@ -76,7 +76,7 @@ C++ Frontend
|
||||
------------
|
||||
|
||||
The PyTorch C++ frontend provides a high level, pure C++ modeling interface for
|
||||
neural network and general ML(Machine Learning) research and production use cases,
|
||||
neural networks and general ML (Machine Learning) research and production use cases,
|
||||
largely following the Python API in design and provided functionality. The C++
|
||||
frontend includes the following:
|
||||
|
||||
|
||||
@ -254,7 +254,7 @@ To toggle the reduced precision reduction flags in C++, one can do
|
||||
|
||||
.. _fp16accumulation:
|
||||
|
||||
Full FP16 Accmumulation in FP16 GEMMs
|
||||
Full FP16 Accumulation in FP16 GEMMs
|
||||
-------------------------------------
|
||||
|
||||
Certain GPUs have increased performance when doing _all_ FP16 GEMM accumulation
|
||||
|
||||
@ -30,5 +30,6 @@ For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`.
|
||||
skip_guard_on_all_nn_modules_unsafe
|
||||
keep_tensor_guards_unsafe
|
||||
skip_guard_on_globals_unsafe
|
||||
skip_all_guards_unsafe
|
||||
nested_compile_region
|
||||
```
|
||||
|
||||
@ -32,7 +32,7 @@ project-excludes = [
|
||||
"torch/utils/tensorboard/summary.py",
|
||||
# formatting issues, will turn on after adjusting where suppressions can be
|
||||
# in import statements
|
||||
"tools/flight_recorder/components/types.py",
|
||||
"torch/distributed/flight_recorder/components/types.py",
|
||||
"torch/linalg/__init__.py",
|
||||
"torch/package/importer.py",
|
||||
"torch/package/_package_pickler.py",
|
||||
|
||||
49
setup.py
49
setup.py
@ -1358,45 +1358,6 @@ class concat_license_files:
|
||||
|
||||
# Need to create the proper LICENSE.txt for the wheel
|
||||
class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
|
||||
def _wrap_headers_with_macro(self, bdist_dir: Path) -> None:
|
||||
"""Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION).
|
||||
|
||||
Excludes:
|
||||
- torch/include/torch/headeronly/*
|
||||
- torch/include/torch/csrc/stable/*
|
||||
- torch/include/torch/csrc/inductor/aoti_torch/c/ (only shim headers)
|
||||
- torch/include/torch/csrc/inductor/aoti_torch/generated/
|
||||
"""
|
||||
header_extensions = (".h", ".hpp", ".cuh")
|
||||
header_files = [
|
||||
f for ext in header_extensions for f in bdist_dir.rglob(f"*{ext}")
|
||||
]
|
||||
|
||||
# Paths to exclude from wrapping
|
||||
exclude_dir_patterns = [
|
||||
"torch/include/torch/headeronly/",
|
||||
"torch/include/torch/csrc/stable/",
|
||||
"torch/include/torch/csrc/inductor/aoti_torch/c/",
|
||||
"torch/include/torch/csrc/inductor/aoti_torch/generated/",
|
||||
]
|
||||
|
||||
for header_file in header_files:
|
||||
rel_path = header_file.relative_to(bdist_dir).as_posix()
|
||||
|
||||
if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns):
|
||||
report(f"Skipping header: {rel_path}")
|
||||
continue
|
||||
|
||||
original_content = header_file.read_text(encoding="utf-8")
|
||||
wrapped_content = (
|
||||
"#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
|
||||
f"{original_content}"
|
||||
"\n#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
|
||||
)
|
||||
|
||||
header_file.write_text(wrapped_content, encoding="utf-8")
|
||||
report(f"Wrapped header: {rel_path}")
|
||||
|
||||
def run(self) -> None:
|
||||
with concat_license_files(include_files=True):
|
||||
super().run()
|
||||
@ -1419,14 +1380,6 @@ class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel):
|
||||
# need an __init__.py file otherwise we wouldn't have a package
|
||||
(bdist_dir / "torch" / "__init__.py").touch()
|
||||
|
||||
# Wrap all header files with TORCH_STABLE_ONLY macro
|
||||
assert self.bdist_dir is not None, "bdist_dir should be set during wheel build"
|
||||
bdist_dir = Path(self.bdist_dir)
|
||||
report(
|
||||
"-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)"
|
||||
)
|
||||
self._wrap_headers_with_macro(bdist_dir)
|
||||
|
||||
|
||||
class clean(Command):
|
||||
user_options: ClassVar[list[tuple[str, str | None, str]]] = []
|
||||
@ -1632,7 +1585,7 @@ def configure_extension_build() -> tuple[
|
||||
if cmake_cache_vars["USE_DISTRIBUTED"]:
|
||||
# Only enable fr_trace command if distributed is enabled
|
||||
entry_points["console_scripts"].append(
|
||||
"torchfrtrace = tools.flight_recorder.fr_trace:main",
|
||||
"torchfrtrace = torch.distributed.flight_recorder.fr_trace:main",
|
||||
)
|
||||
return ext_modules, cmdclass, packages, entry_points, extra_install_requires
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ set(AOTI_ABI_CHECK_TEST_ROOT ${TORCH_ROOT}/test/cpp/aoti_abi_check)
|
||||
# Build the cpp gtest binary containing the cpp-only tests.
|
||||
set(AOTI_ABI_CHECK_TEST_SRCS
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/main.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_accessor.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_cast.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_dispatch.cpp
|
||||
|
||||
50
test/cpp/aoti_abi_check/test_accessor.cpp
Normal file
50
test/cpp/aoti_abi_check/test_accessor.cpp
Normal file
@ -0,0 +1,50 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <torch/headeronly/core/TensorAccessor.h>
|
||||
#include <string>
|
||||
|
||||
TEST(TestAccessor, HeaderOnlyTensorAccessor) {
|
||||
std::vector<int32_t> v = {11, 12, 13, 21, 22, 23};
|
||||
std::vector<int64_t> sizes = {2, 3};
|
||||
std::vector<int64_t> strides = {3, 1};
|
||||
|
||||
auto acc = torch::headeronly::HeaderOnlyTensorAccessor<int32_t, 2>(
|
||||
v.data(), sizes.data(), strides.data());
|
||||
EXPECT_EQ(acc[0][0], 11);
|
||||
EXPECT_EQ(acc[0][1], 12);
|
||||
EXPECT_EQ(acc[0][2], 13);
|
||||
EXPECT_EQ(acc[1][0], 21);
|
||||
EXPECT_EQ(acc[1][1], 22);
|
||||
EXPECT_EQ(acc[1][2], 23);
|
||||
}
|
||||
|
||||
TEST(TestAccessor, HeaderOnlyGenericPackedTensorAccessor) {
|
||||
std::vector<int32_t> v = {11, 12, 13, 21, 22, 23};
|
||||
std::vector<int64_t> sizes = {2, 3};
|
||||
std::vector<int64_t> strides = {3, 1};
|
||||
|
||||
auto acc =
|
||||
torch::headeronly::HeaderOnlyGenericPackedTensorAccessor<int32_t, 2>(
|
||||
v.data(), sizes.data(), strides.data());
|
||||
EXPECT_EQ(acc[0][0], 11);
|
||||
EXPECT_EQ(acc[0][1], 12);
|
||||
EXPECT_EQ(acc[0][2], 13);
|
||||
EXPECT_EQ(acc[1][0], 21);
|
||||
EXPECT_EQ(acc[1][1], 22);
|
||||
EXPECT_EQ(acc[1][2], 23);
|
||||
|
||||
auto tacc = acc.transpose(0, 1);
|
||||
EXPECT_EQ(tacc[0][0], 11);
|
||||
EXPECT_EQ(tacc[0][1], 21);
|
||||
EXPECT_EQ(tacc[1][0], 12);
|
||||
EXPECT_EQ(tacc[1][1], 22);
|
||||
EXPECT_EQ(tacc[2][0], 13);
|
||||
EXPECT_EQ(tacc[2][1], 23);
|
||||
|
||||
try {
|
||||
acc.transpose(0, 2);
|
||||
} catch (const std::exception& e) {
|
||||
EXPECT_TRUE(
|
||||
std::string(e.what()).find("HeaderOnlyIndexBoundsCheck") !=
|
||||
std::string::npos);
|
||||
}
|
||||
}
|
||||
@ -13,6 +13,17 @@ TEST(TestScalarType, ScalarTypeToCPPTypeT) {
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
TEST(TestScalarType, CppTypeToScalarType) {
|
||||
using torch::headeronly::CppTypeToScalarType;
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
|
||||
EXPECT_EQ(CppTypeToScalarType<TYPE>::value, ScalarType::SCALARTYPE);
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
|
||||
{ \
|
||||
EXPECT_EQ( \
|
||||
@ -90,3 +101,14 @@ TEST(TestScalarType, toUnderlying) {
|
||||
AT_FORALL_FLOAT8_TYPES(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
TEST(TestScalarType, isQIntType) {
|
||||
using torch::headeronly::isQIntType;
|
||||
using torch::headeronly::ScalarType;
|
||||
#define DEFINE_CHECK(_, name) EXPECT_TRUE(isQIntType(ScalarType::name));
|
||||
AT_FORALL_QINT_TYPES(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
#define DEFINE_CHECK(_, name) EXPECT_FALSE(isQIntType(ScalarType::name));
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CHECK);
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
@ -15,7 +15,7 @@ namespace jit {
|
||||
TEST(CustomOperatorTest, InferredSchema) {
|
||||
torch::RegisterOperators reg(
|
||||
"foo::bar", [](double a, at::Tensor b) { return a + b; });
|
||||
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar"));
|
||||
auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar"));
|
||||
ASSERT_EQ(ops.size(), 1);
|
||||
|
||||
auto& op = ops.front();
|
||||
@ -43,8 +43,7 @@ TEST(CustomOperatorTest, ExplicitSchema) {
|
||||
"foo::bar_with_schema(float a, Tensor b) -> Tensor",
|
||||
[](double a, at::Tensor b) { return a + b; });
|
||||
|
||||
auto& ops =
|
||||
getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema"));
|
||||
auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema"));
|
||||
ASSERT_EQ(ops.size(), 1);
|
||||
|
||||
auto& op = ops.front();
|
||||
@ -77,7 +76,7 @@ TEST(CustomOperatorTest, ListParameters) {
|
||||
torch::List<c10::complex<double>> complexdoubles,
|
||||
torch::List<at::Tensor> tensors) { return floats; });
|
||||
|
||||
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists"));
|
||||
auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists"));
|
||||
ASSERT_EQ(ops.size(), 1);
|
||||
|
||||
auto& op = ops.front();
|
||||
@ -123,7 +122,7 @@ TEST(CustomOperatorTest, ListParameters2) {
|
||||
"foo::lists2(Tensor[] tensors) -> Tensor[]",
|
||||
[](torch::List<at::Tensor> tensors) { return tensors; });
|
||||
|
||||
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2"));
|
||||
auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2"));
|
||||
ASSERT_EQ(ops.size(), 1);
|
||||
|
||||
auto& op = ops.front();
|
||||
@ -213,7 +212,7 @@ TEST(TestCustomOperator, OperatorGeneratorUndeclared) {
|
||||
},
|
||||
aliasAnalysisFromSchema())});
|
||||
|
||||
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist"));
|
||||
auto ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist"));
|
||||
ASSERT_EQ(ops.size(), 0);
|
||||
}
|
||||
|
||||
@ -232,7 +231,7 @@ TEST(TestCustomOperator, OperatorGeneratorBasic) {
|
||||
},
|
||||
aliasAnalysisFromSchema())});
|
||||
|
||||
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar"));
|
||||
auto ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar"));
|
||||
ASSERT_EQ(ops.size(), 1);
|
||||
|
||||
auto& op = ops.front();
|
||||
|
||||
@ -0,0 +1,30 @@
|
||||
#include "kernel.h"
|
||||
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
using torch::stable::Tensor;
|
||||
|
||||
Tensor mv_tensor_accessor_cuda(Tensor m, Tensor v) {
|
||||
STD_TORCH_CHECK(m.dim() == 2, "m must be 2D");
|
||||
STD_TORCH_CHECK(v.dim() == 1, "v must be 1D");
|
||||
STD_TORCH_CHECK(m.size(1) == v.size(0), "m.shape[1] == v.shape[0] must hold");
|
||||
STD_TORCH_CHECK(m.scalar_type() == v.scalar_type(), "m and v must have the same dtype");
|
||||
STD_TORCH_CHECK(m.device() == v.device(), "m and v must be on the same device");
|
||||
Tensor res = new_empty(m, {m.size(0)});
|
||||
THO_DISPATCH_V2(m.scalar_type(), "mv_tensor_accessor_cuda",
|
||||
AT_WRAP(([&]() {
|
||||
auto resa = Accessor_cuda<scalar_t, 1>(reinterpret_cast<scalar_t*>(res.data_ptr()), res.sizes().data(), res.strides().data());
|
||||
auto ma = Accessor_cuda<scalar_t, 2>(reinterpret_cast<scalar_t*>(m.data_ptr()), m.sizes().data(), m.strides().data());
|
||||
auto va = Accessor_cuda<scalar_t, 1>(reinterpret_cast<scalar_t*>(v.data_ptr()), v.sizes().data(), v.strides().data());
|
||||
mv_tensor_accessor_kernel<Accessor_cuda, scalar_t><<<1, 1, 0, 0>>>(resa, ma, va);
|
||||
})),
|
||||
AT_FLOATING_TYPES);
|
||||
return res;
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CUDA, m) {
|
||||
m.impl("mv_tensor_accessor", TORCH_BOX(&mv_tensor_accessor_cuda));
|
||||
}
|
||||
@ -1,3 +1,5 @@
|
||||
#include "kernel.h"
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
#include <torch/csrc/stable/device.h>
|
||||
@ -308,7 +310,7 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_amax(Tensor a) -> Tensor");
|
||||
m.def("my_amax_vec(Tensor a) -> Tensor");
|
||||
m.def("my_is_cpu(Tensor t) -> bool");
|
||||
m.def("test_default_constructor(bool undefined) -> bool");
|
||||
m.def("test_default_constructor(bool undefined) -> bool");
|
||||
}
|
||||
|
||||
bool test_default_constructor(bool defined) {
|
||||
@ -330,12 +332,47 @@ bool test_default_constructor(bool defined) {
|
||||
return out.defined();
|
||||
}
|
||||
|
||||
uint64_t get_any_data_ptr(Tensor t, bool mutable_) {
|
||||
if (mutable_) {
|
||||
return reinterpret_cast<uint64_t>(t.mutable_data_ptr());
|
||||
} else {
|
||||
return reinterpret_cast<uint64_t>(t.const_data_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t get_template_any_data_ptr(Tensor t, c10::ScalarType dtype, bool mutable_) {
|
||||
#define DEFINE_CASE(T, name) \
|
||||
case torch::headeronly::ScalarType::name: { \
|
||||
if (mutable_) { \
|
||||
return reinterpret_cast<uint64_t>(t.mutable_data_ptr<T>()); \
|
||||
} else { \
|
||||
return reinterpret_cast<uint64_t>(t.const_data_ptr<T>()); \
|
||||
} \
|
||||
}
|
||||
switch (dtype) {
|
||||
// per aten/src/ATen/templates/TensorMethods.cpp:
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
|
||||
DEFINE_CASE(uint16_t, UInt16)
|
||||
DEFINE_CASE(uint32_t, UInt32)
|
||||
DEFINE_CASE(uint64_t, UInt64)
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
#undef DEFINE_CASE
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("get_any_data_ptr(Tensor t, bool mutable_) -> int");
|
||||
m.def("get_template_any_data_ptr(Tensor t, ScalarType dtype, bool mutable_) -> int");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_zero_", TORCH_BOX(&my_zero_));
|
||||
m.impl("my_amax", TORCH_BOX(&my_amax));
|
||||
m.impl("my_amax_vec", TORCH_BOX(&my_amax_vec));
|
||||
m.impl("test_default_constructor", TORCH_BOX(&test_default_constructor));
|
||||
m.impl("get_any_data_ptr", TORCH_BOX(&get_any_data_ptr));
|
||||
m.impl("get_template_any_data_ptr", TORCH_BOX(&get_template_any_data_ptr));
|
||||
}
|
||||
|
||||
std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
|
||||
@ -514,6 +551,32 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("test_device_is_cpu", &boxed_test_device_is_cpu);
|
||||
}
|
||||
|
||||
Tensor mv_tensor_accessor_cpu(Tensor m, Tensor v) {
|
||||
STD_TORCH_CHECK(m.dim() == 2, "m must be 2D");
|
||||
STD_TORCH_CHECK(v.dim() == 1, "v must be 1D");
|
||||
STD_TORCH_CHECK(m.size(1) == v.size(0), "m.shape[1] == v.shape[0] must hold");
|
||||
STD_TORCH_CHECK(m.scalar_type() == v.scalar_type(), "m and v must have the same dtype");
|
||||
STD_TORCH_CHECK(m.device() == v.device(), "m and v must be on the same device");
|
||||
Tensor res = new_empty(m, {m.size(0)});
|
||||
THO_DISPATCH_V2(m.scalar_type(), "mv_tensor_accessor_cpu",
|
||||
AT_WRAP(([&]() {
|
||||
auto resa = Accessor_cpu<scalar_t, 1>(reinterpret_cast<scalar_t*>(res.data_ptr()), res.sizes().data(), res.strides().data());
|
||||
auto ma = Accessor_cpu<scalar_t, 2>(reinterpret_cast<scalar_t*>(m.data_ptr()), m.sizes().data(), m.strides().data());
|
||||
auto va = Accessor_cpu<scalar_t, 1>(reinterpret_cast<scalar_t*>(v.data_ptr()), v.sizes().data(), v.strides().data());
|
||||
mv_tensor_accessor_kernel<Accessor_cpu, scalar_t>(resa, ma, va);
|
||||
})),
|
||||
AT_FLOATING_TYPES);
|
||||
return res;
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("mv_tensor_accessor(Tensor m, Tensor v) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
|
||||
m.impl("mv_tensor_accessor", TORCH_BOX(&mv_tensor_accessor_cpu));
|
||||
}
|
||||
|
||||
// Test functions for torch::stable::accelerator APIs
|
||||
|
||||
#ifdef LAE_USE_CUDA
|
||||
@ -634,3 +697,38 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("test_parallel_for", &boxed_test_parallel_for);
|
||||
m.impl("test_get_num_threads", &boxed_test_get_num_threads);
|
||||
}
|
||||
|
||||
Tensor my_empty(
|
||||
torch::headeronly::HeaderOnlyArrayRef<int64_t> size,
|
||||
std::optional<torch::headeronly::ScalarType> dtype,
|
||||
std::optional<torch::stable::Device> device,
|
||||
std::optional<bool> pin_memory) {
|
||||
return empty(size, dtype, device, pin_memory);
|
||||
}
|
||||
|
||||
Tensor my_flatten(Tensor t, int64_t start_dim, int64_t end_dim) {
|
||||
return flatten(t, start_dim, end_dim);
|
||||
}
|
||||
|
||||
Tensor my_reshape(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> shape) {
|
||||
return reshape(t, shape);
|
||||
}
|
||||
|
||||
Tensor my_view(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> size) {
|
||||
return view(t, size);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def(
|
||||
"my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor");
|
||||
m.def("my_flatten(Tensor t, int start_dim=0, int end_dim=-1) -> Tensor");
|
||||
m.def("my_reshape(Tensor t, int[] shape) -> Tensor");
|
||||
m.def("my_view(Tensor t, int[] size) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_empty", TORCH_BOX(&my_empty));
|
||||
m.impl("my_flatten", TORCH_BOX(&my_flatten));
|
||||
m.impl("my_reshape", TORCH_BOX(&my_reshape));
|
||||
m.impl("my_view", TORCH_BOX(&my_view));
|
||||
}
|
||||
|
||||
@ -0,0 +1,26 @@
|
||||
#include <torch/headeronly/core/Dispatch_v2.h>
|
||||
#include <torch/headeronly/core/TensorAccessor.h>
|
||||
|
||||
template <typename T, size_t N>
|
||||
using Accessor_cpu = torch::headeronly::HeaderOnlyTensorAccessor<T, N>;
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
#define MAYBE_GLOBAL __global__
|
||||
|
||||
template <typename T, size_t N>
|
||||
using Accessor_cuda = torch::headeronly::HeaderOnlyGenericPackedTensorAccessor<T, N, torch::headeronly::RestrictPtrTraits>;
|
||||
|
||||
#else
|
||||
#define MAYBE_GLOBAL
|
||||
#endif
|
||||
|
||||
template <template <typename, size_t> class Accessor, typename scalar_t>
|
||||
MAYBE_GLOBAL void mv_tensor_accessor_kernel(Accessor<scalar_t, 1> resa, Accessor<scalar_t, 2> ma, Accessor<scalar_t, 1> va) {
|
||||
for (int64_t i = 0; i < resa.size(0); i++) {
|
||||
scalar_t val = 0;
|
||||
for (int64_t j = 0; j < ma.size(1); j++) {
|
||||
val += ma[i][j] * va[j];
|
||||
}
|
||||
resa[i] = val;
|
||||
}
|
||||
}
|
||||
@ -227,6 +227,37 @@ def test_tensor_device(t):
|
||||
return torch.ops.libtorch_agnostic.test_tensor_device.default(t)
|
||||
|
||||
|
||||
def get_any_data_ptr(t, mutable) -> int:
|
||||
"""
|
||||
Return data pointer value of the tensor.
|
||||
|
||||
Args:
|
||||
t: Input tensor
|
||||
mutable: whether data pointer qualifier is mutable or const
|
||||
|
||||
Returns: int - pointer value
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.get_any_data_ptr.default(t, mutable)
|
||||
|
||||
|
||||
def get_template_any_data_ptr(t, dtype, mutable) -> int:
|
||||
"""
|
||||
Return data pointer value of the tensor iff it has dtype.
|
||||
|
||||
Args:
|
||||
t: Input tensor
|
||||
dtype: Input dtype
|
||||
mutable: whether data pointer qualifier is mutable or const
|
||||
|
||||
Returns: int - pointer value
|
||||
|
||||
Raises RuntimeError when t.dtype() != dtype.
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.get_template_any_data_ptr.default(
|
||||
t, dtype, mutable
|
||||
)
|
||||
|
||||
|
||||
def my_pad(t) -> Tensor:
|
||||
"""
|
||||
Pads the input tensor with hardcoded padding parameters.
|
||||
@ -487,3 +518,72 @@ def test_get_num_threads() -> int:
|
||||
Returns: int - the number of threads for the parallel backend
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.test_get_num_threads.default()
|
||||
|
||||
|
||||
def my_empty(size, dtype=None, device=None, pin_memory=None) -> Tensor:
|
||||
"""
|
||||
Creates an empty tensor with the specified size, dtype, device, and pin_memory.
|
||||
|
||||
Args:
|
||||
size: list[int] - size of the tensor to create
|
||||
dtype: ScalarType or None - data type of the tensor
|
||||
device: Device or None - device on which to create the tensor
|
||||
pin_memory: bool or None - whether to use pinned memory
|
||||
|
||||
Returns: Tensor - an uninitialized tensor with the specified properties
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_empty.default(size, dtype, device, pin_memory)
|
||||
|
||||
|
||||
def my_flatten(t, start_dim=0, end_dim=-1) -> Tensor:
|
||||
"""
|
||||
Flattens the input tensor from start_dim to end_dim into a single dimension.
|
||||
|
||||
Args:
|
||||
t: Tensor - tensor to flatten
|
||||
start_dim: int - first dimension to flatten (default: 0)
|
||||
end_dim: int - last dimension to flatten (default: -1)
|
||||
|
||||
Returns: Tensor - flattened tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_flatten.default(t, start_dim, end_dim)
|
||||
|
||||
|
||||
def my_reshape(t, shape) -> Tensor:
|
||||
"""
|
||||
Returns a tensor with the same data but different shape.
|
||||
|
||||
Args:
|
||||
t: Tensor - tensor to reshape
|
||||
shape: list[int] - new shape for the tensor
|
||||
|
||||
Returns: Tensor - reshaped tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_reshape.default(t, shape)
|
||||
|
||||
|
||||
def my_view(t, size) -> Tensor:
|
||||
"""
|
||||
Returns a new tensor with the same data as the input tensor but of a different shape.
|
||||
|
||||
Args:
|
||||
t: Tensor - tensor to view
|
||||
size: list[int] - new size for the tensor
|
||||
|
||||
Returns: Tensor - tensor with new view
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_view.default(t, size)
|
||||
|
||||
|
||||
def mv_tensor_accessor(m, v) -> Tensor:
|
||||
"""
|
||||
Returns matrix-vector product.
|
||||
|
||||
Args:
|
||||
m: any 2-D Tensor with shape (N, M)
|
||||
v: any 1-D Tensor with shape (M,)
|
||||
|
||||
Returns:
|
||||
a 1-D Tensor with shape (N,)
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.mv_tensor_accessor.default(m, v)
|
||||
|
||||
@ -33,16 +33,17 @@ class clean(distutils.command.clean.clean):
|
||||
|
||||
def get_extension():
|
||||
extra_compile_args = {
|
||||
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
|
||||
"cxx": ["-fdiagnostics-color=always"],
|
||||
}
|
||||
sources = list(CSRC_DIR.glob("**/*.cpp"))
|
||||
|
||||
extension = CppExtension
|
||||
# allow including <cuda_runtime.h>
|
||||
if torch.cuda.is_available():
|
||||
extra_compile_args["cxx"].append("-DLAE_USE_CUDA")
|
||||
extra_compile_args["nvcc"] = ["-O2"]
|
||||
extension = CUDAExtension
|
||||
|
||||
sources = list(CSRC_DIR.glob("**/*.cpp"))
|
||||
sources.extend(CSRC_DIR.glob("**/*.cu"))
|
||||
|
||||
return [
|
||||
extension(
|
||||
|
||||
@ -14,11 +14,38 @@ from torch.testing._internal.common_utils import (
|
||||
install_cpp_extension,
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
TestCase,
|
||||
xfailIfTorchDynamo,
|
||||
)
|
||||
|
||||
|
||||
def get_supported_dtypes():
|
||||
"""Return a list of dtypes that are supported by torch stable ABI."""
|
||||
return [
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.uint8,
|
||||
torch.uint16,
|
||||
torch.uint32,
|
||||
torch.uint64,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2fnuz,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.complex32,
|
||||
torch.complex64,
|
||||
torch.complex128,
|
||||
torch.bool,
|
||||
]
|
||||
|
||||
|
||||
# TODO: Fix this error in Windows:
|
||||
# LINK : error LNK2001: unresolved external symbol PyInit__C
|
||||
if not IS_WINDOWS:
|
||||
@ -274,6 +301,43 @@ if not IS_WINDOWS:
|
||||
expected0 = torch.narrow(t, dim0, start0, length0)
|
||||
self.assertEqual(out0, expected0)
|
||||
|
||||
@skipIfTorchDynamo("no data pointer defined for FakeTensor, FunctionalTensor")
|
||||
def test_get_any_data_ptr(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.empty(2, 5, device=device, dtype=torch.float32)
|
||||
expected_p = t.data_ptr()
|
||||
|
||||
for mutable in [True, False]:
|
||||
p = libtorch_agnostic.ops.get_any_data_ptr(t, mutable)
|
||||
self.assertEqual(p, expected_p)
|
||||
|
||||
@skipIfTorchDynamo("no data pointer defined for FakeTensor, FunctionalTensor")
|
||||
def test_get_template_any_data_ptr(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
supported_dtypes = get_supported_dtypes()
|
||||
|
||||
for dtype in supported_dtypes:
|
||||
t = torch.empty(2, 5, device=device, dtype=dtype)
|
||||
expected_p = t.data_ptr()
|
||||
|
||||
for rdtype in supported_dtypes:
|
||||
if dtype == rdtype:
|
||||
for mutable in [True, False]:
|
||||
p = libtorch_agnostic.ops.get_template_any_data_ptr(
|
||||
t, rdtype, mutable
|
||||
)
|
||||
self.assertEqual(p, expected_p)
|
||||
else:
|
||||
for mutable in [True, False]:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "expected scalar type.* but found"
|
||||
):
|
||||
libtorch_agnostic.ops.get_template_any_data_ptr(
|
||||
t, rdtype, mutable
|
||||
)
|
||||
|
||||
@onlyCUDA
|
||||
@deviceCountAtLeast(2)
|
||||
def test_device_guard(self, device):
|
||||
@ -525,6 +589,113 @@ if not IS_WINDOWS:
|
||||
expected_num_threads = torch.get_num_threads()
|
||||
self.assertEqual(num_threads, expected_num_threads)
|
||||
|
||||
def test_my_empty(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
deterministic = torch.are_deterministic_algorithms_enabled()
|
||||
try:
|
||||
# set use_deterministic_algorithms to fill uninitialized memory
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
size = [2, 3]
|
||||
result = libtorch_agnostic.ops.my_empty(size, None, None, None)
|
||||
expected = torch.empty(size)
|
||||
self.assertEqual(result, expected, exact_device=True)
|
||||
|
||||
result_float = libtorch_agnostic.ops.my_empty(
|
||||
size, torch.float32, None, None
|
||||
)
|
||||
expected_float = torch.empty(size, dtype=torch.float32)
|
||||
self.assertEqual(result_float, expected_float, exact_device=True)
|
||||
|
||||
result_with_device = libtorch_agnostic.ops.my_empty(
|
||||
size, torch.float64, device, None
|
||||
)
|
||||
expected_with_device = torch.empty(
|
||||
size, dtype=torch.float64, device=device
|
||||
)
|
||||
self.assertEqual(
|
||||
result_with_device, expected_with_device, exact_device=True
|
||||
)
|
||||
|
||||
if device == "cuda":
|
||||
result_pinned = libtorch_agnostic.ops.my_empty(
|
||||
size, torch.float32, "cpu", True
|
||||
)
|
||||
expected_pinned = torch.empty(
|
||||
size, dtype=torch.float32, device="cpu", pin_memory=True
|
||||
)
|
||||
self.assertEqual(result_pinned, expected_pinned)
|
||||
self.assertTrue(result_pinned.is_pinned())
|
||||
finally:
|
||||
torch.use_deterministic_algorithms(deterministic)
|
||||
|
||||
def test_my_flatten(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.randn(2, 3, 4, device=device)
|
||||
result = libtorch_agnostic.ops.my_flatten(t)
|
||||
expected = torch.flatten(t)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
result_start = libtorch_agnostic.ops.my_flatten(t, 1)
|
||||
expected_start = torch.flatten(t, 1)
|
||||
self.assertEqual(result_start, expected_start)
|
||||
|
||||
result_range = libtorch_agnostic.ops.my_flatten(t, 2, -1)
|
||||
expected_range = torch.flatten(t, 2, -1)
|
||||
self.assertEqual(result_range, expected_range)
|
||||
|
||||
def test_my_reshape(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.randn(2, 3, 4, device=device)
|
||||
|
||||
result = libtorch_agnostic.ops.my_reshape(t, [6, 4])
|
||||
expected = torch.reshape(t, [6, 4])
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
result_infer = libtorch_agnostic.ops.my_reshape(t, [-1, 4])
|
||||
expected_infer = torch.reshape(t, [-1, 4])
|
||||
self.assertEqual(result_infer, expected_infer)
|
||||
|
||||
result_flat = libtorch_agnostic.ops.my_reshape(t, [-1])
|
||||
expected_flat = torch.reshape(t, [-1])
|
||||
self.assertEqual(result_flat, expected_flat)
|
||||
|
||||
def test_my_view(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.randn(2, 3, 4, device=device)
|
||||
|
||||
result = libtorch_agnostic.ops.my_view(t, [6, 4])
|
||||
expected = t.view([6, 4])
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
result_infer = libtorch_agnostic.ops.my_view(t, [-1, 4])
|
||||
expected_infer = t.view([-1, 4])
|
||||
self.assertEqual(result_infer, expected_infer)
|
||||
|
||||
result_flat = libtorch_agnostic.ops.my_view(t, [-1])
|
||||
expected_flat = t.view([-1])
|
||||
self.assertEqual(result_flat, expected_flat)
|
||||
|
||||
def test_mv_tensor_accessor(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
m = torch.rand(3, 5, device=device)
|
||||
v = torch.rand(5, device=device)
|
||||
result = libtorch_agnostic.ops.mv_tensor_accessor(m, v)
|
||||
expected = torch.mv(m, v)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
# non-contiguous inputs
|
||||
m = torch.rand(3 * 2, 5 * 3, device=device)[::2, ::3]
|
||||
v = torch.rand(5 * 4, device=device)[::4]
|
||||
result = libtorch_agnostic.ops.mv_tensor_accessor(m, v)
|
||||
expected = torch.mv(m, v)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
67
test/cpp_extensions/torch_stable_test_extension/setup.py
Normal file
67
test/cpp_extensions/torch_stable_test_extension/setup.py
Normal file
@ -0,0 +1,67 @@
|
||||
import distutils.command.clean
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).parent
|
||||
CSRC_DIR = ROOT_DIR / "torch_stable_test" / "csrc"
|
||||
|
||||
|
||||
class clean(distutils.command.clean.clean):
|
||||
def run(self):
|
||||
# Run default behavior first
|
||||
distutils.command.clean.clean.run(self)
|
||||
|
||||
# Remove extension
|
||||
for path in (ROOT_DIR / "torch_stable_test").glob("**/*.so"):
|
||||
path.unlink()
|
||||
# Remove build and dist and egg-info directories
|
||||
dirs = [
|
||||
ROOT_DIR / "build",
|
||||
ROOT_DIR / "dist",
|
||||
ROOT_DIR / "torch_stable_test.egg-info",
|
||||
]
|
||||
for path in dirs:
|
||||
if path.exists():
|
||||
shutil.rmtree(str(path), ignore_errors=True)
|
||||
|
||||
|
||||
def get_extension():
|
||||
extra_compile_args = {
|
||||
"cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"],
|
||||
}
|
||||
|
||||
sources = list(CSRC_DIR.glob("**/*.cpp"))
|
||||
|
||||
return [
|
||||
CppExtension(
|
||||
"torch_stable_test._C",
|
||||
sources=sorted(str(s) for s in sources),
|
||||
py_limited_api=True,
|
||||
extra_compile_args=extra_compile_args,
|
||||
extra_link_args=[],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
setup(
|
||||
name="torch_stable_test",
|
||||
version="0.0",
|
||||
author="PyTorch Core Team",
|
||||
description="Test extension to verify TORCH_STABLE_ONLY flag",
|
||||
packages=find_packages(exclude=("test",)),
|
||||
package_data={"torch_stable_test": ["*.dll", "*.dylib", "*.so"]},
|
||||
install_requires=[
|
||||
"torch",
|
||||
],
|
||||
ext_modules=get_extension(),
|
||||
cmdclass={
|
||||
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
|
||||
"clean": clean,
|
||||
},
|
||||
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
||||
)
|
||||
@ -0,0 +1 @@
|
||||
#include <ATen/core/TensorBase.h> // This should trigger the TORCH_STABLE_ONLY error
|
||||
@ -0,0 +1,22 @@
|
||||
# Owner(s): ["module: cpp"]
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
install_cpp_extension,
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
if not IS_WINDOWS:
|
||||
|
||||
class TestTorchStable(TestCase):
|
||||
def test_setup_fails(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "build failed for cpp extension"):
|
||||
install_cpp_extension(extension_root=Path(__file__).parent.parent)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -22,7 +22,7 @@ void check_all_parameters(
|
||||
|
||||
template<class Result, class... Args>
|
||||
Result get_operator_from_registry_and_execute(const char* op_name, Args&&... args) {
|
||||
auto& ops = torch::jit::getAllOperatorsFor(
|
||||
auto ops = torch::jit::getAllOperatorsFor(
|
||||
torch::jit::Symbol::fromQualString(op_name));
|
||||
TORCH_INTERNAL_ASSERT(ops.size() == 1);
|
||||
|
||||
|
||||
@ -65,7 +65,6 @@ from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
curr_backend = dist.get_default_backend_for_device(device_type)
|
||||
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
@ -423,10 +422,10 @@ class TestFullyShard2DStateDict(DTensorTestBase):
|
||||
@property
|
||||
def backend(self):
|
||||
# need to specify gloo backend for testing cpu offload
|
||||
return f"cpu:gloo,{device_type}:{curr_backend}"
|
||||
return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_fully_shard_tp_2d_set_full_state_dict(self):
|
||||
dummy_model = SimpleModel().to(device_type)
|
||||
mesh_2d = init_device_mesh(
|
||||
@ -515,8 +514,8 @@ class Test2dFSDP1ParallelIntegration(DTensorTestBase):
|
||||
).to_local()
|
||||
self.assertEqual(param_m2, param_m1)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_2d_ddp_integration_functionality(self) -> None:
|
||||
model, twod_model, dp_pg = self.init_model(self.device_type)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=3e-5)
|
||||
@ -567,8 +566,8 @@ class TestNew2dParallelTraining(DTensorTestBase):
|
||||
p2 = p2.redistribute(p2.device_mesh, [Replicate()]).to_local()
|
||||
self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}")
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_2d_fsdp_state_enable_extension(self):
|
||||
mesh_2d = init_device_mesh(
|
||||
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
|
||||
@ -643,18 +642,18 @@ class TestNew2dParallelTraining(DTensorTestBase):
|
||||
# Ensure all params are still the same after optimizer update.
|
||||
self._compare_params(model, model_2d)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_2d_e2e_training_default(self):
|
||||
self._test_2d_e2e_training()
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_2d_e2e_training_use_orig_params(self):
|
||||
self._test_2d_e2e_training(use_orig_params=True)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_2d_e2e_training_not_use_orig_params(self):
|
||||
# TODO: need to revisit input_reshard API about why it failed multi-gpu tests.
|
||||
# self._test_2d_e2e_training(recompute_activation=True)
|
||||
@ -667,10 +666,10 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
@property
|
||||
def backend(self):
|
||||
# need to specify gloo backend for testing cpu offload
|
||||
return f"cpu:gloo,{device_type}:{curr_backend}"
|
||||
return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_fsdp_2d_extension(self):
|
||||
"""
|
||||
Test whether _fsdp_extension from FSDPstate has been set correctly.
|
||||
@ -701,8 +700,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
model_1d_fsdp_state = _get_module_fsdp_state(model_1d)
|
||||
self.assertEqual(model_1d_fsdp_state._fsdp_extension, None)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@parametrize("is_even_sharded_model", [True, False])
|
||||
def test_2d_state_dict(self, is_even_sharded_model):
|
||||
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
|
||||
@ -757,8 +756,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
torch.allclose(no_wrap_v, all_gather_two_d_v.to_local()), True
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@parametrize("is_even_sharded_model", [True, False])
|
||||
def test_2d_load_state_dict(self, is_even_sharded_model):
|
||||
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
|
||||
@ -812,8 +811,8 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
self.assertEqual(v1.device_mesh, v2.device_mesh)
|
||||
self.assertEqual(v1.placements, v2.placements)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@parametrize("is_even_sharded_model", [True, False])
|
||||
def test_2d_optim_state_dict(self, is_even_sharded_model):
|
||||
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
|
||||
@ -900,9 +899,9 @@ class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
else:
|
||||
self.assertEqual(new_state, state)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_fsdp1_tp_2d_set_full_state_dict(self):
|
||||
"""
|
||||
This is a workaround for loading full state dict into a FSDP1+TP 2D model.
|
||||
|
||||
@ -29,8 +29,8 @@ from torch.distributed.tensor.parallel import (
|
||||
parallelize_module,
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_distributed import (
|
||||
at_least_x_gpu,
|
||||
MultiProcessTestCase,
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
@ -40,6 +40,7 @@ from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_XPU,
|
||||
)
|
||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
@ -106,9 +107,11 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
def device(self):
|
||||
return self.rank
|
||||
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@skip_if_lt_x_gpu(8)
|
||||
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU and not TEST_XPU, "Test requires 4+ GPUs"
|
||||
)
|
||||
def test_pp_and_dcp(self):
|
||||
"""
|
||||
Test that pipeline parallelism and distributed checkpointing can be used together and
|
||||
@ -198,9 +201,11 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
|
||||
_dcp_test(self)
|
||||
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@skip_if_lt_x_gpu(8)
|
||||
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
|
||||
)
|
||||
@parametrize(
|
||||
"ScheduleClass",
|
||||
[
|
||||
@ -350,9 +355,11 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@skip_if_lt_x_gpu(8)
|
||||
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
|
||||
)
|
||||
@parametrize(
|
||||
"ScheduleClass",
|
||||
[
|
||||
@ -543,9 +550,11 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@skip_if_lt_x_gpu(8)
|
||||
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
|
||||
)
|
||||
@parametrize(
|
||||
"ScheduleClass",
|
||||
[
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
@ -17,8 +18,8 @@ from torch.distributed.algorithms.ddp_comm_hooks import (
|
||||
)
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.testing._internal.common_distributed import (
|
||||
DistributedTestBase,
|
||||
requires_accelerator_dist_backend,
|
||||
MultiProcessTestCase,
|
||||
requires_nccl,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
@ -29,12 +30,9 @@ if TEST_WITH_DEV_DBG_ASAN:
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
def gpus_for_rank(world_size):
|
||||
visible_devices = list(range(torch.accelerator.device_count()))
|
||||
gpus_per_process = torch.accelerator.device_count() // world_size
|
||||
visible_devices = list(range(torch.cuda.device_count()))
|
||||
gpus_per_process = torch.cuda.device_count() // world_size
|
||||
gpus_for_rank = []
|
||||
for rank in range(world_size):
|
||||
gpus_for_rank.append(
|
||||
@ -62,7 +60,27 @@ class TestDdpCommHook(nn.Module):
|
||||
return self.t0(x ** (1 + rank))
|
||||
|
||||
|
||||
class DistributedDataParallelCommHookTest(DistributedTestBase):
|
||||
class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _get_process_group_nccl(self):
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
return dist.distributed_c10d._get_default_group()
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return 2
|
||||
@ -101,14 +119,14 @@ class DistributedDataParallelCommHookTest(DistributedTestBase):
|
||||
param = next(model.parameters())
|
||||
return param.grad
|
||||
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_comm_hook_allreduce_hook(self):
|
||||
"""
|
||||
This unit test verifies the ``allreduce`` hook registered case gives same result
|
||||
with no hook registered case.
|
||||
"""
|
||||
process_group = self.create_pg(device_type)
|
||||
process_group = self._get_process_group_nccl()
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
@ -117,14 +135,14 @@ class DistributedDataParallelCommHookTest(DistributedTestBase):
|
||||
|
||||
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0)
|
||||
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_comm_hook_fp16compress_hook(self):
|
||||
"""
|
||||
This unit test verifies the ``fp16 compress`` hook registered case
|
||||
gives close result with no hook registered case.
|
||||
"""
|
||||
process_group = self.create_pg(device_type)
|
||||
process_group = self._get_process_group_nccl()
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
@ -133,14 +151,14 @@ class DistributedDataParallelCommHookTest(DistributedTestBase):
|
||||
|
||||
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
|
||||
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_comm_hook_quantize_per_tensor_hook(self):
|
||||
"""
|
||||
This unit test verifies the ``quantize per tensor`` hook registered case
|
||||
gives close result with no hook registered case.
|
||||
"""
|
||||
process_group = self.create_pg(device_type)
|
||||
process_group = self._get_process_group_nccl()
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
@ -149,14 +167,14 @@ class DistributedDataParallelCommHookTest(DistributedTestBase):
|
||||
|
||||
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
|
||||
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_comm_hook_quantize_per_channel_hook(self):
|
||||
"""
|
||||
This unit test verifies the ``quantize per channel`` hook registered case
|
||||
gives close result with no hook registered case.
|
||||
"""
|
||||
process_group = self.create_pg(device_type)
|
||||
process_group = self._get_process_group_nccl()
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
@ -167,14 +185,14 @@ class DistributedDataParallelCommHookTest(DistributedTestBase):
|
||||
|
||||
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
|
||||
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_comm_hook_noop_hook(self):
|
||||
"""
|
||||
This unit test verifies the ``noop`` hook registered case and a subsequent allreduce
|
||||
gives same result with no hook registered case.
|
||||
"""
|
||||
process_group = self.create_pg(device_type)
|
||||
process_group = self._get_process_group_nccl()
|
||||
|
||||
# No hook registered case, get the reference grads.
|
||||
reference_grads = self._get_grads(process_group, None)
|
||||
@ -186,10 +204,10 @@ class DistributedDataParallelCommHookTest(DistributedTestBase):
|
||||
|
||||
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0)
|
||||
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_is_last_hook(self):
|
||||
process_group = self.create_pg(device_type)
|
||||
process_group = self._get_process_group_nccl()
|
||||
|
||||
def hook(flags, bucket):
|
||||
flags.append(bucket.is_last())
|
||||
|
||||
@ -32,7 +32,7 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
class TestStateDictUtils(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return min(4, torch.accelerator.device_count())
|
||||
return min(4, torch.cuda.device_count())
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@ -49,7 +49,7 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
|
||||
)
|
||||
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
|
||||
self.assertEqual(gathered_state_dict["dtensor"].device.type, self.device_type)
|
||||
self.assertTrue(gathered_state_dict["dtensor"].is_cuda)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@ -69,16 +69,14 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
)
|
||||
if dist.get_rank() in (0, 2):
|
||||
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
|
||||
self.assertNotEqual(
|
||||
gathered_state_dict["dtensor"].device.type, self.device_type
|
||||
)
|
||||
self.assertFalse(gathered_state_dict["dtensor"].is_cuda)
|
||||
else:
|
||||
self.assertEqual(gathered_state_dict, {})
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_cpu_and_ranks_only(self):
|
||||
device = torch.device(self.device_type)
|
||||
device = torch.device("cuda")
|
||||
state_dict = {
|
||||
"tensor1": torch.arange(10, device=device),
|
||||
"tensor2": torch.ones(10, device=device),
|
||||
@ -87,7 +85,7 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
cpu_state_dict = _offload_state_dict_to_cpu(state_dict, ranks_only=(0, 2))
|
||||
if dist.get_rank() in (0, 2):
|
||||
for v in cpu_state_dict.values():
|
||||
self.assertNotEqual(v.device.type, self.device_type)
|
||||
self.assertFalse(v.is_cuda)
|
||||
self.assertEqual(cpu_state_dict["tensor1"], torch.arange(10))
|
||||
self.assertEqual(cpu_state_dict["tensor2"], torch.ones(10))
|
||||
else:
|
||||
@ -111,27 +109,27 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
for _ in range(10):
|
||||
tensor, dtensor = create_dtensor()
|
||||
ltensor.append(tensor)
|
||||
ltensor.append(torch.ones(10, device=torch.device(self.device_type)))
|
||||
ltensor.append(torch.ones(10, device=torch.device("cuda")))
|
||||
ldtensor.append(dtensor)
|
||||
ldtensor.append(torch.ones(10, device=torch.device(self.device_type)))
|
||||
ldtensor.append(torch.ones(10, device=torch.device("cuda")))
|
||||
|
||||
tensor, dtensor = create_dtensor()
|
||||
dist_state_dict = {
|
||||
"local": dtensor,
|
||||
"list": ldtensor,
|
||||
"arange": torch.arange(10, device=torch.device(self.device_type)),
|
||||
"arange": torch.arange(10, device=torch.device("cuda")),
|
||||
}
|
||||
state_dict = {
|
||||
"local": tensor,
|
||||
"list": ltensor,
|
||||
"arange": torch.arange(10, device=torch.device(self.device_type)),
|
||||
"arange": torch.arange(10, device=torch.device("cuda")),
|
||||
}
|
||||
self.assertEqual(state_dict, _gather_state_dict(dist_state_dict))
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_create_cpu_state_dict(self):
|
||||
device = torch.device(self.device_type)
|
||||
device = torch.device("cuda")
|
||||
rank = dist.get_rank()
|
||||
# Scale tensors based on world size
|
||||
# to fit in the tensor shards accurately.
|
||||
@ -151,7 +149,7 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
metadata=ShardMetadata(
|
||||
shard_offsets=[5 * rank, 0],
|
||||
shard_sizes=[5, 10],
|
||||
placement=f"rank:{rank}/{self.device_type}:{rank}",
|
||||
placement=f"rank:{rank}/cuda:{rank}",
|
||||
),
|
||||
)
|
||||
],
|
||||
@ -161,7 +159,7 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
torch.arange(50 * scale_factor, device=device).reshape(
|
||||
5 * scale_factor, 10
|
||||
),
|
||||
init_device_mesh(self.device_type, mesh_shape=(self.world_size,)),
|
||||
init_device_mesh("cuda", mesh_shape=(self.world_size,)),
|
||||
[Shard(0)],
|
||||
),
|
||||
"non_tensor_bytes_io": copy.deepcopy(buffer),
|
||||
@ -247,7 +245,7 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
even_tensor = torch.randn(self.world_size, 2)
|
||||
uneven_tensor = torch.randn(1, 2)
|
||||
|
||||
mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
|
||||
mesh = init_device_mesh("cuda", mesh_shape=(self.world_size,))
|
||||
even_dtensor = distribute_tensor(
|
||||
torch.randn(self.world_size, 2), mesh, [Shard(0)]
|
||||
)
|
||||
@ -275,10 +273,10 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_cpu_offload_for_dtensor(self):
|
||||
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
|
||||
device_mesh = init_device_mesh("cuda", mesh_shape=(self.world_size,))
|
||||
sd = {
|
||||
"k": DTensor.from_local(
|
||||
torch.ones(8, 8, device=self.device_type), device_mesh, [Shard(0)]
|
||||
torch.ones(8, 8, device="cuda"), device_mesh, [Shard(0)]
|
||||
)
|
||||
}
|
||||
cpu_sd = _create_cpu_state_dict(sd)
|
||||
@ -292,12 +290,12 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
|
||||
self.assertFalse(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
|
||||
_copy_state_dict(sd, cpu_sd, non_blocking=True)
|
||||
torch.accelerator.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
self.assertTrue(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
|
||||
sd["k"] += 1
|
||||
self.assertFalse(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
|
||||
_copy_state_dict(sd, cpu_sd, non_blocking=True)
|
||||
torch.accelerator.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
self.assertTrue(torch.equal(sd["k"].cpu(), cpu_sd["k"]))
|
||||
|
||||
|
||||
|
||||
@ -2,23 +2,16 @@
|
||||
|
||||
import copy
|
||||
import math
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
|
||||
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
from tools.flight_recorder.components.builder import build_db
|
||||
from tools.flight_recorder.components.config_manager import JobConfig
|
||||
from tools.flight_recorder.components.types import COLLECTIVES, MatchInfo, MatchState
|
||||
from tools.flight_recorder.components.utils import match_one_event
|
||||
|
||||
|
||||
# Make sure to remove REPO_ROOT after import is done
|
||||
sys.path.remove(str(REPO_ROOT))
|
||||
|
||||
from torch.distributed.flight_recorder.components.builder import build_db
|
||||
from torch.distributed.flight_recorder.components.config_manager import JobConfig
|
||||
from torch.distributed.flight_recorder.components.types import (
|
||||
COLLECTIVES,
|
||||
MatchInfo,
|
||||
MatchState,
|
||||
)
|
||||
from torch.distributed.flight_recorder.components.utils import match_one_event
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
|
||||
import copy
|
||||
import sys
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
@ -40,6 +40,7 @@ from torch.testing._internal.common_distributed import (
|
||||
skip_if_rocm_multiprocess,
|
||||
skip_if_win32,
|
||||
)
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -56,17 +57,7 @@ except ImportError:
|
||||
HAS_TORCHVISION = False
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def deterministic_algorithms(enabled=True):
|
||||
prev_state = torch.are_deterministic_algorithms_enabled()
|
||||
torch.use_deterministic_algorithms(enabled)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.use_deterministic_algorithms(prev_state)
|
||||
device_type = str(get_devtype())
|
||||
|
||||
|
||||
class TestZeroRedundancyOptimizer(DistributedTestBase):
|
||||
@ -1250,7 +1241,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
enabled=True, deterministic=True, benchmark=False
|
||||
)
|
||||
if "cuda" in device
|
||||
else deterministic_algorithms(True)
|
||||
else torch.use_deterministic_algorithms(True)
|
||||
)
|
||||
with det_ctx:
|
||||
device_ids = [rank] if requires_ddp_rank(device) else None
|
||||
|
||||
@ -24,7 +24,7 @@ from torch.distributed._functional_collectives import (
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
|
||||
from torch.testing._internal.common_device_type import e4m3_type
|
||||
from torch.testing._internal.common_distributed import (
|
||||
DistributedTestBase,
|
||||
MultiProcessTestCase,
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
@ -59,8 +59,12 @@ if not dist.is_available():
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
@requires_accelerator_dist_backend()
|
||||
class TestWithNCCL(DistributedTestBase):
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
class TestWithNCCL(MultiProcessTestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
@ -74,7 +78,16 @@ class TestWithNCCL(DistributedTestBase):
|
||||
return torch.device(self.rank)
|
||||
|
||||
def _init_process_group(self) -> None:
|
||||
self.create_pg(self.device.type)
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
backend = dist.get_default_backend_for_device(self.device.type)
|
||||
|
||||
dist.init_process_group(
|
||||
backend=backend,
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
|
||||
@ -11,10 +11,13 @@ if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import DistributedTestBase, TEST_SKIPS
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfHpu,
|
||||
TEST_CUDA,
|
||||
TEST_HPU,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
)
|
||||
|
||||
@ -26,8 +29,16 @@ if TEST_WITH_DEV_DBG_ASAN:
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
device_count = torch.accelerator.device_count()
|
||||
if TEST_HPU:
|
||||
DEVICE = "hpu"
|
||||
elif TEST_CUDA:
|
||||
DEVICE = "cuda"
|
||||
else:
|
||||
DEVICE = "cpu"
|
||||
|
||||
device_module = torch.get_device_module(DEVICE)
|
||||
device_count = device_module.device_count()
|
||||
BACKEND = dist.get_default_backend_for_device(DEVICE)
|
||||
|
||||
|
||||
def with_comms(func=None):
|
||||
@ -38,10 +49,11 @@ def with_comms(func=None):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if device_type != "cpu" and device_count < self.world_size:
|
||||
if DEVICE != "cpu" and device_count < self.world_size:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||
|
||||
self.pg = self.create_pg(device=device_type)
|
||||
kwargs["device"] = DEVICE
|
||||
self.pg = self.create_pg(device=DEVICE)
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
finally:
|
||||
@ -52,7 +64,7 @@ def with_comms(func=None):
|
||||
|
||||
class TestObjectCollectives(DistributedTestBase):
|
||||
@with_comms()
|
||||
def test_all_gather_object(self):
|
||||
def test_all_gather_object(self, device):
|
||||
output = [None] * dist.get_world_size()
|
||||
dist.all_gather_object(object_list=output, obj=self.rank)
|
||||
|
||||
@ -60,7 +72,7 @@ class TestObjectCollectives(DistributedTestBase):
|
||||
self.assertEqual(i, v, f"rank: {self.rank}")
|
||||
|
||||
@with_comms()
|
||||
def test_gather_object(self):
|
||||
def test_gather_object(self, device):
|
||||
output = [None] * dist.get_world_size() if self.rank == 0 else None
|
||||
dist.gather_object(obj=self.rank, object_gather_list=output)
|
||||
|
||||
@ -70,7 +82,7 @@ class TestObjectCollectives(DistributedTestBase):
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_send_recv_object_list(self):
|
||||
def test_send_recv_object_list(self, device):
|
||||
val = 99 if self.rank == 0 else None
|
||||
object_list = [val] * dist.get_world_size()
|
||||
if self.rank == 0:
|
||||
@ -84,7 +96,7 @@ class TestObjectCollectives(DistributedTestBase):
|
||||
self.assertEqual(None, object_list[0])
|
||||
|
||||
@with_comms()
|
||||
def test_broadcast_object_list(self):
|
||||
def test_broadcast_object_list(self, device):
|
||||
val = 99 if self.rank == 0 else None
|
||||
object_list = [val] * dist.get_world_size()
|
||||
# TODO test with broadcast_object_list's device argument
|
||||
@ -93,7 +105,7 @@ class TestObjectCollectives(DistributedTestBase):
|
||||
self.assertEqual(99, object_list[0])
|
||||
|
||||
@with_comms()
|
||||
def test_scatter_object_list(self):
|
||||
def test_scatter_object_list(self, device):
|
||||
input_list = list(range(dist.get_world_size())) if self.rank == 0 else None
|
||||
output_list = [None]
|
||||
dist.scatter_object_list(
|
||||
@ -111,30 +123,34 @@ class TestObjectCollectives(DistributedTestBase):
|
||||
my_pg = dist.new_group(ranks, use_local_synchronization=True)
|
||||
return rank, ranks, my_pg
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_subpg_scatter_object(self):
|
||||
def test_subpg_scatter_object(self, device):
|
||||
rank, ranks, my_pg = self.setup_sub_pg()
|
||||
out_list = [None]
|
||||
dist.scatter_object_list(out_list, ranks, src=ranks[0], group=my_pg)
|
||||
self.assertEqual(rank, out_list[0])
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_subpg_all_gather_object(self):
|
||||
def test_subpg_all_gather_object(self, device):
|
||||
rank, ranks, my_pg = self.setup_sub_pg()
|
||||
out_list = [None] * len(ranks)
|
||||
dist.all_gather_object(out_list, rank, group=my_pg)
|
||||
self.assertEqual(ranks, out_list)
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_subpg_gather_object(self):
|
||||
def test_subpg_gather_object(self, device):
|
||||
rank, ranks, my_pg = self.setup_sub_pg()
|
||||
out_list = [None] * len(ranks) if rank == ranks[0] else None
|
||||
dist.gather_object(rank, out_list, dst=ranks[0], group=my_pg)
|
||||
if rank == ranks[0]:
|
||||
self.assertEqual(ranks, out_list)
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_subpg_broadcast_object(self):
|
||||
def test_subpg_broadcast_object(self, device):
|
||||
rank, ranks, my_pg = self.setup_sub_pg()
|
||||
out_list = [None]
|
||||
if rank == ranks[0]:
|
||||
@ -143,5 +159,7 @@ class TestObjectCollectives(DistributedTestBase):
|
||||
self.assertEqual(ranks[0], out_list[0])
|
||||
|
||||
|
||||
devices = ("cpu", "cuda", "hpu")
|
||||
instantiate_device_type_tests(TestObjectCollectives, globals(), only_for=devices)
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -29,7 +29,7 @@ from torch.distributed.tensor._collective_utils import (
|
||||
)
|
||||
from torch.distributed.tensor.placement_types import _Partial, Shard
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_HPU, TEST_XPU, TestCase
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
with_comms,
|
||||
@ -58,7 +58,7 @@ def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_ran
|
||||
os.environ["LOCAL_RANK"] = f"{local_rank}"
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_XPU or TEST_HPU, "XPU/HPU does not support gloo backend.")
|
||||
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend.")
|
||||
class DeviceMeshTestGlooBackend(DTensorTestBase):
|
||||
@property
|
||||
def backend(self):
|
||||
|
||||
@ -40,6 +40,7 @@ from torch.testing._internal.common_distributed import (
|
||||
DynamoDistributedSingleProcTestCase,
|
||||
MultiProcessTestCase,
|
||||
requires_accelerator_dist_backend,
|
||||
requires_gloo,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -1773,16 +1774,10 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
inputs = [x, w, ar_0, ar_1]
|
||||
f(*inputs, **self.get_world_trs())
|
||||
|
||||
def _pass(g):
|
||||
from torch._inductor.fx_passes.bucketing import bucket_all_reduce
|
||||
|
||||
bucket_all_reduce(g.owning_module, lambda _: 2000)
|
||||
|
||||
torch._inductor.config.post_grad_custom_post_pass = _pass
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
{
|
||||
"reorder_for_compute_comm_overlap": False,
|
||||
"bucket_all_reduces_fx": bucket_mode,
|
||||
}
|
||||
):
|
||||
compiled = torch.compile(f)
|
||||
@ -2234,6 +2229,50 @@ class TestSyncDecisionCrossRanks(MultiProcessTestCase):
|
||||
)
|
||||
assert est_ms_nccl > 0
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_gloo()
|
||||
def test_regression_use_nccl_estimate_with_gloo(self):
|
||||
# Test checks that using nccl estimator option does not hard fail
|
||||
# with backends that does not support runtime estimations, e.g. gloo
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
c10d.init_process_group(
|
||||
backend="gloo", store=store, rank=self.rank, world_size=self.world_size
|
||||
)
|
||||
group = c10d.distributed_c10d._get_default_group()
|
||||
group_name = "default"
|
||||
torch._C._distributed_c10d._register_process_group(
|
||||
group_name, torch.distributed.group.WORLD
|
||||
)
|
||||
group_size = group.size()
|
||||
|
||||
def func(inp, group_size, group_name):
|
||||
ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
inp, group_size, group_name
|
||||
)
|
||||
ag_0_wait = torch.ops.c10d_functional.wait_tensor(ag_0_out)
|
||||
ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
ag_0_wait, group_size, group_name
|
||||
)
|
||||
ag_1_wait = torch.ops.c10d_functional.wait_tensor(ag_1_out)
|
||||
return ag_1_wait
|
||||
|
||||
gm = make_fx(func)(torch.ones(4, 4), group_size, group_name)
|
||||
g = gm.graph
|
||||
for n in g.nodes:
|
||||
if is_all_gather_into_tensor(n):
|
||||
from torch._inductor.comm_analysis import (
|
||||
estimate_nccl_collective_runtime_from_fx_node,
|
||||
)
|
||||
|
||||
est_ms = estimate_nccl_collective_runtime_from_fx_node(
|
||||
n, use_nccl_estimator=False
|
||||
)
|
||||
assert est_ms > 0
|
||||
est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node(
|
||||
n, use_nccl_estimator=True
|
||||
)
|
||||
assert est_ms_nccl > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -29,6 +29,8 @@ from torch.testing._internal.common_utils import (
|
||||
|
||||
MY_LAMBDA = lambda x: x + 1 # noqa: E731
|
||||
|
||||
EPS = torch.tensor(1e-7)
|
||||
|
||||
|
||||
class CustomCompiledFunction(torch._dynamo.aot_compile.SerializableCallable):
|
||||
def __init__(self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]):
|
||||
@ -587,6 +589,18 @@ from user code:
|
||||
actual = compiled_fn(fn, *inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_aot_compile_with_global_tensor(self):
|
||||
def fn(x, y):
|
||||
return x + y + EPS
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(3, 4), torch.randn(3, 4))
|
||||
|
||||
compiled_fn = torch.compile(fn, fullgraph=True).aot_compile((make_inputs(), {}))
|
||||
|
||||
test_inputs = make_inputs()
|
||||
self.assertEqual(compiled_fn(*test_inputs), fn(*test_inputs))
|
||||
|
||||
def test_aot_compile_with_default_args(self):
|
||||
def fn(x, y=1):
|
||||
return x + x
|
||||
|
||||
@ -330,6 +330,13 @@ y = FakeTensor(..., size=(2,))
|
||||
'obj_weakref': None
|
||||
'guarded_class': None
|
||||
}
|
||||
global '' GLOBAL_STATE
|
||||
{
|
||||
'guard_types': None,
|
||||
'code': None,
|
||||
'obj_weakref': None
|
||||
'guarded_class': None
|
||||
}
|
||||
global '' TORCH_FUNCTION_STATE
|
||||
{
|
||||
'guard_types': None,
|
||||
|
||||
@ -90,12 +90,12 @@ class GraphModule(torch.nn.Module):
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(u0)", primals_2: "Sym(u1)", primals_3: "Sym(u2)", primals_4: "f32[u0, u1, u2]"):
|
||||
ge_1: "Sym(u0 >= 0)" = primals_1 >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
ge_3: "Sym(u1 >= 0)" = primals_2 >= 0
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
|
||||
ge_5: "Sym(u2 >= 0)" = primals_3 >= 0
|
||||
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_5 = _assert_scalar_2 = None
|
||||
ge: "Sym(u0 >= 0)" = primals_1 >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
|
||||
ge_1: "Sym(u1 >= 0)" = primals_2 >= 0
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
|
||||
ge_2: "Sym(u2 >= 0)" = primals_3 >= 0
|
||||
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_2 = None
|
||||
|
||||
floordiv: "Sym((u0//2))" = primals_1 // 2
|
||||
|
||||
|
||||
@ -1214,7 +1214,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
with torch.enable_grad():
|
||||
ref, loaded = self._test_serialization("GRAD_MODE", fn, x)
|
||||
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||
with torch.no_grad():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
with torch.enable_grad():
|
||||
@ -1226,7 +1226,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
with torch.enable_grad():
|
||||
ref, _ = self._test_serialization("GRAD_MODE", fn, x)
|
||||
ref, _ = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||
with torch.no_grad():
|
||||
# Ensure guards state loading is not affected by the current global grad mode.
|
||||
guards_state = pickle.loads(self._cached_guards_state)
|
||||
@ -1246,7 +1246,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
try:
|
||||
x = torch.randn(3, 2)
|
||||
torch.use_deterministic_algorithms(True)
|
||||
ref, loaded = self._test_serialization("DETERMINISTIC_ALGORITHMS", fn, x)
|
||||
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||
torch.use_deterministic_algorithms(False)
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
torch.use_deterministic_algorithms(True)
|
||||
@ -1270,6 +1270,9 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
ref, loaded = self._test_serialization("TORCH_FUNCTION_STATE", fn, x)
|
||||
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
with GlobalTorchFunctionMode():
|
||||
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||
with GlobalTorchFunctionMode():
|
||||
with torch._C.DisableTorchFunction():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
@ -1306,7 +1309,7 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
x = torch.randn(3, 2)
|
||||
|
||||
with torch.enable_grad():
|
||||
ref, loaded = self._test_serialization("FSDP_TRAINING_STATE", fn, x)
|
||||
ref, loaded = self._test_serialization("GLOBAL_STATE", fn, x)
|
||||
with torch.no_grad():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
with torch.enable_grad():
|
||||
@ -1690,6 +1693,38 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
ref, loaded, {"x": x, "d": ModWithDict({"b": 1e-9, "a": 1e9})}, False
|
||||
)
|
||||
|
||||
def test_global_state_guard_filter(self):
|
||||
def foo(x):
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
|
||||
with torch.no_grad():
|
||||
compiled_fn = torch.compile(
|
||||
foo, options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe}
|
||||
)
|
||||
compiled_fn(x)
|
||||
|
||||
# Check global guards are gone.
|
||||
with torch.enable_grad(), torch.compiler.set_stance("fail_on_recompile"):
|
||||
self.assertEqual(compiled_fn(x), foo(x))
|
||||
|
||||
def test_torch_function_state_filter(self):
|
||||
def foo(x):
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
|
||||
with GlobalTorchFunctionMode():
|
||||
compiled_fn = torch.compile(
|
||||
foo, options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe}
|
||||
)
|
||||
compiled_fn(x)
|
||||
|
||||
# Check global guards are gone.
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
self.assertEqual(compiled_fn(x), foo(x))
|
||||
|
||||
|
||||
class SimpleModule(torch.nn.Module):
|
||||
def __init__(self, c):
|
||||
|
||||
@ -727,7 +727,7 @@ class GraphModule(torch.nn.Module):
|
||||
x = torch.randn(3)
|
||||
arg_count = ifdynstaticdefault(4, 5)
|
||||
# when compiled with dynamic, we don't have upper bound runtime assertions for u0
|
||||
expected_op_count = ifdynstaticdefault(10, 8)
|
||||
expected_op_count = ifdynstaticdefault(9, 7)
|
||||
out_graph = self._test_wrap_simple(
|
||||
f,
|
||||
default_args_generator((x,)),
|
||||
@ -747,7 +747,6 @@ class GraphModule(torch.nn.Module):
|
||||
c: "i64[u0, 1]" = l_x_.nonzero()
|
||||
|
||||
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
||||
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
|
||||
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
@ -784,7 +783,6 @@ class GraphModule(torch.nn.Module):
|
||||
c: "i64[u0, 1]" = l_x_.nonzero()
|
||||
|
||||
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
||||
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
|
||||
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
@ -883,7 +881,7 @@ class GraphModule(torch.nn.Module):
|
||||
x = torch.randn(3)
|
||||
arg_count = ifdynstaticdefault(4, 5)
|
||||
# when compiled with dynamic, we don't have upper bound runtime assertions for u0
|
||||
expected_op_count = ifdynstaticdefault(10, 8)
|
||||
expected_op_count = ifdynstaticdefault(9, 7)
|
||||
out_graph = self._test_wrap_simple(
|
||||
f,
|
||||
default_args_generator((x,)),
|
||||
@ -905,7 +903,6 @@ class GraphModule(torch.nn.Module):
|
||||
c: "i64[u0, 1]" = l_x_.nonzero()
|
||||
|
||||
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
||||
_check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None
|
||||
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int >= 0
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
@ -956,7 +953,7 @@ class GraphModule(torch.nn.Module):
|
||||
y = torch.randn(3)
|
||||
arg_count = ifdynstaticdefault(5, 6)
|
||||
# when compiled with dynamic, we don't have upper bound runtime assertions for u0 and u1
|
||||
expected_op_count = ifdynstaticdefault(17, 13)
|
||||
expected_op_count = ifdynstaticdefault(15, 11)
|
||||
out_graph = self._test_wrap_simple(
|
||||
f,
|
||||
default_args_generator((x, y)),
|
||||
@ -977,7 +974,6 @@ class GraphModule(torch.nn.Module):
|
||||
c: "i64[u0, 1]" = l_x_.nonzero()
|
||||
|
||||
sym_size_int_2: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
|
||||
_check_is_size = torch._check_is_size(sym_size_int_2); _check_is_size = None
|
||||
|
||||
ge: "Sym(u0 >= 0)" = sym_size_int_2 >= 0
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
@ -987,7 +983,6 @@ class GraphModule(torch.nn.Module):
|
||||
d: "i64[u1, 1]" = l_y_.nonzero(); l_y_ = None
|
||||
|
||||
sym_size_int_3: "Sym(u1)" = torch.ops.aten.sym_size.int(d, 0)
|
||||
_check_is_size_1 = torch._check_is_size(sym_size_int_3); _check_is_size_1 = None
|
||||
|
||||
ge_1: "Sym(u1 >= 0)" = sym_size_int_3 >= 0
|
||||
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_2 = None
|
||||
|
||||
@ -244,6 +244,61 @@ class MiscTests(torch._inductor.test_case.TestCase):
|
||||
self.assertTrue(same(val4, correct1))
|
||||
self.assertEqual(counter.frame_count, 3)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "cuda needed")
|
||||
def test_assume_32_bit_indexing(self):
|
||||
@torch.compile(backend="inductor")
|
||||
def func(a, b):
|
||||
# Multiple concat operations
|
||||
x = torch.concat([a, b], dim=0)
|
||||
y = torch.concat([a, b], dim=1)
|
||||
|
||||
# Reshape to create indexing patterns
|
||||
x_flat = x.reshape(-1)
|
||||
y_flat = y.reshape(-1)
|
||||
|
||||
# Take the smaller one and expand
|
||||
min_size = min(x_flat.shape[0], y_flat.shape[0])
|
||||
x_trunc = x_flat[:min_size]
|
||||
y_trunc = y_flat[:min_size]
|
||||
|
||||
# Combine and compute
|
||||
result = (x_trunc + y_trunc) * 10
|
||||
|
||||
# Cumulative operations create complex indexing
|
||||
cumsum = result.cumsum(dim=0)
|
||||
|
||||
return cumsum.sum()
|
||||
|
||||
a = torch.rand(100, 30, device="cuda")
|
||||
b = torch.rand(100, 30, device="cuda")
|
||||
|
||||
torch._dynamo.decorators.mark_unbacked(a, 0)
|
||||
torch._dynamo.decorators.mark_unbacked(a, 1)
|
||||
torch._dynamo.decorators.mark_unbacked(b, 0)
|
||||
torch._dynamo.decorators.mark_unbacked(b, 1)
|
||||
|
||||
source_code = run_and_get_code(func, a, b)[1]
|
||||
|
||||
self.assertTrue(
|
||||
"xindex = xoffset + tl.arange(0, XBLOCK)[:].to(tl.int64)\\n"
|
||||
in str(source_code)
|
||||
)
|
||||
self.assertFalse(
|
||||
"xindex = xoffset + tl.arange(0, XBLOCK)[:]\\n" in str(source_code)
|
||||
)
|
||||
|
||||
torch._dynamo.reset()
|
||||
|
||||
with torch._inductor.config.patch(assume_32bit_indexing=True):
|
||||
source_code = run_and_get_code(func, a, b)[1]
|
||||
self.assertFalse(
|
||||
"xindex = xoffset + tl.arange(0, XBLOCK)[:].to(tl.int64)\\n"
|
||||
in str(source_code)
|
||||
)
|
||||
self.assertTrue(
|
||||
"xindex = xoffset + tl.arange(0, XBLOCK)[:]\\n" in str(source_code)
|
||||
)
|
||||
|
||||
def test_dynamo_inside_custom_op(self):
|
||||
cnt = torch._dynamo.testing.InductorAndRecordGraphs()
|
||||
cnt1 = torch._dynamo.testing.InductorAndRecordGraphs()
|
||||
|
||||
@ -562,7 +562,7 @@ class TestDynamoTimed(TestCase):
|
||||
'graph_node_count': 3,
|
||||
'graph_node_shapes': None,
|
||||
'graph_op_count': 1,
|
||||
'guard_count': 9,
|
||||
'guard_count': 10,
|
||||
'has_guarded_code': True,
|
||||
'inductor_code_gen_cumulative_compile_time_us': 0,
|
||||
'inductor_compile_time_s': 0.0,
|
||||
@ -608,7 +608,7 @@ class TestDynamoTimed(TestCase):
|
||||
'tensorify_float_attempt': None,
|
||||
'tensorify_float_failure': None,
|
||||
'tensorify_float_success': None,
|
||||
'triton_compile_time_us': None,
|
||||
'triton_compile_time_us': 0,
|
||||
'triton_kernel_compile_times_us': None,
|
||||
'triton_version': None}"""
|
||||
if _IS_WINDOWS
|
||||
@ -649,7 +649,7 @@ class TestDynamoTimed(TestCase):
|
||||
'graph_node_count': 3,
|
||||
'graph_node_shapes': None,
|
||||
'graph_op_count': 1,
|
||||
'guard_count': 9,
|
||||
'guard_count': 10,
|
||||
'has_guarded_code': True,
|
||||
'inductor_code_gen_cumulative_compile_time_us': 0,
|
||||
'inductor_compile_time_s': 0.0,
|
||||
@ -920,7 +920,7 @@ class TestDynamoTimed(TestCase):
|
||||
first, second = {
|
||||
(3, 9): (10, 6),
|
||||
(3, 10): (10, 6),
|
||||
(3, 11): (10, 6),
|
||||
(3, 11): (11, 7),
|
||||
(3, 12): (11, 7),
|
||||
(3, 13): (11, 7),
|
||||
(3, 14): (11, 7),
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
import copy
|
||||
import types
|
||||
import unittest
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
@ -18,6 +19,9 @@ from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import TEST_CUDA
|
||||
|
||||
|
||||
GLOBAL_LIST = []
|
||||
|
||||
|
||||
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
|
||||
class TestExperiment(TestCase):
|
||||
def test_joint_basic(self) -> None:
|
||||
@ -585,9 +589,9 @@ def forward(self, args_0):
|
||||
_tree_leaf_0, _tree_leaf_1, = pytree.tree_leaves((self, args_0,))
|
||||
L_args_0_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1)
|
||||
l_args_0_ = L_args_0_
|
||||
add = l_args_0_ + 1
|
||||
add = l_args_0_ + 1; add = None
|
||||
mul = l_args_0_ * 2; l_args_0_ = None
|
||||
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, mul, add), self._out_spec)""",
|
||||
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, mul), self._out_spec)""",
|
||||
)
|
||||
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
|
||||
|
||||
@ -611,6 +615,34 @@ def forward(self, args_0):
|
||||
self.assertEqual(len(list(gm.buffers())), len(list(foo.buffers())))
|
||||
self.assertEqual(len(list(gm.parameters())), len(list(foo.parameters())))
|
||||
|
||||
def test_dynamo_graph_capture_side_effects(self):
|
||||
GLOBAL_LIST.clear()
|
||||
|
||||
def foo(x):
|
||||
z = x + 1
|
||||
GLOBAL_LIST.append(z)
|
||||
return z
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(2, 3),)
|
||||
|
||||
trace_inputs = make_inputs()
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
|
||||
cnt = 0
|
||||
for entry in w:
|
||||
if "While compiling, we found certain side effects happened" in str(
|
||||
entry.message
|
||||
):
|
||||
cnt += 1
|
||||
self.assertEqual(cnt, 1)
|
||||
self.assertEqual(len(GLOBAL_LIST), 0)
|
||||
test_inputs = make_inputs()
|
||||
gm_results = gm(*test_inputs)
|
||||
self.assertEqual(len(GLOBAL_LIST), 0)
|
||||
self.assertEqual(gm_results, foo(*test_inputs))
|
||||
self.assertEqual(len(GLOBAL_LIST), 1)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
||||
def test_dynamo_graph_capture_fx_graph_annotate_overlap_pass(self):
|
||||
class DummyOp(torch.autograd.Function):
|
||||
|
||||
@ -740,18 +740,26 @@ class TestExport(TestCase):
|
||||
dynamic_shapes={"x": {0: Dim("b")}, "y": None},
|
||||
)
|
||||
|
||||
# clean up _torchdynamo related meta data as it could vary depending on the caller
|
||||
# https://github.com/pytorch/pytorch/issues/167432
|
||||
for node in ep.graph.nodes:
|
||||
if "custom" in node.meta:
|
||||
node.meta["custom"] = {
|
||||
k: v
|
||||
for k, v in node.meta["custom"].items()
|
||||
if "_torchdynamo_disable" not in k
|
||||
}
|
||||
|
||||
custom_metadata = torch.fx.traceback._get_custom_metadata(ep.module())
|
||||
|
||||
self.assertExpectedInline(
|
||||
str(custom_metadata),
|
||||
"""\
|
||||
('placeholder', 'x', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
|
||||
('placeholder', 'y', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
|
||||
('call_function', 'cat', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('call_function', 'item', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('call_function', 'ge_1', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('call_function', '_assert_scalar_default', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('call_function', 'mul', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
|
||||
('output', 'output', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})""",
|
||||
('call_function', 'cat', {'moo': 0})
|
||||
('call_function', 'item', {'moo': 0})
|
||||
('call_function', 'ge_1', {'moo': 0})
|
||||
('call_function', '_assert_scalar_default', {'moo': 0})
|
||||
('call_function', 'mul', {'moo': 0})""",
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@ -3073,15 +3081,12 @@ def forward(self, x, y):
|
||||
foo = torch.ops.export.foo.default(x, y); x = None
|
||||
sym_size_int = torch.ops.aten.sym_size.int(foo, 0)
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(foo, 1)
|
||||
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int); sym_constrain_range_for_size_default = None
|
||||
ge = sym_size_int >= 0; sym_size_int = None
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
sym_constrain_range_for_size_default_1 = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1); sym_constrain_range_for_size_default_1 = None
|
||||
ge_1 = sym_size_int_1 >= 0; sym_size_int_1 = None
|
||||
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_1 = None
|
||||
bar = torch.ops.export.bar.default(y); y = None
|
||||
sym_size_int_2 = torch.ops.aten.sym_size.int(bar, 0)
|
||||
sym_constrain_range_for_size_default_2 = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_2); sym_constrain_range_for_size_default_2 = None
|
||||
ge_2 = sym_size_int_2 >= 0; sym_size_int_2 = None
|
||||
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_default_2 = None
|
||||
return (foo, bar)""",
|
||||
@ -15303,12 +15308,12 @@ graph():
|
||||
def forward(self, block):
|
||||
return block.a + block.b
|
||||
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError, "It looks like one of the inputs with type"
|
||||
):
|
||||
_dynamo_graph_capture_for_export(Foo())(
|
||||
dynamo_graph_capture_for_export(Foo())(
|
||||
Block(torch.randn(4, 4), torch.randn(4, 4))
|
||||
)
|
||||
|
||||
@ -17735,7 +17740,6 @@ class TestExportCustomClass(TorchTestCase):
|
||||
def forward(self, x, mask):
|
||||
masked_select = torch.ops.aten.masked_select.default(x, mask); x = mask = None
|
||||
sym_size_int_1 = torch.ops.aten.sym_size.int(masked_select, 0)
|
||||
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1); sym_constrain_range_for_size_default = None
|
||||
ge = sym_size_int_1 >= 0
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
|
||||
le = sym_size_int_1 <= 1188864
|
||||
|
||||
@ -1,71 +0,0 @@
|
||||
# Owner(s): ["oncall: export"]
|
||||
|
||||
|
||||
from torch._dynamo import config as dynamo_config
|
||||
from torch._dynamo.testing import make_test_cls_with_patches
|
||||
from torch._export import config as export_config
|
||||
|
||||
|
||||
try:
|
||||
from . import test_export, testing
|
||||
except ImportError:
|
||||
import test_export # @manual=fbcode//caffe2/test:test_export-library
|
||||
import testing # @manual=fbcode//caffe2/test:test_export-library
|
||||
|
||||
from torch.export import export
|
||||
|
||||
|
||||
test_classes = {}
|
||||
|
||||
|
||||
def mocked_strict_export(*args, **kwargs):
|
||||
# If user already specified strict, don't make it strict
|
||||
if "strict" in kwargs:
|
||||
return export(*args, **kwargs)
|
||||
return export(*args, **kwargs, strict=True)
|
||||
|
||||
|
||||
def make_dynamic_cls(cls):
|
||||
# Some test check for ending in suffix; need to make
|
||||
# the `_strict` for end of string as a result
|
||||
suffix = test_export.INLINE_AND_INSTALL_STRICT_SUFFIX
|
||||
|
||||
cls_prefix = "InlineAndInstall"
|
||||
|
||||
cls_a = testing.make_test_cls_with_mocked_export(
|
||||
cls,
|
||||
"StrictExport",
|
||||
suffix,
|
||||
mocked_strict_export,
|
||||
xfail_prop="_expected_failure_strict",
|
||||
)
|
||||
test_class = make_test_cls_with_patches(
|
||||
cls_a,
|
||||
cls_prefix,
|
||||
"",
|
||||
(export_config, "use_new_tracer_experimental", True),
|
||||
(dynamo_config, "install_free_tensors", True),
|
||||
(dynamo_config, "inline_inbuilt_nn_modules", True),
|
||||
xfail_prop="_expected_failure_inline_and_install",
|
||||
)
|
||||
|
||||
test_classes[test_class.__name__] = test_class
|
||||
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
||||
globals()[test_class.__name__] = test_class
|
||||
test_class.__module__ = __name__
|
||||
return test_class
|
||||
|
||||
|
||||
tests = [
|
||||
test_export.TestDynamismExpression,
|
||||
test_export.TestExport,
|
||||
]
|
||||
for test in tests:
|
||||
make_dynamic_cls(test)
|
||||
del test
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
@ -90,6 +91,99 @@ def forward(self, arg0_1):
|
||||
|
||||
self.assertEqual(printed_output, f"moo 1 2\nmoo {new_inp}\nmoo 1 2\nyeehop 4")
|
||||
|
||||
def test_print_with_side_effect(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
torch._higher_order_ops.print("moo {x} {y}", x=1, y=2)
|
||||
res = x + x
|
||||
torch._higher_order_ops.print("moo {x} {y}", x=1, y=2)
|
||||
return (res,)
|
||||
|
||||
inputs = (torch.randn(3),)
|
||||
|
||||
# With functionalization, it should appear wrapped with with_effects()
|
||||
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
|
||||
self.assertExpectedInline(
|
||||
str(gm.code).strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.higher_order.print, 'moo {x} {y}', x = 1, y = 2); \
|
||||
arg0_1 = None
|
||||
getitem = with_effects[0]; with_effects = None
|
||||
add = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
|
||||
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.higher_order.print, 'moo {x} {y}', x = 1, y = 2); \
|
||||
getitem = None
|
||||
getitem_2 = with_effects_1[0]; with_effects_1 = None
|
||||
return (getitem_2, add)""",
|
||||
)
|
||||
self.assertEqual(len(gs.input_tokens), 1)
|
||||
self.assertEqual(len(gs.output_tokens), 1)
|
||||
|
||||
def test_print_with_input_mutations(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
torch._higher_order_ops.print("moo {x} {y}", x=x, y=2)
|
||||
res = x + x
|
||||
x.add_(res)
|
||||
res = x + x
|
||||
torch._higher_order_ops.print("moo {x} {y}", x=x, y=res)
|
||||
return (res,)
|
||||
|
||||
inputs = (torch.randn(3),)
|
||||
|
||||
# With functionalization, it should appear wrapped with with_effects()
|
||||
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
|
||||
self.assertEqual(len(gs.input_tokens), 1)
|
||||
self.assertEqual(len(gs.output_tokens), 1)
|
||||
self.assertEqual(len(gs.user_inputs_to_mutate), 1)
|
||||
self.assertExpectedInline(
|
||||
str(gm.code).strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.higher_order.print, 'moo {x} {y}', \
|
||||
x = arg1_1, y = 2); arg0_1 = None
|
||||
getitem = with_effects[0]; with_effects = None
|
||||
add = torch.ops.aten.add.Tensor(arg1_1, arg1_1)
|
||||
add_1 = torch.ops.aten.add.Tensor(arg1_1, add); arg1_1 = add = None
|
||||
add_2 = torch.ops.aten.add.Tensor(add_1, add_1)
|
||||
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.higher_order.print, 'moo {x} {y}', \
|
||||
x = add_1, y = add_2); getitem = None
|
||||
getitem_2 = with_effects_1[0]; with_effects_1 = None
|
||||
return (getitem_2, add_1, add_2)""",
|
||||
)
|
||||
|
||||
def test_print_gen_schema(self):
|
||||
from torch._higher_order_ops.print import print as print_op
|
||||
|
||||
# Test basic schema generation with simple kwargs int
|
||||
format_str = "Hello {x} {y}"
|
||||
schema = print_op.gen_schema(format_str, x=1, y=2)
|
||||
self.assertExpectedInline(
|
||||
str(schema),
|
||||
"""print(str format_str, *, int x, int y) -> ()""",
|
||||
)
|
||||
# Test schema generation with different types of inputs
|
||||
|
||||
# Tensor input
|
||||
tensor = torch.randn(2, 2)
|
||||
schema_tensor = print_op.gen_schema("Tensor: {x}", x=tensor)
|
||||
self.assertExpectedInline(
|
||||
str(schema_tensor),
|
||||
"""print(str format_str, *, Tensor x) -> ()""",
|
||||
)
|
||||
|
||||
# TODO: Add schema support with kwargs with value of list type
|
||||
|
||||
# No kwargs
|
||||
schema_no_kwargs = print_op.gen_schema("Simple message")
|
||||
self.assertExpectedInline(
|
||||
str(schema_no_kwargs),
|
||||
"""print(str format_str) -> ()""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user