mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 00:54:56 +08:00
Compare commits
160 Commits
benchmarki
...
mlazos/tes
| Author | SHA1 | Date | |
|---|---|---|---|
| 277ff19a55 | |||
| 2b2245d5db | |||
| 206e9d5160 | |||
| 064bb3cebc | |||
| 0350c7e72c | |||
| f7c09f864a | |||
| c2e9115757 | |||
| b90fc2ec27 | |||
| 0cd18ba1ca | |||
| bfae151269 | |||
| 9cbbc2593b | |||
| 5616fa4a68 | |||
| c33fc9dae3 | |||
| 9ce2732b68 | |||
| dbad6d71c7 | |||
| b85c460749 | |||
| 6a781619bf | |||
| c99e91b1d7 | |||
| c014e4bcaa | |||
| daff263062 | |||
| 15e9119a69 | |||
| 7368eeba5e | |||
| 7a79de1c0f | |||
| bd10ea4e6c | |||
| 43390d8b13 | |||
| ad26ec6abe | |||
| 3e71016459 | |||
| 489afa829a | |||
| 472773c7f9 | |||
| f01e628e3b | |||
| 932733e0e6 | |||
| 108422ac26 | |||
| da4aacabac | |||
| 9b5308cd58 | |||
| b019a33f8f | |||
| 0fab32290a | |||
| faf973da5e | |||
| 78624679a8 | |||
| 5f1c3c67b2 | |||
| bbda22e648 | |||
| 0f3db20132 | |||
| eb93c0adb1 | |||
| 1193bf0855 | |||
| 26aa8dcf27 | |||
| 5acb8d5080 | |||
| abc2264e8f | |||
| 22a4cabd19 | |||
| ed1ff7d0fb | |||
| 2f03673ebf | |||
| f57754e815 | |||
| d6edefefbf | |||
| d89d213118 | |||
| 22641f42b6 | |||
| 967937872f | |||
| f9dc20c7a3 | |||
| fb67fa9968 | |||
| 35fc5c49b4 | |||
| b6b9311f4f | |||
| bbdf469f0e | |||
| 2120eeb8de | |||
| 1b569e5490 | |||
| 30ac7f4d4e | |||
| 65d8dba735 | |||
| 3bdceab124 | |||
| 802ffd06c8 | |||
| fc0135ca11 | |||
| 3027051590 | |||
| e7bf72c908 | |||
| 7183f52675 | |||
| 8002d22ce3 | |||
| 31f95b5d2e | |||
| 4b1f047a33 | |||
| ba3f91af97 | |||
| 0f81c7a28d | |||
| 7e8532077f | |||
| 1ece53b157 | |||
| 9d6f0d5991 | |||
| 3c05167489 | |||
| aec3ef1008 | |||
| dc82e911e7 | |||
| 639f459cb6 | |||
| f889dea97d | |||
| 208965a9d6 | |||
| 5a7442b91f | |||
| d66a55def0 | |||
| 382b38ed1b | |||
| bcbd2a22b2 | |||
| 0df96e3921 | |||
| 30f7079c93 | |||
| d173ba5a75 | |||
| 0fdd568b78 | |||
| a4b0023f3b | |||
| ba51f4876d | |||
| 852b99eba0 | |||
| 20ee5f9044 | |||
| 9c06dff1ce | |||
| c3de2c7c6b | |||
| 4a302b5731 | |||
| adfd5b293a | |||
| 0289313551 | |||
| 58ead04ee9 | |||
| 172015fc11 | |||
| 9371491529 | |||
| d6cb0fe576 | |||
| 0134150ebb | |||
| 61bfb3df9f | |||
| 2c1cb38d95 | |||
| 5b6fd277f9 | |||
| 818f76a745 | |||
| dc0f09a478 | |||
| 0c6c7780d9 | |||
| 9ba67e99bb | |||
| d5e0704247 | |||
| 43b18d098b | |||
| b040d63ce4 | |||
| 7d17253af8 | |||
| fdbf314278 | |||
| c7e8e8ee19 | |||
| 1237f271aa | |||
| 08fdc64c86 | |||
| 83a0e4e6f9 | |||
| 2bc8fec744 | |||
| cb56df55dc | |||
| 629fca295e | |||
| 3afbab66f7 | |||
| e8f5c24d17 | |||
| 20ec61a02f | |||
| 5a21d6f982 | |||
| 0db9c64d68 | |||
| 6f992e1b3f | |||
| 634ce22601 | |||
| 8883e494b3 | |||
| 41092cb86c | |||
| 733e684b11 | |||
| 2c6f24c62d | |||
| 53b0f6f543 | |||
| ef1d45b12d | |||
| d6e29bf875 | |||
| 3c74a72ea0 | |||
| cd9ff41282 | |||
| 447b481c79 | |||
| 9c7ed3e46e | |||
| 07343efc15 | |||
| b394c6e89c | |||
| c0864bb389 | |||
| 316e7a9293 | |||
| 2d932a2e01 | |||
| 4613081b72 | |||
| 946a4c2bdc | |||
| ba0a91b3ea | |||
| 22a1b3b5d0 | |||
| 40abb2b403 | |||
| 2b3ac17aa2 | |||
| 81b7c96697 | |||
| 6cda280483 | |||
| bbd45f1f1f | |||
| 0f0d5749a0 | |||
| 65b1aedd09 | |||
| 3e05a48927 | |||
| d865b784e4 |
@ -27,6 +27,7 @@ if [ "$DESIRED_CUDA" = "cpu" ]; then
|
|||||||
USE_PRIORITIZED_TEXT_FOR_LD=1 python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn
|
USE_PRIORITIZED_TEXT_FOR_LD=1 python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn
|
||||||
else
|
else
|
||||||
echo "BASE_CUDA_VERSION is set to: $DESIRED_CUDA"
|
echo "BASE_CUDA_VERSION is set to: $DESIRED_CUDA"
|
||||||
|
export USE_SYSTEM_NCCL=1
|
||||||
#USE_PRIORITIZED_TEXT_FOR_LD for enable linker script optimization https://github.com/pytorch/pytorch/pull/121975/files
|
#USE_PRIORITIZED_TEXT_FOR_LD for enable linker script optimization https://github.com/pytorch/pytorch/pull/121975/files
|
||||||
USE_PRIORITIZED_TEXT_FOR_LD=1 python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda
|
USE_PRIORITIZED_TEXT_FOR_LD=1 python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda
|
||||||
fi
|
fi
|
||||||
|
|||||||
@ -8,16 +8,6 @@ retry () {
|
|||||||
"$@" || (sleep 10 && "$@") || (sleep 20 && "$@") || (sleep 40 && "$@")
|
"$@" || (sleep 10 && "$@") || (sleep 20 && "$@") || (sleep 40 && "$@")
|
||||||
}
|
}
|
||||||
|
|
||||||
# A bunch of custom pip dependencies for ONNX
|
|
||||||
pip_install \
|
|
||||||
beartype==0.15.0 \
|
|
||||||
filelock==3.9.0 \
|
|
||||||
flatbuffers==2.0 \
|
|
||||||
mock==5.0.1 \
|
|
||||||
ninja==1.10.2 \
|
|
||||||
networkx==2.5 \
|
|
||||||
numpy==1.24.2
|
|
||||||
|
|
||||||
# ONNXRuntime should be installed before installing
|
# ONNXRuntime should be installed before installing
|
||||||
# onnx-weekly. Otherwise, onnx-weekly could be
|
# onnx-weekly. Otherwise, onnx-weekly could be
|
||||||
# overwritten by onnx.
|
# overwritten by onnx.
|
||||||
@ -29,11 +19,8 @@ pip_install \
|
|||||||
transformers==4.36.2
|
transformers==4.36.2
|
||||||
|
|
||||||
pip_install coloredlogs packaging
|
pip_install coloredlogs packaging
|
||||||
|
|
||||||
pip_install onnxruntime==1.18.1
|
pip_install onnxruntime==1.18.1
|
||||||
pip_install onnxscript==0.2.6 --no-deps
|
pip_install onnxscript==0.3.0
|
||||||
# required by onnxscript
|
|
||||||
pip_install ml_dtypes
|
|
||||||
|
|
||||||
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
|
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
|
||||||
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
|
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
|
||||||
|
|||||||
@ -51,7 +51,12 @@ as_jenkins git clone --recursive ${TRITON_REPO} triton
|
|||||||
cd triton
|
cd triton
|
||||||
as_jenkins git checkout ${TRITON_PINNED_COMMIT}
|
as_jenkins git checkout ${TRITON_PINNED_COMMIT}
|
||||||
as_jenkins git submodule update --init --recursive
|
as_jenkins git submodule update --init --recursive
|
||||||
cd python
|
|
||||||
|
# Old versions of python have setup.py in ./python; newer versions have it in ./
|
||||||
|
if [ ! -f setup.py ]; then
|
||||||
|
cd python
|
||||||
|
fi
|
||||||
|
|
||||||
pip_install pybind11==2.13.6
|
pip_install pybind11==2.13.6
|
||||||
|
|
||||||
# TODO: remove patch setup.py once we have a proper fix for https://github.com/triton-lang/triton/issues/4527
|
# TODO: remove patch setup.py once we have a proper fix for https://github.com/triton-lang/triton/issues/4527
|
||||||
|
|||||||
@ -41,14 +41,11 @@ fbscribelogger==0.1.7
|
|||||||
#Pinned versions: 0.1.6
|
#Pinned versions: 0.1.6
|
||||||
#test that import:
|
#test that import:
|
||||||
|
|
||||||
flatbuffers==2.0 ; platform_machine != "s390x"
|
flatbuffers==24.12.23
|
||||||
#Description: cross platform serialization library
|
#Description: cross platform serialization library
|
||||||
#Pinned versions: 2.0
|
#Pinned versions: 24.12.23
|
||||||
#test that import:
|
#test that import:
|
||||||
|
|
||||||
flatbuffers ; platform_machine == "s390x"
|
|
||||||
#Description: cross platform serialization library; Newer version is required on s390x for new python version
|
|
||||||
|
|
||||||
hypothesis==5.35.1
|
hypothesis==5.35.1
|
||||||
# Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136
|
# Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136
|
||||||
#Description: advanced library for generating parametrized tests
|
#Description: advanced library for generating parametrized tests
|
||||||
|
|||||||
@ -15,6 +15,9 @@ export INSTALL_TEST=0 # dont install test binaries into site-packages
|
|||||||
export USE_CUPTI_SO=0
|
export USE_CUPTI_SO=0
|
||||||
export USE_CUSPARSELT=${USE_CUSPARSELT:-1} # Enable if not disabled by libtorch build
|
export USE_CUSPARSELT=${USE_CUSPARSELT:-1} # Enable if not disabled by libtorch build
|
||||||
export USE_CUFILE=${USE_CUFILE:-1}
|
export USE_CUFILE=${USE_CUFILE:-1}
|
||||||
|
export USE_SYSTEM_NCCL=1
|
||||||
|
export NCCL_INCLUDE_DIR="/usr/local/cuda/include/"
|
||||||
|
export NCCL_LIB_DIR="/usr/local/cuda/lib64/"
|
||||||
|
|
||||||
# Keep an array of cmake variables to add to
|
# Keep an array of cmake variables to add to
|
||||||
if [[ -z "$CMAKE_ARGS" ]]; then
|
if [[ -z "$CMAKE_ARGS" ]]; then
|
||||||
@ -172,12 +175,9 @@ if [[ $CUDA_VERSION == 12* ]]; then
|
|||||||
export LIB_SO_RPATH=$CUDA_RPATHS':$ORIGIN'
|
export LIB_SO_RPATH=$CUDA_RPATHS':$ORIGIN'
|
||||||
export FORCE_RPATH="--force-rpath"
|
export FORCE_RPATH="--force-rpath"
|
||||||
export USE_STATIC_NCCL=0
|
export USE_STATIC_NCCL=0
|
||||||
export USE_SYSTEM_NCCL=1
|
|
||||||
export ATEN_STATIC_CUDA=0
|
export ATEN_STATIC_CUDA=0
|
||||||
export USE_CUDA_STATIC_LINK=0
|
export USE_CUDA_STATIC_LINK=0
|
||||||
export USE_CUPTI_SO=1
|
export USE_CUPTI_SO=1
|
||||||
export NCCL_INCLUDE_DIR="/usr/local/cuda/include/"
|
|
||||||
export NCCL_LIB_DIR="/usr/local/cuda/lib64/"
|
|
||||||
fi
|
fi
|
||||||
elif [[ $CUDA_VERSION == "11.8" ]]; then
|
elif [[ $CUDA_VERSION == "11.8" ]]; then
|
||||||
export USE_STATIC_CUDNN=0
|
export USE_STATIC_CUDNN=0
|
||||||
@ -254,12 +254,9 @@ elif [[ $CUDA_VERSION == "11.8" ]]; then
|
|||||||
export LIB_SO_RPATH=$CUDA_RPATHS':$ORIGIN'
|
export LIB_SO_RPATH=$CUDA_RPATHS':$ORIGIN'
|
||||||
export FORCE_RPATH="--force-rpath"
|
export FORCE_RPATH="--force-rpath"
|
||||||
export USE_STATIC_NCCL=0
|
export USE_STATIC_NCCL=0
|
||||||
export USE_SYSTEM_NCCL=1
|
|
||||||
export ATEN_STATIC_CUDA=0
|
export ATEN_STATIC_CUDA=0
|
||||||
export USE_CUDA_STATIC_LINK=0
|
export USE_CUDA_STATIC_LINK=0
|
||||||
export USE_CUPTI_SO=1
|
export USE_CUPTI_SO=1
|
||||||
export NCCL_INCLUDE_DIR="/usr/local/cuda/include/"
|
|
||||||
export NCCL_LIB_DIR="/usr/local/cuda/lib64/"
|
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
echo "Unknown cuda version $CUDA_VERSION"
|
echo "Unknown cuda version $CUDA_VERSION"
|
||||||
|
|||||||
@ -324,6 +324,12 @@ test_python_smoke() {
|
|||||||
assert_git_not_dirty
|
assert_git_not_dirty
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test_h100_distributed() {
|
||||||
|
# Distributed tests at H100
|
||||||
|
time python test/run_test.py --include distributed/_composable/test_composability/test_pp_composability.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||||
|
assert_git_not_dirty
|
||||||
|
}
|
||||||
|
|
||||||
test_lazy_tensor_meta_reference_disabled() {
|
test_lazy_tensor_meta_reference_disabled() {
|
||||||
export TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1
|
export TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1
|
||||||
echo "Testing lazy tensor operations without meta reference"
|
echo "Testing lazy tensor operations without meta reference"
|
||||||
@ -595,7 +601,6 @@ test_perf_for_dashboard() {
|
|||||||
elif [[ "${TEST_CONFIG}" == *cpu_aarch64* ]]; then
|
elif [[ "${TEST_CONFIG}" == *cpu_aarch64* ]]; then
|
||||||
device=cpu_aarch64
|
device=cpu_aarch64
|
||||||
fi
|
fi
|
||||||
test_inductor_set_cpu_affinity
|
|
||||||
elif [[ "${TEST_CONFIG}" == *cuda_a10g* ]]; then
|
elif [[ "${TEST_CONFIG}" == *cuda_a10g* ]]; then
|
||||||
device=cuda_a10g
|
device=cuda_a10g
|
||||||
elif [[ "${TEST_CONFIG}" == *h100* ]]; then
|
elif [[ "${TEST_CONFIG}" == *h100* ]]; then
|
||||||
@ -604,6 +609,9 @@ test_perf_for_dashboard() {
|
|||||||
device=rocm
|
device=rocm
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Always set CPU affinity because metrics like compilation time requires CPU
|
||||||
|
test_inductor_set_cpu_affinity
|
||||||
|
|
||||||
for mode in "${modes[@]}"; do
|
for mode in "${modes[@]}"; do
|
||||||
if [[ "$mode" == "inference" ]]; then
|
if [[ "$mode" == "inference" ]]; then
|
||||||
dtype=bfloat16
|
dtype=bfloat16
|
||||||
@ -1639,7 +1647,7 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then
|
|||||||
install_torchaudio cuda
|
install_torchaudio cuda
|
||||||
fi
|
fi
|
||||||
install_torchvision
|
install_torchvision
|
||||||
TORCH_CUDA_ARCH_LIST="8.0;8.6" pip_install git+https://github.com/pytorch/ao.git
|
TORCH_CUDA_ARCH_LIST="8.0;8.6" install_torchao
|
||||||
id=$((SHARD_NUMBER-1))
|
id=$((SHARD_NUMBER-1))
|
||||||
# https://github.com/opencv/opencv-python/issues/885
|
# https://github.com/opencv/opencv-python/issues/885
|
||||||
pip_install opencv-python==4.8.0.74
|
pip_install opencv-python==4.8.0.74
|
||||||
@ -1724,6 +1732,8 @@ elif [[ "${BUILD_ENVIRONMENT}" == *xpu* ]]; then
|
|||||||
test_xpu_bin
|
test_xpu_bin
|
||||||
elif [[ "${TEST_CONFIG}" == smoke ]]; then
|
elif [[ "${TEST_CONFIG}" == smoke ]]; then
|
||||||
test_python_smoke
|
test_python_smoke
|
||||||
|
elif [[ "${TEST_CONFIG}" == h100_distributed ]]; then
|
||||||
|
test_h100_distributed
|
||||||
else
|
else
|
||||||
install_torchvision
|
install_torchvision
|
||||||
install_monkeytype
|
install_monkeytype
|
||||||
|
|||||||
21
.github/actions/reuse-old-whl/reuse_old_whl.py
vendored
21
.github/actions/reuse-old-whl/reuse_old_whl.py
vendored
@ -120,6 +120,23 @@ def ok_changed_file(file: str) -> bool:
|
|||||||
def check_changed_files(sha: str) -> bool:
|
def check_changed_files(sha: str) -> bool:
|
||||||
# Return true if all the changed files are in the list of allowed files to
|
# Return true if all the changed files are in the list of allowed files to
|
||||||
# be changed to reuse the old whl
|
# be changed to reuse the old whl
|
||||||
|
|
||||||
|
# Removing any files is not allowed since rysnc will not remove files
|
||||||
|
removed_files = (
|
||||||
|
subprocess.check_output(
|
||||||
|
["git", "diff", "--name-only", sha, "HEAD", "--diff-filter=D"],
|
||||||
|
text=True,
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
)
|
||||||
|
.strip()
|
||||||
|
.split()
|
||||||
|
)
|
||||||
|
if removed_files:
|
||||||
|
print(
|
||||||
|
f"Removed files between {sha} and HEAD: {removed_files}, cannot reuse old whl"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
changed_files = (
|
changed_files = (
|
||||||
subprocess.check_output(
|
subprocess.check_output(
|
||||||
["git", "diff", "--name-only", sha, "HEAD"],
|
["git", "diff", "--name-only", sha, "HEAD"],
|
||||||
@ -190,6 +207,10 @@ def unzip_artifact_and_replace_files() -> None:
|
|||||||
subprocess.check_output(
|
subprocess.check_output(
|
||||||
["unzip", "-o", new_path, "-d", f"artifacts/dist/{new_path.stem}"],
|
["unzip", "-o", new_path, "-d", f"artifacts/dist/{new_path.stem}"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Remove the old wheel (which is now a zip file)
|
||||||
|
os.remove(new_path)
|
||||||
|
|
||||||
# Copy python files into the artifact
|
# Copy python files into the artifact
|
||||||
subprocess.check_output(
|
subprocess.check_output(
|
||||||
["rsync", "-avz", "torch", f"artifacts/dist/{new_path.stem}"],
|
["rsync", "-avz", "torch", f"artifacts/dist/{new_path.stem}"],
|
||||||
|
|||||||
1
.github/pytorch-probot.yml
vendored
1
.github/pytorch-probot.yml
vendored
@ -28,6 +28,7 @@ ciflow_push_tags:
|
|||||||
- ciflow/op-benchmark
|
- ciflow/op-benchmark
|
||||||
- ciflow/pull
|
- ciflow/pull
|
||||||
- ciflow/h100
|
- ciflow/h100
|
||||||
|
- ciflow/h100-distributed
|
||||||
retryable_workflows:
|
retryable_workflows:
|
||||||
- pull
|
- pull
|
||||||
- trunk
|
- trunk
|
||||||
|
|||||||
15
.github/scripts/build_triton_wheel.py
vendored
15
.github/scripts/build_triton_wheel.py
vendored
@ -65,6 +65,7 @@ def build_triton(
|
|||||||
with TemporaryDirectory() as tmpdir:
|
with TemporaryDirectory() as tmpdir:
|
||||||
triton_basedir = Path(tmpdir) / "triton"
|
triton_basedir = Path(tmpdir) / "triton"
|
||||||
triton_pythondir = triton_basedir / "python"
|
triton_pythondir = triton_basedir / "python"
|
||||||
|
|
||||||
triton_repo = "https://github.com/openai/triton"
|
triton_repo = "https://github.com/openai/triton"
|
||||||
if device == "rocm":
|
if device == "rocm":
|
||||||
triton_pkg_name = "pytorch-triton-rocm"
|
triton_pkg_name = "pytorch-triton-rocm"
|
||||||
@ -101,11 +102,19 @@ def build_triton(
|
|||||||
)
|
)
|
||||||
print("ROCm libraries setup for triton installation...")
|
print("ROCm libraries setup for triton installation...")
|
||||||
|
|
||||||
check_call(
|
# old triton versions have setup.py in the python/ dir,
|
||||||
[sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir, env=env
|
# new versions have it in the root dir.
|
||||||
|
triton_setupdir = (
|
||||||
|
triton_basedir
|
||||||
|
if (triton_basedir / "setup.py").exists()
|
||||||
|
else triton_pythondir
|
||||||
)
|
)
|
||||||
|
|
||||||
whl_path = next(iter((triton_pythondir / "dist").glob("*.whl")))
|
check_call(
|
||||||
|
[sys.executable, "setup.py", "bdist_wheel"], cwd=triton_setupdir, env=env
|
||||||
|
)
|
||||||
|
|
||||||
|
whl_path = next(iter((triton_setupdir / "dist").glob("*.whl")))
|
||||||
shutil.copy(whl_path, Path.cwd())
|
shutil.copy(whl_path, Path.cwd())
|
||||||
|
|
||||||
if device == "rocm":
|
if device == "rocm":
|
||||||
|
|||||||
9
.github/workflows/build-triton-wheel.yml
vendored
9
.github/workflows/build-triton-wheel.yml
vendored
@ -139,6 +139,15 @@ jobs:
|
|||||||
|
|
||||||
docker exec -t "${container_name}" yum install -y zlib-devel zip
|
docker exec -t "${container_name}" yum install -y zlib-devel zip
|
||||||
docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip install -U setuptools==78.1.0 pybind11==2.13.1 auditwheel wheel
|
docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip install -U setuptools==78.1.0 pybind11==2.13.1 auditwheel wheel
|
||||||
|
set +e
|
||||||
|
docker exec -t "${container_name}" command -v pip
|
||||||
|
has_pip=$?
|
||||||
|
set -e
|
||||||
|
if [ $has_pip -eq 0 ] ; then
|
||||||
|
docker exec -t "${container_name}" pip install -U cmake --force-reinstall
|
||||||
|
else
|
||||||
|
docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip install -U cmake --force-reinstall
|
||||||
|
fi
|
||||||
|
|
||||||
if [[ ("${{ matrix.device }}" == "cuda" || "${{ matrix.device }}" == "rocm" || "${{ matrix.device }}" == "aarch64" ) ]]; then
|
if [[ ("${{ matrix.device }}" == "cuda" || "${{ matrix.device }}" == "rocm" || "${{ matrix.device }}" == "aarch64" ) ]]; then
|
||||||
# With this install, it gets clang 16.0.6.
|
# With this install, it gets clang 16.0.6.
|
||||||
|
|||||||
53
.github/workflows/h100-distributed.yml
vendored
Normal file
53
.github/workflows/h100-distributed.yml
vendored
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
name: Limited CI for distributed tests on H100
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- .github/workflows/h100-distributed.yml
|
||||||
|
workflow_dispatch:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- ciflow/h100-distributed/*
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
|
||||||
|
get-label-type:
|
||||||
|
if: github.repository_owner == 'pytorch'
|
||||||
|
name: get-label-type
|
||||||
|
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||||
|
with:
|
||||||
|
triggering_actor: ${{ github.triggering_actor }}
|
||||||
|
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||||
|
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||||
|
curr_ref_type: ${{ github.ref_type }}
|
||||||
|
|
||||||
|
linux-focal-cuda12_6-py3_10-gcc11-sm90-build:
|
||||||
|
name: linux-focal-cuda12.6-py3.10-gcc11-sm90
|
||||||
|
uses: ./.github/workflows/_linux-build.yml
|
||||||
|
needs: get-label-type
|
||||||
|
with:
|
||||||
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
|
runner: "linux.12xlarge"
|
||||||
|
build-environment: linux-focal-cuda12.6-py3.10-gcc11-sm90
|
||||||
|
docker-image-name: ci-image:pytorch-linux-focal-cuda12.6-cudnn9-py3-gcc11
|
||||||
|
cuda-arch-list: '9.0'
|
||||||
|
test-matrix: |
|
||||||
|
{ include: [
|
||||||
|
{ config: "h100_distributed", shard: 1, num_shards: 1, runner: "linux.aws.h100.8" },
|
||||||
|
]}
|
||||||
|
secrets: inherit
|
||||||
|
|
||||||
|
linux-focal-cuda12_6-py3_10-gcc11-sm90-test:
|
||||||
|
name: linux-focal-cuda12.6-py3.10-gcc11-sm90
|
||||||
|
uses: ./.github/workflows/_linux-test.yml
|
||||||
|
needs:
|
||||||
|
- linux-focal-cuda12_6-py3_10-gcc11-sm90-build
|
||||||
|
with:
|
||||||
|
build-environment: linux-focal-cuda12.6-py3.10-gcc11-sm90
|
||||||
|
docker-image: ${{ needs.linux-focal-cuda12_6-py3_10-gcc11-sm90-build.outputs.docker-image }}
|
||||||
|
test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc11-sm90-build.outputs.test-matrix }}
|
||||||
|
secrets: inherit
|
||||||
@ -64,6 +64,7 @@ include_patterns = [
|
|||||||
'aten/src/ATen/xpu/**/*.cpp',
|
'aten/src/ATen/xpu/**/*.cpp',
|
||||||
'aten/src/ATen/core/boxing/**/*.h',
|
'aten/src/ATen/core/boxing/**/*.h',
|
||||||
'aten/src/ATen/core/dispatch/**/*.h',
|
'aten/src/ATen/core/dispatch/**/*.h',
|
||||||
|
'aten/src/ATen/core/Formatting.cpp',
|
||||||
'aten/src/ATen/native/mps/**/*.metal',
|
'aten/src/ATen/native/mps/**/*.metal',
|
||||||
'aten/src/ATen/native/mps/**/*.mm',
|
'aten/src/ATen/native/mps/**/*.mm',
|
||||||
'aten/src/ATen/native/mps/**/*.h',
|
'aten/src/ATen/native/mps/**/*.h',
|
||||||
|
|||||||
@ -290,6 +290,7 @@ header_template_rule(
|
|||||||
substitutions = {
|
substitutions = {
|
||||||
"@AT_CUDNN_ENABLED@": "1",
|
"@AT_CUDNN_ENABLED@": "1",
|
||||||
"@AT_CUSPARSELT_ENABLED@": "0",
|
"@AT_CUSPARSELT_ENABLED@": "0",
|
||||||
|
"@AT_HIPSPARSELT_ENABLED@": "0",
|
||||||
"@AT_ROCM_ENABLED@": "0",
|
"@AT_ROCM_ENABLED@": "0",
|
||||||
"@AT_MAGMA_ENABLED@": "0",
|
"@AT_MAGMA_ENABLED@": "0",
|
||||||
"@NVCC_FLAGS_EXTRA@": "",
|
"@NVCC_FLAGS_EXTRA@": "",
|
||||||
|
|||||||
@ -101,6 +101,13 @@ else()
|
|||||||
set(AT_CUSPARSELT_ENABLED 1)
|
set(AT_CUSPARSELT_ENABLED 1)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# Add hipSPARSELt support flag
|
||||||
|
if(USE_ROCM AND ROCM_VERSION VERSION_GREATER_EQUAL "6.4.0")
|
||||||
|
set(AT_HIPSPARSELT_ENABLED 1)
|
||||||
|
else()
|
||||||
|
set(AT_HIPSPARSELT_ENABLED 0)
|
||||||
|
endif()
|
||||||
|
|
||||||
list(APPEND ATen_CPU_INCLUDE
|
list(APPEND ATen_CPU_INCLUDE
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/src)
|
${CMAKE_CURRENT_SOURCE_DIR}/src)
|
||||||
add_subdirectory(src/ATen)
|
add_subdirectory(src/ATen)
|
||||||
|
|||||||
@ -34,6 +34,7 @@ set_bool(AT_MAGMA_ENABLED USE_MAGMA)
|
|||||||
set_bool(CAFFE2_STATIC_LINK_CUDA_INT CAFFE2_STATIC_LINK_CUDA)
|
set_bool(CAFFE2_STATIC_LINK_CUDA_INT CAFFE2_STATIC_LINK_CUDA)
|
||||||
set_bool(AT_CUDNN_ENABLED CAFFE2_USE_CUDNN)
|
set_bool(AT_CUDNN_ENABLED CAFFE2_USE_CUDNN)
|
||||||
set_bool(AT_CUSPARSELT_ENABLED CAFFE2_USE_CUSPARSELT)
|
set_bool(AT_CUSPARSELT_ENABLED CAFFE2_USE_CUSPARSELT)
|
||||||
|
set_bool(AT_HIPSPARSELT_ENABLED CAFFE2_USE_HIPSPARSELT)
|
||||||
|
|
||||||
configure_file(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h")
|
configure_file(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h")
|
||||||
# TODO: Do not generate CUDAConfig.h for ROCm BUILDS
|
# TODO: Do not generate CUDAConfig.h for ROCm BUILDS
|
||||||
|
|||||||
@ -28,8 +28,7 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) {
|
|||||||
opt_device_type = at::getAccelerator(false);
|
opt_device_type = at::getAccelerator(false);
|
||||||
}
|
}
|
||||||
if (opt_device_type.has_value()) {
|
if (opt_device_type.has_value()) {
|
||||||
return at::globalContext().getPinnedMemoryAllocator(
|
return at::globalContext().getPinnedMemoryAllocator(opt_device_type);
|
||||||
opt_device_type.value());
|
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false, "Need to provide pin_memory allocator to use pin memory.")
|
false, "Need to provide pin_memory allocator to use pin memory.")
|
||||||
|
|||||||
@ -1,18 +1,22 @@
|
|||||||
#include <ATen/core/Formatting.h>
|
#include <ATen/core/Formatting.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
#include <fmt/compile.h>
|
||||||
|
#include <fmt/format.h>
|
||||||
|
#include <fmt/ostream.h>
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <iomanip>
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <iterator>
|
||||||
|
#include <string>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
std::ostream& operator<<(std::ostream & out, Backend b) {
|
std::ostream& operator<<(std::ostream& out, Backend b) {
|
||||||
return out << toString(b);
|
return out << toString(b);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream & out, const Scalar& s) {
|
std::ostream& operator<<(std::ostream& out, const Scalar& s) {
|
||||||
if (s.isFloatingPoint()) {
|
if (s.isFloatingPoint()) {
|
||||||
return out << s.toDouble();
|
return out << s.toDouble();
|
||||||
}
|
}
|
||||||
@ -35,179 +39,189 @@ std::ostream& operator<<(std::ostream & out, const Scalar& s) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string toString(const Scalar& s) {
|
std::string toString(const Scalar& s) {
|
||||||
std::stringstream out;
|
return fmt::format("{}", fmt::streamed(s));
|
||||||
out << s;
|
|
||||||
return std::move(out).str();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
} // namespace c10
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|
||||||
//not all C++ compilers have default float so we define our own here
|
std::ostream& operator<<(std::ostream& out, const DeprecatedTypeProperties& t) {
|
||||||
inline static std::ios_base& defaultfloat(std::ios_base& __base) {
|
|
||||||
__base.unsetf(std::ios_base::floatfield);
|
|
||||||
return __base;
|
|
||||||
}
|
|
||||||
//saves/restores number formatting inside scope
|
|
||||||
struct FormatGuard {
|
|
||||||
FormatGuard(std::ostream & out)
|
|
||||||
: out(out) {
|
|
||||||
saved.copyfmt(out);
|
|
||||||
}
|
|
||||||
~FormatGuard() {
|
|
||||||
out.copyfmt(saved);
|
|
||||||
}
|
|
||||||
FormatGuard(const FormatGuard&) = delete;
|
|
||||||
FormatGuard(FormatGuard&&) = delete;
|
|
||||||
FormatGuard& operator=(const FormatGuard&) = delete;
|
|
||||||
FormatGuard& operator=(FormatGuard&&) = delete;
|
|
||||||
private:
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
|
||||||
std::ostream & out;
|
|
||||||
std::ios saved{nullptr};
|
|
||||||
};
|
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream & out, const DeprecatedTypeProperties& t) {
|
|
||||||
return out << t.toString();
|
return out << t.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<double, int> __printFormat(std::ostream& stream, const Tensor& self) {
|
enum class FormatType {
|
||||||
|
Default, // 'g' format (defaultfloat equivalent)
|
||||||
|
Scientific, // 'e' format with precision 4
|
||||||
|
Fixed // 'f' format with precision 4
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PrintFormat {
|
||||||
|
double scale;
|
||||||
|
int width;
|
||||||
|
FormatType type;
|
||||||
|
|
||||||
|
PrintFormat(double s, int w, FormatType t = FormatType::Default)
|
||||||
|
: scale(s), width(w), type(t) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
static PrintFormat __printFormat(const Tensor& self) {
|
||||||
auto size = self.numel();
|
auto size = self.numel();
|
||||||
if(size == 0) {
|
if (size == 0) {
|
||||||
return std::make_tuple(1., 0);
|
return PrintFormat(1., 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool intMode = true;
|
bool intMode = true;
|
||||||
auto self_p = self.const_data_ptr<double>();
|
auto self_p = self.const_data_ptr<double>();
|
||||||
for (const auto i : c10::irange(size)) {
|
for (const auto i : c10::irange(size)) {
|
||||||
auto z = self_p[i];
|
auto z = self_p[i];
|
||||||
if(std::isfinite(z)) {
|
if (std::isfinite(z)) {
|
||||||
if(z != std::ceil(z)) {
|
if (z != std::ceil(z)) {
|
||||||
intMode = false;
|
intMode = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t offset = 0;
|
int64_t offset = 0;
|
||||||
while(!std::isfinite(self_p[offset])) {
|
while (offset < size && !std::isfinite(self_p[offset])) {
|
||||||
offset = offset + 1;
|
offset = offset + 1;
|
||||||
if(offset == size) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
double expMin = 1;
|
double expMin = 1;
|
||||||
double expMax = 1;
|
double expMax = 1;
|
||||||
if(offset != size) {
|
if (offset != size) {
|
||||||
expMin = fabs(self_p[offset]);
|
expMin = std::fabs(self_p[offset]);
|
||||||
expMax = fabs(self_p[offset]);
|
expMax = std::fabs(self_p[offset]);
|
||||||
for (const auto i : c10::irange(offset, size)) {
|
for (const auto i : c10::irange(offset, size)) {
|
||||||
double z = fabs(self_p[i]);
|
double z = std::fabs(self_p[i]);
|
||||||
if(std::isfinite(z)) {
|
if (std::isfinite(z)) {
|
||||||
if(z < expMin) {
|
expMin = std::min(expMin, z);
|
||||||
expMin = z;
|
expMax = std::max(expMax, z);
|
||||||
}
|
|
||||||
if(self_p[i] > expMax) {
|
|
||||||
expMax = z;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(expMin != 0) {
|
if (expMin != 0) {
|
||||||
expMin = std::floor(std::log10(expMin)) + 1;
|
expMin = std::floor(std::log10(expMin)) + 1;
|
||||||
} else {
|
} else {
|
||||||
expMin = 1;
|
expMin = 1;
|
||||||
}
|
}
|
||||||
if(expMax != 0) {
|
if (expMax != 0) {
|
||||||
expMax = std::floor(std::log10(expMax)) + 1;
|
expMax = std::floor(std::log10(expMax)) + 1;
|
||||||
} else {
|
} else {
|
||||||
expMax = 1;
|
expMax = 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
double scale = 1;
|
double scale = 1;
|
||||||
int sz = 11;
|
int sz = 11;
|
||||||
if(intMode) {
|
|
||||||
if(expMax > 9) {
|
if (intMode) {
|
||||||
|
if (expMax > 9) {
|
||||||
sz = 11;
|
sz = 11;
|
||||||
stream << std::scientific << std::setprecision(4);
|
return PrintFormat(scale, sz, FormatType::Scientific);
|
||||||
} else {
|
} else {
|
||||||
sz = static_cast<int>(expMax) + 1;
|
sz = static_cast<int>(expMax) + 1;
|
||||||
stream << defaultfloat;
|
return PrintFormat(scale, sz, FormatType::Default);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if(expMax-expMin > 4) {
|
if (expMax - expMin > 4) {
|
||||||
sz = 11;
|
sz = 11;
|
||||||
if(std::fabs(expMax) > 99 || std::fabs(expMin) > 99) {
|
if (std::fabs(expMax) > 99 || std::fabs(expMin) > 99) {
|
||||||
sz = sz + 1;
|
sz = sz + 1;
|
||||||
}
|
}
|
||||||
stream << std::scientific << std::setprecision(4);
|
return PrintFormat(scale, sz, FormatType::Scientific);
|
||||||
} else {
|
} else {
|
||||||
if(expMax > 5 || expMax < 0) {
|
if (expMax > 5 || expMax < 0) {
|
||||||
sz = 7;
|
sz = 7;
|
||||||
scale = std::pow(10, expMax-1);
|
scale = std::pow(10, expMax - 1);
|
||||||
stream << std::fixed << std::setprecision(4);
|
return PrintFormat(scale, sz, FormatType::Fixed);
|
||||||
} else {
|
} else {
|
||||||
if(expMax == 0) {
|
if (expMax == 0) {
|
||||||
sz = 7;
|
sz = 7;
|
||||||
} else {
|
} else {
|
||||||
sz = static_cast<int>(expMax) + 6;
|
sz = static_cast<int>(expMax) + 6;
|
||||||
}
|
}
|
||||||
stream << std::fixed << std::setprecision(4);
|
return PrintFormat(scale, sz, FormatType::Fixed);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return std::make_tuple(scale, sz);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void __printIndent(std::ostream &stream, int64_t indent)
|
// Precompiled format specs
|
||||||
{
|
static constexpr auto FMT_G = FMT_COMPILE("{:>{}g}");
|
||||||
for ([[maybe_unused]] const auto i : c10::irange(indent)) {
|
static constexpr auto FMT_E4 = FMT_COMPILE("{:>{}.4e}");
|
||||||
stream << " ";
|
static constexpr auto FMT_F4 = FMT_COMPILE("{:>{}.4f}");
|
||||||
|
|
||||||
|
// Print a single value directly into the stream buffer with no temporaries
|
||||||
|
static void printValue(std::ostream& stream, double v, const PrintFormat& pf) {
|
||||||
|
auto out_it = std::ostreambuf_iterator<char>(stream);
|
||||||
|
double val = v / pf.scale;
|
||||||
|
switch (pf.type) {
|
||||||
|
case FormatType::Default:
|
||||||
|
fmt::format_to(out_it, FMT_G, val, pf.width);
|
||||||
|
break;
|
||||||
|
case FormatType::Scientific:
|
||||||
|
fmt::format_to(out_it, FMT_E4, val, pf.width);
|
||||||
|
break;
|
||||||
|
case FormatType::Fixed:
|
||||||
|
fmt::format_to(out_it, FMT_F4, val, pf.width);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void printScale(std::ostream & stream, double scale) {
|
static void __printMatrix(
|
||||||
FormatGuard guard(stream);
|
std::ostream& stream,
|
||||||
stream << defaultfloat << scale << " *" << '\n';
|
const Tensor& self,
|
||||||
}
|
int64_t linesize,
|
||||||
static void __printMatrix(std::ostream& stream, const Tensor& self, int64_t linesize, int64_t indent)
|
int64_t indent) {
|
||||||
{
|
auto printFmt = __printFormat(self);
|
||||||
auto [scale, sz] = __printFormat(stream, self);
|
|
||||||
|
|
||||||
__printIndent(stream, indent);
|
int64_t nColumnPerLine = (linesize - indent) / (printFmt.width + 1);
|
||||||
int64_t nColumnPerLine = (linesize-indent)/(sz+1);
|
|
||||||
int64_t firstColumn = 0;
|
int64_t firstColumn = 0;
|
||||||
int64_t lastColumn = -1;
|
int64_t lastColumn = -1;
|
||||||
while(firstColumn < self.size(1)) {
|
|
||||||
if(firstColumn + nColumnPerLine <= self.size(1)) {
|
while (firstColumn < self.size(1)) {
|
||||||
|
if (firstColumn + nColumnPerLine <= self.size(1)) {
|
||||||
lastColumn = firstColumn + nColumnPerLine - 1;
|
lastColumn = firstColumn + nColumnPerLine - 1;
|
||||||
} else {
|
} else {
|
||||||
lastColumn = self.size(1) - 1;
|
lastColumn = self.size(1) - 1;
|
||||||
}
|
}
|
||||||
if(nColumnPerLine < self.size(1)) {
|
|
||||||
if(firstColumn != 0) {
|
if (nColumnPerLine < self.size(1)) {
|
||||||
stream << '\n';
|
if (firstColumn != 0) {
|
||||||
|
stream.put('\n');
|
||||||
}
|
}
|
||||||
stream << "Columns " << firstColumn+1 << " to " << lastColumn+1;
|
fmt::print(
|
||||||
__printIndent(stream, indent);
|
stream,
|
||||||
|
"Columns {} to {}{:>{}s}",
|
||||||
|
firstColumn + 1,
|
||||||
|
lastColumn + 1,
|
||||||
|
"", // empty string to pad
|
||||||
|
indent // width to pad to
|
||||||
|
);
|
||||||
}
|
}
|
||||||
if(scale != 1) {
|
|
||||||
printScale(stream,scale);
|
if (printFmt.scale != 1) {
|
||||||
__printIndent(stream, indent);
|
fmt::print(stream, "{} *\n{:>{}s}", printFmt.scale, "", indent);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto l : c10::irange(self.size(0))) {
|
for (const auto l : c10::irange(self.size(0))) {
|
||||||
Tensor row = self.select(0,l);
|
Tensor row = self.select(0, l);
|
||||||
const double *row_ptr = row.const_data_ptr<double>();
|
const double* row_ptr = row.const_data_ptr<double>();
|
||||||
for (const auto c : c10::irange(firstColumn, lastColumn+1)) {
|
|
||||||
stream << std::setw(sz) << row_ptr[c]/scale;
|
for (const auto c : c10::irange(firstColumn, lastColumn + 1)) {
|
||||||
if(c == lastColumn) {
|
printValue(stream, row_ptr[c], printFmt);
|
||||||
stream << '\n';
|
|
||||||
if(l != self.size(0)-1) {
|
if (c == lastColumn) {
|
||||||
if(scale != 1) {
|
stream.put('\n');
|
||||||
__printIndent(stream, indent);
|
if (l != self.size(0) - 1) {
|
||||||
stream << " ";
|
if (printFmt.scale != 1) {
|
||||||
|
fmt::print(stream, "{:>{}s} ", "", indent);
|
||||||
} else {
|
} else {
|
||||||
__printIndent(stream, indent);
|
fmt::print(stream, "{:>{}s}", "", indent);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
stream << " ";
|
stream.put(' ');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -215,20 +229,21 @@ static void __printMatrix(std::ostream& stream, const Tensor& self, int64_t line
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void __printTensor(std::ostream& stream, Tensor& self, int64_t linesize)
|
static void __printTensor(
|
||||||
{
|
std::ostream& stream,
|
||||||
std::vector<int64_t> counter(self.ndimension()-2);
|
Tensor& self,
|
||||||
|
int64_t linesize) {
|
||||||
|
std::vector<int64_t> counter(self.ndimension() - 2, 0);
|
||||||
|
counter[0] = -1;
|
||||||
|
|
||||||
bool start = true;
|
bool start = true;
|
||||||
bool finished = false;
|
bool finished = false;
|
||||||
counter[0] = -1;
|
|
||||||
for (const auto i : c10::irange(1, counter.size())) {
|
while (true) {
|
||||||
counter[i] = 0;
|
for (int64_t i = 0; self.ndimension() - 2; i++) {
|
||||||
}
|
|
||||||
while(true) {
|
|
||||||
for(int64_t i = 0; self.ndimension()-2; i++) {
|
|
||||||
counter[i] = counter[i] + 1;
|
counter[i] = counter[i] + 1;
|
||||||
if(counter[i] >= self.size(i)) {
|
if (counter[i] >= self.size(i)) {
|
||||||
if(i == self.ndimension()-3) {
|
if (i == self.ndimension() - 3) {
|
||||||
finished = true;
|
finished = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -237,108 +252,133 @@ static void __printTensor(std::ostream& stream, Tensor& self, int64_t linesize)
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(finished) {
|
if (finished) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if(start) {
|
if (start) {
|
||||||
start = false;
|
start = false;
|
||||||
} else {
|
} else {
|
||||||
stream << '\n';
|
stream.put('\n');
|
||||||
}
|
}
|
||||||
stream << "(";
|
|
||||||
|
stream.put('(');
|
||||||
Tensor tensor = self;
|
Tensor tensor = self;
|
||||||
for (const auto i : c10::irange(self.ndimension()-2)) {
|
for (const auto i : c10::irange(self.ndimension() - 2)) {
|
||||||
tensor = tensor.select(0, counter[i]);
|
tensor = tensor.select(0, counter[i]);
|
||||||
stream << counter[i]+1 << ",";
|
fmt::print(stream, "{},", counter[i] + 1);
|
||||||
}
|
}
|
||||||
stream << ".,.) = " << '\n';
|
fmt::print(stream, ".,.) = \n");
|
||||||
__printMatrix(stream, tensor, linesize, 1);
|
__printMatrix(stream, tensor, linesize, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void print(const Tensor & t, int64_t linesize) {
|
void print(const Tensor& t, int64_t linesize) {
|
||||||
print(std::cout,t,linesize);
|
print(std::cout, t, linesize);
|
||||||
}
|
}
|
||||||
std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesize) {
|
|
||||||
FormatGuard guard(stream);
|
|
||||||
if(!tensor_.defined()) {
|
|
||||||
stream << "[ Tensor (undefined) ]";
|
|
||||||
} else if (tensor_.is_sparse()) {
|
|
||||||
stream << "[ " << tensor_.toString() << "{}\n";
|
|
||||||
stream << "indices:\n" << tensor_._indices() << "\n";
|
|
||||||
stream << "values:\n" << tensor_._values() << "\n";
|
|
||||||
stream << "size:\n" << tensor_.sizes() << "\n";
|
|
||||||
stream << "]";
|
|
||||||
} else {
|
|
||||||
Tensor tensor;
|
|
||||||
if (tensor_.is_quantized()) {
|
|
||||||
tensor = tensor_.dequantize().to(kCPU, kDouble).contiguous();
|
|
||||||
} else if (tensor_.is_mkldnn()) {
|
|
||||||
stream << "MKLDNN Tensor: ";
|
|
||||||
tensor = tensor_.to_dense().to(kCPU, kDouble).contiguous();
|
|
||||||
} else if (tensor_.is_mps()) {
|
|
||||||
// MPS does not support double tensors, so first copy then convert
|
|
||||||
tensor = tensor_.to(kCPU).to(kDouble).contiguous();
|
|
||||||
} else {
|
|
||||||
tensor = tensor_.to(kCPU, kDouble).contiguous();
|
|
||||||
}
|
|
||||||
if(tensor.ndimension() == 0) {
|
|
||||||
stream << defaultfloat << tensor.const_data_ptr<double>()[0] << '\n';
|
|
||||||
stream << "[ " << tensor_.toString() << "{}";
|
|
||||||
} else if(tensor.ndimension() == 1) {
|
|
||||||
if (tensor.numel() > 0) {
|
|
||||||
auto [scale, sz] = __printFormat(stream, tensor);
|
|
||||||
if(scale != 1) {
|
|
||||||
printScale(stream, scale);
|
|
||||||
}
|
|
||||||
const double* tensor_p = tensor.const_data_ptr<double>();
|
|
||||||
for (const auto i : c10::irange(tensor.size(0))) {
|
|
||||||
stream << std::setw(sz) << tensor_p[i]/scale << '\n';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stream << "[ " << tensor_.toString() << "{" << tensor.size(0) << "}";
|
|
||||||
} else if(tensor.ndimension() == 2) {
|
|
||||||
if (tensor.numel() > 0) {
|
|
||||||
__printMatrix(stream, tensor, linesize, 0);
|
|
||||||
}
|
|
||||||
stream << "[ " << tensor_.toString() << "{" << tensor.size(0) << "," << tensor.size(1) << "}";
|
|
||||||
} else {
|
|
||||||
if (tensor.numel() > 0) {
|
|
||||||
__printTensor(stream, tensor, linesize);
|
|
||||||
}
|
|
||||||
stream << "[ " << tensor_.toString() << "{" << tensor.size(0);
|
|
||||||
for (const auto i : c10::irange(1, tensor.ndimension())) {
|
|
||||||
stream << "," << tensor.size(i);
|
|
||||||
}
|
|
||||||
stream << "}";
|
|
||||||
}
|
|
||||||
if (tensor_.is_quantized()) {
|
|
||||||
stream << ", qscheme: " << toString(tensor_.qscheme());
|
|
||||||
if (tensor_.qscheme() == c10::kPerTensorAffine) {
|
|
||||||
stream << ", scale: " << tensor_.q_scale();
|
|
||||||
stream << ", zero_point: " << tensor_.q_zero_point();
|
|
||||||
} else if (tensor_.qscheme() == c10::kPerChannelAffine ||
|
|
||||||
tensor_.qscheme() == c10::kPerChannelAffineFloatQParams) {
|
|
||||||
stream << ", scales: ";
|
|
||||||
Tensor scales = tensor_.q_per_channel_scales();
|
|
||||||
print(stream, scales, linesize);
|
|
||||||
stream << ", zero_points: ";
|
|
||||||
Tensor zero_points = tensor_.q_per_channel_zero_points();
|
|
||||||
print(stream, zero_points, linesize);
|
|
||||||
stream << ", axis: " << tensor_.q_per_channel_axis();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Proxy check for if autograd was built
|
std::ostream& print(
|
||||||
if (tensor.getIntrusivePtr()->autograd_meta()) {
|
std::ostream& stream,
|
||||||
auto& fw_grad = tensor._fw_grad(/* level */ 0);
|
const Tensor& tensor_,
|
||||||
if (fw_grad.defined()) {
|
int64_t linesize) {
|
||||||
stream << ", tangent:" << '\n' << fw_grad;
|
if (!tensor_.defined()) {
|
||||||
|
fmt::print(stream, "[ Tensor (undefined) ]");
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tensor_.is_sparse()) {
|
||||||
|
fmt::print(stream, "[ {}{{}}\nindices:\n", tensor_.toString());
|
||||||
|
print(stream, tensor_._indices(), linesize);
|
||||||
|
fmt::print(stream, "\nvalues:\n");
|
||||||
|
print(stream, tensor_._values(), linesize);
|
||||||
|
fmt::print(stream, "\nsize:\n{}\n]", fmt::streamed(tensor_.sizes()));
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor tensor;
|
||||||
|
|
||||||
|
if (tensor_.is_quantized()) {
|
||||||
|
tensor = tensor_.dequantize().to(kCPU, kDouble).contiguous();
|
||||||
|
} else if (tensor_.is_mkldnn()) {
|
||||||
|
fmt::print(stream, "MKLDNN Tensor: ");
|
||||||
|
tensor = tensor_.to_dense().to(kCPU, kDouble).contiguous();
|
||||||
|
} else if (tensor_.is_mps()) {
|
||||||
|
// MPS does not support double tensors, so first copy then convert
|
||||||
|
tensor = tensor_.to(kCPU).to(kDouble).contiguous();
|
||||||
|
} else {
|
||||||
|
tensor = tensor_.to(kCPU, kDouble).contiguous();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tensor.ndimension() == 0) {
|
||||||
|
fmt::print(
|
||||||
|
stream,
|
||||||
|
"{}\n[ {}{{}}",
|
||||||
|
tensor.const_data_ptr<double>()[0],
|
||||||
|
tensor_.toString());
|
||||||
|
} else if (tensor.ndimension() == 1) {
|
||||||
|
if (tensor.numel() > 0) {
|
||||||
|
auto printFmt = __printFormat(tensor);
|
||||||
|
if (printFmt.scale != 1) {
|
||||||
|
fmt::print(stream, "{} *\n", printFmt.scale);
|
||||||
|
}
|
||||||
|
const double* tensor_p = tensor.const_data_ptr<double>();
|
||||||
|
for (const auto i : c10::irange(tensor.size(0))) {
|
||||||
|
printValue(stream, tensor_p[i], printFmt);
|
||||||
|
stream.put('\n');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stream << " ]";
|
fmt::print(stream, "[ {}{{{}}}", tensor_.toString(), tensor.size(0));
|
||||||
|
} else if (tensor.ndimension() == 2) {
|
||||||
|
if (tensor.numel() > 0) {
|
||||||
|
__printMatrix(stream, tensor, linesize, 0);
|
||||||
|
}
|
||||||
|
fmt::print(
|
||||||
|
stream,
|
||||||
|
"[ {}{{{},{}}}",
|
||||||
|
tensor_.toString(),
|
||||||
|
tensor.size(0),
|
||||||
|
tensor.size(1));
|
||||||
|
} else {
|
||||||
|
if (tensor.numel() > 0) {
|
||||||
|
__printTensor(stream, tensor, linesize);
|
||||||
|
}
|
||||||
|
fmt::print(stream, "[ {}{{{}", tensor_.toString(), tensor.size(0));
|
||||||
|
for (const auto i : c10::irange(1, tensor.ndimension())) {
|
||||||
|
fmt::print(stream, ",{}", tensor.size(i));
|
||||||
|
}
|
||||||
|
fmt::print(stream, "}}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add quantization info
|
||||||
|
if (tensor_.is_quantized()) {
|
||||||
|
fmt::print(stream, ", qscheme: {}", toString(tensor_.qscheme()));
|
||||||
|
if (tensor_.qscheme() == c10::kPerTensorAffine) {
|
||||||
|
fmt::print(
|
||||||
|
stream,
|
||||||
|
", scale: {}, zero_point: {}",
|
||||||
|
tensor_.q_scale(),
|
||||||
|
tensor_.q_zero_point());
|
||||||
|
} else if (
|
||||||
|
tensor_.qscheme() == c10::kPerChannelAffine ||
|
||||||
|
tensor_.qscheme() == c10::kPerChannelAffineFloatQParams) {
|
||||||
|
fmt::print(stream, ", scales: ");
|
||||||
|
print(stream, tensor_.q_per_channel_scales(), linesize);
|
||||||
|
fmt::print(stream, ", zero_points: ");
|
||||||
|
print(stream, tensor_.q_per_channel_zero_points(), linesize);
|
||||||
|
fmt::print(stream, ", axis: {}", tensor_.q_per_channel_axis());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Proxy check for if autograd was built
|
||||||
|
if (tensor.getIntrusivePtr()->autograd_meta()) {
|
||||||
|
auto& fw_grad = tensor._fw_grad(/* level */ 0);
|
||||||
|
if (fw_grad.defined()) {
|
||||||
|
fmt::print(stream, ", tangent:\n");
|
||||||
|
print(stream, fw_grad, linesize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt::print(stream, " ]");
|
||||||
return stream;
|
return stream;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
} // namespace at
|
||||||
|
|||||||
@ -205,7 +205,7 @@ std::pair<Vectorized<double>, Vectorized<double>> inline interleave2<double>(
|
|||||||
const Vectorized<double>& a,
|
const Vectorized<double>& a,
|
||||||
const Vectorized<double>& b) {
|
const Vectorized<double>& b) {
|
||||||
// inputs:
|
// inputs:
|
||||||
// a = {a0, a1, a3, a3}
|
// a = {a0, a1, a2, a3}
|
||||||
// b = {b0, b1, b2, b3}
|
// b = {b0, b1, b2, b3}
|
||||||
|
|
||||||
// swap lanes:
|
// swap lanes:
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
// only be included from C++ files.
|
// only be included from C++ files.
|
||||||
#define AT_CUDNN_ENABLED() @AT_CUDNN_ENABLED@
|
#define AT_CUDNN_ENABLED() @AT_CUDNN_ENABLED@
|
||||||
#define AT_CUSPARSELT_ENABLED() @AT_CUSPARSELT_ENABLED@
|
#define AT_CUSPARSELT_ENABLED() @AT_CUSPARSELT_ENABLED@
|
||||||
|
#define AT_HIPSPARSELT_ENABLED() @AT_HIPSPARSELT_ENABLED@
|
||||||
#define AT_ROCM_ENABLED() @AT_ROCM_ENABLED@
|
#define AT_ROCM_ENABLED() @AT_ROCM_ENABLED@
|
||||||
#define AT_MAGMA_ENABLED() @AT_MAGMA_ENABLED@
|
#define AT_MAGMA_ENABLED() @AT_MAGMA_ENABLED@
|
||||||
|
|
||||||
|
|||||||
@ -159,6 +159,7 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
|
|||||||
DispatchKey::XLA,
|
DispatchKey::XLA,
|
||||||
DispatchKey::CUDA,
|
DispatchKey::CUDA,
|
||||||
DispatchKey::CPU,
|
DispatchKey::CPU,
|
||||||
|
DispatchKey::PrivateUse1,
|
||||||
});
|
});
|
||||||
|
|
||||||
inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
|
inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
|
||||||
|
|||||||
@ -143,7 +143,7 @@ static Tensor make_feature_noise(const Tensor& input) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool is_fused_kernel_acceptable(const Tensor& input, double p) {
|
static bool is_fused_kernel_acceptable(const Tensor& input, double p) {
|
||||||
return (input.is_cuda() || input.is_xpu() || input.is_lazy()) && p > 0 && p < 1 && input.numel() > 0;
|
return (input.is_cuda() || input.is_xpu() || input.is_lazy() || input.is_privateuseone()) && p > 0 && p < 1 && input.numel() > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// NB: sure, we could have used different overloads here, but I would feel insecure
|
// NB: sure, we could have used different overloads here, but I would feel insecure
|
||||||
|
|||||||
@ -56,7 +56,8 @@ void dumpTensorCout(const Tensor& tensor) {
|
|||||||
|
|
||||||
static c10::intrusive_ptr<TensorWrapper> makeTensorWrapperPtr(const Tensor& tensor, int64_t level, const std::shared_ptr<bool>& life_handle) {
|
static c10::intrusive_ptr<TensorWrapper> makeTensorWrapperPtr(const Tensor& tensor, int64_t level, const std::shared_ptr<bool>& life_handle) {
|
||||||
auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({
|
auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({
|
||||||
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA});
|
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA,
|
||||||
|
DispatchKey::AutogradPrivateUse1});
|
||||||
auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate);
|
auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate);
|
||||||
key_set = key_set.add(DispatchKey::FuncTorchGradWrapper);
|
key_set = key_set.add(DispatchKey::FuncTorchGradWrapper);
|
||||||
return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, life_handle);
|
return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, life_handle);
|
||||||
@ -76,7 +77,8 @@ static Tensor unsafeMakeTensorWrapper(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({
|
auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({
|
||||||
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA});
|
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA,
|
||||||
|
DispatchKey::AutogradPrivateUse1});
|
||||||
auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate);
|
auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate);
|
||||||
key_set = key_set.add(DispatchKey::FuncTorchGradWrapper);
|
key_set = key_set.add(DispatchKey::FuncTorchGradWrapper);
|
||||||
auto result = at::detail::make_tensor<TensorWrapper>(
|
auto result = at::detail::make_tensor<TensorWrapper>(
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
#include <ATen/miopen/miopen-wrapper.h>
|
#include <ATen/miopen/miopen-wrapper.h>
|
||||||
#include <ATen/core/Tensor.h>
|
#include <ATen/core/Tensor.h>
|
||||||
#include <ATen/TensorUtils.h>
|
#include <ATen/TensorUtils.h>
|
||||||
|
#include <c10/macros/Export.h>
|
||||||
|
|
||||||
namespace at { namespace native {
|
namespace at { namespace native {
|
||||||
|
|
||||||
@ -37,9 +38,9 @@ struct DescriptorDeleter {
|
|||||||
// initialized the first time you call set() or any other initializing
|
// initialized the first time you call set() or any other initializing
|
||||||
// function.
|
// function.
|
||||||
template <typename T, miopenStatus_t (*ctor)(T**), miopenStatus_t (*dtor)(T*)>
|
template <typename T, miopenStatus_t (*ctor)(T**), miopenStatus_t (*dtor)(T*)>
|
||||||
class Descriptor
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
||||||
{
|
class TORCH_CUDA_CPP_API Descriptor {
|
||||||
public:
|
public:
|
||||||
// Use desc() to access the underlying descriptor pointer in
|
// Use desc() to access the underlying descriptor pointer in
|
||||||
// a read-only fashion. Most client code should use this.
|
// a read-only fashion. Most client code should use this.
|
||||||
// If the descriptor was never initialized, this will return
|
// If the descriptor was never initialized, this will return
|
||||||
@ -55,7 +56,7 @@ public:
|
|||||||
protected:
|
protected:
|
||||||
void init() {
|
void init() {
|
||||||
if (desc_ == nullptr) {
|
if (desc_ == nullptr) {
|
||||||
T* raw_desc;
|
T* raw_desc = nullptr;
|
||||||
MIOPEN_CHECK(ctor(&raw_desc));
|
MIOPEN_CHECK(ctor(&raw_desc));
|
||||||
desc_.reset(raw_desc);
|
desc_.reset(raw_desc);
|
||||||
}
|
}
|
||||||
@ -64,13 +65,12 @@ private:
|
|||||||
std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
|
std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TensorDescriptor
|
class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor<
|
||||||
: public Descriptor<miopenTensorDescriptor,
|
miopenTensorDescriptor,
|
||||||
&miopenCreateTensorDescriptor,
|
&miopenCreateTensorDescriptor,
|
||||||
&miopenDestroyTensorDescriptor>
|
&miopenDestroyTensorDescriptor> {
|
||||||
{
|
public:
|
||||||
public:
|
TensorDescriptor() = default;
|
||||||
TensorDescriptor() {}
|
|
||||||
explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
|
explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
|
||||||
set(t, pad);
|
set(t, pad);
|
||||||
}
|
}
|
||||||
@ -88,11 +88,10 @@ private:
|
|||||||
|
|
||||||
std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);
|
std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);
|
||||||
|
|
||||||
class FilterDescriptor
|
class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor<
|
||||||
: public Descriptor<miopenTensorDescriptor,
|
miopenTensorDescriptor,
|
||||||
&miopenCreateTensorDescriptor,
|
&miopenCreateTensorDescriptor,
|
||||||
&miopenDestroyTensorDescriptor>
|
&miopenDestroyTensorDescriptor> {
|
||||||
{
|
|
||||||
public:
|
public:
|
||||||
void set(const at::Tensor &t, int64_t pad = 0) {
|
void set(const at::Tensor &t, int64_t pad = 0) {
|
||||||
set(t, at::MemoryFormat::Contiguous, pad);
|
set(t, at::MemoryFormat::Contiguous, pad);
|
||||||
@ -106,11 +105,11 @@ private:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ConvolutionDescriptor
|
struct TORCH_CUDA_CPP_API ConvolutionDescriptor
|
||||||
: public Descriptor<miopenConvolutionDescriptor,
|
: public Descriptor<
|
||||||
&miopenCreateConvolutionDescriptor,
|
miopenConvolutionDescriptor,
|
||||||
&miopenDestroyConvolutionDescriptor>
|
&miopenCreateConvolutionDescriptor,
|
||||||
{
|
&miopenDestroyConvolutionDescriptor> {
|
||||||
void set(miopenDataType_t dataType, miopenConvolutionMode_t c_mode, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool benchmark, bool deterministic) {
|
void set(miopenDataType_t dataType, miopenConvolutionMode_t c_mode, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool benchmark, bool deterministic) {
|
||||||
MIOPEN_CHECK(miopenInitConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, c_mode));
|
MIOPEN_CHECK(miopenInitConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, c_mode));
|
||||||
MIOPEN_CHECK(miopenSetConvolutionGroupCount(mut_desc(), groups));
|
MIOPEN_CHECK(miopenSetConvolutionGroupCount(mut_desc(), groups));
|
||||||
@ -121,11 +120,12 @@ struct ConvolutionDescriptor
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DropoutDescriptor
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
||||||
: public Descriptor<miopenDropoutDescriptor,
|
struct TORCH_CUDA_CPP_API DropoutDescriptor
|
||||||
&miopenCreateDropoutDescriptor,
|
: public Descriptor<
|
||||||
&miopenDestroyDropoutDescriptor>
|
miopenDropoutDescriptor,
|
||||||
{
|
&miopenCreateDropoutDescriptor,
|
||||||
|
&miopenDestroyDropoutDescriptor> {
|
||||||
void set(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes,
|
void set(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes,
|
||||||
unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) {
|
unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) {
|
||||||
MIOPEN_CHECK(miopenSetDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode));
|
MIOPEN_CHECK(miopenSetDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode));
|
||||||
@ -137,7 +137,7 @@ struct DropoutDescriptor
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct RNNDescriptor
|
struct TORCH_CUDA_CPP_API RNNDescriptor
|
||||||
: public Descriptor<miopenRNNDescriptor,
|
: public Descriptor<miopenRNNDescriptor,
|
||||||
&miopenCreateRNNDescriptor,
|
&miopenCreateRNNDescriptor,
|
||||||
&miopenDestroyRNNDescriptor>
|
&miopenDestroyRNNDescriptor>
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
#include <ATen/miopen/Exceptions.h>
|
|
||||||
#include <ATen/miopen/Handle.h>
|
|
||||||
#include <ATen/hip/detail/DeviceThreadHandles.h>
|
#include <ATen/hip/detail/DeviceThreadHandles.h>
|
||||||
|
#include <ATen/miopen/Handle.h>
|
||||||
#include <c10/hip/HIPStream.h>
|
#include <c10/hip/HIPStream.h>
|
||||||
|
|
||||||
namespace at { namespace native {
|
#include <ATen/hip/Exceptions.h>
|
||||||
|
#include <ATen/miopen/Exceptions.h>
|
||||||
|
|
||||||
|
namespace at::native {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void createMIOpenHandle(miopenHandle_t *handle) {
|
void createMIOpenHandle(miopenHandle_t *handle) {
|
||||||
@ -11,30 +13,33 @@ void createMIOpenHandle(miopenHandle_t *handle) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void destroyMIOpenHandle(miopenHandle_t handle) {
|
void destroyMIOpenHandle(miopenHandle_t handle) {
|
||||||
// this is because of something dumb in the ordering of
|
// this is because of something dumb in the ordering of
|
||||||
// destruction. Sometimes atexit, the cuda context (or something)
|
// destruction. Sometimes atexit, the cuda context (or something)
|
||||||
// would already be destroyed by the time this gets destroyed. It
|
// would already be destroyed by the time this gets destroyed. It
|
||||||
// happens in fbcode setting. @colesbury and I decided to not destroy
|
// happens in fbcode setting. @colesbury and I decided to not destroy
|
||||||
// the handle as a workaround.
|
// the handle as a workaround.
|
||||||
// - @soumith
|
// - @soumith
|
||||||
//
|
//
|
||||||
// Further note: this is now disabled globally, because we are seeing
|
// Further note: this is now disabled globally, because we are seeing
|
||||||
// the same issue as mentioned above in CUDA 11 CI.
|
// the same issue as mentioned above in CUDA 11 CI.
|
||||||
// - @zasdfgbnm
|
// - @zasdfgbnm
|
||||||
//
|
//
|
||||||
// #ifdef NO_MIOPEN_DESTROY_HANDLE
|
// #ifdef NO_MIOPEN_DESTROY_HANDLE
|
||||||
// #else
|
// #else
|
||||||
// miopenDestroy(handle);
|
// miopenDestroy(handle);
|
||||||
// #endif
|
// #endif
|
||||||
}
|
}
|
||||||
|
|
||||||
using MIOpenPoolType = at::cuda::DeviceThreadHandlePool<miopenHandle_t, createMIOpenHandle, destroyMIOpenHandle>;
|
using MIOpenPoolType = at::cuda::DeviceThreadHandlePool<
|
||||||
|
miopenHandle_t,
|
||||||
|
createMIOpenHandle,
|
||||||
|
destroyMIOpenHandle>;
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
miopenHandle_t getMiopenHandle() {
|
miopenHandle_t getMiopenHandle() {
|
||||||
int device;
|
c10::DeviceIndex device = 0;
|
||||||
HIP_CHECK(hipGetDevice(&device));
|
AT_CUDA_CHECK(c10::hip::GetDevice(&device));
|
||||||
|
|
||||||
// Thread local PoolWindows are lazily-initialized
|
// Thread local PoolWindows are lazily-initialized
|
||||||
// to avoid initialization issues that caused hangs on Windows.
|
// to avoid initialization issues that caused hangs on Windows.
|
||||||
@ -46,8 +51,8 @@ miopenHandle_t getMiopenHandle() {
|
|||||||
pool->newPoolWindow());
|
pool->newPoolWindow());
|
||||||
|
|
||||||
auto handle = myPoolWindow->reserve(device);
|
auto handle = myPoolWindow->reserve(device);
|
||||||
MIOPEN_CHECK(miopenSetStream(handle, at::hip::getCurrentHIPStream()));
|
MIOPEN_CHECK(miopenSetStream(handle, c10::hip::getCurrentHIPStream()));
|
||||||
return handle;
|
return handle;
|
||||||
}
|
}
|
||||||
|
|
||||||
}} // namespace at::native
|
} // namespace at::native
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/miopen/miopen-wrapper.h>
|
#include <ATen/miopen/miopen-wrapper.h>
|
||||||
|
#include <c10/macros/Export.h>
|
||||||
|
|
||||||
namespace at { namespace native {
|
namespace at::native {
|
||||||
|
|
||||||
miopenHandle_t getMiopenHandle();
|
TORCH_CUDA_CPP_API miopenHandle_t getMiopenHandle();
|
||||||
|
} // namespace at::native
|
||||||
}} // namespace
|
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/miopen/miopen-wrapper.h>
|
|
||||||
#include <ATen/Tensor.h>
|
#include <ATen/Tensor.h>
|
||||||
|
#include <ATen/miopen/miopen-wrapper.h>
|
||||||
|
#include <c10/macros/Export.h>
|
||||||
|
|
||||||
namespace at { namespace native {
|
namespace at::native {
|
||||||
|
|
||||||
miopenDataType_t getMiopenDataType(const at::Tensor& tensor);
|
TORCH_CUDA_CPP_API miopenDataType_t getMiopenDataType(const at::Tensor& tensor);
|
||||||
|
|
||||||
int64_t miopen_version();
|
int64_t miopen_version();
|
||||||
|
|
||||||
}} // namespace at::miopen
|
} // namespace at::native
|
||||||
|
|||||||
@ -138,7 +138,7 @@ inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
|
|||||||
|
|
||||||
// storageOffset
|
// storageOffset
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
|
TORCH_GUARD_OR_TRUE(sym_ge(storage_offset, 0)), "Tensor: invalid storage offset ", storage_offset);
|
||||||
|
|
||||||
// set_storage_{device} (except set_storage_meta__symint)
|
// set_storage_{device} (except set_storage_meta__symint)
|
||||||
// will (unsafely) set the storage offset and then call resize_impl that
|
// will (unsafely) set the storage offset and then call resize_impl that
|
||||||
|
|||||||
@ -431,7 +431,7 @@ Tensor& set_storage_meta__symint(
|
|||||||
size, stride, storage_offset);
|
size, stride, storage_offset);
|
||||||
|
|
||||||
// Matches maybe_resize_storage_cpu no-numel behavior
|
// Matches maybe_resize_storage_cpu no-numel behavior
|
||||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(result.sym_numel().sym_ne(0))) {
|
if (TORCH_GUARD_OR_TRUE(result.sym_numel().sym_ne(0))) {
|
||||||
// maybe_resize_storage_cpu can handle no storage exists at all but
|
// maybe_resize_storage_cpu can handle no storage exists at all but
|
||||||
// that should never be the case here
|
// that should never be the case here
|
||||||
TORCH_INTERNAL_ASSERT(storage);
|
TORCH_INTERNAL_ASSERT(storage);
|
||||||
@ -440,12 +440,7 @@ Tensor& set_storage_meta__symint(
|
|||||||
// All meta data pointers are the same, so we don't have to "re" allocate
|
// All meta data pointers are the same, so we don't have to "re" allocate
|
||||||
// it. TODO: Actually this might not quite be correct if we use special
|
// it. TODO: Actually this might not quite be correct if we use special
|
||||||
// pointers to track whether or not fake cuda tensors are pinned or not
|
// pointers to track whether or not fake cuda tensors are pinned or not
|
||||||
const auto itemsize = result.dtype().itemsize();
|
|
||||||
c10::SymInt new_size_bytes = result.is_contiguous()
|
|
||||||
? at::detail::computeStorageNbytesContiguous(
|
|
||||||
size, itemsize, std::move(storage_offset))
|
|
||||||
: at::detail::computeStorageNbytes(
|
|
||||||
size, stride, itemsize, std::move(storage_offset));
|
|
||||||
// TODO: When there are unbacked SymInts, we unconditionally skip the
|
// TODO: When there are unbacked SymInts, we unconditionally skip the
|
||||||
// setter. This is technically wrong, but we cannot conveniently test
|
// setter. This is technically wrong, but we cannot conveniently test
|
||||||
// the real condition in many cases, because a lot of people are using
|
// the real condition in many cases, because a lot of people are using
|
||||||
@ -454,10 +449,20 @@ Tensor& set_storage_meta__symint(
|
|||||||
//
|
//
|
||||||
// The old behavior was to unconditionally set_nbytes, but I think not
|
// The old behavior was to unconditionally set_nbytes, but I think not
|
||||||
// setting it is more safe.
|
// setting it is more safe.
|
||||||
if (new_size_bytes.has_hint() && storage.sym_nbytes().has_hint() &&
|
if (result.sym_numel().has_hint()) {
|
||||||
TORCH_GUARD_SIZE_OBLIVIOUS(
|
const auto itemsize = result.dtype().itemsize();
|
||||||
new_size_bytes.sym_gt(storage.sym_nbytes()))) {
|
|
||||||
storage.set_nbytes(std::move(new_size_bytes));
|
c10::SymInt new_size_bytes = result.is_contiguous()
|
||||||
|
? at::detail::computeStorageNbytesContiguous(
|
||||||
|
size, itemsize, std::move(storage_offset))
|
||||||
|
: at::detail::computeStorageNbytes(
|
||||||
|
size, stride, itemsize, std::move(storage_offset));
|
||||||
|
|
||||||
|
if (new_size_bytes.has_hint() && storage.sym_nbytes().has_hint() &&
|
||||||
|
TORCH_GUARD_SIZE_OBLIVIOUS(
|
||||||
|
new_size_bytes.sym_gt(storage.sym_nbytes()))) {
|
||||||
|
storage.set_nbytes(std::move(new_size_bytes));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
|||||||
@ -345,8 +345,8 @@ static inline void launch_vectorized_kernel(
|
|||||||
auto output_calc = TrivialOffsetCalculator<1>();
|
auto output_calc = TrivialOffsetCalculator<1>();
|
||||||
auto loader = memory::LoadWithoutCast();
|
auto loader = memory::LoadWithoutCast();
|
||||||
auto storer = memory::StoreWithoutCast();
|
auto storer = memory::StoreWithoutCast();
|
||||||
int64_t grid_unrolled = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
|
int64_t grid_unrolled = (N + elementwise_block_work_size() - 1) / elementwise_block_work_size();
|
||||||
unrolled_elementwise_kernel<func_t, array_t, elems_per_thread<io_size>()>
|
unrolled_elementwise_kernel<func_t, array_t, elementwise_thread_work_size()>
|
||||||
<<<grid_unrolled, num_threads(), 0, stream>>>(
|
<<<grid_unrolled, num_threads(), 0, stream>>>(
|
||||||
N, f, data, input_calc, output_calc, loader, storer);
|
N, f, data, input_calc, output_calc, loader, storer);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
|||||||
@ -28,9 +28,15 @@ __device__ inline int min(int a, int b) {
|
|||||||
return a <= b ? a : b;
|
return a <= b ? a : b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
#define CUDA_MAX_THREADS 256
|
||||||
|
#define BLOCK_STRIDE_FWD 2 // increasing block_stride to lower # of blocks launched
|
||||||
|
#define BLOCK_STRIDE_BWD 4 // increasing block_stride to lower # of blocks launched
|
||||||
|
#else
|
||||||
#define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit
|
#define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit
|
||||||
|
#define BLOCK_STRIDE_FWD 2 // increasing block_stride to lower # of blocks launched
|
||||||
#define BLOCK_STRIDE 2 // increasing block_stride to lower # of blocks launched
|
#define BLOCK_STRIDE_BWD 2 // increasing block_stride to lower # of blocks launched
|
||||||
|
#endif
|
||||||
|
|
||||||
static __device__ inline int p_start(int size, int pad, int kernel, int dilation, int stride) {
|
static __device__ inline int p_start(int size, int pad, int kernel, int dilation, int stride) {
|
||||||
return (size + pad < ((kernel - 1) * dilation + 1)) ? 0 : (size + pad - ((kernel - 1) * dilation + 1)) / stride + 1;
|
return (size + pad < ((kernel - 1) * dilation + 1)) ? 0 : (size + pad - ((kernel - 1) * dilation + 1)) / stride + 1;
|
||||||
@ -464,10 +470,10 @@ const Tensor& indices) {
|
|||||||
int grid_x = nbatch*kernel_stride_C;
|
int grid_x = nbatch*kernel_stride_C;
|
||||||
int grid_y = std::min<int>(
|
int grid_y = std::min<int>(
|
||||||
at::cuda::getCurrentDeviceProperties()->maxGridSize[1],
|
at::cuda::getCurrentDeviceProperties()->maxGridSize[1],
|
||||||
ceil_div(safe_downcast<int, int64_t>(outputWidth), block_y*BLOCK_STRIDE));
|
ceil_div(safe_downcast<int, int64_t>(outputWidth), block_y*BLOCK_STRIDE_FWD));
|
||||||
int grid_z = std::min<int>(
|
int grid_z = std::min<int>(
|
||||||
at::cuda::getCurrentDeviceProperties()->maxGridSize[2],
|
at::cuda::getCurrentDeviceProperties()->maxGridSize[2],
|
||||||
ceil_div(safe_downcast<int, int64_t>(outputHeight), block_z*BLOCK_STRIDE));
|
ceil_div(safe_downcast<int, int64_t>(outputHeight), block_z*BLOCK_STRIDE_FWD));
|
||||||
const dim3 grid(grid_x, grid_y, grid_z);
|
const dim3 grid(grid_x, grid_y, grid_z);
|
||||||
|
|
||||||
size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof(int) + sizeof(scalar_t));
|
size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof(int) + sizeof(scalar_t));
|
||||||
@ -599,10 +605,10 @@ const Tensor& gradInput) {
|
|||||||
int grid_x = nbatch*kernel_stride_C;
|
int grid_x = nbatch*kernel_stride_C;
|
||||||
int grid_y = std::min<int>(
|
int grid_y = std::min<int>(
|
||||||
at::cuda::getCurrentDeviceProperties()->maxGridSize[1],
|
at::cuda::getCurrentDeviceProperties()->maxGridSize[1],
|
||||||
ceil_div(safe_downcast<int, int64_t>(inputWidth), block_y*BLOCK_STRIDE));
|
ceil_div(safe_downcast<int, int64_t>(inputWidth), block_y*BLOCK_STRIDE_BWD));
|
||||||
int grid_z = std::min<int>(
|
int grid_z = std::min<int>(
|
||||||
at::cuda::getCurrentDeviceProperties()->maxGridSize[2],
|
at::cuda::getCurrentDeviceProperties()->maxGridSize[2],
|
||||||
ceil_div(safe_downcast<int, int64_t>(inputHeight), block_z*BLOCK_STRIDE));
|
ceil_div(safe_downcast<int, int64_t>(inputHeight), block_z*BLOCK_STRIDE_BWD));
|
||||||
const dim3 grid(grid_x, grid_y, grid_z);
|
const dim3 grid(grid_x, grid_y, grid_z);
|
||||||
|
|
||||||
size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * sizeof(accscalar_t);
|
size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * sizeof(accscalar_t);
|
||||||
|
|||||||
@ -1159,7 +1159,8 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
|
|||||||
config.ctas_per_output = div_up(num_mp, 2);
|
config.ctas_per_output = div_up(num_mp, 2);
|
||||||
else if (config.ctas_per_output < 16)
|
else if (config.ctas_per_output < 16)
|
||||||
config.ctas_per_output = 1;
|
config.ctas_per_output = 1;
|
||||||
if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension)
|
bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast);
|
||||||
|
if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last)
|
||||||
config.ctas_per_output = 4;
|
config.ctas_per_output = 4;
|
||||||
#endif
|
#endif
|
||||||
if (config.ctas_per_output > 1) {
|
if (config.ctas_per_output > 1) {
|
||||||
|
|||||||
594
aten/src/ATen/native/mkldnn/xpu/detail/DnnlExt.h
Normal file
594
aten/src/ATen/native/mkldnn/xpu/detail/DnnlExt.h
Normal file
@ -0,0 +1,594 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
|
||||||
|
#include <ATen/native/mkldnn/xpu/detail/LRUCache.h>
|
||||||
|
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
|
||||||
|
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
|
||||||
|
|
||||||
|
#include <oneapi/dnnl/dnnl.h>
|
||||||
|
#include <oneapi/dnnl/dnnl.hpp>
|
||||||
|
|
||||||
|
namespace std {
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct hash<dnnl::memory::dims> {
|
||||||
|
size_t operator()(dnnl::memory::dims const& vec) const {
|
||||||
|
size_t seed = vec.size();
|
||||||
|
for (auto& i : vec) {
|
||||||
|
seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||||
|
}
|
||||||
|
return seed;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace std
|
||||||
|
|
||||||
|
using namespace dnnl;
|
||||||
|
|
||||||
|
namespace at::native::onednn {
|
||||||
|
|
||||||
|
class primitive_ext : public primitive {
|
||||||
|
static constexpr int max_args = 12;
|
||||||
|
|
||||||
|
public:
|
||||||
|
primitive_ext(const primitive& base) : primitive(base) {}
|
||||||
|
primitive_ext(primitive&& base) : primitive(std::move(base)) {}
|
||||||
|
|
||||||
|
/// Returns a memory descriptor.
|
||||||
|
///
|
||||||
|
/// @note
|
||||||
|
/// There are also convenience methods
|
||||||
|
/// #dnnl::primitive_desc_base::src_desc(),
|
||||||
|
/// #dnnl::primitive_desc_base::dst_desc(), and others.
|
||||||
|
///
|
||||||
|
/// @param what The kind of parameter to query; can be
|
||||||
|
/// #dnnl::query::src_md, #dnnl::query::dst_md, etc.
|
||||||
|
/// @param idx Index of the parameter. For example, convolution bias can
|
||||||
|
/// be queried with what = #dnnl::query::weights_md and idx = 1.
|
||||||
|
/// @returns The requested memory descriptor.
|
||||||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||||||
|
/// parameter of the specified kind or index.
|
||||||
|
const_dnnl_memory_desc_t query_md(query what, int idx = 0) const {
|
||||||
|
std::vector<query> valid_q{
|
||||||
|
query::src_md,
|
||||||
|
query::diff_src_md,
|
||||||
|
query::weights_md,
|
||||||
|
query::diff_weights_md,
|
||||||
|
query::dst_md,
|
||||||
|
query::diff_dst_md,
|
||||||
|
query::workspace_md,
|
||||||
|
query::scratchpad_md,
|
||||||
|
query::exec_arg_md};
|
||||||
|
if (!std::any_of(valid_q.cbegin(), valid_q.cend(), [=](query q) {
|
||||||
|
return what == q;
|
||||||
|
}))
|
||||||
|
DNNL_THROW_ERROR(
|
||||||
|
dnnl_invalid_arguments, "memory descriptor query is invalid");
|
||||||
|
|
||||||
|
const_dnnl_memory_desc_t cdesc = dnnl_primitive_desc_query_md(
|
||||||
|
this->get_primitive_desc(), dnnl::convert_to_c(what), idx);
|
||||||
|
|
||||||
|
return cdesc ? cdesc : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a source memory descriptor.
|
||||||
|
/// @param idx Source index.
|
||||||
|
/// @returns Source memory descriptor.
|
||||||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||||||
|
/// source parameter with index @p idx.
|
||||||
|
const_dnnl_memory_desc_t src_desc(int idx) const {
|
||||||
|
return query_md(query::src_md, idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a destination memory descriptor.
|
||||||
|
/// @param idx Destination index.
|
||||||
|
/// @returns Destination memory descriptor.
|
||||||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||||||
|
/// destination parameter with index @p idx.
|
||||||
|
const_dnnl_memory_desc_t dst_desc(int idx) const {
|
||||||
|
return query_md(query::dst_md, idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a weights memory descriptor.
|
||||||
|
/// @param idx Weights index.
|
||||||
|
/// @returns Weights memory descriptor.
|
||||||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||||||
|
/// weights parameter with index @p idx.
|
||||||
|
const_dnnl_memory_desc_t weights_desc(int idx) const {
|
||||||
|
return query_md(query::weights_md, idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a diff source memory descriptor.
|
||||||
|
/// @param idx Diff source index.
|
||||||
|
/// @returns Diff source memory descriptor.
|
||||||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||||||
|
/// diff source parameter with index @p idx.
|
||||||
|
const_dnnl_memory_desc_t diff_src_desc(int idx) const {
|
||||||
|
return query_md(query::diff_src_md, idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a diff destination memory descriptor.
|
||||||
|
/// @param idx Diff destination index.
|
||||||
|
/// @returns Diff destination memory descriptor.
|
||||||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||||||
|
/// diff destination parameter with index @p idx.
|
||||||
|
const_dnnl_memory_desc_t diff_dst_desc(int idx) const {
|
||||||
|
return query_md(query::diff_dst_md, idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a diff weights memory descriptor.
|
||||||
|
/// @param idx Diff weights index.
|
||||||
|
/// @returns Diff weights memory descriptor.
|
||||||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||||||
|
/// diff weights parameter with index @p idx.
|
||||||
|
const_dnnl_memory_desc_t diff_weights_desc(int idx) const {
|
||||||
|
return query_md(query::diff_weights_md, idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
const_dnnl_memory_desc_t exec_arg_desc(int idx) const {
|
||||||
|
return query_md(query::exec_arg_md, idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Separate versions without the index argument for documentation
|
||||||
|
// purposes.
|
||||||
|
|
||||||
|
/// Returns a source memory descriptor.
|
||||||
|
/// @returns Source memory descriptor.
|
||||||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||||||
|
/// source parameter.
|
||||||
|
const_dnnl_memory_desc_t src_desc() const {
|
||||||
|
return src_desc(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a destination memory descriptor.
|
||||||
|
/// @returns Destination memory descriptor.
|
||||||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||||||
|
/// destination parameter.
|
||||||
|
const_dnnl_memory_desc_t dst_desc() const {
|
||||||
|
return dst_desc(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a weights memory descriptor.
|
||||||
|
/// @returns Weights memory descriptor.
|
||||||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||||||
|
/// weights parameter.
|
||||||
|
const_dnnl_memory_desc_t weights_desc() const {
|
||||||
|
return weights_desc(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a diff source memory descriptor.
|
||||||
|
/// @returns Diff source memory descriptor.
|
||||||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||||||
|
/// diff source memory with.
|
||||||
|
const_dnnl_memory_desc_t diff_src_desc() const {
|
||||||
|
return diff_src_desc(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a diff destination memory descriptor.
|
||||||
|
/// @returns Diff destination memory descriptor.
|
||||||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||||||
|
/// diff destination parameter.
|
||||||
|
const_dnnl_memory_desc_t diff_dst_desc() const {
|
||||||
|
return diff_dst_desc(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a diff weights memory descriptor.
|
||||||
|
/// @returns Diff weights memory descriptor.
|
||||||
|
/// @returns A zero memory descriptor if the primitive does not have a
|
||||||
|
/// diff weights parameter.
|
||||||
|
const_dnnl_memory_desc_t diff_weights_desc() const {
|
||||||
|
return diff_weights_desc(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the workspace memory descriptor.
|
||||||
|
/// @returns Workspace memory descriptor.
|
||||||
|
/// @returns A zero memory descriptor if the primitive does not require
|
||||||
|
/// workspace parameter.
|
||||||
|
const_dnnl_memory_desc_t workspace_desc() const {
|
||||||
|
return query_md(query::workspace_md, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the scratchpad memory descriptor.
|
||||||
|
/// @returns scratchpad memory descriptor.
|
||||||
|
/// @returns A zero memory descriptor if the primitive does not require
|
||||||
|
/// scratchpad parameter.
|
||||||
|
/// @sa @ref dev_guide_attributes_scratchpad
|
||||||
|
const_dnnl_memory_desc_t scratchpad_desc() const {
|
||||||
|
return query_md(query::scratchpad_md, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline memory make_memory(
|
||||||
|
const_dnnl_memory_desc_t md_t,
|
||||||
|
const engine& aengine,
|
||||||
|
void* handle = DNNL_MEMORY_ALLOCATE) const {
|
||||||
|
sycl_interop::memory_kind kind = dnnl::sycl_interop::memory_kind::usm;
|
||||||
|
dnnl_memory_t c_memory;
|
||||||
|
error::wrap_c_api(
|
||||||
|
dnnl_sycl_interop_memory_create(
|
||||||
|
&c_memory, md_t, aengine.get(), convert_to_c(kind), handle),
|
||||||
|
"could not create a memory");
|
||||||
|
return memory(c_memory);
|
||||||
|
}
|
||||||
|
|
||||||
|
memory make_src(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE)
|
||||||
|
const {
|
||||||
|
return make_memory(src_desc(), aengine, handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
memory make_weight(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE)
|
||||||
|
const {
|
||||||
|
return make_memory(weights_desc(), aengine, handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
memory make_bias(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE)
|
||||||
|
const {
|
||||||
|
return make_memory(weights_desc(1), aengine, handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
memory make_dst(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE)
|
||||||
|
const {
|
||||||
|
return make_memory(dst_desc(), aengine, handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
memory make_scratchpad(
|
||||||
|
const engine& aengine,
|
||||||
|
void* handle = DNNL_MEMORY_ALLOCATE) const {
|
||||||
|
return make_memory(scratchpad_desc(), aengine, handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t get_scratchpad_size() const {
|
||||||
|
return dnnl_memory_desc_get_size(scratchpad_desc());
|
||||||
|
}
|
||||||
|
|
||||||
|
memory make_args(int arg_class, const engine& aengine, void* handle) const {
|
||||||
|
switch (arg_class) {
|
||||||
|
case DNNL_ARG_SRC:
|
||||||
|
return make_src(aengine, handle);
|
||||||
|
case DNNL_ARG_WEIGHTS:
|
||||||
|
return make_weight(aengine, handle);
|
||||||
|
case DNNL_ARG_SCRATCHPAD:
|
||||||
|
return make_scratchpad(aengine, handle);
|
||||||
|
case DNNL_ARG_DST:
|
||||||
|
return make_dst(aengine, handle);
|
||||||
|
case DNNL_ARG_BIAS:
|
||||||
|
return make_bias(aengine, handle);
|
||||||
|
default:
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
false, "unsupported argument class for primitive_ext");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename M>
|
||||||
|
void set_attribute(int slot, int arg_class, void* handle, M constructor) {
|
||||||
|
if (mem_arg_cache[slot])
|
||||||
|
mem_arg_cache[slot].set_data_handle(handle);
|
||||||
|
else {
|
||||||
|
mem_arg_cache[slot] = constructor();
|
||||||
|
c_args[slot].arg = arg_class;
|
||||||
|
c_args[slot].memory = mem_arg_cache[slot].get();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sycl::event execute(
|
||||||
|
const stream& astream,
|
||||||
|
const engine& aengine,
|
||||||
|
std::vector<std::pair<int, void*>>&& handles,
|
||||||
|
int slot_off = 2) {
|
||||||
|
auto off = slot_off;
|
||||||
|
for (const auto& p : handles) {
|
||||||
|
auto& m_arg = mem_arg_cache[off];
|
||||||
|
if (m_arg)
|
||||||
|
m_arg.set_data_handle(p.second);
|
||||||
|
else {
|
||||||
|
m_arg = make_args(p.first, aengine, p.second);
|
||||||
|
c_args[off].arg = p.first;
|
||||||
|
c_args[off].memory = m_arg.get();
|
||||||
|
}
|
||||||
|
++off;
|
||||||
|
}
|
||||||
|
|
||||||
|
sycl::event return_event;
|
||||||
|
std::vector<sycl::event> deps{};
|
||||||
|
error::wrap_c_api(
|
||||||
|
dnnl_sycl_interop_primitive_execute(
|
||||||
|
this->get(), astream.get(), off, c_args, &deps, &return_event),
|
||||||
|
"could not execute a primitive");
|
||||||
|
return return_event;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
memory mem_arg_cache[max_args];
|
||||||
|
dnnl_exec_arg_t c_args[max_args];
|
||||||
|
};
|
||||||
|
|
||||||
|
// Specifies the combined data types of input and weight tensors.
|
||||||
|
// For example, f32 means both input and weight are FP32,
|
||||||
|
// bf16_int4 means input is BF16 and weight is INT4.
|
||||||
|
enum class joint_dtypes_t { f32 = 0, f16, bf16, int8, f16_int4, bf16_int4 };
|
||||||
|
|
||||||
|
// Specifies the transposition state of input and weight tensors.
|
||||||
|
// Convention: first letter = input, second letter = weight.
|
||||||
|
// 'n' = not transposed, 't' = transposed.
|
||||||
|
// For example, 'nt' means input is not transposed, weight is transposed.
|
||||||
|
enum class trans_type_t { nn = 0, nt, tn, tt };
|
||||||
|
|
||||||
|
// Specifies the type and placement of bias in the computation.
|
||||||
|
// 'none' = no bias,
|
||||||
|
// 'scalar' = a single scalar bias applied to all elements,
|
||||||
|
// 'm' = per-row bias (typically matched to input rows),
|
||||||
|
// 'n' = per-column bias (typically matched to output channels),
|
||||||
|
// 'mn' = full bias matrix matching the output dimensions.
|
||||||
|
enum class bias_type_t { none = 0, scalar, m, n, mn };
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T concat(const T& t1, at::ScalarType d) {
|
||||||
|
T t;
|
||||||
|
t.insert(t.end(), t1.begin(), t1.end());
|
||||||
|
t.push_back((int64_t)d);
|
||||||
|
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T concat(const T& t1, bool b) {
|
||||||
|
T t;
|
||||||
|
t.insert(t.end(), t1.begin(), t1.end());
|
||||||
|
t.push_back(b);
|
||||||
|
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T concat(const T& t1, int b) {
|
||||||
|
T t;
|
||||||
|
t.insert(t.end(), t1.begin(), t1.end());
|
||||||
|
t.push_back(b);
|
||||||
|
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T concat(const T& t1, const T& t2) {
|
||||||
|
T t;
|
||||||
|
t.insert(t.end(), t1.begin(), t1.end());
|
||||||
|
t.insert(t.end(), t2.begin(), t2.end());
|
||||||
|
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T1, typename T2, typename... Ts>
|
||||||
|
T1 concat(const T1& t1, const T2& t2, const Ts&... ts) {
|
||||||
|
return concat(concat(t1, t2), ts...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <joint_dtypes_t Ts>
|
||||||
|
struct onednn_types_mapper;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct onednn_types_mapper<joint_dtypes_t::f16_int4> {
|
||||||
|
static inline std::tuple<dnnl::memory::data_type, dnnl::memory::data_type>
|
||||||
|
get() {
|
||||||
|
return std::make_tuple(
|
||||||
|
dnnl::memory::data_type::f16, dnnl::memory::data_type::u4);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct onednn_types_mapper<joint_dtypes_t::bf16_int4> {
|
||||||
|
static inline std::tuple<dnnl::memory::data_type, dnnl::memory::data_type>
|
||||||
|
get() {
|
||||||
|
return std::make_tuple(
|
||||||
|
dnnl::memory::data_type::bf16, dnnl::memory::data_type::u4);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: bias types maybe not right
|
||||||
|
static inline dnnl::memory::dims get_bias_type(
|
||||||
|
bias_type_t b_dims,
|
||||||
|
const int m,
|
||||||
|
const int n) {
|
||||||
|
switch (b_dims) {
|
||||||
|
case bias_type_t::none:
|
||||||
|
return {0};
|
||||||
|
case bias_type_t::scalar:
|
||||||
|
return {1, 1};
|
||||||
|
case bias_type_t::m:
|
||||||
|
return {m, 1};
|
||||||
|
case bias_type_t::n:
|
||||||
|
return {1, n};
|
||||||
|
case bias_type_t::mn:
|
||||||
|
return {m, n};
|
||||||
|
default:
|
||||||
|
TORCH_INTERNAL_ASSERT(false, "unsupported bias type ...");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: use template specialization on struct
|
||||||
|
template <trans_type_t Tt>
|
||||||
|
inline void get_strides(
|
||||||
|
memory::dims& src_strides,
|
||||||
|
memory::dims& wei_strides,
|
||||||
|
memory::dims& dst_strides,
|
||||||
|
const int64_t lda,
|
||||||
|
const int64_t ldb,
|
||||||
|
const int64_t ldc) {}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline void get_strides<trans_type_t::nt>(
|
||||||
|
memory::dims& src_strides,
|
||||||
|
memory::dims& wei_strides,
|
||||||
|
memory::dims& dst_strides,
|
||||||
|
const int64_t lda,
|
||||||
|
const int64_t ldb,
|
||||||
|
const int64_t ldc) {
|
||||||
|
src_strides = {lda, 1};
|
||||||
|
wei_strides = {1, ldb};
|
||||||
|
dst_strides = {ldc, 1};
|
||||||
|
}
|
||||||
|
|
||||||
|
using primitive_cache =
|
||||||
|
at::native::onednn::lru_cache<memory::dims, primitive_ext>;
|
||||||
|
|
||||||
|
template <trans_type_t Tt, joint_dtypes_t Ts, typename F>
|
||||||
|
struct matmul_primitive_cache_t {
|
||||||
|
static inline primitive_ext& get(
|
||||||
|
const int m,
|
||||||
|
const int n,
|
||||||
|
const int k,
|
||||||
|
const int64_t lda,
|
||||||
|
const int64_t ldb,
|
||||||
|
const int64_t ldc,
|
||||||
|
const bias_type_t
|
||||||
|
b_dims, // for shapeless bias, not put it into template parameter
|
||||||
|
const int device_id,
|
||||||
|
F f_attr,
|
||||||
|
const int64_t scale_group_size,
|
||||||
|
const int64_t zp_group_size) {
|
||||||
|
auto& cached = get_cache(device_id);
|
||||||
|
memory::dims src_strides, wei_strides, dst_strides;
|
||||||
|
get_strides<Tt>(src_strides, wei_strides, dst_strides, lda, ldb, ldc);
|
||||||
|
auto pri_key = at::native::onednn::concat(
|
||||||
|
src_strides,
|
||||||
|
wei_strides,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
int(b_dims),
|
||||||
|
int(scale_group_size),
|
||||||
|
int(zp_group_size));
|
||||||
|
auto iter = cached.find(pri_key);
|
||||||
|
if (iter == cached.end()) {
|
||||||
|
auto [src_dt, wei_dt] = onednn_types_mapper<Ts>::get();
|
||||||
|
auto bias_dims = get_bias_type(b_dims, m, n);
|
||||||
|
|
||||||
|
auto src_md = memory::desc({m, k}, src_dt, src_strides);
|
||||||
|
auto wei_md = memory::desc({k, n}, wei_dt, wei_strides);
|
||||||
|
auto dst_md = memory::desc({m, n}, src_dt, dst_strides);
|
||||||
|
auto bias_format = b_dims == bias_type_t::none
|
||||||
|
? dnnl::memory::format_tag::undef
|
||||||
|
: dnnl::memory::format_tag::ab;
|
||||||
|
auto bias_md =
|
||||||
|
memory::desc(bias_dims, src_dt, bias_format); // {m, n} or {1, n}
|
||||||
|
|
||||||
|
primitive_attr pattr;
|
||||||
|
f_attr(pattr);
|
||||||
|
|
||||||
|
dnnl::matmul::primitive_desc matmul_pd;
|
||||||
|
auto aengine =
|
||||||
|
at::native::onednn::GpuEngineManager::Instance().get_engine(
|
||||||
|
device_id);
|
||||||
|
if (b_dims == bias_type_t::none) {
|
||||||
|
matmul_pd = dnnl::matmul::primitive_desc(
|
||||||
|
aengine, src_md, wei_md, dst_md, pattr);
|
||||||
|
} else {
|
||||||
|
matmul_pd = dnnl::matmul::primitive_desc(
|
||||||
|
aengine, src_md, wei_md, bias_md, dst_md, pattr);
|
||||||
|
}
|
||||||
|
|
||||||
|
return cached.insert({pri_key, primitive_ext(dnnl::matmul(matmul_pd))})
|
||||||
|
.first->second;
|
||||||
|
} else {
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static constexpr int max_cache_capacity = 512;
|
||||||
|
// if default constructor of primitive cache could read the environment
|
||||||
|
// variable then it'll save a lot of trouble
|
||||||
|
static inline thread_local std::array<primitive_cache, 16> mappings;
|
||||||
|
|
||||||
|
// this won't be needed if primitive_cache have good default constructor
|
||||||
|
static inline primitive_cache& get_cache(const int device_id) {
|
||||||
|
auto& mapping = mappings[device_id];
|
||||||
|
if (mapping.max_size() == 0) {
|
||||||
|
mapping.resize(max_cache_capacity);
|
||||||
|
}
|
||||||
|
return mapping;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <joint_dtypes_t Ts, typename F>
|
||||||
|
static inline primitive_ext& matmul_primitive_create_and_cache(
|
||||||
|
const trans_type_t Tt,
|
||||||
|
const bias_type_t b_dims,
|
||||||
|
const int m,
|
||||||
|
const int n,
|
||||||
|
const int k,
|
||||||
|
const int64_t lda,
|
||||||
|
const int64_t ldb,
|
||||||
|
const int64_t ldc,
|
||||||
|
const int device_id,
|
||||||
|
F attr,
|
||||||
|
const int64_t scale_group_size,
|
||||||
|
const int64_t zp_group_size) {
|
||||||
|
switch (Tt) {
|
||||||
|
case trans_type_t::nt:
|
||||||
|
return matmul_primitive_cache_t<trans_type_t::nt, Ts, F>::get(
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
ldc,
|
||||||
|
b_dims,
|
||||||
|
device_id,
|
||||||
|
attr,
|
||||||
|
scale_group_size,
|
||||||
|
zp_group_size);
|
||||||
|
default:
|
||||||
|
TORCH_INTERNAL_ASSERT(false, "unsupported trans type ...");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
static inline primitive_ext& matmul_primitive_create_and_cache(
|
||||||
|
const joint_dtypes_t Ts,
|
||||||
|
const trans_type_t Tt,
|
||||||
|
const bias_type_t b_dims,
|
||||||
|
const int m,
|
||||||
|
const int n,
|
||||||
|
const int k,
|
||||||
|
const int64_t lda,
|
||||||
|
const int64_t ldb, // is weight ldb necessary?
|
||||||
|
const int64_t ldc,
|
||||||
|
const int device_id,
|
||||||
|
F attr,
|
||||||
|
const int64_t scale_group_size = 0,
|
||||||
|
const int64_t zp_group_size = 0) {
|
||||||
|
switch (Ts) {
|
||||||
|
case joint_dtypes_t::f16_int4:
|
||||||
|
return matmul_primitive_create_and_cache<joint_dtypes_t::f16_int4, F>(
|
||||||
|
Tt,
|
||||||
|
b_dims,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
ldc,
|
||||||
|
device_id,
|
||||||
|
attr,
|
||||||
|
scale_group_size,
|
||||||
|
zp_group_size);
|
||||||
|
case joint_dtypes_t::bf16_int4:
|
||||||
|
return matmul_primitive_create_and_cache<joint_dtypes_t::bf16_int4, F>(
|
||||||
|
Tt,
|
||||||
|
b_dims,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
ldc,
|
||||||
|
device_id,
|
||||||
|
attr,
|
||||||
|
scale_group_size,
|
||||||
|
zp_group_size);
|
||||||
|
default:
|
||||||
|
TORCH_INTERNAL_ASSERT(false, "Only support int4 ...");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace at::native::onednn
|
||||||
110
aten/src/ATen/native/mkldnn/xpu/detail/LRUCache.h
Normal file
110
aten/src/ATen/native/mkldnn/xpu/detail/LRUCache.h
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <iterator>
|
||||||
|
#include <list>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace at::native::onednn {
|
||||||
|
|
||||||
|
template <
|
||||||
|
class key_t,
|
||||||
|
class value_t,
|
||||||
|
template <typename...> class map_t = std::unordered_map>
|
||||||
|
class lru_cache {
|
||||||
|
public:
|
||||||
|
using value_type = std::pair<key_t, value_t>;
|
||||||
|
using list_type = std::list<value_type>;
|
||||||
|
using list_iter = typename list_type::iterator;
|
||||||
|
using map_type = map_t<key_t, list_iter>;
|
||||||
|
using const_list_iter = typename list_type::const_iterator;
|
||||||
|
using size_type = typename list_type::size_type;
|
||||||
|
|
||||||
|
explicit lru_cache(size_type capacity) : capacity_(capacity) {}
|
||||||
|
lru_cache() : capacity_(0) {}
|
||||||
|
|
||||||
|
[[nodiscard]] size_type size() const noexcept {
|
||||||
|
return map_.size();
|
||||||
|
}
|
||||||
|
[[nodiscard]] size_type max_size() const noexcept {
|
||||||
|
return capacity_;
|
||||||
|
}
|
||||||
|
[[nodiscard]] bool empty() const noexcept {
|
||||||
|
return vlist_.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
void resize(size_type new_capacity) {
|
||||||
|
capacity_ = new_capacity;
|
||||||
|
trim();
|
||||||
|
}
|
||||||
|
|
||||||
|
list_iter begin() noexcept {
|
||||||
|
return vlist_.begin();
|
||||||
|
}
|
||||||
|
const_list_iter begin() const noexcept {
|
||||||
|
return vlist_.begin();
|
||||||
|
}
|
||||||
|
list_iter end() noexcept {
|
||||||
|
return vlist_.end();
|
||||||
|
}
|
||||||
|
const_list_iter end() const noexcept {
|
||||||
|
return vlist_.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear() noexcept {
|
||||||
|
map_.clear();
|
||||||
|
vlist_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
void swap(lru_cache& other) noexcept {
|
||||||
|
using std::swap;
|
||||||
|
swap(vlist_, other.vlist_);
|
||||||
|
swap(map_, other.map_);
|
||||||
|
swap(capacity_, other.capacity_);
|
||||||
|
}
|
||||||
|
|
||||||
|
list_iter find(const key_t& key) {
|
||||||
|
auto it = map_.find(key);
|
||||||
|
if (it == map_.end())
|
||||||
|
return end();
|
||||||
|
vlist_.splice(vlist_.begin(), vlist_, it->second);
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<list_iter, bool> insert(const value_type& value) {
|
||||||
|
auto it = map_.find(value.first);
|
||||||
|
if (it != map_.end()) {
|
||||||
|
// Move existing to front
|
||||||
|
vlist_.splice(vlist_.begin(), vlist_, it->second);
|
||||||
|
return {it->second, false};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert new at front
|
||||||
|
vlist_.emplace_front(value);
|
||||||
|
map_[value.first] = vlist_.begin();
|
||||||
|
|
||||||
|
trim();
|
||||||
|
|
||||||
|
return {vlist_.begin(), true};
|
||||||
|
}
|
||||||
|
|
||||||
|
list_iter erase(list_iter pos) {
|
||||||
|
map_.erase(pos->first);
|
||||||
|
return vlist_.erase(pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void trim() {
|
||||||
|
while (map_.size() > capacity_) {
|
||||||
|
auto last = std::prev(vlist_.end());
|
||||||
|
map_.erase(last->first);
|
||||||
|
vlist_.pop_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
list_type vlist_;
|
||||||
|
map_type map_;
|
||||||
|
size_type capacity_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace at::native::onednn
|
||||||
@ -294,6 +294,13 @@ bool is_onednn_matmul_strides(const at::Tensor& tensor) {
|
|||||||
if (tensor.is_contiguous())
|
if (tensor.is_contiguous())
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
|
if (tensor.storage_offset() > 0) {
|
||||||
|
// currently onednn asks 64 byte alignment
|
||||||
|
constexpr int alignment_byte = 64;
|
||||||
|
if (reinterpret_cast<uintptr_t>(tensor.data_ptr()) % alignment_byte > 0)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// the overlaped cases are not supported
|
// the overlaped cases are not supported
|
||||||
dnnl::memory::dims strides = get_onednn_strides(tensor);
|
dnnl::memory::dims strides = get_onednn_strides(tensor);
|
||||||
int64_t storage_size = 1;
|
int64_t storage_size = 1;
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
#include <c10/xpu/XPUFunctions.h>
|
#include <c10/xpu/XPUFunctions.h>
|
||||||
|
|
||||||
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
|
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
|
||||||
|
#include <ATen/native/mkldnn/xpu/detail/DnnlExt.h>
|
||||||
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
|
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
|
||||||
|
|
||||||
#include <oneapi/dnnl/dnnl.hpp>
|
#include <oneapi/dnnl/dnnl.hpp>
|
||||||
@ -8,22 +9,13 @@
|
|||||||
|
|
||||||
namespace at::native::onednn {
|
namespace at::native::onednn {
|
||||||
|
|
||||||
void woq_matmul_int4(
|
void woq_matmul_int4_impl(
|
||||||
Tensor& result, // torchao: [M, K], dtype: fp16,bf16
|
Tensor& result,
|
||||||
const Tensor& mat1_, // torchao: [M, K], dtype: fp16,bf16
|
const Tensor& mat1_,
|
||||||
const Tensor& mat2_, // torchao quantized weight, [K/8, N], dtype: uint4x8
|
const Tensor& mat2_,
|
||||||
const Tensor& scale, // torchao: [K/group_size, N], dtype: fp16,bf16
|
const Tensor& scale,
|
||||||
const Tensor& zp, // torchao: [K/group_size, N], dtype: int8
|
const Tensor& zp,
|
||||||
int64_t group_size) {
|
int64_t group_size) {
|
||||||
size_t dims = result.dim();
|
|
||||||
TORCH_CHECK(
|
|
||||||
dims == 2, "INT4 matmul at XPU only works with 2D input, got ", dims);
|
|
||||||
TORCH_CHECK(result.defined(), "oneDNN matmul result should be defined");
|
|
||||||
|
|
||||||
at::Device cur_device = at::Device(at::kXPU, at::xpu::current_device());
|
|
||||||
TORCH_CHECK(
|
|
||||||
cur_device == mat1_.device(),
|
|
||||||
"_weight_int4pack_mm_with_scales_and_zeros input should be on current device.");
|
|
||||||
auto& engine = GpuEngineManager::Instance().get_engine();
|
auto& engine = GpuEngineManager::Instance().get_engine();
|
||||||
auto& stream = GpuStreamManager::Instance().get_stream();
|
auto& stream = GpuStreamManager::Instance().get_stream();
|
||||||
|
|
||||||
@ -176,4 +168,162 @@ void woq_matmul_int4(
|
|||||||
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, zp_usr_m});
|
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, zp_usr_m});
|
||||||
dnnl::sycl_interop::execute(matmul_p, stream, args);
|
dnnl::sycl_interop::execute(matmul_p, stream, args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline void set_quant_primitive_attr(
|
||||||
|
primitive_attr& pattr,
|
||||||
|
const Tensor& scale,
|
||||||
|
const Tensor& zp,
|
||||||
|
const int64_t group_size) {
|
||||||
|
// set scale and zero point for matmul args
|
||||||
|
pattr.set_scales(
|
||||||
|
DNNL_ARG_WEIGHTS,
|
||||||
|
/* mask */ (1 << 0) + (1 << 1),
|
||||||
|
{group_size, 1},
|
||||||
|
get_onednn_dtype(scale));
|
||||||
|
pattr.set_zero_points(
|
||||||
|
DNNL_ARG_WEIGHTS,
|
||||||
|
/* mask */ (1 << 0) + (1 << 1),
|
||||||
|
{group_size, 1},
|
||||||
|
memory::data_type::s8);
|
||||||
|
}
|
||||||
|
|
||||||
|
void woq_matmul_int4_impl_cache(
|
||||||
|
Tensor& result,
|
||||||
|
const Tensor& mat1,
|
||||||
|
const Tensor& mat2,
|
||||||
|
const Tensor& scale,
|
||||||
|
const Tensor& zp,
|
||||||
|
int64_t group_size) {
|
||||||
|
auto a_sz = mat1.sizes();
|
||||||
|
auto c_sz = result.sizes();
|
||||||
|
|
||||||
|
const int m =
|
||||||
|
std::reduce(a_sz.begin(), a_sz.end() - 1, 1, std::multiplies<int64_t>());
|
||||||
|
const int n = *(c_sz.end() - 1);
|
||||||
|
const int k = *(a_sz.end() - 1);
|
||||||
|
|
||||||
|
const int64_t ldb = mat2.strides()[mat2.dim() - 2] * 8; // for int4 matmul
|
||||||
|
const int64_t lda = mat1.strides()[mat1.dim() - 2];
|
||||||
|
const int64_t ldc = result.strides()[result.dim() - 2];
|
||||||
|
|
||||||
|
bias_type_t b_type = bias_type_t::none;
|
||||||
|
trans_type_t tt = trans_type_t::nt; // only support nt for int4 matmul
|
||||||
|
|
||||||
|
joint_dtypes_t jd;
|
||||||
|
if (mat1.scalar_type() == at::ScalarType::Half) {
|
||||||
|
jd = joint_dtypes_t::f16_int4;
|
||||||
|
} else if (mat1.scalar_type() == at::ScalarType::BFloat16) {
|
||||||
|
jd = joint_dtypes_t::bf16_int4;
|
||||||
|
} else {
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
false, "Unsupported data type for int4 matmul: ", mat1.scalar_type());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto f_attr = [&](primitive_attr& pattr) {
|
||||||
|
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||||
|
|
||||||
|
if (jd == joint_dtypes_t::f16_int4) {
|
||||||
|
pattr.set_fpmath_mode(dnnl::fpmath_mode::f16, true);
|
||||||
|
} else if (jd == joint_dtypes_t::bf16_int4) {
|
||||||
|
pattr.set_fpmath_mode(dnnl::fpmath_mode::bf16, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
set_quant_primitive_attr(pattr, scale, zp, group_size);
|
||||||
|
|
||||||
|
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||||
|
if (at::globalContext().deterministicAlgorithms() ||
|
||||||
|
at::globalContext().deterministicMkldnn()) {
|
||||||
|
pattr.set_deterministic(true);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
int64_t zp_group_size = group_size;
|
||||||
|
auto device_id = c10::xpu::current_device();
|
||||||
|
auto& matmul_ext = matmul_primitive_create_and_cache(
|
||||||
|
jd,
|
||||||
|
tt,
|
||||||
|
b_type,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
ldc,
|
||||||
|
device_id,
|
||||||
|
f_attr,
|
||||||
|
group_size,
|
||||||
|
zp_group_size);
|
||||||
|
|
||||||
|
auto& engine = GpuEngineManager::Instance().get_engine();
|
||||||
|
|
||||||
|
int arg_off = 0;
|
||||||
|
// set scale and zero point for matmul args
|
||||||
|
matmul_ext.set_attribute(
|
||||||
|
arg_off++,
|
||||||
|
DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS,
|
||||||
|
scale.data_ptr(),
|
||||||
|
[&]() {
|
||||||
|
return make_onednn_memory(
|
||||||
|
get_onednn_md(scale), engine, scale.data_ptr());
|
||||||
|
});
|
||||||
|
|
||||||
|
// set zp_md for asymmetric quantization
|
||||||
|
matmul_ext.set_attribute(
|
||||||
|
arg_off++,
|
||||||
|
DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS,
|
||||||
|
zp.data_ptr(),
|
||||||
|
[&]() {
|
||||||
|
int num_groups = k / group_size;
|
||||||
|
memory zp_usr_m(
|
||||||
|
{{num_groups, n}, memory::data_type::s8, {n, 1}},
|
||||||
|
engine,
|
||||||
|
zp.data_ptr());
|
||||||
|
return zp_usr_m;
|
||||||
|
});
|
||||||
|
|
||||||
|
// set general args
|
||||||
|
std::vector<std::pair<int, void*>> arg_handles;
|
||||||
|
arg_handles.reserve(8);
|
||||||
|
|
||||||
|
arg_handles.emplace_back(DNNL_ARG_SRC, mat1.data_ptr());
|
||||||
|
arg_handles.emplace_back(DNNL_ARG_WEIGHTS, mat2.data_ptr());
|
||||||
|
arg_handles.emplace_back(DNNL_ARG_DST, result.data_ptr());
|
||||||
|
|
||||||
|
int scratchpad_size = matmul_ext.get_scratchpad_size();
|
||||||
|
Tensor scratchpad_tensor = at::empty(
|
||||||
|
{scratchpad_size}, mat1.options().dtype(at::kByte), std::nullopt);
|
||||||
|
arg_handles.emplace_back(DNNL_ARG_SCRATCHPAD, scratchpad_tensor.data_ptr());
|
||||||
|
|
||||||
|
auto& strm = GpuStreamManager::Instance().get_stream();
|
||||||
|
auto qint4_matmul_event =
|
||||||
|
matmul_ext.execute(strm, engine, std::move(arg_handles), arg_off);
|
||||||
|
}
|
||||||
|
|
||||||
|
void woq_matmul_int4(
|
||||||
|
Tensor& result, // torchao: [M, K], dtype: fp16,bf16
|
||||||
|
const Tensor& mat1_, // torchao: [M, K], dtype: fp16,bf16
|
||||||
|
const Tensor& mat2_, // torchao quantized weight, [K/8, N], dtype: uint4x8
|
||||||
|
const Tensor& scale, // torchao: [K/group_size, N], dtype: fp16,bf16
|
||||||
|
const Tensor& zp, // torchao: [K/group_size, N], dtype: int8
|
||||||
|
int64_t group_size,
|
||||||
|
bool pri_cache) {
|
||||||
|
size_t dims = result.dim();
|
||||||
|
TORCH_CHECK(
|
||||||
|
dims == 2, "INT4 matmul at XPU only works with 2D input, got ", dims);
|
||||||
|
TORCH_CHECK(result.defined(), "oneDNN matmul result should be defined");
|
||||||
|
|
||||||
|
const int device_id = c10::xpu::current_device();
|
||||||
|
at::Device cur_device = at::Device(at::kXPU, device_id);
|
||||||
|
TORCH_CHECK(
|
||||||
|
cur_device == mat1_.device(),
|
||||||
|
"_weight_int4pack_mm_with_scales_and_zeros input should be on current device.");
|
||||||
|
|
||||||
|
if (pri_cache) {
|
||||||
|
woq_matmul_int4_impl_cache(result, mat1_, mat2_, scale, zp, group_size);
|
||||||
|
} else {
|
||||||
|
woq_matmul_int4_impl(result, mat1_, mat2_, scale, zp, group_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace at::native::onednn
|
} // namespace at::native::onednn
|
||||||
|
|||||||
@ -95,7 +95,8 @@ TORCH_API void woq_matmul_int4(
|
|||||||
const at::Tensor& mat2_, // quantized weight, [K/8, N]
|
const at::Tensor& mat2_, // quantized weight, [K/8, N]
|
||||||
const at::Tensor& scale, // [K/group_size, N]
|
const at::Tensor& scale, // [K/group_size, N]
|
||||||
const at::Tensor& zp, // [k/group_size, N]
|
const at::Tensor& zp, // [k/group_size, N]
|
||||||
int64_t group_size);
|
int64_t group_size,
|
||||||
|
bool pri_cache = true);
|
||||||
|
|
||||||
dnnl::memory::dims conv_dst_size(
|
dnnl::memory::dims conv_dst_size(
|
||||||
int64_t ndim,
|
int64_t ndim,
|
||||||
|
|||||||
@ -295,6 +295,127 @@ kernel void masked_fill_scalar_strided(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename index_t>
|
||||||
|
kernel void index_copy_dense(
|
||||||
|
device T* output,
|
||||||
|
constant T* input,
|
||||||
|
constant T* source,
|
||||||
|
constant index_t* indices,
|
||||||
|
constant uint& dim,
|
||||||
|
constant long* sizes,
|
||||||
|
constant uint& ndim,
|
||||||
|
constant uint& indices_numel,
|
||||||
|
uint thread_index [[thread_position_in_grid]]) {
|
||||||
|
// first copy input to output
|
||||||
|
output[thread_index] = input[thread_index];
|
||||||
|
|
||||||
|
// calculate pos in the tensor using a signed counter
|
||||||
|
long pos[max_ndim];
|
||||||
|
long linear_idx = thread_index;
|
||||||
|
for (int i = static_cast<int>(ndim) - 1; i >= 0; --i) {
|
||||||
|
pos[i] = linear_idx % sizes[i];
|
||||||
|
linear_idx /= sizes[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if this position's dim coordinate is in the indices
|
||||||
|
long dim_pos = pos[dim];
|
||||||
|
|
||||||
|
// search through indices to see if current dim pos should be updated
|
||||||
|
for (uint i = 0; i < indices_numel; i++) {
|
||||||
|
if (indices[i] == dim_pos) {
|
||||||
|
// this position should be updated from source
|
||||||
|
// calculate source offset where the source tensor has the same shape
|
||||||
|
// except along dim where it has size = indices_numel
|
||||||
|
long source_offset = 0;
|
||||||
|
long stride = 1;
|
||||||
|
for (int j = static_cast<int>(ndim) - 1; j >= 0; --j) {
|
||||||
|
if (j == static_cast<int>(dim)) {
|
||||||
|
// for the indexed dimension, use position i
|
||||||
|
source_offset += i * stride;
|
||||||
|
stride *= indices_numel;
|
||||||
|
} else {
|
||||||
|
// for other dimensions use the same position
|
||||||
|
source_offset += pos[j] * stride;
|
||||||
|
stride *= sizes[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output[thread_index] = source[source_offset];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename index_t>
|
||||||
|
kernel void index_copy_strided(
|
||||||
|
device T* output,
|
||||||
|
constant T* input,
|
||||||
|
constant T* source,
|
||||||
|
constant index_t* indices,
|
||||||
|
constant uint& dim,
|
||||||
|
constant long* sizes,
|
||||||
|
constant uint& ndim,
|
||||||
|
constant uint& indices_numel,
|
||||||
|
constant long* input_strides,
|
||||||
|
constant long* output_strides,
|
||||||
|
constant long* source_strides,
|
||||||
|
uint thread_index [[thread_position_in_grid]]) {
|
||||||
|
int pos[max_ndim];
|
||||||
|
pos_from_thread_index(int(thread_index), pos, sizes, ndim);
|
||||||
|
|
||||||
|
// compute offsets for the output and input tensors
|
||||||
|
long output_offset = offset_from_coord(pos, output_strides, ndim);
|
||||||
|
long input_offset = offset_from_coord(pos, input_strides, ndim);
|
||||||
|
|
||||||
|
output[output_offset] = input[input_offset];
|
||||||
|
|
||||||
|
// save the original coordinate along the dim we're updating
|
||||||
|
int orig_dim = pos[dim];
|
||||||
|
|
||||||
|
// find the last index in the indices array that equals this coordinate
|
||||||
|
int last_matching_index = -1;
|
||||||
|
for (uint i = 0; i < indices_numel; i++) {
|
||||||
|
if (indices[i] == orig_dim) {
|
||||||
|
last_matching_index = int(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if a matching index was found, use it to update the output
|
||||||
|
if (last_matching_index != -1) {
|
||||||
|
pos[dim] = last_matching_index;
|
||||||
|
long source_offset = offset_from_coord(pos, source_strides, ndim);
|
||||||
|
output[output_offset] = source[source_offset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define INSTANTIATE_INDEX_COPY(T, index_t) \
|
||||||
|
template [[host_name("index_copy_dense_" #T "_" #index_t)]] \
|
||||||
|
kernel void index_copy_dense<T, index_t>( \
|
||||||
|
device T*, \
|
||||||
|
constant T*, \
|
||||||
|
constant T*, \
|
||||||
|
constant index_t*, \
|
||||||
|
constant uint&, \
|
||||||
|
constant long*, \
|
||||||
|
constant uint&, \
|
||||||
|
constant uint&, \
|
||||||
|
uint); \
|
||||||
|
\
|
||||||
|
template [[host_name("index_copy_strided_" #T "_" #index_t)]] \
|
||||||
|
kernel void index_copy_strided<T, index_t>( \
|
||||||
|
device T*, \
|
||||||
|
constant T*, \
|
||||||
|
constant T*, \
|
||||||
|
constant index_t*, \
|
||||||
|
constant uint&, \
|
||||||
|
constant long*, \
|
||||||
|
constant uint&, \
|
||||||
|
constant uint&, \
|
||||||
|
constant long*, \
|
||||||
|
constant long*, \
|
||||||
|
constant long*, \
|
||||||
|
uint);
|
||||||
|
|
||||||
#define REGISTER_MASKED_FILL_SCALAR(SIZE, DTYPE) \
|
#define REGISTER_MASKED_FILL_SCALAR(SIZE, DTYPE) \
|
||||||
template [[host_name("masked_fill_scalar_strided_" #SIZE)]] kernel void \
|
template [[host_name("masked_fill_scalar_strided_" #SIZE)]] kernel void \
|
||||||
masked_fill_scalar_strided<DTYPE>( \
|
masked_fill_scalar_strided<DTYPE>( \
|
||||||
@ -317,3 +438,28 @@ REGISTER_MASKED_FILL_SCALAR(64bit, long);
|
|||||||
REGISTER_MASKED_FILL_SCALAR(32bit, int);
|
REGISTER_MASKED_FILL_SCALAR(32bit, int);
|
||||||
REGISTER_MASKED_FILL_SCALAR(16bit, short);
|
REGISTER_MASKED_FILL_SCALAR(16bit, short);
|
||||||
REGISTER_MASKED_FILL_SCALAR(8bit, char);
|
REGISTER_MASKED_FILL_SCALAR(8bit, char);
|
||||||
|
INSTANTIATE_INDEX_COPY(float, int);
|
||||||
|
INSTANTIATE_INDEX_COPY(float, long);
|
||||||
|
INSTANTIATE_INDEX_COPY(bool, int);
|
||||||
|
INSTANTIATE_INDEX_COPY(bool, long);
|
||||||
|
INSTANTIATE_INDEX_COPY(half, int);
|
||||||
|
INSTANTIATE_INDEX_COPY(half, long);
|
||||||
|
INSTANTIATE_INDEX_COPY(int, int);
|
||||||
|
INSTANTIATE_INDEX_COPY(int, long);
|
||||||
|
INSTANTIATE_INDEX_COPY(long, int);
|
||||||
|
INSTANTIATE_INDEX_COPY(long, long);
|
||||||
|
INSTANTIATE_INDEX_COPY(short, int);
|
||||||
|
INSTANTIATE_INDEX_COPY(short, long);
|
||||||
|
INSTANTIATE_INDEX_COPY(char, int);
|
||||||
|
INSTANTIATE_INDEX_COPY(char, long);
|
||||||
|
INSTANTIATE_INDEX_COPY(uchar, int);
|
||||||
|
INSTANTIATE_INDEX_COPY(uchar, long);
|
||||||
|
|
||||||
|
#if __METAL_VERSION__ >= 310
|
||||||
|
INSTANTIATE_INDEX_COPY(bfloat, int);
|
||||||
|
INSTANTIATE_INDEX_COPY(bfloat, long);
|
||||||
|
#endif
|
||||||
|
INSTANTIATE_INDEX_COPY(float2, int);
|
||||||
|
INSTANTIATE_INDEX_COPY(float2, long);
|
||||||
|
INSTANTIATE_INDEX_COPY(half2, int);
|
||||||
|
INSTANTIATE_INDEX_COPY(half2, long);
|
||||||
|
|||||||
@ -34,6 +34,7 @@
|
|||||||
#include <ATen/ops/flip_native.h>
|
#include <ATen/ops/flip_native.h>
|
||||||
#include <ATen/ops/index.h>
|
#include <ATen/ops/index.h>
|
||||||
#include <ATen/ops/index_add_native.h>
|
#include <ATen/ops/index_add_native.h>
|
||||||
|
#include <ATen/ops/index_copy_native.h>
|
||||||
#include <ATen/ops/index_fill_native.h>
|
#include <ATen/ops/index_fill_native.h>
|
||||||
#include <ATen/ops/index_put.h>
|
#include <ATen/ops/index_put.h>
|
||||||
#include <ATen/ops/index_select_native.h>
|
#include <ATen/ops/index_select_native.h>
|
||||||
@ -252,6 +253,78 @@ static void index_put_kernel_mps(TensorIterator& iter,
|
|||||||
}
|
}
|
||||||
} // namespace mps
|
} // namespace mps
|
||||||
|
|
||||||
|
TORCH_IMPL_FUNC(index_copy_out_mps)(const Tensor& self,
|
||||||
|
int64_t dim,
|
||||||
|
const Tensor& index,
|
||||||
|
const Tensor& source,
|
||||||
|
const Tensor& result) {
|
||||||
|
using namespace mps;
|
||||||
|
|
||||||
|
// special-case for 0-dim tensors
|
||||||
|
if (self.dim() == 0) {
|
||||||
|
TORCH_CHECK(index.numel() == 1,
|
||||||
|
"index_copy_(): attempting to index a 0-dim tensor with an index tensor of size ",
|
||||||
|
index.numel());
|
||||||
|
int64_t idx = index.item<int64_t>();
|
||||||
|
TORCH_CHECK(idx == 0, "index_copy_(): the only valid index for a 0-dim tensor is 0, but got ", idx);
|
||||||
|
result.copy_(source);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dim = maybe_wrap_dim(dim, self.dim());
|
||||||
|
|
||||||
|
// early return for empty index
|
||||||
|
if (index.numel() == 0) {
|
||||||
|
result.copy_(self);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < self.dim(); i++) {
|
||||||
|
if (i != dim) {
|
||||||
|
TORCH_CHECK(self.size(i) == source.size(i),
|
||||||
|
"index_copy_(): self and source must have same size at dimension ",
|
||||||
|
i,
|
||||||
|
"; self has size ",
|
||||||
|
self.size(i),
|
||||||
|
", source has size ",
|
||||||
|
source.size(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_CHECK(source.size(dim) == index.numel(),
|
||||||
|
"index_copy_(): Number of indices (",
|
||||||
|
index.numel(),
|
||||||
|
") should be equal to source.size(dim) (",
|
||||||
|
source.size(dim),
|
||||||
|
")");
|
||||||
|
|
||||||
|
auto stream = getCurrentMPSStream();
|
||||||
|
auto device = MPSDevice::getInstance()->device();
|
||||||
|
|
||||||
|
const bool is_dense =
|
||||||
|
self.is_contiguous() && source.is_contiguous() && result.is_contiguous() && index.is_contiguous();
|
||||||
|
|
||||||
|
auto dense_or_strided = is_dense ? "dense" : "strided";
|
||||||
|
auto long_or_int = (index.scalar_type() == ScalarType::Long) ? "long" : "int";
|
||||||
|
auto indexCopyPSO = lib.getPipelineStateForFunc(
|
||||||
|
fmt::format("index_copy_{}_{}_{}", dense_or_strided, scalarToMetalTypeString(result), long_or_int));
|
||||||
|
|
||||||
|
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||||
|
@autoreleasepool {
|
||||||
|
auto computeEncoder = stream->commandEncoder();
|
||||||
|
uint32_t dim_arg = static_cast<uint32_t>(dim);
|
||||||
|
uint32_t ndim = self.dim();
|
||||||
|
uint32_t indices_numel = index.numel();
|
||||||
|
[computeEncoder setComputePipelineState:indexCopyPSO];
|
||||||
|
mtl_setArgs(computeEncoder, result, self, source, index, dim_arg, self.sizes(), ndim, indices_numel);
|
||||||
|
if (!is_dense) {
|
||||||
|
mtl_setArgs<8>(computeEncoder, self.strides(), result.strides(), source.strides());
|
||||||
|
}
|
||||||
|
mtl_dispatch1DJob(computeEncoder, indexCopyPSO, result.numel());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
static Tensor nonzero_fallback(const Tensor& self) {
|
static Tensor nonzero_fallback(const Tensor& self) {
|
||||||
return at::nonzero(self.to("cpu")).to("mps");
|
return at::nonzero(self.to("cpu")).to("mps");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -35,14 +35,15 @@ static void _mps_linear_nograph(const Tensor& input, const Tensor& weight, const
|
|||||||
shape:getMPSShape(weight.sizes())];
|
shape:getMPSShape(weight.sizes())];
|
||||||
weightDesc.preferPackedRows = YES;
|
weightDesc.preferPackedRows = YES;
|
||||||
[weightDesc transposeDimension:0 withDimension:1];
|
[weightDesc transposeDimension:0 withDimension:1];
|
||||||
MPSNDArray* weightNDArray = [[MPSNDArray alloc] initWithBuffer:weightBuf
|
MPSNDArray* weightNDArray = [[[MPSNDArray alloc] initWithBuffer:weightBuf
|
||||||
offset:weight.storage_offset() * weight.element_size()
|
offset:weight.storage_offset() * weight.element_size()
|
||||||
descriptor:weightDesc];
|
descriptor:weightDesc] autorelease];
|
||||||
|
|
||||||
if (is_bias_defined) {
|
if (is_bias_defined) {
|
||||||
auto biasNDArray = getMPSNDArray(bias, bias.sizes(), bias.strides());
|
auto biasNDArray = getMPSNDArray(bias, bias.sizes(), bias.strides());
|
||||||
auto cachedKernel = LookUpOrCreateCachedKernel<MPSCachedKernel>(
|
auto cachedKernel = LookUpOrCreateCachedKernel<MPSCachedKernel>(key, [&]() {
|
||||||
key, [&]() { return [[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:3]; });
|
return [[[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:3] autorelease];
|
||||||
|
});
|
||||||
auto kernel = cachedKernel->kernel<MPSNDArrayMatrixMultiplication>();
|
auto kernel = cachedKernel->kernel<MPSNDArrayMatrixMultiplication>();
|
||||||
|
|
||||||
getMPSProfiler().beginProfileKernel(kernel, "mps_linear", {input, weight, bias});
|
getMPSProfiler().beginProfileKernel(kernel, "mps_linear", {input, weight, bias});
|
||||||
@ -52,8 +53,9 @@ static void _mps_linear_nograph(const Tensor& input, const Tensor& weight, const
|
|||||||
destinationArray:outNDArray];
|
destinationArray:outNDArray];
|
||||||
getMPSProfiler().endProfileKernel(kernel);
|
getMPSProfiler().endProfileKernel(kernel);
|
||||||
} else {
|
} else {
|
||||||
auto cachedKernel = LookUpOrCreateCachedKernel<MPSCachedKernel>(
|
auto cachedKernel = LookUpOrCreateCachedKernel<MPSCachedKernel>(key, [&]() {
|
||||||
key, [&]() { return [[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:2]; });
|
return [[[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:2] autorelease];
|
||||||
|
});
|
||||||
auto kernel = cachedKernel->kernel<MPSNDArrayMatrixMultiplication>();
|
auto kernel = cachedKernel->kernel<MPSNDArrayMatrixMultiplication>();
|
||||||
getMPSProfiler().beginProfileKernel(kernel, "mps_linear", {input, weight, bias});
|
getMPSProfiler().beginProfileKernel(kernel, "mps_linear", {input, weight, bias});
|
||||||
[kernel encodeToCommandEncoder:computeEncoder
|
[kernel encodeToCommandEncoder:computeEncoder
|
||||||
|
|||||||
@ -3110,6 +3110,7 @@
|
|||||||
- dim -> int dim
|
- dim -> int dim
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU, CUDA: index_copy_out
|
CPU, CUDA: index_copy_out
|
||||||
|
MPS: index_copy_out_mps
|
||||||
|
|
||||||
- func: index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)
|
- func: index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)
|
||||||
variants: method
|
variants: method
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
#include <ATen/native/sparse/cuda/cuSPARSELtOps.h>
|
#include <ATen/native/sparse/cuda/cuSPARSELtOps.h>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <mutex>
|
||||||
|
#include <string_view>
|
||||||
#if AT_CUSPARSELT_ENABLED()
|
#if AT_CUSPARSELT_ENABLED()
|
||||||
|
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
@ -15,6 +17,45 @@ namespace at::native {
|
|||||||
thread_local cusparseLtHandle_t handle;
|
thread_local cusparseLtHandle_t handle;
|
||||||
thread_local bool handle_initialized = false;
|
thread_local bool handle_initialized = false;
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
// Single global flag for platform-wide hipSparseLt support
|
||||||
|
c10::once_flag g_hipSparseLtSupportInitFlag;
|
||||||
|
static bool g_hipSparseLtSupported = false;
|
||||||
|
|
||||||
|
// Initialize the hipSparseLt support status once for the platform
|
||||||
|
static void initHipSparseLtSupport() {
|
||||||
|
// Default to not supported
|
||||||
|
g_hipSparseLtSupported = false;
|
||||||
|
|
||||||
|
// Check only the first available device
|
||||||
|
try {
|
||||||
|
if (at::cuda::device_count() > 0) {
|
||||||
|
g_hipSparseLtSupported = at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx942"}, 0);
|
||||||
|
}
|
||||||
|
} catch (const std::exception&) {
|
||||||
|
// If an exception occurs during device property check, we assume hipSparseLt is not supported
|
||||||
|
// This could happen due to driver issues, device access problems, or other runtime errors
|
||||||
|
g_hipSparseLtSupported = false;
|
||||||
|
TORCH_WARN("Exception occurred while checking hipSparseLt support. Assuming not supported.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isHipSparseLtSupported() {
|
||||||
|
// Initialize support check only once
|
||||||
|
c10::call_once(g_hipSparseLtSupportInitFlag, initHipSparseLtSupport);
|
||||||
|
|
||||||
|
// Return cached result (platform-wide)
|
||||||
|
if (!g_hipSparseLtSupported) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
"hipSparseLt not supported on this device, supported architectures: "
|
||||||
|
"gfx950, gfx942. "
|
||||||
|
"required ROCM version: 6.4.0 or later.");
|
||||||
|
}
|
||||||
|
return g_hipSparseLtSupported;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
at::Tensor _cslt_compress(const Tensor& sparse_input) {
|
at::Tensor _cslt_compress(const Tensor& sparse_input) {
|
||||||
if (!handle_initialized) {
|
if (!handle_initialized) {
|
||||||
TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle));
|
TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle));
|
||||||
@ -25,6 +66,10 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) {
|
|||||||
cudaDataType type;
|
cudaDataType type;
|
||||||
auto compression_factor = 9;
|
auto compression_factor = 9;
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
TORCH_CHECK(isHipSparseLtSupported());
|
||||||
|
#endif
|
||||||
|
|
||||||
switch (sparse_input.scalar_type()) {
|
switch (sparse_input.scalar_type()) {
|
||||||
case at::ScalarType::Char:
|
case at::ScalarType::Char:
|
||||||
type = CUDA_R_8I;
|
type = CUDA_R_8I;
|
||||||
@ -36,17 +81,19 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) {
|
|||||||
case at::ScalarType::BFloat16:
|
case at::ScalarType::BFloat16:
|
||||||
type = CUDA_R_16BF;
|
type = CUDA_R_16BF;
|
||||||
break;
|
break;
|
||||||
|
#ifndef USE_ROCM
|
||||||
case at::ScalarType::Float:
|
case at::ScalarType::Float:
|
||||||
type = CUDA_R_32F;
|
type = CUDA_R_32F;
|
||||||
break;
|
break;
|
||||||
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
|
#endif
|
||||||
|
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM)
|
||||||
case at::ScalarType::Float8_e4m3fn:
|
case at::ScalarType::Float8_e4m3fn:
|
||||||
type = CUDA_R_8F_E4M3;
|
type = CUDA_R_8F_E4M3;
|
||||||
compression_factor = 10;
|
compression_factor = 10;
|
||||||
break;
|
break;
|
||||||
#endif
|
#endif
|
||||||
default:
|
default:
|
||||||
TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix");
|
TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt/hipSparseLt compressed matrix");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,6 +167,10 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
|
|||||||
cusparseComputeType compute_type;
|
cusparseComputeType compute_type;
|
||||||
auto compression_factor = 9;
|
auto compression_factor = 9;
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
TORCH_CHECK(isHipSparseLtSupported());
|
||||||
|
#endif
|
||||||
|
|
||||||
switch (compressed_A.scalar_type()) {
|
switch (compressed_A.scalar_type()) {
|
||||||
case at::ScalarType::Char:
|
case at::ScalarType::Char:
|
||||||
input_type = CUDA_R_8I;
|
input_type = CUDA_R_8I;
|
||||||
@ -131,7 +182,7 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
|
|||||||
|
|
||||||
// cuSPARSELt v0.5.2 onwards changes CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUT_16F
|
// cuSPARSELt v0.5.2 onwards changes CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUT_16F
|
||||||
// to CUSPARSE_COMPUTE_32F
|
// to CUSPARSE_COMPUTE_32F
|
||||||
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502
|
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502 || defined(USE_ROCM)
|
||||||
case at::ScalarType::Half:
|
case at::ScalarType::Half:
|
||||||
input_type = CUDA_R_16F;
|
input_type = CUDA_R_16F;
|
||||||
output_type = CUDA_R_16F;
|
output_type = CUDA_R_16F;
|
||||||
@ -144,14 +195,16 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
|
|||||||
C_type = CUDA_R_16BF;
|
C_type = CUDA_R_16BF;
|
||||||
compute_type = CUSPARSE_COMPUTE_32F;
|
compute_type = CUSPARSE_COMPUTE_32F;
|
||||||
break;
|
break;
|
||||||
|
#ifndef USE_ROCM
|
||||||
case at::ScalarType::Float:
|
case at::ScalarType::Float:
|
||||||
input_type = CUDA_R_32F;
|
input_type = CUDA_R_32F;
|
||||||
output_type = CUDA_R_32F;
|
output_type = CUDA_R_32F;
|
||||||
C_type = CUDA_R_32F;
|
C_type = CUDA_R_32F;
|
||||||
compute_type = CUSPARSE_COMPUTE_32F;
|
compute_type = CUSPARSE_COMPUTE_32F;
|
||||||
break;
|
break;
|
||||||
|
#endif
|
||||||
// if cuSPARSELt >= 6.2.3, we can add Float8 support
|
// if cuSPARSELt >= 6.2.3, we can add Float8 support
|
||||||
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
|
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM)
|
||||||
case at::ScalarType::Float8_e4m3fn:
|
case at::ScalarType::Float8_e4m3fn:
|
||||||
input_type = CUDA_R_8F_E4M3;
|
input_type = CUDA_R_8F_E4M3;
|
||||||
output_type = CUDA_R_8F_E4M3;
|
output_type = CUDA_R_8F_E4M3;
|
||||||
@ -214,7 +267,7 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// cslt 0.6.2+: fp8 fp8 -> {fp8, fp16, bf16, fp32} support
|
// cslt 0.6.2+: fp8 fp8 -> {fp8, fp16, bf16, fp32} support
|
||||||
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
|
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM)
|
||||||
else if (input_type == CUDA_R_8F_E4M3) {
|
else if (input_type == CUDA_R_8F_E4M3) {
|
||||||
switch (out_dtype) {
|
switch (out_dtype) {
|
||||||
case at::ScalarType::Float8_e4m3fn:
|
case at::ScalarType::Float8_e4m3fn:
|
||||||
|
|||||||
@ -968,8 +968,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
|
|||||||
int64_t batch_size = query.size(0);
|
int64_t batch_size = query.size(0);
|
||||||
|
|
||||||
if (batch_size > MAX_BATCH_SIZE) {
|
if (batch_size > MAX_BATCH_SIZE) {
|
||||||
TORCH_CHECK(!compute_log_sumexp && (dropout_p == 0.0),
|
TORCH_CHECK(dropout_p == 0.0,
|
||||||
"Efficient attention cannot produce valid seed, logsumexp and offset outputs when "
|
"Efficient attention cannot produce valid seed and offset outputs when "
|
||||||
"the batch size exceeds (", MAX_BATCH_SIZE, ").");
|
"the batch size exceeds (", MAX_BATCH_SIZE, ").");
|
||||||
}
|
}
|
||||||
auto process_chunk = [&](const Tensor& q_chunk,
|
auto process_chunk = [&](const Tensor& q_chunk,
|
||||||
@ -1030,6 +1030,17 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
|
|||||||
}
|
}
|
||||||
Tensor final_attention = at::empty_strided(sizes, attn.strides(), attn.options());
|
Tensor final_attention = at::empty_strided(sizes, attn.strides(), attn.options());
|
||||||
final_attention.slice(0, start, end).copy_(attn);
|
final_attention.slice(0, start, end).copy_(attn);
|
||||||
|
Tensor final_log_sumexp;
|
||||||
|
if (compute_log_sumexp && log_sumexp.numel() > 0) {
|
||||||
|
std::vector<int64_t> lse_sizes;
|
||||||
|
lse_sizes.reserve(log_sumexp.dim());
|
||||||
|
lse_sizes.push_back(batch_size);
|
||||||
|
for (int i = 1; i < log_sumexp.dim(); i++) {
|
||||||
|
lse_sizes.push_back(log_sumexp.size(i));
|
||||||
|
}
|
||||||
|
final_log_sumexp = at::empty(std::move(lse_sizes), log_sumexp.options());
|
||||||
|
final_log_sumexp.slice(0, start, end).copy_(log_sumexp);
|
||||||
|
}
|
||||||
|
|
||||||
for (start = end; start < batch_size; start += MAX_BATCH_SIZE) {
|
for (start = end; start < batch_size; start += MAX_BATCH_SIZE) {
|
||||||
end = std::min(start + MAX_BATCH_SIZE, batch_size);
|
end = std::min(start + MAX_BATCH_SIZE, batch_size);
|
||||||
@ -1045,10 +1056,13 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
|
|||||||
auto [chunk_attn, chunk_log_sumexp, chunk_seed, chunk_offset] =
|
auto [chunk_attn, chunk_log_sumexp, chunk_seed, chunk_offset] =
|
||||||
process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk);
|
process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk);
|
||||||
final_attention.slice(0, start, end).copy_(chunk_attn);
|
final_attention.slice(0, start, end).copy_(chunk_attn);
|
||||||
|
if (compute_log_sumexp && chunk_log_sumexp.numel() > 0) {
|
||||||
|
final_log_sumexp.slice(0, start, end).copy_(chunk_log_sumexp);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_tuple(std::move(final_attention),
|
return std::make_tuple(std::move(final_attention),
|
||||||
std::move(log_sumexp),
|
std::move(final_log_sumexp),
|
||||||
std::move(seed),
|
std::move(seed),
|
||||||
std::move(offset));
|
std::move(offset));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -24,6 +24,8 @@
|
|||||||
#include <ATen/Functions.h>
|
#include <ATen/Functions.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#else
|
#else
|
||||||
|
#include <ATen/ops/zeros_like.h>
|
||||||
|
#include <ATen/ops/empty_strided.h>
|
||||||
#include <ATen/ops/_flash_attention_backward.h>
|
#include <ATen/ops/_flash_attention_backward.h>
|
||||||
#include <ATen/ops/_flash_attention_backward_native.h>
|
#include <ATen/ops/_flash_attention_backward_native.h>
|
||||||
#include <ATen/ops/_efficient_attention_backward.h>
|
#include <ATen/ops/_efficient_attention_backward.h>
|
||||||
@ -905,40 +907,56 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e
|
|||||||
if (!grad_out_.defined()) {
|
if (!grad_out_.defined()) {
|
||||||
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
|
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
|
||||||
}
|
}
|
||||||
auto grad_out = grad_out_.transpose(1, 2);
|
constexpr int64_t MAX_BATCH_SIZE = (1LL << 16) - 1;
|
||||||
auto out_t = out.transpose(1, 2);
|
int64_t batch_size = query.size(0);
|
||||||
auto q_t = query.transpose(1, 2);
|
|
||||||
auto k_t = key.transpose(1, 2);
|
|
||||||
auto v_t = value.transpose(1, 2);
|
|
||||||
|
|
||||||
|
if (batch_size > MAX_BATCH_SIZE) {
|
||||||
|
TORCH_CHECK(dropout_p == 0.0,
|
||||||
|
"Efficient attention backward cannot handle dropout when "
|
||||||
|
"the batch size exceeds (", MAX_BATCH_SIZE, ").");
|
||||||
|
}
|
||||||
|
auto grad_out_t = grad_out_.transpose(1, 2);
|
||||||
|
auto query_t = query.transpose(1, 2);
|
||||||
|
auto key_t = key.transpose(1, 2);
|
||||||
|
auto value_t = value.transpose(1, 2);
|
||||||
|
auto out_t = out.transpose(1, 2);
|
||||||
|
|
||||||
|
auto process_chunk = [&](const Tensor& grad_out_chunk,
|
||||||
|
const Tensor& query_chunk,
|
||||||
|
const Tensor& key_chunk,
|
||||||
|
const Tensor& value_chunk,
|
||||||
|
const std::optional<Tensor>& attn_bias_chunk,
|
||||||
|
const Tensor& out_chunk,
|
||||||
|
const Tensor& logsumexp_chunk)
|
||||||
|
-> std::tuple<Tensor, Tensor, Tensor, Tensor> {
|
||||||
// This is needed because SaveVariable automatically converts
|
// This is needed because SaveVariable automatically converts
|
||||||
// std::optional to undefined tensor
|
// std::optional to undefined tensor
|
||||||
std::optional<Tensor> kernel_bias;
|
std::optional<Tensor> kernel_bias;
|
||||||
if (attn_bias.defined()) {
|
if (attn_bias_chunk.has_value() && attn_bias_chunk.value().defined()) {
|
||||||
kernel_bias = attn_bias;
|
kernel_bias = attn_bias_chunk.value();
|
||||||
}
|
}
|
||||||
// Will add with signauter changes for dropout and bias
|
// Will add with signauter changes for dropout and bias
|
||||||
// We are only handling Dense inputs, but this should be passed
|
// We are only handling Dense inputs, but this should be passed
|
||||||
// from forward to backward
|
// from forward to backward
|
||||||
int64_t max_seqlen_q = q_t.size(1);
|
int64_t max_seqlen_q = query_chunk.size(2);
|
||||||
int64_t max_seqlen_k = k_t.size(1);
|
int64_t max_seqlen_k = key_chunk.size(2);
|
||||||
|
|
||||||
sdp::CustomMaskType custom_mask_type = causal
|
sdp::CustomMaskType custom_mask_type = causal
|
||||||
? sdp::CustomMaskType::CausalFromTopLeft
|
? sdp::CustomMaskType::CausalFromTopLeft
|
||||||
: sdp::CustomMaskType::NoCustomMask;
|
: sdp::CustomMaskType::NoCustomMask;
|
||||||
auto [grad_q, grad_k, grad_v, grad_bias] =
|
auto [grad_q, grad_k, grad_v, grad_bias] =
|
||||||
at::_efficient_attention_backward(
|
at::_efficient_attention_backward(
|
||||||
grad_out,
|
grad_out_chunk,
|
||||||
q_t,
|
query_chunk,
|
||||||
k_t,
|
key_chunk,
|
||||||
v_t,
|
value_chunk,
|
||||||
kernel_bias,
|
kernel_bias,
|
||||||
out_t,
|
out_chunk,
|
||||||
std::nullopt,
|
std::nullopt,
|
||||||
std::nullopt,
|
std::nullopt,
|
||||||
max_seqlen_q,
|
max_seqlen_q,
|
||||||
max_seqlen_k,
|
max_seqlen_k,
|
||||||
logsumexp,
|
logsumexp_chunk,
|
||||||
dropout_p,
|
dropout_p,
|
||||||
philox_seed,
|
philox_seed,
|
||||||
philox_offset,
|
philox_offset,
|
||||||
@ -947,7 +965,90 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e
|
|||||||
scale,
|
scale,
|
||||||
std::nullopt); // num_split_keys
|
std::nullopt); // num_split_keys
|
||||||
return std::make_tuple(
|
return std::make_tuple(
|
||||||
grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias);
|
grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), std::move(grad_bias));
|
||||||
|
};
|
||||||
|
|
||||||
|
// process in chunks if batch size exceeds maximum
|
||||||
|
if (batch_size > MAX_BATCH_SIZE) {
|
||||||
|
Tensor final_grad_q, final_grad_k, final_grad_v, final_grad_bias;
|
||||||
|
|
||||||
|
auto create_strided_output = [batch_size](const Tensor& tensor) -> Tensor {
|
||||||
|
if (!tensor.defined()) {
|
||||||
|
return Tensor{};
|
||||||
|
}
|
||||||
|
int dim = tensor.dim();
|
||||||
|
std::vector<int64_t> sizes;
|
||||||
|
sizes.reserve(dim);
|
||||||
|
sizes.push_back(batch_size);
|
||||||
|
for (int i = 1; i < dim; i++) {
|
||||||
|
sizes.push_back(tensor.size(i));
|
||||||
|
}
|
||||||
|
return at::empty_strided(std::move(sizes), tensor.strides(), tensor.options());
|
||||||
|
};
|
||||||
|
|
||||||
|
if (grad_input_mask[0]) {
|
||||||
|
final_grad_q = create_strided_output(query);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (grad_input_mask[1]) {
|
||||||
|
final_grad_k = create_strided_output(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (grad_input_mask[2]) {
|
||||||
|
final_grad_v = create_strided_output(value);
|
||||||
|
}
|
||||||
|
if (grad_input_mask[3] && attn_bias.defined()) {
|
||||||
|
final_grad_bias = at::zeros_like(attn_bias);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t start = 0; start < batch_size; start += MAX_BATCH_SIZE) {
|
||||||
|
int64_t end = std::min(start + MAX_BATCH_SIZE, batch_size);
|
||||||
|
|
||||||
|
Tensor grad_out_chunk = grad_out_t.slice(0, start, end);
|
||||||
|
Tensor query_chunk = query_t.slice(0, start, end);
|
||||||
|
Tensor key_chunk = key_t.slice(0, start, end);
|
||||||
|
Tensor value_chunk = value_t.slice(0, start, end);
|
||||||
|
Tensor attn_bias_chunk;
|
||||||
|
if (attn_bias.defined()) {
|
||||||
|
attn_bias_chunk = attn_bias.slice(0, start, end);
|
||||||
|
} else {
|
||||||
|
attn_bias_chunk.reset();
|
||||||
|
}
|
||||||
|
Tensor out_chunk = out_t.slice(0, start, end);
|
||||||
|
Tensor logsumexp_chunk = logsumexp.numel() > 0 ? logsumexp.slice(0, start, end) : logsumexp;
|
||||||
|
|
||||||
|
auto [chunk_grad_q, chunk_grad_k, chunk_grad_v, chunk_grad_bias] =
|
||||||
|
process_chunk(grad_out_chunk, query_chunk, key_chunk, value_chunk,
|
||||||
|
attn_bias_chunk, out_chunk, logsumexp_chunk);
|
||||||
|
|
||||||
|
if (grad_input_mask[0] && chunk_grad_q.defined()) {
|
||||||
|
final_grad_q.slice(0, start, end).copy_(chunk_grad_q);
|
||||||
|
}
|
||||||
|
if (grad_input_mask[1] && chunk_grad_k.defined()) {
|
||||||
|
final_grad_k.slice(0, start, end).copy_(chunk_grad_k);
|
||||||
|
}
|
||||||
|
if (grad_input_mask[2] && chunk_grad_v.defined()) {
|
||||||
|
final_grad_v.slice(0, start, end).copy_(chunk_grad_v);
|
||||||
|
}
|
||||||
|
if (grad_input_mask[3] && chunk_grad_bias.defined()) {
|
||||||
|
final_grad_bias.add_(chunk_grad_bias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(
|
||||||
|
std::move(final_grad_q),
|
||||||
|
std::move(final_grad_k),
|
||||||
|
std::move(final_grad_v),
|
||||||
|
std::move(final_grad_bias));
|
||||||
|
}
|
||||||
|
// when batch size is within allowed size, no chunking needed
|
||||||
|
else {
|
||||||
|
std::optional<Tensor> attn_bias_opt;
|
||||||
|
if (attn_bias.defined()) {
|
||||||
|
attn_bias_opt = attn_bias;
|
||||||
|
}
|
||||||
|
return process_chunk(grad_out_t, query_t, key_t, value_t, attn_bias_opt, out_t, logsumexp);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace at::native
|
} // namespace at::native
|
||||||
|
|||||||
@ -1018,9 +1018,6 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
|
|||||||
|
|
||||||
Writes to ./speedups.csv
|
Writes to ./speedups.csv
|
||||||
"""
|
"""
|
||||||
# if args.dynamic_shapes:
|
|
||||||
# return speedup_experiment_ds(args, model_iter_fn, model, example_inputs)
|
|
||||||
|
|
||||||
timings = np.zeros((args.repeat, 2), np.float64)
|
timings = np.zeros((args.repeat, 2), np.float64)
|
||||||
# if we randomize the input, we should also check the result is correct
|
# if we randomize the input, we should also check the result is correct
|
||||||
should_randomize_input = args.randomize_input
|
should_randomize_input = args.randomize_input
|
||||||
@ -1179,82 +1176,6 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
# WARNING: This code is currently dead
|
|
||||||
def speedup_experiment_ds(args, model_iter_fn, model, example_inputs):
|
|
||||||
"""
|
|
||||||
Run dynamic shapes benchmarks.
|
|
||||||
|
|
||||||
Requires dynamic shape compatible models, which provide a list of example inputs.
|
|
||||||
|
|
||||||
Warms up using the first input example and then iterates the inputs,
|
|
||||||
measuring (and expecting minimal) variance between the runtime for different examples.
|
|
||||||
|
|
||||||
"""
|
|
||||||
timings = np.zeros((args.repeat, len(example_inputs), 2), np.float64)
|
|
||||||
|
|
||||||
if args.repeat > 5:
|
|
||||||
print(
|
|
||||||
f"\ndynamic shapes experiments are slow, consider setting --repeat less than {args.repeat}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
nwarmup = 4
|
|
||||||
for rep in range(args.repeat):
|
|
||||||
# Start each rep fresh, e.g. only warmup on example 0
|
|
||||||
torch._dynamo.reset()
|
|
||||||
optimized_model_iter_fn = optimize_ctx(model_iter_fn)
|
|
||||||
for _ in range(nwarmup):
|
|
||||||
optimized_model_iter_fn(model, example_inputs[0])
|
|
||||||
|
|
||||||
for input_idx, inputs in enumerate(example_inputs):
|
|
||||||
# interleave the runs to handle frequency scaling and load changes
|
|
||||||
timings[rep, input_idx, 0] = timed(
|
|
||||||
model, model_iter_fn, inputs, return_result=False
|
|
||||||
)
|
|
||||||
# different from regular speedup_experiment, we _DO_ want to allow recompilation
|
|
||||||
timings[rep, input_idx, 1] = timed(
|
|
||||||
model, optimized_model_iter_fn, inputs, return_result=False
|
|
||||||
)
|
|
||||||
medians = np.median(timings, axis=0)
|
|
||||||
speedups = list(medians[:, 0] / medians[:, 1])
|
|
||||||
speedups_mean = np.mean(speedups)
|
|
||||||
speedups_median = np.median(speedups)
|
|
||||||
speedups_var = np.var(speedups)
|
|
||||||
|
|
||||||
# TODO this x[0] is not going to work in general but bert only has 1 input
|
|
||||||
shapes = [x[0].shape for x in example_inputs]
|
|
||||||
shape_keys = sorted(set(shapes))
|
|
||||||
shape_speedups = {
|
|
||||||
shape: [
|
|
||||||
it[1] for it in filter(lambda it: it[0] == shape, zip(shapes, speedups))
|
|
||||||
]
|
|
||||||
for shape in shape_keys
|
|
||||||
}
|
|
||||||
output_str = (
|
|
||||||
f"mean: {speedups_mean:.3f}, median: {speedups_median:.3f}, var: {speedups_var:.3f}"
|
|
||||||
+ "\nSpeedups by shape: "
|
|
||||||
+ "\n".join(
|
|
||||||
[
|
|
||||||
f"{shape}: "
|
|
||||||
+ ", ".join([f"{speedup: .3g}" for speedup in shape_speedups[shape]])
|
|
||||||
for shape in shape_keys
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
write_outputs(
|
|
||||||
output_filename,
|
|
||||||
("dev", "name", "batch_size", "speedup mean", "speedup median", "speedup var"),
|
|
||||||
[
|
|
||||||
current_device,
|
|
||||||
current_name,
|
|
||||||
current_batch_size,
|
|
||||||
speedups_mean,
|
|
||||||
speedups_median,
|
|
||||||
speedups_var,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
return output_str
|
|
||||||
|
|
||||||
|
|
||||||
def overhead_experiment(*args, model_iter_fn):
|
def overhead_experiment(*args, model_iter_fn):
|
||||||
"""
|
"""
|
||||||
Measure overheads of TorchDynamo by running with no backend (only
|
Measure overheads of TorchDynamo by running with no backend (only
|
||||||
|
|||||||
@ -54,12 +54,9 @@ class Benchmark(BenchmarkBase):
|
|||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
|
|
||||||
def _work(self):
|
def _work(self):
|
||||||
# enable_cpp_symbolic_shape_guards has impact on this benchmark
|
|
||||||
# Keep using False value for consistency.
|
|
||||||
with (
|
with (
|
||||||
fresh_inductor_cache(),
|
fresh_inductor_cache(),
|
||||||
torch._inductor.config.patch(force_shape_pad=self._force_shape_pad),
|
torch._inductor.config.patch(force_shape_pad=self._force_shape_pad),
|
||||||
torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards", False),
|
|
||||||
):
|
):
|
||||||
opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())(
|
opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())(
|
||||||
self.m.cuda() if self._is_gpu else self.m
|
self.m.cuda() if self._is_gpu else self.m
|
||||||
|
|||||||
@ -247,7 +247,10 @@ class BenchmarkBase(ABC):
|
|||||||
instruction_count=r,
|
instruction_count=r,
|
||||||
)
|
)
|
||||||
if self._enable_compile_time_instruction_count:
|
if self._enable_compile_time_instruction_count:
|
||||||
r = self._count_compile_time_instructions()
|
# enable_cpp_symbolic_shape_guards has impact on these benchmarks
|
||||||
|
# Keep using False value for consistency.
|
||||||
|
with config.patch("enable_cpp_symbolic_shape_guards", False):
|
||||||
|
r = self._count_compile_time_instructions()
|
||||||
|
|
||||||
self.results.append(
|
self.results.append(
|
||||||
(
|
(
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
add_loop_eager,compile_time_instruction_count,2953000000,0.015
|
add_loop_eager,compile_time_instruction_count,2937000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_eager_dynamic,compile_time_instruction_count,5738000000,0.025
|
add_loop_eager_dynamic,compile_time_instruction_count,4300194436,0.025
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -10,7 +10,7 @@ add_loop_inductor,compile_time_instruction_count,29370000000,0.015
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44490000000,0.025
|
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38747844521,0.025
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -18,15 +18,15 @@ add_loop_inductor_gpu,compile_time_instruction_count,25900000000,0.015
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,939900000,0.015
|
basic_modules_ListOfLinears_eager,compile_time_instruction_count,952700000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18270000000,0.015
|
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18390000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16310000000,0.015
|
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16450000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
update_hint_regression,compile_time_instruction_count,1700000000,0.02
|
update_hint_regression,compile_time_instruction_count,1661000000,0.02
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
float_args,compile_time_instruction_count,452500000,0.015
|
float_args,compile_time_instruction_count,455500000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -62,7 +62,7 @@ aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6022000000,0
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8672000000,0.015
|
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8724000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -70,7 +70,7 @@ aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1917000000,0.015
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3859000000,0.015
|
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3838000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
|
@ -80,7 +80,6 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) ->
|
|||||||
|
|
||||||
@dataclass(frozen=True, kw_only=True)
|
@dataclass(frozen=True, kw_only=True)
|
||||||
class ExperimentConfig:
|
class ExperimentConfig:
|
||||||
autotune_fallback_to_aten: bool = False
|
|
||||||
max_autotune: bool = True
|
max_autotune: bool = True
|
||||||
coordinate_descent_tuning: bool = True
|
coordinate_descent_tuning: bool = True
|
||||||
max_autotune_gemm_backends: str = "ATEN"
|
max_autotune_gemm_backends: str = "ATEN"
|
||||||
@ -91,7 +90,6 @@ class ExperimentConfig:
|
|||||||
|
|
||||||
def to_options(self) -> dict[str, Any]:
|
def to_options(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"autotune_fallback_to_aten": self.autotune_fallback_to_aten,
|
|
||||||
"max_autotune": self.max_autotune,
|
"max_autotune": self.max_autotune,
|
||||||
"coordinate_descent_tuning": self.coordinate_descent_tuning,
|
"coordinate_descent_tuning": self.coordinate_descent_tuning,
|
||||||
"max_autotune_gemm_backends": self.max_autotune_gemm_backends,
|
"max_autotune_gemm_backends": self.max_autotune_gemm_backends,
|
||||||
|
|||||||
@ -38,8 +38,8 @@ void c10_cuda_check_implementation(
|
|||||||
"Device-side assertions were explicitly omitted for this error check; the error probably arose while initializing the DSA handlers.");
|
"Device-side assertions were explicitly omitted for this error check; the error probably arose while initializing the DSA handlers.");
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
throw c10::AcceleratorError(
|
||||||
TORCH_CHECK(false, check_message);
|
{__func__, __FILE__, int32_t(__LINE__)}, err, check_message);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace c10::cuda
|
} // namespace c10::cuda
|
||||||
|
|||||||
@ -200,6 +200,7 @@ static void initGlobalStreamState() {
|
|||||||
// Init a single CUDA or HIP stream
|
// Init a single CUDA or HIP stream
|
||||||
// See Note [HIP Lazy Streams]
|
// See Note [HIP Lazy Streams]
|
||||||
static void initSingleStream(int p, DeviceIndex device_index, int i) {
|
static void initSingleStream(int p, DeviceIndex device_index, int i) {
|
||||||
|
CUDAGuard device_guard(device_index);
|
||||||
auto& stream = streams[p][device_index][i];
|
auto& stream = streams[p][device_index][i];
|
||||||
auto pri = -p; // lower number is higher priority
|
auto pri = -p; // lower number is higher priority
|
||||||
|
|
||||||
|
|||||||
@ -295,6 +295,19 @@ class C10_API SyntaxError : public Error {
|
|||||||
using Error::Error;
|
using Error::Error;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Raised when accelerator API call hits an error.
|
||||||
|
// These turn into AcceleratorError when the cross into Python
|
||||||
|
class C10_API AcceleratorError : public Error {
|
||||||
|
int32_t error_code;
|
||||||
|
|
||||||
|
public:
|
||||||
|
AcceleratorError(SourceLocation loc, int32_t code, const std::string& msg)
|
||||||
|
: Error(loc, msg), error_code(code) {}
|
||||||
|
int32_t get_error_code() const {
|
||||||
|
return error_code;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Base error type for all distributed errors.
|
// Base error type for all distributed errors.
|
||||||
// These turn into DistError when they cross into Python.
|
// These turn into DistError when they cross into Python.
|
||||||
class C10_API DistError : public Error {
|
class C10_API DistError : public Error {
|
||||||
|
|||||||
@ -133,8 +133,13 @@ inline void initGlobalDevicePoolState() {
|
|||||||
#else
|
#else
|
||||||
// The default context is utilized for each Intel GPU device, allowing the
|
// The default context is utilized for each Intel GPU device, allowing the
|
||||||
// retrieval of the context from any GPU device.
|
// retrieval of the context from any GPU device.
|
||||||
|
const auto& platform = gDevicePool.devices[0]->get_platform();
|
||||||
gDevicePool.context = std::make_unique<sycl::context>(
|
gDevicePool.context = std::make_unique<sycl::context>(
|
||||||
gDevicePool.devices[0]->get_platform().ext_oneapi_get_default_context());
|
#if SYCL_COMPILER_VERSION >= 20250200
|
||||||
|
platform.khr_get_default_context());
|
||||||
|
#else
|
||||||
|
platform.ext_oneapi_get_default_context());
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1063,7 +1063,7 @@ if(USE_ROCM)
|
|||||||
|
|
||||||
# Math libraries
|
# Math libraries
|
||||||
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
|
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
|
||||||
roc::hipblas roc::rocblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver roc::hipblaslt)
|
roc::hipblas roc::rocblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsparselt roc::hipsolver roc::hipblaslt)
|
||||||
|
|
||||||
# ---[ Kernel asserts
|
# ---[ Kernel asserts
|
||||||
# Kernel asserts is disabled for ROCm by default.
|
# Kernel asserts is disabled for ROCm by default.
|
||||||
|
|||||||
@ -57,7 +57,8 @@ if(NCCL_FOUND) # obtaining NCCL version and some sanity checks
|
|||||||
include(CheckCXXSymbolExists)
|
include(CheckCXXSymbolExists)
|
||||||
check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED)
|
check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED)
|
||||||
|
|
||||||
if (NCCL_VERSION_DEFINED)
|
# this condition check only works for non static NCCL linking
|
||||||
|
if (NCCL_VERSION_DEFINED AND NOT USE_STATIC_NCCL)
|
||||||
set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc")
|
set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc")
|
||||||
file(WRITE ${file} "
|
file(WRITE ${file} "
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
@ -65,7 +66,6 @@ if(NCCL_FOUND) # obtaining NCCL version and some sanity checks
|
|||||||
int main()
|
int main()
|
||||||
{
|
{
|
||||||
std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl;
|
std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl;
|
||||||
|
|
||||||
int x;
|
int x;
|
||||||
ncclGetVersion(&x);
|
ncclGetVersion(&x);
|
||||||
return x == NCCL_VERSION_CODE;
|
return x == NCCL_VERSION_CODE;
|
||||||
@ -80,11 +80,9 @@ if(NCCL_FOUND) # obtaining NCCL version and some sanity checks
|
|||||||
(include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.")
|
(include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.")
|
||||||
endif()
|
endif()
|
||||||
message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}")
|
message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}")
|
||||||
else()
|
|
||||||
message(STATUS "NCCL version < 2.3.5-5")
|
|
||||||
endif ()
|
endif ()
|
||||||
set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES})
|
|
||||||
|
|
||||||
|
set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES})
|
||||||
message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
||||||
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@ -151,6 +151,7 @@ if(HIP_FOUND)
|
|||||||
find_package_and_print_version(miopen REQUIRED)
|
find_package_and_print_version(miopen REQUIRED)
|
||||||
find_package_and_print_version(hipfft REQUIRED)
|
find_package_and_print_version(hipfft REQUIRED)
|
||||||
find_package_and_print_version(hipsparse REQUIRED)
|
find_package_and_print_version(hipsparse REQUIRED)
|
||||||
|
find_package_and_print_version(hipsparselt REQUIRED)
|
||||||
find_package_and_print_version(rocprim REQUIRED)
|
find_package_and_print_version(rocprim REQUIRED)
|
||||||
find_package_and_print_version(hipcub REQUIRED)
|
find_package_and_print_version(hipcub REQUIRED)
|
||||||
find_package_and_print_version(rocthrust REQUIRED)
|
find_package_and_print_version(rocthrust REQUIRED)
|
||||||
|
|||||||
@ -26,8 +26,8 @@ As shown in the CPU example section of :class:`torch.autocast`, "automatic mixed
|
|||||||
datatype of ``torch.bfloat16`` only uses :class:`torch.autocast`.
|
datatype of ``torch.bfloat16`` only uses :class:`torch.autocast`.
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
``torch.cuda.amp.autocast(args...)`` and ``torch.cpu.amp.autocast(args...)`` will be deprecated. Please use ``torch.autocast("cuda", args...)`` or ``torch.autocast("cpu", args...)`` instead.
|
``torch.cuda.amp.autocast(args...)`` and ``torch.cpu.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cuda", args...)`` or ``torch.amp.autocast("cpu", args...)`` instead.
|
||||||
``torch.cuda.amp.GradScaler(args...)`` and ``torch.cpu.amp.GradScaler(args...)`` will be deprecated. Please use ``torch.GradScaler("cuda", args...)`` or ``torch.GradScaler("cpu", args...)`` instead.
|
``torch.cuda.amp.GradScaler(args...)`` and ``torch.cpu.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cuda", args...)`` or ``torch.amp.GradScaler("cpu", args...)`` instead.
|
||||||
|
|
||||||
:class:`torch.autocast` and :class:`torch.cpu.amp.autocast` are new in version `1.10`.
|
:class:`torch.autocast` and :class:`torch.cpu.amp.autocast` are new in version `1.10`.
|
||||||
|
|
||||||
|
|||||||
@ -40,6 +40,7 @@ torch.cuda
|
|||||||
temperature
|
temperature
|
||||||
power_draw
|
power_draw
|
||||||
clock_rate
|
clock_rate
|
||||||
|
AcceleratorError
|
||||||
OutOfMemoryError
|
OutOfMemoryError
|
||||||
|
|
||||||
Random Number Generator
|
Random Number Generator
|
||||||
|
|||||||
@ -31,6 +31,7 @@ torch.fx.experimental.symbolic_shapes
|
|||||||
PropagateUnbackedSymInts
|
PropagateUnbackedSymInts
|
||||||
DivideByKey
|
DivideByKey
|
||||||
InnerTensorKey
|
InnerTensorKey
|
||||||
|
Specialization
|
||||||
|
|
||||||
hint_int
|
hint_int
|
||||||
is_concrete_int
|
is_concrete_int
|
||||||
|
|||||||
@ -360,8 +360,7 @@ Suppose we want to define a sparse tensor with the entry 3 at location
|
|||||||
Unspecified elements are assumed to have the same value, fill value,
|
Unspecified elements are assumed to have the same value, fill value,
|
||||||
which is zero by default. We would then write:
|
which is zero by default. We would then write:
|
||||||
|
|
||||||
>>> i = [[0, 1, 1],
|
>>> i = [[0, 1, 1], [2, 0, 2]]
|
||||||
[2, 0, 2]]
|
|
||||||
>>> v = [3, 4, 5]
|
>>> v = [3, 4, 5]
|
||||||
>>> s = torch.sparse_coo_tensor(i, v, (2, 3))
|
>>> s = torch.sparse_coo_tensor(i, v, (2, 3))
|
||||||
>>> s
|
>>> s
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# Updates Triton to the pinned version for this copy of PyTorch
|
# Updates Triton to the pinned version for this copy of PyTorch
|
||||||
|
PYTHON="python3"
|
||||||
|
PIP="$PYTHON -m pip"
|
||||||
BRANCH=$(git rev-parse --abbrev-ref HEAD)
|
BRANCH=$(git rev-parse --abbrev-ref HEAD)
|
||||||
DOWNLOAD_PYTORCH_ORG="https://download.pytorch.org/whl"
|
DOWNLOAD_PYTORCH_ORG="https://download.pytorch.org/whl"
|
||||||
|
|
||||||
@ -8,9 +10,9 @@ if [[ -z "${USE_XPU}" ]]; then
|
|||||||
|
|
||||||
TRITON_VERSION="pytorch-triton==$(cat .ci/docker/triton_version.txt)"
|
TRITON_VERSION="pytorch-triton==$(cat .ci/docker/triton_version.txt)"
|
||||||
if [[ "$BRANCH" =~ .*release.* ]]; then
|
if [[ "$BRANCH" =~ .*release.* ]]; then
|
||||||
pip install --index-url ${DOWNLOAD_PYTORCH_ORG}/test/ $TRITON_VERSION
|
${PIP} install --index-url ${DOWNLOAD_PYTORCH_ORG}/test/ $TRITON_VERSION
|
||||||
else
|
else
|
||||||
pip install --index-url ${DOWNLOAD_PYTORCH_ORG}/nightly/ $TRITON_VERSION+git$(head -c 8 .ci/docker/ci_commit_pins/triton.txt)
|
${PIP} install --index-url ${DOWNLOAD_PYTORCH_ORG}/nightly/ $TRITON_VERSION+git$(head -c 8 .ci/docker/ci_commit_pins/triton.txt)
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
# The Triton xpu logic is as follows:
|
# The Triton xpu logic is as follows:
|
||||||
@ -21,11 +23,11 @@ else
|
|||||||
TRITON_VERSION="pytorch-triton-xpu==$(cat .ci/docker/triton_version.txt)"
|
TRITON_VERSION="pytorch-triton-xpu==$(cat .ci/docker/triton_version.txt)"
|
||||||
TRITON_XPU_COMMIT_ID="$(head -c 8 .ci/docker/ci_commit_pins/triton-xpu.txt)"
|
TRITON_XPU_COMMIT_ID="$(head -c 8 .ci/docker/ci_commit_pins/triton-xpu.txt)"
|
||||||
if [[ -z "${TRITON_XPU_BUILD_FROM_SOURCE}" ]]; then
|
if [[ -z "${TRITON_XPU_BUILD_FROM_SOURCE}" ]]; then
|
||||||
pip install --index-url ${DOWNLOAD_PYTORCH_ORG}/nightly/ ${TRITON_VERSION}+git${TRITON_XPU_COMMIT_ID}
|
${PIP} install --index-url ${DOWNLOAD_PYTORCH_ORG}/nightly/ ${TRITON_VERSION}+git${TRITON_XPU_COMMIT_ID}
|
||||||
else
|
else
|
||||||
TRITON_XPU_REPO="https://github.com/intel/intel-xpu-backend-for-triton"
|
TRITON_XPU_REPO="https://github.com/intel/intel-xpu-backend-for-triton"
|
||||||
|
|
||||||
# force-reinstall to ensure the pinned version is installed
|
# force-reinstall to ensure the pinned version is installed
|
||||||
pip install --force-reinstall "git+${TRITON_XPU_REPO}@${TRITON_XPU_COMMIT_ID}#subdirectory=python"
|
${PIP} install --force-reinstall "git+${TRITON_XPU_REPO}@${TRITON_XPU_COMMIT_ID}#subdirectory=python"
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
# PyTorch Release Scripts
|
# PyTorch release scripts performing branch cut and applying release only changes
|
||||||
|
|
||||||
These are a collection of scripts that are to be used for release activities.
|
These are a collection of scripts that are to be used for release activities.
|
||||||
|
|
||||||
@ -7,54 +7,12 @@ These are a collection of scripts that are to be used for release activities.
|
|||||||
> The basic idea being that there should be no potential to do anything dangerous unless
|
> The basic idea being that there should be no potential to do anything dangerous unless
|
||||||
> `DRY_RUN` is explicitly set to `disabled`.
|
> `DRY_RUN` is explicitly set to `disabled`.
|
||||||
|
|
||||||
## Requirements to actually run these scripts
|
### Order of Execution
|
||||||
* AWS access to pytorch account
|
|
||||||
* Access to upload conda packages to the `pytorch` conda channel
|
|
||||||
* Access to the PyPI repositories
|
|
||||||
|
|
||||||
|
1. Run cut-release-branch.sh to cut the release branch
|
||||||
|
2. Run tag-docker-images.sh to tag current docker images with release tag and push them to docker.io. These images will be used to build the release.
|
||||||
|
3. Run apply-release-changes.sh to apply release only changes to create a PR with release only changes similar to this [PR](https://github.com/pytorch/pytorch/pull/149056)
|
||||||
|
|
||||||
## Promote
|
#### Promoting packages
|
||||||
|
|
||||||
These are scripts related to promotion of release candidates to GA channels, these
|
Scripts for Promotion of PyTorch packages are under test-infra repository. Please follow [README.md](https://github.com/pytorch/test-infra/blob/main/release/README.md)
|
||||||
can actually be used to promote pytorch, libtorch, and related domain libraries.
|
|
||||||
|
|
||||||
### Usage
|
|
||||||
|
|
||||||
Usage should be fairly straightforward and should actually require no extra variables
|
|
||||||
if you are running from the correct git tags. (i.e. the GA tag to promote is currently
|
|
||||||
checked out)
|
|
||||||
|
|
||||||
`PACKAGE_TYPE` and `PACKAGE_NAME` can be swapped out to promote other packages.
|
|
||||||
|
|
||||||
#### Promoting pytorch wheels
|
|
||||||
```bash
|
|
||||||
promote/s3_to_s3.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Promoting libtorch archives
|
|
||||||
```bash
|
|
||||||
PACKAGE_TYPE=libtorch PACKAGE_NAME=libtorch promote/s3_to_s3.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Promoting conda packages
|
|
||||||
```bash
|
|
||||||
promote/conda_to_conda.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Promoting wheels to PyPI
|
|
||||||
**WARNING**: These can only be run once and cannot be undone, run with caution
|
|
||||||
```
|
|
||||||
promote/wheel_to_pypi.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
## Restoring backups
|
|
||||||
|
|
||||||
All release candidates are currently backed up to `s3://pytorch-backup/${TAG_NAME}` and
|
|
||||||
can be restored to the test channels with the `restore-backup.sh` script.
|
|
||||||
|
|
||||||
Which backup to restore from is dictated by the `RESTORE_FROM` environment variable.
|
|
||||||
|
|
||||||
### Usage
|
|
||||||
```bash
|
|
||||||
RESTORE_FROM=v1.5.0-rc5 ./restore-backup.sh
|
|
||||||
```
|
|
||||||
|
|||||||
@ -1,61 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
exit_if_not_on_git_tag() {
|
|
||||||
# Have an override for debugging purposes
|
|
||||||
if [[ -n "${TEST_WITHOUT_GIT_TAG-}" ]] ;then
|
|
||||||
>&2 echo "+ WARN: Continuing without being on a git tag"
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
# Exit if we're not currently on a git tag
|
|
||||||
if ! git describe --tags --exact >/dev/null 2>/dev/null; then
|
|
||||||
>&2 echo "- ERROR: Attempting to promote on a non-git tag, must have tagged current commit locally first"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
# Exit if we're currently on an RC
|
|
||||||
if git describe --tags | grep "-rc" >/dev/null 2>/dev/null; then
|
|
||||||
>&2 echo "- ERROR: Attempting to promote on a non GA git tag, current tag must be a GA tag"
|
|
||||||
>&2 echo " Example: v1.5.0"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
get_pytorch_version() {
|
|
||||||
if [[ -n "${TEST_WITHOUT_GIT_TAG-}" ]];then
|
|
||||||
if [[ -z "${TEST_PYTORCH_PROMOTE_VERSION-}" ]]; then
|
|
||||||
>&2 echo "- ERROR: Specified TEST_WITHOUT_GIT_TAG without specifying TEST_PYTORCH_PROMOTE_VERSION"
|
|
||||||
>&2 echo "- TEST_PYTORCH_PROMOTE_VERSION must be specified"
|
|
||||||
exit 1
|
|
||||||
else
|
|
||||||
echo "${TEST_PYTORCH_PROMOTE_VERSION}"
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
exit_if_not_on_git_tag
|
|
||||||
# Echo git tag, strip leading v
|
|
||||||
git describe --tags | sed -e 's/^v//'
|
|
||||||
}
|
|
||||||
|
|
||||||
aws_promote() {
|
|
||||||
package_name=$1
|
|
||||||
pytorch_version=$(get_pytorch_version)
|
|
||||||
# Dry run by default
|
|
||||||
DRY_RUN=${DRY_RUN:-enabled}
|
|
||||||
DRY_RUN_FLAG="--dryrun"
|
|
||||||
if [[ $DRY_RUN = "disabled" ]]; then
|
|
||||||
DRY_RUN_FLAG=""
|
|
||||||
fi
|
|
||||||
AWS=${AWS:-aws}
|
|
||||||
(
|
|
||||||
set -x
|
|
||||||
${AWS} s3 cp ${DRY_RUN_FLAG} \
|
|
||||||
--only-show-errors \
|
|
||||||
--acl public-read \
|
|
||||||
--recursive \
|
|
||||||
--exclude '*' \
|
|
||||||
--include "*${package_name}-${pytorch_version}*" \
|
|
||||||
"${PYTORCH_S3_FROM/\/$//}" \
|
|
||||||
"${PYTORCH_S3_TO/\/$//}"
|
|
||||||
)
|
|
||||||
# ^ We grep for package_name-.*pytorch_version to avoid any situations where domain libraries have
|
|
||||||
# the same version on our S3 buckets
|
|
||||||
}
|
|
||||||
@ -1,45 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
# Preps binaries for publishing to pypi by removing the
|
|
||||||
# version suffix we normally add for all binaries
|
|
||||||
# (outside of default ones, CUDA 10.2 currently)
|
|
||||||
|
|
||||||
# Usage is:
|
|
||||||
# $ prep_binary_for_pypy.sh <path_to_whl_file> <path_to_multiple_whl_files>
|
|
||||||
|
|
||||||
# Will output a whl in your current directory
|
|
||||||
|
|
||||||
set -eou pipefail
|
|
||||||
shopt -s globstar
|
|
||||||
|
|
||||||
OUTPUT_DIR=${OUTPUT_DIR:-$(pwd)}
|
|
||||||
|
|
||||||
tmp_dir="$(mktemp -d)"
|
|
||||||
trap 'rm -rf ${tmp_dir}' EXIT
|
|
||||||
|
|
||||||
for whl_file in "$@"; do
|
|
||||||
whl_file=$(realpath "${whl_file}")
|
|
||||||
whl_dir="${tmp_dir}/$(basename "${whl_file}")_unzipped"
|
|
||||||
mkdir -pv "${whl_dir}"
|
|
||||||
(
|
|
||||||
set -x
|
|
||||||
unzip -q "${whl_file}" -d "${whl_dir}"
|
|
||||||
)
|
|
||||||
version_with_suffix=$(grep '^Version:' "${whl_dir}"/*/METADATA | cut -d' ' -f2)
|
|
||||||
version_with_suffix_escaped=${version_with_suffix/+/%2B}
|
|
||||||
# Remove all suffixed +bleh versions
|
|
||||||
version_no_suffix=${version_with_suffix/+*/}
|
|
||||||
new_whl_file=${OUTPUT_DIR}/$(basename "${whl_file/${version_with_suffix_escaped}/${version_no_suffix}}")
|
|
||||||
dist_info_folder=$(find "${whl_dir}" -type d -name '*.dist-info' | head -1)
|
|
||||||
basename_dist_info_folder=$(basename "${dist_info_folder}")
|
|
||||||
dirname_dist_info_folder=$(dirname "${dist_info_folder}")
|
|
||||||
(
|
|
||||||
set -x
|
|
||||||
find "${dist_info_folder}" -type f -exec sed -i "s!${version_with_suffix}!${version_no_suffix}!" {} \;
|
|
||||||
# Moves distinfo from one with a version suffix to one without
|
|
||||||
# Example: torch-1.8.0+cpu.dist-info => torch-1.8.0.dist-info
|
|
||||||
mv "${dist_info_folder}" "${dirname_dist_info_folder}/${basename_dist_info_folder/${version_with_suffix}/${version_no_suffix}}"
|
|
||||||
cd "${whl_dir}"
|
|
||||||
zip -qr "${new_whl_file}" .
|
|
||||||
)
|
|
||||||
done
|
|
||||||
@ -1,19 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
set -eou pipefail
|
|
||||||
|
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
|
||||||
source "${DIR}/common_utils.sh"
|
|
||||||
|
|
||||||
# Allow for users to pass PACKAGE_NAME
|
|
||||||
# For use with other packages, i.e. torchvision, etc.
|
|
||||||
PACKAGE_NAME=${PACKAGE_NAME:-torch}
|
|
||||||
PACKAGE_TYPE=${PACKAGE_TYPE:-whl}
|
|
||||||
|
|
||||||
PYTORCH_S3_BUCKET=${PYTORCH_S3_BUCKET:-s3://pytorch}
|
|
||||||
FROM=${FROM:-test}
|
|
||||||
PYTORCH_S3_FROM=${PYTORCH_S3_FROM:-${PYTORCH_S3_BUCKET}/${PACKAGE_TYPE}/${FROM}}
|
|
||||||
TO=${TO:-}
|
|
||||||
PYTORCH_S3_TO=${PYTORCH_S3_TO:-${PYTORCH_S3_BUCKET}/${PACKAGE_TYPE}/${TO}}
|
|
||||||
|
|
||||||
aws_promote "${PACKAGE_NAME}"
|
|
||||||
@ -1,69 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
set -eou pipefail
|
|
||||||
|
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
|
||||||
source "${DIR}/common_utils.sh"
|
|
||||||
|
|
||||||
# Allow for users to pass PACKAGE_NAME
|
|
||||||
# For use with other packages, i.e. torchvision, etc.
|
|
||||||
PACKAGE_NAME=${PACKAGE_NAME:-torch}
|
|
||||||
|
|
||||||
pytorch_version="$(get_pytorch_version)"
|
|
||||||
# Refers to the specific package we'd like to promote
|
|
||||||
# i.e. VERSION_SUFFIX='%2Bcu102'
|
|
||||||
# torch-1.8.0+cu102 -> torch-1.8.0
|
|
||||||
VERSION_SUFFIX=${VERSION_SUFFIX:-}
|
|
||||||
# Refers to the specific platofmr we'd like to promote
|
|
||||||
# i.e. PLATFORM=linux_x86_64
|
|
||||||
# For domains like torchaudio / torchtext this is to be left blank
|
|
||||||
PLATFORM=${PLATFORM:-}
|
|
||||||
|
|
||||||
pkgs_to_promote=$(\
|
|
||||||
curl -fsSL https://download.pytorch.org/whl/torch_stable.html \
|
|
||||||
| grep "${PACKAGE_NAME}-${pytorch_version}${VERSION_SUFFIX}-" \
|
|
||||||
| grep "${PLATFORM}" \
|
|
||||||
| cut -d '"' -f2
|
|
||||||
)
|
|
||||||
|
|
||||||
tmp_dir="$(mktemp -d)"
|
|
||||||
output_tmp_dir="$(mktemp -d)"
|
|
||||||
trap 'rm -rf ${tmp_dir} ${output_tmp_dir}' EXIT
|
|
||||||
pushd "${output_tmp_dir}"
|
|
||||||
|
|
||||||
# Dry run by default
|
|
||||||
DRY_RUN=${DRY_RUN:-enabled}
|
|
||||||
# On dry run just echo the commands that are meant to be run
|
|
||||||
TWINE_UPLOAD="echo twine upload"
|
|
||||||
if [[ $DRY_RUN = "disabled" ]]; then
|
|
||||||
TWINE_UPLOAD="twine upload"
|
|
||||||
fi
|
|
||||||
|
|
||||||
for pkg in ${pkgs_to_promote}; do
|
|
||||||
pkg_basename="$(basename "${pkg}")"
|
|
||||||
# Don't attempt to change if manylinux2014
|
|
||||||
if [[ "${pkg}" != *manylinux2014* ]]; then
|
|
||||||
pkg_basename="$(basename "${pkg//linux/manylinux1}")"
|
|
||||||
fi
|
|
||||||
orig_pkg="${tmp_dir}/${pkg_basename}"
|
|
||||||
(
|
|
||||||
set -x
|
|
||||||
# Download package, sub out linux for manylinux1
|
|
||||||
curl -fsSL -o "${orig_pkg}" "https://download.pytorch.org/whl/${pkg}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if [[ -n "${VERSION_SUFFIX}" ]]; then
|
|
||||||
OUTPUT_DIR="${output_tmp_dir}" ${DIR}/prep_binary_for_pypi.sh "${orig_pkg}"
|
|
||||||
else
|
|
||||||
mv "${orig_pkg}" "${output_tmp_dir}/"
|
|
||||||
fi
|
|
||||||
|
|
||||||
(
|
|
||||||
set -x
|
|
||||||
${TWINE_UPLOAD} \
|
|
||||||
--disable-progress-bar \
|
|
||||||
--non-interactive \
|
|
||||||
./*.whl
|
|
||||||
rm -rf ./*.whl
|
|
||||||
)
|
|
||||||
done
|
|
||||||
@ -1,31 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
set -eou pipefail
|
|
||||||
|
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
|
||||||
source "${DIR}/promote/common_utils.sh"
|
|
||||||
|
|
||||||
if [[ -z "${RESTORE_FROM:-}" ]]; then
|
|
||||||
echo "ERROR: RESTORE_FROM environment variable must be specified"
|
|
||||||
echo " example: RESTORE_FROM=v1.6.0-rc3 ${0}"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
DRY_RUN=${DRY_RUN:-enabled}
|
|
||||||
|
|
||||||
PYTORCH_S3_BACKUP_BUCKET=${PYTORCH_S3_BACKUP_BUCKET:-s3://pytorch-backup/${RESTORE_FROM}}
|
|
||||||
PYTORCH_S3_TEST_BUCKET=${PYTORCH_S3_TEST_BUCKET:-s3://pytorch/}
|
|
||||||
PYTORCH_S3_FROM=${PYTORCH_S3_FROM:-${PYTORCH_S3_BACKUP_BUCKET}}
|
|
||||||
PYTORCH_S3_TO=${PYTORCH_S3_TO:-s3://pytorch/}
|
|
||||||
|
|
||||||
restore_wheels() {
|
|
||||||
aws_promote torch whl
|
|
||||||
}
|
|
||||||
|
|
||||||
restore_libtorch() {
|
|
||||||
aws_promote libtorch-* libtorch
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
restore_wheels
|
|
||||||
restore_libtorch
|
|
||||||
@ -1322,5 +1322,33 @@ class TestFullyShardOldImport(FSDPTestMultiThread):
|
|||||||
model(inp).sum().backward()
|
model(inp).sum().backward()
|
||||||
|
|
||||||
|
|
||||||
|
class TestFullyShardMixedDtypeParam(FSDPTestMultiThread):
|
||||||
|
@property
|
||||||
|
def world_size(self) -> int:
|
||||||
|
return 2
|
||||||
|
|
||||||
|
@skip_if_lt_x_gpu(2)
|
||||||
|
def test_mixed_dtypes_no_grad_param(self):
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# no grad params with different dtypes
|
||||||
|
self.w_fp8 = torch.nn.Parameter(
|
||||||
|
torch.empty((256, 256), dtype=torch.float8_e4m3fn),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.w_fp32 = torch.nn.Parameter(
|
||||||
|
torch.empty((256, 256), dtype=torch.float32)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return
|
||||||
|
|
||||||
|
mesh = init_device_mesh(device_type.type, (self.world_size,))
|
||||||
|
model = Model()
|
||||||
|
fully_shard(model, mesh=mesh)
|
||||||
|
model(0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|||||||
@ -1,135 +0,0 @@
|
|||||||
# Owner(s): ["oncall: distributed_checkpointing"]
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed.checkpoint as dist_cp
|
|
||||||
from torch import distributed as dist
|
|
||||||
from torch.distributed.checkpoint.scripts._consolidate_hf_safetensors import (
|
|
||||||
consolidate_safetensors_files,
|
|
||||||
)
|
|
||||||
from torch.distributed.device_mesh import init_device_mesh
|
|
||||||
from torch.distributed.tensor import DTensor, Shard
|
|
||||||
from torch.testing._internal.common_utils import run_tests
|
|
||||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
||||||
DTensorTestBase,
|
|
||||||
skip_if_lt_x_gpu,
|
|
||||||
with_comms,
|
|
||||||
)
|
|
||||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
|
||||||
|
|
||||||
|
|
||||||
class TestConsolidateHFSafeTensors(DTensorTestBase):
|
|
||||||
def _create_d_tensors(self) -> None:
|
|
||||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
|
||||||
mesh_shape = (self.world_size,)
|
|
||||||
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
|
|
||||||
|
|
||||||
# Create local tensor with row-wise sharding
|
|
||||||
rows_per_rank = global_tensor.shape[0] // self.world_size
|
|
||||||
start_row = self.rank * rows_per_rank
|
|
||||||
end_row = start_row + rows_per_rank
|
|
||||||
local_tensor = global_tensor[start_row:end_row].clone()
|
|
||||||
|
|
||||||
# Create DTensor with row-wise sharding
|
|
||||||
dtensor = DTensor.from_local(
|
|
||||||
local_tensor,
|
|
||||||
device_mesh=mesh_1d,
|
|
||||||
placements=[Shard(0)],
|
|
||||||
shape=global_tensor.shape,
|
|
||||||
stride=(4, 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create local tensor with column-wise sharding
|
|
||||||
cols_per_rank = global_tensor.shape[1] // self.world_size
|
|
||||||
start_col = self.rank * cols_per_rank
|
|
||||||
end_col = start_col + cols_per_rank
|
|
||||||
local_tensor_col = global_tensor[:, start_col:end_col].clone()
|
|
||||||
|
|
||||||
# Create DTensor with column-wise sharding
|
|
||||||
dtensor_col = DTensor.from_local(
|
|
||||||
local_tensor_col,
|
|
||||||
device_mesh=mesh_1d,
|
|
||||||
placements=[Shard(1)], # Column-wise sharding
|
|
||||||
shape=global_tensor.shape,
|
|
||||||
stride=(4, 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
state_dict_to_save = {"dtensor": dtensor, "dtensor_col": dtensor_col}
|
|
||||||
dist_cp.save(
|
|
||||||
state_dict=state_dict_to_save,
|
|
||||||
storage_writer=dist_cp._HuggingFaceStorageWriter(
|
|
||||||
path=self.temp_dir, save_sharded=True
|
|
||||||
),
|
|
||||||
)
|
|
||||||
dist.barrier()
|
|
||||||
os.sync()
|
|
||||||
|
|
||||||
@with_comms
|
|
||||||
@with_temp_dir
|
|
||||||
@skip_if_lt_x_gpu(2)
|
|
||||||
def test_consolidate_to_one_file(self) -> None:
|
|
||||||
try:
|
|
||||||
import safetensors
|
|
||||||
except ImportError:
|
|
||||||
print("safetensors not installed")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
checkpoint_dir = self.temp_dir
|
|
||||||
output_dir = os.path.join(checkpoint_dir, "consolidated")
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
self._create_d_tensors()
|
|
||||||
|
|
||||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
|
||||||
|
|
||||||
if self.rank == 0:
|
|
||||||
consolidate_safetensors_files(checkpoint_dir, output_dir)
|
|
||||||
|
|
||||||
file_path = os.path.join(output_dir, "model-00001-of-00001.safetensors")
|
|
||||||
loaded_dict = safetensors.torch.load_file(file_path)
|
|
||||||
self.assertEqual(loaded_dict.keys(), {"dtensor", "dtensor_col"})
|
|
||||||
self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor))
|
|
||||||
self.assertTrue(torch.equal(loaded_dict["dtensor_col"], global_tensor))
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
@with_comms
|
|
||||||
@with_temp_dir
|
|
||||||
@skip_if_lt_x_gpu(2)
|
|
||||||
def test_consolidate_to_two_files(self):
|
|
||||||
try:
|
|
||||||
import safetensors
|
|
||||||
except ImportError:
|
|
||||||
print("safetensors not installed")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
checkpoint_dir = self.temp_dir
|
|
||||||
output_dir = os.path.join(checkpoint_dir, "consolidated")
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
self._create_d_tensors()
|
|
||||||
|
|
||||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
|
||||||
|
|
||||||
if self.rank == 0:
|
|
||||||
fqn_to_index_mapping = {"dtensor": 1, "dtensor_col": 2}
|
|
||||||
consolidate_safetensors_files(
|
|
||||||
checkpoint_dir, output_dir, fqn_to_index_mapping
|
|
||||||
)
|
|
||||||
|
|
||||||
file1_path = os.path.join(output_dir, "model-00001-of-00002.safetensors")
|
|
||||||
file2_path = os.path.join(output_dir, "model-00002-of-00002.safetensors")
|
|
||||||
|
|
||||||
loaded_dict = safetensors.torch.load_file(file1_path)
|
|
||||||
self.assertEqual(loaded_dict.keys(), {"dtensor"})
|
|
||||||
self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor))
|
|
||||||
|
|
||||||
loaded_dict_col = safetensors.torch.load_file(file2_path)
|
|
||||||
self.assertEqual(loaded_dict_col.keys(), {"dtensor_col"})
|
|
||||||
self.assertTrue(torch.equal(loaded_dict_col["dtensor_col"], global_tensor))
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
run_tests()
|
|
||||||
@ -1,420 +0,0 @@
|
|||||||
# Owner(s): ["oncall: distributed_checkpointing"]
|
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed.checkpoint as dist_cp
|
|
||||||
from torch.distributed.checkpoint import _HuggingFaceLoadPlanner
|
|
||||||
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
|
|
||||||
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys
|
|
||||||
from torch.distributed.device_mesh import init_device_mesh
|
|
||||||
from torch.distributed.tensor import distribute_tensor, Replicate, Shard, zeros
|
|
||||||
from torch.testing._internal.common_utils import (
|
|
||||||
instantiate_parametrized_tests,
|
|
||||||
run_tests,
|
|
||||||
TestCase,
|
|
||||||
)
|
|
||||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
||||||
DTensorTestBase,
|
|
||||||
skip_if_lt_x_gpu,
|
|
||||||
with_comms,
|
|
||||||
)
|
|
||||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
|
||||||
|
|
||||||
|
|
||||||
CHECKPOINT_DIR = "checkpoint"
|
|
||||||
|
|
||||||
|
|
||||||
class MyTestModule(torch.nn.Module):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.linear_1 = torch.nn.Linear(5, 5)
|
|
||||||
self.linear_2 = torch.nn.Linear(5, 1)
|
|
||||||
self.emb = torch.nn.EmbeddingBag(5, 10)
|
|
||||||
|
|
||||||
class TestSingleRankSaveLoad(TestCase):
|
|
||||||
@with_temp_dir
|
|
||||||
def test_save(self) -> None:
|
|
||||||
try:
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
except ImportError:
|
|
||||||
print("safetensors not installed")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
CHECKPOINT_DIR = self.temp_dir
|
|
||||||
|
|
||||||
state_dict_to_save = MyTestModule().state_dict()
|
|
||||||
dist_cp.save(
|
|
||||||
state_dict=state_dict_to_save,
|
|
||||||
storage_writer=dist_cp._HuggingFaceStorageWriter(
|
|
||||||
path=CHECKPOINT_DIR
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
state_dict_loaded = load_file(CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
|
|
||||||
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys()))
|
|
||||||
for key in state_dict_to_save.keys():
|
|
||||||
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_loaded[key]))
|
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_load(self) -> None:
|
|
||||||
try:
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
except ImportError:
|
|
||||||
print("safetensors not installed")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
CHECKPOINT_DIR = self.temp_dir
|
|
||||||
|
|
||||||
state_dict_to_save = MyTestModule().state_dict()
|
|
||||||
state_dict_to_load = MyTestModule().state_dict()
|
|
||||||
save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
|
|
||||||
|
|
||||||
dist_cp.load(
|
|
||||||
state_dict=state_dict_to_load,
|
|
||||||
storage_reader=dist_cp._HuggingFaceStorageReader(
|
|
||||||
path=CHECKPOINT_DIR
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys()))
|
|
||||||
for key in state_dict_to_save.keys():
|
|
||||||
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_to_load[key]))
|
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_load_into_empty_dict(self) -> None:
|
|
||||||
try:
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
except ImportError:
|
|
||||||
print("safetensors not installed")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
CHECKPOINT_DIR = self.temp_dir
|
|
||||||
|
|
||||||
state_dict_to_save = MyTestModule().state_dict()
|
|
||||||
save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
|
|
||||||
|
|
||||||
state_dict_loaded = _load_state_dict_from_keys(
|
|
||||||
storage_reader=dist_cp._HuggingFaceStorageReader(
|
|
||||||
path=CHECKPOINT_DIR
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys()))
|
|
||||||
for key in state_dict_to_save.keys():
|
|
||||||
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_loaded[key]))
|
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_load_allowing_resize(self) -> None:
|
|
||||||
try:
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
except ImportError:
|
|
||||||
print("safetensors not installed")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
CHECKPOINT_DIR = self.temp_dir
|
|
||||||
|
|
||||||
state_dict_to_save = MyTestModule().state_dict()
|
|
||||||
save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
|
|
||||||
|
|
||||||
state_dict_to_load= {}
|
|
||||||
for key in state_dict_to_save.keys():
|
|
||||||
state_dict_to_load[key] = torch.zeros(1)
|
|
||||||
|
|
||||||
dist_cp.load(
|
|
||||||
state_dict=state_dict_to_load,
|
|
||||||
storage_reader=dist_cp._HuggingFaceStorageReader(
|
|
||||||
path=CHECKPOINT_DIR
|
|
||||||
),
|
|
||||||
planner=_HuggingFaceLoadPlanner(allow_tensor_resize=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys()))
|
|
||||||
for key in state_dict_to_save.keys():
|
|
||||||
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_to_load[key]))
|
|
||||||
|
|
||||||
ONE_D_PLACEMENTS = [
|
|
||||||
[Shard(0)],
|
|
||||||
[Replicate()],
|
|
||||||
]
|
|
||||||
ONE_D_TO_ONE_D_PLACEMENTS = [
|
|
||||||
([Replicate()], [Shard(0)]),
|
|
||||||
([Shard(0)], [Replicate()]),
|
|
||||||
]
|
|
||||||
|
|
||||||
TWO_D_PLACEMENTS = [
|
|
||||||
[Replicate(), Replicate()],
|
|
||||||
[Replicate(), Shard(0)],
|
|
||||||
[Shard(0), Replicate()],
|
|
||||||
[Shard(0), Shard(0)],
|
|
||||||
]
|
|
||||||
TWO_D_TO_TWO_D_PLACEMENTS = []
|
|
||||||
for p1 in TWO_D_PLACEMENTS:
|
|
||||||
for p2 in TWO_D_PLACEMENTS:
|
|
||||||
if p1 != p2:
|
|
||||||
TWO_D_TO_TWO_D_PLACEMENTS.append((p1, p2))
|
|
||||||
|
|
||||||
|
|
||||||
@instantiate_parametrized_tests
|
|
||||||
class TestDTensorReshardPlacementChange(DTensorTestBase):
|
|
||||||
"""
|
|
||||||
Test DCP reshard for DTensor with placements changes and without world_size change and mesh_tensor change.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@with_comms
|
|
||||||
@skip_if_lt_x_gpu(2)
|
|
||||||
@with_temp_dir
|
|
||||||
def test_1d_to_1d_reshard_placement_change(self) -> None:
|
|
||||||
try:
|
|
||||||
import safetensors
|
|
||||||
except ImportError:
|
|
||||||
print("safetensors not installed")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
CHECKPOINT_DIR = self.temp_dir
|
|
||||||
|
|
||||||
for one_d_to_one_d_placements in ONE_D_TO_ONE_D_PLACEMENTS:
|
|
||||||
original_placement, new_placement = one_d_to_one_d_placements
|
|
||||||
|
|
||||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
|
||||||
mesh_shape = (self.world_size,)
|
|
||||||
device_mesh = init_device_mesh(self.device_type, mesh_shape)
|
|
||||||
dtensor = distribute_tensor(
|
|
||||||
global_tensor, device_mesh, placements=original_placement
|
|
||||||
)
|
|
||||||
state_dict_to_save = {"dtensor": dtensor}
|
|
||||||
|
|
||||||
dist_cp.save(
|
|
||||||
state_dict=state_dict_to_save,
|
|
||||||
storage_writer=dist_cp._HuggingFaceStorageWriter(
|
|
||||||
path=CHECKPOINT_DIR,
|
|
||||||
save_sharded=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
zero_dtensor = zeros(
|
|
||||||
[4, 4], device_mesh=device_mesh, placements=new_placement
|
|
||||||
)
|
|
||||||
state_dict_to_load = {"dtensor": zero_dtensor}
|
|
||||||
|
|
||||||
dist_cp.load(
|
|
||||||
state_dict=state_dict_to_load,
|
|
||||||
storage_reader=dist_cp._HuggingFaceStorageReader(
|
|
||||||
CHECKPOINT_DIR,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# materialize the whole tensor to compare with the original global_tensor
|
|
||||||
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
|
|
||||||
device_mesh,
|
|
||||||
placements=[Replicate()],
|
|
||||||
)
|
|
||||||
self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local())
|
|
||||||
|
|
||||||
# redistribute the tensor back to its original placement for comparison.
|
|
||||||
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
|
|
||||||
device_mesh,
|
|
||||||
placements=original_placement,
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
state_dict_to_save["dtensor"].to_local(),
|
|
||||||
state_dict_to_load["dtensor"].to_local(),
|
|
||||||
)
|
|
||||||
|
|
||||||
@with_comms
|
|
||||||
@skip_if_lt_x_gpu(4)
|
|
||||||
@with_temp_dir
|
|
||||||
def test_2d_to_2d_reshard_placement_change(self) -> None:
|
|
||||||
try:
|
|
||||||
import safetensors
|
|
||||||
except ImportError:
|
|
||||||
print("safetensors not installed")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
CHECKPOINT_DIR = self.temp_dir
|
|
||||||
for two_d_to_two_d_placements in TWO_D_TO_TWO_D_PLACEMENTS:
|
|
||||||
original_placement, new_placement = two_d_to_two_d_placements
|
|
||||||
|
|
||||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
|
||||||
mesh_shape = (2, self.world_size // 2)
|
|
||||||
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
|
|
||||||
dtensor = distribute_tensor(
|
|
||||||
global_tensor,
|
|
||||||
mesh_2d,
|
|
||||||
placements=original_placement,
|
|
||||||
)
|
|
||||||
state_dict_to_save = {"dtensor": dtensor}
|
|
||||||
|
|
||||||
dist_cp.save(
|
|
||||||
state_dict=state_dict_to_save,
|
|
||||||
storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True),
|
|
||||||
planner=dist_cp.DefaultSavePlanner(),
|
|
||||||
)
|
|
||||||
|
|
||||||
zero_dtensor = zeros([4, 4], device_mesh=mesh_2d, placements=new_placement)
|
|
||||||
state_dict_to_load = {"dtensor": zero_dtensor}
|
|
||||||
|
|
||||||
dist_cp.load(
|
|
||||||
state_dict=state_dict_to_load,
|
|
||||||
storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR),
|
|
||||||
)
|
|
||||||
|
|
||||||
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
|
|
||||||
mesh_2d,
|
|
||||||
placements=[Replicate(), Replicate()],
|
|
||||||
)
|
|
||||||
self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local())
|
|
||||||
|
|
||||||
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
|
|
||||||
mesh_2d,
|
|
||||||
placements=original_placement,
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
state_dict_to_save["dtensor"].to_local(),
|
|
||||||
state_dict_to_load["dtensor"].to_local(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDTensorReshardMeshChange(DTensorTestBase):
|
|
||||||
"""
|
|
||||||
Test DCP reshard for DTensor with placements changes and mesh_tensor change.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@with_comms
|
|
||||||
@with_temp_dir
|
|
||||||
@skip_if_lt_x_gpu(2)
|
|
||||||
def test_1d_to_2d_reshard_mesh_change(self) -> None:
|
|
||||||
try:
|
|
||||||
import safetensors
|
|
||||||
except ImportError:
|
|
||||||
print("safetensors not installed")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
CHECKPOINT_DIR = self.temp_dir
|
|
||||||
for placements_1d in ONE_D_PLACEMENTS:
|
|
||||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
|
||||||
mesh_shape = (self.world_size,)
|
|
||||||
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
|
|
||||||
dtensor = distribute_tensor(
|
|
||||||
global_tensor, mesh_1d, placements=placements_1d
|
|
||||||
)
|
|
||||||
state_dict_to_save = {"dtensor": dtensor}
|
|
||||||
|
|
||||||
dist_cp.save(
|
|
||||||
state_dict=state_dict_to_save,
|
|
||||||
storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
for placements_2d in TWO_D_PLACEMENTS:
|
|
||||||
mesh_shape = (2, self.world_size // 2)
|
|
||||||
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
|
|
||||||
|
|
||||||
zero_dtensor = zeros(
|
|
||||||
[4, 4], device_mesh=mesh_2d, placements=placements_2d
|
|
||||||
)
|
|
||||||
state_dict_to_load = {"dtensor": zero_dtensor}
|
|
||||||
|
|
||||||
dist_cp.load(
|
|
||||||
state_dict=state_dict_to_load,
|
|
||||||
storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR),
|
|
||||||
planner=dist_cp.DefaultLoadPlanner(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# materialzie the whole tensor to compare with the original global_tensor
|
|
||||||
state_dict_to_load["dtensor"] = state_dict_to_load[
|
|
||||||
"dtensor"
|
|
||||||
].redistribute(
|
|
||||||
mesh_2d,
|
|
||||||
placements=[Replicate(), Replicate()],
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
global_tensor, state_dict_to_load["dtensor"].to_local()
|
|
||||||
)
|
|
||||||
|
|
||||||
@with_comms
|
|
||||||
@with_temp_dir
|
|
||||||
@skip_if_lt_x_gpu(4)
|
|
||||||
def test_2d_to_1d_reshard_mesh_change(self) -> None:
|
|
||||||
try:
|
|
||||||
import safetensors
|
|
||||||
except ImportError:
|
|
||||||
print("safetensors not installed")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
CHECKPOINT_DIR = self.temp_dir
|
|
||||||
for placements_2d in TWO_D_PLACEMENTS:
|
|
||||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
|
||||||
mesh_shape = (2, self.world_size // 2)
|
|
||||||
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
|
|
||||||
dtensor = distribute_tensor(
|
|
||||||
global_tensor, mesh_2d, placements=placements_2d
|
|
||||||
)
|
|
||||||
state_dict_to_save = {"dtensor": dtensor}
|
|
||||||
|
|
||||||
dist_cp.save(
|
|
||||||
state_dict=state_dict_to_save,
|
|
||||||
storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True),
|
|
||||||
planner=dist_cp.DefaultSavePlanner(),
|
|
||||||
)
|
|
||||||
|
|
||||||
for placements_1d in ONE_D_PLACEMENTS:
|
|
||||||
mesh_shape = (self.world_size,)
|
|
||||||
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
|
|
||||||
|
|
||||||
zero_dtensor = zeros(
|
|
||||||
[4, 4], device_mesh=mesh_1d, placements=placements_1d
|
|
||||||
)
|
|
||||||
state_dict_to_load = {"dtensor": zero_dtensor}
|
|
||||||
|
|
||||||
dist_cp.load(
|
|
||||||
state_dict=state_dict_to_load,
|
|
||||||
storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR),
|
|
||||||
planner=dist_cp.DefaultLoadPlanner(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# materialzie the whole tensor to compare with the original global_tensor
|
|
||||||
state_dict_to_load["dtensor"] = state_dict_to_load[
|
|
||||||
"dtensor"
|
|
||||||
].redistribute(
|
|
||||||
mesh_1d,
|
|
||||||
placements=[Replicate()],
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
global_tensor, state_dict_to_load["dtensor"].to_local()
|
|
||||||
)
|
|
||||||
|
|
||||||
@with_comms
|
|
||||||
@with_temp_dir
|
|
||||||
@skip_if_lt_x_gpu(2)
|
|
||||||
def test_dtensor_checkpoint_resharding_with_empty_shard(self):
|
|
||||||
"""
|
|
||||||
Test dtensor checkpoint resharding with dtensor containing empty shards.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import safetensors
|
|
||||||
except ImportError:
|
|
||||||
print("safetensors not installed")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
tensor = torch.rand(1).cuda()
|
|
||||||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
|
||||||
dtensor = distribute_tensor(tensor, mesh, [Shard(0)])
|
|
||||||
ref_state_dict = {"dtensor": dtensor}
|
|
||||||
|
|
||||||
dist_cp.save(
|
|
||||||
state_dict=ref_state_dict,
|
|
||||||
storage_writer=dist_cp._HuggingFaceStorageWriter(path=self.temp_dir, save_sharded=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
tensor = torch.rand(1).cuda()
|
|
||||||
mesh_2 = init_device_mesh(self.device_type, (2, self.world_size // 2))
|
|
||||||
dtensor = distribute_tensor(tensor, mesh_2, [Shard(0), Shard(0)])
|
|
||||||
state_dict = {"dtensor": dtensor}
|
|
||||||
dist_cp.load(
|
|
||||||
state_dict=state_dict,
|
|
||||||
storage_reader=dist_cp._HuggingFaceStorageReader(self.temp_dir),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
run_tests()
|
|
||||||
@ -8,7 +8,10 @@ import tempfile
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed.checkpoint import DefaultLoadPlanner
|
from torch.distributed.checkpoint._hf_planner import (
|
||||||
|
_FqnToFileMapping,
|
||||||
|
_HuggingFaceLoadPlanner,
|
||||||
|
)
|
||||||
from torch.distributed.checkpoint._hf_storage import (
|
from torch.distributed.checkpoint._hf_storage import (
|
||||||
_HuggingFaceStorageReader,
|
_HuggingFaceStorageReader,
|
||||||
_HuggingFaceStorageWriter,
|
_HuggingFaceStorageWriter,
|
||||||
@ -18,19 +21,14 @@ from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
|
|||||||
from torch.distributed.checkpoint.filesystem import _StorageInfo, FileSystem
|
from torch.distributed.checkpoint.filesystem import _StorageInfo, FileSystem
|
||||||
from torch.distributed.checkpoint.metadata import (
|
from torch.distributed.checkpoint.metadata import (
|
||||||
BytesStorageMetadata,
|
BytesStorageMetadata,
|
||||||
ChunkStorageMetadata,
|
|
||||||
Metadata,
|
Metadata,
|
||||||
MetadataIndex,
|
MetadataIndex,
|
||||||
TensorProperties,
|
|
||||||
TensorStorageMetadata,
|
|
||||||
)
|
)
|
||||||
from torch.distributed.checkpoint.planner import (
|
from torch.distributed.checkpoint.planner import LoadPlan, SavePlan
|
||||||
LoadItemType,
|
from torch.distributed.checkpoint.planner_helpers import (
|
||||||
LoadPlan,
|
_create_read_items,
|
||||||
ReadItem,
|
_create_write_item_for_tensor,
|
||||||
SavePlan,
|
|
||||||
)
|
)
|
||||||
from torch.distributed.checkpoint.planner_helpers import _create_write_item_for_tensor
|
|
||||||
from torch.distributed.checkpoint.storage import WriteResult
|
from torch.distributed.checkpoint.storage import WriteResult
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||||
|
|
||||||
@ -38,66 +36,9 @@ from torch.testing._internal.common_utils import run_tests, TestCase
|
|||||||
class TestHfStorage(TestCase):
|
class TestHfStorage(TestCase):
|
||||||
def test_write_data_hf(self) -> None:
|
def test_write_data_hf(self) -> None:
|
||||||
mock_module = MagicMock()
|
mock_module = MagicMock()
|
||||||
mock_module.save.return_value = b""
|
sys.modules["safetensors"] = mock_module
|
||||||
sys.modules["safetensors.torch"] = mock_module
|
sys.modules["huggingface_hub"] = mock_module
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as path:
|
|
||||||
writer = _HuggingFaceStorageWriter(
|
|
||||||
path=path,
|
|
||||||
fqn_to_index_mapping={"tensor_0": 1, "tensor_1": 2},
|
|
||||||
)
|
|
||||||
writer.fs = FileSystem()
|
|
||||||
|
|
||||||
tensor0 = torch.rand(4)
|
|
||||||
tensor1 = torch.rand(10)
|
|
||||||
write_item_1 = _create_write_item_for_tensor("tensor_0", tensor0)
|
|
||||||
write_item_2 = _create_write_item_for_tensor("tensor_1", tensor1)
|
|
||||||
|
|
||||||
state_dict = {"tensor_0": tensor0, "tensor_1": tensor1}
|
|
||||||
|
|
||||||
save_plan = SavePlan(
|
|
||||||
[write_item_1, write_item_2],
|
|
||||||
storage_data={"fqn_to_file_mapping": {"tensor_0": 1, "tensor_1": 2}},
|
|
||||||
)
|
|
||||||
save_planner = DefaultSavePlanner()
|
|
||||||
save_planner.set_up_planner(state_dict=state_dict)
|
|
||||||
|
|
||||||
write_results = writer.write_data(save_plan, save_planner)
|
|
||||||
|
|
||||||
write_results.wait()
|
|
||||||
actual_write_results = write_results.value()
|
|
||||||
|
|
||||||
expected_write_results = [
|
|
||||||
WriteResult(
|
|
||||||
index=MetadataIndex(
|
|
||||||
fqn="tensor_0", offset=torch.Size([0]), index=None
|
|
||||||
),
|
|
||||||
size_in_bytes=tensor0.numel() * tensor0.element_size(),
|
|
||||||
storage_data=_StorageInfo(
|
|
||||||
relative_path="model-00001-of-00002.safetensors",
|
|
||||||
offset=0,
|
|
||||||
length=tensor0.numel() * tensor0.element_size(),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
WriteResult(
|
|
||||||
index=MetadataIndex(
|
|
||||||
fqn="tensor_1", offset=torch.Size([0]), index=None
|
|
||||||
),
|
|
||||||
size_in_bytes=tensor1.numel() * tensor1.element_size(),
|
|
||||||
storage_data=_StorageInfo(
|
|
||||||
relative_path="model-00002-of-00002.safetensors",
|
|
||||||
offset=0,
|
|
||||||
length=tensor1.numel() * tensor1.element_size(),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
actual_write_results,
|
|
||||||
expected_write_results,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_write_data_with_sharding(self) -> None:
|
|
||||||
mock_module = MagicMock()
|
mock_module = MagicMock()
|
||||||
mock_module.save.return_value = b""
|
mock_module.save.return_value = b""
|
||||||
sys.modules["safetensors.torch"] = mock_module
|
sys.modules["safetensors.torch"] = mock_module
|
||||||
@ -105,7 +46,7 @@ class TestHfStorage(TestCase):
|
|||||||
with tempfile.TemporaryDirectory() as path:
|
with tempfile.TemporaryDirectory() as path:
|
||||||
writer = _HuggingFaceStorageWriter(
|
writer = _HuggingFaceStorageWriter(
|
||||||
path=path,
|
path=path,
|
||||||
save_sharded=True,
|
fqn_to_index_mapping={"tensor_0": 1, "tensor_1": 1},
|
||||||
)
|
)
|
||||||
writer.fs = FileSystem()
|
writer.fs = FileSystem()
|
||||||
|
|
||||||
@ -118,7 +59,7 @@ class TestHfStorage(TestCase):
|
|||||||
|
|
||||||
save_plan = SavePlan(
|
save_plan = SavePlan(
|
||||||
[write_item_1, write_item_2],
|
[write_item_1, write_item_2],
|
||||||
storage_data={"shard_index": 1},
|
storage_data=_FqnToFileMapping({"tensor_0": 1, "tensor_1": 1}),
|
||||||
)
|
)
|
||||||
save_planner = DefaultSavePlanner()
|
save_planner = DefaultSavePlanner()
|
||||||
save_planner.set_up_planner(state_dict=state_dict)
|
save_planner.set_up_planner(state_dict=state_dict)
|
||||||
@ -135,7 +76,7 @@ class TestHfStorage(TestCase):
|
|||||||
),
|
),
|
||||||
size_in_bytes=tensor0.numel() * tensor0.element_size(),
|
size_in_bytes=tensor0.numel() * tensor0.element_size(),
|
||||||
storage_data=_StorageInfo(
|
storage_data=_StorageInfo(
|
||||||
relative_path="shard-00001-model-00001-of-00001.safetensors",
|
relative_path="model-00001-of-00001.safetensors",
|
||||||
offset=0,
|
offset=0,
|
||||||
length=tensor0.numel() * tensor0.element_size(),
|
length=tensor0.numel() * tensor0.element_size(),
|
||||||
),
|
),
|
||||||
@ -146,7 +87,7 @@ class TestHfStorage(TestCase):
|
|||||||
),
|
),
|
||||||
size_in_bytes=tensor1.numel() * tensor1.element_size(),
|
size_in_bytes=tensor1.numel() * tensor1.element_size(),
|
||||||
storage_data=_StorageInfo(
|
storage_data=_StorageInfo(
|
||||||
relative_path="shard-00001-model-00001-of-00001.safetensors",
|
relative_path="model-00001-of-00001.safetensors",
|
||||||
offset=0,
|
offset=0,
|
||||||
length=tensor1.numel() * tensor1.element_size(),
|
length=tensor1.numel() * tensor1.element_size(),
|
||||||
),
|
),
|
||||||
@ -159,84 +100,43 @@ class TestHfStorage(TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_read_data_hf(self) -> None:
|
def test_read_data_hf(self) -> None:
|
||||||
mock_safetensors = MagicMock()
|
mock_module = MagicMock()
|
||||||
sys.modules["safetensors"] = mock_safetensors
|
sys.modules["safetensors"] = mock_module
|
||||||
|
sys.modules["huggingface_hub"] = mock_module
|
||||||
|
|
||||||
# Create test tensors
|
name = "tensor_0"
|
||||||
tensor_0 = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
tensor_0 = torch.rand(4)
|
||||||
|
mock_module = MagicMock()
|
||||||
# Mock the deserialize function to return our test tensors
|
mock_module.load.return_value = {name: tensor_0}
|
||||||
# The format matches what's expected in the read_data method
|
sys.modules["safetensors.torch"] = mock_module
|
||||||
mock_safetensors.deserialize.return_value = [
|
|
||||||
("tensor_0", {
|
|
||||||
"data": tensor_0.numpy().tobytes(),
|
|
||||||
"dtype": "F32",
|
|
||||||
"shape": [4]
|
|
||||||
}),
|
|
||||||
]
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as path:
|
with tempfile.TemporaryDirectory() as path:
|
||||||
# Create the reader
|
|
||||||
reader = _HuggingFaceStorageReader(path=path)
|
reader = _HuggingFaceStorageReader(path=path)
|
||||||
reader.fs = FileSystem()
|
reader.fs = FileSystem()
|
||||||
|
file_name = "model-00001-of-00001"
|
||||||
|
|
||||||
# Create test file
|
pathlib.Path(os.path.join(path, file_name)).touch()
|
||||||
file_name = "model-00001-of-00001.safetensors"
|
|
||||||
file_path = os.path.join(path, file_name)
|
|
||||||
pathlib.Path(file_path).touch()
|
|
||||||
|
|
||||||
# Set up storage data with _StorageInfo objects
|
reader.set_up_storage_reader(
|
||||||
storage_data = {
|
Metadata(
|
||||||
"tensor_0": _StorageInfo(file_path, 0, tensor_0.numel() * tensor_0.element_size()),
|
state_dict_metadata={name: BytesStorageMetadata()},
|
||||||
}
|
storage_data={name: file_name},
|
||||||
|
),
|
||||||
|
is_coordinator=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
read_items = _create_read_items(name, BytesStorageMetadata(), file_name)
|
||||||
reader.storage_data = storage_data
|
|
||||||
|
|
||||||
# Create target tensors that will be updated by read_data
|
|
||||||
target_tensor_0 = torch.zeros(4)
|
|
||||||
state_dict = {
|
|
||||||
"tensor_0": target_tensor_0,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create read items for the load plan
|
|
||||||
read_items = []
|
|
||||||
for name, tensor in state_dict.items():
|
|
||||||
storage_index = MetadataIndex(fqn=name, offset=torch.Size([0]), index=None)
|
|
||||||
dest_index = MetadataIndex(fqn=name, offset=torch.Size([0]), index=None)
|
|
||||||
read_items.append(
|
|
||||||
ReadItem(
|
|
||||||
type=LoadItemType.TENSOR,
|
|
||||||
storage_index=storage_index,
|
|
||||||
dest_index=dest_index,
|
|
||||||
storage_offsets=[0, 0],
|
|
||||||
dest_offsets=[0, 0],
|
|
||||||
lengths=tensor.size(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create load plan and planner
|
|
||||||
load_plan = LoadPlan(read_items)
|
load_plan = LoadPlan(read_items)
|
||||||
load_planner = DefaultLoadPlanner()
|
load_planner = _HuggingFaceLoadPlanner()
|
||||||
load_planner.set_up_planner(
|
load_planner.set_up_planner(state_dict={name: torch.rand(4)})
|
||||||
state_dict=state_dict,
|
|
||||||
metadata=Metadata(
|
|
||||||
state_dict_metadata={
|
|
||||||
"tensor_0": TensorStorageMetadata(
|
|
||||||
properties=TensorProperties(dtype=torch.float32),
|
|
||||||
size=torch.Size([4]),
|
|
||||||
chunks=[ChunkStorageMetadata(offsets=[0], sizes=torch.Size([4]))])},
|
|
||||||
storage_data=storage_data)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call read_data
|
read_data = reader.read_data(load_plan, load_planner)
|
||||||
future = reader.read_data(load_plan, load_planner)
|
read_data.wait()
|
||||||
future.wait()
|
|
||||||
|
|
||||||
# Verify results - the target tensors should now contain the values from our test tensor
|
loaded_tensor = load_planner.original_state_dict[name]
|
||||||
self.assertTrue(torch.equal(state_dict["tensor_0"], tensor_0))
|
self.assertEqual(loaded_tensor, tensor_0)
|
||||||
|
|
||||||
def test_write_metadata_hf(self) -> None:
|
def test_metadata_hf(self) -> None:
|
||||||
mock_module = MagicMock()
|
mock_module = MagicMock()
|
||||||
sys.modules["huggingface_hub"] = mock_module
|
sys.modules["huggingface_hub"] = mock_module
|
||||||
with tempfile.TemporaryDirectory() as path:
|
with tempfile.TemporaryDirectory() as path:
|
||||||
@ -260,6 +160,7 @@ class TestHfStorage(TestCase):
|
|||||||
|
|
||||||
writer = _HuggingFaceStorageWriter(
|
writer = _HuggingFaceStorageWriter(
|
||||||
path=path,
|
path=path,
|
||||||
|
fqn_to_index_mapping=_FqnToFileMapping({}),
|
||||||
)
|
)
|
||||||
writer.fs = FileSystem()
|
writer.fs = FileSystem()
|
||||||
writer.finish(
|
writer.finish(
|
||||||
@ -284,16 +185,26 @@ class TestHfStorage(TestCase):
|
|||||||
metadata = json.load(f)
|
metadata = json.load(f)
|
||||||
self.assertEqual(metadata, expected_metadata)
|
self.assertEqual(metadata, expected_metadata)
|
||||||
|
|
||||||
def test_read_metadata_hf(self):
|
reader = _HuggingFaceStorageReader(path=path)
|
||||||
|
reader.fs = FileSystem()
|
||||||
|
metadata = reader.read_metadata()
|
||||||
|
self.assertEqual(metadata.storage_data, expected_metadata["weight_map"])
|
||||||
|
|
||||||
|
def test_read_metadata_when_metadata_file_does_not_exist(self) -> None:
|
||||||
|
mock_module = MagicMock()
|
||||||
|
sys.modules["huggingface_hub"] = mock_module
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as path:
|
with tempfile.TemporaryDirectory() as path:
|
||||||
reader = _HuggingFaceStorageReader(path=path)
|
reader = _HuggingFaceStorageReader(path=path)
|
||||||
|
reader.fs = FileSystem()
|
||||||
key = "tensor_0"
|
# there is one safetensor file, but no metadata file,
|
||||||
|
# so we create metadata from the safetensor file
|
||||||
|
keys = ["tensor_0", "tensor_1"]
|
||||||
file_name = "test.safetensors"
|
file_name = "test.safetensors"
|
||||||
with open(os.path.join(path, file_name), "wb") as f:
|
with open(os.path.join(path, file_name), "wb") as f:
|
||||||
# write metadata the same way it would be in safetensors file
|
# write metadata the same way it would be in safetensors file
|
||||||
metadata_contents = json.dumps(
|
metadata_contents = json.dumps(
|
||||||
{'tensor_0': {'dtype': "F32", "shape": [5, 10], "data_offsets": [0, 200]}}
|
{"tensor_0": "value_0", "tensor_1": "value_1"}
|
||||||
)
|
)
|
||||||
metadata_bytes = metadata_contents.encode("utf-8")
|
metadata_bytes = metadata_contents.encode("utf-8")
|
||||||
|
|
||||||
@ -305,16 +216,13 @@ class TestHfStorage(TestCase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
metadata.state_dict_metadata,
|
metadata.state_dict_metadata,
|
||||||
{
|
{
|
||||||
key: TensorStorageMetadata(
|
keys[0]: BytesStorageMetadata(),
|
||||||
properties=TensorProperties(dtype=torch.float32),
|
keys[1]: BytesStorageMetadata(),
|
||||||
size=torch.Size([5, 10]),
|
|
||||||
chunks=[ChunkStorageMetadata(offsets=[0, 0], sizes=torch.Size([5, 10]))],
|
|
||||||
),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
metadata.storage_data,
|
metadata.storage_data,
|
||||||
{key: _StorageInfo(os.path.join(path, file_name), 0, 200, transform_descriptors=None)},
|
{keys[0]: file_name, keys[1]: file_name},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
67
test/dynamo/cpython/3_13/list_tests.diff
Normal file
67
test/dynamo/cpython/3_13/list_tests.diff
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
diff --git a/test/dynamo/cpython/3_13/list_tests.py b/test/dynamo/cpython/3_13/list_tests.py
|
||||||
|
index dbc5ef4f9f2..2b9f3b9311f 100644
|
||||||
|
--- a/test/dynamo/cpython/3_13/list_tests.py
|
||||||
|
+++ b/test/dynamo/cpython/3_13/list_tests.py
|
||||||
|
@@ -1,3 +1,53 @@
|
||||||
|
+# ======= BEGIN Dynamo patch =======
|
||||||
|
+# Owner(s): ["module: dynamo"]
|
||||||
|
+
|
||||||
|
+# ruff: noqa
|
||||||
|
+# flake8: noqa
|
||||||
|
+
|
||||||
|
+import sys
|
||||||
|
+import torch
|
||||||
|
+import torch._dynamo.test_case
|
||||||
|
+import unittest
|
||||||
|
+from torch._dynamo.test_case import CPythonTestCase
|
||||||
|
+from torch.testing._internal.common_utils import run_tests
|
||||||
|
+
|
||||||
|
+__TestCase = CPythonTestCase
|
||||||
|
+
|
||||||
|
+# redirect import statements
|
||||||
|
+import sys
|
||||||
|
+import importlib.abc
|
||||||
|
+
|
||||||
|
+redirect_imports = (
|
||||||
|
+ "test.mapping_tests",
|
||||||
|
+ "test.typinganndata",
|
||||||
|
+ "test.test_grammar",
|
||||||
|
+ "test.test_math",
|
||||||
|
+ "test.test_iter",
|
||||||
|
+ "test.typinganndata.ann_module",
|
||||||
|
+)
|
||||||
|
+
|
||||||
|
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||||
|
+ def find_spec(self, fullname, path, target=None):
|
||||||
|
+ # Check if the import is the problematic one
|
||||||
|
+ if fullname in redirect_imports:
|
||||||
|
+ try:
|
||||||
|
+ # Attempt to import the standalone module
|
||||||
|
+ name = fullname.removeprefix("test.")
|
||||||
|
+ r = importlib.import_module(name)
|
||||||
|
+ # Redirect the module in sys.modules
|
||||||
|
+ sys.modules[fullname] = r
|
||||||
|
+ # Return a module spec from the found module
|
||||||
|
+ return importlib.util.find_spec(name)
|
||||||
|
+ except ImportError:
|
||||||
|
+ return None
|
||||||
|
+ return None
|
||||||
|
+
|
||||||
|
+# Add the custom finder to sys.meta_path
|
||||||
|
+sys.meta_path.insert(0, RedirectImportFinder())
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+# ======= END DYNAMO PATCH =======
|
||||||
|
+
|
||||||
|
"""
|
||||||
|
Tests common to list and UserList.UserList
|
||||||
|
"""
|
||||||
|
@@ -5,7 +55,7 @@ Tests common to list and UserList.UserList
|
||||||
|
import sys
|
||||||
|
from functools import cmp_to_key
|
||||||
|
|
||||||
|
-from test import seq_tests
|
||||||
|
+import seq_tests
|
||||||
|
from test.support import ALWAYS_EQ, NEVER_EQ, get_c_recursion_limit
|
||||||
|
|
||||||
|
|
||||||
627
test/dynamo/cpython/3_13/list_tests.py
Normal file
627
test/dynamo/cpython/3_13/list_tests.py
Normal file
@ -0,0 +1,627 @@
|
|||||||
|
# ======= BEGIN Dynamo patch =======
|
||||||
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
|
# ruff: noqa
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import torch._dynamo.test_case
|
||||||
|
import unittest
|
||||||
|
from torch._dynamo.test_case import CPythonTestCase
|
||||||
|
from torch.testing._internal.common_utils import run_tests
|
||||||
|
|
||||||
|
__TestCase = CPythonTestCase
|
||||||
|
|
||||||
|
# redirect import statements
|
||||||
|
import sys
|
||||||
|
import importlib.abc
|
||||||
|
|
||||||
|
redirect_imports = (
|
||||||
|
"test.mapping_tests",
|
||||||
|
"test.typinganndata",
|
||||||
|
"test.test_grammar",
|
||||||
|
"test.test_math",
|
||||||
|
"test.test_iter",
|
||||||
|
"test.typinganndata.ann_module",
|
||||||
|
)
|
||||||
|
|
||||||
|
class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||||
|
def find_spec(self, fullname, path, target=None):
|
||||||
|
# Check if the import is the problematic one
|
||||||
|
if fullname in redirect_imports:
|
||||||
|
try:
|
||||||
|
# Attempt to import the standalone module
|
||||||
|
name = fullname.removeprefix("test.")
|
||||||
|
r = importlib.import_module(name)
|
||||||
|
# Redirect the module in sys.modules
|
||||||
|
sys.modules[fullname] = r
|
||||||
|
# Return a module spec from the found module
|
||||||
|
return importlib.util.find_spec(name)
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Add the custom finder to sys.meta_path
|
||||||
|
sys.meta_path.insert(0, RedirectImportFinder())
|
||||||
|
|
||||||
|
|
||||||
|
# ======= END DYNAMO PATCH =======
|
||||||
|
|
||||||
|
"""
|
||||||
|
Tests common to list and UserList.UserList
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from functools import cmp_to_key
|
||||||
|
|
||||||
|
import seq_tests
|
||||||
|
from test.support import ALWAYS_EQ, NEVER_EQ, get_c_recursion_limit
|
||||||
|
|
||||||
|
|
||||||
|
class CommonTest(seq_tests.CommonTest):
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
# Iterable arg is optional
|
||||||
|
self.assertEqual(self.type2test([]), self.type2test())
|
||||||
|
|
||||||
|
# Init clears previous values
|
||||||
|
a = self.type2test([1, 2, 3])
|
||||||
|
a.__init__()
|
||||||
|
self.assertEqual(a, self.type2test([]))
|
||||||
|
|
||||||
|
# Init overwrites previous values
|
||||||
|
a = self.type2test([1, 2, 3])
|
||||||
|
a.__init__([4, 5, 6])
|
||||||
|
self.assertEqual(a, self.type2test([4, 5, 6]))
|
||||||
|
|
||||||
|
# Mutables always return a new object
|
||||||
|
b = self.type2test(a)
|
||||||
|
self.assertNotEqual(id(a), id(b))
|
||||||
|
self.assertEqual(a, b)
|
||||||
|
|
||||||
|
def test_getitem_error(self):
|
||||||
|
a = []
|
||||||
|
msg = "list indices must be integers or slices"
|
||||||
|
with self.assertRaisesRegex(TypeError, msg):
|
||||||
|
a['a']
|
||||||
|
|
||||||
|
def test_setitem_error(self):
|
||||||
|
a = []
|
||||||
|
msg = "list indices must be integers or slices"
|
||||||
|
with self.assertRaisesRegex(TypeError, msg):
|
||||||
|
a['a'] = "python"
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
l0 = []
|
||||||
|
l2 = [0, 1, 2]
|
||||||
|
a0 = self.type2test(l0)
|
||||||
|
a2 = self.type2test(l2)
|
||||||
|
|
||||||
|
self.assertEqual(str(a0), str(l0))
|
||||||
|
self.assertEqual(repr(a0), repr(l0))
|
||||||
|
self.assertEqual(repr(a2), repr(l2))
|
||||||
|
self.assertEqual(str(a2), "[0, 1, 2]")
|
||||||
|
self.assertEqual(repr(a2), "[0, 1, 2]")
|
||||||
|
|
||||||
|
a2.append(a2)
|
||||||
|
a2.append(3)
|
||||||
|
self.assertEqual(str(a2), "[0, 1, 2, [...], 3]")
|
||||||
|
self.assertEqual(repr(a2), "[0, 1, 2, [...], 3]")
|
||||||
|
|
||||||
|
def test_repr_deep(self):
|
||||||
|
a = self.type2test([])
|
||||||
|
for i in range(get_c_recursion_limit() + 1):
|
||||||
|
a = self.type2test([a])
|
||||||
|
self.assertRaises(RecursionError, repr, a)
|
||||||
|
|
||||||
|
def test_set_subscript(self):
|
||||||
|
a = self.type2test(range(20))
|
||||||
|
self.assertRaises(ValueError, a.__setitem__, slice(0, 10, 0), [1,2,3])
|
||||||
|
self.assertRaises(TypeError, a.__setitem__, slice(0, 10), 1)
|
||||||
|
self.assertRaises(ValueError, a.__setitem__, slice(0, 10, 2), [1,2])
|
||||||
|
self.assertRaises(TypeError, a.__getitem__, 'x', 1)
|
||||||
|
a[slice(2,10,3)] = [1,2,3]
|
||||||
|
self.assertEqual(a, self.type2test([0, 1, 1, 3, 4, 2, 6, 7, 3,
|
||||||
|
9, 10, 11, 12, 13, 14, 15,
|
||||||
|
16, 17, 18, 19]))
|
||||||
|
|
||||||
|
def test_reversed(self):
|
||||||
|
a = self.type2test(range(20))
|
||||||
|
r = reversed(a)
|
||||||
|
self.assertEqual(list(r), self.type2test(range(19, -1, -1)))
|
||||||
|
self.assertRaises(StopIteration, next, r)
|
||||||
|
self.assertEqual(list(reversed(self.type2test())),
|
||||||
|
self.type2test())
|
||||||
|
# Bug 3689: make sure list-reversed-iterator doesn't have __len__
|
||||||
|
self.assertRaises(TypeError, len, reversed([1,2,3]))
|
||||||
|
|
||||||
|
def test_setitem(self):
|
||||||
|
a = self.type2test([0, 1])
|
||||||
|
a[0] = 0
|
||||||
|
a[1] = 100
|
||||||
|
self.assertEqual(a, self.type2test([0, 100]))
|
||||||
|
a[-1] = 200
|
||||||
|
self.assertEqual(a, self.type2test([0, 200]))
|
||||||
|
a[-2] = 100
|
||||||
|
self.assertEqual(a, self.type2test([100, 200]))
|
||||||
|
self.assertRaises(IndexError, a.__setitem__, -3, 200)
|
||||||
|
self.assertRaises(IndexError, a.__setitem__, 2, 200)
|
||||||
|
|
||||||
|
a = self.type2test([])
|
||||||
|
self.assertRaises(IndexError, a.__setitem__, 0, 200)
|
||||||
|
self.assertRaises(IndexError, a.__setitem__, -1, 200)
|
||||||
|
self.assertRaises(TypeError, a.__setitem__)
|
||||||
|
|
||||||
|
a = self.type2test([0,1,2,3,4])
|
||||||
|
a[0] = 1
|
||||||
|
a[1] = 2
|
||||||
|
a[2] = 3
|
||||||
|
self.assertEqual(a, self.type2test([1,2,3,3,4]))
|
||||||
|
a[0] = 5
|
||||||
|
a[1] = 6
|
||||||
|
a[2] = 7
|
||||||
|
self.assertEqual(a, self.type2test([5,6,7,3,4]))
|
||||||
|
a[-2] = 88
|
||||||
|
a[-1] = 99
|
||||||
|
self.assertEqual(a, self.type2test([5,6,7,88,99]))
|
||||||
|
a[-2] = 8
|
||||||
|
a[-1] = 9
|
||||||
|
self.assertEqual(a, self.type2test([5,6,7,8,9]))
|
||||||
|
|
||||||
|
msg = "list indices must be integers or slices"
|
||||||
|
with self.assertRaisesRegex(TypeError, msg):
|
||||||
|
a['a'] = "python"
|
||||||
|
|
||||||
|
def test_delitem(self):
|
||||||
|
a = self.type2test([0, 1])
|
||||||
|
del a[1]
|
||||||
|
self.assertEqual(a, [0])
|
||||||
|
del a[0]
|
||||||
|
self.assertEqual(a, [])
|
||||||
|
|
||||||
|
a = self.type2test([0, 1])
|
||||||
|
del a[-2]
|
||||||
|
self.assertEqual(a, [1])
|
||||||
|
del a[-1]
|
||||||
|
self.assertEqual(a, [])
|
||||||
|
|
||||||
|
a = self.type2test([0, 1])
|
||||||
|
self.assertRaises(IndexError, a.__delitem__, -3)
|
||||||
|
self.assertRaises(IndexError, a.__delitem__, 2)
|
||||||
|
|
||||||
|
a = self.type2test([])
|
||||||
|
self.assertRaises(IndexError, a.__delitem__, 0)
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, a.__delitem__)
|
||||||
|
|
||||||
|
def test_setslice(self):
|
||||||
|
l = [0, 1]
|
||||||
|
a = self.type2test(l)
|
||||||
|
|
||||||
|
for i in range(-3, 4):
|
||||||
|
a[:i] = l[:i]
|
||||||
|
self.assertEqual(a, l)
|
||||||
|
a2 = a[:]
|
||||||
|
a2[:i] = a[:i]
|
||||||
|
self.assertEqual(a2, a)
|
||||||
|
a[i:] = l[i:]
|
||||||
|
self.assertEqual(a, l)
|
||||||
|
a2 = a[:]
|
||||||
|
a2[i:] = a[i:]
|
||||||
|
self.assertEqual(a2, a)
|
||||||
|
for j in range(-3, 4):
|
||||||
|
a[i:j] = l[i:j]
|
||||||
|
self.assertEqual(a, l)
|
||||||
|
a2 = a[:]
|
||||||
|
a2[i:j] = a[i:j]
|
||||||
|
self.assertEqual(a2, a)
|
||||||
|
|
||||||
|
aa2 = a2[:]
|
||||||
|
aa2[:0] = [-2, -1]
|
||||||
|
self.assertEqual(aa2, [-2, -1, 0, 1])
|
||||||
|
aa2[0:] = []
|
||||||
|
self.assertEqual(aa2, [])
|
||||||
|
|
||||||
|
a = self.type2test([1, 2, 3, 4, 5])
|
||||||
|
a[:-1] = a
|
||||||
|
self.assertEqual(a, self.type2test([1, 2, 3, 4, 5, 5]))
|
||||||
|
a = self.type2test([1, 2, 3, 4, 5])
|
||||||
|
a[1:] = a
|
||||||
|
self.assertEqual(a, self.type2test([1, 1, 2, 3, 4, 5]))
|
||||||
|
a = self.type2test([1, 2, 3, 4, 5])
|
||||||
|
a[1:-1] = a
|
||||||
|
self.assertEqual(a, self.type2test([1, 1, 2, 3, 4, 5, 5]))
|
||||||
|
|
||||||
|
a = self.type2test([])
|
||||||
|
a[:] = tuple(range(10))
|
||||||
|
self.assertEqual(a, self.type2test(range(10)))
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, a.__setitem__, slice(0, 1, 5))
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, a.__setitem__)
|
||||||
|
|
||||||
|
def test_slice_assign_iterator(self):
|
||||||
|
x = self.type2test(range(5))
|
||||||
|
x[0:3] = reversed(range(3))
|
||||||
|
self.assertEqual(x, self.type2test([2, 1, 0, 3, 4]))
|
||||||
|
|
||||||
|
x[:] = reversed(range(3))
|
||||||
|
self.assertEqual(x, self.type2test([2, 1, 0]))
|
||||||
|
|
||||||
|
def test_delslice(self):
|
||||||
|
a = self.type2test([0, 1])
|
||||||
|
del a[1:2]
|
||||||
|
del a[0:1]
|
||||||
|
self.assertEqual(a, self.type2test([]))
|
||||||
|
|
||||||
|
a = self.type2test([0, 1])
|
||||||
|
del a[1:2]
|
||||||
|
del a[0:1]
|
||||||
|
self.assertEqual(a, self.type2test([]))
|
||||||
|
|
||||||
|
a = self.type2test([0, 1])
|
||||||
|
del a[-2:-1]
|
||||||
|
self.assertEqual(a, self.type2test([1]))
|
||||||
|
|
||||||
|
a = self.type2test([0, 1])
|
||||||
|
del a[-2:-1]
|
||||||
|
self.assertEqual(a, self.type2test([1]))
|
||||||
|
|
||||||
|
a = self.type2test([0, 1])
|
||||||
|
del a[1:]
|
||||||
|
del a[:1]
|
||||||
|
self.assertEqual(a, self.type2test([]))
|
||||||
|
|
||||||
|
a = self.type2test([0, 1])
|
||||||
|
del a[1:]
|
||||||
|
del a[:1]
|
||||||
|
self.assertEqual(a, self.type2test([]))
|
||||||
|
|
||||||
|
a = self.type2test([0, 1])
|
||||||
|
del a[-1:]
|
||||||
|
self.assertEqual(a, self.type2test([0]))
|
||||||
|
|
||||||
|
a = self.type2test([0, 1])
|
||||||
|
del a[-1:]
|
||||||
|
self.assertEqual(a, self.type2test([0]))
|
||||||
|
|
||||||
|
a = self.type2test([0, 1])
|
||||||
|
del a[:]
|
||||||
|
self.assertEqual(a, self.type2test([]))
|
||||||
|
|
||||||
|
def test_append(self):
|
||||||
|
a = self.type2test([])
|
||||||
|
a.append(0)
|
||||||
|
a.append(1)
|
||||||
|
a.append(2)
|
||||||
|
self.assertEqual(a, self.type2test([0, 1, 2]))
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, a.append)
|
||||||
|
|
||||||
|
def test_extend(self):
|
||||||
|
a1 = self.type2test([0])
|
||||||
|
a2 = self.type2test((0, 1))
|
||||||
|
a = a1[:]
|
||||||
|
a.extend(a2)
|
||||||
|
self.assertEqual(a, a1 + a2)
|
||||||
|
|
||||||
|
a.extend(self.type2test([]))
|
||||||
|
self.assertEqual(a, a1 + a2)
|
||||||
|
|
||||||
|
a.extend(a)
|
||||||
|
self.assertEqual(a, self.type2test([0, 0, 1, 0, 0, 1]))
|
||||||
|
|
||||||
|
a = self.type2test("spam")
|
||||||
|
a.extend("eggs")
|
||||||
|
self.assertEqual(a, list("spameggs"))
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, a.extend, None)
|
||||||
|
self.assertRaises(TypeError, a.extend)
|
||||||
|
|
||||||
|
# overflow test. issue1621
|
||||||
|
class CustomIter:
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
def __next__(self):
|
||||||
|
raise StopIteration
|
||||||
|
def __length_hint__(self):
|
||||||
|
return sys.maxsize
|
||||||
|
a = self.type2test([1,2,3,4])
|
||||||
|
a.extend(CustomIter())
|
||||||
|
self.assertEqual(a, [1,2,3,4])
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert(self):
|
||||||
|
a = self.type2test([0, 1, 2])
|
||||||
|
a.insert(0, -2)
|
||||||
|
a.insert(1, -1)
|
||||||
|
a.insert(2, 0)
|
||||||
|
self.assertEqual(a, [-2, -1, 0, 0, 1, 2])
|
||||||
|
|
||||||
|
b = a[:]
|
||||||
|
b.insert(-2, "foo")
|
||||||
|
b.insert(-200, "left")
|
||||||
|
b.insert(200, "right")
|
||||||
|
self.assertEqual(b, self.type2test(["left",-2,-1,0,0,"foo",1,2,"right"]))
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, a.insert)
|
||||||
|
|
||||||
|
def test_pop(self):
|
||||||
|
a = self.type2test([-1, 0, 1])
|
||||||
|
a.pop()
|
||||||
|
self.assertEqual(a, [-1, 0])
|
||||||
|
a.pop(0)
|
||||||
|
self.assertEqual(a, [0])
|
||||||
|
self.assertRaises(IndexError, a.pop, 5)
|
||||||
|
a.pop(0)
|
||||||
|
self.assertEqual(a, [])
|
||||||
|
self.assertRaises(IndexError, a.pop)
|
||||||
|
self.assertRaises(TypeError, a.pop, 42, 42)
|
||||||
|
a = self.type2test([0, 10, 20, 30, 40])
|
||||||
|
|
||||||
|
def test_remove(self):
|
||||||
|
a = self.type2test([0, 0, 1])
|
||||||
|
a.remove(1)
|
||||||
|
self.assertEqual(a, [0, 0])
|
||||||
|
a.remove(0)
|
||||||
|
self.assertEqual(a, [0])
|
||||||
|
a.remove(0)
|
||||||
|
self.assertEqual(a, [])
|
||||||
|
|
||||||
|
self.assertRaises(ValueError, a.remove, 0)
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, a.remove)
|
||||||
|
|
||||||
|
a = self.type2test([1, 2])
|
||||||
|
self.assertRaises(ValueError, a.remove, NEVER_EQ)
|
||||||
|
self.assertEqual(a, [1, 2])
|
||||||
|
a.remove(ALWAYS_EQ)
|
||||||
|
self.assertEqual(a, [2])
|
||||||
|
a = self.type2test([ALWAYS_EQ])
|
||||||
|
a.remove(1)
|
||||||
|
self.assertEqual(a, [])
|
||||||
|
a = self.type2test([ALWAYS_EQ])
|
||||||
|
a.remove(NEVER_EQ)
|
||||||
|
self.assertEqual(a, [])
|
||||||
|
a = self.type2test([NEVER_EQ])
|
||||||
|
self.assertRaises(ValueError, a.remove, ALWAYS_EQ)
|
||||||
|
|
||||||
|
class BadExc(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class BadCmp:
|
||||||
|
def __eq__(self, other):
|
||||||
|
if other == 2:
|
||||||
|
raise BadExc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
a = self.type2test([0, 1, 2, 3])
|
||||||
|
self.assertRaises(BadExc, a.remove, BadCmp())
|
||||||
|
|
||||||
|
class BadCmp2:
|
||||||
|
def __eq__(self, other):
|
||||||
|
raise BadExc()
|
||||||
|
|
||||||
|
d = self.type2test('abcdefghcij')
|
||||||
|
d.remove('c')
|
||||||
|
self.assertEqual(d, self.type2test('abdefghcij'))
|
||||||
|
d.remove('c')
|
||||||
|
self.assertEqual(d, self.type2test('abdefghij'))
|
||||||
|
self.assertRaises(ValueError, d.remove, 'c')
|
||||||
|
self.assertEqual(d, self.type2test('abdefghij'))
|
||||||
|
|
||||||
|
# Handle comparison errors
|
||||||
|
d = self.type2test(['a', 'b', BadCmp2(), 'c'])
|
||||||
|
e = self.type2test(d)
|
||||||
|
self.assertRaises(BadExc, d.remove, 'c')
|
||||||
|
for x, y in zip(d, e):
|
||||||
|
# verify that original order and values are retained.
|
||||||
|
self.assertIs(x, y)
|
||||||
|
|
||||||
|
def test_index(self):
|
||||||
|
super().test_index()
|
||||||
|
a = self.type2test([-2, -1, 0, 0, 1, 2])
|
||||||
|
a.remove(0)
|
||||||
|
self.assertRaises(ValueError, a.index, 2, 0, 4)
|
||||||
|
self.assertEqual(a, self.type2test([-2, -1, 0, 1, 2]))
|
||||||
|
|
||||||
|
# Test modifying the list during index's iteration
|
||||||
|
class EvilCmp:
|
||||||
|
def __init__(self, victim):
|
||||||
|
self.victim = victim
|
||||||
|
def __eq__(self, other):
|
||||||
|
del self.victim[:]
|
||||||
|
return False
|
||||||
|
a = self.type2test()
|
||||||
|
a[:] = [EvilCmp(a) for _ in range(100)]
|
||||||
|
# This used to seg fault before patch #1005778
|
||||||
|
self.assertRaises(ValueError, a.index, None)
|
||||||
|
|
||||||
|
def test_reverse(self):
|
||||||
|
u = self.type2test([-2, -1, 0, 1, 2])
|
||||||
|
u2 = u[:]
|
||||||
|
u.reverse()
|
||||||
|
self.assertEqual(u, [2, 1, 0, -1, -2])
|
||||||
|
u.reverse()
|
||||||
|
self.assertEqual(u, u2)
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, u.reverse, 42)
|
||||||
|
|
||||||
|
def test_clear(self):
|
||||||
|
u = self.type2test([2, 3, 4])
|
||||||
|
u.clear()
|
||||||
|
self.assertEqual(u, [])
|
||||||
|
|
||||||
|
u = self.type2test([])
|
||||||
|
u.clear()
|
||||||
|
self.assertEqual(u, [])
|
||||||
|
|
||||||
|
u = self.type2test([])
|
||||||
|
u.append(1)
|
||||||
|
u.clear()
|
||||||
|
u.append(2)
|
||||||
|
self.assertEqual(u, [2])
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, u.clear, None)
|
||||||
|
|
||||||
|
def test_copy(self):
|
||||||
|
u = self.type2test([1, 2, 3])
|
||||||
|
v = u.copy()
|
||||||
|
self.assertEqual(v, [1, 2, 3])
|
||||||
|
|
||||||
|
u = self.type2test([])
|
||||||
|
v = u.copy()
|
||||||
|
self.assertEqual(v, [])
|
||||||
|
|
||||||
|
# test that it's indeed a copy and not a reference
|
||||||
|
u = self.type2test(['a', 'b'])
|
||||||
|
v = u.copy()
|
||||||
|
v.append('i')
|
||||||
|
self.assertEqual(u, ['a', 'b'])
|
||||||
|
self.assertEqual(v, u + ['i'])
|
||||||
|
|
||||||
|
# test that it's a shallow, not a deep copy
|
||||||
|
u = self.type2test([1, 2, [3, 4], 5])
|
||||||
|
v = u.copy()
|
||||||
|
self.assertEqual(u, v)
|
||||||
|
self.assertIs(v[3], u[3])
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, u.copy, None)
|
||||||
|
|
||||||
|
def test_sort(self):
|
||||||
|
u = self.type2test([1, 0])
|
||||||
|
u.sort()
|
||||||
|
self.assertEqual(u, [0, 1])
|
||||||
|
|
||||||
|
u = self.type2test([2,1,0,-1,-2])
|
||||||
|
u.sort()
|
||||||
|
self.assertEqual(u, self.type2test([-2,-1,0,1,2]))
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, u.sort, 42, 42)
|
||||||
|
|
||||||
|
def revcmp(a, b):
|
||||||
|
if a == b:
|
||||||
|
return 0
|
||||||
|
elif a < b:
|
||||||
|
return 1
|
||||||
|
else: # a > b
|
||||||
|
return -1
|
||||||
|
u.sort(key=cmp_to_key(revcmp))
|
||||||
|
self.assertEqual(u, self.type2test([2,1,0,-1,-2]))
|
||||||
|
|
||||||
|
# The following dumps core in unpatched Python 1.5:
|
||||||
|
def myComparison(x,y):
|
||||||
|
xmod, ymod = x%3, y%7
|
||||||
|
if xmod == ymod:
|
||||||
|
return 0
|
||||||
|
elif xmod < ymod:
|
||||||
|
return -1
|
||||||
|
else: # xmod > ymod
|
||||||
|
return 1
|
||||||
|
z = self.type2test(range(12))
|
||||||
|
z.sort(key=cmp_to_key(myComparison))
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, z.sort, 2)
|
||||||
|
|
||||||
|
def selfmodifyingComparison(x,y):
|
||||||
|
z.append(1)
|
||||||
|
if x == y:
|
||||||
|
return 0
|
||||||
|
elif x < y:
|
||||||
|
return -1
|
||||||
|
else: # x > y
|
||||||
|
return 1
|
||||||
|
self.assertRaises(ValueError, z.sort,
|
||||||
|
key=cmp_to_key(selfmodifyingComparison))
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, z.sort, 42, 42, 42, 42)
|
||||||
|
|
||||||
|
def test_slice(self):
|
||||||
|
u = self.type2test("spam")
|
||||||
|
u[:2] = "h"
|
||||||
|
self.assertEqual(u, list("ham"))
|
||||||
|
|
||||||
|
def test_iadd(self):
|
||||||
|
super().test_iadd()
|
||||||
|
u = self.type2test([0, 1])
|
||||||
|
u2 = u
|
||||||
|
u += [2, 3]
|
||||||
|
self.assertIs(u, u2)
|
||||||
|
|
||||||
|
u = self.type2test("spam")
|
||||||
|
u += "eggs"
|
||||||
|
self.assertEqual(u, self.type2test("spameggs"))
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, u.__iadd__, None)
|
||||||
|
|
||||||
|
def test_imul(self):
|
||||||
|
super().test_imul()
|
||||||
|
s = self.type2test([])
|
||||||
|
oldid = id(s)
|
||||||
|
s *= 10
|
||||||
|
self.assertEqual(id(s), oldid)
|
||||||
|
|
||||||
|
def test_extendedslicing(self):
|
||||||
|
# subscript
|
||||||
|
a = self.type2test([0,1,2,3,4])
|
||||||
|
|
||||||
|
# deletion
|
||||||
|
del a[::2]
|
||||||
|
self.assertEqual(a, self.type2test([1,3]))
|
||||||
|
a = self.type2test(range(5))
|
||||||
|
del a[1::2]
|
||||||
|
self.assertEqual(a, self.type2test([0,2,4]))
|
||||||
|
a = self.type2test(range(5))
|
||||||
|
del a[1::-2]
|
||||||
|
self.assertEqual(a, self.type2test([0,2,3,4]))
|
||||||
|
a = self.type2test(range(10))
|
||||||
|
del a[::1000]
|
||||||
|
self.assertEqual(a, self.type2test([1, 2, 3, 4, 5, 6, 7, 8, 9]))
|
||||||
|
# assignment
|
||||||
|
a = self.type2test(range(10))
|
||||||
|
a[::2] = [-1]*5
|
||||||
|
self.assertEqual(a, self.type2test([-1, 1, -1, 3, -1, 5, -1, 7, -1, 9]))
|
||||||
|
a = self.type2test(range(10))
|
||||||
|
a[::-4] = [10]*3
|
||||||
|
self.assertEqual(a, self.type2test([0, 10, 2, 3, 4, 10, 6, 7, 8 ,10]))
|
||||||
|
a = self.type2test(range(4))
|
||||||
|
a[::-1] = a
|
||||||
|
self.assertEqual(a, self.type2test([3, 2, 1, 0]))
|
||||||
|
a = self.type2test(range(10))
|
||||||
|
b = a[:]
|
||||||
|
c = a[:]
|
||||||
|
a[2:3] = self.type2test(["two", "elements"])
|
||||||
|
b[slice(2,3)] = self.type2test(["two", "elements"])
|
||||||
|
c[2:3:] = self.type2test(["two", "elements"])
|
||||||
|
self.assertEqual(a, b)
|
||||||
|
self.assertEqual(a, c)
|
||||||
|
a = self.type2test(range(10))
|
||||||
|
a[::2] = tuple(range(5))
|
||||||
|
self.assertEqual(a, self.type2test([0, 1, 1, 3, 2, 5, 3, 7, 4, 9]))
|
||||||
|
# test issue7788
|
||||||
|
a = self.type2test(range(10))
|
||||||
|
del a[9::1<<333]
|
||||||
|
|
||||||
|
def test_constructor_exception_handling(self):
|
||||||
|
# Bug #1242657
|
||||||
|
class F(object):
|
||||||
|
def __iter__(self):
|
||||||
|
raise KeyboardInterrupt
|
||||||
|
self.assertRaises(KeyboardInterrupt, list, F())
|
||||||
|
|
||||||
|
def test_exhausted_iterator(self):
|
||||||
|
a = self.type2test([1, 2, 3])
|
||||||
|
exhit = iter(a)
|
||||||
|
empit = iter(a)
|
||||||
|
for x in exhit: # exhaust the iterator
|
||||||
|
next(empit) # not exhausted
|
||||||
|
a.append(9)
|
||||||
|
self.assertEqual(list(exhit), [])
|
||||||
|
self.assertEqual(list(empit), [9])
|
||||||
|
self.assertEqual(a, self.type2test([1, 2, 3, 9]))
|
||||||
|
|
||||||
|
# gh-115733: Crash when iterating over exhausted iterator
|
||||||
|
exhit = iter(self.type2test([1, 2, 3]))
|
||||||
|
for _ in exhit:
|
||||||
|
next(exhit, 1)
|
||||||
67
test/dynamo/cpython/3_13/mapping_tests.diff
Normal file
67
test/dynamo/cpython/3_13/mapping_tests.diff
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
diff --git a/test/dynamo/cpython/3_13/mapping_tests.py b/test/dynamo/cpython/3_13/mapping_tests.py
|
||||||
|
index ed89a81a6ea..eed59a68e94 100644
|
||||||
|
--- a/test/dynamo/cpython/3_13/mapping_tests.py
|
||||||
|
+++ b/test/dynamo/cpython/3_13/mapping_tests.py
|
||||||
|
@@ -1,10 +1,61 @@
|
||||||
|
+# ======= BEGIN Dynamo patch =======
|
||||||
|
+# Owner(s): ["module: dynamo"]
|
||||||
|
+
|
||||||
|
+# ruff: noqa
|
||||||
|
+# flake8: noqa
|
||||||
|
+
|
||||||
|
+import sys
|
||||||
|
+import torch
|
||||||
|
+import torch._dynamo.test_case
|
||||||
|
+import unittest
|
||||||
|
+from torch._dynamo.test_case import CPythonTestCase
|
||||||
|
+from torch.testing._internal.common_utils import run_tests
|
||||||
|
+
|
||||||
|
+__TestCase = CPythonTestCase
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+# redirect import statements
|
||||||
|
+import sys
|
||||||
|
+import importlib.abc
|
||||||
|
+
|
||||||
|
+redirect_imports = (
|
||||||
|
+ "test.mapping_tests",
|
||||||
|
+ "test.typinganndata",
|
||||||
|
+ "test.test_grammar",
|
||||||
|
+ "test.test_math",
|
||||||
|
+ "test.test_iter",
|
||||||
|
+ "test.typinganndata.ann_module",
|
||||||
|
+)
|
||||||
|
+
|
||||||
|
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||||
|
+ def find_spec(self, fullname, path, target=None):
|
||||||
|
+ # Check if the import is the problematic one
|
||||||
|
+ if fullname in redirect_imports:
|
||||||
|
+ try:
|
||||||
|
+ # Attempt to import the standalone module
|
||||||
|
+ name = fullname.removeprefix("test.")
|
||||||
|
+ r = importlib.import_module(name)
|
||||||
|
+ # Redirect the module in sys.modules
|
||||||
|
+ sys.modules[fullname] = r
|
||||||
|
+ # Return a module spec from the found module
|
||||||
|
+ return importlib.util.find_spec(name)
|
||||||
|
+ except ImportError:
|
||||||
|
+ return None
|
||||||
|
+ return None
|
||||||
|
+
|
||||||
|
+# Add the custom finder to sys.meta_path
|
||||||
|
+sys.meta_path.insert(0, RedirectImportFinder())
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+# ======= END DYNAMO PATCH =======
|
||||||
|
+
|
||||||
|
# tests common to dict and UserDict
|
||||||
|
import unittest
|
||||||
|
import collections
|
||||||
|
from test.support import get_c_recursion_limit
|
||||||
|
|
||||||
|
|
||||||
|
-class BasicTestMappingProtocol(unittest.TestCase):
|
||||||
|
+class BasicTestMappingProtocol(__TestCase):
|
||||||
|
# This base class can be used to check that an object conforms to the
|
||||||
|
# mapping protocol
|
||||||
|
|
||||||
719
test/dynamo/cpython/3_13/mapping_tests.py
Normal file
719
test/dynamo/cpython/3_13/mapping_tests.py
Normal file
@ -0,0 +1,719 @@
|
|||||||
|
# ======= BEGIN Dynamo patch =======
|
||||||
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
|
# ruff: noqa
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import torch._dynamo.test_case
|
||||||
|
import unittest
|
||||||
|
from torch._dynamo.test_case import CPythonTestCase
|
||||||
|
from torch.testing._internal.common_utils import run_tests
|
||||||
|
|
||||||
|
__TestCase = CPythonTestCase
|
||||||
|
|
||||||
|
|
||||||
|
# redirect import statements
|
||||||
|
import sys
|
||||||
|
import importlib.abc
|
||||||
|
|
||||||
|
redirect_imports = (
|
||||||
|
"test.mapping_tests",
|
||||||
|
"test.typinganndata",
|
||||||
|
"test.test_grammar",
|
||||||
|
"test.test_math",
|
||||||
|
"test.test_iter",
|
||||||
|
"test.typinganndata.ann_module",
|
||||||
|
)
|
||||||
|
|
||||||
|
class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||||
|
def find_spec(self, fullname, path, target=None):
|
||||||
|
# Check if the import is the problematic one
|
||||||
|
if fullname in redirect_imports:
|
||||||
|
try:
|
||||||
|
# Attempt to import the standalone module
|
||||||
|
name = fullname.removeprefix("test.")
|
||||||
|
r = importlib.import_module(name)
|
||||||
|
# Redirect the module in sys.modules
|
||||||
|
sys.modules[fullname] = r
|
||||||
|
# Return a module spec from the found module
|
||||||
|
return importlib.util.find_spec(name)
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Add the custom finder to sys.meta_path
|
||||||
|
sys.meta_path.insert(0, RedirectImportFinder())
|
||||||
|
|
||||||
|
|
||||||
|
# ======= END DYNAMO PATCH =======
|
||||||
|
|
||||||
|
# tests common to dict and UserDict
|
||||||
|
import unittest
|
||||||
|
import collections
|
||||||
|
from test.support import get_c_recursion_limit
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTestMappingProtocol(__TestCase):
|
||||||
|
# This base class can be used to check that an object conforms to the
|
||||||
|
# mapping protocol
|
||||||
|
|
||||||
|
# Functions that can be useful to override to adapt to dictionary
|
||||||
|
# semantics
|
||||||
|
type2test = None # which class is being tested (overwrite in subclasses)
|
||||||
|
|
||||||
|
def _reference(self):
|
||||||
|
"""Return a dictionary of values which are invariant by storage
|
||||||
|
in the object under test."""
|
||||||
|
return {"1": "2", "key1":"value1", "key2":(1,2,3)}
|
||||||
|
def _empty_mapping(self):
|
||||||
|
"""Return an empty mapping object"""
|
||||||
|
return self.type2test()
|
||||||
|
def _full_mapping(self, data):
|
||||||
|
"""Return a mapping object with the value contained in data
|
||||||
|
dictionary"""
|
||||||
|
x = self._empty_mapping()
|
||||||
|
for key, value in data.items():
|
||||||
|
x[key] = value
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __init__(self, *args, **kw):
|
||||||
|
unittest.TestCase.__init__(self, *args, **kw)
|
||||||
|
self.reference = self._reference().copy()
|
||||||
|
|
||||||
|
# A (key, value) pair not in the mapping
|
||||||
|
key, value = self.reference.popitem()
|
||||||
|
self.other = {key:value}
|
||||||
|
|
||||||
|
# A (key, value) pair in the mapping
|
||||||
|
key, value = self.reference.popitem()
|
||||||
|
self.inmapping = {key:value}
|
||||||
|
self.reference[key] = value
|
||||||
|
|
||||||
|
def test_read(self):
|
||||||
|
# Test for read only operations on mapping
|
||||||
|
p = self._empty_mapping()
|
||||||
|
p1 = dict(p) #workaround for singleton objects
|
||||||
|
d = self._full_mapping(self.reference)
|
||||||
|
if d is p:
|
||||||
|
p = p1
|
||||||
|
#Indexing
|
||||||
|
for key, value in self.reference.items():
|
||||||
|
self.assertEqual(d[key], value)
|
||||||
|
knownkey = list(self.other.keys())[0]
|
||||||
|
self.assertRaises(KeyError, lambda:d[knownkey])
|
||||||
|
#len
|
||||||
|
self.assertEqual(len(p), 0)
|
||||||
|
self.assertEqual(len(d), len(self.reference))
|
||||||
|
#__contains__
|
||||||
|
for k in self.reference:
|
||||||
|
self.assertIn(k, d)
|
||||||
|
for k in self.other:
|
||||||
|
self.assertNotIn(k, d)
|
||||||
|
#cmp
|
||||||
|
self.assertEqual(p, p)
|
||||||
|
self.assertEqual(d, d)
|
||||||
|
self.assertNotEqual(p, d)
|
||||||
|
self.assertNotEqual(d, p)
|
||||||
|
#bool
|
||||||
|
if p: self.fail("Empty mapping must compare to False")
|
||||||
|
if not d: self.fail("Full mapping must compare to True")
|
||||||
|
# keys(), items(), iterkeys() ...
|
||||||
|
def check_iterandlist(iter, lst, ref):
|
||||||
|
self.assertTrue(hasattr(iter, '__next__'))
|
||||||
|
self.assertTrue(hasattr(iter, '__iter__'))
|
||||||
|
x = list(iter)
|
||||||
|
self.assertTrue(set(x)==set(lst)==set(ref))
|
||||||
|
check_iterandlist(iter(d.keys()), list(d.keys()),
|
||||||
|
self.reference.keys())
|
||||||
|
check_iterandlist(iter(d), list(d.keys()), self.reference.keys())
|
||||||
|
check_iterandlist(iter(d.values()), list(d.values()),
|
||||||
|
self.reference.values())
|
||||||
|
check_iterandlist(iter(d.items()), list(d.items()),
|
||||||
|
self.reference.items())
|
||||||
|
#get
|
||||||
|
key, value = next(iter(d.items()))
|
||||||
|
knownkey, knownvalue = next(iter(self.other.items()))
|
||||||
|
self.assertEqual(d.get(key, knownvalue), value)
|
||||||
|
self.assertEqual(d.get(knownkey, knownvalue), knownvalue)
|
||||||
|
self.assertNotIn(knownkey, d)
|
||||||
|
|
||||||
|
def test_write(self):
|
||||||
|
# Test for write operations on mapping
|
||||||
|
p = self._empty_mapping()
|
||||||
|
#Indexing
|
||||||
|
for key, value in self.reference.items():
|
||||||
|
p[key] = value
|
||||||
|
self.assertEqual(p[key], value)
|
||||||
|
for key in self.reference.keys():
|
||||||
|
del p[key]
|
||||||
|
self.assertRaises(KeyError, lambda:p[key])
|
||||||
|
p = self._empty_mapping()
|
||||||
|
#update
|
||||||
|
p.update(self.reference)
|
||||||
|
self.assertEqual(dict(p), self.reference)
|
||||||
|
items = list(p.items())
|
||||||
|
p = self._empty_mapping()
|
||||||
|
p.update(items)
|
||||||
|
self.assertEqual(dict(p), self.reference)
|
||||||
|
d = self._full_mapping(self.reference)
|
||||||
|
#setdefault
|
||||||
|
key, value = next(iter(d.items()))
|
||||||
|
knownkey, knownvalue = next(iter(self.other.items()))
|
||||||
|
self.assertEqual(d.setdefault(key, knownvalue), value)
|
||||||
|
self.assertEqual(d[key], value)
|
||||||
|
self.assertEqual(d.setdefault(knownkey, knownvalue), knownvalue)
|
||||||
|
self.assertEqual(d[knownkey], knownvalue)
|
||||||
|
#pop
|
||||||
|
self.assertEqual(d.pop(knownkey), knownvalue)
|
||||||
|
self.assertNotIn(knownkey, d)
|
||||||
|
self.assertRaises(KeyError, d.pop, knownkey)
|
||||||
|
default = 909
|
||||||
|
d[knownkey] = knownvalue
|
||||||
|
self.assertEqual(d.pop(knownkey, default), knownvalue)
|
||||||
|
self.assertNotIn(knownkey, d)
|
||||||
|
self.assertEqual(d.pop(knownkey, default), default)
|
||||||
|
#popitem
|
||||||
|
key, value = d.popitem()
|
||||||
|
self.assertNotIn(key, d)
|
||||||
|
self.assertEqual(value, self.reference[key])
|
||||||
|
p=self._empty_mapping()
|
||||||
|
self.assertRaises(KeyError, p.popitem)
|
||||||
|
|
||||||
|
def test_constructor(self):
|
||||||
|
self.assertEqual(self._empty_mapping(), self._empty_mapping())
|
||||||
|
|
||||||
|
def test_bool(self):
|
||||||
|
self.assertTrue(not self._empty_mapping())
|
||||||
|
self.assertTrue(self.reference)
|
||||||
|
self.assertTrue(bool(self._empty_mapping()) is False)
|
||||||
|
self.assertTrue(bool(self.reference) is True)
|
||||||
|
|
||||||
|
def test_keys(self):
|
||||||
|
d = self._empty_mapping()
|
||||||
|
self.assertEqual(list(d.keys()), [])
|
||||||
|
d = self.reference
|
||||||
|
self.assertIn(list(self.inmapping.keys())[0], d.keys())
|
||||||
|
self.assertNotIn(list(self.other.keys())[0], d.keys())
|
||||||
|
self.assertRaises(TypeError, d.keys, None)
|
||||||
|
|
||||||
|
def test_values(self):
|
||||||
|
d = self._empty_mapping()
|
||||||
|
self.assertEqual(list(d.values()), [])
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, d.values, None)
|
||||||
|
|
||||||
|
def test_items(self):
|
||||||
|
d = self._empty_mapping()
|
||||||
|
self.assertEqual(list(d.items()), [])
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, d.items, None)
|
||||||
|
|
||||||
|
def test_len(self):
|
||||||
|
d = self._empty_mapping()
|
||||||
|
self.assertEqual(len(d), 0)
|
||||||
|
|
||||||
|
def test_getitem(self):
|
||||||
|
d = self.reference
|
||||||
|
self.assertEqual(d[list(self.inmapping.keys())[0]],
|
||||||
|
list(self.inmapping.values())[0])
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, d.__getitem__)
|
||||||
|
|
||||||
|
def test_update(self):
|
||||||
|
# mapping argument
|
||||||
|
d = self._empty_mapping()
|
||||||
|
d.update(self.other)
|
||||||
|
self.assertEqual(list(d.items()), list(self.other.items()))
|
||||||
|
|
||||||
|
# No argument
|
||||||
|
d = self._empty_mapping()
|
||||||
|
d.update()
|
||||||
|
self.assertEqual(d, self._empty_mapping())
|
||||||
|
|
||||||
|
# item sequence
|
||||||
|
d = self._empty_mapping()
|
||||||
|
d.update(self.other.items())
|
||||||
|
self.assertEqual(list(d.items()), list(self.other.items()))
|
||||||
|
|
||||||
|
# Iterator
|
||||||
|
d = self._empty_mapping()
|
||||||
|
d.update(self.other.items())
|
||||||
|
self.assertEqual(list(d.items()), list(self.other.items()))
|
||||||
|
|
||||||
|
# FIXME: Doesn't work with UserDict
|
||||||
|
# self.assertRaises((TypeError, AttributeError), d.update, None)
|
||||||
|
self.assertRaises((TypeError, AttributeError), d.update, 42)
|
||||||
|
|
||||||
|
outerself = self
|
||||||
|
class SimpleUserDict:
|
||||||
|
def __init__(self):
|
||||||
|
self.d = outerself.reference
|
||||||
|
def keys(self):
|
||||||
|
return self.d.keys()
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return self.d[i]
|
||||||
|
d.clear()
|
||||||
|
d.update(SimpleUserDict())
|
||||||
|
i1 = sorted(d.items())
|
||||||
|
i2 = sorted(self.reference.items())
|
||||||
|
self.assertEqual(i1, i2)
|
||||||
|
|
||||||
|
class Exc(Exception): pass
|
||||||
|
|
||||||
|
d = self._empty_mapping()
|
||||||
|
class FailingUserDict:
|
||||||
|
def keys(self):
|
||||||
|
raise Exc
|
||||||
|
self.assertRaises(Exc, d.update, FailingUserDict())
|
||||||
|
|
||||||
|
d.clear()
|
||||||
|
|
||||||
|
class FailingUserDict:
|
||||||
|
def keys(self):
|
||||||
|
class BogonIter:
|
||||||
|
def __init__(self):
|
||||||
|
self.i = 1
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
def __next__(self):
|
||||||
|
if self.i:
|
||||||
|
self.i = 0
|
||||||
|
return 'a'
|
||||||
|
raise Exc
|
||||||
|
return BogonIter()
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return key
|
||||||
|
self.assertRaises(Exc, d.update, FailingUserDict())
|
||||||
|
|
||||||
|
class FailingUserDict:
|
||||||
|
def keys(self):
|
||||||
|
class BogonIter:
|
||||||
|
def __init__(self):
|
||||||
|
self.i = ord('a')
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
def __next__(self):
|
||||||
|
if self.i <= ord('z'):
|
||||||
|
rtn = chr(self.i)
|
||||||
|
self.i += 1
|
||||||
|
return rtn
|
||||||
|
raise StopIteration
|
||||||
|
return BogonIter()
|
||||||
|
def __getitem__(self, key):
|
||||||
|
raise Exc
|
||||||
|
self.assertRaises(Exc, d.update, FailingUserDict())
|
||||||
|
|
||||||
|
d = self._empty_mapping()
|
||||||
|
class badseq(object):
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
def __next__(self):
|
||||||
|
raise Exc()
|
||||||
|
|
||||||
|
self.assertRaises(Exc, d.update, badseq())
|
||||||
|
|
||||||
|
self.assertRaises(ValueError, d.update, [(1, 2, 3)])
|
||||||
|
|
||||||
|
# no test_fromkeys or test_copy as both os.environ and selves don't support it
|
||||||
|
|
||||||
|
def test_get(self):
|
||||||
|
d = self._empty_mapping()
|
||||||
|
self.assertTrue(d.get(list(self.other.keys())[0]) is None)
|
||||||
|
self.assertEqual(d.get(list(self.other.keys())[0], 3), 3)
|
||||||
|
d = self.reference
|
||||||
|
self.assertTrue(d.get(list(self.other.keys())[0]) is None)
|
||||||
|
self.assertEqual(d.get(list(self.other.keys())[0], 3), 3)
|
||||||
|
self.assertEqual(d.get(list(self.inmapping.keys())[0]),
|
||||||
|
list(self.inmapping.values())[0])
|
||||||
|
self.assertEqual(d.get(list(self.inmapping.keys())[0], 3),
|
||||||
|
list(self.inmapping.values())[0])
|
||||||
|
self.assertRaises(TypeError, d.get)
|
||||||
|
self.assertRaises(TypeError, d.get, None, None, None)
|
||||||
|
|
||||||
|
def test_setdefault(self):
|
||||||
|
d = self._empty_mapping()
|
||||||
|
self.assertRaises(TypeError, d.setdefault)
|
||||||
|
|
||||||
|
def test_popitem(self):
|
||||||
|
d = self._empty_mapping()
|
||||||
|
self.assertRaises(KeyError, d.popitem)
|
||||||
|
self.assertRaises(TypeError, d.popitem, 42)
|
||||||
|
|
||||||
|
def test_pop(self):
|
||||||
|
d = self._empty_mapping()
|
||||||
|
k, v = list(self.inmapping.items())[0]
|
||||||
|
d[k] = v
|
||||||
|
self.assertRaises(KeyError, d.pop, list(self.other.keys())[0])
|
||||||
|
|
||||||
|
self.assertEqual(d.pop(k), v)
|
||||||
|
self.assertEqual(len(d), 0)
|
||||||
|
|
||||||
|
self.assertRaises(KeyError, d.pop, k)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMappingProtocol(BasicTestMappingProtocol):
|
||||||
|
def test_constructor(self):
|
||||||
|
BasicTestMappingProtocol.test_constructor(self)
|
||||||
|
self.assertTrue(self._empty_mapping() is not self._empty_mapping())
|
||||||
|
self.assertEqual(self.type2test(x=1, y=2), {"x": 1, "y": 2})
|
||||||
|
|
||||||
|
def test_bool(self):
|
||||||
|
BasicTestMappingProtocol.test_bool(self)
|
||||||
|
self.assertTrue(not self._empty_mapping())
|
||||||
|
self.assertTrue(self._full_mapping({"x": "y"}))
|
||||||
|
self.assertTrue(bool(self._empty_mapping()) is False)
|
||||||
|
self.assertTrue(bool(self._full_mapping({"x": "y"})) is True)
|
||||||
|
|
||||||
|
def test_keys(self):
|
||||||
|
BasicTestMappingProtocol.test_keys(self)
|
||||||
|
d = self._empty_mapping()
|
||||||
|
self.assertEqual(list(d.keys()), [])
|
||||||
|
d = self._full_mapping({'a': 1, 'b': 2})
|
||||||
|
k = d.keys()
|
||||||
|
self.assertIn('a', k)
|
||||||
|
self.assertIn('b', k)
|
||||||
|
self.assertNotIn('c', k)
|
||||||
|
|
||||||
|
def test_values(self):
|
||||||
|
BasicTestMappingProtocol.test_values(self)
|
||||||
|
d = self._full_mapping({1:2})
|
||||||
|
self.assertEqual(list(d.values()), [2])
|
||||||
|
|
||||||
|
def test_items(self):
|
||||||
|
BasicTestMappingProtocol.test_items(self)
|
||||||
|
|
||||||
|
d = self._full_mapping({1:2})
|
||||||
|
self.assertEqual(list(d.items()), [(1, 2)])
|
||||||
|
|
||||||
|
def test_contains(self):
|
||||||
|
d = self._empty_mapping()
|
||||||
|
self.assertNotIn('a', d)
|
||||||
|
self.assertTrue(not ('a' in d))
|
||||||
|
self.assertTrue('a' not in d)
|
||||||
|
d = self._full_mapping({'a': 1, 'b': 2})
|
||||||
|
self.assertIn('a', d)
|
||||||
|
self.assertIn('b', d)
|
||||||
|
self.assertNotIn('c', d)
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, d.__contains__)
|
||||||
|
|
||||||
|
def test_len(self):
|
||||||
|
BasicTestMappingProtocol.test_len(self)
|
||||||
|
d = self._full_mapping({'a': 1, 'b': 2})
|
||||||
|
self.assertEqual(len(d), 2)
|
||||||
|
|
||||||
|
def test_getitem(self):
|
||||||
|
BasicTestMappingProtocol.test_getitem(self)
|
||||||
|
d = self._full_mapping({'a': 1, 'b': 2})
|
||||||
|
self.assertEqual(d['a'], 1)
|
||||||
|
self.assertEqual(d['b'], 2)
|
||||||
|
d['c'] = 3
|
||||||
|
d['a'] = 4
|
||||||
|
self.assertEqual(d['c'], 3)
|
||||||
|
self.assertEqual(d['a'], 4)
|
||||||
|
del d['b']
|
||||||
|
self.assertEqual(d, {'a': 4, 'c': 3})
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, d.__getitem__)
|
||||||
|
|
||||||
|
def test_clear(self):
|
||||||
|
d = self._full_mapping({1:1, 2:2, 3:3})
|
||||||
|
d.clear()
|
||||||
|
self.assertEqual(d, {})
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, d.clear, None)
|
||||||
|
|
||||||
|
def test_update(self):
|
||||||
|
BasicTestMappingProtocol.test_update(self)
|
||||||
|
# mapping argument
|
||||||
|
d = self._empty_mapping()
|
||||||
|
d.update({1:100})
|
||||||
|
d.update({2:20})
|
||||||
|
d.update({1:1, 2:2, 3:3})
|
||||||
|
self.assertEqual(d, {1:1, 2:2, 3:3})
|
||||||
|
|
||||||
|
# no argument
|
||||||
|
d.update()
|
||||||
|
self.assertEqual(d, {1:1, 2:2, 3:3})
|
||||||
|
|
||||||
|
# keyword arguments
|
||||||
|
d = self._empty_mapping()
|
||||||
|
d.update(x=100)
|
||||||
|
d.update(y=20)
|
||||||
|
d.update(x=1, y=2, z=3)
|
||||||
|
self.assertEqual(d, {"x":1, "y":2, "z":3})
|
||||||
|
|
||||||
|
# item sequence
|
||||||
|
d = self._empty_mapping()
|
||||||
|
d.update([("x", 100), ("y", 20)])
|
||||||
|
self.assertEqual(d, {"x":100, "y":20})
|
||||||
|
|
||||||
|
# Both item sequence and keyword arguments
|
||||||
|
d = self._empty_mapping()
|
||||||
|
d.update([("x", 100), ("y", 20)], x=1, y=2)
|
||||||
|
self.assertEqual(d, {"x":1, "y":2})
|
||||||
|
|
||||||
|
# iterator
|
||||||
|
d = self._full_mapping({1:3, 2:4})
|
||||||
|
d.update(self._full_mapping({1:2, 3:4, 5:6}).items())
|
||||||
|
self.assertEqual(d, {1:2, 2:4, 3:4, 5:6})
|
||||||
|
|
||||||
|
class SimpleUserDict:
|
||||||
|
def __init__(self):
|
||||||
|
self.d = {1:1, 2:2, 3:3}
|
||||||
|
def keys(self):
|
||||||
|
return self.d.keys()
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return self.d[i]
|
||||||
|
d.clear()
|
||||||
|
d.update(SimpleUserDict())
|
||||||
|
self.assertEqual(d, {1:1, 2:2, 3:3})
|
||||||
|
|
||||||
|
def test_fromkeys(self):
|
||||||
|
self.assertEqual(self.type2test.fromkeys('abc'), {'a':None, 'b':None, 'c':None})
|
||||||
|
d = self._empty_mapping()
|
||||||
|
self.assertTrue(not(d.fromkeys('abc') is d))
|
||||||
|
self.assertEqual(d.fromkeys('abc'), {'a':None, 'b':None, 'c':None})
|
||||||
|
self.assertEqual(d.fromkeys((4,5),0), {4:0, 5:0})
|
||||||
|
self.assertEqual(d.fromkeys([]), {})
|
||||||
|
def g():
|
||||||
|
yield 1
|
||||||
|
self.assertEqual(d.fromkeys(g()), {1:None})
|
||||||
|
self.assertRaises(TypeError, {}.fromkeys, 3)
|
||||||
|
class dictlike(self.type2test): pass
|
||||||
|
self.assertEqual(dictlike.fromkeys('a'), {'a':None})
|
||||||
|
self.assertEqual(dictlike().fromkeys('a'), {'a':None})
|
||||||
|
self.assertTrue(dictlike.fromkeys('a').__class__ is dictlike)
|
||||||
|
self.assertTrue(dictlike().fromkeys('a').__class__ is dictlike)
|
||||||
|
self.assertTrue(type(dictlike.fromkeys('a')) is dictlike)
|
||||||
|
class mydict(self.type2test):
|
||||||
|
def __new__(cls):
|
||||||
|
return collections.UserDict()
|
||||||
|
ud = mydict.fromkeys('ab')
|
||||||
|
self.assertEqual(ud, {'a':None, 'b':None})
|
||||||
|
self.assertIsInstance(ud, collections.UserDict)
|
||||||
|
self.assertRaises(TypeError, dict.fromkeys)
|
||||||
|
|
||||||
|
class Exc(Exception): pass
|
||||||
|
|
||||||
|
class baddict1(self.type2test):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
raise Exc()
|
||||||
|
|
||||||
|
self.assertRaises(Exc, baddict1.fromkeys, [1])
|
||||||
|
|
||||||
|
class BadSeq(object):
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
def __next__(self):
|
||||||
|
raise Exc()
|
||||||
|
|
||||||
|
self.assertRaises(Exc, self.type2test.fromkeys, BadSeq())
|
||||||
|
|
||||||
|
class baddict2(self.type2test):
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
raise Exc()
|
||||||
|
|
||||||
|
self.assertRaises(Exc, baddict2.fromkeys, [1])
|
||||||
|
|
||||||
|
def test_copy(self):
|
||||||
|
d = self._full_mapping({1:1, 2:2, 3:3})
|
||||||
|
self.assertEqual(d.copy(), {1:1, 2:2, 3:3})
|
||||||
|
d = self._empty_mapping()
|
||||||
|
self.assertEqual(d.copy(), d)
|
||||||
|
self.assertIsInstance(d.copy(), d.__class__)
|
||||||
|
self.assertRaises(TypeError, d.copy, None)
|
||||||
|
|
||||||
|
def test_get(self):
|
||||||
|
BasicTestMappingProtocol.test_get(self)
|
||||||
|
d = self._empty_mapping()
|
||||||
|
self.assertTrue(d.get('c') is None)
|
||||||
|
self.assertEqual(d.get('c', 3), 3)
|
||||||
|
d = self._full_mapping({'a' : 1, 'b' : 2})
|
||||||
|
self.assertTrue(d.get('c') is None)
|
||||||
|
self.assertEqual(d.get('c', 3), 3)
|
||||||
|
self.assertEqual(d.get('a'), 1)
|
||||||
|
self.assertEqual(d.get('a', 3), 1)
|
||||||
|
|
||||||
|
def test_setdefault(self):
|
||||||
|
BasicTestMappingProtocol.test_setdefault(self)
|
||||||
|
d = self._empty_mapping()
|
||||||
|
self.assertTrue(d.setdefault('key0') is None)
|
||||||
|
d.setdefault('key0', [])
|
||||||
|
self.assertTrue(d.setdefault('key0') is None)
|
||||||
|
d.setdefault('key', []).append(3)
|
||||||
|
self.assertEqual(d['key'][0], 3)
|
||||||
|
d.setdefault('key', []).append(4)
|
||||||
|
self.assertEqual(len(d['key']), 2)
|
||||||
|
|
||||||
|
def test_popitem(self):
|
||||||
|
BasicTestMappingProtocol.test_popitem(self)
|
||||||
|
for copymode in -1, +1:
|
||||||
|
# -1: b has same structure as a
|
||||||
|
# +1: b is a.copy()
|
||||||
|
for log2size in range(12):
|
||||||
|
size = 2**log2size
|
||||||
|
a = self._empty_mapping()
|
||||||
|
b = self._empty_mapping()
|
||||||
|
for i in range(size):
|
||||||
|
a[repr(i)] = i
|
||||||
|
if copymode < 0:
|
||||||
|
b[repr(i)] = i
|
||||||
|
if copymode > 0:
|
||||||
|
b = a.copy()
|
||||||
|
for i in range(size):
|
||||||
|
ka, va = ta = a.popitem()
|
||||||
|
self.assertEqual(va, int(ka))
|
||||||
|
kb, vb = tb = b.popitem()
|
||||||
|
self.assertEqual(vb, int(kb))
|
||||||
|
self.assertTrue(not(copymode < 0 and ta != tb))
|
||||||
|
self.assertTrue(not a)
|
||||||
|
self.assertTrue(not b)
|
||||||
|
|
||||||
|
def test_pop(self):
|
||||||
|
BasicTestMappingProtocol.test_pop(self)
|
||||||
|
|
||||||
|
# Tests for pop with specified key
|
||||||
|
d = self._empty_mapping()
|
||||||
|
k, v = 'abc', 'def'
|
||||||
|
|
||||||
|
self.assertEqual(d.pop(k, v), v)
|
||||||
|
d[k] = v
|
||||||
|
self.assertEqual(d.pop(k, 1), v)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHashMappingProtocol(TestMappingProtocol):
|
||||||
|
|
||||||
|
def test_getitem(self):
|
||||||
|
TestMappingProtocol.test_getitem(self)
|
||||||
|
class Exc(Exception): pass
|
||||||
|
|
||||||
|
class BadEq(object):
|
||||||
|
def __eq__(self, other):
|
||||||
|
raise Exc()
|
||||||
|
def __hash__(self):
|
||||||
|
return 24
|
||||||
|
|
||||||
|
d = self._empty_mapping()
|
||||||
|
d[BadEq()] = 42
|
||||||
|
self.assertRaises(KeyError, d.__getitem__, 23)
|
||||||
|
|
||||||
|
class BadHash(object):
|
||||||
|
fail = False
|
||||||
|
def __hash__(self):
|
||||||
|
if self.fail:
|
||||||
|
raise Exc()
|
||||||
|
else:
|
||||||
|
return 42
|
||||||
|
|
||||||
|
d = self._empty_mapping()
|
||||||
|
x = BadHash()
|
||||||
|
d[x] = 42
|
||||||
|
x.fail = True
|
||||||
|
self.assertRaises(Exc, d.__getitem__, x)
|
||||||
|
|
||||||
|
def test_fromkeys(self):
|
||||||
|
TestMappingProtocol.test_fromkeys(self)
|
||||||
|
class mydict(self.type2test):
|
||||||
|
def __new__(cls):
|
||||||
|
return collections.UserDict()
|
||||||
|
ud = mydict.fromkeys('ab')
|
||||||
|
self.assertEqual(ud, {'a':None, 'b':None})
|
||||||
|
self.assertIsInstance(ud, collections.UserDict)
|
||||||
|
|
||||||
|
def test_pop(self):
|
||||||
|
TestMappingProtocol.test_pop(self)
|
||||||
|
|
||||||
|
class Exc(Exception): pass
|
||||||
|
|
||||||
|
class BadHash(object):
|
||||||
|
fail = False
|
||||||
|
def __hash__(self):
|
||||||
|
if self.fail:
|
||||||
|
raise Exc()
|
||||||
|
else:
|
||||||
|
return 42
|
||||||
|
|
||||||
|
d = self._empty_mapping()
|
||||||
|
x = BadHash()
|
||||||
|
d[x] = 42
|
||||||
|
x.fail = True
|
||||||
|
self.assertRaises(Exc, d.pop, x)
|
||||||
|
|
||||||
|
def test_mutatingiteration(self):
|
||||||
|
d = self._empty_mapping()
|
||||||
|
d[1] = 1
|
||||||
|
try:
|
||||||
|
count = 0
|
||||||
|
for i in d:
|
||||||
|
d[i+1] = 1
|
||||||
|
if count >= 1:
|
||||||
|
self.fail("changing dict size during iteration doesn't raise Error")
|
||||||
|
count += 1
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
d = self._empty_mapping()
|
||||||
|
self.assertEqual(repr(d), '{}')
|
||||||
|
d[1] = 2
|
||||||
|
self.assertEqual(repr(d), '{1: 2}')
|
||||||
|
d = self._empty_mapping()
|
||||||
|
d[1] = d
|
||||||
|
self.assertEqual(repr(d), '{1: {...}}')
|
||||||
|
|
||||||
|
class Exc(Exception): pass
|
||||||
|
|
||||||
|
class BadRepr(object):
|
||||||
|
def __repr__(self):
|
||||||
|
raise Exc()
|
||||||
|
|
||||||
|
d = self._full_mapping({1: BadRepr()})
|
||||||
|
self.assertRaises(Exc, repr, d)
|
||||||
|
|
||||||
|
def test_repr_deep(self):
|
||||||
|
d = self._empty_mapping()
|
||||||
|
for i in range(get_c_recursion_limit() + 1):
|
||||||
|
d0 = d
|
||||||
|
d = self._empty_mapping()
|
||||||
|
d[1] = d0
|
||||||
|
self.assertRaises(RecursionError, repr, d)
|
||||||
|
|
||||||
|
def test_eq(self):
|
||||||
|
self.assertEqual(self._empty_mapping(), self._empty_mapping())
|
||||||
|
self.assertEqual(self._full_mapping({1: 2}),
|
||||||
|
self._full_mapping({1: 2}))
|
||||||
|
|
||||||
|
class Exc(Exception): pass
|
||||||
|
|
||||||
|
class BadCmp(object):
|
||||||
|
def __eq__(self, other):
|
||||||
|
raise Exc()
|
||||||
|
def __hash__(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
d1 = self._full_mapping({BadCmp(): 1})
|
||||||
|
d2 = self._full_mapping({1: 1})
|
||||||
|
self.assertRaises(Exc, lambda: BadCmp()==1)
|
||||||
|
self.assertRaises(Exc, lambda: d1==d2)
|
||||||
|
|
||||||
|
def test_setdefault(self):
|
||||||
|
TestMappingProtocol.test_setdefault(self)
|
||||||
|
|
||||||
|
class Exc(Exception): pass
|
||||||
|
|
||||||
|
class BadHash(object):
|
||||||
|
fail = False
|
||||||
|
def __hash__(self):
|
||||||
|
if self.fail:
|
||||||
|
raise Exc()
|
||||||
|
else:
|
||||||
|
return 42
|
||||||
|
|
||||||
|
d = self._empty_mapping()
|
||||||
|
x = BadHash()
|
||||||
|
d[x] = 42
|
||||||
|
x.fail = True
|
||||||
|
self.assertRaises(Exc, d.setdefault, x, [])
|
||||||
68
test/dynamo/cpython/3_13/seq_tests.diff
Normal file
68
test/dynamo/cpython/3_13/seq_tests.diff
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
diff --git a/test/dynamo/cpython/3_13/seq_tests.py b/test/dynamo/cpython/3_13/seq_tests.py
|
||||||
|
index 719c9434a16..4325892276d 100644
|
||||||
|
--- a/test/dynamo/cpython/3_13/seq_tests.py
|
||||||
|
+++ b/test/dynamo/cpython/3_13/seq_tests.py
|
||||||
|
@@ -1,3 +1,54 @@
|
||||||
|
+# ======= BEGIN Dynamo patch =======
|
||||||
|
+# Owner(s): ["module: dynamo"]
|
||||||
|
+
|
||||||
|
+# ruff: noqa
|
||||||
|
+# flake8: noqa
|
||||||
|
+
|
||||||
|
+import sys
|
||||||
|
+import torch
|
||||||
|
+import torch._dynamo.test_case
|
||||||
|
+import unittest
|
||||||
|
+from torch._dynamo.test_case import CPythonTestCase
|
||||||
|
+from torch.testing._internal.common_utils import run_tests
|
||||||
|
+
|
||||||
|
+__TestCase = CPythonTestCase
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+# redirect import statements
|
||||||
|
+import sys
|
||||||
|
+import importlib.abc
|
||||||
|
+
|
||||||
|
+redirect_imports = (
|
||||||
|
+ "test.mapping_tests",
|
||||||
|
+ "test.typinganndata",
|
||||||
|
+ "test.test_grammar",
|
||||||
|
+ "test.test_math",
|
||||||
|
+ "test.test_iter",
|
||||||
|
+ "test.typinganndata.ann_module",
|
||||||
|
+)
|
||||||
|
+
|
||||||
|
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||||
|
+ def find_spec(self, fullname, path, target=None):
|
||||||
|
+ # Check if the import is the problematic one
|
||||||
|
+ if fullname in redirect_imports:
|
||||||
|
+ try:
|
||||||
|
+ # Attempt to import the standalone module
|
||||||
|
+ name = fullname.removeprefix("test.")
|
||||||
|
+ r = importlib.import_module(name)
|
||||||
|
+ # Redirect the module in sys.modules
|
||||||
|
+ sys.modules[fullname] = r
|
||||||
|
+ # Return a module spec from the found module
|
||||||
|
+ return importlib.util.find_spec(name)
|
||||||
|
+ except ImportError:
|
||||||
|
+ return None
|
||||||
|
+ return None
|
||||||
|
+
|
||||||
|
+# Add the custom finder to sys.meta_path
|
||||||
|
+sys.meta_path.insert(0, RedirectImportFinder())
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+# ======= END DYNAMO PATCH =======
|
||||||
|
+
|
||||||
|
"""
|
||||||
|
Tests common to tuple, list and UserList.UserList
|
||||||
|
"""
|
||||||
|
@@ -95,7 +146,7 @@ class LyingList(list):
|
||||||
|
def __iter__(self):
|
||||||
|
yield 1
|
||||||
|
|
||||||
|
-class CommonTest(unittest.TestCase):
|
||||||
|
+class CommonTest(__TestCase):
|
||||||
|
# The type to be tested
|
||||||
|
type2test = None
|
||||||
|
|
||||||
483
test/dynamo/cpython/3_13/seq_tests.py
Normal file
483
test/dynamo/cpython/3_13/seq_tests.py
Normal file
@ -0,0 +1,483 @@
|
|||||||
|
# ======= BEGIN Dynamo patch =======
|
||||||
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
|
# ruff: noqa
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import torch._dynamo.test_case
|
||||||
|
import unittest
|
||||||
|
from torch._dynamo.test_case import CPythonTestCase
|
||||||
|
from torch.testing._internal.common_utils import run_tests
|
||||||
|
|
||||||
|
__TestCase = CPythonTestCase
|
||||||
|
|
||||||
|
|
||||||
|
# redirect import statements
|
||||||
|
import sys
|
||||||
|
import importlib.abc
|
||||||
|
|
||||||
|
redirect_imports = (
|
||||||
|
"test.mapping_tests",
|
||||||
|
"test.typinganndata",
|
||||||
|
"test.test_grammar",
|
||||||
|
"test.test_math",
|
||||||
|
"test.test_iter",
|
||||||
|
"test.typinganndata.ann_module",
|
||||||
|
)
|
||||||
|
|
||||||
|
class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||||
|
def find_spec(self, fullname, path, target=None):
|
||||||
|
# Check if the import is the problematic one
|
||||||
|
if fullname in redirect_imports:
|
||||||
|
try:
|
||||||
|
# Attempt to import the standalone module
|
||||||
|
name = fullname.removeprefix("test.")
|
||||||
|
r = importlib.import_module(name)
|
||||||
|
# Redirect the module in sys.modules
|
||||||
|
sys.modules[fullname] = r
|
||||||
|
# Return a module spec from the found module
|
||||||
|
return importlib.util.find_spec(name)
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Add the custom finder to sys.meta_path
|
||||||
|
sys.meta_path.insert(0, RedirectImportFinder())
|
||||||
|
|
||||||
|
|
||||||
|
# ======= END DYNAMO PATCH =======
|
||||||
|
|
||||||
|
"""
|
||||||
|
Tests common to tuple, list and UserList.UserList
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import sys
|
||||||
|
import pickle
|
||||||
|
from test import support
|
||||||
|
from test.support import ALWAYS_EQ, NEVER_EQ
|
||||||
|
|
||||||
|
# Various iterables
|
||||||
|
# This is used for checking the constructor (here and in test_deque.py)
|
||||||
|
def iterfunc(seqn):
|
||||||
|
'Regular generator'
|
||||||
|
for i in seqn:
|
||||||
|
yield i
|
||||||
|
|
||||||
|
class Sequence:
|
||||||
|
'Sequence using __getitem__'
|
||||||
|
def __init__(self, seqn):
|
||||||
|
self.seqn = seqn
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return self.seqn[i]
|
||||||
|
|
||||||
|
class IterFunc:
|
||||||
|
'Sequence using iterator protocol'
|
||||||
|
def __init__(self, seqn):
|
||||||
|
self.seqn = seqn
|
||||||
|
self.i = 0
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
def __next__(self):
|
||||||
|
if self.i >= len(self.seqn): raise StopIteration
|
||||||
|
v = self.seqn[self.i]
|
||||||
|
self.i += 1
|
||||||
|
return v
|
||||||
|
|
||||||
|
class IterGen:
|
||||||
|
'Sequence using iterator protocol defined with a generator'
|
||||||
|
def __init__(self, seqn):
|
||||||
|
self.seqn = seqn
|
||||||
|
self.i = 0
|
||||||
|
def __iter__(self):
|
||||||
|
for val in self.seqn:
|
||||||
|
yield val
|
||||||
|
|
||||||
|
class IterNextOnly:
|
||||||
|
'Missing __getitem__ and __iter__'
|
||||||
|
def __init__(self, seqn):
|
||||||
|
self.seqn = seqn
|
||||||
|
self.i = 0
|
||||||
|
def __next__(self):
|
||||||
|
if self.i >= len(self.seqn): raise StopIteration
|
||||||
|
v = self.seqn[self.i]
|
||||||
|
self.i += 1
|
||||||
|
return v
|
||||||
|
|
||||||
|
class IterNoNext:
|
||||||
|
'Iterator missing __next__()'
|
||||||
|
def __init__(self, seqn):
|
||||||
|
self.seqn = seqn
|
||||||
|
self.i = 0
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
class IterGenExc:
|
||||||
|
'Test propagation of exceptions'
|
||||||
|
def __init__(self, seqn):
|
||||||
|
self.seqn = seqn
|
||||||
|
self.i = 0
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
def __next__(self):
|
||||||
|
3 // 0
|
||||||
|
|
||||||
|
class IterFuncStop:
|
||||||
|
'Test immediate stop'
|
||||||
|
def __init__(self, seqn):
|
||||||
|
pass
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
def __next__(self):
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
from itertools import chain
|
||||||
|
def itermulti(seqn):
|
||||||
|
'Test multiple tiers of iterators'
|
||||||
|
return chain(map(lambda x:x, iterfunc(IterGen(Sequence(seqn)))))
|
||||||
|
|
||||||
|
class LyingTuple(tuple):
|
||||||
|
def __iter__(self):
|
||||||
|
yield 1
|
||||||
|
|
||||||
|
class LyingList(list):
|
||||||
|
def __iter__(self):
|
||||||
|
yield 1
|
||||||
|
|
||||||
|
class CommonTest(__TestCase):
|
||||||
|
# The type to be tested
|
||||||
|
type2test = None
|
||||||
|
|
||||||
|
def test_constructors(self):
|
||||||
|
l0 = []
|
||||||
|
l1 = [0]
|
||||||
|
l2 = [0, 1]
|
||||||
|
|
||||||
|
u = self.type2test()
|
||||||
|
u0 = self.type2test(l0)
|
||||||
|
u1 = self.type2test(l1)
|
||||||
|
u2 = self.type2test(l2)
|
||||||
|
|
||||||
|
uu = self.type2test(u)
|
||||||
|
uu0 = self.type2test(u0)
|
||||||
|
uu1 = self.type2test(u1)
|
||||||
|
uu2 = self.type2test(u2)
|
||||||
|
|
||||||
|
v = self.type2test(tuple(u))
|
||||||
|
class OtherSeq:
|
||||||
|
def __init__(self, initseq):
|
||||||
|
self.__data = initseq
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.__data)
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return self.__data[i]
|
||||||
|
s = OtherSeq(u0)
|
||||||
|
v0 = self.type2test(s)
|
||||||
|
self.assertEqual(len(v0), len(s))
|
||||||
|
|
||||||
|
s = "this is also a sequence"
|
||||||
|
vv = self.type2test(s)
|
||||||
|
self.assertEqual(len(vv), len(s))
|
||||||
|
|
||||||
|
# Create from various iteratables
|
||||||
|
for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
|
||||||
|
for g in (Sequence, IterFunc, IterGen,
|
||||||
|
itermulti, iterfunc):
|
||||||
|
self.assertEqual(self.type2test(g(s)), self.type2test(s))
|
||||||
|
self.assertEqual(self.type2test(IterFuncStop(s)), self.type2test())
|
||||||
|
self.assertEqual(self.type2test(c for c in "123"), self.type2test("123"))
|
||||||
|
self.assertRaises(TypeError, self.type2test, IterNextOnly(s))
|
||||||
|
self.assertRaises(TypeError, self.type2test, IterNoNext(s))
|
||||||
|
self.assertRaises(ZeroDivisionError, self.type2test, IterGenExc(s))
|
||||||
|
|
||||||
|
# Issue #23757
|
||||||
|
self.assertEqual(self.type2test(LyingTuple((2,))), self.type2test((1,)))
|
||||||
|
self.assertEqual(self.type2test(LyingList([2])), self.type2test([1]))
|
||||||
|
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
self.type2test(unsupported_arg=[])
|
||||||
|
|
||||||
|
def test_truth(self):
|
||||||
|
self.assertFalse(self.type2test())
|
||||||
|
self.assertTrue(self.type2test([42]))
|
||||||
|
|
||||||
|
def test_getitem(self):
|
||||||
|
u = self.type2test([0, 1, 2, 3, 4])
|
||||||
|
for i in range(len(u)):
|
||||||
|
self.assertEqual(u[i], i)
|
||||||
|
self.assertEqual(u[int(i)], i)
|
||||||
|
for i in range(-len(u), -1):
|
||||||
|
self.assertEqual(u[i], len(u)+i)
|
||||||
|
self.assertEqual(u[int(i)], len(u)+i)
|
||||||
|
self.assertRaises(IndexError, u.__getitem__, -len(u)-1)
|
||||||
|
self.assertRaises(IndexError, u.__getitem__, len(u))
|
||||||
|
self.assertRaises(ValueError, u.__getitem__, slice(0,10,0))
|
||||||
|
|
||||||
|
u = self.type2test()
|
||||||
|
self.assertRaises(IndexError, u.__getitem__, 0)
|
||||||
|
self.assertRaises(IndexError, u.__getitem__, -1)
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, u.__getitem__)
|
||||||
|
|
||||||
|
a = self.type2test([10, 11])
|
||||||
|
self.assertEqual(a[0], 10)
|
||||||
|
self.assertEqual(a[1], 11)
|
||||||
|
self.assertEqual(a[-2], 10)
|
||||||
|
self.assertEqual(a[-1], 11)
|
||||||
|
self.assertRaises(IndexError, a.__getitem__, -3)
|
||||||
|
self.assertRaises(IndexError, a.__getitem__, 3)
|
||||||
|
|
||||||
|
def test_getslice(self):
|
||||||
|
l = [0, 1, 2, 3, 4]
|
||||||
|
u = self.type2test(l)
|
||||||
|
|
||||||
|
self.assertEqual(u[0:0], self.type2test())
|
||||||
|
self.assertEqual(u[1:2], self.type2test([1]))
|
||||||
|
self.assertEqual(u[-2:-1], self.type2test([3]))
|
||||||
|
self.assertEqual(u[-1000:1000], u)
|
||||||
|
self.assertEqual(u[1000:-1000], self.type2test([]))
|
||||||
|
self.assertEqual(u[:], u)
|
||||||
|
self.assertEqual(u[1:None], self.type2test([1, 2, 3, 4]))
|
||||||
|
self.assertEqual(u[None:3], self.type2test([0, 1, 2]))
|
||||||
|
|
||||||
|
# Extended slices
|
||||||
|
self.assertEqual(u[::], u)
|
||||||
|
self.assertEqual(u[::2], self.type2test([0, 2, 4]))
|
||||||
|
self.assertEqual(u[1::2], self.type2test([1, 3]))
|
||||||
|
self.assertEqual(u[::-1], self.type2test([4, 3, 2, 1, 0]))
|
||||||
|
self.assertEqual(u[::-2], self.type2test([4, 2, 0]))
|
||||||
|
self.assertEqual(u[3::-2], self.type2test([3, 1]))
|
||||||
|
self.assertEqual(u[3:3:-2], self.type2test([]))
|
||||||
|
self.assertEqual(u[3:2:-2], self.type2test([3]))
|
||||||
|
self.assertEqual(u[3:1:-2], self.type2test([3]))
|
||||||
|
self.assertEqual(u[3:0:-2], self.type2test([3, 1]))
|
||||||
|
self.assertEqual(u[::-100], self.type2test([4]))
|
||||||
|
self.assertEqual(u[100:-100:], self.type2test([]))
|
||||||
|
self.assertEqual(u[-100:100:], u)
|
||||||
|
self.assertEqual(u[100:-100:-1], u[::-1])
|
||||||
|
self.assertEqual(u[-100:100:-1], self.type2test([]))
|
||||||
|
self.assertEqual(u[-100:100:2], self.type2test([0, 2, 4]))
|
||||||
|
|
||||||
|
# Test extreme cases with long ints
|
||||||
|
a = self.type2test([0,1,2,3,4])
|
||||||
|
self.assertEqual(a[ -pow(2,128): 3 ], self.type2test([0,1,2]))
|
||||||
|
self.assertEqual(a[ 3: pow(2,145) ], self.type2test([3,4]))
|
||||||
|
self.assertEqual(a[3::sys.maxsize], self.type2test([3]))
|
||||||
|
|
||||||
|
def test_contains(self):
|
||||||
|
u = self.type2test([0, 1, 2])
|
||||||
|
for i in u:
|
||||||
|
self.assertIn(i, u)
|
||||||
|
for i in min(u)-1, max(u)+1:
|
||||||
|
self.assertNotIn(i, u)
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, u.__contains__)
|
||||||
|
|
||||||
|
def test_contains_fake(self):
|
||||||
|
# Sequences must use rich comparison against each item
|
||||||
|
# (unless "is" is true, or an earlier item answered)
|
||||||
|
# So ALWAYS_EQ must be found in all non-empty sequences.
|
||||||
|
self.assertNotIn(ALWAYS_EQ, self.type2test([]))
|
||||||
|
self.assertIn(ALWAYS_EQ, self.type2test([1]))
|
||||||
|
self.assertIn(1, self.type2test([ALWAYS_EQ]))
|
||||||
|
self.assertNotIn(NEVER_EQ, self.type2test([]))
|
||||||
|
self.assertNotIn(ALWAYS_EQ, self.type2test([NEVER_EQ]))
|
||||||
|
self.assertIn(NEVER_EQ, self.type2test([ALWAYS_EQ]))
|
||||||
|
|
||||||
|
def test_contains_order(self):
|
||||||
|
# Sequences must test in-order. If a rich comparison has side
|
||||||
|
# effects, these will be visible to tests against later members.
|
||||||
|
# In this test, the "side effect" is a short-circuiting raise.
|
||||||
|
class DoNotTestEq(Exception):
|
||||||
|
pass
|
||||||
|
class StopCompares:
|
||||||
|
def __eq__(self, other):
|
||||||
|
raise DoNotTestEq
|
||||||
|
|
||||||
|
checkfirst = self.type2test([1, StopCompares()])
|
||||||
|
self.assertIn(1, checkfirst)
|
||||||
|
checklast = self.type2test([StopCompares(), 1])
|
||||||
|
self.assertRaises(DoNotTestEq, checklast.__contains__, 1)
|
||||||
|
|
||||||
|
def test_len(self):
|
||||||
|
self.assertEqual(len(self.type2test()), 0)
|
||||||
|
self.assertEqual(len(self.type2test([])), 0)
|
||||||
|
self.assertEqual(len(self.type2test([0])), 1)
|
||||||
|
self.assertEqual(len(self.type2test([0, 1, 2])), 3)
|
||||||
|
|
||||||
|
def test_minmax(self):
|
||||||
|
u = self.type2test([0, 1, 2])
|
||||||
|
self.assertEqual(min(u), 0)
|
||||||
|
self.assertEqual(max(u), 2)
|
||||||
|
|
||||||
|
def test_addmul(self):
|
||||||
|
u1 = self.type2test([0])
|
||||||
|
u2 = self.type2test([0, 1])
|
||||||
|
self.assertEqual(u1, u1 + self.type2test())
|
||||||
|
self.assertEqual(u1, self.type2test() + u1)
|
||||||
|
self.assertEqual(u1 + self.type2test([1]), u2)
|
||||||
|
self.assertEqual(self.type2test([-1]) + u1, self.type2test([-1, 0]))
|
||||||
|
self.assertEqual(self.type2test(), u2*0)
|
||||||
|
self.assertEqual(self.type2test(), 0*u2)
|
||||||
|
self.assertEqual(self.type2test(), u2*0)
|
||||||
|
self.assertEqual(self.type2test(), 0*u2)
|
||||||
|
self.assertEqual(u2, u2*1)
|
||||||
|
self.assertEqual(u2, 1*u2)
|
||||||
|
self.assertEqual(u2, u2*1)
|
||||||
|
self.assertEqual(u2, 1*u2)
|
||||||
|
self.assertEqual(u2+u2, u2*2)
|
||||||
|
self.assertEqual(u2+u2, 2*u2)
|
||||||
|
self.assertEqual(u2+u2, u2*2)
|
||||||
|
self.assertEqual(u2+u2, 2*u2)
|
||||||
|
self.assertEqual(u2+u2+u2, u2*3)
|
||||||
|
self.assertEqual(u2+u2+u2, 3*u2)
|
||||||
|
|
||||||
|
class subclass(self.type2test):
|
||||||
|
pass
|
||||||
|
u3 = subclass([0, 1])
|
||||||
|
self.assertEqual(u3, u3*1)
|
||||||
|
self.assertIsNot(u3, u3*1)
|
||||||
|
|
||||||
|
def test_iadd(self):
|
||||||
|
u = self.type2test([0, 1])
|
||||||
|
u += self.type2test()
|
||||||
|
self.assertEqual(u, self.type2test([0, 1]))
|
||||||
|
u += self.type2test([2, 3])
|
||||||
|
self.assertEqual(u, self.type2test([0, 1, 2, 3]))
|
||||||
|
u += self.type2test([4, 5])
|
||||||
|
self.assertEqual(u, self.type2test([0, 1, 2, 3, 4, 5]))
|
||||||
|
|
||||||
|
u = self.type2test("spam")
|
||||||
|
u += self.type2test("eggs")
|
||||||
|
self.assertEqual(u, self.type2test("spameggs"))
|
||||||
|
|
||||||
|
def test_imul(self):
|
||||||
|
u = self.type2test([0, 1])
|
||||||
|
u *= 3
|
||||||
|
self.assertEqual(u, self.type2test([0, 1, 0, 1, 0, 1]))
|
||||||
|
u *= 0
|
||||||
|
self.assertEqual(u, self.type2test([]))
|
||||||
|
|
||||||
|
def test_getitemoverwriteiter(self):
|
||||||
|
# Verify that __getitem__ overrides are not recognized by __iter__
|
||||||
|
class T(self.type2test):
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return str(key) + '!!!'
|
||||||
|
self.assertEqual(next(iter(T((1,2)))), 1)
|
||||||
|
|
||||||
|
def test_repeat(self):
|
||||||
|
for m in range(4):
|
||||||
|
s = tuple(range(m))
|
||||||
|
for n in range(-3, 5):
|
||||||
|
self.assertEqual(self.type2test(s*n), self.type2test(s)*n)
|
||||||
|
self.assertEqual(self.type2test(s)*(-4), self.type2test([]))
|
||||||
|
self.assertEqual(id(s), id(s*1))
|
||||||
|
|
||||||
|
def test_bigrepeat(self):
|
||||||
|
if sys.maxsize <= 2147483647:
|
||||||
|
x = self.type2test([0])
|
||||||
|
x *= 2**16
|
||||||
|
self.assertRaises(MemoryError, x.__mul__, 2**16)
|
||||||
|
if hasattr(x, '__imul__'):
|
||||||
|
self.assertRaises(MemoryError, x.__imul__, 2**16)
|
||||||
|
|
||||||
|
def test_subscript(self):
|
||||||
|
a = self.type2test([10, 11])
|
||||||
|
self.assertEqual(a.__getitem__(0), 10)
|
||||||
|
self.assertEqual(a.__getitem__(1), 11)
|
||||||
|
self.assertEqual(a.__getitem__(-2), 10)
|
||||||
|
self.assertEqual(a.__getitem__(-1), 11)
|
||||||
|
self.assertRaises(IndexError, a.__getitem__, -3)
|
||||||
|
self.assertRaises(IndexError, a.__getitem__, 3)
|
||||||
|
self.assertEqual(a.__getitem__(slice(0,1)), self.type2test([10]))
|
||||||
|
self.assertEqual(a.__getitem__(slice(1,2)), self.type2test([11]))
|
||||||
|
self.assertEqual(a.__getitem__(slice(0,2)), self.type2test([10, 11]))
|
||||||
|
self.assertEqual(a.__getitem__(slice(0,3)), self.type2test([10, 11]))
|
||||||
|
self.assertEqual(a.__getitem__(slice(3,5)), self.type2test([]))
|
||||||
|
self.assertRaises(ValueError, a.__getitem__, slice(0, 10, 0))
|
||||||
|
self.assertRaises(TypeError, a.__getitem__, 'x')
|
||||||
|
|
||||||
|
def test_count(self):
|
||||||
|
a = self.type2test([0, 1, 2])*3
|
||||||
|
self.assertEqual(a.count(0), 3)
|
||||||
|
self.assertEqual(a.count(1), 3)
|
||||||
|
self.assertEqual(a.count(3), 0)
|
||||||
|
|
||||||
|
self.assertEqual(a.count(ALWAYS_EQ), 9)
|
||||||
|
self.assertEqual(self.type2test([ALWAYS_EQ, ALWAYS_EQ]).count(1), 2)
|
||||||
|
self.assertEqual(self.type2test([ALWAYS_EQ, ALWAYS_EQ]).count(NEVER_EQ), 2)
|
||||||
|
self.assertEqual(self.type2test([NEVER_EQ, NEVER_EQ]).count(ALWAYS_EQ), 0)
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, a.count)
|
||||||
|
|
||||||
|
class BadExc(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class BadCmp:
|
||||||
|
def __eq__(self, other):
|
||||||
|
if other == 2:
|
||||||
|
raise BadExc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.assertRaises(BadExc, a.count, BadCmp())
|
||||||
|
|
||||||
|
def test_index(self):
|
||||||
|
u = self.type2test([0, 1])
|
||||||
|
self.assertEqual(u.index(0), 0)
|
||||||
|
self.assertEqual(u.index(1), 1)
|
||||||
|
self.assertRaises(ValueError, u.index, 2)
|
||||||
|
|
||||||
|
u = self.type2test([-2, -1, 0, 0, 1, 2])
|
||||||
|
self.assertEqual(u.count(0), 2)
|
||||||
|
self.assertEqual(u.index(0), 2)
|
||||||
|
self.assertEqual(u.index(0, 2), 2)
|
||||||
|
self.assertEqual(u.index(-2, -10), 0)
|
||||||
|
self.assertEqual(u.index(0, 3), 3)
|
||||||
|
self.assertEqual(u.index(0, 3, 4), 3)
|
||||||
|
self.assertRaises(ValueError, u.index, 2, 0, -10)
|
||||||
|
|
||||||
|
self.assertEqual(u.index(ALWAYS_EQ), 0)
|
||||||
|
self.assertEqual(self.type2test([ALWAYS_EQ, ALWAYS_EQ]).index(1), 0)
|
||||||
|
self.assertEqual(self.type2test([ALWAYS_EQ, ALWAYS_EQ]).index(NEVER_EQ), 0)
|
||||||
|
self.assertRaises(ValueError, self.type2test([NEVER_EQ, NEVER_EQ]).index, ALWAYS_EQ)
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, u.index)
|
||||||
|
|
||||||
|
class BadExc(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class BadCmp:
|
||||||
|
def __eq__(self, other):
|
||||||
|
if other == 2:
|
||||||
|
raise BadExc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
a = self.type2test([0, 1, 2, 3])
|
||||||
|
self.assertRaises(BadExc, a.index, BadCmp())
|
||||||
|
|
||||||
|
a = self.type2test([-2, -1, 0, 0, 1, 2])
|
||||||
|
self.assertEqual(a.index(0), 2)
|
||||||
|
self.assertEqual(a.index(0, 2), 2)
|
||||||
|
self.assertEqual(a.index(0, -4), 2)
|
||||||
|
self.assertEqual(a.index(-2, -10), 0)
|
||||||
|
self.assertEqual(a.index(0, 3), 3)
|
||||||
|
self.assertEqual(a.index(0, -3), 3)
|
||||||
|
self.assertEqual(a.index(0, 3, 4), 3)
|
||||||
|
self.assertEqual(a.index(0, -3, -2), 3)
|
||||||
|
self.assertEqual(a.index(0, -4*sys.maxsize, 4*sys.maxsize), 2)
|
||||||
|
self.assertRaises(ValueError, a.index, 0, 4*sys.maxsize,-4*sys.maxsize)
|
||||||
|
self.assertRaises(ValueError, a.index, 2, 0, -10)
|
||||||
|
|
||||||
|
def test_pickle(self):
|
||||||
|
lst = self.type2test([4, 5, 6, 7])
|
||||||
|
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||||
|
lst2 = pickle.loads(pickle.dumps(lst, proto))
|
||||||
|
self.assertEqual(lst2, lst)
|
||||||
|
self.assertNotEqual(id(lst2), id(lst))
|
||||||
|
|
||||||
|
@support.suppress_immortalization()
|
||||||
|
def test_free_after_iterating(self):
|
||||||
|
support.check_free_after_iterating(self, iter, self.type2test)
|
||||||
|
support.check_free_after_iterating(self, reversed, self.type2test)
|
||||||
122
test/dynamo/cpython/3_13/test_dict.diff
Normal file
122
test/dynamo/cpython/3_13/test_dict.diff
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
diff --git a/test/dynamo/cpython/3_13/test_dict.py b/test/dynamo/cpython/3_13/test_dict.py
|
||||||
|
index 4729132c5a5..14f829c1715 100644
|
||||||
|
--- a/test/dynamo/cpython/3_13/test_dict.py
|
||||||
|
+++ b/test/dynamo/cpython/3_13/test_dict.py
|
||||||
|
@@ -1,3 +1,57 @@
|
||||||
|
+# ======= BEGIN Dynamo patch =======
|
||||||
|
+# Owner(s): ["module: dynamo"]
|
||||||
|
+
|
||||||
|
+# ruff: noqa
|
||||||
|
+# flake8: noqa
|
||||||
|
+
|
||||||
|
+import sys
|
||||||
|
+import torch
|
||||||
|
+import torch._dynamo.test_case
|
||||||
|
+import unittest
|
||||||
|
+from torch._dynamo.test_case import CPythonTestCase
|
||||||
|
+from torch.testing._internal.common_utils import (
|
||||||
|
+ run_tests,
|
||||||
|
+ xfailIfTorchDynamo,
|
||||||
|
+)
|
||||||
|
+
|
||||||
|
+__TestCase = CPythonTestCase
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+# redirect import statements
|
||||||
|
+import sys
|
||||||
|
+import importlib.abc
|
||||||
|
+
|
||||||
|
+redirect_imports = (
|
||||||
|
+ "test.mapping_tests",
|
||||||
|
+ "test.typinganndata",
|
||||||
|
+ "test.test_grammar",
|
||||||
|
+ "test.test_math",
|
||||||
|
+ "test.test_iter",
|
||||||
|
+ "test.typinganndata.ann_module",
|
||||||
|
+)
|
||||||
|
+
|
||||||
|
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||||
|
+ def find_spec(self, fullname, path, target=None):
|
||||||
|
+ # Check if the import is the problematic one
|
||||||
|
+ if fullname in redirect_imports:
|
||||||
|
+ try:
|
||||||
|
+ # Attempt to import the standalone module
|
||||||
|
+ name = fullname.removeprefix("test.")
|
||||||
|
+ r = importlib.import_module(name)
|
||||||
|
+ # Redirect the module in sys.modules
|
||||||
|
+ sys.modules[fullname] = r
|
||||||
|
+ # Return a module spec from the found module
|
||||||
|
+ return importlib.util.find_spec(name)
|
||||||
|
+ except ImportError:
|
||||||
|
+ return None
|
||||||
|
+ return None
|
||||||
|
+
|
||||||
|
+# Add the custom finder to sys.meta_path
|
||||||
|
+sys.meta_path.insert(0, RedirectImportFinder())
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+# ======= END DYNAMO PATCH =======
|
||||||
|
+
|
||||||
|
import collections
|
||||||
|
import collections.abc
|
||||||
|
import gc
|
||||||
|
@@ -11,7 +65,7 @@ from test import support
|
||||||
|
from test.support import import_helper, get_c_recursion_limit
|
||||||
|
|
||||||
|
|
||||||
|
-class DictTest(unittest.TestCase):
|
||||||
|
+class DictTest(__TestCase):
|
||||||
|
|
||||||
|
def test_invalid_keyword_arguments(self):
|
||||||
|
class Custom(dict):
|
||||||
|
@@ -265,6 +319,7 @@ class DictTest(unittest.TestCase):
|
||||||
|
|
||||||
|
self.assertRaises(ValueError, {}.update, [(1, 2, 3)])
|
||||||
|
|
||||||
|
+ @unittest.skip("test hangs")
|
||||||
|
def test_fromkeys(self):
|
||||||
|
self.assertEqual(dict.fromkeys('abc'), {'a':None, 'b':None, 'c':None})
|
||||||
|
d = {}
|
||||||
|
@@ -477,7 +532,7 @@ class DictTest(unittest.TestCase):
|
||||||
|
for copymode in -1, +1:
|
||||||
|
# -1: b has same structure as a
|
||||||
|
# +1: b is a.copy()
|
||||||
|
- for log2size in range(12):
|
||||||
|
+ for log2size in range(4):
|
||||||
|
size = 2**log2size
|
||||||
|
a = {}
|
||||||
|
b = {}
|
||||||
|
@@ -1006,18 +1061,6 @@ class DictTest(unittest.TestCase):
|
||||||
|
pass
|
||||||
|
self._tracked(MyDict())
|
||||||
|
|
||||||
|
- @support.cpython_only
|
||||||
|
- def test_track_lazy_instance_dicts(self):
|
||||||
|
- class C:
|
||||||
|
- pass
|
||||||
|
- o = C()
|
||||||
|
- d = o.__dict__
|
||||||
|
- self._not_tracked(d)
|
||||||
|
- o.untracked = 42
|
||||||
|
- self._not_tracked(d)
|
||||||
|
- o.tracked = []
|
||||||
|
- self._tracked(d)
|
||||||
|
-
|
||||||
|
def make_shared_key_dict(self, n):
|
||||||
|
class C:
|
||||||
|
pass
|
||||||
|
@@ -1622,7 +1665,7 @@ class DictTest(unittest.TestCase):
|
||||||
|
self.assertGreaterEqual(eq_count, 1)
|
||||||
|
|
||||||
|
|
||||||
|
-class CAPITest(unittest.TestCase):
|
||||||
|
+class CAPITest(__TestCase):
|
||||||
|
|
||||||
|
# Test _PyDict_GetItem_KnownHash()
|
||||||
|
@support.cpython_only
|
||||||
|
@@ -1666,4 +1709,4 @@ class SubclassMappingTests(mapping_tests.BasicTestMappingProtocol):
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
- unittest.main()
|
||||||
|
+ run_tests()
|
||||||
1712
test/dynamo/cpython/3_13/test_dict.py
Normal file
1712
test/dynamo/cpython/3_13/test_dict.py
Normal file
File diff suppressed because it is too large
Load Diff
77
test/dynamo/cpython/3_13/test_list.diff
Normal file
77
test/dynamo/cpython/3_13/test_list.diff
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
diff --git a/test/dynamo/cpython/3_13/test_list.py b/test/dynamo/cpython/3_13/test_list.py
|
||||||
|
index 23ef902aa0b..30e69ff75bd 100644
|
||||||
|
--- a/test/dynamo/cpython/3_13/test_list.py
|
||||||
|
+++ b/test/dynamo/cpython/3_13/test_list.py
|
||||||
|
@@ -1,6 +1,57 @@
|
||||||
|
+# ======= BEGIN Dynamo patch =======
|
||||||
|
+# Owner(s): ["module: dynamo"]
|
||||||
|
+
|
||||||
|
+# ruff: noqa
|
||||||
|
+# flake8: noqa
|
||||||
|
+
|
||||||
|
+import sys
|
||||||
|
+import torch
|
||||||
|
+import torch._dynamo.test_case
|
||||||
|
+import unittest
|
||||||
|
+from torch._dynamo.test_case import CPythonTestCase
|
||||||
|
+from torch.testing._internal.common_utils import run_tests
|
||||||
|
+
|
||||||
|
+__TestCase = CPythonTestCase
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+# redirect import statements
|
||||||
|
+import sys
|
||||||
|
+import importlib.abc
|
||||||
|
+
|
||||||
|
+redirect_imports = (
|
||||||
|
+ "test.mapping_tests",
|
||||||
|
+ "test.typinganndata",
|
||||||
|
+ "test.test_grammar",
|
||||||
|
+ "test.test_math",
|
||||||
|
+ "test.test_iter",
|
||||||
|
+ "test.typinganndata.ann_module",
|
||||||
|
+)
|
||||||
|
+
|
||||||
|
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||||
|
+ def find_spec(self, fullname, path, target=None):
|
||||||
|
+ # Check if the import is the problematic one
|
||||||
|
+ if fullname in redirect_imports:
|
||||||
|
+ try:
|
||||||
|
+ # Attempt to import the standalone module
|
||||||
|
+ name = fullname.removeprefix("test.")
|
||||||
|
+ r = importlib.import_module(name)
|
||||||
|
+ # Redirect the module in sys.modules
|
||||||
|
+ sys.modules[fullname] = r
|
||||||
|
+ # Return a module spec from the found module
|
||||||
|
+ return importlib.util.find_spec(name)
|
||||||
|
+ except ImportError:
|
||||||
|
+ return None
|
||||||
|
+ return None
|
||||||
|
+
|
||||||
|
+# Add the custom finder to sys.meta_path
|
||||||
|
+sys.meta_path.insert(0, RedirectImportFinder())
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+# ======= END DYNAMO PATCH =======
|
||||||
|
+
|
||||||
|
import sys
|
||||||
|
import textwrap
|
||||||
|
-from test import list_tests
|
||||||
|
+import list_tests
|
||||||
|
from test.support import cpython_only
|
||||||
|
from test.support.script_helper import assert_python_ok
|
||||||
|
import pickle
|
||||||
|
@@ -324,6 +375,7 @@ class ListTest(list_tests.CommonTest):
|
||||||
|
a.append(4)
|
||||||
|
self.assertEqual(list(it), [])
|
||||||
|
|
||||||
|
+ @unittest.expectedFailure
|
||||||
|
def test_deopt_from_append_list(self):
|
||||||
|
# gh-132011: it used to crash, because
|
||||||
|
# of `CALL_LIST_APPEND` specialization failure.
|
||||||
|
@@ -345,4 +397,4 @@ class ListTest(list_tests.CommonTest):
|
||||||
|
self.assertEqual(rc, 0)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
- unittest.main()
|
||||||
|
+ run_tests()
|
||||||
398
test/dynamo/cpython/3_13/test_list.py
Normal file
398
test/dynamo/cpython/3_13/test_list.py
Normal file
@ -0,0 +1,398 @@
|
|||||||
|
# ======= BEGIN Dynamo patch =======
|
||||||
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
|
# ruff: noqa
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import torch._dynamo.test_case
|
||||||
|
import unittest
|
||||||
|
from torch._dynamo.test_case import CPythonTestCase
|
||||||
|
from torch.testing._internal.common_utils import run_tests
|
||||||
|
|
||||||
|
__TestCase = CPythonTestCase
|
||||||
|
|
||||||
|
|
||||||
|
# redirect import statements
|
||||||
|
import sys
|
||||||
|
import importlib.abc
|
||||||
|
|
||||||
|
redirect_imports = (
|
||||||
|
"test.mapping_tests",
|
||||||
|
"test.typinganndata",
|
||||||
|
"test.test_grammar",
|
||||||
|
"test.test_math",
|
||||||
|
"test.test_iter",
|
||||||
|
"test.typinganndata.ann_module",
|
||||||
|
)
|
||||||
|
|
||||||
|
class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||||
|
def find_spec(self, fullname, path, target=None):
|
||||||
|
# Check if the import is the problematic one
|
||||||
|
if fullname in redirect_imports:
|
||||||
|
try:
|
||||||
|
# Attempt to import the standalone module
|
||||||
|
name = fullname.removeprefix("test.")
|
||||||
|
r = importlib.import_module(name)
|
||||||
|
# Redirect the module in sys.modules
|
||||||
|
sys.modules[fullname] = r
|
||||||
|
# Return a module spec from the found module
|
||||||
|
return importlib.util.find_spec(name)
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Add the custom finder to sys.meta_path
|
||||||
|
sys.meta_path.insert(0, RedirectImportFinder())
|
||||||
|
|
||||||
|
|
||||||
|
# ======= END DYNAMO PATCH =======
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import textwrap
|
||||||
|
import list_tests
|
||||||
|
from test.support import cpython_only
|
||||||
|
from test.support.script_helper import assert_python_ok
|
||||||
|
import pickle
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
class ListTest(list_tests.CommonTest):
|
||||||
|
type2test = list
|
||||||
|
|
||||||
|
def test_basic(self):
|
||||||
|
self.assertEqual(list([]), [])
|
||||||
|
l0_3 = [0, 1, 2, 3]
|
||||||
|
l0_3_bis = list(l0_3)
|
||||||
|
self.assertEqual(l0_3, l0_3_bis)
|
||||||
|
self.assertTrue(l0_3 is not l0_3_bis)
|
||||||
|
self.assertEqual(list(()), [])
|
||||||
|
self.assertEqual(list((0, 1, 2, 3)), [0, 1, 2, 3])
|
||||||
|
self.assertEqual(list(''), [])
|
||||||
|
self.assertEqual(list('spam'), ['s', 'p', 'a', 'm'])
|
||||||
|
self.assertEqual(list(x for x in range(10) if x % 2),
|
||||||
|
[1, 3, 5, 7, 9])
|
||||||
|
|
||||||
|
if sys.maxsize == 0x7fffffff:
|
||||||
|
# This test can currently only work on 32-bit machines.
|
||||||
|
# XXX If/when PySequence_Length() returns a ssize_t, it should be
|
||||||
|
# XXX re-enabled.
|
||||||
|
# Verify clearing of bug #556025.
|
||||||
|
# This assumes that the max data size (sys.maxint) == max
|
||||||
|
# address size this also assumes that the address size is at
|
||||||
|
# least 4 bytes with 8 byte addresses, the bug is not well
|
||||||
|
# tested
|
||||||
|
#
|
||||||
|
# Note: This test is expected to SEGV under Cygwin 1.3.12 or
|
||||||
|
# earlier due to a newlib bug. See the following mailing list
|
||||||
|
# thread for the details:
|
||||||
|
self.assertRaises(MemoryError, list, range(sys.maxsize // 2))
|
||||||
|
|
||||||
|
# This code used to segfault in Py2.4a3
|
||||||
|
x = []
|
||||||
|
x.extend(-y for y in x)
|
||||||
|
self.assertEqual(x, [])
|
||||||
|
|
||||||
|
def test_keyword_args(self):
|
||||||
|
with self.assertRaisesRegex(TypeError, 'keyword argument'):
|
||||||
|
list(sequence=[])
|
||||||
|
|
||||||
|
def test_keywords_in_subclass(self):
|
||||||
|
class subclass(list):
|
||||||
|
pass
|
||||||
|
u = subclass([1, 2])
|
||||||
|
self.assertIs(type(u), subclass)
|
||||||
|
self.assertEqual(list(u), [1, 2])
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
subclass(sequence=())
|
||||||
|
|
||||||
|
class subclass_with_init(list):
|
||||||
|
def __init__(self, seq, newarg=None):
|
||||||
|
super().__init__(seq)
|
||||||
|
self.newarg = newarg
|
||||||
|
u = subclass_with_init([1, 2], newarg=3)
|
||||||
|
self.assertIs(type(u), subclass_with_init)
|
||||||
|
self.assertEqual(list(u), [1, 2])
|
||||||
|
self.assertEqual(u.newarg, 3)
|
||||||
|
|
||||||
|
class subclass_with_new(list):
|
||||||
|
def __new__(cls, seq, newarg=None):
|
||||||
|
self = super().__new__(cls, seq)
|
||||||
|
self.newarg = newarg
|
||||||
|
return self
|
||||||
|
u = subclass_with_new([1, 2], newarg=3)
|
||||||
|
self.assertIs(type(u), subclass_with_new)
|
||||||
|
self.assertEqual(list(u), [1, 2])
|
||||||
|
self.assertEqual(u.newarg, 3)
|
||||||
|
|
||||||
|
def test_truth(self):
|
||||||
|
super().test_truth()
|
||||||
|
self.assertTrue(not [])
|
||||||
|
self.assertTrue([42])
|
||||||
|
|
||||||
|
def test_identity(self):
|
||||||
|
self.assertTrue([] is not [])
|
||||||
|
|
||||||
|
def test_len(self):
|
||||||
|
super().test_len()
|
||||||
|
self.assertEqual(len([]), 0)
|
||||||
|
self.assertEqual(len([0]), 1)
|
||||||
|
self.assertEqual(len([0, 1, 2]), 3)
|
||||||
|
|
||||||
|
def test_overflow(self):
|
||||||
|
lst = [4, 5, 6, 7]
|
||||||
|
n = int((sys.maxsize*2+2) // len(lst))
|
||||||
|
def mul(a, b): return a * b
|
||||||
|
def imul(a, b): a *= b
|
||||||
|
self.assertRaises((MemoryError, OverflowError), mul, lst, n)
|
||||||
|
self.assertRaises((MemoryError, OverflowError), imul, lst, n)
|
||||||
|
|
||||||
|
def test_empty_slice(self):
|
||||||
|
x = []
|
||||||
|
x[:] = x
|
||||||
|
self.assertEqual(x, [])
|
||||||
|
|
||||||
|
def test_list_resize_overflow(self):
|
||||||
|
# gh-97616: test new_allocated * sizeof(PyObject*) overflow
|
||||||
|
# check in list_resize()
|
||||||
|
lst = [0] * 65
|
||||||
|
del lst[1:]
|
||||||
|
self.assertEqual(len(lst), 1)
|
||||||
|
|
||||||
|
size = sys.maxsize
|
||||||
|
with self.assertRaises((MemoryError, OverflowError)):
|
||||||
|
lst * size
|
||||||
|
with self.assertRaises((MemoryError, OverflowError)):
|
||||||
|
lst *= size
|
||||||
|
|
||||||
|
def test_repr_mutate(self):
|
||||||
|
class Obj:
|
||||||
|
@staticmethod
|
||||||
|
def __repr__():
|
||||||
|
try:
|
||||||
|
mylist.pop()
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
return 'obj'
|
||||||
|
|
||||||
|
mylist = [Obj() for _ in range(5)]
|
||||||
|
self.assertEqual(repr(mylist), '[obj, obj, obj]')
|
||||||
|
|
||||||
|
def test_repr_large(self):
|
||||||
|
# Check the repr of large list objects
|
||||||
|
def check(n):
|
||||||
|
l = [0] * n
|
||||||
|
s = repr(l)
|
||||||
|
self.assertEqual(s,
|
||||||
|
'[' + ', '.join(['0'] * n) + ']')
|
||||||
|
check(10) # check our checking code
|
||||||
|
check(1000000)
|
||||||
|
|
||||||
|
def test_iterator_pickle(self):
|
||||||
|
orig = self.type2test([4, 5, 6, 7])
|
||||||
|
data = [10, 11, 12, 13, 14, 15]
|
||||||
|
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||||
|
# initial iterator
|
||||||
|
itorig = iter(orig)
|
||||||
|
d = pickle.dumps((itorig, orig), proto)
|
||||||
|
it, a = pickle.loads(d)
|
||||||
|
a[:] = data
|
||||||
|
self.assertEqual(type(it), type(itorig))
|
||||||
|
self.assertEqual(list(it), data)
|
||||||
|
|
||||||
|
# running iterator
|
||||||
|
next(itorig)
|
||||||
|
d = pickle.dumps((itorig, orig), proto)
|
||||||
|
it, a = pickle.loads(d)
|
||||||
|
a[:] = data
|
||||||
|
self.assertEqual(type(it), type(itorig))
|
||||||
|
self.assertEqual(list(it), data[1:])
|
||||||
|
|
||||||
|
# empty iterator
|
||||||
|
for i in range(1, len(orig)):
|
||||||
|
next(itorig)
|
||||||
|
d = pickle.dumps((itorig, orig), proto)
|
||||||
|
it, a = pickle.loads(d)
|
||||||
|
a[:] = data
|
||||||
|
self.assertEqual(type(it), type(itorig))
|
||||||
|
self.assertEqual(list(it), data[len(orig):])
|
||||||
|
|
||||||
|
# exhausted iterator
|
||||||
|
self.assertRaises(StopIteration, next, itorig)
|
||||||
|
d = pickle.dumps((itorig, orig), proto)
|
||||||
|
it, a = pickle.loads(d)
|
||||||
|
a[:] = data
|
||||||
|
self.assertEqual(list(it), [])
|
||||||
|
|
||||||
|
def test_reversed_pickle(self):
|
||||||
|
orig = self.type2test([4, 5, 6, 7])
|
||||||
|
data = [10, 11, 12, 13, 14, 15]
|
||||||
|
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||||
|
# initial iterator
|
||||||
|
itorig = reversed(orig)
|
||||||
|
d = pickle.dumps((itorig, orig), proto)
|
||||||
|
it, a = pickle.loads(d)
|
||||||
|
a[:] = data
|
||||||
|
self.assertEqual(type(it), type(itorig))
|
||||||
|
self.assertEqual(list(it), data[len(orig)-1::-1])
|
||||||
|
|
||||||
|
# running iterator
|
||||||
|
next(itorig)
|
||||||
|
d = pickle.dumps((itorig, orig), proto)
|
||||||
|
it, a = pickle.loads(d)
|
||||||
|
a[:] = data
|
||||||
|
self.assertEqual(type(it), type(itorig))
|
||||||
|
self.assertEqual(list(it), data[len(orig)-2::-1])
|
||||||
|
|
||||||
|
# empty iterator
|
||||||
|
for i in range(1, len(orig)):
|
||||||
|
next(itorig)
|
||||||
|
d = pickle.dumps((itorig, orig), proto)
|
||||||
|
it, a = pickle.loads(d)
|
||||||
|
a[:] = data
|
||||||
|
self.assertEqual(type(it), type(itorig))
|
||||||
|
self.assertEqual(list(it), [])
|
||||||
|
|
||||||
|
# exhausted iterator
|
||||||
|
self.assertRaises(StopIteration, next, itorig)
|
||||||
|
d = pickle.dumps((itorig, orig), proto)
|
||||||
|
it, a = pickle.loads(d)
|
||||||
|
a[:] = data
|
||||||
|
self.assertEqual(list(it), [])
|
||||||
|
|
||||||
|
def test_step_overflow(self):
|
||||||
|
a = [0, 1, 2, 3, 4]
|
||||||
|
a[1::sys.maxsize] = [0]
|
||||||
|
self.assertEqual(a[3::sys.maxsize], [3])
|
||||||
|
|
||||||
|
def test_no_comdat_folding(self):
|
||||||
|
# Issue 8847: In the PGO build, the MSVC linker's COMDAT folding
|
||||||
|
# optimization causes failures in code that relies on distinct
|
||||||
|
# function addresses.
|
||||||
|
class L(list): pass
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
(3,) + L([1,2])
|
||||||
|
|
||||||
|
def test_equal_operator_modifying_operand(self):
|
||||||
|
# test fix for seg fault reported in bpo-38588 part 2.
|
||||||
|
class X:
|
||||||
|
def __eq__(self,other) :
|
||||||
|
list2.clear()
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
class Y:
|
||||||
|
def __eq__(self, other):
|
||||||
|
list1.clear()
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
class Z:
|
||||||
|
def __eq__(self, other):
|
||||||
|
list3.clear()
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
list1 = [X()]
|
||||||
|
list2 = [Y()]
|
||||||
|
self.assertTrue(list1 == list2)
|
||||||
|
|
||||||
|
list3 = [Z()]
|
||||||
|
list4 = [1]
|
||||||
|
self.assertFalse(list3 == list4)
|
||||||
|
|
||||||
|
def test_lt_operator_modifying_operand(self):
|
||||||
|
# See gh-120298
|
||||||
|
class evil:
|
||||||
|
def __lt__(self, other):
|
||||||
|
other.clear()
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
a = [[evil()]]
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
a[0] < a
|
||||||
|
|
||||||
|
def test_list_index_modifing_operand(self):
|
||||||
|
# See gh-120384
|
||||||
|
class evil:
|
||||||
|
def __init__(self, lst):
|
||||||
|
self.lst = lst
|
||||||
|
def __iter__(self):
|
||||||
|
yield from self.lst
|
||||||
|
self.lst.clear()
|
||||||
|
|
||||||
|
lst = list(range(5))
|
||||||
|
operand = evil(lst)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
lst[::-1] = operand
|
||||||
|
|
||||||
|
@cpython_only
|
||||||
|
def test_preallocation(self):
|
||||||
|
iterable = [0] * 10
|
||||||
|
iter_size = sys.getsizeof(iterable)
|
||||||
|
|
||||||
|
self.assertEqual(iter_size, sys.getsizeof(list([0] * 10)))
|
||||||
|
self.assertEqual(iter_size, sys.getsizeof(list(range(10))))
|
||||||
|
|
||||||
|
def test_count_index_remove_crashes(self):
|
||||||
|
# bpo-38610: The count(), index(), and remove() methods were not
|
||||||
|
# holding strong references to list elements while calling
|
||||||
|
# PyObject_RichCompareBool().
|
||||||
|
class X:
|
||||||
|
def __eq__(self, other):
|
||||||
|
lst.clear()
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
lst = [X()]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
lst.index(lst)
|
||||||
|
|
||||||
|
class L(list):
|
||||||
|
def __eq__(self, other):
|
||||||
|
str(other)
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
lst = L([X()])
|
||||||
|
lst.count(lst)
|
||||||
|
|
||||||
|
lst = L([X()])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
lst.remove(lst)
|
||||||
|
|
||||||
|
# bpo-39453: list.__contains__ was not holding strong references
|
||||||
|
# to list elements while calling PyObject_RichCompareBool().
|
||||||
|
lst = [X(), X()]
|
||||||
|
3 in lst
|
||||||
|
lst = [X(), X()]
|
||||||
|
X() in lst
|
||||||
|
|
||||||
|
def test_tier2_invalidates_iterator(self):
|
||||||
|
# GH-121012
|
||||||
|
for _ in range(100):
|
||||||
|
a = [1, 2, 3]
|
||||||
|
it = iter(a)
|
||||||
|
for _ in it:
|
||||||
|
pass
|
||||||
|
a.append(4)
|
||||||
|
self.assertEqual(list(it), [])
|
||||||
|
|
||||||
|
@unittest.expectedFailure
|
||||||
|
def test_deopt_from_append_list(self):
|
||||||
|
# gh-132011: it used to crash, because
|
||||||
|
# of `CALL_LIST_APPEND` specialization failure.
|
||||||
|
code = textwrap.dedent("""
|
||||||
|
l = []
|
||||||
|
def lappend(l, x, y):
|
||||||
|
l.append((x, y))
|
||||||
|
for x in range(3):
|
||||||
|
lappend(l, None, None)
|
||||||
|
try:
|
||||||
|
lappend(list, None, None)
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise AssertionError
|
||||||
|
""")
|
||||||
|
|
||||||
|
rc, _, _ = assert_python_ok("-c", code)
|
||||||
|
self.assertEqual(rc, 0)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_tests()
|
||||||
173
test/dynamo/cpython/3_13/test_ordered_dict.diff
Normal file
173
test/dynamo/cpython/3_13/test_ordered_dict.diff
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
diff --git a/test/dynamo/cpython/3_13/test_ordered_dict.py b/test/dynamo/cpython/3_13/test_ordered_dict.py
|
||||||
|
index a9b6a84996e..b77eff70414 100644
|
||||||
|
--- a/test/dynamo/cpython/3_13/test_ordered_dict.py
|
||||||
|
+++ b/test/dynamo/cpython/3_13/test_ordered_dict.py
|
||||||
|
@@ -1,3 +1,57 @@
|
||||||
|
+# ======= BEGIN Dynamo patch =======
|
||||||
|
+# Owner(s): ["module: dynamo"]
|
||||||
|
+
|
||||||
|
+# ruff: noqa
|
||||||
|
+# flake8: noqa
|
||||||
|
+
|
||||||
|
+import sys
|
||||||
|
+import torch
|
||||||
|
+import torch._dynamo.test_case
|
||||||
|
+import unittest
|
||||||
|
+from torch._dynamo.test_case import CPythonTestCase
|
||||||
|
+from torch.testing._internal.common_utils import (
|
||||||
|
+ run_tests,
|
||||||
|
+ xfailIfTorchDynamo,
|
||||||
|
+)
|
||||||
|
+
|
||||||
|
+__TestCase = CPythonTestCase
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+# redirect import statements
|
||||||
|
+import sys
|
||||||
|
+import importlib.abc
|
||||||
|
+
|
||||||
|
+redirect_imports = (
|
||||||
|
+ "test.mapping_tests",
|
||||||
|
+ "test.typinganndata",
|
||||||
|
+ "test.test_grammar",
|
||||||
|
+ "test.test_math",
|
||||||
|
+ "test.test_iter",
|
||||||
|
+ "test.typinganndata.ann_module",
|
||||||
|
+)
|
||||||
|
+
|
||||||
|
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||||
|
+ def find_spec(self, fullname, path, target=None):
|
||||||
|
+ # Check if the import is the problematic one
|
||||||
|
+ if fullname in redirect_imports:
|
||||||
|
+ try:
|
||||||
|
+ # Attempt to import the standalone module
|
||||||
|
+ name = fullname.removeprefix("test.")
|
||||||
|
+ r = importlib.import_module(name)
|
||||||
|
+ # Redirect the module in sys.modules
|
||||||
|
+ sys.modules[fullname] = r
|
||||||
|
+ # Return a module spec from the found module
|
||||||
|
+ return importlib.util.find_spec(name)
|
||||||
|
+ except ImportError:
|
||||||
|
+ return None
|
||||||
|
+ return None
|
||||||
|
+
|
||||||
|
+# Add the custom finder to sys.meta_path
|
||||||
|
+sys.meta_path.insert(0, RedirectImportFinder())
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+# ======= END DYNAMO PATCH =======
|
||||||
|
+
|
||||||
|
import builtins
|
||||||
|
import contextlib
|
||||||
|
import copy
|
||||||
|
@@ -760,7 +814,7 @@ class _TriggerSideEffectOnEqual:
|
||||||
|
def side_effect(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
-class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase):
|
||||||
|
+class PurePythonOrderedDictTests(OrderedDictTests, __TestCase):
|
||||||
|
|
||||||
|
module = py_coll
|
||||||
|
OrderedDict = py_coll.OrderedDict
|
||||||
|
@@ -781,7 +835,7 @@ class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase):
|
||||||
|
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
|
||||||
|
|
||||||
|
|
||||||
|
-class CPythonBuiltinDictTests(unittest.TestCase):
|
||||||
|
+class CPythonBuiltinDictTests(__TestCase):
|
||||||
|
"""Builtin dict preserves insertion order.
|
||||||
|
|
||||||
|
Reuse some of tests in OrderedDict selectively.
|
||||||
|
@@ -800,6 +854,7 @@ for method in (
|
||||||
|
del method
|
||||||
|
|
||||||
|
|
||||||
|
+
|
||||||
|
class CPythonOrderedDictSideEffects:
|
||||||
|
|
||||||
|
def check_runtime_error_issue119004(self, dict1, dict2):
|
||||||
|
@@ -878,7 +933,7 @@ class CPythonOrderedDictSideEffects:
|
||||||
|
@unittest.skipUnless(c_coll, 'requires the C version of the collections module')
|
||||||
|
class CPythonOrderedDictTests(OrderedDictTests,
|
||||||
|
CPythonOrderedDictSideEffects,
|
||||||
|
- unittest.TestCase):
|
||||||
|
+ __TestCase):
|
||||||
|
|
||||||
|
module = c_coll
|
||||||
|
OrderedDict = c_coll.OrderedDict
|
||||||
|
@@ -986,7 +1041,7 @@ class CPythonOrderedDictSubclassTests(CPythonOrderedDictTests):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
-class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
|
||||||
|
+class PurePythonOrderedDictWithSlotsCopyingTests(__TestCase):
|
||||||
|
|
||||||
|
module = py_coll
|
||||||
|
class OrderedDict(py_coll.OrderedDict):
|
||||||
|
@@ -995,7 +1050,7 @@ class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipUnless(c_coll, 'requires the C version of the collections module')
|
||||||
|
-class CPythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
|
||||||
|
+class CPythonOrderedDictWithSlotsCopyingTests(__TestCase):
|
||||||
|
|
||||||
|
module = c_coll
|
||||||
|
class OrderedDict(c_coll.OrderedDict):
|
||||||
|
@@ -1008,6 +1063,7 @@ class PurePythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.type2test = py_coll.OrderedDict
|
||||||
|
+ super().setUpClass()
|
||||||
|
|
||||||
|
def test_popitem(self):
|
||||||
|
d = self._empty_mapping()
|
||||||
|
@@ -1020,6 +1076,7 @@ class CPythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.type2test = c_coll.OrderedDict
|
||||||
|
+ super().setUpClass()
|
||||||
|
|
||||||
|
def test_popitem(self):
|
||||||
|
d = self._empty_mapping()
|
||||||
|
@@ -1033,6 +1090,7 @@ class PurePythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol):
|
||||||
|
class MyOrderedDict(py_coll.OrderedDict):
|
||||||
|
pass
|
||||||
|
cls.type2test = MyOrderedDict
|
||||||
|
+ super().setUpClass()
|
||||||
|
|
||||||
|
def test_popitem(self):
|
||||||
|
d = self._empty_mapping()
|
||||||
|
@@ -1047,6 +1105,7 @@ class CPythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol):
|
||||||
|
class MyOrderedDict(c_coll.OrderedDict):
|
||||||
|
pass
|
||||||
|
cls.type2test = MyOrderedDict
|
||||||
|
+ super().setUpClass()
|
||||||
|
|
||||||
|
def test_popitem(self):
|
||||||
|
d = self._empty_mapping()
|
||||||
|
@@ -1120,21 +1179,22 @@ class SimpleLRUCacheTests:
|
||||||
|
self.assertEqual(list(c), [1, 3, 2])
|
||||||
|
|
||||||
|
|
||||||
|
-class PySimpleLRUCacheTests(SimpleLRUCacheTests, unittest.TestCase):
|
||||||
|
+class PySimpleLRUCacheTests(SimpleLRUCacheTests, __TestCase):
|
||||||
|
|
||||||
|
class type2test(SimpleLRUCache, py_coll.OrderedDict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipUnless(c_coll, 'requires the C version of the collections module')
|
||||||
|
-class CSimpleLRUCacheTests(SimpleLRUCacheTests, unittest.TestCase):
|
||||||
|
+class CSimpleLRUCacheTests(SimpleLRUCacheTests, __TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
class type2test(SimpleLRUCache, c_coll.OrderedDict):
|
||||||
|
pass
|
||||||
|
cls.type2test = type2test
|
||||||
|
+ super().setUpClass()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
- unittest.main()
|
||||||
|
+ run_tests()
|
||||||
1200
test/dynamo/cpython/3_13/test_ordered_dict.py
Normal file
1200
test/dynamo/cpython/3_13/test_ordered_dict.py
Normal file
File diff suppressed because it is too large
Load Diff
67
test/dynamo/cpython/3_13/test_tuple.diff
Normal file
67
test/dynamo/cpython/3_13/test_tuple.diff
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
diff --git a/test/dynamo/cpython/3_13/test_tuple.py b/test/dynamo/cpython/3_13/test_tuple.py
|
||||||
|
index 9ce80c5e8ea..e52c0cbc140 100644
|
||||||
|
--- a/test/dynamo/cpython/3_13/test_tuple.py
|
||||||
|
+++ b/test/dynamo/cpython/3_13/test_tuple.py
|
||||||
|
@@ -1,4 +1,55 @@
|
||||||
|
-from test import support, seq_tests
|
||||||
|
+# ======= BEGIN Dynamo patch =======
|
||||||
|
+# Owner(s): ["module: dynamo"]
|
||||||
|
+
|
||||||
|
+# ruff: noqa
|
||||||
|
+# flake8: noqa
|
||||||
|
+
|
||||||
|
+import sys
|
||||||
|
+import torch
|
||||||
|
+import torch._dynamo.test_case
|
||||||
|
+import unittest
|
||||||
|
+from torch._dynamo.test_case import CPythonTestCase
|
||||||
|
+from torch.testing._internal.common_utils import run_tests
|
||||||
|
+
|
||||||
|
+__TestCase = CPythonTestCase
|
||||||
|
+
|
||||||
|
+# redirect import statements
|
||||||
|
+import sys
|
||||||
|
+import importlib.abc
|
||||||
|
+
|
||||||
|
+redirect_imports = (
|
||||||
|
+ "test.mapping_tests",
|
||||||
|
+ "test.typinganndata",
|
||||||
|
+ "test.test_grammar",
|
||||||
|
+ "test.test_math",
|
||||||
|
+ "test.test_iter",
|
||||||
|
+ "test.typinganndata.ann_module",
|
||||||
|
+)
|
||||||
|
+
|
||||||
|
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||||
|
+ def find_spec(self, fullname, path, target=None):
|
||||||
|
+ # Check if the import is the problematic one
|
||||||
|
+ if fullname in redirect_imports:
|
||||||
|
+ try:
|
||||||
|
+ # Attempt to import the standalone module
|
||||||
|
+ name = fullname.removeprefix("test.")
|
||||||
|
+ r = importlib.import_module(name)
|
||||||
|
+ # Redirect the module in sys.modules
|
||||||
|
+ sys.modules[fullname] = r
|
||||||
|
+ # Return a module spec from the found module
|
||||||
|
+ return importlib.util.find_spec(name)
|
||||||
|
+ except ImportError:
|
||||||
|
+ return None
|
||||||
|
+ return None
|
||||||
|
+
|
||||||
|
+# Add the custom finder to sys.meta_path
|
||||||
|
+sys.meta_path.insert(0, RedirectImportFinder())
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+# ======= END DYNAMO PATCH =======
|
||||||
|
+
|
||||||
|
+from test import support
|
||||||
|
+import seq_tests
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import gc
|
||||||
|
@@ -510,4 +561,4 @@ class TupleTest(seq_tests.CommonTest):
|
||||||
|
# pileup 262,143 mean 8.0 coll 262,143 z +92683.6
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
- unittest.main()
|
||||||
|
+ run_tests()
|
||||||
564
test/dynamo/cpython/3_13/test_tuple.py
Normal file
564
test/dynamo/cpython/3_13/test_tuple.py
Normal file
@ -0,0 +1,564 @@
|
|||||||
|
# ======= BEGIN Dynamo patch =======
|
||||||
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
|
# ruff: noqa
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import torch._dynamo.test_case
|
||||||
|
import unittest
|
||||||
|
from torch._dynamo.test_case import CPythonTestCase
|
||||||
|
from torch.testing._internal.common_utils import run_tests
|
||||||
|
|
||||||
|
__TestCase = CPythonTestCase
|
||||||
|
|
||||||
|
# redirect import statements
|
||||||
|
import sys
|
||||||
|
import importlib.abc
|
||||||
|
|
||||||
|
redirect_imports = (
|
||||||
|
"test.mapping_tests",
|
||||||
|
"test.typinganndata",
|
||||||
|
"test.test_grammar",
|
||||||
|
"test.test_math",
|
||||||
|
"test.test_iter",
|
||||||
|
"test.typinganndata.ann_module",
|
||||||
|
)
|
||||||
|
|
||||||
|
class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||||
|
def find_spec(self, fullname, path, target=None):
|
||||||
|
# Check if the import is the problematic one
|
||||||
|
if fullname in redirect_imports:
|
||||||
|
try:
|
||||||
|
# Attempt to import the standalone module
|
||||||
|
name = fullname.removeprefix("test.")
|
||||||
|
r = importlib.import_module(name)
|
||||||
|
# Redirect the module in sys.modules
|
||||||
|
sys.modules[fullname] = r
|
||||||
|
# Return a module spec from the found module
|
||||||
|
return importlib.util.find_spec(name)
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Add the custom finder to sys.meta_path
|
||||||
|
sys.meta_path.insert(0, RedirectImportFinder())
|
||||||
|
|
||||||
|
|
||||||
|
# ======= END DYNAMO PATCH =======
|
||||||
|
|
||||||
|
from test import support
|
||||||
|
import seq_tests
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
# For tuple hashes, we normally only run a test to ensure that we get
|
||||||
|
# the same results across platforms in a handful of cases. If that's
|
||||||
|
# so, there's no real point to running more. Set RUN_ALL_HASH_TESTS to
|
||||||
|
# run more anyway. That's usually of real interest only when analyzing,
|
||||||
|
# or changing, the hash algorithm. In which case it's usually also
|
||||||
|
# most useful to set JUST_SHOW_HASH_RESULTS, to see all the results
|
||||||
|
# instead of wrestling with test "failures". See the bottom of the
|
||||||
|
# file for extensive notes on what we're testing here and why.
|
||||||
|
RUN_ALL_HASH_TESTS = False
|
||||||
|
JUST_SHOW_HASH_RESULTS = False # if RUN_ALL_HASH_TESTS, just display
|
||||||
|
|
||||||
|
class TupleTest(seq_tests.CommonTest):
|
||||||
|
type2test = tuple
|
||||||
|
|
||||||
|
def test_getitem_error(self):
|
||||||
|
t = ()
|
||||||
|
msg = "tuple indices must be integers or slices"
|
||||||
|
with self.assertRaisesRegex(TypeError, msg):
|
||||||
|
t['a']
|
||||||
|
|
||||||
|
def test_constructors(self):
|
||||||
|
super().test_constructors()
|
||||||
|
# calling built-in types without argument must return empty
|
||||||
|
self.assertEqual(tuple(), ())
|
||||||
|
t0_3 = (0, 1, 2, 3)
|
||||||
|
t0_3_bis = tuple(t0_3)
|
||||||
|
self.assertTrue(t0_3 is t0_3_bis)
|
||||||
|
self.assertEqual(tuple([]), ())
|
||||||
|
self.assertEqual(tuple([0, 1, 2, 3]), (0, 1, 2, 3))
|
||||||
|
self.assertEqual(tuple(''), ())
|
||||||
|
self.assertEqual(tuple('spam'), ('s', 'p', 'a', 'm'))
|
||||||
|
self.assertEqual(tuple(x for x in range(10) if x % 2),
|
||||||
|
(1, 3, 5, 7, 9))
|
||||||
|
|
||||||
|
def test_keyword_args(self):
|
||||||
|
with self.assertRaisesRegex(TypeError, 'keyword argument'):
|
||||||
|
tuple(sequence=())
|
||||||
|
|
||||||
|
def test_keywords_in_subclass(self):
|
||||||
|
class subclass(tuple):
|
||||||
|
pass
|
||||||
|
u = subclass([1, 2])
|
||||||
|
self.assertIs(type(u), subclass)
|
||||||
|
self.assertEqual(list(u), [1, 2])
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
subclass(sequence=())
|
||||||
|
|
||||||
|
class subclass_with_init(tuple):
|
||||||
|
def __init__(self, arg, newarg=None):
|
||||||
|
self.newarg = newarg
|
||||||
|
u = subclass_with_init([1, 2], newarg=3)
|
||||||
|
self.assertIs(type(u), subclass_with_init)
|
||||||
|
self.assertEqual(list(u), [1, 2])
|
||||||
|
self.assertEqual(u.newarg, 3)
|
||||||
|
|
||||||
|
class subclass_with_new(tuple):
|
||||||
|
def __new__(cls, arg, newarg=None):
|
||||||
|
self = super().__new__(cls, arg)
|
||||||
|
self.newarg = newarg
|
||||||
|
return self
|
||||||
|
u = subclass_with_new([1, 2], newarg=3)
|
||||||
|
self.assertIs(type(u), subclass_with_new)
|
||||||
|
self.assertEqual(list(u), [1, 2])
|
||||||
|
self.assertEqual(u.newarg, 3)
|
||||||
|
|
||||||
|
def test_truth(self):
|
||||||
|
super().test_truth()
|
||||||
|
self.assertTrue(not ())
|
||||||
|
self.assertTrue((42, ))
|
||||||
|
|
||||||
|
def test_len(self):
|
||||||
|
super().test_len()
|
||||||
|
self.assertEqual(len(()), 0)
|
||||||
|
self.assertEqual(len((0,)), 1)
|
||||||
|
self.assertEqual(len((0, 1, 2)), 3)
|
||||||
|
|
||||||
|
def test_iadd(self):
|
||||||
|
super().test_iadd()
|
||||||
|
u = (0, 1)
|
||||||
|
u2 = u
|
||||||
|
u += (2, 3)
|
||||||
|
self.assertTrue(u is not u2)
|
||||||
|
|
||||||
|
def test_imul(self):
|
||||||
|
super().test_imul()
|
||||||
|
u = (0, 1)
|
||||||
|
u2 = u
|
||||||
|
u *= 3
|
||||||
|
self.assertTrue(u is not u2)
|
||||||
|
|
||||||
|
def test_tupleresizebug(self):
|
||||||
|
# Check that a specific bug in _PyTuple_Resize() is squashed.
|
||||||
|
def f():
|
||||||
|
for i in range(1000):
|
||||||
|
yield i
|
||||||
|
self.assertEqual(list(tuple(f())), list(range(1000)))
|
||||||
|
|
||||||
|
# We expect tuples whose base components have deterministic hashes to
|
||||||
|
# have deterministic hashes too - and, indeed, the same hashes across
|
||||||
|
# platforms with hash codes of the same bit width.
|
||||||
|
def test_hash_exact(self):
|
||||||
|
def check_one_exact(t, e32, e64):
|
||||||
|
got = hash(t)
|
||||||
|
expected = e32 if support.NHASHBITS == 32 else e64
|
||||||
|
if got != expected:
|
||||||
|
msg = f"FAIL hash({t!r}) == {got} != {expected}"
|
||||||
|
self.fail(msg)
|
||||||
|
|
||||||
|
check_one_exact((), 750394483, 5740354900026072187)
|
||||||
|
check_one_exact((0,), 1214856301, -8753497827991233192)
|
||||||
|
check_one_exact((0, 0), -168982784, -8458139203682520985)
|
||||||
|
check_one_exact((0.5,), 2077348973, -408149959306781352)
|
||||||
|
check_one_exact((0.5, (), (-2, 3, (4, 6))), 714642271,
|
||||||
|
-1845940830829704396)
|
||||||
|
|
||||||
|
# Various tests for hashing of tuples to check that we get few collisions.
|
||||||
|
# Does something only if RUN_ALL_HASH_TESTS is true.
|
||||||
|
#
|
||||||
|
# Earlier versions of the tuple hash algorithm had massive collisions
|
||||||
|
# reported at:
|
||||||
|
# - https://bugs.python.org/issue942952
|
||||||
|
# - https://bugs.python.org/issue34751
|
||||||
|
def test_hash_optional(self):
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
|
if not RUN_ALL_HASH_TESTS:
|
||||||
|
return
|
||||||
|
|
||||||
|
# If specified, `expected` is a 2-tuple of expected
|
||||||
|
# (number_of_collisions, pileup) values, and the test fails if
|
||||||
|
# those aren't the values we get. Also if specified, the test
|
||||||
|
# fails if z > `zlimit`.
|
||||||
|
def tryone_inner(tag, nbins, hashes, expected=None, zlimit=None):
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
nballs = len(hashes)
|
||||||
|
mean, sdev = support.collision_stats(nbins, nballs)
|
||||||
|
c = Counter(hashes)
|
||||||
|
collisions = nballs - len(c)
|
||||||
|
z = (collisions - mean) / sdev
|
||||||
|
pileup = max(c.values()) - 1
|
||||||
|
del c
|
||||||
|
got = (collisions, pileup)
|
||||||
|
failed = False
|
||||||
|
prefix = ""
|
||||||
|
if zlimit is not None and z > zlimit:
|
||||||
|
failed = True
|
||||||
|
prefix = f"FAIL z > {zlimit}; "
|
||||||
|
if expected is not None and got != expected:
|
||||||
|
failed = True
|
||||||
|
prefix += f"FAIL {got} != {expected}; "
|
||||||
|
if failed or JUST_SHOW_HASH_RESULTS:
|
||||||
|
msg = f"{prefix}{tag}; pileup {pileup:,} mean {mean:.1f} "
|
||||||
|
msg += f"coll {collisions:,} z {z:+.1f}"
|
||||||
|
if JUST_SHOW_HASH_RESULTS:
|
||||||
|
import sys
|
||||||
|
print(msg, file=sys.__stdout__)
|
||||||
|
else:
|
||||||
|
self.fail(msg)
|
||||||
|
|
||||||
|
def tryone(tag, xs,
|
||||||
|
native32=None, native64=None, hi32=None, lo32=None,
|
||||||
|
zlimit=None):
|
||||||
|
NHASHBITS = support.NHASHBITS
|
||||||
|
hashes = list(map(hash, xs))
|
||||||
|
tryone_inner(tag + f"; {NHASHBITS}-bit hash codes",
|
||||||
|
1 << NHASHBITS,
|
||||||
|
hashes,
|
||||||
|
native32 if NHASHBITS == 32 else native64,
|
||||||
|
zlimit)
|
||||||
|
|
||||||
|
if NHASHBITS > 32:
|
||||||
|
shift = NHASHBITS - 32
|
||||||
|
tryone_inner(tag + "; 32-bit upper hash codes",
|
||||||
|
1 << 32,
|
||||||
|
[h >> shift for h in hashes],
|
||||||
|
hi32,
|
||||||
|
zlimit)
|
||||||
|
|
||||||
|
mask = (1 << 32) - 1
|
||||||
|
tryone_inner(tag + "; 32-bit lower hash codes",
|
||||||
|
1 << 32,
|
||||||
|
[h & mask for h in hashes],
|
||||||
|
lo32,
|
||||||
|
zlimit)
|
||||||
|
|
||||||
|
# Tuples of smallish positive integers are common - nice if we
|
||||||
|
# get "better than random" for these.
|
||||||
|
tryone("range(100) by 3", list(product(range(100), repeat=3)),
|
||||||
|
(0, 0), (0, 0), (4, 1), (0, 0))
|
||||||
|
|
||||||
|
# A previous hash had systematic problems when mixing integers of
|
||||||
|
# similar magnitude but opposite sign, obscurely related to that
|
||||||
|
# j ^ -2 == -j when j is odd.
|
||||||
|
cands = list(range(-10, -1)) + list(range(9))
|
||||||
|
|
||||||
|
# Note: -1 is omitted because hash(-1) == hash(-2) == -2, and
|
||||||
|
# there's nothing the tuple hash can do to avoid collisions
|
||||||
|
# inherited from collisions in the tuple components' hashes.
|
||||||
|
tryone("-10 .. 8 by 4", list(product(cands, repeat=4)),
|
||||||
|
(0, 0), (0, 0), (0, 0), (0, 0))
|
||||||
|
del cands
|
||||||
|
|
||||||
|
# The hashes here are a weird mix of values where all the
|
||||||
|
# variation is in the lowest bits and across a single high-order
|
||||||
|
# bit - the middle bits are all zeroes. A decent hash has to
|
||||||
|
# both propagate low bits to the left and high bits to the
|
||||||
|
# right. This is also complicated a bit in that there are
|
||||||
|
# collisions among the hashes of the integers in L alone.
|
||||||
|
L = [n << 60 for n in range(100)]
|
||||||
|
tryone("0..99 << 60 by 3", list(product(L, repeat=3)),
|
||||||
|
(0, 0), (0, 0), (0, 0), (324, 1))
|
||||||
|
del L
|
||||||
|
|
||||||
|
# Used to suffer a massive number of collisions.
|
||||||
|
tryone("[-3, 3] by 18", list(product([-3, 3], repeat=18)),
|
||||||
|
(7, 1), (0, 0), (7, 1), (6, 1))
|
||||||
|
|
||||||
|
# And even worse. hash(0.5) has only a single bit set, at the
|
||||||
|
# high end. A decent hash needs to propagate high bits right.
|
||||||
|
tryone("[0, 0.5] by 18", list(product([0, 0.5], repeat=18)),
|
||||||
|
(5, 1), (0, 0), (9, 1), (12, 1))
|
||||||
|
|
||||||
|
# Hashes of ints and floats are the same across platforms.
|
||||||
|
# String hashes vary even on a single platform across runs, due
|
||||||
|
# to hash randomization for strings. So we can't say exactly
|
||||||
|
# what this should do. Instead we insist that the # of
|
||||||
|
# collisions is no more than 4 sdevs above the theoretically
|
||||||
|
# random mean. Even if the tuple hash can't achieve that on its
|
||||||
|
# own, the string hash is trying to be decently pseudo-random
|
||||||
|
# (in all bit positions) on _its_ own. We can at least test
|
||||||
|
# that the tuple hash doesn't systematically ruin that.
|
||||||
|
tryone("4-char tuples",
|
||||||
|
list(product("abcdefghijklmnopqrstuvwxyz", repeat=4)),
|
||||||
|
zlimit=4.0)
|
||||||
|
|
||||||
|
# The "old tuple test". See https://bugs.python.org/issue942952.
|
||||||
|
# Ensures, for example, that the hash:
|
||||||
|
# is non-commutative
|
||||||
|
# spreads closely spaced values
|
||||||
|
# doesn't exhibit cancellation in tuples like (x,(x,y))
|
||||||
|
N = 50
|
||||||
|
base = list(range(N))
|
||||||
|
xp = list(product(base, repeat=2))
|
||||||
|
inps = base + list(product(base, xp)) + \
|
||||||
|
list(product(xp, base)) + xp + list(zip(base))
|
||||||
|
tryone("old tuple test", inps,
|
||||||
|
(2, 1), (0, 0), (52, 49), (7, 1))
|
||||||
|
del base, xp, inps
|
||||||
|
|
||||||
|
# The "new tuple test". See https://bugs.python.org/issue34751.
|
||||||
|
# Even more tortured nesting, and a mix of signed ints of very
|
||||||
|
# small magnitude.
|
||||||
|
n = 5
|
||||||
|
A = [x for x in range(-n, n+1) if x != -1]
|
||||||
|
B = A + [(a,) for a in A]
|
||||||
|
L2 = list(product(A, repeat=2))
|
||||||
|
L3 = L2 + list(product(A, repeat=3))
|
||||||
|
L4 = L3 + list(product(A, repeat=4))
|
||||||
|
# T = list of testcases. These consist of all (possibly nested
|
||||||
|
# at most 2 levels deep) tuples containing at most 4 items from
|
||||||
|
# the set A.
|
||||||
|
T = A
|
||||||
|
T += [(a,) for a in B + L4]
|
||||||
|
T += product(L3, B)
|
||||||
|
T += product(L2, repeat=2)
|
||||||
|
T += product(B, L3)
|
||||||
|
T += product(B, B, L2)
|
||||||
|
T += product(B, L2, B)
|
||||||
|
T += product(L2, B, B)
|
||||||
|
T += product(B, repeat=4)
|
||||||
|
assert len(T) == 345130
|
||||||
|
tryone("new tuple test", T,
|
||||||
|
(9, 1), (0, 0), (21, 5), (6, 1))
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
l0 = tuple()
|
||||||
|
l2 = (0, 1, 2)
|
||||||
|
a0 = self.type2test(l0)
|
||||||
|
a2 = self.type2test(l2)
|
||||||
|
|
||||||
|
self.assertEqual(str(a0), repr(l0))
|
||||||
|
self.assertEqual(str(a2), repr(l2))
|
||||||
|
self.assertEqual(repr(a0), "()")
|
||||||
|
self.assertEqual(repr(a2), "(0, 1, 2)")
|
||||||
|
|
||||||
|
def _not_tracked(self, t):
|
||||||
|
# Nested tuples can take several collections to untrack
|
||||||
|
gc.collect()
|
||||||
|
gc.collect()
|
||||||
|
self.assertFalse(gc.is_tracked(t), t)
|
||||||
|
|
||||||
|
def _tracked(self, t):
|
||||||
|
self.assertTrue(gc.is_tracked(t), t)
|
||||||
|
gc.collect()
|
||||||
|
gc.collect()
|
||||||
|
self.assertTrue(gc.is_tracked(t), t)
|
||||||
|
|
||||||
|
@support.cpython_only
|
||||||
|
def test_track_literals(self):
|
||||||
|
# Test GC-optimization of tuple literals
|
||||||
|
x, y, z = 1.5, "a", []
|
||||||
|
|
||||||
|
self._not_tracked(())
|
||||||
|
self._not_tracked((1,))
|
||||||
|
self._not_tracked((1, 2))
|
||||||
|
self._not_tracked((1, 2, "a"))
|
||||||
|
self._not_tracked((1, 2, (None, True, False, ()), int))
|
||||||
|
self._not_tracked((object(),))
|
||||||
|
self._not_tracked(((1, x), y, (2, 3)))
|
||||||
|
|
||||||
|
# Tuples with mutable elements are always tracked, even if those
|
||||||
|
# elements are not tracked right now.
|
||||||
|
self._tracked(([],))
|
||||||
|
self._tracked(([1],))
|
||||||
|
self._tracked(({},))
|
||||||
|
self._tracked((set(),))
|
||||||
|
self._tracked((x, y, z))
|
||||||
|
|
||||||
|
def check_track_dynamic(self, tp, always_track):
|
||||||
|
x, y, z = 1.5, "a", []
|
||||||
|
|
||||||
|
check = self._tracked if always_track else self._not_tracked
|
||||||
|
check(tp())
|
||||||
|
check(tp([]))
|
||||||
|
check(tp(set()))
|
||||||
|
check(tp([1, x, y]))
|
||||||
|
check(tp(obj for obj in [1, x, y]))
|
||||||
|
check(tp(set([1, x, y])))
|
||||||
|
check(tp(tuple([obj]) for obj in [1, x, y]))
|
||||||
|
check(tuple(tp([obj]) for obj in [1, x, y]))
|
||||||
|
|
||||||
|
self._tracked(tp([z]))
|
||||||
|
self._tracked(tp([[x, y]]))
|
||||||
|
self._tracked(tp([{x: y}]))
|
||||||
|
self._tracked(tp(obj for obj in [x, y, z]))
|
||||||
|
self._tracked(tp(tuple([obj]) for obj in [x, y, z]))
|
||||||
|
self._tracked(tuple(tp([obj]) for obj in [x, y, z]))
|
||||||
|
|
||||||
|
@support.cpython_only
|
||||||
|
def test_track_dynamic(self):
|
||||||
|
# Test GC-optimization of dynamically constructed tuples.
|
||||||
|
self.check_track_dynamic(tuple, False)
|
||||||
|
|
||||||
|
@support.cpython_only
|
||||||
|
def test_track_subtypes(self):
|
||||||
|
# Tuple subtypes must always be tracked
|
||||||
|
class MyTuple(tuple):
|
||||||
|
pass
|
||||||
|
self.check_track_dynamic(MyTuple, True)
|
||||||
|
|
||||||
|
@support.cpython_only
|
||||||
|
def test_bug7466(self):
|
||||||
|
# Trying to untrack an unfinished tuple could crash Python
|
||||||
|
self._not_tracked(tuple(gc.collect() for i in range(101)))
|
||||||
|
|
||||||
|
def test_repr_large(self):
|
||||||
|
# Check the repr of large list objects
|
||||||
|
def check(n):
|
||||||
|
l = (0,) * n
|
||||||
|
s = repr(l)
|
||||||
|
self.assertEqual(s,
|
||||||
|
'(' + ', '.join(['0'] * n) + ')')
|
||||||
|
check(10) # check our checking code
|
||||||
|
check(1000000)
|
||||||
|
|
||||||
|
def test_iterator_pickle(self):
|
||||||
|
# Userlist iterators don't support pickling yet since
|
||||||
|
# they are based on generators.
|
||||||
|
data = self.type2test([4, 5, 6, 7])
|
||||||
|
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||||
|
itorg = iter(data)
|
||||||
|
d = pickle.dumps(itorg, proto)
|
||||||
|
it = pickle.loads(d)
|
||||||
|
self.assertEqual(type(itorg), type(it))
|
||||||
|
self.assertEqual(self.type2test(it), self.type2test(data))
|
||||||
|
|
||||||
|
it = pickle.loads(d)
|
||||||
|
next(it)
|
||||||
|
d = pickle.dumps(it, proto)
|
||||||
|
self.assertEqual(self.type2test(it), self.type2test(data)[1:])
|
||||||
|
|
||||||
|
def test_reversed_pickle(self):
|
||||||
|
data = self.type2test([4, 5, 6, 7])
|
||||||
|
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
|
||||||
|
itorg = reversed(data)
|
||||||
|
d = pickle.dumps(itorg, proto)
|
||||||
|
it = pickle.loads(d)
|
||||||
|
self.assertEqual(type(itorg), type(it))
|
||||||
|
self.assertEqual(self.type2test(it), self.type2test(reversed(data)))
|
||||||
|
|
||||||
|
it = pickle.loads(d)
|
||||||
|
next(it)
|
||||||
|
d = pickle.dumps(it, proto)
|
||||||
|
self.assertEqual(self.type2test(it), self.type2test(reversed(data))[1:])
|
||||||
|
|
||||||
|
def test_no_comdat_folding(self):
|
||||||
|
# Issue 8847: In the PGO build, the MSVC linker's COMDAT folding
|
||||||
|
# optimization causes failures in code that relies on distinct
|
||||||
|
# function addresses.
|
||||||
|
class T(tuple): pass
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
[3,] + T((1,2))
|
||||||
|
|
||||||
|
def test_lexicographic_ordering(self):
|
||||||
|
# Issue 21100
|
||||||
|
a = self.type2test([1, 2])
|
||||||
|
b = self.type2test([1, 2, 0])
|
||||||
|
c = self.type2test([1, 3])
|
||||||
|
self.assertLess(a, b)
|
||||||
|
self.assertLess(b, c)
|
||||||
|
|
||||||
|
# Notes on testing hash codes. The primary thing is that Python doesn't
|
||||||
|
# care about "random" hash codes. To the contrary, we like them to be
|
||||||
|
# very regular when possible, so that the low-order bits are as evenly
|
||||||
|
# distributed as possible. For integers this is easy: hash(i) == i for
|
||||||
|
# all not-huge i except i==-1.
|
||||||
|
#
|
||||||
|
# For tuples of mixed type there's really no hope of that, so we want
|
||||||
|
# "randomish" here instead. But getting close to pseudo-random in all
|
||||||
|
# bit positions is more expensive than we've been willing to pay for.
|
||||||
|
#
|
||||||
|
# We can tolerate large deviations from random - what we don't want is
|
||||||
|
# catastrophic pileups on a relative handful of hash codes. The dict
|
||||||
|
# and set lookup routines remain effective provided that full-width hash
|
||||||
|
# codes for not-equal objects are distinct.
|
||||||
|
#
|
||||||
|
# So we compute various statistics here based on what a "truly random"
|
||||||
|
# hash would do, but don't automate "pass or fail" based on those
|
||||||
|
# results. Instead those are viewed as inputs to human judgment, and the
|
||||||
|
# automated tests merely ensure we get the _same_ results across
|
||||||
|
# platforms. In fact, we normally don't bother to run them at all -
|
||||||
|
# set RUN_ALL_HASH_TESTS to force it.
|
||||||
|
#
|
||||||
|
# When global JUST_SHOW_HASH_RESULTS is True, the tuple hash statistics
|
||||||
|
# are just displayed to stdout. A typical output line looks like:
|
||||||
|
#
|
||||||
|
# old tuple test; 32-bit upper hash codes; \
|
||||||
|
# pileup 49 mean 7.4 coll 52 z +16.4
|
||||||
|
#
|
||||||
|
# "old tuple test" is just a string name for the test being run.
|
||||||
|
#
|
||||||
|
# "32-bit upper hash codes" means this was run under a 64-bit build and
|
||||||
|
# we've shifted away the lower 32 bits of the hash codes.
|
||||||
|
#
|
||||||
|
# "pileup" is 0 if there were no collisions across those hash codes.
|
||||||
|
# It's 1 less than the maximum number of times any single hash code was
|
||||||
|
# seen. So in this case, there was (at least) one hash code that was
|
||||||
|
# seen 50 times: that hash code "piled up" 49 more times than ideal.
|
||||||
|
#
|
||||||
|
# "mean" is the number of collisions a perfectly random hash function
|
||||||
|
# would have yielded, on average.
|
||||||
|
#
|
||||||
|
# "coll" is the number of collisions actually seen.
|
||||||
|
#
|
||||||
|
# "z" is "coll - mean" divided by the standard deviation of the number
|
||||||
|
# of collisions a perfectly random hash function would suffer. A
|
||||||
|
# positive value is "worse than random", and negative value "better than
|
||||||
|
# random". Anything of magnitude greater than 3 would be highly suspect
|
||||||
|
# for a hash function that claimed to be random. It's essentially
|
||||||
|
# impossible that a truly random function would deliver a result 16.4
|
||||||
|
# sdevs "worse than random".
|
||||||
|
#
|
||||||
|
# But we don't care here! That's why the test isn't coded to fail.
|
||||||
|
# Knowing something about how the high-order hash code bits behave
|
||||||
|
# provides insight, but is irrelevant to how the dict and set lookup
|
||||||
|
# code performs. The low-order bits are much more important to that,
|
||||||
|
# and on the same test those did "just like random":
|
||||||
|
#
|
||||||
|
# old tuple test; 32-bit lower hash codes; \
|
||||||
|
# pileup 1 mean 7.4 coll 7 z -0.2
|
||||||
|
#
|
||||||
|
# So there are always tradeoffs to consider. For another:
|
||||||
|
#
|
||||||
|
# 0..99 << 60 by 3; 32-bit hash codes; \
|
||||||
|
# pileup 0 mean 116.4 coll 0 z -10.8
|
||||||
|
#
|
||||||
|
# That was run under a 32-bit build, and is spectacularly "better than
|
||||||
|
# random". On a 64-bit build the wider hash codes are fine too:
|
||||||
|
#
|
||||||
|
# 0..99 << 60 by 3; 64-bit hash codes; \
|
||||||
|
# pileup 0 mean 0.0 coll 0 z -0.0
|
||||||
|
#
|
||||||
|
# but their lower 32 bits are poor:
|
||||||
|
#
|
||||||
|
# 0..99 << 60 by 3; 32-bit lower hash codes; \
|
||||||
|
# pileup 1 mean 116.4 coll 324 z +19.2
|
||||||
|
#
|
||||||
|
# In a statistical sense that's waaaaay too many collisions, but (a) 324
|
||||||
|
# collisions out of a million hash codes isn't anywhere near being a
|
||||||
|
# real problem; and, (b) the worst pileup on a single hash code is a measly
|
||||||
|
# 1 extra. It's a relatively poor case for the tuple hash, but still
|
||||||
|
# fine for practical use.
|
||||||
|
#
|
||||||
|
# This isn't, which is what Python 3.7.1 produced for the hashes of
|
||||||
|
# itertools.product([0, 0.5], repeat=18). Even with a fat 64-bit
|
||||||
|
# hashcode, the highest pileup was over 16,000 - making a dict/set
|
||||||
|
# lookup on one of the colliding values thousands of times slower (on
|
||||||
|
# average) than we expect.
|
||||||
|
#
|
||||||
|
# [0, 0.5] by 18; 64-bit hash codes; \
|
||||||
|
# pileup 16,383 mean 0.0 coll 262,128 z +6073641856.9
|
||||||
|
# [0, 0.5] by 18; 32-bit lower hash codes; \
|
||||||
|
# pileup 262,143 mean 8.0 coll 262,143 z +92683.6
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_tests()
|
||||||
74
test/dynamo/cpython/3_13/test_userdict.diff
Normal file
74
test/dynamo/cpython/3_13/test_userdict.diff
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
diff --git a/test/dynamo/cpython/3_13/test_userdict.py b/test/dynamo/cpython/3_13/test_userdict.py
|
||||||
|
index 61e79f553e8..c953390355e 100644
|
||||||
|
--- a/test/dynamo/cpython/3_13/test_userdict.py
|
||||||
|
+++ b/test/dynamo/cpython/3_13/test_userdict.py
|
||||||
|
@@ -1,3 +1,54 @@
|
||||||
|
+# ======= BEGIN Dynamo patch =======
|
||||||
|
+# Owner(s): ["module: dynamo"]
|
||||||
|
+
|
||||||
|
+# ruff: noqa
|
||||||
|
+# flake8: noqa
|
||||||
|
+
|
||||||
|
+import sys
|
||||||
|
+import torch
|
||||||
|
+import torch._dynamo.test_case
|
||||||
|
+import unittest
|
||||||
|
+from torch._dynamo.test_case import CPythonTestCase
|
||||||
|
+from torch.testing._internal.common_utils import run_tests
|
||||||
|
+
|
||||||
|
+__TestCase = CPythonTestCase
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+# redirect import statements
|
||||||
|
+import sys
|
||||||
|
+import importlib.abc
|
||||||
|
+
|
||||||
|
+redirect_imports = (
|
||||||
|
+ "test.mapping_tests",
|
||||||
|
+ "test.typinganndata",
|
||||||
|
+ "test.test_grammar",
|
||||||
|
+ "test.test_math",
|
||||||
|
+ "test.test_iter",
|
||||||
|
+ "test.typinganndata.ann_module",
|
||||||
|
+)
|
||||||
|
+
|
||||||
|
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||||
|
+ def find_spec(self, fullname, path, target=None):
|
||||||
|
+ # Check if the import is the problematic one
|
||||||
|
+ if fullname in redirect_imports:
|
||||||
|
+ try:
|
||||||
|
+ # Attempt to import the standalone module
|
||||||
|
+ name = fullname.removeprefix("test.")
|
||||||
|
+ r = importlib.import_module(name)
|
||||||
|
+ # Redirect the module in sys.modules
|
||||||
|
+ sys.modules[fullname] = r
|
||||||
|
+ # Return a module spec from the found module
|
||||||
|
+ return importlib.util.find_spec(name)
|
||||||
|
+ except ImportError:
|
||||||
|
+ return None
|
||||||
|
+ return None
|
||||||
|
+
|
||||||
|
+# Add the custom finder to sys.meta_path
|
||||||
|
+sys.meta_path.insert(0, RedirectImportFinder())
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+# ======= END DYNAMO PATCH =======
|
||||||
|
+
|
||||||
|
# Check every path through every method of UserDict
|
||||||
|
|
||||||
|
from test import mapping_tests, support
|
||||||
|
@@ -215,10 +266,10 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol):
|
||||||
|
|
||||||
|
# Decorate existing test with recursion limit, because
|
||||||
|
# the test is for C structure, but `UserDict` is a Python structure.
|
||||||
|
- test_repr_deep = support.infinite_recursion(25)(
|
||||||
|
- mapping_tests.TestHashMappingProtocol.test_repr_deep,
|
||||||
|
- )
|
||||||
|
+ # test_repr_deep = support.infinite_recursion(25)(
|
||||||
|
+ # mapping_tests.TestHashMappingProtocol.test_repr_deep,
|
||||||
|
+ # )
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
- unittest.main()
|
||||||
|
+ run_tests()
|
||||||
275
test/dynamo/cpython/3_13/test_userdict.py
Normal file
275
test/dynamo/cpython/3_13/test_userdict.py
Normal file
@ -0,0 +1,275 @@
|
|||||||
|
# ======= BEGIN Dynamo patch =======
|
||||||
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
|
# ruff: noqa
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import torch._dynamo.test_case
|
||||||
|
import unittest
|
||||||
|
from torch._dynamo.test_case import CPythonTestCase
|
||||||
|
from torch.testing._internal.common_utils import run_tests
|
||||||
|
|
||||||
|
__TestCase = CPythonTestCase
|
||||||
|
|
||||||
|
|
||||||
|
# redirect import statements
|
||||||
|
import sys
|
||||||
|
import importlib.abc
|
||||||
|
|
||||||
|
redirect_imports = (
|
||||||
|
"test.mapping_tests",
|
||||||
|
"test.typinganndata",
|
||||||
|
"test.test_grammar",
|
||||||
|
"test.test_math",
|
||||||
|
"test.test_iter",
|
||||||
|
"test.typinganndata.ann_module",
|
||||||
|
)
|
||||||
|
|
||||||
|
class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||||
|
def find_spec(self, fullname, path, target=None):
|
||||||
|
# Check if the import is the problematic one
|
||||||
|
if fullname in redirect_imports:
|
||||||
|
try:
|
||||||
|
# Attempt to import the standalone module
|
||||||
|
name = fullname.removeprefix("test.")
|
||||||
|
r = importlib.import_module(name)
|
||||||
|
# Redirect the module in sys.modules
|
||||||
|
sys.modules[fullname] = r
|
||||||
|
# Return a module spec from the found module
|
||||||
|
return importlib.util.find_spec(name)
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Add the custom finder to sys.meta_path
|
||||||
|
sys.meta_path.insert(0, RedirectImportFinder())
|
||||||
|
|
||||||
|
|
||||||
|
# ======= END DYNAMO PATCH =======
|
||||||
|
|
||||||
|
# Check every path through every method of UserDict
|
||||||
|
|
||||||
|
from test import mapping_tests, support
|
||||||
|
import unittest
|
||||||
|
import collections
|
||||||
|
|
||||||
|
d0 = {}
|
||||||
|
d1 = {"one": 1}
|
||||||
|
d2 = {"one": 1, "two": 2}
|
||||||
|
d3 = {"one": 1, "two": 3, "three": 5}
|
||||||
|
d4 = {"one": None, "two": None}
|
||||||
|
d5 = {"one": 1, "two": 1}
|
||||||
|
|
||||||
|
class UserDictTest(mapping_tests.TestHashMappingProtocol):
|
||||||
|
type2test = collections.UserDict
|
||||||
|
|
||||||
|
def test_all(self):
|
||||||
|
# Test constructors
|
||||||
|
u = collections.UserDict()
|
||||||
|
u0 = collections.UserDict(d0)
|
||||||
|
u1 = collections.UserDict(d1)
|
||||||
|
u2 = collections.UserDict(d2)
|
||||||
|
|
||||||
|
uu = collections.UserDict(u)
|
||||||
|
uu0 = collections.UserDict(u0)
|
||||||
|
uu1 = collections.UserDict(u1)
|
||||||
|
uu2 = collections.UserDict(u2)
|
||||||
|
|
||||||
|
# keyword arg constructor
|
||||||
|
self.assertEqual(collections.UserDict(one=1, two=2), d2)
|
||||||
|
# item sequence constructor
|
||||||
|
self.assertEqual(collections.UserDict([('one',1), ('two',2)]), d2)
|
||||||
|
self.assertEqual(collections.UserDict(dict=[('one',1), ('two',2)]),
|
||||||
|
{'dict': [('one', 1), ('two', 2)]})
|
||||||
|
# both together
|
||||||
|
self.assertEqual(collections.UserDict([('one',1), ('two',2)], two=3, three=5), d3)
|
||||||
|
|
||||||
|
# alternate constructor
|
||||||
|
self.assertEqual(collections.UserDict.fromkeys('one two'.split()), d4)
|
||||||
|
self.assertEqual(collections.UserDict().fromkeys('one two'.split()), d4)
|
||||||
|
self.assertEqual(collections.UserDict.fromkeys('one two'.split(), 1), d5)
|
||||||
|
self.assertEqual(collections.UserDict().fromkeys('one two'.split(), 1), d5)
|
||||||
|
self.assertTrue(u1.fromkeys('one two'.split()) is not u1)
|
||||||
|
self.assertIsInstance(u1.fromkeys('one two'.split()), collections.UserDict)
|
||||||
|
self.assertIsInstance(u2.fromkeys('one two'.split()), collections.UserDict)
|
||||||
|
|
||||||
|
# Test __repr__
|
||||||
|
self.assertEqual(str(u0), str(d0))
|
||||||
|
self.assertEqual(repr(u1), repr(d1))
|
||||||
|
self.assertIn(repr(u2), ("{'one': 1, 'two': 2}",
|
||||||
|
"{'two': 2, 'one': 1}"))
|
||||||
|
|
||||||
|
# Test rich comparison and __len__
|
||||||
|
all = [d0, d1, d2, u, u0, u1, u2, uu, uu0, uu1, uu2]
|
||||||
|
for a in all:
|
||||||
|
for b in all:
|
||||||
|
self.assertEqual(a == b, len(a) == len(b))
|
||||||
|
|
||||||
|
# Test __getitem__
|
||||||
|
self.assertEqual(u2["one"], 1)
|
||||||
|
self.assertRaises(KeyError, u1.__getitem__, "two")
|
||||||
|
|
||||||
|
# Test __setitem__
|
||||||
|
u3 = collections.UserDict(u2)
|
||||||
|
u3["two"] = 2
|
||||||
|
u3["three"] = 3
|
||||||
|
|
||||||
|
# Test __delitem__
|
||||||
|
del u3["three"]
|
||||||
|
self.assertRaises(KeyError, u3.__delitem__, "three")
|
||||||
|
|
||||||
|
# Test clear
|
||||||
|
u3.clear()
|
||||||
|
self.assertEqual(u3, {})
|
||||||
|
|
||||||
|
# Test copy()
|
||||||
|
u2a = u2.copy()
|
||||||
|
self.assertEqual(u2a, u2)
|
||||||
|
u2b = collections.UserDict(x=42, y=23)
|
||||||
|
u2c = u2b.copy() # making a copy of a UserDict is special cased
|
||||||
|
self.assertEqual(u2b, u2c)
|
||||||
|
|
||||||
|
class MyUserDict(collections.UserDict):
|
||||||
|
def display(self): print(self)
|
||||||
|
|
||||||
|
m2 = MyUserDict(u2)
|
||||||
|
m2a = m2.copy()
|
||||||
|
self.assertEqual(m2a, m2)
|
||||||
|
|
||||||
|
# SF bug #476616 -- copy() of UserDict subclass shared data
|
||||||
|
m2['foo'] = 'bar'
|
||||||
|
self.assertNotEqual(m2a, m2)
|
||||||
|
|
||||||
|
# Test keys, items, values
|
||||||
|
self.assertEqual(sorted(u2.keys()), sorted(d2.keys()))
|
||||||
|
self.assertEqual(sorted(u2.items()), sorted(d2.items()))
|
||||||
|
self.assertEqual(sorted(u2.values()), sorted(d2.values()))
|
||||||
|
|
||||||
|
# Test "in".
|
||||||
|
for i in u2.keys():
|
||||||
|
self.assertIn(i, u2)
|
||||||
|
self.assertEqual(i in u1, i in d1)
|
||||||
|
self.assertEqual(i in u0, i in d0)
|
||||||
|
|
||||||
|
# Test update
|
||||||
|
t = collections.UserDict()
|
||||||
|
t.update(u2)
|
||||||
|
self.assertEqual(t, u2)
|
||||||
|
|
||||||
|
# Test get
|
||||||
|
for i in u2.keys():
|
||||||
|
self.assertEqual(u2.get(i), u2[i])
|
||||||
|
self.assertEqual(u1.get(i), d1.get(i))
|
||||||
|
self.assertEqual(u0.get(i), d0.get(i))
|
||||||
|
|
||||||
|
# Test "in" iteration.
|
||||||
|
for i in range(20):
|
||||||
|
u2[i] = str(i)
|
||||||
|
ikeys = []
|
||||||
|
for k in u2:
|
||||||
|
ikeys.append(k)
|
||||||
|
keys = u2.keys()
|
||||||
|
self.assertEqual(set(ikeys), set(keys))
|
||||||
|
|
||||||
|
# Test setdefault
|
||||||
|
t = collections.UserDict()
|
||||||
|
self.assertEqual(t.setdefault("x", 42), 42)
|
||||||
|
self.assertIn("x", t)
|
||||||
|
self.assertEqual(t.setdefault("x", 23), 42)
|
||||||
|
|
||||||
|
# Test pop
|
||||||
|
t = collections.UserDict(x=42)
|
||||||
|
self.assertEqual(t.pop("x"), 42)
|
||||||
|
self.assertRaises(KeyError, t.pop, "x")
|
||||||
|
self.assertEqual(t.pop("x", 1), 1)
|
||||||
|
t["x"] = 42
|
||||||
|
self.assertEqual(t.pop("x", 1), 42)
|
||||||
|
|
||||||
|
# Test popitem
|
||||||
|
t = collections.UserDict(x=42)
|
||||||
|
self.assertEqual(t.popitem(), ("x", 42))
|
||||||
|
self.assertRaises(KeyError, t.popitem)
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
for kw in 'self', 'other', 'iterable':
|
||||||
|
self.assertEqual(list(collections.UserDict(**{kw: 42}).items()),
|
||||||
|
[(kw, 42)])
|
||||||
|
self.assertEqual(list(collections.UserDict({}, dict=42).items()),
|
||||||
|
[('dict', 42)])
|
||||||
|
self.assertEqual(list(collections.UserDict({}, dict=None).items()),
|
||||||
|
[('dict', None)])
|
||||||
|
self.assertEqual(list(collections.UserDict(dict={'a': 42}).items()),
|
||||||
|
[('dict', {'a': 42})])
|
||||||
|
self.assertRaises(TypeError, collections.UserDict, 42)
|
||||||
|
self.assertRaises(TypeError, collections.UserDict, (), ())
|
||||||
|
self.assertRaises(TypeError, collections.UserDict.__init__)
|
||||||
|
|
||||||
|
def test_update(self):
|
||||||
|
for kw in 'self', 'dict', 'other', 'iterable':
|
||||||
|
d = collections.UserDict()
|
||||||
|
d.update(**{kw: 42})
|
||||||
|
self.assertEqual(list(d.items()), [(kw, 42)])
|
||||||
|
self.assertRaises(TypeError, collections.UserDict().update, 42)
|
||||||
|
self.assertRaises(TypeError, collections.UserDict().update, {}, {})
|
||||||
|
self.assertRaises(TypeError, collections.UserDict.update)
|
||||||
|
|
||||||
|
def test_missing(self):
|
||||||
|
# Make sure UserDict doesn't have a __missing__ method
|
||||||
|
self.assertEqual(hasattr(collections.UserDict, "__missing__"), False)
|
||||||
|
# Test several cases:
|
||||||
|
# (D) subclass defines __missing__ method returning a value
|
||||||
|
# (E) subclass defines __missing__ method raising RuntimeError
|
||||||
|
# (F) subclass sets __missing__ instance variable (no effect)
|
||||||
|
# (G) subclass doesn't define __missing__ at all
|
||||||
|
class D(collections.UserDict):
|
||||||
|
def __missing__(self, key):
|
||||||
|
return 42
|
||||||
|
d = D({1: 2, 3: 4})
|
||||||
|
self.assertEqual(d[1], 2)
|
||||||
|
self.assertEqual(d[3], 4)
|
||||||
|
self.assertNotIn(2, d)
|
||||||
|
self.assertNotIn(2, d.keys())
|
||||||
|
self.assertEqual(d[2], 42)
|
||||||
|
class E(collections.UserDict):
|
||||||
|
def __missing__(self, key):
|
||||||
|
raise RuntimeError(key)
|
||||||
|
e = E()
|
||||||
|
try:
|
||||||
|
e[42]
|
||||||
|
except RuntimeError as err:
|
||||||
|
self.assertEqual(err.args, (42,))
|
||||||
|
else:
|
||||||
|
self.fail("e[42] didn't raise RuntimeError")
|
||||||
|
class F(collections.UserDict):
|
||||||
|
def __init__(self):
|
||||||
|
# An instance variable __missing__ should have no effect
|
||||||
|
self.__missing__ = lambda key: None
|
||||||
|
collections.UserDict.__init__(self)
|
||||||
|
f = F()
|
||||||
|
try:
|
||||||
|
f[42]
|
||||||
|
except KeyError as err:
|
||||||
|
self.assertEqual(err.args, (42,))
|
||||||
|
else:
|
||||||
|
self.fail("f[42] didn't raise KeyError")
|
||||||
|
class G(collections.UserDict):
|
||||||
|
pass
|
||||||
|
g = G()
|
||||||
|
try:
|
||||||
|
g[42]
|
||||||
|
except KeyError as err:
|
||||||
|
self.assertEqual(err.args, (42,))
|
||||||
|
else:
|
||||||
|
self.fail("g[42] didn't raise KeyError")
|
||||||
|
|
||||||
|
# Decorate existing test with recursion limit, because
|
||||||
|
# the test is for C structure, but `UserDict` is a Python structure.
|
||||||
|
# test_repr_deep = support.infinite_recursion(25)(
|
||||||
|
# mapping_tests.TestHashMappingProtocol.test_repr_deep,
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_tests()
|
||||||
78
test/dynamo/cpython/3_13/test_userlist.diff
Normal file
78
test/dynamo/cpython/3_13/test_userlist.diff
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
diff --git a/test/dynamo/cpython/3_13/test_userlist.py b/test/dynamo/cpython/3_13/test_userlist.py
|
||||||
|
index 312702c8e39..a4532922f5d 100644
|
||||||
|
--- a/test/dynamo/cpython/3_13/test_userlist.py
|
||||||
|
+++ b/test/dynamo/cpython/3_13/test_userlist.py
|
||||||
|
@@ -1,7 +1,58 @@
|
||||||
|
+# ======= BEGIN Dynamo patch =======
|
||||||
|
+# Owner(s): ["module: dynamo"]
|
||||||
|
+
|
||||||
|
+# ruff: noqa
|
||||||
|
+# flake8: noqa
|
||||||
|
+
|
||||||
|
+import sys
|
||||||
|
+import torch
|
||||||
|
+import torch._dynamo.test_case
|
||||||
|
+import unittest
|
||||||
|
+from torch._dynamo.test_case import CPythonTestCase
|
||||||
|
+from torch.testing._internal.common_utils import run_tests
|
||||||
|
+
|
||||||
|
+__TestCase = CPythonTestCase
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+# redirect import statements
|
||||||
|
+import sys
|
||||||
|
+import importlib.abc
|
||||||
|
+
|
||||||
|
+redirect_imports = (
|
||||||
|
+ "test.mapping_tests",
|
||||||
|
+ "test.typinganndata",
|
||||||
|
+ "test.test_grammar",
|
||||||
|
+ "test.test_math",
|
||||||
|
+ "test.test_iter",
|
||||||
|
+ "test.typinganndata.ann_module",
|
||||||
|
+)
|
||||||
|
+
|
||||||
|
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||||
|
+ def find_spec(self, fullname, path, target=None):
|
||||||
|
+ # Check if the import is the problematic one
|
||||||
|
+ if fullname in redirect_imports:
|
||||||
|
+ try:
|
||||||
|
+ # Attempt to import the standalone module
|
||||||
|
+ name = fullname.removeprefix("test.")
|
||||||
|
+ r = importlib.import_module(name)
|
||||||
|
+ # Redirect the module in sys.modules
|
||||||
|
+ sys.modules[fullname] = r
|
||||||
|
+ # Return a module spec from the found module
|
||||||
|
+ return importlib.util.find_spec(name)
|
||||||
|
+ except ImportError:
|
||||||
|
+ return None
|
||||||
|
+ return None
|
||||||
|
+
|
||||||
|
+# Add the custom finder to sys.meta_path
|
||||||
|
+sys.meta_path.insert(0, RedirectImportFinder())
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+# ======= END DYNAMO PATCH =======
|
||||||
|
+
|
||||||
|
# Check every path through every method of UserList
|
||||||
|
|
||||||
|
from collections import UserList
|
||||||
|
-from test import list_tests
|
||||||
|
+import list_tests
|
||||||
|
import unittest
|
||||||
|
from test import support
|
||||||
|
|
||||||
|
@@ -69,9 +120,9 @@ class UserListTest(list_tests.CommonTest):
|
||||||
|
|
||||||
|
# Decorate existing test with recursion limit, because
|
||||||
|
# the test is for C structure, but `UserList` is a Python structure.
|
||||||
|
- test_repr_deep = support.infinite_recursion(25)(
|
||||||
|
- list_tests.CommonTest.test_repr_deep,
|
||||||
|
- )
|
||||||
|
+ # test_repr_deep = support.infinite_recursion(25)(
|
||||||
|
+ # list_tests.CommonTest.test_repr_deep,
|
||||||
|
+ # )
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
- unittest.main()
|
||||||
|
+ run_tests()
|
||||||
128
test/dynamo/cpython/3_13/test_userlist.py
Normal file
128
test/dynamo/cpython/3_13/test_userlist.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
# ======= BEGIN Dynamo patch =======
|
||||||
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
|
# ruff: noqa
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import torch._dynamo.test_case
|
||||||
|
import unittest
|
||||||
|
from torch._dynamo.test_case import CPythonTestCase
|
||||||
|
from torch.testing._internal.common_utils import run_tests
|
||||||
|
|
||||||
|
__TestCase = CPythonTestCase
|
||||||
|
|
||||||
|
|
||||||
|
# redirect import statements
|
||||||
|
import sys
|
||||||
|
import importlib.abc
|
||||||
|
|
||||||
|
redirect_imports = (
|
||||||
|
"test.mapping_tests",
|
||||||
|
"test.typinganndata",
|
||||||
|
"test.test_grammar",
|
||||||
|
"test.test_math",
|
||||||
|
"test.test_iter",
|
||||||
|
"test.typinganndata.ann_module",
|
||||||
|
)
|
||||||
|
|
||||||
|
class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
||||||
|
def find_spec(self, fullname, path, target=None):
|
||||||
|
# Check if the import is the problematic one
|
||||||
|
if fullname in redirect_imports:
|
||||||
|
try:
|
||||||
|
# Attempt to import the standalone module
|
||||||
|
name = fullname.removeprefix("test.")
|
||||||
|
r = importlib.import_module(name)
|
||||||
|
# Redirect the module in sys.modules
|
||||||
|
sys.modules[fullname] = r
|
||||||
|
# Return a module spec from the found module
|
||||||
|
return importlib.util.find_spec(name)
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Add the custom finder to sys.meta_path
|
||||||
|
sys.meta_path.insert(0, RedirectImportFinder())
|
||||||
|
|
||||||
|
|
||||||
|
# ======= END DYNAMO PATCH =======
|
||||||
|
|
||||||
|
# Check every path through every method of UserList
|
||||||
|
|
||||||
|
from collections import UserList
|
||||||
|
import list_tests
|
||||||
|
import unittest
|
||||||
|
from test import support
|
||||||
|
|
||||||
|
|
||||||
|
class UserListTest(list_tests.CommonTest):
|
||||||
|
type2test = UserList
|
||||||
|
|
||||||
|
def test_getslice(self):
|
||||||
|
super().test_getslice()
|
||||||
|
l = [0, 1, 2, 3, 4]
|
||||||
|
u = self.type2test(l)
|
||||||
|
for i in range(-3, 6):
|
||||||
|
self.assertEqual(u[:i], l[:i])
|
||||||
|
self.assertEqual(u[i:], l[i:])
|
||||||
|
for j in range(-3, 6):
|
||||||
|
self.assertEqual(u[i:j], l[i:j])
|
||||||
|
|
||||||
|
def test_slice_type(self):
|
||||||
|
l = [0, 1, 2, 3, 4]
|
||||||
|
u = UserList(l)
|
||||||
|
self.assertIsInstance(u[:], u.__class__)
|
||||||
|
self.assertEqual(u[:],u)
|
||||||
|
|
||||||
|
def test_add_specials(self):
|
||||||
|
u = UserList("spam")
|
||||||
|
u2 = u + "eggs"
|
||||||
|
self.assertEqual(u2, list("spameggs"))
|
||||||
|
|
||||||
|
def test_radd_specials(self):
|
||||||
|
u = UserList("eggs")
|
||||||
|
u2 = "spam" + u
|
||||||
|
self.assertEqual(u2, list("spameggs"))
|
||||||
|
u2 = u.__radd__(UserList("spam"))
|
||||||
|
self.assertEqual(u2, list("spameggs"))
|
||||||
|
|
||||||
|
def test_iadd(self):
|
||||||
|
super().test_iadd()
|
||||||
|
u = [0, 1]
|
||||||
|
u += UserList([0, 1])
|
||||||
|
self.assertEqual(u, [0, 1, 0, 1])
|
||||||
|
|
||||||
|
def test_mixedcmp(self):
|
||||||
|
u = self.type2test([0, 1])
|
||||||
|
self.assertEqual(u, [0, 1])
|
||||||
|
self.assertNotEqual(u, [0])
|
||||||
|
self.assertNotEqual(u, [0, 2])
|
||||||
|
|
||||||
|
def test_mixedadd(self):
|
||||||
|
u = self.type2test([0, 1])
|
||||||
|
self.assertEqual(u + [], u)
|
||||||
|
self.assertEqual(u + [2], [0, 1, 2])
|
||||||
|
|
||||||
|
def test_getitemoverwriteiter(self):
|
||||||
|
# Verify that __getitem__ overrides *are* recognized by __iter__
|
||||||
|
class T(self.type2test):
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return str(key) + '!!!'
|
||||||
|
self.assertEqual(next(iter(T((1,2)))), "0!!!")
|
||||||
|
|
||||||
|
def test_userlist_copy(self):
|
||||||
|
u = self.type2test([6, 8, 1, 9, 1])
|
||||||
|
v = u.copy()
|
||||||
|
self.assertEqual(u, v)
|
||||||
|
self.assertEqual(type(u), type(v))
|
||||||
|
|
||||||
|
# Decorate existing test with recursion limit, because
|
||||||
|
# the test is for C structure, but `UserList` is a Python structure.
|
||||||
|
# test_repr_deep = support.infinite_recursion(25)(
|
||||||
|
# list_tests.CommonTest.test_repr_deep,
|
||||||
|
# )
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_tests()
|
||||||
0
test/dynamo/cpython/3_13/typinganndata/__init__.py
Normal file
0
test/dynamo/cpython/3_13/typinganndata/__init__.py
Normal file
30
test/dynamo/cpython/3_13/typinganndata/_typed_dict_helper.py
Normal file
30
test/dynamo/cpython/3_13/typinganndata/_typed_dict_helper.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
"""Used to test `get_type_hints()` on a cross-module inherited `TypedDict` class
|
||||||
|
|
||||||
|
This script uses future annotations to postpone a type that won't be available
|
||||||
|
on the module inheriting from to `Foo`. The subclass in the other module should
|
||||||
|
look something like this:
|
||||||
|
|
||||||
|
class Bar(_typed_dict_helper.Foo, total=False):
|
||||||
|
b: int
|
||||||
|
|
||||||
|
In addition, it uses multiple levels of Annotated to test the interaction
|
||||||
|
between the __future__ import, Annotated, and Required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated, Generic, Optional, Required, TypedDict, TypeVar
|
||||||
|
|
||||||
|
|
||||||
|
OptionalIntType = Optional[int]
|
||||||
|
|
||||||
|
class Foo(TypedDict):
|
||||||
|
a: OptionalIntType
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
class FooGeneric(TypedDict, Generic[T]):
|
||||||
|
a: Optional[T]
|
||||||
|
|
||||||
|
class VeryAnnotated(TypedDict, total=False):
|
||||||
|
a: Annotated[Annotated[Annotated[Required[int], "a"], "b"], "c"]
|
||||||
62
test/dynamo/cpython/3_13/typinganndata/ann_module.py
Normal file
62
test/dynamo/cpython/3_13/typinganndata/ann_module.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
The module for testing variable annotations.
|
||||||
|
Empty lines above are for good reason (testing for correct line numbers)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
__annotations__[1] = 2
|
||||||
|
|
||||||
|
class C:
|
||||||
|
|
||||||
|
x = 5; y: Optional['C'] = None
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
x: int = 5; y: str = x; f: Tuple[int, int]
|
||||||
|
|
||||||
|
class M(type):
|
||||||
|
|
||||||
|
__annotations__['123'] = 123
|
||||||
|
o: type = object
|
||||||
|
|
||||||
|
(pars): bool = True
|
||||||
|
|
||||||
|
class D(C):
|
||||||
|
j: str = 'hi'; k: str= 'bye'
|
||||||
|
|
||||||
|
from types import new_class
|
||||||
|
h_class = new_class('H', (C,))
|
||||||
|
j_class = new_class('J')
|
||||||
|
|
||||||
|
class F():
|
||||||
|
z: int = 5
|
||||||
|
def __init__(self, x):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Y(F):
|
||||||
|
def __init__(self):
|
||||||
|
super(F, self).__init__(123)
|
||||||
|
|
||||||
|
class Meta(type):
|
||||||
|
def __new__(meta, name, bases, namespace):
|
||||||
|
return super().__new__(meta, name, bases, namespace)
|
||||||
|
|
||||||
|
class S(metaclass = Meta):
|
||||||
|
x: str = 'something'
|
||||||
|
y: str = 'something else'
|
||||||
|
|
||||||
|
def foo(x: int = 10):
|
||||||
|
def bar(y: List[str]):
|
||||||
|
x: str = 'yes'
|
||||||
|
bar()
|
||||||
|
|
||||||
|
def dec(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
u: int | float
|
||||||
36
test/dynamo/cpython/3_13/typinganndata/ann_module2.py
Normal file
36
test/dynamo/cpython/3_13/typinganndata/ann_module2.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
"""
|
||||||
|
Some correct syntax for variable annotation here.
|
||||||
|
More examples are in test_grammar and test_parser.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import no_type_check, ClassVar
|
||||||
|
|
||||||
|
i: int = 1
|
||||||
|
j: int
|
||||||
|
x: float = i/10
|
||||||
|
|
||||||
|
def f():
|
||||||
|
class C: ...
|
||||||
|
return C()
|
||||||
|
|
||||||
|
f().new_attr: object = object()
|
||||||
|
|
||||||
|
class C:
|
||||||
|
def __init__(self, x: int) -> None:
|
||||||
|
self.x = x
|
||||||
|
|
||||||
|
c = C(5)
|
||||||
|
c.new_attr: int = 10
|
||||||
|
|
||||||
|
__annotations__ = {}
|
||||||
|
|
||||||
|
|
||||||
|
@no_type_check
|
||||||
|
class NTC:
|
||||||
|
def meth(self, param: complex) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
class CV:
|
||||||
|
var: ClassVar['CV']
|
||||||
|
|
||||||
|
CV.var = CV()
|
||||||
18
test/dynamo/cpython/3_13/typinganndata/ann_module3.py
Normal file
18
test/dynamo/cpython/3_13/typinganndata/ann_module3.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
"""
|
||||||
|
Correct syntax for variable annotation that should fail at runtime
|
||||||
|
in a certain manner. More examples are in test_grammar and test_parser.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def f_bad_ann():
|
||||||
|
__annotations__[1] = 2
|
||||||
|
|
||||||
|
class C_OK:
|
||||||
|
def __init__(self, x: int) -> None:
|
||||||
|
self.x: no_such_name = x # This one is OK as proposed by Guido
|
||||||
|
|
||||||
|
class D_bad_ann:
|
||||||
|
def __init__(self, x: int) -> None:
|
||||||
|
sfel.y: int = 0
|
||||||
|
|
||||||
|
def g_bad_ann():
|
||||||
|
no_such_name.attr: int = 0
|
||||||
5
test/dynamo/cpython/3_13/typinganndata/ann_module4.py
Normal file
5
test/dynamo/cpython/3_13/typinganndata/ann_module4.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# This ann_module isn't for test_typing,
|
||||||
|
# it's for test_module
|
||||||
|
|
||||||
|
a:int=3
|
||||||
|
b:str=4
|
||||||
10
test/dynamo/cpython/3_13/typinganndata/ann_module5.py
Normal file
10
test/dynamo/cpython/3_13/typinganndata/ann_module5.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# Used by test_typing to verify that Final wrapped in ForwardRef works.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
name: Final[str] = "final"
|
||||||
|
|
||||||
|
class MyClass:
|
||||||
|
value: Final = 3000
|
||||||
7
test/dynamo/cpython/3_13/typinganndata/ann_module6.py
Normal file
7
test/dynamo/cpython/3_13/typinganndata/ann_module6.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
# Tests that top-level ClassVar is not allowed
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
wrong: ClassVar[int] = 1
|
||||||
22
test/dynamo/cpython/3_13/typinganndata/ann_module695.py
Normal file
22
test/dynamo/cpython/3_13/typinganndata/ann_module695.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
|
||||||
|
class A[T, *Ts, **P]:
|
||||||
|
x: T
|
||||||
|
y: tuple[*Ts]
|
||||||
|
z: Callable[P, str]
|
||||||
|
|
||||||
|
|
||||||
|
class B[T, *Ts, **P]:
|
||||||
|
T = int
|
||||||
|
Ts = str
|
||||||
|
P = bytes
|
||||||
|
x: T
|
||||||
|
y: Ts
|
||||||
|
z: P
|
||||||
|
|
||||||
|
|
||||||
|
def generic_function[T, *Ts, **P](
|
||||||
|
x: T, *y: *Ts, z: P.args, zz: P.kwargs
|
||||||
|
) -> None: ...
|
||||||
11
test/dynamo/cpython/3_13/typinganndata/ann_module7.py
Normal file
11
test/dynamo/cpython/3_13/typinganndata/ann_module7.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# Tests class have ``__text_signature__``
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
DEFAULT_BUFFER_SIZE = 8192
|
||||||
|
|
||||||
|
class BufferedReader(object):
|
||||||
|
"""BufferedReader(raw, buffer_size=DEFAULT_BUFFER_SIZE)\n--\n\n
|
||||||
|
Create a new buffered reader using the given readable raw IO object.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
10
test/dynamo/cpython/3_13/typinganndata/ann_module8.py
Normal file
10
test/dynamo/cpython/3_13/typinganndata/ann_module8.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# Test `@no_type_check`,
|
||||||
|
# see https://bugs.python.org/issue46571
|
||||||
|
|
||||||
|
class NoTypeCheck_Outer:
|
||||||
|
class Inner:
|
||||||
|
x: int
|
||||||
|
|
||||||
|
|
||||||
|
def NoTypeCheck_function(arg: int) -> int:
|
||||||
|
...
|
||||||
14
test/dynamo/cpython/3_13/typinganndata/ann_module9.py
Normal file
14
test/dynamo/cpython/3_13/typinganndata/ann_module9.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# Test ``inspect.formatannotation``
|
||||||
|
# https://github.com/python/cpython/issues/96073
|
||||||
|
|
||||||
|
from typing import Union, List
|
||||||
|
|
||||||
|
ann = Union[List[str], int]
|
||||||
|
|
||||||
|
# mock typing._type_repr behaviour
|
||||||
|
class A: ...
|
||||||
|
|
||||||
|
A.__module__ = 'testModule.typing'
|
||||||
|
A.__qualname__ = 'A'
|
||||||
|
|
||||||
|
ann1 = Union[List[A], int]
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user