mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 20:34:54 +08:00 
			
		
		
		
	Compare commits
	
		
			160 Commits
		
	
	
		
			benchmarki
			...
			mlazos/tes
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 277ff19a55 | |||
| 2b2245d5db | |||
| 206e9d5160 | |||
| 064bb3cebc | |||
| 0350c7e72c | |||
| f7c09f864a | |||
| c2e9115757 | |||
| b90fc2ec27 | |||
| 0cd18ba1ca | |||
| bfae151269 | |||
| 9cbbc2593b | |||
| 5616fa4a68 | |||
| c33fc9dae3 | |||
| 9ce2732b68 | |||
| dbad6d71c7 | |||
| b85c460749 | |||
| 6a781619bf | |||
| c99e91b1d7 | |||
| c014e4bcaa | |||
| daff263062 | |||
| 15e9119a69 | |||
| 7368eeba5e | |||
| 7a79de1c0f | |||
| bd10ea4e6c | |||
| 43390d8b13 | |||
| ad26ec6abe | |||
| 3e71016459 | |||
| 489afa829a | |||
| 472773c7f9 | |||
| f01e628e3b | |||
| 932733e0e6 | |||
| 108422ac26 | |||
| da4aacabac | |||
| 9b5308cd58 | |||
| b019a33f8f | |||
| 0fab32290a | |||
| faf973da5e | |||
| 78624679a8 | |||
| 5f1c3c67b2 | |||
| bbda22e648 | |||
| 0f3db20132 | |||
| eb93c0adb1 | |||
| 1193bf0855 | |||
| 26aa8dcf27 | |||
| 5acb8d5080 | |||
| abc2264e8f | |||
| 22a4cabd19 | |||
| ed1ff7d0fb | |||
| 2f03673ebf | |||
| f57754e815 | |||
| d6edefefbf | |||
| d89d213118 | |||
| 22641f42b6 | |||
| 967937872f | |||
| f9dc20c7a3 | |||
| fb67fa9968 | |||
| 35fc5c49b4 | |||
| b6b9311f4f | |||
| bbdf469f0e | |||
| 2120eeb8de | |||
| 1b569e5490 | |||
| 30ac7f4d4e | |||
| 65d8dba735 | |||
| 3bdceab124 | |||
| 802ffd06c8 | |||
| fc0135ca11 | |||
| 3027051590 | |||
| e7bf72c908 | |||
| 7183f52675 | |||
| 8002d22ce3 | |||
| 31f95b5d2e | |||
| 4b1f047a33 | |||
| ba3f91af97 | |||
| 0f81c7a28d | |||
| 7e8532077f | |||
| 1ece53b157 | |||
| 9d6f0d5991 | |||
| 3c05167489 | |||
| aec3ef1008 | |||
| dc82e911e7 | |||
| 639f459cb6 | |||
| f889dea97d | |||
| 208965a9d6 | |||
| 5a7442b91f | |||
| d66a55def0 | |||
| 382b38ed1b | |||
| bcbd2a22b2 | |||
| 0df96e3921 | |||
| 30f7079c93 | |||
| d173ba5a75 | |||
| 0fdd568b78 | |||
| a4b0023f3b | |||
| ba51f4876d | |||
| 852b99eba0 | |||
| 20ee5f9044 | |||
| 9c06dff1ce | |||
| c3de2c7c6b | |||
| 4a302b5731 | |||
| adfd5b293a | |||
| 0289313551 | |||
| 58ead04ee9 | |||
| 172015fc11 | |||
| 9371491529 | |||
| d6cb0fe576 | |||
| 0134150ebb | |||
| 61bfb3df9f | |||
| 2c1cb38d95 | |||
| 5b6fd277f9 | |||
| 818f76a745 | |||
| dc0f09a478 | |||
| 0c6c7780d9 | |||
| 9ba67e99bb | |||
| d5e0704247 | |||
| 43b18d098b | |||
| b040d63ce4 | |||
| 7d17253af8 | |||
| fdbf314278 | |||
| c7e8e8ee19 | |||
| 1237f271aa | |||
| 08fdc64c86 | |||
| 83a0e4e6f9 | |||
| 2bc8fec744 | |||
| cb56df55dc | |||
| 629fca295e | |||
| 3afbab66f7 | |||
| e8f5c24d17 | |||
| 20ec61a02f | |||
| 5a21d6f982 | |||
| 0db9c64d68 | |||
| 6f992e1b3f | |||
| 634ce22601 | |||
| 8883e494b3 | |||
| 41092cb86c | |||
| 733e684b11 | |||
| 2c6f24c62d | |||
| 53b0f6f543 | |||
| ef1d45b12d | |||
| d6e29bf875 | |||
| 3c74a72ea0 | |||
| cd9ff41282 | |||
| 447b481c79 | |||
| 9c7ed3e46e | |||
| 07343efc15 | |||
| b394c6e89c | |||
| c0864bb389 | |||
| 316e7a9293 | |||
| 2d932a2e01 | |||
| 4613081b72 | |||
| 946a4c2bdc | |||
| ba0a91b3ea | |||
| 22a1b3b5d0 | |||
| 40abb2b403 | |||
| 2b3ac17aa2 | |||
| 81b7c96697 | |||
| 6cda280483 | |||
| bbd45f1f1f | |||
| 0f0d5749a0 | |||
| 65b1aedd09 | |||
| 3e05a48927 | |||
| d865b784e4 | 
| @ -27,6 +27,7 @@ if [ "$DESIRED_CUDA" = "cpu" ]; then | ||||
|     USE_PRIORITIZED_TEXT_FOR_LD=1 python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn | ||||
| else | ||||
|     echo "BASE_CUDA_VERSION is set to: $DESIRED_CUDA" | ||||
|     export USE_SYSTEM_NCCL=1 | ||||
|     #USE_PRIORITIZED_TEXT_FOR_LD for enable linker script optimization https://github.com/pytorch/pytorch/pull/121975/files | ||||
|     USE_PRIORITIZED_TEXT_FOR_LD=1 python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda | ||||
| fi | ||||
|  | ||||
| @ -8,16 +8,6 @@ retry () { | ||||
|     "$@" || (sleep 10 && "$@") || (sleep 20 && "$@") || (sleep 40 && "$@") | ||||
| } | ||||
|  | ||||
| # A bunch of custom pip dependencies for ONNX | ||||
| pip_install \ | ||||
|   beartype==0.15.0 \ | ||||
|   filelock==3.9.0 \ | ||||
|   flatbuffers==2.0 \ | ||||
|   mock==5.0.1 \ | ||||
|   ninja==1.10.2 \ | ||||
|   networkx==2.5 \ | ||||
|   numpy==1.24.2 | ||||
|  | ||||
| # ONNXRuntime should be installed before installing | ||||
| # onnx-weekly. Otherwise, onnx-weekly could be | ||||
| # overwritten by onnx. | ||||
| @ -29,11 +19,8 @@ pip_install \ | ||||
|   transformers==4.36.2 | ||||
|  | ||||
| pip_install coloredlogs packaging | ||||
|  | ||||
| pip_install onnxruntime==1.18.1 | ||||
| pip_install onnxscript==0.2.6 --no-deps | ||||
| # required by onnxscript | ||||
| pip_install ml_dtypes | ||||
| pip_install onnxscript==0.3.0 | ||||
|  | ||||
| # Cache the transformers model to be used later by ONNX tests. We need to run the transformers | ||||
| # package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/ | ||||
|  | ||||
| @ -51,7 +51,12 @@ as_jenkins git clone --recursive ${TRITON_REPO} triton | ||||
| cd triton | ||||
| as_jenkins git checkout ${TRITON_PINNED_COMMIT} | ||||
| as_jenkins git submodule update --init --recursive | ||||
| cd python | ||||
|  | ||||
| # Old versions of python have setup.py in ./python; newer versions have it in ./ | ||||
| if [ ! -f setup.py ]; then | ||||
|   cd python | ||||
| fi | ||||
|  | ||||
| pip_install pybind11==2.13.6 | ||||
|  | ||||
| # TODO: remove patch setup.py once we have a proper fix for https://github.com/triton-lang/triton/issues/4527 | ||||
|  | ||||
| @ -41,14 +41,11 @@ fbscribelogger==0.1.7 | ||||
| #Pinned versions: 0.1.6 | ||||
| #test that import: | ||||
|  | ||||
| flatbuffers==2.0 ; platform_machine != "s390x" | ||||
| flatbuffers==24.12.23 | ||||
| #Description: cross platform serialization library | ||||
| #Pinned versions: 2.0 | ||||
| #Pinned versions: 24.12.23 | ||||
| #test that import: | ||||
|  | ||||
| flatbuffers ; platform_machine == "s390x" | ||||
| #Description: cross platform serialization library; Newer version is required on s390x for new python version | ||||
|  | ||||
| hypothesis==5.35.1 | ||||
| # Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136 | ||||
| #Description: advanced library for generating parametrized tests | ||||
|  | ||||
| @ -15,6 +15,9 @@ export INSTALL_TEST=0 # dont install test binaries into site-packages | ||||
| export USE_CUPTI_SO=0 | ||||
| export USE_CUSPARSELT=${USE_CUSPARSELT:-1} # Enable if not disabled by libtorch build | ||||
| export USE_CUFILE=${USE_CUFILE:-1} | ||||
| export USE_SYSTEM_NCCL=1 | ||||
| export NCCL_INCLUDE_DIR="/usr/local/cuda/include/" | ||||
| export NCCL_LIB_DIR="/usr/local/cuda/lib64/" | ||||
|  | ||||
| # Keep an array of cmake variables to add to | ||||
| if [[ -z "$CMAKE_ARGS" ]]; then | ||||
| @ -172,12 +175,9 @@ if [[ $CUDA_VERSION == 12* ]]; then | ||||
|         export LIB_SO_RPATH=$CUDA_RPATHS':$ORIGIN' | ||||
|         export FORCE_RPATH="--force-rpath" | ||||
|         export USE_STATIC_NCCL=0 | ||||
|         export USE_SYSTEM_NCCL=1 | ||||
|         export ATEN_STATIC_CUDA=0 | ||||
|         export USE_CUDA_STATIC_LINK=0 | ||||
|         export USE_CUPTI_SO=1 | ||||
|         export NCCL_INCLUDE_DIR="/usr/local/cuda/include/" | ||||
|         export NCCL_LIB_DIR="/usr/local/cuda/lib64/" | ||||
|     fi | ||||
| elif [[ $CUDA_VERSION == "11.8" ]]; then | ||||
|     export USE_STATIC_CUDNN=0 | ||||
| @ -254,12 +254,9 @@ elif [[ $CUDA_VERSION == "11.8" ]]; then | ||||
|         export LIB_SO_RPATH=$CUDA_RPATHS':$ORIGIN' | ||||
|         export FORCE_RPATH="--force-rpath" | ||||
|         export USE_STATIC_NCCL=0 | ||||
|         export USE_SYSTEM_NCCL=1 | ||||
|         export ATEN_STATIC_CUDA=0 | ||||
|         export USE_CUDA_STATIC_LINK=0 | ||||
|         export USE_CUPTI_SO=1 | ||||
|         export NCCL_INCLUDE_DIR="/usr/local/cuda/include/" | ||||
|         export NCCL_LIB_DIR="/usr/local/cuda/lib64/" | ||||
|     fi | ||||
| else | ||||
|     echo "Unknown cuda version $CUDA_VERSION" | ||||
|  | ||||
| @ -324,6 +324,12 @@ test_python_smoke() { | ||||
|   assert_git_not_dirty | ||||
| } | ||||
|  | ||||
| test_h100_distributed() { | ||||
|   # Distributed tests at H100 | ||||
|   time python test/run_test.py --include distributed/_composable/test_composability/test_pp_composability.py  $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running | ||||
|   assert_git_not_dirty | ||||
| } | ||||
|  | ||||
| test_lazy_tensor_meta_reference_disabled() { | ||||
|   export TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1 | ||||
|   echo "Testing lazy tensor operations without meta reference" | ||||
| @ -595,7 +601,6 @@ test_perf_for_dashboard() { | ||||
|     elif [[ "${TEST_CONFIG}" == *cpu_aarch64* ]]; then | ||||
|       device=cpu_aarch64 | ||||
|     fi | ||||
|     test_inductor_set_cpu_affinity | ||||
|   elif [[ "${TEST_CONFIG}" == *cuda_a10g* ]]; then | ||||
|     device=cuda_a10g | ||||
|   elif [[ "${TEST_CONFIG}" == *h100* ]]; then | ||||
| @ -604,6 +609,9 @@ test_perf_for_dashboard() { | ||||
|     device=rocm | ||||
|   fi | ||||
|  | ||||
|   # Always set CPU affinity because metrics like compilation time requires CPU | ||||
|   test_inductor_set_cpu_affinity | ||||
|  | ||||
|   for mode in "${modes[@]}"; do | ||||
|     if [[ "$mode" == "inference" ]]; then | ||||
|       dtype=bfloat16 | ||||
| @ -1639,7 +1647,7 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then | ||||
|     install_torchaudio cuda | ||||
|   fi | ||||
|   install_torchvision | ||||
|   TORCH_CUDA_ARCH_LIST="8.0;8.6" pip_install git+https://github.com/pytorch/ao.git | ||||
|   TORCH_CUDA_ARCH_LIST="8.0;8.6" install_torchao | ||||
|   id=$((SHARD_NUMBER-1)) | ||||
|   # https://github.com/opencv/opencv-python/issues/885 | ||||
|   pip_install opencv-python==4.8.0.74 | ||||
| @ -1724,6 +1732,8 @@ elif [[ "${BUILD_ENVIRONMENT}" == *xpu* ]]; then | ||||
|   test_xpu_bin | ||||
| elif [[ "${TEST_CONFIG}" == smoke ]]; then | ||||
|   test_python_smoke | ||||
| elif [[ "${TEST_CONFIG}" == h100_distributed ]]; then | ||||
|   test_h100_distributed | ||||
| else | ||||
|   install_torchvision | ||||
|   install_monkeytype | ||||
|  | ||||
							
								
								
									
										21
									
								
								.github/actions/reuse-old-whl/reuse_old_whl.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										21
									
								
								.github/actions/reuse-old-whl/reuse_old_whl.py
									
									
									
									
										vendored
									
									
								
							| @ -120,6 +120,23 @@ def ok_changed_file(file: str) -> bool: | ||||
| def check_changed_files(sha: str) -> bool: | ||||
|     # Return true if all the changed files are in the list of allowed files to | ||||
|     # be changed to reuse the old whl | ||||
|  | ||||
|     # Removing any files is not allowed since rysnc will not remove files | ||||
|     removed_files = ( | ||||
|         subprocess.check_output( | ||||
|             ["git", "diff", "--name-only", sha, "HEAD", "--diff-filter=D"], | ||||
|             text=True, | ||||
|             stderr=subprocess.DEVNULL, | ||||
|         ) | ||||
|         .strip() | ||||
|         .split() | ||||
|     ) | ||||
|     if removed_files: | ||||
|         print( | ||||
|             f"Removed files between {sha} and HEAD: {removed_files}, cannot reuse old whl" | ||||
|         ) | ||||
|         return False | ||||
|  | ||||
|     changed_files = ( | ||||
|         subprocess.check_output( | ||||
|             ["git", "diff", "--name-only", sha, "HEAD"], | ||||
| @ -190,6 +207,10 @@ def unzip_artifact_and_replace_files() -> None: | ||||
|         subprocess.check_output( | ||||
|             ["unzip", "-o", new_path, "-d", f"artifacts/dist/{new_path.stem}"], | ||||
|         ) | ||||
|  | ||||
|         # Remove the old wheel (which is now a zip file) | ||||
|         os.remove(new_path) | ||||
|  | ||||
|         # Copy python files into the artifact | ||||
|         subprocess.check_output( | ||||
|             ["rsync", "-avz", "torch", f"artifacts/dist/{new_path.stem}"], | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							| @ -28,6 +28,7 @@ ciflow_push_tags: | ||||
| - ciflow/op-benchmark | ||||
| - ciflow/pull | ||||
| - ciflow/h100 | ||||
| - ciflow/h100-distributed | ||||
| retryable_workflows: | ||||
| - pull | ||||
| - trunk | ||||
|  | ||||
							
								
								
									
										15
									
								
								.github/scripts/build_triton_wheel.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										15
									
								
								.github/scripts/build_triton_wheel.py
									
									
									
									
										vendored
									
									
								
							| @ -65,6 +65,7 @@ def build_triton( | ||||
|     with TemporaryDirectory() as tmpdir: | ||||
|         triton_basedir = Path(tmpdir) / "triton" | ||||
|         triton_pythondir = triton_basedir / "python" | ||||
|  | ||||
|         triton_repo = "https://github.com/openai/triton" | ||||
|         if device == "rocm": | ||||
|             triton_pkg_name = "pytorch-triton-rocm" | ||||
| @ -101,11 +102,19 @@ def build_triton( | ||||
|             ) | ||||
|             print("ROCm libraries setup for triton installation...") | ||||
|  | ||||
|         check_call( | ||||
|             [sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir, env=env | ||||
|         # old triton versions have setup.py in the python/ dir, | ||||
|         # new versions have it in the root dir. | ||||
|         triton_setupdir = ( | ||||
|             triton_basedir | ||||
|             if (triton_basedir / "setup.py").exists() | ||||
|             else triton_pythondir | ||||
|         ) | ||||
|  | ||||
|         whl_path = next(iter((triton_pythondir / "dist").glob("*.whl"))) | ||||
|         check_call( | ||||
|             [sys.executable, "setup.py", "bdist_wheel"], cwd=triton_setupdir, env=env | ||||
|         ) | ||||
|  | ||||
|         whl_path = next(iter((triton_setupdir / "dist").glob("*.whl"))) | ||||
|         shutil.copy(whl_path, Path.cwd()) | ||||
|  | ||||
|         if device == "rocm": | ||||
|  | ||||
							
								
								
									
										9
									
								
								.github/workflows/build-triton-wheel.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								.github/workflows/build-triton-wheel.yml
									
									
									
									
										vendored
									
									
								
							| @ -139,6 +139,15 @@ jobs: | ||||
|  | ||||
|           docker exec -t "${container_name}" yum install -y zlib-devel zip | ||||
|           docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}"  -m pip install -U setuptools==78.1.0 pybind11==2.13.1 auditwheel wheel | ||||
|           set +e | ||||
|           docker exec -t "${container_name}" command -v pip | ||||
|           has_pip=$? | ||||
|           set -e | ||||
|           if [ $has_pip -eq 0 ] ; then | ||||
|               docker exec -t "${container_name}" pip install -U cmake --force-reinstall | ||||
|           else | ||||
|               docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}"  -m pip install -U cmake --force-reinstall | ||||
|           fi | ||||
|  | ||||
|           if [[ ("${{ matrix.device }}" == "cuda" || "${{ matrix.device }}" == "rocm" || "${{ matrix.device }}" == "aarch64" ) ]]; then | ||||
|             # With this install, it gets clang 16.0.6. | ||||
|  | ||||
							
								
								
									
										53
									
								
								.github/workflows/h100-distributed.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								.github/workflows/h100-distributed.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,53 @@ | ||||
| name: Limited CI for distributed tests on H100 | ||||
|  | ||||
| on: | ||||
|   pull_request: | ||||
|     paths: | ||||
|       - .github/workflows/h100-distributed.yml | ||||
|   workflow_dispatch: | ||||
|   push: | ||||
|     tags: | ||||
|       - ciflow/h100-distributed/* | ||||
|  | ||||
| concurrency: | ||||
|   group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} | ||||
|   cancel-in-progress: true | ||||
|  | ||||
| jobs: | ||||
|  | ||||
|   get-label-type: | ||||
|     if: github.repository_owner == 'pytorch' | ||||
|     name: get-label-type | ||||
|     uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main | ||||
|     with: | ||||
|       triggering_actor: ${{ github.triggering_actor }} | ||||
|       issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} | ||||
|       curr_branch: ${{ github.head_ref || github.ref_name }} | ||||
|       curr_ref_type: ${{ github.ref_type }} | ||||
|  | ||||
|   linux-focal-cuda12_6-py3_10-gcc11-sm90-build: | ||||
|     name: linux-focal-cuda12.6-py3.10-gcc11-sm90 | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     needs: get-label-type | ||||
|     with: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       runner: "linux.12xlarge" | ||||
|       build-environment: linux-focal-cuda12.6-py3.10-gcc11-sm90 | ||||
|       docker-image-name: ci-image:pytorch-linux-focal-cuda12.6-cudnn9-py3-gcc11 | ||||
|       cuda-arch-list: '9.0' | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "h100_distributed", shard: 1, num_shards: 1, runner: "linux.aws.h100.8" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   linux-focal-cuda12_6-py3_10-gcc11-sm90-test: | ||||
|     name: linux-focal-cuda12.6-py3.10-gcc11-sm90 | ||||
|     uses: ./.github/workflows/_linux-test.yml | ||||
|     needs: | ||||
|       - linux-focal-cuda12_6-py3_10-gcc11-sm90-build | ||||
|     with: | ||||
|       build-environment: linux-focal-cuda12.6-py3.10-gcc11-sm90 | ||||
|       docker-image: ${{ needs.linux-focal-cuda12_6-py3_10-gcc11-sm90-build.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc11-sm90-build.outputs.test-matrix }} | ||||
|     secrets: inherit | ||||
| @ -64,6 +64,7 @@ include_patterns = [ | ||||
|     'aten/src/ATen/xpu/**/*.cpp', | ||||
|     'aten/src/ATen/core/boxing/**/*.h', | ||||
|     'aten/src/ATen/core/dispatch/**/*.h', | ||||
|     'aten/src/ATen/core/Formatting.cpp', | ||||
|     'aten/src/ATen/native/mps/**/*.metal', | ||||
|     'aten/src/ATen/native/mps/**/*.mm', | ||||
|     'aten/src/ATen/native/mps/**/*.h', | ||||
|  | ||||
| @ -290,6 +290,7 @@ header_template_rule( | ||||
|     substitutions = { | ||||
|         "@AT_CUDNN_ENABLED@": "1", | ||||
|         "@AT_CUSPARSELT_ENABLED@": "0", | ||||
|         "@AT_HIPSPARSELT_ENABLED@": "0", | ||||
|         "@AT_ROCM_ENABLED@": "0", | ||||
|         "@AT_MAGMA_ENABLED@": "0", | ||||
|         "@NVCC_FLAGS_EXTRA@": "", | ||||
|  | ||||
| @ -101,6 +101,13 @@ else() | ||||
|   set(AT_CUSPARSELT_ENABLED 1) | ||||
| endif() | ||||
|  | ||||
| # Add hipSPARSELt support flag | ||||
| if(USE_ROCM AND ROCM_VERSION VERSION_GREATER_EQUAL "6.4.0") | ||||
|   set(AT_HIPSPARSELT_ENABLED 1) | ||||
| else() | ||||
|   set(AT_HIPSPARSELT_ENABLED 0) | ||||
| endif() | ||||
|  | ||||
| list(APPEND ATen_CPU_INCLUDE | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/src) | ||||
| add_subdirectory(src/ATen) | ||||
|  | ||||
| @ -34,6 +34,7 @@ set_bool(AT_MAGMA_ENABLED USE_MAGMA) | ||||
| set_bool(CAFFE2_STATIC_LINK_CUDA_INT CAFFE2_STATIC_LINK_CUDA) | ||||
| set_bool(AT_CUDNN_ENABLED CAFFE2_USE_CUDNN) | ||||
| set_bool(AT_CUSPARSELT_ENABLED CAFFE2_USE_CUSPARSELT) | ||||
| set_bool(AT_HIPSPARSELT_ENABLED CAFFE2_USE_HIPSPARSELT) | ||||
|  | ||||
| configure_file(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h") | ||||
| # TODO: Do not generate CUDAConfig.h for ROCm BUILDS | ||||
|  | ||||
| @ -28,8 +28,7 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) { | ||||
|       opt_device_type = at::getAccelerator(false); | ||||
|     } | ||||
|     if (opt_device_type.has_value()) { | ||||
|       return at::globalContext().getPinnedMemoryAllocator( | ||||
|           opt_device_type.value()); | ||||
|       return at::globalContext().getPinnedMemoryAllocator(opt_device_type); | ||||
|     } else { | ||||
|       TORCH_CHECK( | ||||
|           false, "Need to provide pin_memory allocator to use pin memory.") | ||||
|  | ||||
| @ -1,18 +1,22 @@ | ||||
| #include <ATen/core/Formatting.h> | ||||
| #include <c10/util/irange.h> | ||||
| #include <fmt/compile.h> | ||||
| #include <fmt/format.h> | ||||
| #include <fmt/ostream.h> | ||||
|  | ||||
| #include <cmath> | ||||
| #include <cstdint> | ||||
| #include <iomanip> | ||||
| #include <iostream> | ||||
| #include <iterator> | ||||
| #include <string> | ||||
| #include <tuple> | ||||
|  | ||||
| namespace c10 { | ||||
| std::ostream& operator<<(std::ostream & out, Backend b) { | ||||
| std::ostream& operator<<(std::ostream& out, Backend b) { | ||||
|   return out << toString(b); | ||||
| } | ||||
|  | ||||
| std::ostream& operator<<(std::ostream & out, const Scalar& s) { | ||||
| std::ostream& operator<<(std::ostream& out, const Scalar& s) { | ||||
|   if (s.isFloatingPoint()) { | ||||
|     return out << s.toDouble(); | ||||
|   } | ||||
| @ -35,179 +39,189 @@ std::ostream& operator<<(std::ostream & out, const Scalar& s) { | ||||
| } | ||||
|  | ||||
| std::string toString(const Scalar& s) { | ||||
|   std::stringstream out; | ||||
|   out << s; | ||||
|   return std::move(out).str(); | ||||
| } | ||||
|   return fmt::format("{}", fmt::streamed(s)); | ||||
| } | ||||
| } // namespace c10 | ||||
|  | ||||
| namespace at { | ||||
|  | ||||
| //not all C++ compilers have default float so we define our own here | ||||
| inline static std::ios_base& defaultfloat(std::ios_base& __base) { | ||||
|   __base.unsetf(std::ios_base::floatfield); | ||||
|   return __base; | ||||
| } | ||||
| //saves/restores number formatting inside scope | ||||
| struct FormatGuard { | ||||
|   FormatGuard(std::ostream & out) | ||||
|   : out(out) { | ||||
|     saved.copyfmt(out); | ||||
|   } | ||||
|   ~FormatGuard() { | ||||
|     out.copyfmt(saved); | ||||
|   } | ||||
|   FormatGuard(const FormatGuard&) = delete; | ||||
|   FormatGuard(FormatGuard&&) = delete; | ||||
|   FormatGuard& operator=(const FormatGuard&) = delete; | ||||
|   FormatGuard& operator=(FormatGuard&&) = delete; | ||||
| private: | ||||
|   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) | ||||
|   std::ostream & out; | ||||
|   std::ios saved{nullptr}; | ||||
| }; | ||||
|  | ||||
| std::ostream& operator<<(std::ostream & out, const DeprecatedTypeProperties& t) { | ||||
| std::ostream& operator<<(std::ostream& out, const DeprecatedTypeProperties& t) { | ||||
|   return out << t.toString(); | ||||
| } | ||||
|  | ||||
| static std::tuple<double, int> __printFormat(std::ostream& stream, const Tensor& self) { | ||||
| enum class FormatType { | ||||
|   Default, // 'g' format (defaultfloat equivalent) | ||||
|   Scientific, // 'e' format with precision 4 | ||||
|   Fixed // 'f' format with precision 4 | ||||
| }; | ||||
|  | ||||
| struct PrintFormat { | ||||
|   double scale; | ||||
|   int width; | ||||
|   FormatType type; | ||||
|  | ||||
|   PrintFormat(double s, int w, FormatType t = FormatType::Default) | ||||
|       : scale(s), width(w), type(t) {} | ||||
| }; | ||||
|  | ||||
| static PrintFormat __printFormat(const Tensor& self) { | ||||
|   auto size = self.numel(); | ||||
|   if(size == 0) { | ||||
|     return std::make_tuple(1., 0); | ||||
|   if (size == 0) { | ||||
|     return PrintFormat(1., 0); | ||||
|   } | ||||
|  | ||||
|   bool intMode = true; | ||||
|   auto self_p = self.const_data_ptr<double>(); | ||||
|   for (const auto i : c10::irange(size)) { | ||||
|     auto z = self_p[i]; | ||||
|     if(std::isfinite(z)) { | ||||
|       if(z != std::ceil(z)) { | ||||
|     if (std::isfinite(z)) { | ||||
|       if (z != std::ceil(z)) { | ||||
|         intMode = false; | ||||
|         break; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   int64_t offset = 0; | ||||
|   while(!std::isfinite(self_p[offset])) { | ||||
|   while (offset < size && !std::isfinite(self_p[offset])) { | ||||
|     offset = offset + 1; | ||||
|     if(offset == size) { | ||||
|       break; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   double expMin = 1; | ||||
|   double expMax = 1; | ||||
|   if(offset != size) { | ||||
|     expMin = fabs(self_p[offset]); | ||||
|     expMax = fabs(self_p[offset]); | ||||
|   if (offset != size) { | ||||
|     expMin = std::fabs(self_p[offset]); | ||||
|     expMax = std::fabs(self_p[offset]); | ||||
|     for (const auto i : c10::irange(offset, size)) { | ||||
|       double z = fabs(self_p[i]); | ||||
|       if(std::isfinite(z)) { | ||||
|         if(z < expMin) { | ||||
|           expMin = z; | ||||
|         } | ||||
|         if(self_p[i] > expMax) { | ||||
|           expMax = z; | ||||
|         } | ||||
|       double z = std::fabs(self_p[i]); | ||||
|       if (std::isfinite(z)) { | ||||
|         expMin = std::min(expMin, z); | ||||
|         expMax = std::max(expMax, z); | ||||
|       } | ||||
|     } | ||||
|     if(expMin != 0) { | ||||
|     if (expMin != 0) { | ||||
|       expMin = std::floor(std::log10(expMin)) + 1; | ||||
|     } else { | ||||
|       expMin = 1; | ||||
|     } | ||||
|     if(expMax != 0) { | ||||
|     if (expMax != 0) { | ||||
|       expMax = std::floor(std::log10(expMax)) + 1; | ||||
|     } else { | ||||
|       expMax = 1; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   double scale = 1; | ||||
|   int sz = 11; | ||||
|   if(intMode) { | ||||
|     if(expMax > 9) { | ||||
|  | ||||
|   if (intMode) { | ||||
|     if (expMax > 9) { | ||||
|       sz = 11; | ||||
|       stream << std::scientific << std::setprecision(4); | ||||
|       return PrintFormat(scale, sz, FormatType::Scientific); | ||||
|     } else { | ||||
|       sz = static_cast<int>(expMax) + 1; | ||||
|       stream << defaultfloat; | ||||
|       return PrintFormat(scale, sz, FormatType::Default); | ||||
|     } | ||||
|   } else { | ||||
|     if(expMax-expMin > 4) { | ||||
|     if (expMax - expMin > 4) { | ||||
|       sz = 11; | ||||
|       if(std::fabs(expMax) > 99 || std::fabs(expMin) > 99) { | ||||
|       if (std::fabs(expMax) > 99 || std::fabs(expMin) > 99) { | ||||
|         sz = sz + 1; | ||||
|       } | ||||
|       stream << std::scientific << std::setprecision(4); | ||||
|       return PrintFormat(scale, sz, FormatType::Scientific); | ||||
|     } else { | ||||
|       if(expMax > 5 || expMax < 0) { | ||||
|       if (expMax > 5 || expMax < 0) { | ||||
|         sz = 7; | ||||
|         scale = std::pow(10, expMax-1); | ||||
|         stream << std::fixed << std::setprecision(4); | ||||
|         scale = std::pow(10, expMax - 1); | ||||
|         return PrintFormat(scale, sz, FormatType::Fixed); | ||||
|       } else { | ||||
|         if(expMax == 0) { | ||||
|         if (expMax == 0) { | ||||
|           sz = 7; | ||||
|         } else { | ||||
|           sz = static_cast<int>(expMax) + 6; | ||||
|         } | ||||
|         stream << std::fixed << std::setprecision(4); | ||||
|         return PrintFormat(scale, sz, FormatType::Fixed); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   return std::make_tuple(scale, sz); | ||||
| } | ||||
|  | ||||
| static void __printIndent(std::ostream &stream, int64_t indent) | ||||
| { | ||||
|   for ([[maybe_unused]] const auto i : c10::irange(indent)) { | ||||
|     stream << " "; | ||||
| // Precompiled format specs | ||||
| static constexpr auto FMT_G = FMT_COMPILE("{:>{}g}"); | ||||
| static constexpr auto FMT_E4 = FMT_COMPILE("{:>{}.4e}"); | ||||
| static constexpr auto FMT_F4 = FMT_COMPILE("{:>{}.4f}"); | ||||
|  | ||||
| // Print a single value directly into the stream buffer with no temporaries | ||||
| static void printValue(std::ostream& stream, double v, const PrintFormat& pf) { | ||||
|   auto out_it = std::ostreambuf_iterator<char>(stream); | ||||
|   double val = v / pf.scale; | ||||
|   switch (pf.type) { | ||||
|     case FormatType::Default: | ||||
|       fmt::format_to(out_it, FMT_G, val, pf.width); | ||||
|       break; | ||||
|     case FormatType::Scientific: | ||||
|       fmt::format_to(out_it, FMT_E4, val, pf.width); | ||||
|       break; | ||||
|     case FormatType::Fixed: | ||||
|       fmt::format_to(out_it, FMT_F4, val, pf.width); | ||||
|       break; | ||||
|   } | ||||
| } | ||||
|  | ||||
| static void printScale(std::ostream & stream, double scale) { | ||||
|   FormatGuard guard(stream); | ||||
|   stream << defaultfloat << scale << " *" << '\n'; | ||||
| } | ||||
| static void __printMatrix(std::ostream& stream, const Tensor& self, int64_t linesize, int64_t indent) | ||||
| { | ||||
|   auto [scale, sz] = __printFormat(stream, self); | ||||
| static void __printMatrix( | ||||
|     std::ostream& stream, | ||||
|     const Tensor& self, | ||||
|     int64_t linesize, | ||||
|     int64_t indent) { | ||||
|   auto printFmt = __printFormat(self); | ||||
|  | ||||
|   __printIndent(stream, indent); | ||||
|   int64_t nColumnPerLine = (linesize-indent)/(sz+1); | ||||
|   int64_t nColumnPerLine = (linesize - indent) / (printFmt.width + 1); | ||||
|   int64_t firstColumn = 0; | ||||
|   int64_t lastColumn = -1; | ||||
|   while(firstColumn < self.size(1)) { | ||||
|     if(firstColumn + nColumnPerLine <= self.size(1)) { | ||||
|  | ||||
|   while (firstColumn < self.size(1)) { | ||||
|     if (firstColumn + nColumnPerLine <= self.size(1)) { | ||||
|       lastColumn = firstColumn + nColumnPerLine - 1; | ||||
|     } else { | ||||
|       lastColumn = self.size(1) - 1; | ||||
|     } | ||||
|     if(nColumnPerLine < self.size(1)) { | ||||
|       if(firstColumn != 0) { | ||||
|         stream << '\n'; | ||||
|  | ||||
|     if (nColumnPerLine < self.size(1)) { | ||||
|       if (firstColumn != 0) { | ||||
|         stream.put('\n'); | ||||
|       } | ||||
|       stream << "Columns " << firstColumn+1 << " to " << lastColumn+1; | ||||
|       __printIndent(stream, indent); | ||||
|       fmt::print( | ||||
|           stream, | ||||
|           "Columns {} to {}{:>{}s}", | ||||
|           firstColumn + 1, | ||||
|           lastColumn + 1, | ||||
|           "", // empty string to pad | ||||
|           indent // width to pad to | ||||
|       ); | ||||
|     } | ||||
|     if(scale != 1) { | ||||
|       printScale(stream,scale); | ||||
|       __printIndent(stream, indent); | ||||
|  | ||||
|     if (printFmt.scale != 1) { | ||||
|       fmt::print(stream, "{} *\n{:>{}s}", printFmt.scale, "", indent); | ||||
|     } | ||||
|  | ||||
|     for (const auto l : c10::irange(self.size(0))) { | ||||
|       Tensor row = self.select(0,l); | ||||
|       const double *row_ptr = row.const_data_ptr<double>(); | ||||
|       for (const auto c : c10::irange(firstColumn, lastColumn+1)) { | ||||
|         stream << std::setw(sz) << row_ptr[c]/scale; | ||||
|         if(c == lastColumn) { | ||||
|           stream << '\n'; | ||||
|           if(l != self.size(0)-1) { | ||||
|             if(scale != 1) { | ||||
|               __printIndent(stream, indent); | ||||
|               stream << " "; | ||||
|       Tensor row = self.select(0, l); | ||||
|       const double* row_ptr = row.const_data_ptr<double>(); | ||||
|  | ||||
|       for (const auto c : c10::irange(firstColumn, lastColumn + 1)) { | ||||
|         printValue(stream, row_ptr[c], printFmt); | ||||
|  | ||||
|         if (c == lastColumn) { | ||||
|           stream.put('\n'); | ||||
|           if (l != self.size(0) - 1) { | ||||
|             if (printFmt.scale != 1) { | ||||
|               fmt::print(stream, "{:>{}s} ", "", indent); | ||||
|             } else { | ||||
|               __printIndent(stream, indent); | ||||
|               fmt::print(stream, "{:>{}s}", "", indent); | ||||
|             } | ||||
|           } | ||||
|         } else { | ||||
|           stream << " "; | ||||
|           stream.put(' '); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
| @ -215,20 +229,21 @@ static void __printMatrix(std::ostream& stream, const Tensor& self, int64_t line | ||||
|   } | ||||
| } | ||||
|  | ||||
| static void __printTensor(std::ostream& stream, Tensor& self, int64_t linesize) | ||||
| { | ||||
|   std::vector<int64_t> counter(self.ndimension()-2); | ||||
| static void __printTensor( | ||||
|     std::ostream& stream, | ||||
|     Tensor& self, | ||||
|     int64_t linesize) { | ||||
|   std::vector<int64_t> counter(self.ndimension() - 2, 0); | ||||
|   counter[0] = -1; | ||||
|  | ||||
|   bool start = true; | ||||
|   bool finished = false; | ||||
|   counter[0] = -1; | ||||
|   for (const auto i : c10::irange(1, counter.size())) { | ||||
|     counter[i] = 0; | ||||
|   } | ||||
|   while(true) { | ||||
|     for(int64_t i = 0; self.ndimension()-2; i++) { | ||||
|  | ||||
|   while (true) { | ||||
|     for (int64_t i = 0; self.ndimension() - 2; i++) { | ||||
|       counter[i] = counter[i] + 1; | ||||
|       if(counter[i] >= self.size(i)) { | ||||
|         if(i == self.ndimension()-3) { | ||||
|       if (counter[i] >= self.size(i)) { | ||||
|         if (i == self.ndimension() - 3) { | ||||
|           finished = true; | ||||
|           break; | ||||
|         } | ||||
| @ -237,108 +252,133 @@ static void __printTensor(std::ostream& stream, Tensor& self, int64_t linesize) | ||||
|         break; | ||||
|       } | ||||
|     } | ||||
|     if(finished) { | ||||
|     if (finished) { | ||||
|       break; | ||||
|     } | ||||
|     if(start) { | ||||
|     if (start) { | ||||
|       start = false; | ||||
|     } else { | ||||
|       stream << '\n'; | ||||
|       stream.put('\n'); | ||||
|     } | ||||
|     stream << "("; | ||||
|  | ||||
|     stream.put('('); | ||||
|     Tensor tensor = self; | ||||
|     for (const auto i : c10::irange(self.ndimension()-2)) { | ||||
|     for (const auto i : c10::irange(self.ndimension() - 2)) { | ||||
|       tensor = tensor.select(0, counter[i]); | ||||
|       stream << counter[i]+1 << ","; | ||||
|       fmt::print(stream, "{},", counter[i] + 1); | ||||
|     } | ||||
|     stream << ".,.) = " << '\n'; | ||||
|     fmt::print(stream, ".,.) = \n"); | ||||
|     __printMatrix(stream, tensor, linesize, 1); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void print(const Tensor & t, int64_t linesize) { | ||||
|   print(std::cout,t,linesize); | ||||
| void print(const Tensor& t, int64_t linesize) { | ||||
|   print(std::cout, t, linesize); | ||||
| } | ||||
| std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesize) { | ||||
|   FormatGuard guard(stream); | ||||
|   if(!tensor_.defined()) { | ||||
|     stream << "[ Tensor (undefined) ]"; | ||||
|   } else if (tensor_.is_sparse()) { | ||||
|     stream << "[ " << tensor_.toString() << "{}\n"; | ||||
|     stream << "indices:\n" << tensor_._indices() << "\n"; | ||||
|     stream << "values:\n" << tensor_._values() << "\n"; | ||||
|     stream << "size:\n" << tensor_.sizes() << "\n"; | ||||
|     stream << "]"; | ||||
|   } else { | ||||
|     Tensor tensor; | ||||
|     if (tensor_.is_quantized()) { | ||||
|       tensor = tensor_.dequantize().to(kCPU, kDouble).contiguous(); | ||||
|     } else if (tensor_.is_mkldnn()) { | ||||
|       stream << "MKLDNN Tensor: "; | ||||
|       tensor = tensor_.to_dense().to(kCPU, kDouble).contiguous(); | ||||
|     } else if (tensor_.is_mps()) { | ||||
|       // MPS does not support double tensors, so first copy then convert | ||||
|       tensor = tensor_.to(kCPU).to(kDouble).contiguous(); | ||||
|     } else { | ||||
|       tensor = tensor_.to(kCPU, kDouble).contiguous(); | ||||
|     } | ||||
|     if(tensor.ndimension() == 0) { | ||||
|       stream << defaultfloat << tensor.const_data_ptr<double>()[0] << '\n'; | ||||
|       stream << "[ " << tensor_.toString() << "{}"; | ||||
|     } else if(tensor.ndimension() == 1) { | ||||
|       if (tensor.numel() > 0) { | ||||
|         auto [scale, sz] = __printFormat(stream, tensor); | ||||
|         if(scale != 1) { | ||||
|           printScale(stream, scale); | ||||
|         } | ||||
|         const double* tensor_p = tensor.const_data_ptr<double>(); | ||||
|         for (const auto i : c10::irange(tensor.size(0))) { | ||||
|           stream << std::setw(sz) << tensor_p[i]/scale << '\n'; | ||||
|         } | ||||
|       } | ||||
|       stream << "[ " << tensor_.toString() << "{" << tensor.size(0) << "}"; | ||||
|     } else if(tensor.ndimension() == 2) { | ||||
|       if (tensor.numel() > 0) { | ||||
|         __printMatrix(stream, tensor, linesize, 0); | ||||
|       } | ||||
|       stream << "[ " << tensor_.toString() << "{" << tensor.size(0) << "," <<  tensor.size(1) << "}"; | ||||
|     } else { | ||||
|       if (tensor.numel() > 0) { | ||||
|         __printTensor(stream, tensor, linesize); | ||||
|       } | ||||
|       stream << "[ " << tensor_.toString() << "{" << tensor.size(0); | ||||
|       for (const auto i : c10::irange(1, tensor.ndimension())) { | ||||
|         stream << "," << tensor.size(i); | ||||
|       } | ||||
|       stream << "}"; | ||||
|     } | ||||
|     if (tensor_.is_quantized()) { | ||||
|       stream << ", qscheme: " << toString(tensor_.qscheme()); | ||||
|       if (tensor_.qscheme() == c10::kPerTensorAffine) { | ||||
|         stream << ", scale: " << tensor_.q_scale(); | ||||
|         stream << ", zero_point: " << tensor_.q_zero_point(); | ||||
|       } else if (tensor_.qscheme() == c10::kPerChannelAffine || | ||||
|           tensor_.qscheme() == c10::kPerChannelAffineFloatQParams) { | ||||
|         stream << ", scales: "; | ||||
|         Tensor scales = tensor_.q_per_channel_scales(); | ||||
|         print(stream, scales, linesize); | ||||
|         stream << ", zero_points: "; | ||||
|         Tensor zero_points = tensor_.q_per_channel_zero_points(); | ||||
|         print(stream, zero_points, linesize); | ||||
|         stream << ", axis: " << tensor_.q_per_channel_axis(); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     // Proxy check for if autograd was built | ||||
|     if (tensor.getIntrusivePtr()->autograd_meta()) { | ||||
|       auto& fw_grad = tensor._fw_grad(/* level */ 0); | ||||
|       if (fw_grad.defined()) { | ||||
|         stream << ", tangent:" << '\n' << fw_grad; | ||||
| std::ostream& print( | ||||
|     std::ostream& stream, | ||||
|     const Tensor& tensor_, | ||||
|     int64_t linesize) { | ||||
|   if (!tensor_.defined()) { | ||||
|     fmt::print(stream, "[ Tensor (undefined) ]"); | ||||
|     return stream; | ||||
|   } | ||||
|  | ||||
|   if (tensor_.is_sparse()) { | ||||
|     fmt::print(stream, "[ {}{{}}\nindices:\n", tensor_.toString()); | ||||
|     print(stream, tensor_._indices(), linesize); | ||||
|     fmt::print(stream, "\nvalues:\n"); | ||||
|     print(stream, tensor_._values(), linesize); | ||||
|     fmt::print(stream, "\nsize:\n{}\n]", fmt::streamed(tensor_.sizes())); | ||||
|     return stream; | ||||
|   } | ||||
|  | ||||
|   Tensor tensor; | ||||
|  | ||||
|   if (tensor_.is_quantized()) { | ||||
|     tensor = tensor_.dequantize().to(kCPU, kDouble).contiguous(); | ||||
|   } else if (tensor_.is_mkldnn()) { | ||||
|     fmt::print(stream, "MKLDNN Tensor: "); | ||||
|     tensor = tensor_.to_dense().to(kCPU, kDouble).contiguous(); | ||||
|   } else if (tensor_.is_mps()) { | ||||
|     // MPS does not support double tensors, so first copy then convert | ||||
|     tensor = tensor_.to(kCPU).to(kDouble).contiguous(); | ||||
|   } else { | ||||
|     tensor = tensor_.to(kCPU, kDouble).contiguous(); | ||||
|   } | ||||
|  | ||||
|   if (tensor.ndimension() == 0) { | ||||
|     fmt::print( | ||||
|         stream, | ||||
|         "{}\n[ {}{{}}", | ||||
|         tensor.const_data_ptr<double>()[0], | ||||
|         tensor_.toString()); | ||||
|   } else if (tensor.ndimension() == 1) { | ||||
|     if (tensor.numel() > 0) { | ||||
|       auto printFmt = __printFormat(tensor); | ||||
|       if (printFmt.scale != 1) { | ||||
|         fmt::print(stream, "{} *\n", printFmt.scale); | ||||
|       } | ||||
|       const double* tensor_p = tensor.const_data_ptr<double>(); | ||||
|       for (const auto i : c10::irange(tensor.size(0))) { | ||||
|         printValue(stream, tensor_p[i], printFmt); | ||||
|         stream.put('\n'); | ||||
|       } | ||||
|     } | ||||
|     stream << " ]"; | ||||
|     fmt::print(stream, "[ {}{{{}}}", tensor_.toString(), tensor.size(0)); | ||||
|   } else if (tensor.ndimension() == 2) { | ||||
|     if (tensor.numel() > 0) { | ||||
|       __printMatrix(stream, tensor, linesize, 0); | ||||
|     } | ||||
|     fmt::print( | ||||
|         stream, | ||||
|         "[ {}{{{},{}}}", | ||||
|         tensor_.toString(), | ||||
|         tensor.size(0), | ||||
|         tensor.size(1)); | ||||
|   } else { | ||||
|     if (tensor.numel() > 0) { | ||||
|       __printTensor(stream, tensor, linesize); | ||||
|     } | ||||
|     fmt::print(stream, "[ {}{{{}", tensor_.toString(), tensor.size(0)); | ||||
|     for (const auto i : c10::irange(1, tensor.ndimension())) { | ||||
|       fmt::print(stream, ",{}", tensor.size(i)); | ||||
|     } | ||||
|     fmt::print(stream, "}}"); | ||||
|   } | ||||
|  | ||||
|   // Add quantization info | ||||
|   if (tensor_.is_quantized()) { | ||||
|     fmt::print(stream, ", qscheme: {}", toString(tensor_.qscheme())); | ||||
|     if (tensor_.qscheme() == c10::kPerTensorAffine) { | ||||
|       fmt::print( | ||||
|           stream, | ||||
|           ", scale: {}, zero_point: {}", | ||||
|           tensor_.q_scale(), | ||||
|           tensor_.q_zero_point()); | ||||
|     } else if ( | ||||
|         tensor_.qscheme() == c10::kPerChannelAffine || | ||||
|         tensor_.qscheme() == c10::kPerChannelAffineFloatQParams) { | ||||
|       fmt::print(stream, ", scales: "); | ||||
|       print(stream, tensor_.q_per_channel_scales(), linesize); | ||||
|       fmt::print(stream, ", zero_points: "); | ||||
|       print(stream, tensor_.q_per_channel_zero_points(), linesize); | ||||
|       fmt::print(stream, ", axis: {}", tensor_.q_per_channel_axis()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   // Proxy check for if autograd was built | ||||
|   if (tensor.getIntrusivePtr()->autograd_meta()) { | ||||
|     auto& fw_grad = tensor._fw_grad(/* level */ 0); | ||||
|     if (fw_grad.defined()) { | ||||
|       fmt::print(stream, ", tangent:\n"); | ||||
|       print(stream, fw_grad, linesize); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   fmt::print(stream, " ]"); | ||||
|   return stream; | ||||
| } | ||||
|  | ||||
| } | ||||
| } // namespace at | ||||
|  | ||||
| @ -205,7 +205,7 @@ std::pair<Vectorized<double>, Vectorized<double>> inline interleave2<double>( | ||||
|     const Vectorized<double>& a, | ||||
|     const Vectorized<double>& b) { | ||||
|   // inputs: | ||||
|   //   a = {a0, a1, a3, a3} | ||||
|   //   a = {a0, a1, a2, a3} | ||||
|   //   b = {b0, b1, b2, b3} | ||||
|  | ||||
|   // swap lanes: | ||||
|  | ||||
| @ -8,6 +8,7 @@ | ||||
| // only be included from C++ files. | ||||
| #define AT_CUDNN_ENABLED() @AT_CUDNN_ENABLED@ | ||||
| #define AT_CUSPARSELT_ENABLED() @AT_CUSPARSELT_ENABLED@ | ||||
| #define AT_HIPSPARSELT_ENABLED() @AT_HIPSPARSELT_ENABLED@ | ||||
| #define AT_ROCM_ENABLED() @AT_ROCM_ENABLED@ | ||||
| #define AT_MAGMA_ENABLED() @AT_MAGMA_ENABLED@ | ||||
|  | ||||
|  | ||||
| @ -159,6 +159,7 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({ | ||||
|   DispatchKey::XLA, | ||||
|   DispatchKey::CUDA, | ||||
|   DispatchKey::CPU, | ||||
|   DispatchKey::PrivateUse1, | ||||
| }); | ||||
|  | ||||
| inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) { | ||||
|  | ||||
| @ -143,7 +143,7 @@ static Tensor make_feature_noise(const Tensor& input) { | ||||
| } | ||||
|  | ||||
| static bool is_fused_kernel_acceptable(const Tensor& input, double p) { | ||||
|   return (input.is_cuda() || input.is_xpu() || input.is_lazy()) && p > 0 && p < 1 && input.numel() > 0; | ||||
|   return (input.is_cuda() || input.is_xpu() || input.is_lazy() || input.is_privateuseone()) && p > 0 && p < 1 && input.numel() > 0; | ||||
| } | ||||
|  | ||||
| // NB: sure, we could have used different overloads here, but I would feel insecure | ||||
|  | ||||
| @ -56,7 +56,8 @@ void dumpTensorCout(const Tensor& tensor) { | ||||
|  | ||||
| static c10::intrusive_ptr<TensorWrapper> makeTensorWrapperPtr(const Tensor& tensor, int64_t level, const std::shared_ptr<bool>& life_handle) { | ||||
|   auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({ | ||||
|       DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA}); | ||||
|       DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA, | ||||
|       DispatchKey::AutogradPrivateUse1}); | ||||
|   auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate); | ||||
|   key_set = key_set.add(DispatchKey::FuncTorchGradWrapper); | ||||
|   return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, life_handle); | ||||
| @ -76,7 +77,8 @@ static Tensor unsafeMakeTensorWrapper( | ||||
|   } | ||||
|  | ||||
|   auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({ | ||||
|       DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA}); | ||||
|       DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA, | ||||
|       DispatchKey::AutogradPrivateUse1}); | ||||
|   auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate); | ||||
|   key_set = key_set.add(DispatchKey::FuncTorchGradWrapper); | ||||
|   auto result = at::detail::make_tensor<TensorWrapper>( | ||||
|  | ||||
| @ -5,6 +5,7 @@ | ||||
| #include <ATen/miopen/miopen-wrapper.h> | ||||
| #include <ATen/core/Tensor.h> | ||||
| #include <ATen/TensorUtils.h> | ||||
| #include <c10/macros/Export.h> | ||||
|  | ||||
| namespace at { namespace native { | ||||
|  | ||||
| @ -37,9 +38,9 @@ struct DescriptorDeleter { | ||||
| // initialized the first time you call set() or any other initializing | ||||
| // function. | ||||
| template <typename T, miopenStatus_t (*ctor)(T**), miopenStatus_t (*dtor)(T*)> | ||||
| class Descriptor | ||||
| { | ||||
| public: | ||||
| // NOLINTNEXTLINE(bugprone-exception-escape) | ||||
| class TORCH_CUDA_CPP_API Descriptor { | ||||
|  public: | ||||
|   // Use desc() to access the underlying descriptor pointer in | ||||
|   // a read-only fashion.  Most client code should use this. | ||||
|   // If the descriptor was never initialized, this will return | ||||
| @ -55,7 +56,7 @@ public: | ||||
| protected: | ||||
|   void init() { | ||||
|     if (desc_ == nullptr) { | ||||
|       T* raw_desc; | ||||
|       T* raw_desc = nullptr; | ||||
|       MIOPEN_CHECK(ctor(&raw_desc)); | ||||
|       desc_.reset(raw_desc); | ||||
|     } | ||||
| @ -64,13 +65,12 @@ private: | ||||
|   std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_; | ||||
| }; | ||||
|  | ||||
| class TensorDescriptor | ||||
|   : public Descriptor<miopenTensorDescriptor, | ||||
|                       &miopenCreateTensorDescriptor, | ||||
|                       &miopenDestroyTensorDescriptor> | ||||
| { | ||||
| public: | ||||
|   TensorDescriptor() {} | ||||
| class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor< | ||||
|                                                miopenTensorDescriptor, | ||||
|                                                &miopenCreateTensorDescriptor, | ||||
|                                                &miopenDestroyTensorDescriptor> { | ||||
|  public: | ||||
|   TensorDescriptor() = default; | ||||
|   explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) { | ||||
|     set(t, pad); | ||||
|   } | ||||
| @ -88,11 +88,10 @@ private: | ||||
|  | ||||
| std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d); | ||||
|  | ||||
| class FilterDescriptor | ||||
|   : public Descriptor<miopenTensorDescriptor, | ||||
|                       &miopenCreateTensorDescriptor, | ||||
|                       &miopenDestroyTensorDescriptor> | ||||
| { | ||||
| class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor< | ||||
|                                                miopenTensorDescriptor, | ||||
|                                                &miopenCreateTensorDescriptor, | ||||
|                                                &miopenDestroyTensorDescriptor> { | ||||
|  public: | ||||
|   void set(const at::Tensor &t, int64_t pad = 0) { | ||||
|     set(t, at::MemoryFormat::Contiguous, pad); | ||||
| @ -106,11 +105,11 @@ private: | ||||
|   } | ||||
| }; | ||||
|  | ||||
| struct ConvolutionDescriptor | ||||
|   : public Descriptor<miopenConvolutionDescriptor, | ||||
|                       &miopenCreateConvolutionDescriptor, | ||||
|                       &miopenDestroyConvolutionDescriptor> | ||||
| { | ||||
| struct TORCH_CUDA_CPP_API ConvolutionDescriptor | ||||
|     : public Descriptor< | ||||
|           miopenConvolutionDescriptor, | ||||
|           &miopenCreateConvolutionDescriptor, | ||||
|           &miopenDestroyConvolutionDescriptor> { | ||||
|   void set(miopenDataType_t dataType, miopenConvolutionMode_t c_mode,  int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool benchmark, bool deterministic) { | ||||
|     MIOPEN_CHECK(miopenInitConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, c_mode)); | ||||
|     MIOPEN_CHECK(miopenSetConvolutionGroupCount(mut_desc(), groups)); | ||||
| @ -121,11 +120,12 @@ struct ConvolutionDescriptor | ||||
|   } | ||||
| }; | ||||
|  | ||||
| struct DropoutDescriptor | ||||
|   : public Descriptor<miopenDropoutDescriptor, | ||||
|                       &miopenCreateDropoutDescriptor, | ||||
|                       &miopenDestroyDropoutDescriptor> | ||||
| { | ||||
| // NOLINTNEXTLINE(bugprone-exception-escape) | ||||
| struct TORCH_CUDA_CPP_API DropoutDescriptor | ||||
|     : public Descriptor< | ||||
|           miopenDropoutDescriptor, | ||||
|           &miopenCreateDropoutDescriptor, | ||||
|           &miopenDestroyDropoutDescriptor> { | ||||
|     void set(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes, | ||||
|              unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) { | ||||
|       MIOPEN_CHECK(miopenSetDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode)); | ||||
| @ -137,7 +137,7 @@ struct DropoutDescriptor | ||||
|     } | ||||
| }; | ||||
|  | ||||
| struct RNNDescriptor | ||||
| struct TORCH_CUDA_CPP_API RNNDescriptor | ||||
|   : public Descriptor<miopenRNNDescriptor, | ||||
|                       &miopenCreateRNNDescriptor, | ||||
|                       &miopenDestroyRNNDescriptor> | ||||
|  | ||||
| @ -1,9 +1,11 @@ | ||||
| #include <ATen/miopen/Exceptions.h> | ||||
| #include <ATen/miopen/Handle.h> | ||||
| #include <ATen/hip/detail/DeviceThreadHandles.h> | ||||
| #include <ATen/miopen/Handle.h> | ||||
| #include <c10/hip/HIPStream.h> | ||||
|  | ||||
| namespace at { namespace native { | ||||
| #include <ATen/hip/Exceptions.h> | ||||
| #include <ATen/miopen/Exceptions.h> | ||||
|  | ||||
| namespace at::native { | ||||
| namespace { | ||||
|  | ||||
| void createMIOpenHandle(miopenHandle_t *handle) { | ||||
| @ -11,30 +13,33 @@ void createMIOpenHandle(miopenHandle_t *handle) { | ||||
| } | ||||
|  | ||||
| void destroyMIOpenHandle(miopenHandle_t handle) { | ||||
| // this is because of something dumb in the ordering of | ||||
| // destruction. Sometimes atexit, the cuda context (or something) | ||||
| // would already be destroyed by the time this gets destroyed. It | ||||
| // happens in fbcode setting. @colesbury and I decided to not destroy | ||||
| // the handle as a workaround. | ||||
| //   - @soumith | ||||
| // | ||||
| // Further note: this is now disabled globally, because we are seeing | ||||
| // the same issue as mentioned above in CUDA 11 CI. | ||||
| //   - @zasdfgbnm | ||||
| // | ||||
| // #ifdef NO_MIOPEN_DESTROY_HANDLE | ||||
| // #else | ||||
| //   miopenDestroy(handle); | ||||
| // #endif | ||||
|   // this is because of something dumb in the ordering of | ||||
|   // destruction. Sometimes atexit, the cuda context (or something) | ||||
|   // would already be destroyed by the time this gets destroyed. It | ||||
|   // happens in fbcode setting. @colesbury and I decided to not destroy | ||||
|   // the handle as a workaround. | ||||
|   //   - @soumith | ||||
|   // | ||||
|   // Further note: this is now disabled globally, because we are seeing | ||||
|   // the same issue as mentioned above in CUDA 11 CI. | ||||
|   //   - @zasdfgbnm | ||||
|   // | ||||
|   // #ifdef NO_MIOPEN_DESTROY_HANDLE | ||||
|   // #else | ||||
|   //   miopenDestroy(handle); | ||||
|   // #endif | ||||
| } | ||||
|  | ||||
| using MIOpenPoolType = at::cuda::DeviceThreadHandlePool<miopenHandle_t, createMIOpenHandle, destroyMIOpenHandle>; | ||||
| using MIOpenPoolType = at::cuda::DeviceThreadHandlePool< | ||||
|     miopenHandle_t, | ||||
|     createMIOpenHandle, | ||||
|     destroyMIOpenHandle>; | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| miopenHandle_t getMiopenHandle() { | ||||
|   int device; | ||||
|   HIP_CHECK(hipGetDevice(&device)); | ||||
|   c10::DeviceIndex device = 0; | ||||
|   AT_CUDA_CHECK(c10::hip::GetDevice(&device)); | ||||
|  | ||||
|   // Thread local PoolWindows are lazily-initialized | ||||
|   // to avoid initialization issues that caused hangs on Windows. | ||||
| @ -46,8 +51,8 @@ miopenHandle_t getMiopenHandle() { | ||||
|       pool->newPoolWindow()); | ||||
|  | ||||
|   auto handle = myPoolWindow->reserve(device); | ||||
|   MIOPEN_CHECK(miopenSetStream(handle, at::hip::getCurrentHIPStream())); | ||||
|   MIOPEN_CHECK(miopenSetStream(handle, c10::hip::getCurrentHIPStream())); | ||||
|   return handle; | ||||
| } | ||||
|  | ||||
| }} // namespace at::native | ||||
| } // namespace at::native | ||||
|  | ||||
| @ -1,9 +1,9 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/miopen/miopen-wrapper.h> | ||||
| #include <c10/macros/Export.h> | ||||
|  | ||||
| namespace at { namespace native { | ||||
| namespace at::native { | ||||
|  | ||||
| miopenHandle_t getMiopenHandle(); | ||||
|  | ||||
| }} // namespace | ||||
| TORCH_CUDA_CPP_API miopenHandle_t getMiopenHandle(); | ||||
| } // namespace at::native | ||||
|  | ||||
| @ -1,12 +1,13 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/miopen/miopen-wrapper.h> | ||||
| #include <ATen/Tensor.h> | ||||
| #include <ATen/miopen/miopen-wrapper.h> | ||||
| #include <c10/macros/Export.h> | ||||
|  | ||||
| namespace at { namespace native { | ||||
| namespace at::native { | ||||
|  | ||||
| miopenDataType_t getMiopenDataType(const at::Tensor& tensor); | ||||
| TORCH_CUDA_CPP_API miopenDataType_t getMiopenDataType(const at::Tensor& tensor); | ||||
|  | ||||
| int64_t miopen_version(); | ||||
|  | ||||
| }}  // namespace at::miopen | ||||
| } // namespace at::native | ||||
|  | ||||
| @ -138,7 +138,7 @@ inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset, | ||||
|  | ||||
|   // storageOffset | ||||
|   TORCH_CHECK( | ||||
|       storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset); | ||||
|     TORCH_GUARD_OR_TRUE(sym_ge(storage_offset, 0)), "Tensor: invalid storage offset ", storage_offset); | ||||
|  | ||||
|   // set_storage_{device} (except set_storage_meta__symint) | ||||
|   // will (unsafely) set the storage offset and then call resize_impl that | ||||
|  | ||||
| @ -431,7 +431,7 @@ Tensor& set_storage_meta__symint( | ||||
|       size, stride, storage_offset); | ||||
|  | ||||
|   // Matches maybe_resize_storage_cpu no-numel behavior | ||||
|   if (TORCH_GUARD_SIZE_OBLIVIOUS(result.sym_numel().sym_ne(0))) { | ||||
|   if (TORCH_GUARD_OR_TRUE(result.sym_numel().sym_ne(0))) { | ||||
|     // maybe_resize_storage_cpu can handle no storage exists at all but | ||||
|     // that should never be the case here | ||||
|     TORCH_INTERNAL_ASSERT(storage); | ||||
| @ -440,12 +440,7 @@ Tensor& set_storage_meta__symint( | ||||
|     // All meta data pointers are the same, so we don't have to "re" allocate | ||||
|     // it.  TODO: Actually this might not quite be correct if we use special | ||||
|     // pointers to track whether or not fake cuda tensors are pinned or not | ||||
|     const auto itemsize = result.dtype().itemsize(); | ||||
|     c10::SymInt new_size_bytes = result.is_contiguous() | ||||
|         ? at::detail::computeStorageNbytesContiguous( | ||||
|               size, itemsize, std::move(storage_offset)) | ||||
|         : at::detail::computeStorageNbytes( | ||||
|               size, stride, itemsize, std::move(storage_offset)); | ||||
|  | ||||
|     // TODO: When there are unbacked SymInts, we unconditionally skip the | ||||
|     // setter.  This is technically wrong, but we cannot conveniently test | ||||
|     // the real condition in many cases, because a lot of people are using | ||||
| @ -454,10 +449,20 @@ Tensor& set_storage_meta__symint( | ||||
|     // | ||||
|     // The old behavior was to unconditionally set_nbytes, but I think not | ||||
|     // setting it is more safe. | ||||
|     if (new_size_bytes.has_hint() && storage.sym_nbytes().has_hint() && | ||||
|         TORCH_GUARD_SIZE_OBLIVIOUS( | ||||
|             new_size_bytes.sym_gt(storage.sym_nbytes()))) { | ||||
|       storage.set_nbytes(std::move(new_size_bytes)); | ||||
|     if (result.sym_numel().has_hint()) { | ||||
|       const auto itemsize = result.dtype().itemsize(); | ||||
|  | ||||
|       c10::SymInt new_size_bytes = result.is_contiguous() | ||||
|           ? at::detail::computeStorageNbytesContiguous( | ||||
|                 size, itemsize, std::move(storage_offset)) | ||||
|           : at::detail::computeStorageNbytes( | ||||
|                 size, stride, itemsize, std::move(storage_offset)); | ||||
|  | ||||
|       if (new_size_bytes.has_hint() && storage.sym_nbytes().has_hint() && | ||||
|           TORCH_GUARD_SIZE_OBLIVIOUS( | ||||
|               new_size_bytes.sym_gt(storage.sym_nbytes()))) { | ||||
|         storage.set_nbytes(std::move(new_size_bytes)); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   return result; | ||||
|  | ||||
| @ -345,8 +345,8 @@ static inline void launch_vectorized_kernel( | ||||
|       auto output_calc = TrivialOffsetCalculator<1>(); | ||||
|       auto loader = memory::LoadWithoutCast(); | ||||
|       auto storer = memory::StoreWithoutCast(); | ||||
|       int64_t grid_unrolled = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>(); | ||||
|       unrolled_elementwise_kernel<func_t, array_t, elems_per_thread<io_size>()> | ||||
|       int64_t grid_unrolled = (N + elementwise_block_work_size() - 1) / elementwise_block_work_size(); | ||||
|       unrolled_elementwise_kernel<func_t, array_t, elementwise_thread_work_size()> | ||||
|           <<<grid_unrolled, num_threads(), 0, stream>>>( | ||||
|               N, f, data, input_calc, output_calc, loader, storer); | ||||
|       C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
|  | ||||
| @ -28,9 +28,15 @@ __device__ inline int min(int a, int b) { | ||||
|   return a <= b ? a : b; | ||||
| } | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
| #define CUDA_MAX_THREADS 256 | ||||
| #define BLOCK_STRIDE_FWD 2 // increasing block_stride to lower # of blocks launched | ||||
| #define BLOCK_STRIDE_BWD 4 // increasing block_stride to lower # of blocks launched | ||||
| #else | ||||
| #define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit | ||||
|  | ||||
| #define BLOCK_STRIDE 2 // increasing block_stride to lower # of blocks launched | ||||
| #define BLOCK_STRIDE_FWD 2 // increasing block_stride to lower # of blocks launched | ||||
| #define BLOCK_STRIDE_BWD 2 // increasing block_stride to lower # of blocks launched | ||||
| #endif | ||||
|  | ||||
| static __device__ inline int p_start(int size, int pad, int kernel, int dilation, int stride) { | ||||
|   return (size + pad < ((kernel - 1) * dilation + 1)) ? 0 : (size + pad - ((kernel - 1) * dilation + 1)) / stride + 1; | ||||
| @ -464,10 +470,10 @@ const Tensor& indices) { | ||||
|           int grid_x = nbatch*kernel_stride_C; | ||||
|           int grid_y = std::min<int>( | ||||
|               at::cuda::getCurrentDeviceProperties()->maxGridSize[1], | ||||
|               ceil_div(safe_downcast<int, int64_t>(outputWidth), block_y*BLOCK_STRIDE)); | ||||
|               ceil_div(safe_downcast<int, int64_t>(outputWidth), block_y*BLOCK_STRIDE_FWD)); | ||||
|           int grid_z = std::min<int>( | ||||
|               at::cuda::getCurrentDeviceProperties()->maxGridSize[2], | ||||
|               ceil_div(safe_downcast<int, int64_t>(outputHeight), block_z*BLOCK_STRIDE)); | ||||
|               ceil_div(safe_downcast<int, int64_t>(outputHeight), block_z*BLOCK_STRIDE_FWD)); | ||||
|           const dim3 grid(grid_x, grid_y, grid_z); | ||||
|  | ||||
|           size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof(int) + sizeof(scalar_t)); | ||||
| @ -599,10 +605,10 @@ const Tensor& gradInput) { | ||||
|           int grid_x = nbatch*kernel_stride_C; | ||||
|           int grid_y = std::min<int>( | ||||
|               at::cuda::getCurrentDeviceProperties()->maxGridSize[1], | ||||
|               ceil_div(safe_downcast<int, int64_t>(inputWidth), block_y*BLOCK_STRIDE)); | ||||
|               ceil_div(safe_downcast<int, int64_t>(inputWidth), block_y*BLOCK_STRIDE_BWD)); | ||||
|           int grid_z = std::min<int>( | ||||
|               at::cuda::getCurrentDeviceProperties()->maxGridSize[2], | ||||
|               ceil_div(safe_downcast<int, int64_t>(inputHeight), block_z*BLOCK_STRIDE)); | ||||
|               ceil_div(safe_downcast<int, int64_t>(inputHeight), block_z*BLOCK_STRIDE_BWD)); | ||||
|           const dim3 grid(grid_x, grid_y, grid_z); | ||||
|  | ||||
|           size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * sizeof(accscalar_t); | ||||
|  | ||||
| @ -1159,7 +1159,8 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ | ||||
|       config.ctas_per_output = div_up(num_mp, 2); | ||||
|     else if (config.ctas_per_output < 16) | ||||
|       config.ctas_per_output = 1; | ||||
|     if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension) | ||||
|     bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast); | ||||
|     if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) | ||||
|       config.ctas_per_output = 4; | ||||
| #endif | ||||
|     if (config.ctas_per_output > 1) { | ||||
|  | ||||
							
								
								
									
										594
									
								
								aten/src/ATen/native/mkldnn/xpu/detail/DnnlExt.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										594
									
								
								aten/src/ATen/native/mkldnn/xpu/detail/DnnlExt.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,594 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/ATen.h> | ||||
|  | ||||
| #include <ATen/native/mkldnn/xpu/detail/LRUCache.h> | ||||
| #include <ATen/native/mkldnn/xpu/detail/Utils.h> | ||||
| #include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h> | ||||
|  | ||||
| #include <oneapi/dnnl/dnnl.h> | ||||
| #include <oneapi/dnnl/dnnl.hpp> | ||||
|  | ||||
| namespace std { | ||||
|  | ||||
| template <> | ||||
| struct hash<dnnl::memory::dims> { | ||||
|   size_t operator()(dnnl::memory::dims const& vec) const { | ||||
|     size_t seed = vec.size(); | ||||
|     for (auto& i : vec) { | ||||
|       seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2); | ||||
|     } | ||||
|     return seed; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| } // namespace std | ||||
|  | ||||
| using namespace dnnl; | ||||
|  | ||||
| namespace at::native::onednn { | ||||
|  | ||||
| class primitive_ext : public primitive { | ||||
|   static constexpr int max_args = 12; | ||||
|  | ||||
|  public: | ||||
|   primitive_ext(const primitive& base) : primitive(base) {} | ||||
|   primitive_ext(primitive&& base) : primitive(std::move(base)) {} | ||||
|  | ||||
|   /// Returns a memory descriptor. | ||||
|   /// | ||||
|   /// @note | ||||
|   ///     There are also convenience methods | ||||
|   ///     #dnnl::primitive_desc_base::src_desc(), | ||||
|   ///     #dnnl::primitive_desc_base::dst_desc(), and others. | ||||
|   /// | ||||
|   /// @param what The kind of parameter to query; can be | ||||
|   ///     #dnnl::query::src_md, #dnnl::query::dst_md, etc. | ||||
|   /// @param idx Index of the parameter. For example, convolution bias can | ||||
|   ///     be queried with what = #dnnl::query::weights_md and idx = 1. | ||||
|   /// @returns The requested memory descriptor. | ||||
|   /// @returns A zero memory descriptor if the primitive does not have a | ||||
|   ///     parameter of the specified kind or index. | ||||
|   const_dnnl_memory_desc_t query_md(query what, int idx = 0) const { | ||||
|     std::vector<query> valid_q{ | ||||
|         query::src_md, | ||||
|         query::diff_src_md, | ||||
|         query::weights_md, | ||||
|         query::diff_weights_md, | ||||
|         query::dst_md, | ||||
|         query::diff_dst_md, | ||||
|         query::workspace_md, | ||||
|         query::scratchpad_md, | ||||
|         query::exec_arg_md}; | ||||
|     if (!std::any_of(valid_q.cbegin(), valid_q.cend(), [=](query q) { | ||||
|           return what == q; | ||||
|         })) | ||||
|       DNNL_THROW_ERROR( | ||||
|           dnnl_invalid_arguments, "memory descriptor query is invalid"); | ||||
|  | ||||
|     const_dnnl_memory_desc_t cdesc = dnnl_primitive_desc_query_md( | ||||
|         this->get_primitive_desc(), dnnl::convert_to_c(what), idx); | ||||
|  | ||||
|     return cdesc ? cdesc : nullptr; | ||||
|   } | ||||
|  | ||||
|   /// Returns a source memory descriptor. | ||||
|   /// @param idx Source index. | ||||
|   /// @returns Source memory descriptor. | ||||
|   /// @returns A zero memory descriptor if the primitive does not have a | ||||
|   ///     source parameter with index @p idx. | ||||
|   const_dnnl_memory_desc_t src_desc(int idx) const { | ||||
|     return query_md(query::src_md, idx); | ||||
|   } | ||||
|  | ||||
|   /// Returns a destination memory descriptor. | ||||
|   /// @param idx Destination index. | ||||
|   /// @returns Destination memory descriptor. | ||||
|   /// @returns A zero memory descriptor if the primitive does not have a | ||||
|   ///     destination parameter with index @p idx. | ||||
|   const_dnnl_memory_desc_t dst_desc(int idx) const { | ||||
|     return query_md(query::dst_md, idx); | ||||
|   } | ||||
|  | ||||
|   /// Returns a weights memory descriptor. | ||||
|   /// @param idx Weights index. | ||||
|   /// @returns Weights memory descriptor. | ||||
|   /// @returns A zero memory descriptor if the primitive does not have a | ||||
|   ///     weights parameter with index @p idx. | ||||
|   const_dnnl_memory_desc_t weights_desc(int idx) const { | ||||
|     return query_md(query::weights_md, idx); | ||||
|   } | ||||
|  | ||||
|   /// Returns a diff source memory descriptor. | ||||
|   /// @param idx Diff source index. | ||||
|   /// @returns Diff source memory descriptor. | ||||
|   /// @returns A zero memory descriptor if the primitive does not have a | ||||
|   ///     diff source parameter with index @p idx. | ||||
|   const_dnnl_memory_desc_t diff_src_desc(int idx) const { | ||||
|     return query_md(query::diff_src_md, idx); | ||||
|   } | ||||
|  | ||||
|   /// Returns a diff destination memory descriptor. | ||||
|   /// @param idx Diff destination index. | ||||
|   /// @returns Diff destination memory descriptor. | ||||
|   /// @returns A zero memory descriptor if the primitive does not have a | ||||
|   ///     diff destination parameter with index @p idx. | ||||
|   const_dnnl_memory_desc_t diff_dst_desc(int idx) const { | ||||
|     return query_md(query::diff_dst_md, idx); | ||||
|   } | ||||
|  | ||||
|   /// Returns a diff weights memory descriptor. | ||||
|   /// @param idx Diff weights index. | ||||
|   /// @returns Diff weights memory descriptor. | ||||
|   /// @returns A zero memory descriptor if the primitive does not have a | ||||
|   ///     diff weights parameter with index @p idx. | ||||
|   const_dnnl_memory_desc_t diff_weights_desc(int idx) const { | ||||
|     return query_md(query::diff_weights_md, idx); | ||||
|   } | ||||
|  | ||||
|   const_dnnl_memory_desc_t exec_arg_desc(int idx) const { | ||||
|     return query_md(query::exec_arg_md, idx); | ||||
|   } | ||||
|  | ||||
|   // Separate versions without the index argument for documentation | ||||
|   // purposes. | ||||
|  | ||||
|   /// Returns a source memory descriptor. | ||||
|   /// @returns Source memory descriptor. | ||||
|   /// @returns A zero memory descriptor if the primitive does not have a | ||||
|   ///     source parameter. | ||||
|   const_dnnl_memory_desc_t src_desc() const { | ||||
|     return src_desc(0); | ||||
|   } | ||||
|  | ||||
|   /// Returns a destination memory descriptor. | ||||
|   /// @returns Destination memory descriptor. | ||||
|   /// @returns A zero memory descriptor if the primitive does not have a | ||||
|   ///     destination parameter. | ||||
|   const_dnnl_memory_desc_t dst_desc() const { | ||||
|     return dst_desc(0); | ||||
|   } | ||||
|  | ||||
|   /// Returns a weights memory descriptor. | ||||
|   /// @returns Weights memory descriptor. | ||||
|   /// @returns A zero memory descriptor if the primitive does not have a | ||||
|   ///     weights parameter. | ||||
|   const_dnnl_memory_desc_t weights_desc() const { | ||||
|     return weights_desc(0); | ||||
|   } | ||||
|  | ||||
|   /// Returns a diff source memory descriptor. | ||||
|   /// @returns Diff source memory descriptor. | ||||
|   /// @returns A zero memory descriptor if the primitive does not have a | ||||
|   ///     diff source memory with. | ||||
|   const_dnnl_memory_desc_t diff_src_desc() const { | ||||
|     return diff_src_desc(0); | ||||
|   } | ||||
|  | ||||
|   /// Returns a diff destination memory descriptor. | ||||
|   /// @returns Diff destination memory descriptor. | ||||
|   /// @returns A zero memory descriptor if the primitive does not have a | ||||
|   ///     diff destination parameter. | ||||
|   const_dnnl_memory_desc_t diff_dst_desc() const { | ||||
|     return diff_dst_desc(0); | ||||
|   } | ||||
|  | ||||
|   /// Returns a diff weights memory descriptor. | ||||
|   /// @returns Diff weights memory descriptor. | ||||
|   /// @returns A zero memory descriptor if the primitive does not have a | ||||
|   ///     diff weights parameter. | ||||
|   const_dnnl_memory_desc_t diff_weights_desc() const { | ||||
|     return diff_weights_desc(0); | ||||
|   } | ||||
|  | ||||
|   /// Returns the workspace memory descriptor. | ||||
|   /// @returns Workspace memory descriptor. | ||||
|   /// @returns A zero memory descriptor if the primitive does not require | ||||
|   ///     workspace parameter. | ||||
|   const_dnnl_memory_desc_t workspace_desc() const { | ||||
|     return query_md(query::workspace_md, 0); | ||||
|   } | ||||
|  | ||||
|   /// Returns the scratchpad memory descriptor. | ||||
|   /// @returns scratchpad memory descriptor. | ||||
|   /// @returns A zero memory descriptor if the primitive does not require | ||||
|   ///     scratchpad parameter. | ||||
|   /// @sa @ref dev_guide_attributes_scratchpad | ||||
|   const_dnnl_memory_desc_t scratchpad_desc() const { | ||||
|     return query_md(query::scratchpad_md, 0); | ||||
|   } | ||||
|  | ||||
|   inline memory make_memory( | ||||
|       const_dnnl_memory_desc_t md_t, | ||||
|       const engine& aengine, | ||||
|       void* handle = DNNL_MEMORY_ALLOCATE) const { | ||||
|     sycl_interop::memory_kind kind = dnnl::sycl_interop::memory_kind::usm; | ||||
|     dnnl_memory_t c_memory; | ||||
|     error::wrap_c_api( | ||||
|         dnnl_sycl_interop_memory_create( | ||||
|             &c_memory, md_t, aengine.get(), convert_to_c(kind), handle), | ||||
|         "could not create a memory"); | ||||
|     return memory(c_memory); | ||||
|   } | ||||
|  | ||||
|   memory make_src(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE) | ||||
|       const { | ||||
|     return make_memory(src_desc(), aengine, handle); | ||||
|   } | ||||
|  | ||||
|   memory make_weight(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE) | ||||
|       const { | ||||
|     return make_memory(weights_desc(), aengine, handle); | ||||
|   } | ||||
|  | ||||
|   memory make_bias(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE) | ||||
|       const { | ||||
|     return make_memory(weights_desc(1), aengine, handle); | ||||
|   } | ||||
|  | ||||
|   memory make_dst(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE) | ||||
|       const { | ||||
|     return make_memory(dst_desc(), aengine, handle); | ||||
|   } | ||||
|  | ||||
|   memory make_scratchpad( | ||||
|       const engine& aengine, | ||||
|       void* handle = DNNL_MEMORY_ALLOCATE) const { | ||||
|     return make_memory(scratchpad_desc(), aengine, handle); | ||||
|   } | ||||
|  | ||||
|   size_t get_scratchpad_size() const { | ||||
|     return dnnl_memory_desc_get_size(scratchpad_desc()); | ||||
|   } | ||||
|  | ||||
|   memory make_args(int arg_class, const engine& aengine, void* handle) const { | ||||
|     switch (arg_class) { | ||||
|       case DNNL_ARG_SRC: | ||||
|         return make_src(aengine, handle); | ||||
|       case DNNL_ARG_WEIGHTS: | ||||
|         return make_weight(aengine, handle); | ||||
|       case DNNL_ARG_SCRATCHPAD: | ||||
|         return make_scratchpad(aengine, handle); | ||||
|       case DNNL_ARG_DST: | ||||
|         return make_dst(aengine, handle); | ||||
|       case DNNL_ARG_BIAS: | ||||
|         return make_bias(aengine, handle); | ||||
|       default: | ||||
|         TORCH_INTERNAL_ASSERT( | ||||
|             false, "unsupported argument class for primitive_ext"); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   template <typename M> | ||||
|   void set_attribute(int slot, int arg_class, void* handle, M constructor) { | ||||
|     if (mem_arg_cache[slot]) | ||||
|       mem_arg_cache[slot].set_data_handle(handle); | ||||
|     else { | ||||
|       mem_arg_cache[slot] = constructor(); | ||||
|       c_args[slot].arg = arg_class; | ||||
|       c_args[slot].memory = mem_arg_cache[slot].get(); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   sycl::event execute( | ||||
|       const stream& astream, | ||||
|       const engine& aengine, | ||||
|       std::vector<std::pair<int, void*>>&& handles, | ||||
|       int slot_off = 2) { | ||||
|     auto off = slot_off; | ||||
|     for (const auto& p : handles) { | ||||
|       auto& m_arg = mem_arg_cache[off]; | ||||
|       if (m_arg) | ||||
|         m_arg.set_data_handle(p.second); | ||||
|       else { | ||||
|         m_arg = make_args(p.first, aengine, p.second); | ||||
|         c_args[off].arg = p.first; | ||||
|         c_args[off].memory = m_arg.get(); | ||||
|       } | ||||
|       ++off; | ||||
|     } | ||||
|  | ||||
|     sycl::event return_event; | ||||
|     std::vector<sycl::event> deps{}; | ||||
|     error::wrap_c_api( | ||||
|         dnnl_sycl_interop_primitive_execute( | ||||
|             this->get(), astream.get(), off, c_args, &deps, &return_event), | ||||
|         "could not execute a primitive"); | ||||
|     return return_event; | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   memory mem_arg_cache[max_args]; | ||||
|   dnnl_exec_arg_t c_args[max_args]; | ||||
| }; | ||||
|  | ||||
| // Specifies the combined data types of input and weight tensors. | ||||
| // For example, f32 means both input and weight are FP32, | ||||
| // bf16_int4 means input is BF16 and weight is INT4. | ||||
| enum class joint_dtypes_t { f32 = 0, f16, bf16, int8, f16_int4, bf16_int4 }; | ||||
|  | ||||
| // Specifies the transposition state of input and weight tensors. | ||||
| // Convention: first letter = input, second letter = weight. | ||||
| // 'n' = not transposed, 't' = transposed. | ||||
| // For example, 'nt' means input is not transposed, weight is transposed. | ||||
| enum class trans_type_t { nn = 0, nt, tn, tt }; | ||||
|  | ||||
| // Specifies the type and placement of bias in the computation. | ||||
| // 'none' = no bias, | ||||
| // 'scalar' = a single scalar bias applied to all elements, | ||||
| // 'm' = per-row bias (typically matched to input rows), | ||||
| // 'n' = per-column bias (typically matched to output channels), | ||||
| // 'mn' = full bias matrix matching the output dimensions. | ||||
| enum class bias_type_t { none = 0, scalar, m, n, mn }; | ||||
|  | ||||
| template <typename T> | ||||
| T concat(const T& t1, at::ScalarType d) { | ||||
|   T t; | ||||
|   t.insert(t.end(), t1.begin(), t1.end()); | ||||
|   t.push_back((int64_t)d); | ||||
|  | ||||
|   return t; | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| T concat(const T& t1, bool b) { | ||||
|   T t; | ||||
|   t.insert(t.end(), t1.begin(), t1.end()); | ||||
|   t.push_back(b); | ||||
|  | ||||
|   return t; | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| T concat(const T& t1, int b) { | ||||
|   T t; | ||||
|   t.insert(t.end(), t1.begin(), t1.end()); | ||||
|   t.push_back(b); | ||||
|  | ||||
|   return t; | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| T concat(const T& t1, const T& t2) { | ||||
|   T t; | ||||
|   t.insert(t.end(), t1.begin(), t1.end()); | ||||
|   t.insert(t.end(), t2.begin(), t2.end()); | ||||
|  | ||||
|   return t; | ||||
| } | ||||
|  | ||||
| template <typename T1, typename T2, typename... Ts> | ||||
| T1 concat(const T1& t1, const T2& t2, const Ts&... ts) { | ||||
|   return concat(concat(t1, t2), ts...); | ||||
| } | ||||
|  | ||||
| template <joint_dtypes_t Ts> | ||||
| struct onednn_types_mapper; | ||||
|  | ||||
| template <> | ||||
| struct onednn_types_mapper<joint_dtypes_t::f16_int4> { | ||||
|   static inline std::tuple<dnnl::memory::data_type, dnnl::memory::data_type> | ||||
|   get() { | ||||
|     return std::make_tuple( | ||||
|         dnnl::memory::data_type::f16, dnnl::memory::data_type::u4); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct onednn_types_mapper<joint_dtypes_t::bf16_int4> { | ||||
|   static inline std::tuple<dnnl::memory::data_type, dnnl::memory::data_type> | ||||
|   get() { | ||||
|     return std::make_tuple( | ||||
|         dnnl::memory::data_type::bf16, dnnl::memory::data_type::u4); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| // TODO: bias types maybe not right | ||||
| static inline dnnl::memory::dims get_bias_type( | ||||
|     bias_type_t b_dims, | ||||
|     const int m, | ||||
|     const int n) { | ||||
|   switch (b_dims) { | ||||
|     case bias_type_t::none: | ||||
|       return {0}; | ||||
|     case bias_type_t::scalar: | ||||
|       return {1, 1}; | ||||
|     case bias_type_t::m: | ||||
|       return {m, 1}; | ||||
|     case bias_type_t::n: | ||||
|       return {1, n}; | ||||
|     case bias_type_t::mn: | ||||
|       return {m, n}; | ||||
|     default: | ||||
|       TORCH_INTERNAL_ASSERT(false, "unsupported bias type ..."); | ||||
|   } | ||||
| } | ||||
|  | ||||
| // TODO: use template specialization on struct | ||||
| template <trans_type_t Tt> | ||||
| inline void get_strides( | ||||
|     memory::dims& src_strides, | ||||
|     memory::dims& wei_strides, | ||||
|     memory::dims& dst_strides, | ||||
|     const int64_t lda, | ||||
|     const int64_t ldb, | ||||
|     const int64_t ldc) {} | ||||
|  | ||||
| template <> | ||||
| inline void get_strides<trans_type_t::nt>( | ||||
|     memory::dims& src_strides, | ||||
|     memory::dims& wei_strides, | ||||
|     memory::dims& dst_strides, | ||||
|     const int64_t lda, | ||||
|     const int64_t ldb, | ||||
|     const int64_t ldc) { | ||||
|   src_strides = {lda, 1}; | ||||
|   wei_strides = {1, ldb}; | ||||
|   dst_strides = {ldc, 1}; | ||||
| } | ||||
|  | ||||
| using primitive_cache = | ||||
|     at::native::onednn::lru_cache<memory::dims, primitive_ext>; | ||||
|  | ||||
| template <trans_type_t Tt, joint_dtypes_t Ts, typename F> | ||||
| struct matmul_primitive_cache_t { | ||||
|   static inline primitive_ext& get( | ||||
|       const int m, | ||||
|       const int n, | ||||
|       const int k, | ||||
|       const int64_t lda, | ||||
|       const int64_t ldb, | ||||
|       const int64_t ldc, | ||||
|       const bias_type_t | ||||
|           b_dims, // for shapeless bias, not put it into template parameter | ||||
|       const int device_id, | ||||
|       F f_attr, | ||||
|       const int64_t scale_group_size, | ||||
|       const int64_t zp_group_size) { | ||||
|     auto& cached = get_cache(device_id); | ||||
|     memory::dims src_strides, wei_strides, dst_strides; | ||||
|     get_strides<Tt>(src_strides, wei_strides, dst_strides, lda, ldb, ldc); | ||||
|     auto pri_key = at::native::onednn::concat( | ||||
|         src_strides, | ||||
|         wei_strides, | ||||
|         m, | ||||
|         n, | ||||
|         k, | ||||
|         int(b_dims), | ||||
|         int(scale_group_size), | ||||
|         int(zp_group_size)); | ||||
|     auto iter = cached.find(pri_key); | ||||
|     if (iter == cached.end()) { | ||||
|       auto [src_dt, wei_dt] = onednn_types_mapper<Ts>::get(); | ||||
|       auto bias_dims = get_bias_type(b_dims, m, n); | ||||
|  | ||||
|       auto src_md = memory::desc({m, k}, src_dt, src_strides); | ||||
|       auto wei_md = memory::desc({k, n}, wei_dt, wei_strides); | ||||
|       auto dst_md = memory::desc({m, n}, src_dt, dst_strides); | ||||
|       auto bias_format = b_dims == bias_type_t::none | ||||
|           ? dnnl::memory::format_tag::undef | ||||
|           : dnnl::memory::format_tag::ab; | ||||
|       auto bias_md = | ||||
|           memory::desc(bias_dims, src_dt, bias_format); // {m, n} or {1, n} | ||||
|  | ||||
|       primitive_attr pattr; | ||||
|       f_attr(pattr); | ||||
|  | ||||
|       dnnl::matmul::primitive_desc matmul_pd; | ||||
|       auto aengine = | ||||
|           at::native::onednn::GpuEngineManager::Instance().get_engine( | ||||
|               device_id); | ||||
|       if (b_dims == bias_type_t::none) { | ||||
|         matmul_pd = dnnl::matmul::primitive_desc( | ||||
|             aengine, src_md, wei_md, dst_md, pattr); | ||||
|       } else { | ||||
|         matmul_pd = dnnl::matmul::primitive_desc( | ||||
|             aengine, src_md, wei_md, bias_md, dst_md, pattr); | ||||
|       } | ||||
|  | ||||
|       return cached.insert({pri_key, primitive_ext(dnnl::matmul(matmul_pd))}) | ||||
|           .first->second; | ||||
|     } else { | ||||
|       return iter->second; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   static constexpr int max_cache_capacity = 512; | ||||
|   // if default constructor of primitive cache could read the environment | ||||
|   // variable then it'll save a lot of trouble | ||||
|   static inline thread_local std::array<primitive_cache, 16> mappings; | ||||
|  | ||||
|   // this won't be needed if primitive_cache have good default constructor | ||||
|   static inline primitive_cache& get_cache(const int device_id) { | ||||
|     auto& mapping = mappings[device_id]; | ||||
|     if (mapping.max_size() == 0) { | ||||
|       mapping.resize(max_cache_capacity); | ||||
|     } | ||||
|     return mapping; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <joint_dtypes_t Ts, typename F> | ||||
| static inline primitive_ext& matmul_primitive_create_and_cache( | ||||
|     const trans_type_t Tt, | ||||
|     const bias_type_t b_dims, | ||||
|     const int m, | ||||
|     const int n, | ||||
|     const int k, | ||||
|     const int64_t lda, | ||||
|     const int64_t ldb, | ||||
|     const int64_t ldc, | ||||
|     const int device_id, | ||||
|     F attr, | ||||
|     const int64_t scale_group_size, | ||||
|     const int64_t zp_group_size) { | ||||
|   switch (Tt) { | ||||
|     case trans_type_t::nt: | ||||
|       return matmul_primitive_cache_t<trans_type_t::nt, Ts, F>::get( | ||||
|           m, | ||||
|           n, | ||||
|           k, | ||||
|           lda, | ||||
|           ldb, | ||||
|           ldc, | ||||
|           b_dims, | ||||
|           device_id, | ||||
|           attr, | ||||
|           scale_group_size, | ||||
|           zp_group_size); | ||||
|     default: | ||||
|       TORCH_INTERNAL_ASSERT(false, "unsupported trans type ..."); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename F> | ||||
| static inline primitive_ext& matmul_primitive_create_and_cache( | ||||
|     const joint_dtypes_t Ts, | ||||
|     const trans_type_t Tt, | ||||
|     const bias_type_t b_dims, | ||||
|     const int m, | ||||
|     const int n, | ||||
|     const int k, | ||||
|     const int64_t lda, | ||||
|     const int64_t ldb, // is weight ldb necessary? | ||||
|     const int64_t ldc, | ||||
|     const int device_id, | ||||
|     F attr, | ||||
|     const int64_t scale_group_size = 0, | ||||
|     const int64_t zp_group_size = 0) { | ||||
|   switch (Ts) { | ||||
|     case joint_dtypes_t::f16_int4: | ||||
|       return matmul_primitive_create_and_cache<joint_dtypes_t::f16_int4, F>( | ||||
|           Tt, | ||||
|           b_dims, | ||||
|           m, | ||||
|           n, | ||||
|           k, | ||||
|           lda, | ||||
|           ldb, | ||||
|           ldc, | ||||
|           device_id, | ||||
|           attr, | ||||
|           scale_group_size, | ||||
|           zp_group_size); | ||||
|     case joint_dtypes_t::bf16_int4: | ||||
|       return matmul_primitive_create_and_cache<joint_dtypes_t::bf16_int4, F>( | ||||
|           Tt, | ||||
|           b_dims, | ||||
|           m, | ||||
|           n, | ||||
|           k, | ||||
|           lda, | ||||
|           ldb, | ||||
|           ldc, | ||||
|           device_id, | ||||
|           attr, | ||||
|           scale_group_size, | ||||
|           zp_group_size); | ||||
|     default: | ||||
|       TORCH_INTERNAL_ASSERT(false, "Only support int4 ..."); | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace at::native::onednn | ||||
							
								
								
									
										110
									
								
								aten/src/ATen/native/mkldnn/xpu/detail/LRUCache.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								aten/src/ATen/native/mkldnn/xpu/detail/LRUCache.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,110 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <iterator> | ||||
| #include <list> | ||||
| #include <unordered_map> | ||||
| #include <utility> | ||||
|  | ||||
| namespace at::native::onednn { | ||||
|  | ||||
| template < | ||||
|     class key_t, | ||||
|     class value_t, | ||||
|     template <typename...> class map_t = std::unordered_map> | ||||
| class lru_cache { | ||||
|  public: | ||||
|   using value_type = std::pair<key_t, value_t>; | ||||
|   using list_type = std::list<value_type>; | ||||
|   using list_iter = typename list_type::iterator; | ||||
|   using map_type = map_t<key_t, list_iter>; | ||||
|   using const_list_iter = typename list_type::const_iterator; | ||||
|   using size_type = typename list_type::size_type; | ||||
|  | ||||
|   explicit lru_cache(size_type capacity) : capacity_(capacity) {} | ||||
|   lru_cache() : capacity_(0) {} | ||||
|  | ||||
|   [[nodiscard]] size_type size() const noexcept { | ||||
|     return map_.size(); | ||||
|   } | ||||
|   [[nodiscard]] size_type max_size() const noexcept { | ||||
|     return capacity_; | ||||
|   } | ||||
|   [[nodiscard]] bool empty() const noexcept { | ||||
|     return vlist_.empty(); | ||||
|   } | ||||
|  | ||||
|   void resize(size_type new_capacity) { | ||||
|     capacity_ = new_capacity; | ||||
|     trim(); | ||||
|   } | ||||
|  | ||||
|   list_iter begin() noexcept { | ||||
|     return vlist_.begin(); | ||||
|   } | ||||
|   const_list_iter begin() const noexcept { | ||||
|     return vlist_.begin(); | ||||
|   } | ||||
|   list_iter end() noexcept { | ||||
|     return vlist_.end(); | ||||
|   } | ||||
|   const_list_iter end() const noexcept { | ||||
|     return vlist_.end(); | ||||
|   } | ||||
|  | ||||
|   void clear() noexcept { | ||||
|     map_.clear(); | ||||
|     vlist_.clear(); | ||||
|   } | ||||
|  | ||||
|   void swap(lru_cache& other) noexcept { | ||||
|     using std::swap; | ||||
|     swap(vlist_, other.vlist_); | ||||
|     swap(map_, other.map_); | ||||
|     swap(capacity_, other.capacity_); | ||||
|   } | ||||
|  | ||||
|   list_iter find(const key_t& key) { | ||||
|     auto it = map_.find(key); | ||||
|     if (it == map_.end()) | ||||
|       return end(); | ||||
|     vlist_.splice(vlist_.begin(), vlist_, it->second); | ||||
|     return it->second; | ||||
|   } | ||||
|  | ||||
|   std::pair<list_iter, bool> insert(const value_type& value) { | ||||
|     auto it = map_.find(value.first); | ||||
|     if (it != map_.end()) { | ||||
|       // Move existing to front | ||||
|       vlist_.splice(vlist_.begin(), vlist_, it->second); | ||||
|       return {it->second, false}; | ||||
|     } | ||||
|  | ||||
|     // Insert new at front | ||||
|     vlist_.emplace_front(value); | ||||
|     map_[value.first] = vlist_.begin(); | ||||
|  | ||||
|     trim(); | ||||
|  | ||||
|     return {vlist_.begin(), true}; | ||||
|   } | ||||
|  | ||||
|   list_iter erase(list_iter pos) { | ||||
|     map_.erase(pos->first); | ||||
|     return vlist_.erase(pos); | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   void trim() { | ||||
|     while (map_.size() > capacity_) { | ||||
|       auto last = std::prev(vlist_.end()); | ||||
|       map_.erase(last->first); | ||||
|       vlist_.pop_back(); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   list_type vlist_; | ||||
|   map_type map_; | ||||
|   size_type capacity_; | ||||
| }; | ||||
|  | ||||
| } // namespace at::native::onednn | ||||
| @ -294,6 +294,13 @@ bool is_onednn_matmul_strides(const at::Tensor& tensor) { | ||||
|   if (tensor.is_contiguous()) | ||||
|     return true; | ||||
|  | ||||
|   if (tensor.storage_offset() > 0) { | ||||
|     // currently onednn asks 64 byte alignment | ||||
|     constexpr int alignment_byte = 64; | ||||
|     if (reinterpret_cast<uintptr_t>(tensor.data_ptr()) % alignment_byte > 0) | ||||
|       return false; | ||||
|   } | ||||
|  | ||||
|   // the overlaped cases are not supported | ||||
|   dnnl::memory::dims strides = get_onednn_strides(tensor); | ||||
|   int64_t storage_size = 1; | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| #include <c10/xpu/XPUFunctions.h> | ||||
|  | ||||
| #include <ATen/native/mkldnn/xpu/detail/Attr.h> | ||||
| #include <ATen/native/mkldnn/xpu/detail/DnnlExt.h> | ||||
| #include <ATen/native/mkldnn/xpu/detail/Utils.h> | ||||
|  | ||||
| #include <oneapi/dnnl/dnnl.hpp> | ||||
| @ -8,22 +9,13 @@ | ||||
|  | ||||
| namespace at::native::onednn { | ||||
|  | ||||
| void woq_matmul_int4( | ||||
|     Tensor& result, // torchao: [M, K], dtype: fp16,bf16 | ||||
|     const Tensor& mat1_, // torchao: [M, K], dtype: fp16,bf16 | ||||
|     const Tensor& mat2_, // torchao quantized weight, [K/8, N], dtype: uint4x8 | ||||
|     const Tensor& scale, // torchao: [K/group_size, N], dtype: fp16,bf16 | ||||
|     const Tensor& zp, // torchao: [K/group_size, N], dtype: int8 | ||||
| void woq_matmul_int4_impl( | ||||
|     Tensor& result, | ||||
|     const Tensor& mat1_, | ||||
|     const Tensor& mat2_, | ||||
|     const Tensor& scale, | ||||
|     const Tensor& zp, | ||||
|     int64_t group_size) { | ||||
|   size_t dims = result.dim(); | ||||
|   TORCH_CHECK( | ||||
|       dims == 2, "INT4 matmul at XPU only works with 2D input, got ", dims); | ||||
|   TORCH_CHECK(result.defined(), "oneDNN matmul result should be defined"); | ||||
|  | ||||
|   at::Device cur_device = at::Device(at::kXPU, at::xpu::current_device()); | ||||
|   TORCH_CHECK( | ||||
|       cur_device == mat1_.device(), | ||||
|       "_weight_int4pack_mm_with_scales_and_zeros input should be on current device."); | ||||
|   auto& engine = GpuEngineManager::Instance().get_engine(); | ||||
|   auto& stream = GpuStreamManager::Instance().get_stream(); | ||||
|  | ||||
| @ -176,4 +168,162 @@ void woq_matmul_int4( | ||||
|   args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, zp_usr_m}); | ||||
|   dnnl::sycl_interop::execute(matmul_p, stream, args); | ||||
| } | ||||
|  | ||||
| static inline void set_quant_primitive_attr( | ||||
|     primitive_attr& pattr, | ||||
|     const Tensor& scale, | ||||
|     const Tensor& zp, | ||||
|     const int64_t group_size) { | ||||
|   // set scale and zero point for matmul args | ||||
|   pattr.set_scales( | ||||
|       DNNL_ARG_WEIGHTS, | ||||
|       /* mask */ (1 << 0) + (1 << 1), | ||||
|       {group_size, 1}, | ||||
|       get_onednn_dtype(scale)); | ||||
|   pattr.set_zero_points( | ||||
|       DNNL_ARG_WEIGHTS, | ||||
|       /* mask */ (1 << 0) + (1 << 1), | ||||
|       {group_size, 1}, | ||||
|       memory::data_type::s8); | ||||
| } | ||||
|  | ||||
| void woq_matmul_int4_impl_cache( | ||||
|     Tensor& result, | ||||
|     const Tensor& mat1, | ||||
|     const Tensor& mat2, | ||||
|     const Tensor& scale, | ||||
|     const Tensor& zp, | ||||
|     int64_t group_size) { | ||||
|   auto a_sz = mat1.sizes(); | ||||
|   auto c_sz = result.sizes(); | ||||
|  | ||||
|   const int m = | ||||
|       std::reduce(a_sz.begin(), a_sz.end() - 1, 1, std::multiplies<int64_t>()); | ||||
|   const int n = *(c_sz.end() - 1); | ||||
|   const int k = *(a_sz.end() - 1); | ||||
|  | ||||
|   const int64_t ldb = mat2.strides()[mat2.dim() - 2] * 8; // for int4 matmul | ||||
|   const int64_t lda = mat1.strides()[mat1.dim() - 2]; | ||||
|   const int64_t ldc = result.strides()[result.dim() - 2]; | ||||
|  | ||||
|   bias_type_t b_type = bias_type_t::none; | ||||
|   trans_type_t tt = trans_type_t::nt; // only support nt for int4 matmul | ||||
|  | ||||
|   joint_dtypes_t jd; | ||||
|   if (mat1.scalar_type() == at::ScalarType::Half) { | ||||
|     jd = joint_dtypes_t::f16_int4; | ||||
|   } else if (mat1.scalar_type() == at::ScalarType::BFloat16) { | ||||
|     jd = joint_dtypes_t::bf16_int4; | ||||
|   } else { | ||||
|     TORCH_INTERNAL_ASSERT( | ||||
|         false, "Unsupported data type for int4 matmul: ", mat1.scalar_type()); | ||||
|   } | ||||
|  | ||||
|   auto f_attr = [&](primitive_attr& pattr) { | ||||
|     pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | ||||
|  | ||||
|     if (jd == joint_dtypes_t::f16_int4) { | ||||
|       pattr.set_fpmath_mode(dnnl::fpmath_mode::f16, true); | ||||
|     } else if (jd == joint_dtypes_t::bf16_int4) { | ||||
|       pattr.set_fpmath_mode(dnnl::fpmath_mode::bf16, true); | ||||
|     } | ||||
|  | ||||
|     set_quant_primitive_attr(pattr, scale, zp, group_size); | ||||
|  | ||||
| #if ONEDNN_SUPPORT_DETERMINISTIC | ||||
|     if (at::globalContext().deterministicAlgorithms() || | ||||
|         at::globalContext().deterministicMkldnn()) { | ||||
|       pattr.set_deterministic(true); | ||||
|     } | ||||
| #endif | ||||
|   }; | ||||
|  | ||||
|   int64_t zp_group_size = group_size; | ||||
|   auto device_id = c10::xpu::current_device(); | ||||
|   auto& matmul_ext = matmul_primitive_create_and_cache( | ||||
|       jd, | ||||
|       tt, | ||||
|       b_type, | ||||
|       m, | ||||
|       n, | ||||
|       k, | ||||
|       lda, | ||||
|       ldb, | ||||
|       ldc, | ||||
|       device_id, | ||||
|       f_attr, | ||||
|       group_size, | ||||
|       zp_group_size); | ||||
|  | ||||
|   auto& engine = GpuEngineManager::Instance().get_engine(); | ||||
|  | ||||
|   int arg_off = 0; | ||||
|   // set scale and zero point for matmul args | ||||
|   matmul_ext.set_attribute( | ||||
|       arg_off++, | ||||
|       DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, | ||||
|       scale.data_ptr(), | ||||
|       [&]() { | ||||
|         return make_onednn_memory( | ||||
|             get_onednn_md(scale), engine, scale.data_ptr()); | ||||
|       }); | ||||
|  | ||||
|   // set zp_md for asymmetric quantization | ||||
|   matmul_ext.set_attribute( | ||||
|       arg_off++, | ||||
|       DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, | ||||
|       zp.data_ptr(), | ||||
|       [&]() { | ||||
|         int num_groups = k / group_size; | ||||
|         memory zp_usr_m( | ||||
|             {{num_groups, n}, memory::data_type::s8, {n, 1}}, | ||||
|             engine, | ||||
|             zp.data_ptr()); | ||||
|         return zp_usr_m; | ||||
|       }); | ||||
|  | ||||
|   // set general args | ||||
|   std::vector<std::pair<int, void*>> arg_handles; | ||||
|   arg_handles.reserve(8); | ||||
|  | ||||
|   arg_handles.emplace_back(DNNL_ARG_SRC, mat1.data_ptr()); | ||||
|   arg_handles.emplace_back(DNNL_ARG_WEIGHTS, mat2.data_ptr()); | ||||
|   arg_handles.emplace_back(DNNL_ARG_DST, result.data_ptr()); | ||||
|  | ||||
|   int scratchpad_size = matmul_ext.get_scratchpad_size(); | ||||
|   Tensor scratchpad_tensor = at::empty( | ||||
|       {scratchpad_size}, mat1.options().dtype(at::kByte), std::nullopt); | ||||
|   arg_handles.emplace_back(DNNL_ARG_SCRATCHPAD, scratchpad_tensor.data_ptr()); | ||||
|  | ||||
|   auto& strm = GpuStreamManager::Instance().get_stream(); | ||||
|   auto qint4_matmul_event = | ||||
|       matmul_ext.execute(strm, engine, std::move(arg_handles), arg_off); | ||||
| } | ||||
|  | ||||
| void woq_matmul_int4( | ||||
|     Tensor& result, // torchao: [M, K], dtype: fp16,bf16 | ||||
|     const Tensor& mat1_, // torchao: [M, K], dtype: fp16,bf16 | ||||
|     const Tensor& mat2_, // torchao quantized weight, [K/8, N], dtype: uint4x8 | ||||
|     const Tensor& scale, // torchao: [K/group_size, N], dtype: fp16,bf16 | ||||
|     const Tensor& zp, // torchao: [K/group_size, N], dtype: int8 | ||||
|     int64_t group_size, | ||||
|     bool pri_cache) { | ||||
|   size_t dims = result.dim(); | ||||
|   TORCH_CHECK( | ||||
|       dims == 2, "INT4 matmul at XPU only works with 2D input, got ", dims); | ||||
|   TORCH_CHECK(result.defined(), "oneDNN matmul result should be defined"); | ||||
|  | ||||
|   const int device_id = c10::xpu::current_device(); | ||||
|   at::Device cur_device = at::Device(at::kXPU, device_id); | ||||
|   TORCH_CHECK( | ||||
|       cur_device == mat1_.device(), | ||||
|       "_weight_int4pack_mm_with_scales_and_zeros input should be on current device."); | ||||
|  | ||||
|   if (pri_cache) { | ||||
|     woq_matmul_int4_impl_cache(result, mat1_, mat2_, scale, zp, group_size); | ||||
|   } else { | ||||
|     woq_matmul_int4_impl(result, mat1_, mat2_, scale, zp, group_size); | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace at::native::onednn | ||||
|  | ||||
| @ -95,7 +95,8 @@ TORCH_API void woq_matmul_int4( | ||||
|     const at::Tensor& mat2_, // quantized weight, [K/8, N] | ||||
|     const at::Tensor& scale, // [K/group_size, N] | ||||
|     const at::Tensor& zp, // [k/group_size, N] | ||||
|     int64_t group_size); | ||||
|     int64_t group_size, | ||||
|     bool pri_cache = true); | ||||
|  | ||||
| dnnl::memory::dims conv_dst_size( | ||||
|     int64_t ndim, | ||||
|  | ||||
| @ -295,6 +295,127 @@ kernel void masked_fill_scalar_strided( | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T, typename index_t> | ||||
| kernel void index_copy_dense( | ||||
|     device T* output, | ||||
|     constant T* input, | ||||
|     constant T* source, | ||||
|     constant index_t* indices, | ||||
|     constant uint& dim, | ||||
|     constant long* sizes, | ||||
|     constant uint& ndim, | ||||
|     constant uint& indices_numel, | ||||
|     uint thread_index [[thread_position_in_grid]]) { | ||||
|   // first copy input to output | ||||
|   output[thread_index] = input[thread_index]; | ||||
|  | ||||
|   // calculate pos in the tensor using a signed counter | ||||
|   long pos[max_ndim]; | ||||
|   long linear_idx = thread_index; | ||||
|   for (int i = static_cast<int>(ndim) - 1; i >= 0; --i) { | ||||
|     pos[i] = linear_idx % sizes[i]; | ||||
|     linear_idx /= sizes[i]; | ||||
|   } | ||||
|  | ||||
|   // check if this position's dim coordinate is in the indices | ||||
|   long dim_pos = pos[dim]; | ||||
|  | ||||
|   // search through indices to see if current dim pos should be updated | ||||
|   for (uint i = 0; i < indices_numel; i++) { | ||||
|     if (indices[i] == dim_pos) { | ||||
|       // this position should be updated from source | ||||
|       // calculate source offset where the source tensor has the same shape | ||||
|       // except along dim where it has size = indices_numel | ||||
|       long source_offset = 0; | ||||
|       long stride = 1; | ||||
|       for (int j = static_cast<int>(ndim) - 1; j >= 0; --j) { | ||||
|         if (j == static_cast<int>(dim)) { | ||||
|           // for the indexed dimension, use position i | ||||
|           source_offset += i * stride; | ||||
|           stride *= indices_numel; | ||||
|         } else { | ||||
|           // for other dimensions use the same position | ||||
|           source_offset += pos[j] * stride; | ||||
|           stride *= sizes[j]; | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       output[thread_index] = source[source_offset]; | ||||
|       break; | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T, typename index_t> | ||||
| kernel void index_copy_strided( | ||||
|     device T* output, | ||||
|     constant T* input, | ||||
|     constant T* source, | ||||
|     constant index_t* indices, | ||||
|     constant uint& dim, | ||||
|     constant long* sizes, | ||||
|     constant uint& ndim, | ||||
|     constant uint& indices_numel, | ||||
|     constant long* input_strides, | ||||
|     constant long* output_strides, | ||||
|     constant long* source_strides, | ||||
|     uint thread_index [[thread_position_in_grid]]) { | ||||
|   int pos[max_ndim]; | ||||
|   pos_from_thread_index(int(thread_index), pos, sizes, ndim); | ||||
|  | ||||
|   // compute offsets for the output and input tensors | ||||
|   long output_offset = offset_from_coord(pos, output_strides, ndim); | ||||
|   long input_offset = offset_from_coord(pos, input_strides, ndim); | ||||
|  | ||||
|   output[output_offset] = input[input_offset]; | ||||
|  | ||||
|   // save the original coordinate along the dim we're updating | ||||
|   int orig_dim = pos[dim]; | ||||
|  | ||||
|   // find the last index in the indices array that equals this coordinate | ||||
|   int last_matching_index = -1; | ||||
|   for (uint i = 0; i < indices_numel; i++) { | ||||
|     if (indices[i] == orig_dim) { | ||||
|       last_matching_index = int(i); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   // if a matching index was found, use it to update the output | ||||
|   if (last_matching_index != -1) { | ||||
|     pos[dim] = last_matching_index; | ||||
|     long source_offset = offset_from_coord(pos, source_strides, ndim); | ||||
|     output[output_offset] = source[source_offset]; | ||||
|   } | ||||
| } | ||||
|  | ||||
| #define INSTANTIATE_INDEX_COPY(T, index_t)                      \ | ||||
|   template [[host_name("index_copy_dense_" #T "_" #index_t)]]   \ | ||||
|   kernel void index_copy_dense<T, index_t>(                     \ | ||||
|       device T*,                                                \ | ||||
|       constant T*,                                              \ | ||||
|       constant T*,                                              \ | ||||
|       constant index_t*,                                        \ | ||||
|       constant uint&,                                           \ | ||||
|       constant long*,                                           \ | ||||
|       constant uint&,                                           \ | ||||
|       constant uint&,                                           \ | ||||
|       uint);                                                    \ | ||||
|                                                                 \ | ||||
|   template [[host_name("index_copy_strided_" #T "_" #index_t)]] \ | ||||
|   kernel void index_copy_strided<T, index_t>(                   \ | ||||
|       device T*,                                                \ | ||||
|       constant T*,                                              \ | ||||
|       constant T*,                                              \ | ||||
|       constant index_t*,                                        \ | ||||
|       constant uint&,                                           \ | ||||
|       constant long*,                                           \ | ||||
|       constant uint&,                                           \ | ||||
|       constant uint&,                                           \ | ||||
|       constant long*,                                           \ | ||||
|       constant long*,                                           \ | ||||
|       constant long*,                                           \ | ||||
|       uint); | ||||
|  | ||||
| #define REGISTER_MASKED_FILL_SCALAR(SIZE, DTYPE)                            \ | ||||
|   template [[host_name("masked_fill_scalar_strided_" #SIZE)]] kernel void   \ | ||||
|   masked_fill_scalar_strided<DTYPE>(                                        \ | ||||
| @ -317,3 +438,28 @@ REGISTER_MASKED_FILL_SCALAR(64bit, long); | ||||
| REGISTER_MASKED_FILL_SCALAR(32bit, int); | ||||
| REGISTER_MASKED_FILL_SCALAR(16bit, short); | ||||
| REGISTER_MASKED_FILL_SCALAR(8bit, char); | ||||
| INSTANTIATE_INDEX_COPY(float, int); | ||||
| INSTANTIATE_INDEX_COPY(float, long); | ||||
| INSTANTIATE_INDEX_COPY(bool, int); | ||||
| INSTANTIATE_INDEX_COPY(bool, long); | ||||
| INSTANTIATE_INDEX_COPY(half, int); | ||||
| INSTANTIATE_INDEX_COPY(half, long); | ||||
| INSTANTIATE_INDEX_COPY(int, int); | ||||
| INSTANTIATE_INDEX_COPY(int, long); | ||||
| INSTANTIATE_INDEX_COPY(long, int); | ||||
| INSTANTIATE_INDEX_COPY(long, long); | ||||
| INSTANTIATE_INDEX_COPY(short, int); | ||||
| INSTANTIATE_INDEX_COPY(short, long); | ||||
| INSTANTIATE_INDEX_COPY(char, int); | ||||
| INSTANTIATE_INDEX_COPY(char, long); | ||||
| INSTANTIATE_INDEX_COPY(uchar, int); | ||||
| INSTANTIATE_INDEX_COPY(uchar, long); | ||||
|  | ||||
| #if __METAL_VERSION__ >= 310 | ||||
| INSTANTIATE_INDEX_COPY(bfloat, int); | ||||
| INSTANTIATE_INDEX_COPY(bfloat, long); | ||||
| #endif | ||||
| INSTANTIATE_INDEX_COPY(float2, int); | ||||
| INSTANTIATE_INDEX_COPY(float2, long); | ||||
| INSTANTIATE_INDEX_COPY(half2, int); | ||||
| INSTANTIATE_INDEX_COPY(half2, long); | ||||
|  | ||||
| @ -34,6 +34,7 @@ | ||||
| #include <ATen/ops/flip_native.h> | ||||
| #include <ATen/ops/index.h> | ||||
| #include <ATen/ops/index_add_native.h> | ||||
| #include <ATen/ops/index_copy_native.h> | ||||
| #include <ATen/ops/index_fill_native.h> | ||||
| #include <ATen/ops/index_put.h> | ||||
| #include <ATen/ops/index_select_native.h> | ||||
| @ -252,6 +253,78 @@ static void index_put_kernel_mps(TensorIterator& iter, | ||||
| } | ||||
| } // namespace mps | ||||
|  | ||||
| TORCH_IMPL_FUNC(index_copy_out_mps)(const Tensor& self, | ||||
|                                     int64_t dim, | ||||
|                                     const Tensor& index, | ||||
|                                     const Tensor& source, | ||||
|                                     const Tensor& result) { | ||||
|   using namespace mps; | ||||
|  | ||||
|   // special-case for 0-dim tensors | ||||
|   if (self.dim() == 0) { | ||||
|     TORCH_CHECK(index.numel() == 1, | ||||
|                 "index_copy_(): attempting to index a 0-dim tensor with an index tensor of size ", | ||||
|                 index.numel()); | ||||
|     int64_t idx = index.item<int64_t>(); | ||||
|     TORCH_CHECK(idx == 0, "index_copy_(): the only valid index for a 0-dim tensor is 0, but got ", idx); | ||||
|     result.copy_(source); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   dim = maybe_wrap_dim(dim, self.dim()); | ||||
|  | ||||
|   // early return for empty index | ||||
|   if (index.numel() == 0) { | ||||
|     result.copy_(self); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   for (int64_t i = 0; i < self.dim(); i++) { | ||||
|     if (i != dim) { | ||||
|       TORCH_CHECK(self.size(i) == source.size(i), | ||||
|                   "index_copy_(): self and source must have same size at dimension ", | ||||
|                   i, | ||||
|                   "; self has size ", | ||||
|                   self.size(i), | ||||
|                   ", source has size ", | ||||
|                   source.size(i)); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   TORCH_CHECK(source.size(dim) == index.numel(), | ||||
|               "index_copy_(): Number of indices (", | ||||
|               index.numel(), | ||||
|               ") should be equal to source.size(dim) (", | ||||
|               source.size(dim), | ||||
|               ")"); | ||||
|  | ||||
|   auto stream = getCurrentMPSStream(); | ||||
|   auto device = MPSDevice::getInstance()->device(); | ||||
|  | ||||
|   const bool is_dense = | ||||
|       self.is_contiguous() && source.is_contiguous() && result.is_contiguous() && index.is_contiguous(); | ||||
|  | ||||
|   auto dense_or_strided = is_dense ? "dense" : "strided"; | ||||
|   auto long_or_int = (index.scalar_type() == ScalarType::Long) ? "long" : "int"; | ||||
|   auto indexCopyPSO = lib.getPipelineStateForFunc( | ||||
|       fmt::format("index_copy_{}_{}_{}", dense_or_strided, scalarToMetalTypeString(result), long_or_int)); | ||||
|  | ||||
|   dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||
|     @autoreleasepool { | ||||
|       auto computeEncoder = stream->commandEncoder(); | ||||
|       uint32_t dim_arg = static_cast<uint32_t>(dim); | ||||
|       uint32_t ndim = self.dim(); | ||||
|       uint32_t indices_numel = index.numel(); | ||||
|       [computeEncoder setComputePipelineState:indexCopyPSO]; | ||||
|       mtl_setArgs(computeEncoder, result, self, source, index, dim_arg, self.sizes(), ndim, indices_numel); | ||||
|       if (!is_dense) { | ||||
|         mtl_setArgs<8>(computeEncoder, self.strides(), result.strides(), source.strides()); | ||||
|       } | ||||
|       mtl_dispatch1DJob(computeEncoder, indexCopyPSO, result.numel()); | ||||
|     } | ||||
|   }); | ||||
| } | ||||
|  | ||||
| static Tensor nonzero_fallback(const Tensor& self) { | ||||
|   return at::nonzero(self.to("cpu")).to("mps"); | ||||
| } | ||||
|  | ||||
| @ -35,14 +35,15 @@ static void _mps_linear_nograph(const Tensor& input, const Tensor& weight, const | ||||
|                                                                                 shape:getMPSShape(weight.sizes())]; | ||||
|       weightDesc.preferPackedRows = YES; | ||||
|       [weightDesc transposeDimension:0 withDimension:1]; | ||||
|       MPSNDArray* weightNDArray = [[MPSNDArray alloc] initWithBuffer:weightBuf | ||||
|                                                               offset:weight.storage_offset() * weight.element_size() | ||||
|                                                           descriptor:weightDesc]; | ||||
|       MPSNDArray* weightNDArray = [[[MPSNDArray alloc] initWithBuffer:weightBuf | ||||
|                                                                offset:weight.storage_offset() * weight.element_size() | ||||
|                                                            descriptor:weightDesc] autorelease]; | ||||
|  | ||||
|       if (is_bias_defined) { | ||||
|         auto biasNDArray = getMPSNDArray(bias, bias.sizes(), bias.strides()); | ||||
|         auto cachedKernel = LookUpOrCreateCachedKernel<MPSCachedKernel>( | ||||
|             key, [&]() { return [[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:3]; }); | ||||
|         auto cachedKernel = LookUpOrCreateCachedKernel<MPSCachedKernel>(key, [&]() { | ||||
|           return [[[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:3] autorelease]; | ||||
|         }); | ||||
|         auto kernel = cachedKernel->kernel<MPSNDArrayMatrixMultiplication>(); | ||||
|  | ||||
|         getMPSProfiler().beginProfileKernel(kernel, "mps_linear", {input, weight, bias}); | ||||
| @ -52,8 +53,9 @@ static void _mps_linear_nograph(const Tensor& input, const Tensor& weight, const | ||||
|                       destinationArray:outNDArray]; | ||||
|         getMPSProfiler().endProfileKernel(kernel); | ||||
|       } else { | ||||
|         auto cachedKernel = LookUpOrCreateCachedKernel<MPSCachedKernel>( | ||||
|             key, [&]() { return [[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:2]; }); | ||||
|         auto cachedKernel = LookUpOrCreateCachedKernel<MPSCachedKernel>(key, [&]() { | ||||
|           return [[[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:2] autorelease]; | ||||
|         }); | ||||
|         auto kernel = cachedKernel->kernel<MPSNDArrayMatrixMultiplication>(); | ||||
|         getMPSProfiler().beginProfileKernel(kernel, "mps_linear", {input, weight, bias}); | ||||
|         [kernel encodeToCommandEncoder:computeEncoder | ||||
|  | ||||
| @ -3110,6 +3110,7 @@ | ||||
|   - dim -> int dim | ||||
|   dispatch: | ||||
|     CPU, CUDA: index_copy_out | ||||
|     MPS: index_copy_out_mps | ||||
|  | ||||
| - func: index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!) | ||||
|   variants: method | ||||
|  | ||||
| @ -1,5 +1,7 @@ | ||||
| #include <ATen/native/sparse/cuda/cuSPARSELtOps.h> | ||||
|  | ||||
| #include <unordered_map> | ||||
| #include <mutex> | ||||
| #include <string_view> | ||||
| #if AT_CUSPARSELT_ENABLED() | ||||
|  | ||||
| namespace at::native { | ||||
| @ -15,6 +17,45 @@ namespace at::native { | ||||
| thread_local cusparseLtHandle_t handle; | ||||
| thread_local bool handle_initialized = false; | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
| // Single global flag for platform-wide hipSparseLt support | ||||
| c10::once_flag g_hipSparseLtSupportInitFlag; | ||||
| static bool g_hipSparseLtSupported = false; | ||||
|  | ||||
| // Initialize the hipSparseLt support status once for the platform | ||||
| static void initHipSparseLtSupport() { | ||||
|     // Default to not supported | ||||
|     g_hipSparseLtSupported = false; | ||||
|  | ||||
|     // Check only the first available device | ||||
|     try { | ||||
|         if (at::cuda::device_count() > 0) { | ||||
|             g_hipSparseLtSupported = at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx942"}, 0); | ||||
|         } | ||||
|     } catch (const std::exception&) { | ||||
|         // If an exception occurs during device property check, we assume hipSparseLt is not supported | ||||
|         // This could happen due to driver issues, device access problems, or other runtime errors | ||||
|         g_hipSparseLtSupported = false; | ||||
|         TORCH_WARN("Exception occurred while checking hipSparseLt support. Assuming not supported."); | ||||
|     } | ||||
| } | ||||
|  | ||||
| static bool isHipSparseLtSupported() { | ||||
|     // Initialize support check only once | ||||
|     c10::call_once(g_hipSparseLtSupportInitFlag, initHipSparseLtSupport); | ||||
|  | ||||
|     // Return cached result (platform-wide) | ||||
|     if (!g_hipSparseLtSupported) { | ||||
|         TORCH_CHECK( | ||||
|             false, | ||||
|             "hipSparseLt not supported on this device, supported architectures: " | ||||
|             "gfx950, gfx942. " | ||||
|             "required ROCM version: 6.4.0 or later."); | ||||
|     } | ||||
|     return g_hipSparseLtSupported; | ||||
| } | ||||
| #endif | ||||
|  | ||||
| at::Tensor _cslt_compress(const Tensor& sparse_input) { | ||||
|   if (!handle_initialized) { | ||||
|     TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle)); | ||||
| @ -25,6 +66,10 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) { | ||||
|   cudaDataType type; | ||||
|   auto compression_factor = 9; | ||||
|  | ||||
|   #ifdef USE_ROCM | ||||
|   TORCH_CHECK(isHipSparseLtSupported()); | ||||
|   #endif | ||||
|  | ||||
|   switch (sparse_input.scalar_type()) { | ||||
|     case at::ScalarType::Char: | ||||
|       type = CUDA_R_8I; | ||||
| @ -36,17 +81,19 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) { | ||||
|     case at::ScalarType::BFloat16: | ||||
|       type = CUDA_R_16BF; | ||||
|       break; | ||||
| #ifndef USE_ROCM | ||||
|     case at::ScalarType::Float: | ||||
|       type = CUDA_R_32F; | ||||
|       break; | ||||
| #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 | ||||
| #endif | ||||
| #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM) | ||||
|     case at::ScalarType::Float8_e4m3fn: | ||||
|       type = CUDA_R_8F_E4M3; | ||||
|       compression_factor = 10; | ||||
|       break; | ||||
| #endif | ||||
|     default: | ||||
|       TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix"); | ||||
|       TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt/hipSparseLt compressed matrix"); | ||||
|       break; | ||||
|   } | ||||
|  | ||||
| @ -120,6 +167,10 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl( | ||||
|   cusparseComputeType compute_type; | ||||
|   auto compression_factor = 9; | ||||
|  | ||||
|   #ifdef USE_ROCM | ||||
|   TORCH_CHECK(isHipSparseLtSupported()); | ||||
|   #endif | ||||
|  | ||||
|   switch (compressed_A.scalar_type()) { | ||||
|     case at::ScalarType::Char: | ||||
|       input_type = CUDA_R_8I; | ||||
| @ -131,7 +182,7 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl( | ||||
|  | ||||
| // cuSPARSELt v0.5.2 onwards changes CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUT_16F | ||||
| // to CUSPARSE_COMPUTE_32F | ||||
| #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502 | ||||
| #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502 || defined(USE_ROCM) | ||||
|     case at::ScalarType::Half: | ||||
|       input_type = CUDA_R_16F; | ||||
|       output_type = CUDA_R_16F; | ||||
| @ -144,14 +195,16 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl( | ||||
|       C_type = CUDA_R_16BF; | ||||
|       compute_type = CUSPARSE_COMPUTE_32F; | ||||
|       break; | ||||
| #ifndef USE_ROCM | ||||
|     case at::ScalarType::Float: | ||||
|       input_type = CUDA_R_32F; | ||||
|       output_type = CUDA_R_32F; | ||||
|       C_type = CUDA_R_32F; | ||||
|       compute_type = CUSPARSE_COMPUTE_32F; | ||||
|       break; | ||||
| #endif | ||||
| // if cuSPARSELt >= 6.2.3, we can add Float8 support | ||||
| #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 | ||||
| #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM) | ||||
|     case at::ScalarType::Float8_e4m3fn: | ||||
|       input_type = CUDA_R_8F_E4M3; | ||||
|       output_type = CUDA_R_8F_E4M3; | ||||
| @ -214,7 +267,7 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl( | ||||
|       } | ||||
|     } | ||||
| // cslt 0.6.2+: fp8 fp8 -> {fp8, fp16, bf16, fp32} support | ||||
| #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 | ||||
| #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM) | ||||
|     else if (input_type == CUDA_R_8F_E4M3) { | ||||
|       switch (out_dtype) { | ||||
|         case at::ScalarType::Float8_e4m3fn: | ||||
|  | ||||
| @ -968,8 +968,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti | ||||
|   int64_t batch_size = query.size(0); | ||||
|  | ||||
|   if (batch_size > MAX_BATCH_SIZE) { | ||||
|     TORCH_CHECK(!compute_log_sumexp && (dropout_p == 0.0), | ||||
|                 "Efficient attention cannot produce valid seed, logsumexp and offset outputs when " | ||||
|     TORCH_CHECK(dropout_p == 0.0, | ||||
|                 "Efficient attention cannot produce valid seed and offset outputs when " | ||||
|                 "the batch size exceeds (", MAX_BATCH_SIZE, ")."); | ||||
|   } | ||||
|   auto process_chunk = [&](const Tensor& q_chunk, | ||||
| @ -1030,6 +1030,17 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti | ||||
|     } | ||||
|     Tensor final_attention = at::empty_strided(sizes, attn.strides(), attn.options()); | ||||
|     final_attention.slice(0, start, end).copy_(attn); | ||||
|     Tensor final_log_sumexp; | ||||
|     if (compute_log_sumexp && log_sumexp.numel() > 0) { | ||||
|       std::vector<int64_t> lse_sizes; | ||||
|       lse_sizes.reserve(log_sumexp.dim()); | ||||
|       lse_sizes.push_back(batch_size); | ||||
|       for (int i = 1; i < log_sumexp.dim(); i++) { | ||||
|         lse_sizes.push_back(log_sumexp.size(i)); | ||||
|       } | ||||
|       final_log_sumexp = at::empty(std::move(lse_sizes), log_sumexp.options()); | ||||
|       final_log_sumexp.slice(0, start, end).copy_(log_sumexp); | ||||
|     } | ||||
|  | ||||
|     for (start = end; start < batch_size; start += MAX_BATCH_SIZE) { | ||||
|       end = std::min(start + MAX_BATCH_SIZE, batch_size); | ||||
| @ -1045,10 +1056,13 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti | ||||
|       auto [chunk_attn, chunk_log_sumexp, chunk_seed, chunk_offset] = | ||||
|           process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk); | ||||
|       final_attention.slice(0, start, end).copy_(chunk_attn); | ||||
|       if (compute_log_sumexp && chunk_log_sumexp.numel() > 0) { | ||||
|         final_log_sumexp.slice(0, start, end).copy_(chunk_log_sumexp); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     return std::make_tuple(std::move(final_attention), | ||||
|               std::move(log_sumexp), | ||||
|               std::move(final_log_sumexp), | ||||
|               std::move(seed), | ||||
|               std::move(offset)); | ||||
|   } | ||||
|  | ||||
| @ -24,6 +24,8 @@ | ||||
| #include <ATen/Functions.h> | ||||
| #include <ATen/NativeFunctions.h> | ||||
| #else | ||||
| #include <ATen/ops/zeros_like.h> | ||||
| #include <ATen/ops/empty_strided.h> | ||||
| #include <ATen/ops/_flash_attention_backward.h> | ||||
| #include <ATen/ops/_flash_attention_backward_native.h> | ||||
| #include <ATen/ops/_efficient_attention_backward.h> | ||||
| @ -905,40 +907,56 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e | ||||
|   if (!grad_out_.defined()) { | ||||
|     return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}); | ||||
|   } | ||||
|   auto grad_out = grad_out_.transpose(1, 2); | ||||
|   auto out_t = out.transpose(1, 2); | ||||
|   auto q_t = query.transpose(1, 2); | ||||
|   auto k_t = key.transpose(1, 2); | ||||
|   auto v_t = value.transpose(1, 2); | ||||
|   constexpr int64_t MAX_BATCH_SIZE = (1LL << 16) - 1; | ||||
|   int64_t batch_size = query.size(0); | ||||
|  | ||||
|   if (batch_size > MAX_BATCH_SIZE) { | ||||
|     TORCH_CHECK(dropout_p == 0.0, | ||||
|                 "Efficient attention backward cannot handle dropout when " | ||||
|                 "the batch size exceeds (", MAX_BATCH_SIZE, ")."); | ||||
|   } | ||||
|   auto grad_out_t = grad_out_.transpose(1, 2); | ||||
|   auto query_t = query.transpose(1, 2); | ||||
|   auto key_t = key.transpose(1, 2); | ||||
|   auto value_t = value.transpose(1, 2); | ||||
|   auto out_t = out.transpose(1, 2); | ||||
|  | ||||
|   auto process_chunk = [&](const Tensor& grad_out_chunk, | ||||
|                           const Tensor& query_chunk, | ||||
|                           const Tensor& key_chunk, | ||||
|                           const Tensor& value_chunk, | ||||
|                           const std::optional<Tensor>& attn_bias_chunk, | ||||
|                           const Tensor& out_chunk, | ||||
|                           const Tensor& logsumexp_chunk) | ||||
|       -> std::tuple<Tensor, Tensor, Tensor, Tensor> { | ||||
|   // This is needed because SaveVariable automatically converts | ||||
|   // std::optional to undefined tensor | ||||
|   std::optional<Tensor> kernel_bias; | ||||
|   if (attn_bias.defined()) { | ||||
|     kernel_bias = attn_bias; | ||||
|   if (attn_bias_chunk.has_value() && attn_bias_chunk.value().defined()) { | ||||
|     kernel_bias = attn_bias_chunk.value(); | ||||
|   } | ||||
|   // Will add with signauter changes for dropout and bias | ||||
|   // We are only handling Dense inputs, but this should be passed | ||||
|   // from forward to backward | ||||
|   int64_t max_seqlen_q = q_t.size(1); | ||||
|   int64_t max_seqlen_k = k_t.size(1); | ||||
|   int64_t max_seqlen_q = query_chunk.size(2); | ||||
|   int64_t max_seqlen_k = key_chunk.size(2); | ||||
|  | ||||
|   sdp::CustomMaskType custom_mask_type = causal | ||||
|     ? sdp::CustomMaskType::CausalFromTopLeft | ||||
|     : sdp::CustomMaskType::NoCustomMask; | ||||
|   auto [grad_q, grad_k, grad_v, grad_bias] = | ||||
|       at::_efficient_attention_backward( | ||||
|           grad_out, | ||||
|           q_t, | ||||
|           k_t, | ||||
|           v_t, | ||||
|           grad_out_chunk, | ||||
|           query_chunk, | ||||
|           key_chunk, | ||||
|           value_chunk, | ||||
|           kernel_bias, | ||||
|           out_t, | ||||
|           out_chunk, | ||||
|           std::nullopt, | ||||
|           std::nullopt, | ||||
|           max_seqlen_q, | ||||
|           max_seqlen_k, | ||||
|           logsumexp, | ||||
|           logsumexp_chunk, | ||||
|           dropout_p, | ||||
|           philox_seed, | ||||
|           philox_offset, | ||||
| @ -947,7 +965,90 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e | ||||
|           scale, | ||||
|           std::nullopt);  // num_split_keys | ||||
|   return std::make_tuple( | ||||
|       grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias); | ||||
|       grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), std::move(grad_bias)); | ||||
|   }; | ||||
|  | ||||
|   // process in chunks if batch size exceeds maximum | ||||
|   if (batch_size > MAX_BATCH_SIZE) { | ||||
|     Tensor final_grad_q, final_grad_k, final_grad_v, final_grad_bias; | ||||
|  | ||||
|     auto create_strided_output = [batch_size](const Tensor& tensor) -> Tensor { | ||||
|       if (!tensor.defined()) { | ||||
|         return Tensor{}; | ||||
|       } | ||||
|       int dim = tensor.dim(); | ||||
|       std::vector<int64_t> sizes; | ||||
|       sizes.reserve(dim); | ||||
|       sizes.push_back(batch_size); | ||||
|       for (int i = 1; i < dim; i++) { | ||||
|         sizes.push_back(tensor.size(i)); | ||||
|       } | ||||
|       return at::empty_strided(std::move(sizes), tensor.strides(), tensor.options()); | ||||
|     }; | ||||
|  | ||||
|     if (grad_input_mask[0]) { | ||||
|       final_grad_q = create_strided_output(query); | ||||
|     } | ||||
|  | ||||
|     if (grad_input_mask[1]) { | ||||
|       final_grad_k = create_strided_output(key); | ||||
|     } | ||||
|  | ||||
|     if (grad_input_mask[2]) { | ||||
|       final_grad_v = create_strided_output(value); | ||||
|     } | ||||
|     if (grad_input_mask[3] && attn_bias.defined()) { | ||||
|       final_grad_bias = at::zeros_like(attn_bias); | ||||
|     } | ||||
|  | ||||
|     for (int64_t start = 0; start < batch_size; start += MAX_BATCH_SIZE) { | ||||
|       int64_t end = std::min(start + MAX_BATCH_SIZE, batch_size); | ||||
|  | ||||
|       Tensor grad_out_chunk = grad_out_t.slice(0, start, end); | ||||
|       Tensor query_chunk = query_t.slice(0, start, end); | ||||
|       Tensor key_chunk = key_t.slice(0, start, end); | ||||
|       Tensor value_chunk = value_t.slice(0, start, end); | ||||
|       Tensor attn_bias_chunk; | ||||
|       if (attn_bias.defined()) { | ||||
|         attn_bias_chunk = attn_bias.slice(0, start, end); | ||||
|       } else { | ||||
|         attn_bias_chunk.reset(); | ||||
|       } | ||||
|       Tensor out_chunk = out_t.slice(0, start, end); | ||||
|       Tensor logsumexp_chunk = logsumexp.numel() > 0 ? logsumexp.slice(0, start, end) : logsumexp; | ||||
|  | ||||
|       auto [chunk_grad_q, chunk_grad_k, chunk_grad_v, chunk_grad_bias] = | ||||
|           process_chunk(grad_out_chunk, query_chunk, key_chunk, value_chunk, | ||||
|                       attn_bias_chunk, out_chunk, logsumexp_chunk); | ||||
|  | ||||
|       if (grad_input_mask[0] && chunk_grad_q.defined()) { | ||||
|         final_grad_q.slice(0, start, end).copy_(chunk_grad_q); | ||||
|       } | ||||
|       if (grad_input_mask[1] && chunk_grad_k.defined()) { | ||||
|         final_grad_k.slice(0, start, end).copy_(chunk_grad_k); | ||||
|       } | ||||
|       if (grad_input_mask[2] && chunk_grad_v.defined()) { | ||||
|         final_grad_v.slice(0, start, end).copy_(chunk_grad_v); | ||||
|       } | ||||
|       if (grad_input_mask[3] && chunk_grad_bias.defined()) { | ||||
|         final_grad_bias.add_(chunk_grad_bias); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     return std::make_tuple( | ||||
|         std::move(final_grad_q), | ||||
|         std::move(final_grad_k), | ||||
|         std::move(final_grad_v), | ||||
|         std::move(final_grad_bias)); | ||||
|   } | ||||
|   // when batch size is within allowed size, no chunking needed | ||||
|   else { | ||||
|     std::optional<Tensor> attn_bias_opt; | ||||
|     if (attn_bias.defined()) { | ||||
|       attn_bias_opt = attn_bias; | ||||
|     } | ||||
|     return process_chunk(grad_out_t, query_t, key_t, value_t, attn_bias_opt, out_t, logsumexp); | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace at::native | ||||
|  | ||||
| @ -1018,9 +1018,6 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs): | ||||
|  | ||||
|     Writes to ./speedups.csv | ||||
|     """ | ||||
|     # if args.dynamic_shapes: | ||||
|     #     return speedup_experiment_ds(args, model_iter_fn, model, example_inputs) | ||||
|  | ||||
|     timings = np.zeros((args.repeat, 2), np.float64) | ||||
|     # if we randomize the input, we should also check the result is correct | ||||
|     should_randomize_input = args.randomize_input | ||||
| @ -1179,82 +1176,6 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs): | ||||
|     return msg | ||||
|  | ||||
|  | ||||
| # WARNING: This code is currently dead | ||||
| def speedup_experiment_ds(args, model_iter_fn, model, example_inputs): | ||||
|     """ | ||||
|     Run dynamic shapes benchmarks. | ||||
|  | ||||
|     Requires dynamic shape compatible models, which provide a list of example inputs. | ||||
|  | ||||
|     Warms up using the first input example and then iterates the inputs, | ||||
|     measuring (and expecting minimal) variance between the runtime for different examples. | ||||
|  | ||||
|     """ | ||||
|     timings = np.zeros((args.repeat, len(example_inputs), 2), np.float64) | ||||
|  | ||||
|     if args.repeat > 5: | ||||
|         print( | ||||
|             f"\ndynamic shapes experiments are slow, consider setting --repeat less than {args.repeat}\n" | ||||
|         ) | ||||
|  | ||||
|     nwarmup = 4 | ||||
|     for rep in range(args.repeat): | ||||
|         # Start each rep fresh, e.g. only warmup on example 0 | ||||
|         torch._dynamo.reset() | ||||
|         optimized_model_iter_fn = optimize_ctx(model_iter_fn) | ||||
|         for _ in range(nwarmup): | ||||
|             optimized_model_iter_fn(model, example_inputs[0]) | ||||
|  | ||||
|         for input_idx, inputs in enumerate(example_inputs): | ||||
|             # interleave the runs to handle frequency scaling and load changes | ||||
|             timings[rep, input_idx, 0] = timed( | ||||
|                 model, model_iter_fn, inputs, return_result=False | ||||
|             ) | ||||
|             # different from regular speedup_experiment, we _DO_ want to allow recompilation | ||||
|             timings[rep, input_idx, 1] = timed( | ||||
|                 model, optimized_model_iter_fn, inputs, return_result=False | ||||
|             ) | ||||
|     medians = np.median(timings, axis=0) | ||||
|     speedups = list(medians[:, 0] / medians[:, 1]) | ||||
|     speedups_mean = np.mean(speedups) | ||||
|     speedups_median = np.median(speedups) | ||||
|     speedups_var = np.var(speedups) | ||||
|  | ||||
|     # TODO this x[0] is not going to work in general but bert only has 1 input | ||||
|     shapes = [x[0].shape for x in example_inputs] | ||||
|     shape_keys = sorted(set(shapes)) | ||||
|     shape_speedups = { | ||||
|         shape: [ | ||||
|             it[1] for it in filter(lambda it: it[0] == shape, zip(shapes, speedups)) | ||||
|         ] | ||||
|         for shape in shape_keys | ||||
|     } | ||||
|     output_str = ( | ||||
|         f"mean: {speedups_mean:.3f}, median: {speedups_median:.3f}, var: {speedups_var:.3f}" | ||||
|         + "\nSpeedups by shape: " | ||||
|         + "\n".join( | ||||
|             [ | ||||
|                 f"{shape}: " | ||||
|                 + ", ".join([f"{speedup: .3g}" for speedup in shape_speedups[shape]]) | ||||
|                 for shape in shape_keys | ||||
|             ] | ||||
|         ) | ||||
|     ) | ||||
|     write_outputs( | ||||
|         output_filename, | ||||
|         ("dev", "name", "batch_size", "speedup mean", "speedup median", "speedup var"), | ||||
|         [ | ||||
|             current_device, | ||||
|             current_name, | ||||
|             current_batch_size, | ||||
|             speedups_mean, | ||||
|             speedups_median, | ||||
|             speedups_var, | ||||
|         ], | ||||
|     ) | ||||
|     return output_str | ||||
|  | ||||
|  | ||||
| def overhead_experiment(*args, model_iter_fn): | ||||
|     """ | ||||
|     Measure overheads of TorchDynamo by running with no backend (only | ||||
|  | ||||
| @ -54,12 +54,9 @@ class Benchmark(BenchmarkBase): | ||||
|         torch._dynamo.reset() | ||||
|  | ||||
|     def _work(self): | ||||
|         # enable_cpp_symbolic_shape_guards has impact on this benchmark | ||||
|         # Keep using False value for consistency. | ||||
|         with ( | ||||
|             fresh_inductor_cache(), | ||||
|             torch._inductor.config.patch(force_shape_pad=self._force_shape_pad), | ||||
|             torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards", False), | ||||
|         ): | ||||
|             opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())( | ||||
|                 self.m.cuda() if self._is_gpu else self.m | ||||
|  | ||||
| @ -247,7 +247,10 @@ class BenchmarkBase(ABC): | ||||
|                     instruction_count=r, | ||||
|                 ) | ||||
|         if self._enable_compile_time_instruction_count: | ||||
|             r = self._count_compile_time_instructions() | ||||
|             # enable_cpp_symbolic_shape_guards has impact on these benchmarks | ||||
|             # Keep using False value for consistency. | ||||
|             with config.patch("enable_cpp_symbolic_shape_guards", False): | ||||
|                 r = self._count_compile_time_instructions() | ||||
|  | ||||
|             self.results.append( | ||||
|                 ( | ||||
|  | ||||
| @ -1,8 +1,8 @@ | ||||
| add_loop_eager,compile_time_instruction_count,2953000000,0.015 | ||||
| add_loop_eager,compile_time_instruction_count,2937000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| add_loop_eager_dynamic,compile_time_instruction_count,5738000000,0.025 | ||||
| add_loop_eager_dynamic,compile_time_instruction_count,4300194436,0.025 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -10,7 +10,7 @@ add_loop_inductor,compile_time_instruction_count,29370000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44490000000,0.025 | ||||
| add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38747844521,0.025 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -18,15 +18,15 @@ add_loop_inductor_gpu,compile_time_instruction_count,25900000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| basic_modules_ListOfLinears_eager,compile_time_instruction_count,939900000,0.015 | ||||
| basic_modules_ListOfLinears_eager,compile_time_instruction_count,952700000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18270000000,0.015 | ||||
| basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18390000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16310000000,0.015 | ||||
| basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16450000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000 | ||||
|  | ||||
|  | ||||
|  | ||||
| update_hint_regression,compile_time_instruction_count,1700000000,0.02 | ||||
| update_hint_regression,compile_time_instruction_count,1661000000,0.02 | ||||
|  | ||||
|  | ||||
|  | ||||
| float_args,compile_time_instruction_count,452500000,0.015 | ||||
| float_args,compile_time_instruction_count,455500000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -62,7 +62,7 @@ aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6022000000,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| aotdispatcher_partitioner_cpu,compile_time_instruction_count,8672000000,0.015 | ||||
| aotdispatcher_partitioner_cpu,compile_time_instruction_count,8724000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -70,7 +70,7 @@ aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1917000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3859000000,0.015 | ||||
| aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3838000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -80,7 +80,6 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> | ||||
|  | ||||
| @dataclass(frozen=True, kw_only=True) | ||||
| class ExperimentConfig: | ||||
|     autotune_fallback_to_aten: bool = False | ||||
|     max_autotune: bool = True | ||||
|     coordinate_descent_tuning: bool = True | ||||
|     max_autotune_gemm_backends: str = "ATEN" | ||||
| @ -91,7 +90,6 @@ class ExperimentConfig: | ||||
|  | ||||
|     def to_options(self) -> dict[str, Any]: | ||||
|         return { | ||||
|             "autotune_fallback_to_aten": self.autotune_fallback_to_aten, | ||||
|             "max_autotune": self.max_autotune, | ||||
|             "coordinate_descent_tuning": self.coordinate_descent_tuning, | ||||
|             "max_autotune_gemm_backends": self.max_autotune_gemm_backends, | ||||
|  | ||||
| @ -38,8 +38,8 @@ void c10_cuda_check_implementation( | ||||
|         "Device-side assertions were explicitly omitted for this error check; the error probably arose while initializing the DSA handlers."); | ||||
|   } | ||||
| #endif | ||||
|  | ||||
|   TORCH_CHECK(false, check_message); | ||||
|   throw c10::AcceleratorError( | ||||
|       {__func__, __FILE__, int32_t(__LINE__)}, err, check_message); | ||||
| } | ||||
|  | ||||
| } // namespace c10::cuda | ||||
|  | ||||
| @ -200,6 +200,7 @@ static void initGlobalStreamState() { | ||||
| // Init a single CUDA or HIP stream | ||||
| // See Note [HIP Lazy Streams] | ||||
| static void initSingleStream(int p, DeviceIndex device_index, int i) { | ||||
|   CUDAGuard device_guard(device_index); | ||||
|   auto& stream = streams[p][device_index][i]; | ||||
|   auto pri = -p; // lower number is higher priority | ||||
|  | ||||
|  | ||||
| @ -295,6 +295,19 @@ class C10_API SyntaxError : public Error { | ||||
|   using Error::Error; | ||||
| }; | ||||
|  | ||||
| // Raised when accelerator API call hits an error. | ||||
| // These turn into AcceleratorError when the cross into Python | ||||
| class C10_API AcceleratorError : public Error { | ||||
|   int32_t error_code; | ||||
|  | ||||
|  public: | ||||
|   AcceleratorError(SourceLocation loc, int32_t code, const std::string& msg) | ||||
|       : Error(loc, msg), error_code(code) {} | ||||
|   int32_t get_error_code() const { | ||||
|     return error_code; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| // Base error type for all distributed errors. | ||||
| // These turn into DistError when they cross into Python. | ||||
| class C10_API DistError : public Error { | ||||
|  | ||||
| @ -133,8 +133,13 @@ inline void initGlobalDevicePoolState() { | ||||
| #else | ||||
|   // The default context is utilized for each Intel GPU device, allowing the | ||||
|   // retrieval of the context from any GPU device. | ||||
|   const auto& platform = gDevicePool.devices[0]->get_platform(); | ||||
|   gDevicePool.context = std::make_unique<sycl::context>( | ||||
|       gDevicePool.devices[0]->get_platform().ext_oneapi_get_default_context()); | ||||
| #if SYCL_COMPILER_VERSION >= 20250200 | ||||
|       platform.khr_get_default_context()); | ||||
| #else | ||||
|       platform.ext_oneapi_get_default_context()); | ||||
| #endif | ||||
| #endif | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -1063,7 +1063,7 @@ if(USE_ROCM) | ||||
|  | ||||
|     # Math libraries | ||||
|     list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS | ||||
|       roc::hipblas roc::rocblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver roc::hipblaslt) | ||||
|       roc::hipblas roc::rocblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsparselt roc::hipsolver roc::hipblaslt) | ||||
|  | ||||
|     # ---[ Kernel asserts | ||||
|     # Kernel asserts is disabled for ROCm by default. | ||||
|  | ||||
| @ -57,7 +57,8 @@ if(NCCL_FOUND)  # obtaining NCCL version and some sanity checks | ||||
|   include(CheckCXXSymbolExists) | ||||
|   check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED) | ||||
|  | ||||
|   if (NCCL_VERSION_DEFINED) | ||||
|   # this condition check only works for non static NCCL linking | ||||
|   if (NCCL_VERSION_DEFINED AND NOT USE_STATIC_NCCL) | ||||
|     set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc") | ||||
|     file(WRITE ${file} " | ||||
|       #include <iostream> | ||||
| @ -65,7 +66,6 @@ if(NCCL_FOUND)  # obtaining NCCL version and some sanity checks | ||||
|       int main() | ||||
|       { | ||||
|         std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl; | ||||
|  | ||||
|         int x; | ||||
|         ncclGetVersion(&x); | ||||
|         return x == NCCL_VERSION_CODE; | ||||
| @ -80,11 +80,9 @@ if(NCCL_FOUND)  # obtaining NCCL version and some sanity checks | ||||
| (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.") | ||||
|     endif() | ||||
|     message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}") | ||||
|   else() | ||||
|     message(STATUS "NCCL version < 2.3.5-5") | ||||
|   endif () | ||||
|   set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES}) | ||||
|  | ||||
|   set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES}) | ||||
|   message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})") | ||||
|   mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES) | ||||
| endif() | ||||
|  | ||||
| @ -151,6 +151,7 @@ if(HIP_FOUND) | ||||
|   find_package_and_print_version(miopen REQUIRED) | ||||
|   find_package_and_print_version(hipfft REQUIRED) | ||||
|   find_package_and_print_version(hipsparse REQUIRED) | ||||
|   find_package_and_print_version(hipsparselt REQUIRED) | ||||
|   find_package_and_print_version(rocprim REQUIRED) | ||||
|   find_package_and_print_version(hipcub REQUIRED) | ||||
|   find_package_and_print_version(rocthrust REQUIRED) | ||||
|  | ||||
| @ -26,8 +26,8 @@ As shown in the CPU example section of :class:`torch.autocast`, "automatic mixed | ||||
| datatype of ``torch.bfloat16`` only uses :class:`torch.autocast`. | ||||
|  | ||||
| .. warning:: | ||||
|     ``torch.cuda.amp.autocast(args...)`` and ``torch.cpu.amp.autocast(args...)`` will be deprecated. Please use ``torch.autocast("cuda", args...)`` or ``torch.autocast("cpu", args...)`` instead. | ||||
|     ``torch.cuda.amp.GradScaler(args...)`` and ``torch.cpu.amp.GradScaler(args...)`` will be deprecated. Please use ``torch.GradScaler("cuda", args...)`` or ``torch.GradScaler("cpu", args...)`` instead. | ||||
|     ``torch.cuda.amp.autocast(args...)`` and ``torch.cpu.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cuda", args...)`` or ``torch.amp.autocast("cpu", args...)`` instead. | ||||
|     ``torch.cuda.amp.GradScaler(args...)`` and ``torch.cpu.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cuda", args...)`` or ``torch.amp.GradScaler("cpu", args...)`` instead. | ||||
|  | ||||
| :class:`torch.autocast` and :class:`torch.cpu.amp.autocast` are new in version `1.10`. | ||||
|  | ||||
|  | ||||
| @ -40,6 +40,7 @@ torch.cuda | ||||
|     temperature | ||||
|     power_draw | ||||
|     clock_rate | ||||
|     AcceleratorError | ||||
|     OutOfMemoryError | ||||
|  | ||||
| Random Number Generator | ||||
|  | ||||
| @ -31,6 +31,7 @@ torch.fx.experimental.symbolic_shapes | ||||
|     PropagateUnbackedSymInts | ||||
|     DivideByKey | ||||
|     InnerTensorKey | ||||
|     Specialization | ||||
|  | ||||
|     hint_int | ||||
|     is_concrete_int | ||||
|  | ||||
| @ -360,8 +360,7 @@ Suppose we want to define a sparse tensor with the entry 3 at location | ||||
| Unspecified elements are assumed to have the same value, fill value, | ||||
| which is zero by default. We would then write: | ||||
|  | ||||
|     >>> i = [[0, 1, 1], | ||||
|              [2, 0, 2]] | ||||
|     >>> i = [[0, 1, 1], [2, 0, 2]] | ||||
|     >>> v =  [3, 4, 5] | ||||
|     >>> s = torch.sparse_coo_tensor(i, v, (2, 3)) | ||||
|     >>> s | ||||
|  | ||||
| @ -1,5 +1,7 @@ | ||||
| #!/bin/bash | ||||
| # Updates Triton to the pinned version for this copy of PyTorch | ||||
| PYTHON="python3" | ||||
| PIP="$PYTHON -m pip" | ||||
| BRANCH=$(git rev-parse --abbrev-ref HEAD) | ||||
| DOWNLOAD_PYTORCH_ORG="https://download.pytorch.org/whl" | ||||
|  | ||||
| @ -8,9 +10,9 @@ if [[ -z "${USE_XPU}" ]]; then | ||||
|  | ||||
|     TRITON_VERSION="pytorch-triton==$(cat .ci/docker/triton_version.txt)" | ||||
|     if [[ "$BRANCH" =~ .*release.* ]]; then | ||||
|         pip install --index-url ${DOWNLOAD_PYTORCH_ORG}/test/ $TRITON_VERSION | ||||
|         ${PIP} install --index-url ${DOWNLOAD_PYTORCH_ORG}/test/ $TRITON_VERSION | ||||
|     else | ||||
|         pip install --index-url ${DOWNLOAD_PYTORCH_ORG}/nightly/ $TRITON_VERSION+git$(head -c 8 .ci/docker/ci_commit_pins/triton.txt) | ||||
|         ${PIP} install --index-url ${DOWNLOAD_PYTORCH_ORG}/nightly/ $TRITON_VERSION+git$(head -c 8 .ci/docker/ci_commit_pins/triton.txt) | ||||
|     fi | ||||
| else | ||||
|     # The Triton xpu logic is as follows: | ||||
| @ -21,11 +23,11 @@ else | ||||
|     TRITON_VERSION="pytorch-triton-xpu==$(cat .ci/docker/triton_version.txt)" | ||||
|     TRITON_XPU_COMMIT_ID="$(head -c 8 .ci/docker/ci_commit_pins/triton-xpu.txt)" | ||||
|     if [[ -z "${TRITON_XPU_BUILD_FROM_SOURCE}" ]]; then | ||||
|         pip install --index-url ${DOWNLOAD_PYTORCH_ORG}/nightly/ ${TRITON_VERSION}+git${TRITON_XPU_COMMIT_ID} | ||||
|         ${PIP} install --index-url ${DOWNLOAD_PYTORCH_ORG}/nightly/ ${TRITON_VERSION}+git${TRITON_XPU_COMMIT_ID} | ||||
|     else | ||||
|         TRITON_XPU_REPO="https://github.com/intel/intel-xpu-backend-for-triton" | ||||
|  | ||||
|         # force-reinstall to ensure the pinned version is installed | ||||
|         pip install --force-reinstall "git+${TRITON_XPU_REPO}@${TRITON_XPU_COMMIT_ID}#subdirectory=python" | ||||
|         ${PIP} install --force-reinstall "git+${TRITON_XPU_REPO}@${TRITON_XPU_COMMIT_ID}#subdirectory=python" | ||||
|     fi | ||||
| fi | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # PyTorch Release Scripts | ||||
| # PyTorch release scripts performing branch cut and applying release only changes | ||||
|  | ||||
| These are a collection of scripts that are to be used for release activities. | ||||
|  | ||||
| @ -7,54 +7,12 @@ These are a collection of scripts that are to be used for release activities. | ||||
| >       The basic idea being that there should be no potential to do anything dangerous unless | ||||
| >       `DRY_RUN` is explicitly set to `disabled`. | ||||
|  | ||||
| ## Requirements to actually run these scripts | ||||
| * AWS access to pytorch account | ||||
| * Access to upload conda packages to the `pytorch` conda channel | ||||
| * Access to the PyPI repositories | ||||
| ### Order of Execution | ||||
|  | ||||
| 1. Run cut-release-branch.sh to cut the release branch | ||||
| 2. Run tag-docker-images.sh to tag current docker images with release tag and push them to docker.io. These images will be used to build the release. | ||||
| 3. Run apply-release-changes.sh to apply release only changes to create a PR with release only changes similar to this [PR](https://github.com/pytorch/pytorch/pull/149056) | ||||
|  | ||||
| ## Promote | ||||
| #### Promoting packages | ||||
|  | ||||
| These are scripts related to promotion of release candidates to GA channels, these | ||||
| can actually be used to promote pytorch, libtorch, and related domain libraries. | ||||
|  | ||||
| ### Usage | ||||
|  | ||||
| Usage should be fairly straightforward and should actually require no extra variables | ||||
| if you are running from the correct git tags. (i.e. the GA tag to promote is currently | ||||
| checked out) | ||||
|  | ||||
| `PACKAGE_TYPE` and `PACKAGE_NAME` can be swapped out to promote other packages. | ||||
|  | ||||
| #### Promoting pytorch wheels | ||||
| ```bash | ||||
| promote/s3_to_s3.sh | ||||
| ``` | ||||
|  | ||||
| #### Promoting libtorch archives | ||||
| ```bash | ||||
| PACKAGE_TYPE=libtorch PACKAGE_NAME=libtorch promote/s3_to_s3.sh | ||||
| ``` | ||||
|  | ||||
| #### Promoting conda packages | ||||
| ```bash | ||||
| promote/conda_to_conda.sh | ||||
| ``` | ||||
|  | ||||
| #### Promoting wheels to PyPI | ||||
| **WARNING**: These can only be run once and cannot be undone, run with caution | ||||
| ``` | ||||
| promote/wheel_to_pypi.sh | ||||
| ``` | ||||
|  | ||||
| ## Restoring backups | ||||
|  | ||||
| All release candidates are currently backed up to `s3://pytorch-backup/${TAG_NAME}` and | ||||
| can be restored to the test channels with the `restore-backup.sh` script. | ||||
|  | ||||
| Which backup to restore from is dictated by the `RESTORE_FROM` environment variable. | ||||
|  | ||||
| ### Usage | ||||
| ```bash | ||||
| RESTORE_FROM=v1.5.0-rc5 ./restore-backup.sh | ||||
| ``` | ||||
|  Scripts for Promotion of PyTorch packages are under test-infra repository. Please follow [README.md](https://github.com/pytorch/test-infra/blob/main/release/README.md) | ||||
|  | ||||
| @ -1,61 +0,0 @@ | ||||
| #!/usr/bin/env bash | ||||
|  | ||||
| exit_if_not_on_git_tag() { | ||||
|     # Have an override for debugging purposes | ||||
|     if [[ -n "${TEST_WITHOUT_GIT_TAG-}" ]] ;then | ||||
|         >&2 echo "+ WARN: Continuing without being on a git tag" | ||||
|         exit 0 | ||||
|     fi | ||||
|     # Exit if we're not currently on a git tag | ||||
|     if ! git describe --tags --exact >/dev/null 2>/dev/null; then | ||||
|         >&2 echo "- ERROR: Attempting to promote on a non-git tag, must have tagged current commit locally first" | ||||
|         exit 1 | ||||
|     fi | ||||
|     # Exit if we're currently on an RC | ||||
|     if git describe --tags | grep "-rc" >/dev/null 2>/dev/null; then | ||||
|         >&2 echo "- ERROR: Attempting to promote on a non GA git tag, current tag must be a GA tag" | ||||
|         >&2 echo "         Example: v1.5.0" | ||||
|         exit 1 | ||||
|     fi | ||||
| } | ||||
|  | ||||
| get_pytorch_version() { | ||||
|     if [[ -n "${TEST_WITHOUT_GIT_TAG-}" ]];then | ||||
|         if  [[ -z "${TEST_PYTORCH_PROMOTE_VERSION-}" ]]; then | ||||
|             >&2 echo "- ERROR: Specified TEST_WITHOUT_GIT_TAG without specifying TEST_PYTORCH_PROMOTE_VERSION" | ||||
|             >&2 echo "-        TEST_PYTORCH_PROMOTE_VERSION must be specified" | ||||
|             exit 1 | ||||
|         else | ||||
|             echo "${TEST_PYTORCH_PROMOTE_VERSION}" | ||||
|             exit 0 | ||||
|         fi | ||||
|     fi | ||||
|     exit_if_not_on_git_tag | ||||
|     # Echo git tag, strip leading v | ||||
|     git describe --tags | sed -e 's/^v//' | ||||
| } | ||||
|  | ||||
| aws_promote() { | ||||
|     package_name=$1 | ||||
|     pytorch_version=$(get_pytorch_version) | ||||
|     # Dry run by default | ||||
|     DRY_RUN=${DRY_RUN:-enabled} | ||||
|     DRY_RUN_FLAG="--dryrun" | ||||
|     if [[ $DRY_RUN = "disabled" ]]; then | ||||
|         DRY_RUN_FLAG="" | ||||
|     fi | ||||
|     AWS=${AWS:-aws} | ||||
|     ( | ||||
|         set -x | ||||
|         ${AWS} s3 cp ${DRY_RUN_FLAG} \ | ||||
|             --only-show-errors \ | ||||
|             --acl public-read \ | ||||
|             --recursive \ | ||||
|             --exclude '*' \ | ||||
|             --include "*${package_name}-${pytorch_version}*" \ | ||||
|             "${PYTORCH_S3_FROM/\/$//}" \ | ||||
|             "${PYTORCH_S3_TO/\/$//}" | ||||
|     ) | ||||
|     # ^ We grep for package_name-.*pytorch_version to avoid any situations where domain libraries have | ||||
|     #   the same version on our S3 buckets | ||||
| } | ||||
| @ -1,45 +0,0 @@ | ||||
| #!/usr/bin/env bash | ||||
|  | ||||
| # Preps binaries for publishing to pypi by removing the | ||||
| # version suffix we normally add for all binaries | ||||
| # (outside of default ones, CUDA 10.2 currently) | ||||
|  | ||||
| # Usage is: | ||||
| # $ prep_binary_for_pypy.sh <path_to_whl_file> <path_to_multiple_whl_files> | ||||
|  | ||||
| # Will output a whl in your current directory | ||||
|  | ||||
| set -eou pipefail | ||||
| shopt -s globstar | ||||
|  | ||||
| OUTPUT_DIR=${OUTPUT_DIR:-$(pwd)} | ||||
|  | ||||
| tmp_dir="$(mktemp -d)" | ||||
| trap 'rm -rf ${tmp_dir}' EXIT | ||||
|  | ||||
| for whl_file in "$@"; do | ||||
|     whl_file=$(realpath "${whl_file}") | ||||
|     whl_dir="${tmp_dir}/$(basename "${whl_file}")_unzipped" | ||||
|     mkdir -pv "${whl_dir}" | ||||
|     ( | ||||
|         set -x | ||||
|         unzip -q "${whl_file}" -d "${whl_dir}" | ||||
|     ) | ||||
|     version_with_suffix=$(grep '^Version:' "${whl_dir}"/*/METADATA | cut -d' ' -f2) | ||||
|     version_with_suffix_escaped=${version_with_suffix/+/%2B} | ||||
|     # Remove all suffixed +bleh versions | ||||
|     version_no_suffix=${version_with_suffix/+*/} | ||||
|     new_whl_file=${OUTPUT_DIR}/$(basename "${whl_file/${version_with_suffix_escaped}/${version_no_suffix}}") | ||||
|     dist_info_folder=$(find "${whl_dir}" -type d -name '*.dist-info' | head -1) | ||||
|     basename_dist_info_folder=$(basename "${dist_info_folder}") | ||||
|     dirname_dist_info_folder=$(dirname "${dist_info_folder}") | ||||
|     ( | ||||
|         set -x | ||||
|         find "${dist_info_folder}" -type f -exec sed -i "s!${version_with_suffix}!${version_no_suffix}!" {} \; | ||||
|         # Moves distinfo from one with a version suffix to one without | ||||
|         # Example: torch-1.8.0+cpu.dist-info => torch-1.8.0.dist-info | ||||
|         mv "${dist_info_folder}" "${dirname_dist_info_folder}/${basename_dist_info_folder/${version_with_suffix}/${version_no_suffix}}" | ||||
|         cd "${whl_dir}" | ||||
|         zip -qr "${new_whl_file}" . | ||||
|     ) | ||||
| done | ||||
| @ -1,19 +0,0 @@ | ||||
| #!/usr/bin/env bash | ||||
|  | ||||
| set -eou pipefail | ||||
|  | ||||
| DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" | ||||
| source "${DIR}/common_utils.sh" | ||||
|  | ||||
| # Allow for users to pass PACKAGE_NAME | ||||
| # For use with other packages, i.e. torchvision, etc. | ||||
| PACKAGE_NAME=${PACKAGE_NAME:-torch} | ||||
| PACKAGE_TYPE=${PACKAGE_TYPE:-whl} | ||||
|  | ||||
| PYTORCH_S3_BUCKET=${PYTORCH_S3_BUCKET:-s3://pytorch} | ||||
| FROM=${FROM:-test} | ||||
| PYTORCH_S3_FROM=${PYTORCH_S3_FROM:-${PYTORCH_S3_BUCKET}/${PACKAGE_TYPE}/${FROM}} | ||||
| TO=${TO:-} | ||||
| PYTORCH_S3_TO=${PYTORCH_S3_TO:-${PYTORCH_S3_BUCKET}/${PACKAGE_TYPE}/${TO}} | ||||
|  | ||||
| aws_promote "${PACKAGE_NAME}" | ||||
| @ -1,69 +0,0 @@ | ||||
| #!/usr/bin/env bash | ||||
|  | ||||
| set -eou pipefail | ||||
|  | ||||
| DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" | ||||
| source "${DIR}/common_utils.sh" | ||||
|  | ||||
| # Allow for users to pass PACKAGE_NAME | ||||
| # For use with other packages, i.e. torchvision, etc. | ||||
| PACKAGE_NAME=${PACKAGE_NAME:-torch} | ||||
|  | ||||
| pytorch_version="$(get_pytorch_version)" | ||||
| # Refers to the specific package we'd like to promote | ||||
| # i.e. VERSION_SUFFIX='%2Bcu102' | ||||
| #      torch-1.8.0+cu102 -> torch-1.8.0 | ||||
| VERSION_SUFFIX=${VERSION_SUFFIX:-} | ||||
| # Refers to the specific platofmr we'd like to promote | ||||
| # i.e. PLATFORM=linux_x86_64 | ||||
| # For domains like torchaudio / torchtext this is to be left blank | ||||
| PLATFORM=${PLATFORM:-} | ||||
|  | ||||
| pkgs_to_promote=$(\ | ||||
|     curl -fsSL https://download.pytorch.org/whl/torch_stable.html \ | ||||
|         | grep "${PACKAGE_NAME}-${pytorch_version}${VERSION_SUFFIX}-" \ | ||||
|         | grep "${PLATFORM}" \ | ||||
|         | cut -d '"' -f2 | ||||
| ) | ||||
|  | ||||
| tmp_dir="$(mktemp -d)" | ||||
| output_tmp_dir="$(mktemp -d)" | ||||
| trap 'rm -rf ${tmp_dir} ${output_tmp_dir}' EXIT | ||||
| pushd "${output_tmp_dir}" | ||||
|  | ||||
| # Dry run by default | ||||
| DRY_RUN=${DRY_RUN:-enabled} | ||||
| # On dry run just echo the commands that are meant to be run | ||||
| TWINE_UPLOAD="echo twine upload" | ||||
| if [[ $DRY_RUN = "disabled" ]]; then | ||||
|     TWINE_UPLOAD="twine upload" | ||||
| fi | ||||
|  | ||||
| for pkg in ${pkgs_to_promote}; do | ||||
|     pkg_basename="$(basename "${pkg}")" | ||||
|     # Don't attempt to change if manylinux2014 | ||||
|     if [[ "${pkg}" != *manylinux2014* ]]; then | ||||
|         pkg_basename="$(basename "${pkg//linux/manylinux1}")" | ||||
|     fi | ||||
|     orig_pkg="${tmp_dir}/${pkg_basename}" | ||||
|     ( | ||||
|         set -x | ||||
|         # Download package, sub out linux for manylinux1 | ||||
|         curl -fsSL -o "${orig_pkg}" "https://download.pytorch.org/whl/${pkg}" | ||||
|     ) | ||||
|  | ||||
|     if [[ -n "${VERSION_SUFFIX}" ]]; then | ||||
|         OUTPUT_DIR="${output_tmp_dir}" ${DIR}/prep_binary_for_pypi.sh "${orig_pkg}" | ||||
|     else | ||||
|         mv "${orig_pkg}" "${output_tmp_dir}/" | ||||
|     fi | ||||
|  | ||||
|     ( | ||||
|         set -x | ||||
|         ${TWINE_UPLOAD} \ | ||||
|             --disable-progress-bar \ | ||||
|             --non-interactive \ | ||||
|             ./*.whl | ||||
|         rm -rf ./*.whl | ||||
|     ) | ||||
| done | ||||
| @ -1,31 +0,0 @@ | ||||
| #!/usr/bin/env bash | ||||
|  | ||||
| set -eou pipefail | ||||
|  | ||||
| DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" | ||||
| source "${DIR}/promote/common_utils.sh" | ||||
|  | ||||
| if [[ -z "${RESTORE_FROM:-}" ]]; then | ||||
|     echo "ERROR: RESTORE_FROM environment variable must be specified" | ||||
|     echo "       example: RESTORE_FROM=v1.6.0-rc3 ${0}" | ||||
|     exit 1 | ||||
| fi | ||||
|  | ||||
| DRY_RUN=${DRY_RUN:-enabled} | ||||
|  | ||||
| PYTORCH_S3_BACKUP_BUCKET=${PYTORCH_S3_BACKUP_BUCKET:-s3://pytorch-backup/${RESTORE_FROM}} | ||||
| PYTORCH_S3_TEST_BUCKET=${PYTORCH_S3_TEST_BUCKET:-s3://pytorch/} | ||||
| PYTORCH_S3_FROM=${PYTORCH_S3_FROM:-${PYTORCH_S3_BACKUP_BUCKET}} | ||||
| PYTORCH_S3_TO=${PYTORCH_S3_TO:-s3://pytorch/} | ||||
|  | ||||
| restore_wheels() { | ||||
|     aws_promote torch whl | ||||
| } | ||||
|  | ||||
| restore_libtorch() { | ||||
|     aws_promote libtorch-* libtorch | ||||
| } | ||||
|  | ||||
|  | ||||
| restore_wheels | ||||
| restore_libtorch | ||||
| @ -1322,5 +1322,33 @@ class TestFullyShardOldImport(FSDPTestMultiThread): | ||||
|         model(inp).sum().backward() | ||||
|  | ||||
|  | ||||
| class TestFullyShardMixedDtypeParam(FSDPTestMultiThread): | ||||
|     @property | ||||
|     def world_size(self) -> int: | ||||
|         return 2 | ||||
|  | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     def test_mixed_dtypes_no_grad_param(self): | ||||
|         class Model(torch.nn.Module): | ||||
|             def __init__(self): | ||||
|                 super().__init__() | ||||
|                 # no grad params with different dtypes | ||||
|                 self.w_fp8 = torch.nn.Parameter( | ||||
|                     torch.empty((256, 256), dtype=torch.float8_e4m3fn), | ||||
|                     requires_grad=False, | ||||
|                 ) | ||||
|                 self.w_fp32 = torch.nn.Parameter( | ||||
|                     torch.empty((256, 256), dtype=torch.float32) | ||||
|                 ) | ||||
|  | ||||
|             def forward(self, input): | ||||
|                 return | ||||
|  | ||||
|         mesh = init_device_mesh(device_type.type, (self.world_size,)) | ||||
|         model = Model() | ||||
|         fully_shard(model, mesh=mesh) | ||||
|         model(0) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
|  | ||||
| @ -1,135 +0,0 @@ | ||||
| # Owner(s): ["oncall: distributed_checkpointing"] | ||||
|  | ||||
| import os | ||||
| import sys | ||||
|  | ||||
| import torch | ||||
| import torch.distributed.checkpoint as dist_cp | ||||
| from torch import distributed as dist | ||||
| from torch.distributed.checkpoint.scripts._consolidate_hf_safetensors import ( | ||||
|     consolidate_safetensors_files, | ||||
| ) | ||||
| from torch.distributed.device_mesh import init_device_mesh | ||||
| from torch.distributed.tensor import DTensor, Shard | ||||
| from torch.testing._internal.common_utils import run_tests | ||||
| from torch.testing._internal.distributed._tensor.common_dtensor import ( | ||||
|     DTensorTestBase, | ||||
|     skip_if_lt_x_gpu, | ||||
|     with_comms, | ||||
| ) | ||||
| from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir | ||||
|  | ||||
|  | ||||
| class TestConsolidateHFSafeTensors(DTensorTestBase): | ||||
|     def _create_d_tensors(self) -> None: | ||||
|         global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) | ||||
|         mesh_shape = (self.world_size,) | ||||
|         mesh_1d = init_device_mesh(self.device_type, mesh_shape) | ||||
|  | ||||
|         # Create local tensor with row-wise sharding | ||||
|         rows_per_rank = global_tensor.shape[0] // self.world_size | ||||
|         start_row = self.rank * rows_per_rank | ||||
|         end_row = start_row + rows_per_rank | ||||
|         local_tensor = global_tensor[start_row:end_row].clone() | ||||
|  | ||||
|         # Create DTensor with row-wise sharding | ||||
|         dtensor = DTensor.from_local( | ||||
|             local_tensor, | ||||
|             device_mesh=mesh_1d, | ||||
|             placements=[Shard(0)], | ||||
|             shape=global_tensor.shape, | ||||
|             stride=(4, 1), | ||||
|         ) | ||||
|  | ||||
|         # Create local tensor with column-wise sharding | ||||
|         cols_per_rank = global_tensor.shape[1] // self.world_size | ||||
|         start_col = self.rank * cols_per_rank | ||||
|         end_col = start_col + cols_per_rank | ||||
|         local_tensor_col = global_tensor[:, start_col:end_col].clone() | ||||
|  | ||||
|         # Create DTensor with column-wise sharding | ||||
|         dtensor_col = DTensor.from_local( | ||||
|             local_tensor_col, | ||||
|             device_mesh=mesh_1d, | ||||
|             placements=[Shard(1)],  # Column-wise sharding | ||||
|             shape=global_tensor.shape, | ||||
|             stride=(4, 1), | ||||
|         ) | ||||
|  | ||||
|         state_dict_to_save = {"dtensor": dtensor, "dtensor_col": dtensor_col} | ||||
|         dist_cp.save( | ||||
|             state_dict=state_dict_to_save, | ||||
|             storage_writer=dist_cp._HuggingFaceStorageWriter( | ||||
|                 path=self.temp_dir, save_sharded=True | ||||
|             ), | ||||
|         ) | ||||
|         dist.barrier() | ||||
|         os.sync() | ||||
|  | ||||
|     @with_comms | ||||
|     @with_temp_dir | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     def test_consolidate_to_one_file(self) -> None: | ||||
|         try: | ||||
|             import safetensors | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         checkpoint_dir = self.temp_dir | ||||
|         output_dir = os.path.join(checkpoint_dir, "consolidated") | ||||
|         os.makedirs(output_dir, exist_ok=True) | ||||
|  | ||||
|         self._create_d_tensors() | ||||
|  | ||||
|         global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) | ||||
|  | ||||
|         if self.rank == 0: | ||||
|             consolidate_safetensors_files(checkpoint_dir, output_dir) | ||||
|  | ||||
|             file_path = os.path.join(output_dir, "model-00001-of-00001.safetensors") | ||||
|             loaded_dict = safetensors.torch.load_file(file_path) | ||||
|             self.assertEqual(loaded_dict.keys(), {"dtensor", "dtensor_col"}) | ||||
|             self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor)) | ||||
|             self.assertTrue(torch.equal(loaded_dict["dtensor_col"], global_tensor)) | ||||
|         dist.barrier() | ||||
|  | ||||
|     @with_comms | ||||
|     @with_temp_dir | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     def test_consolidate_to_two_files(self): | ||||
|         try: | ||||
|             import safetensors | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         checkpoint_dir = self.temp_dir | ||||
|         output_dir = os.path.join(checkpoint_dir, "consolidated") | ||||
|         os.makedirs(output_dir, exist_ok=True) | ||||
|  | ||||
|         self._create_d_tensors() | ||||
|  | ||||
|         global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) | ||||
|  | ||||
|         if self.rank == 0: | ||||
|             fqn_to_index_mapping = {"dtensor": 1, "dtensor_col": 2} | ||||
|             consolidate_safetensors_files( | ||||
|                 checkpoint_dir, output_dir, fqn_to_index_mapping | ||||
|             ) | ||||
|  | ||||
|             file1_path = os.path.join(output_dir, "model-00001-of-00002.safetensors") | ||||
|             file2_path = os.path.join(output_dir, "model-00002-of-00002.safetensors") | ||||
|  | ||||
|             loaded_dict = safetensors.torch.load_file(file1_path) | ||||
|             self.assertEqual(loaded_dict.keys(), {"dtensor"}) | ||||
|             self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor)) | ||||
|  | ||||
|             loaded_dict_col = safetensors.torch.load_file(file2_path) | ||||
|             self.assertEqual(loaded_dict_col.keys(), {"dtensor_col"}) | ||||
|             self.assertTrue(torch.equal(loaded_dict_col["dtensor_col"], global_tensor)) | ||||
|         dist.barrier() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
| @ -1,420 +0,0 @@ | ||||
| # Owner(s): ["oncall: distributed_checkpointing"] | ||||
|  | ||||
| import sys | ||||
|  | ||||
| import torch | ||||
| import torch.distributed.checkpoint as dist_cp | ||||
| from torch.distributed.checkpoint import _HuggingFaceLoadPlanner | ||||
| from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner | ||||
| from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys | ||||
| from torch.distributed.device_mesh import init_device_mesh | ||||
| from torch.distributed.tensor import distribute_tensor, Replicate, Shard, zeros | ||||
| from torch.testing._internal.common_utils import ( | ||||
|     instantiate_parametrized_tests, | ||||
|     run_tests, | ||||
|     TestCase, | ||||
| ) | ||||
| from torch.testing._internal.distributed._tensor.common_dtensor import ( | ||||
|     DTensorTestBase, | ||||
|     skip_if_lt_x_gpu, | ||||
|     with_comms, | ||||
| ) | ||||
| from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir | ||||
|  | ||||
|  | ||||
| CHECKPOINT_DIR = "checkpoint" | ||||
|  | ||||
|  | ||||
| class MyTestModule(torch.nn.Module): | ||||
|     def __init__(self) -> None: | ||||
|         super().__init__() | ||||
|         self.linear_1 = torch.nn.Linear(5, 5) | ||||
|         self.linear_2 = torch.nn.Linear(5, 1) | ||||
|         self.emb = torch.nn.EmbeddingBag(5, 10) | ||||
|  | ||||
| class TestSingleRankSaveLoad(TestCase): | ||||
|     @with_temp_dir | ||||
|     def test_save(self) -> None: | ||||
|         try: | ||||
|             from safetensors.torch import load_file | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         CHECKPOINT_DIR = self.temp_dir | ||||
|  | ||||
|         state_dict_to_save = MyTestModule().state_dict() | ||||
|         dist_cp.save( | ||||
|             state_dict=state_dict_to_save, | ||||
|                 storage_writer=dist_cp._HuggingFaceStorageWriter( | ||||
|                     path=CHECKPOINT_DIR | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|         state_dict_loaded = load_file(CHECKPOINT_DIR + "/model-00001-of-00001.safetensors") | ||||
|         self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys())) | ||||
|         for key in state_dict_to_save.keys(): | ||||
|             self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_loaded[key])) | ||||
|  | ||||
|     @with_temp_dir | ||||
|     def test_load(self) -> None: | ||||
|         try: | ||||
|             from safetensors.torch import save_file | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         CHECKPOINT_DIR = self.temp_dir | ||||
|  | ||||
|         state_dict_to_save = MyTestModule().state_dict() | ||||
|         state_dict_to_load = MyTestModule().state_dict() | ||||
|         save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors") | ||||
|  | ||||
|         dist_cp.load( | ||||
|             state_dict=state_dict_to_load, | ||||
|                 storage_reader=dist_cp._HuggingFaceStorageReader( | ||||
|                     path=CHECKPOINT_DIR | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|         self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys())) | ||||
|         for key in state_dict_to_save.keys(): | ||||
|             self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_to_load[key])) | ||||
|  | ||||
|     @with_temp_dir | ||||
|     def test_load_into_empty_dict(self) -> None: | ||||
|         try: | ||||
|             from safetensors.torch import save_file | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         CHECKPOINT_DIR = self.temp_dir | ||||
|  | ||||
|         state_dict_to_save = MyTestModule().state_dict() | ||||
|         save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors") | ||||
|  | ||||
|         state_dict_loaded = _load_state_dict_from_keys( | ||||
|                 storage_reader=dist_cp._HuggingFaceStorageReader( | ||||
|                     path=CHECKPOINT_DIR | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|         self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys())) | ||||
|         for key in state_dict_to_save.keys(): | ||||
|             self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_loaded[key])) | ||||
|  | ||||
|     @with_temp_dir | ||||
|     def test_load_allowing_resize(self) -> None: | ||||
|         try: | ||||
|             from safetensors.torch import save_file | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         CHECKPOINT_DIR = self.temp_dir | ||||
|  | ||||
|         state_dict_to_save = MyTestModule().state_dict() | ||||
|         save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors") | ||||
|  | ||||
|         state_dict_to_load= {} | ||||
|         for key in state_dict_to_save.keys(): | ||||
|             state_dict_to_load[key] = torch.zeros(1) | ||||
|  | ||||
|         dist_cp.load( | ||||
|             state_dict=state_dict_to_load, | ||||
|                 storage_reader=dist_cp._HuggingFaceStorageReader( | ||||
|                     path=CHECKPOINT_DIR | ||||
|                 ), | ||||
|                 planner=_HuggingFaceLoadPlanner(allow_tensor_resize=True), | ||||
|             ) | ||||
|  | ||||
|         self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys())) | ||||
|         for key in state_dict_to_save.keys(): | ||||
|             self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_to_load[key])) | ||||
|  | ||||
| ONE_D_PLACEMENTS = [ | ||||
|     [Shard(0)], | ||||
|     [Replicate()], | ||||
| ] | ||||
| ONE_D_TO_ONE_D_PLACEMENTS = [ | ||||
|     ([Replicate()], [Shard(0)]), | ||||
|     ([Shard(0)], [Replicate()]), | ||||
| ] | ||||
|  | ||||
| TWO_D_PLACEMENTS = [ | ||||
|     [Replicate(), Replicate()], | ||||
|     [Replicate(), Shard(0)], | ||||
|     [Shard(0), Replicate()], | ||||
|     [Shard(0), Shard(0)], | ||||
| ] | ||||
| TWO_D_TO_TWO_D_PLACEMENTS = [] | ||||
| for p1 in TWO_D_PLACEMENTS: | ||||
|     for p2 in TWO_D_PLACEMENTS: | ||||
|         if p1 != p2: | ||||
|             TWO_D_TO_TWO_D_PLACEMENTS.append((p1, p2)) | ||||
|  | ||||
|  | ||||
| @instantiate_parametrized_tests | ||||
| class TestDTensorReshardPlacementChange(DTensorTestBase): | ||||
|     """ | ||||
|     Test DCP reshard for DTensor with placements changes and without world_size change and mesh_tensor change. | ||||
|     """ | ||||
|  | ||||
|     @with_comms | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     @with_temp_dir | ||||
|     def test_1d_to_1d_reshard_placement_change(self) -> None: | ||||
|         try: | ||||
|             import safetensors | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         CHECKPOINT_DIR = self.temp_dir | ||||
|  | ||||
|         for one_d_to_one_d_placements in ONE_D_TO_ONE_D_PLACEMENTS: | ||||
|             original_placement, new_placement = one_d_to_one_d_placements | ||||
|  | ||||
|             global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) | ||||
|             mesh_shape = (self.world_size,) | ||||
|             device_mesh = init_device_mesh(self.device_type, mesh_shape) | ||||
|             dtensor = distribute_tensor( | ||||
|                 global_tensor, device_mesh, placements=original_placement | ||||
|             ) | ||||
|             state_dict_to_save = {"dtensor": dtensor} | ||||
|  | ||||
|             dist_cp.save( | ||||
|                 state_dict=state_dict_to_save, | ||||
|                 storage_writer=dist_cp._HuggingFaceStorageWriter( | ||||
|                     path=CHECKPOINT_DIR, | ||||
|                     save_sharded=True, | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|             zero_dtensor = zeros( | ||||
|                 [4, 4], device_mesh=device_mesh, placements=new_placement | ||||
|             ) | ||||
|             state_dict_to_load = {"dtensor": zero_dtensor} | ||||
|  | ||||
|             dist_cp.load( | ||||
|                 state_dict=state_dict_to_load, | ||||
|                 storage_reader=dist_cp._HuggingFaceStorageReader( | ||||
|                     CHECKPOINT_DIR, | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|             # materialize the whole tensor to compare with the original global_tensor | ||||
|             state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute( | ||||
|                 device_mesh, | ||||
|                 placements=[Replicate()], | ||||
|             ) | ||||
|             self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local()) | ||||
|  | ||||
|             # redistribute the tensor back to its original placement for comparison. | ||||
|             state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute( | ||||
|                 device_mesh, | ||||
|                 placements=original_placement, | ||||
|             ) | ||||
|             self.assertEqual( | ||||
|                 state_dict_to_save["dtensor"].to_local(), | ||||
|                 state_dict_to_load["dtensor"].to_local(), | ||||
|             ) | ||||
|  | ||||
|     @with_comms | ||||
|     @skip_if_lt_x_gpu(4) | ||||
|     @with_temp_dir | ||||
|     def test_2d_to_2d_reshard_placement_change(self) -> None: | ||||
|         try: | ||||
|             import safetensors | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         CHECKPOINT_DIR = self.temp_dir | ||||
|         for two_d_to_two_d_placements in TWO_D_TO_TWO_D_PLACEMENTS: | ||||
|             original_placement, new_placement = two_d_to_two_d_placements | ||||
|  | ||||
|             global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) | ||||
|             mesh_shape = (2, self.world_size // 2) | ||||
|             mesh_2d = init_device_mesh(self.device_type, mesh_shape) | ||||
|             dtensor = distribute_tensor( | ||||
|                 global_tensor, | ||||
|                 mesh_2d, | ||||
|                 placements=original_placement, | ||||
|             ) | ||||
|             state_dict_to_save = {"dtensor": dtensor} | ||||
|  | ||||
|             dist_cp.save( | ||||
|                 state_dict=state_dict_to_save, | ||||
|                 storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True), | ||||
|                 planner=dist_cp.DefaultSavePlanner(), | ||||
|             ) | ||||
|  | ||||
|             zero_dtensor = zeros([4, 4], device_mesh=mesh_2d, placements=new_placement) | ||||
|             state_dict_to_load = {"dtensor": zero_dtensor} | ||||
|  | ||||
|             dist_cp.load( | ||||
|                 state_dict=state_dict_to_load, | ||||
|                 storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR), | ||||
|             ) | ||||
|  | ||||
|             state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute( | ||||
|                 mesh_2d, | ||||
|                 placements=[Replicate(), Replicate()], | ||||
|             ) | ||||
|             self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local()) | ||||
|  | ||||
|             state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute( | ||||
|                 mesh_2d, | ||||
|                 placements=original_placement, | ||||
|             ) | ||||
|             self.assertEqual( | ||||
|                 state_dict_to_save["dtensor"].to_local(), | ||||
|                 state_dict_to_load["dtensor"].to_local(), | ||||
|             ) | ||||
|  | ||||
|  | ||||
| class TestDTensorReshardMeshChange(DTensorTestBase): | ||||
|     """ | ||||
|     Test DCP reshard for DTensor with placements changes and mesh_tensor change. | ||||
|     """ | ||||
|  | ||||
|     @with_comms | ||||
|     @with_temp_dir | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     def test_1d_to_2d_reshard_mesh_change(self) -> None: | ||||
|         try: | ||||
|             import safetensors | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         CHECKPOINT_DIR = self.temp_dir | ||||
|         for placements_1d in ONE_D_PLACEMENTS: | ||||
|             global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) | ||||
|             mesh_shape = (self.world_size,) | ||||
|             mesh_1d = init_device_mesh(self.device_type, mesh_shape) | ||||
|             dtensor = distribute_tensor( | ||||
|                 global_tensor, mesh_1d, placements=placements_1d | ||||
|             ) | ||||
|             state_dict_to_save = {"dtensor": dtensor} | ||||
|  | ||||
|             dist_cp.save( | ||||
|                 state_dict=state_dict_to_save, | ||||
|                 storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True), | ||||
|             ) | ||||
|  | ||||
|             for placements_2d in TWO_D_PLACEMENTS: | ||||
|                 mesh_shape = (2, self.world_size // 2) | ||||
|                 mesh_2d = init_device_mesh(self.device_type, mesh_shape) | ||||
|  | ||||
|                 zero_dtensor = zeros( | ||||
|                     [4, 4], device_mesh=mesh_2d, placements=placements_2d | ||||
|                 ) | ||||
|                 state_dict_to_load = {"dtensor": zero_dtensor} | ||||
|  | ||||
|                 dist_cp.load( | ||||
|                     state_dict=state_dict_to_load, | ||||
|                     storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR), | ||||
|                     planner=dist_cp.DefaultLoadPlanner(), | ||||
|                 ) | ||||
|  | ||||
|                 # materialzie the whole tensor to compare with the original global_tensor | ||||
|                 state_dict_to_load["dtensor"] = state_dict_to_load[ | ||||
|                     "dtensor" | ||||
|                 ].redistribute( | ||||
|                     mesh_2d, | ||||
|                     placements=[Replicate(), Replicate()], | ||||
|                 ) | ||||
|                 self.assertEqual( | ||||
|                     global_tensor, state_dict_to_load["dtensor"].to_local() | ||||
|                 ) | ||||
|  | ||||
|     @with_comms | ||||
|     @with_temp_dir | ||||
|     @skip_if_lt_x_gpu(4) | ||||
|     def test_2d_to_1d_reshard_mesh_change(self) -> None: | ||||
|         try: | ||||
|             import safetensors | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         CHECKPOINT_DIR = self.temp_dir | ||||
|         for placements_2d in TWO_D_PLACEMENTS: | ||||
|             global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) | ||||
|             mesh_shape = (2, self.world_size // 2) | ||||
|             mesh_2d = init_device_mesh(self.device_type, mesh_shape) | ||||
|             dtensor = distribute_tensor( | ||||
|                 global_tensor, mesh_2d, placements=placements_2d | ||||
|             ) | ||||
|             state_dict_to_save = {"dtensor": dtensor} | ||||
|  | ||||
|             dist_cp.save( | ||||
|                 state_dict=state_dict_to_save, | ||||
|                 storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True), | ||||
|                 planner=dist_cp.DefaultSavePlanner(), | ||||
|             ) | ||||
|  | ||||
|             for placements_1d in ONE_D_PLACEMENTS: | ||||
|                 mesh_shape = (self.world_size,) | ||||
|                 mesh_1d = init_device_mesh(self.device_type, mesh_shape) | ||||
|  | ||||
|                 zero_dtensor = zeros( | ||||
|                     [4, 4], device_mesh=mesh_1d, placements=placements_1d | ||||
|                 ) | ||||
|                 state_dict_to_load = {"dtensor": zero_dtensor} | ||||
|  | ||||
|                 dist_cp.load( | ||||
|                     state_dict=state_dict_to_load, | ||||
|                     storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR), | ||||
|                     planner=dist_cp.DefaultLoadPlanner(), | ||||
|                 ) | ||||
|  | ||||
|                 # materialzie the whole tensor to compare with the original global_tensor | ||||
|                 state_dict_to_load["dtensor"] = state_dict_to_load[ | ||||
|                     "dtensor" | ||||
|                 ].redistribute( | ||||
|                     mesh_1d, | ||||
|                     placements=[Replicate()], | ||||
|                 ) | ||||
|                 self.assertEqual( | ||||
|                     global_tensor, state_dict_to_load["dtensor"].to_local() | ||||
|                 ) | ||||
|  | ||||
|     @with_comms | ||||
|     @with_temp_dir | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     def test_dtensor_checkpoint_resharding_with_empty_shard(self): | ||||
|         """ | ||||
|         Test dtensor checkpoint resharding with dtensor containing empty shards. | ||||
|         """ | ||||
|         try: | ||||
|             import safetensors | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         tensor = torch.rand(1).cuda() | ||||
|         mesh = init_device_mesh(self.device_type, (self.world_size,)) | ||||
|         dtensor = distribute_tensor(tensor, mesh, [Shard(0)]) | ||||
|         ref_state_dict = {"dtensor": dtensor} | ||||
|  | ||||
|         dist_cp.save( | ||||
|             state_dict=ref_state_dict, | ||||
|             storage_writer=dist_cp._HuggingFaceStorageWriter(path=self.temp_dir, save_sharded=True), | ||||
|         ) | ||||
|  | ||||
|         tensor = torch.rand(1).cuda() | ||||
|         mesh_2 = init_device_mesh(self.device_type, (2, self.world_size // 2)) | ||||
|         dtensor = distribute_tensor(tensor, mesh_2, [Shard(0), Shard(0)]) | ||||
|         state_dict = {"dtensor": dtensor} | ||||
|         dist_cp.load( | ||||
|             state_dict=state_dict, | ||||
|             storage_reader=dist_cp._HuggingFaceStorageReader(self.temp_dir), | ||||
|         ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
| @ -8,7 +8,10 @@ import tempfile | ||||
| from unittest.mock import MagicMock | ||||
|  | ||||
| import torch | ||||
| from torch.distributed.checkpoint import DefaultLoadPlanner | ||||
| from torch.distributed.checkpoint._hf_planner import ( | ||||
|     _FqnToFileMapping, | ||||
|     _HuggingFaceLoadPlanner, | ||||
| ) | ||||
| from torch.distributed.checkpoint._hf_storage import ( | ||||
|     _HuggingFaceStorageReader, | ||||
|     _HuggingFaceStorageWriter, | ||||
| @ -18,19 +21,14 @@ from torch.distributed.checkpoint.default_planner import DefaultSavePlanner | ||||
| from torch.distributed.checkpoint.filesystem import _StorageInfo, FileSystem | ||||
| from torch.distributed.checkpoint.metadata import ( | ||||
|     BytesStorageMetadata, | ||||
|     ChunkStorageMetadata, | ||||
|     Metadata, | ||||
|     MetadataIndex, | ||||
|     TensorProperties, | ||||
|     TensorStorageMetadata, | ||||
| ) | ||||
| from torch.distributed.checkpoint.planner import ( | ||||
|     LoadItemType, | ||||
|     LoadPlan, | ||||
|     ReadItem, | ||||
|     SavePlan, | ||||
| from torch.distributed.checkpoint.planner import LoadPlan, SavePlan | ||||
| from torch.distributed.checkpoint.planner_helpers import ( | ||||
|     _create_read_items, | ||||
|     _create_write_item_for_tensor, | ||||
| ) | ||||
| from torch.distributed.checkpoint.planner_helpers import _create_write_item_for_tensor | ||||
| from torch.distributed.checkpoint.storage import WriteResult | ||||
| from torch.testing._internal.common_utils import run_tests, TestCase | ||||
|  | ||||
| @ -38,66 +36,9 @@ from torch.testing._internal.common_utils import run_tests, TestCase | ||||
| class TestHfStorage(TestCase): | ||||
|     def test_write_data_hf(self) -> None: | ||||
|         mock_module = MagicMock() | ||||
|         mock_module.save.return_value = b"" | ||||
|         sys.modules["safetensors.torch"] = mock_module | ||||
|         sys.modules["safetensors"] = mock_module | ||||
|         sys.modules["huggingface_hub"] = mock_module | ||||
|  | ||||
|         with tempfile.TemporaryDirectory() as path: | ||||
|             writer = _HuggingFaceStorageWriter( | ||||
|                 path=path, | ||||
|                 fqn_to_index_mapping={"tensor_0": 1, "tensor_1": 2}, | ||||
|             ) | ||||
|             writer.fs = FileSystem() | ||||
|  | ||||
|             tensor0 = torch.rand(4) | ||||
|             tensor1 = torch.rand(10) | ||||
|             write_item_1 = _create_write_item_for_tensor("tensor_0", tensor0) | ||||
|             write_item_2 = _create_write_item_for_tensor("tensor_1", tensor1) | ||||
|  | ||||
|             state_dict = {"tensor_0": tensor0, "tensor_1": tensor1} | ||||
|  | ||||
|             save_plan = SavePlan( | ||||
|                 [write_item_1, write_item_2], | ||||
|                 storage_data={"fqn_to_file_mapping": {"tensor_0": 1, "tensor_1": 2}}, | ||||
|             ) | ||||
|             save_planner = DefaultSavePlanner() | ||||
|             save_planner.set_up_planner(state_dict=state_dict) | ||||
|  | ||||
|             write_results = writer.write_data(save_plan, save_planner) | ||||
|  | ||||
|             write_results.wait() | ||||
|             actual_write_results = write_results.value() | ||||
|  | ||||
|             expected_write_results = [ | ||||
|                 WriteResult( | ||||
|                     index=MetadataIndex( | ||||
|                         fqn="tensor_0", offset=torch.Size([0]), index=None | ||||
|                     ), | ||||
|                     size_in_bytes=tensor0.numel() * tensor0.element_size(), | ||||
|                     storage_data=_StorageInfo( | ||||
|                         relative_path="model-00001-of-00002.safetensors", | ||||
|                         offset=0, | ||||
|                         length=tensor0.numel() * tensor0.element_size(), | ||||
|                     ), | ||||
|                 ), | ||||
|                 WriteResult( | ||||
|                     index=MetadataIndex( | ||||
|                         fqn="tensor_1", offset=torch.Size([0]), index=None | ||||
|                     ), | ||||
|                     size_in_bytes=tensor1.numel() * tensor1.element_size(), | ||||
|                     storage_data=_StorageInfo( | ||||
|                         relative_path="model-00002-of-00002.safetensors", | ||||
|                         offset=0, | ||||
|                         length=tensor1.numel() * tensor1.element_size(), | ||||
|                     ), | ||||
|                 ), | ||||
|             ] | ||||
|  | ||||
|             self.assertEqual( | ||||
|                 actual_write_results, | ||||
|                 expected_write_results, | ||||
|             ) | ||||
|  | ||||
|     def test_write_data_with_sharding(self) -> None: | ||||
|         mock_module = MagicMock() | ||||
|         mock_module.save.return_value = b"" | ||||
|         sys.modules["safetensors.torch"] = mock_module | ||||
| @ -105,7 +46,7 @@ class TestHfStorage(TestCase): | ||||
|         with tempfile.TemporaryDirectory() as path: | ||||
|             writer = _HuggingFaceStorageWriter( | ||||
|                 path=path, | ||||
|                 save_sharded=True, | ||||
|                 fqn_to_index_mapping={"tensor_0": 1, "tensor_1": 1}, | ||||
|             ) | ||||
|             writer.fs = FileSystem() | ||||
|  | ||||
| @ -118,7 +59,7 @@ class TestHfStorage(TestCase): | ||||
|  | ||||
|             save_plan = SavePlan( | ||||
|                 [write_item_1, write_item_2], | ||||
|                 storage_data={"shard_index": 1}, | ||||
|                 storage_data=_FqnToFileMapping({"tensor_0": 1, "tensor_1": 1}), | ||||
|             ) | ||||
|             save_planner = DefaultSavePlanner() | ||||
|             save_planner.set_up_planner(state_dict=state_dict) | ||||
| @ -135,7 +76,7 @@ class TestHfStorage(TestCase): | ||||
|                     ), | ||||
|                     size_in_bytes=tensor0.numel() * tensor0.element_size(), | ||||
|                     storage_data=_StorageInfo( | ||||
|                         relative_path="shard-00001-model-00001-of-00001.safetensors", | ||||
|                         relative_path="model-00001-of-00001.safetensors", | ||||
|                         offset=0, | ||||
|                         length=tensor0.numel() * tensor0.element_size(), | ||||
|                     ), | ||||
| @ -146,7 +87,7 @@ class TestHfStorage(TestCase): | ||||
|                     ), | ||||
|                     size_in_bytes=tensor1.numel() * tensor1.element_size(), | ||||
|                     storage_data=_StorageInfo( | ||||
|                         relative_path="shard-00001-model-00001-of-00001.safetensors", | ||||
|                         relative_path="model-00001-of-00001.safetensors", | ||||
|                         offset=0, | ||||
|                         length=tensor1.numel() * tensor1.element_size(), | ||||
|                     ), | ||||
| @ -159,84 +100,43 @@ class TestHfStorage(TestCase): | ||||
|             ) | ||||
|  | ||||
|     def test_read_data_hf(self) -> None: | ||||
|         mock_safetensors = MagicMock() | ||||
|         sys.modules["safetensors"] = mock_safetensors | ||||
|         mock_module = MagicMock() | ||||
|         sys.modules["safetensors"] = mock_module | ||||
|         sys.modules["huggingface_hub"] = mock_module | ||||
|  | ||||
|         # Create test tensors | ||||
|         tensor_0 = torch.tensor([1.0, 2.0, 3.0, 4.0]) | ||||
|  | ||||
|         # Mock the deserialize function to return our test tensors | ||||
|         # The format matches what's expected in the read_data method | ||||
|         mock_safetensors.deserialize.return_value = [ | ||||
|             ("tensor_0", { | ||||
|                 "data": tensor_0.numpy().tobytes(), | ||||
|                 "dtype": "F32", | ||||
|                 "shape": [4] | ||||
|             }), | ||||
|         ] | ||||
|         name = "tensor_0" | ||||
|         tensor_0 = torch.rand(4) | ||||
|         mock_module = MagicMock() | ||||
|         mock_module.load.return_value = {name: tensor_0} | ||||
|         sys.modules["safetensors.torch"] = mock_module | ||||
|  | ||||
|         with tempfile.TemporaryDirectory() as path: | ||||
|             # Create the reader | ||||
|             reader = _HuggingFaceStorageReader(path=path) | ||||
|             reader.fs = FileSystem() | ||||
|             file_name = "model-00001-of-00001" | ||||
|  | ||||
|             # Create test file | ||||
|             file_name = "model-00001-of-00001.safetensors" | ||||
|             file_path = os.path.join(path, file_name) | ||||
|             pathlib.Path(file_path).touch() | ||||
|             pathlib.Path(os.path.join(path, file_name)).touch() | ||||
|  | ||||
|             # Set up storage data with _StorageInfo objects | ||||
|             storage_data = { | ||||
|                 "tensor_0": _StorageInfo(file_path, 0, tensor_0.numel() * tensor_0.element_size()), | ||||
|             } | ||||
|             reader.set_up_storage_reader( | ||||
|                 Metadata( | ||||
|                     state_dict_metadata={name: BytesStorageMetadata()}, | ||||
|                     storage_data={name: file_name}, | ||||
|                 ), | ||||
|                 is_coordinator=True, | ||||
|             ) | ||||
|  | ||||
|  | ||||
|             reader.storage_data = storage_data | ||||
|  | ||||
|             # Create target tensors that will be updated by read_data | ||||
|             target_tensor_0 = torch.zeros(4) | ||||
|             state_dict = { | ||||
|                 "tensor_0": target_tensor_0, | ||||
|             } | ||||
|  | ||||
|             # Create read items for the load plan | ||||
|             read_items = [] | ||||
|             for name, tensor in state_dict.items(): | ||||
|                 storage_index = MetadataIndex(fqn=name, offset=torch.Size([0]), index=None) | ||||
|                 dest_index = MetadataIndex(fqn=name, offset=torch.Size([0]), index=None) | ||||
|                 read_items.append( | ||||
|                     ReadItem( | ||||
|                         type=LoadItemType.TENSOR, | ||||
|                         storage_index=storage_index, | ||||
|                         dest_index=dest_index, | ||||
|                         storage_offsets=[0, 0], | ||||
|                         dest_offsets=[0, 0], | ||||
|                         lengths=tensor.size(), | ||||
|                     ) | ||||
|                 ) | ||||
|  | ||||
|             # Create load plan and planner | ||||
|             read_items = _create_read_items(name, BytesStorageMetadata(), file_name) | ||||
|             load_plan = LoadPlan(read_items) | ||||
|             load_planner = DefaultLoadPlanner() | ||||
|             load_planner.set_up_planner( | ||||
|                 state_dict=state_dict, | ||||
|                  metadata=Metadata( | ||||
|                             state_dict_metadata={ | ||||
|                                 "tensor_0": TensorStorageMetadata( | ||||
|                                             properties=TensorProperties(dtype=torch.float32), | ||||
|                                             size=torch.Size([4]), | ||||
|                                             chunks=[ChunkStorageMetadata(offsets=[0], sizes=torch.Size([4]))])}, | ||||
|                                  storage_data=storage_data) | ||||
|                             ) | ||||
|             load_planner = _HuggingFaceLoadPlanner() | ||||
|             load_planner.set_up_planner(state_dict={name: torch.rand(4)}) | ||||
|  | ||||
|             # Call read_data | ||||
|             future = reader.read_data(load_plan, load_planner) | ||||
|             future.wait() | ||||
|             read_data = reader.read_data(load_plan, load_planner) | ||||
|             read_data.wait() | ||||
|  | ||||
|             # Verify results - the target tensors should now contain the values from our test tensor | ||||
|             self.assertTrue(torch.equal(state_dict["tensor_0"], tensor_0)) | ||||
|             loaded_tensor = load_planner.original_state_dict[name] | ||||
|             self.assertEqual(loaded_tensor, tensor_0) | ||||
|  | ||||
|     def test_write_metadata_hf(self) -> None: | ||||
|     def test_metadata_hf(self) -> None: | ||||
|         mock_module = MagicMock() | ||||
|         sys.modules["huggingface_hub"] = mock_module | ||||
|         with tempfile.TemporaryDirectory() as path: | ||||
| @ -260,6 +160,7 @@ class TestHfStorage(TestCase): | ||||
|  | ||||
|             writer = _HuggingFaceStorageWriter( | ||||
|                 path=path, | ||||
|                 fqn_to_index_mapping=_FqnToFileMapping({}), | ||||
|             ) | ||||
|             writer.fs = FileSystem() | ||||
|             writer.finish( | ||||
| @ -284,16 +185,26 @@ class TestHfStorage(TestCase): | ||||
|                 metadata = json.load(f) | ||||
|                 self.assertEqual(metadata, expected_metadata) | ||||
|  | ||||
|     def test_read_metadata_hf(self): | ||||
|             reader = _HuggingFaceStorageReader(path=path) | ||||
|             reader.fs = FileSystem() | ||||
|             metadata = reader.read_metadata() | ||||
|             self.assertEqual(metadata.storage_data, expected_metadata["weight_map"]) | ||||
|  | ||||
|     def test_read_metadata_when_metadata_file_does_not_exist(self) -> None: | ||||
|         mock_module = MagicMock() | ||||
|         sys.modules["huggingface_hub"] = mock_module | ||||
|  | ||||
|         with tempfile.TemporaryDirectory() as path: | ||||
|             reader = _HuggingFaceStorageReader(path=path) | ||||
|  | ||||
|             key = "tensor_0" | ||||
|             reader.fs = FileSystem() | ||||
|             # there is one safetensor file, but no metadata file, | ||||
|             # so we create metadata from the safetensor file | ||||
|             keys = ["tensor_0", "tensor_1"] | ||||
|             file_name = "test.safetensors" | ||||
|             with open(os.path.join(path, file_name), "wb") as f: | ||||
|                 # write metadata the same way it would be in safetensors file | ||||
|                 metadata_contents = json.dumps( | ||||
|                     {'tensor_0': {'dtype': "F32", "shape": [5, 10], "data_offsets": [0, 200]}} | ||||
|                     {"tensor_0": "value_0", "tensor_1": "value_1"} | ||||
|                 ) | ||||
|                 metadata_bytes = metadata_contents.encode("utf-8") | ||||
|  | ||||
| @ -305,16 +216,13 @@ class TestHfStorage(TestCase): | ||||
|             self.assertEqual( | ||||
|                 metadata.state_dict_metadata, | ||||
|                 { | ||||
|                     key: TensorStorageMetadata( | ||||
|                             properties=TensorProperties(dtype=torch.float32), | ||||
|                             size=torch.Size([5, 10]), | ||||
|                             chunks=[ChunkStorageMetadata(offsets=[0, 0], sizes=torch.Size([5, 10]))], | ||||
|                     ), | ||||
|                     keys[0]: BytesStorageMetadata(), | ||||
|                     keys[1]: BytesStorageMetadata(), | ||||
|                 }, | ||||
|             ) | ||||
|             self.assertEqual( | ||||
|                 metadata.storage_data, | ||||
|                 {key: _StorageInfo(os.path.join(path, file_name), 0, 200, transform_descriptors=None)}, | ||||
|                 {keys[0]: file_name, keys[1]: file_name}, | ||||
|             ) | ||||
|  | ||||
|  | ||||
|  | ||||
							
								
								
									
										67
									
								
								test/dynamo/cpython/3_13/list_tests.diff
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								test/dynamo/cpython/3_13/list_tests.diff
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,67 @@ | ||||
| diff --git a/test/dynamo/cpython/3_13/list_tests.py b/test/dynamo/cpython/3_13/list_tests.py | ||||
| index dbc5ef4f9f2..2b9f3b9311f 100644 | ||||
| --- a/test/dynamo/cpython/3_13/list_tests.py | ||||
| +++ b/test/dynamo/cpython/3_13/list_tests.py | ||||
| @@ -1,3 +1,53 @@ | ||||
| +# ======= BEGIN Dynamo patch ======= | ||||
| +# Owner(s): ["module: dynamo"] | ||||
| + | ||||
| +# ruff: noqa | ||||
| +# flake8: noqa | ||||
| + | ||||
| +import sys | ||||
| +import torch | ||||
| +import torch._dynamo.test_case | ||||
| +import unittest | ||||
| +from torch._dynamo.test_case import CPythonTestCase | ||||
| +from torch.testing._internal.common_utils import run_tests | ||||
| + | ||||
| +__TestCase = CPythonTestCase | ||||
| + | ||||
| +# redirect import statements | ||||
| +import sys | ||||
| +import importlib.abc | ||||
| + | ||||
| +redirect_imports = ( | ||||
| +    "test.mapping_tests", | ||||
| +    "test.typinganndata", | ||||
| +    "test.test_grammar", | ||||
| +    "test.test_math", | ||||
| +    "test.test_iter", | ||||
| +    "test.typinganndata.ann_module", | ||||
| +) | ||||
| + | ||||
| +class RedirectImportFinder(importlib.abc.MetaPathFinder): | ||||
| +    def find_spec(self, fullname, path, target=None): | ||||
| +        # Check if the import is the problematic one | ||||
| +        if fullname in redirect_imports: | ||||
| +            try: | ||||
| +                # Attempt to import the standalone module | ||||
| +                name = fullname.removeprefix("test.") | ||||
| +                r = importlib.import_module(name) | ||||
| +                # Redirect the module in sys.modules | ||||
| +                sys.modules[fullname] = r | ||||
| +                # Return a module spec from the found module | ||||
| +                return importlib.util.find_spec(name) | ||||
| +            except ImportError: | ||||
| +                return None | ||||
| +        return None | ||||
| + | ||||
| +# Add the custom finder to sys.meta_path | ||||
| +sys.meta_path.insert(0, RedirectImportFinder()) | ||||
| + | ||||
| + | ||||
| +# ======= END DYNAMO PATCH ======= | ||||
| + | ||||
|  """ | ||||
|  Tests common to list and UserList.UserList | ||||
|  """ | ||||
| @@ -5,7 +55,7 @@ Tests common to list and UserList.UserList | ||||
|  import sys | ||||
|  from functools import cmp_to_key | ||||
|   | ||||
| -from test import seq_tests | ||||
| +import seq_tests | ||||
|  from test.support import ALWAYS_EQ, NEVER_EQ, get_c_recursion_limit | ||||
|   | ||||
|   | ||||
							
								
								
									
										627
									
								
								test/dynamo/cpython/3_13/list_tests.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										627
									
								
								test/dynamo/cpython/3_13/list_tests.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,627 @@ | ||||
| # ======= BEGIN Dynamo patch ======= | ||||
| # Owner(s): ["module: dynamo"] | ||||
|  | ||||
| # ruff: noqa | ||||
| # flake8: noqa | ||||
|  | ||||
| import sys | ||||
| import torch | ||||
| import torch._dynamo.test_case | ||||
| import unittest | ||||
| from torch._dynamo.test_case import CPythonTestCase | ||||
| from torch.testing._internal.common_utils import run_tests | ||||
|  | ||||
| __TestCase = CPythonTestCase | ||||
|  | ||||
| # redirect import statements | ||||
| import sys | ||||
| import importlib.abc | ||||
|  | ||||
| redirect_imports = ( | ||||
|     "test.mapping_tests", | ||||
|     "test.typinganndata", | ||||
|     "test.test_grammar", | ||||
|     "test.test_math", | ||||
|     "test.test_iter", | ||||
|     "test.typinganndata.ann_module", | ||||
| ) | ||||
|  | ||||
| class RedirectImportFinder(importlib.abc.MetaPathFinder): | ||||
|     def find_spec(self, fullname, path, target=None): | ||||
|         # Check if the import is the problematic one | ||||
|         if fullname in redirect_imports: | ||||
|             try: | ||||
|                 # Attempt to import the standalone module | ||||
|                 name = fullname.removeprefix("test.") | ||||
|                 r = importlib.import_module(name) | ||||
|                 # Redirect the module in sys.modules | ||||
|                 sys.modules[fullname] = r | ||||
|                 # Return a module spec from the found module | ||||
|                 return importlib.util.find_spec(name) | ||||
|             except ImportError: | ||||
|                 return None | ||||
|         return None | ||||
|  | ||||
| # Add the custom finder to sys.meta_path | ||||
| sys.meta_path.insert(0, RedirectImportFinder()) | ||||
|  | ||||
|  | ||||
| # ======= END DYNAMO PATCH ======= | ||||
|  | ||||
| """ | ||||
| Tests common to list and UserList.UserList | ||||
| """ | ||||
|  | ||||
| import sys | ||||
| from functools import cmp_to_key | ||||
|  | ||||
| import seq_tests | ||||
| from test.support import ALWAYS_EQ, NEVER_EQ, get_c_recursion_limit | ||||
|  | ||||
|  | ||||
| class CommonTest(seq_tests.CommonTest): | ||||
|  | ||||
|     def test_init(self): | ||||
|         # Iterable arg is optional | ||||
|         self.assertEqual(self.type2test([]), self.type2test()) | ||||
|  | ||||
|         # Init clears previous values | ||||
|         a = self.type2test([1, 2, 3]) | ||||
|         a.__init__() | ||||
|         self.assertEqual(a, self.type2test([])) | ||||
|  | ||||
|         # Init overwrites previous values | ||||
|         a = self.type2test([1, 2, 3]) | ||||
|         a.__init__([4, 5, 6]) | ||||
|         self.assertEqual(a, self.type2test([4, 5, 6])) | ||||
|  | ||||
|         # Mutables always return a new object | ||||
|         b = self.type2test(a) | ||||
|         self.assertNotEqual(id(a), id(b)) | ||||
|         self.assertEqual(a, b) | ||||
|  | ||||
|     def test_getitem_error(self): | ||||
|         a = [] | ||||
|         msg = "list indices must be integers or slices" | ||||
|         with self.assertRaisesRegex(TypeError, msg): | ||||
|             a['a'] | ||||
|  | ||||
|     def test_setitem_error(self): | ||||
|         a = [] | ||||
|         msg = "list indices must be integers or slices" | ||||
|         with self.assertRaisesRegex(TypeError, msg): | ||||
|             a['a'] = "python" | ||||
|  | ||||
|     def test_repr(self): | ||||
|         l0 = [] | ||||
|         l2 = [0, 1, 2] | ||||
|         a0 = self.type2test(l0) | ||||
|         a2 = self.type2test(l2) | ||||
|  | ||||
|         self.assertEqual(str(a0), str(l0)) | ||||
|         self.assertEqual(repr(a0), repr(l0)) | ||||
|         self.assertEqual(repr(a2), repr(l2)) | ||||
|         self.assertEqual(str(a2), "[0, 1, 2]") | ||||
|         self.assertEqual(repr(a2), "[0, 1, 2]") | ||||
|  | ||||
|         a2.append(a2) | ||||
|         a2.append(3) | ||||
|         self.assertEqual(str(a2), "[0, 1, 2, [...], 3]") | ||||
|         self.assertEqual(repr(a2), "[0, 1, 2, [...], 3]") | ||||
|  | ||||
|     def test_repr_deep(self): | ||||
|         a = self.type2test([]) | ||||
|         for i in range(get_c_recursion_limit() + 1): | ||||
|             a = self.type2test([a]) | ||||
|         self.assertRaises(RecursionError, repr, a) | ||||
|  | ||||
|     def test_set_subscript(self): | ||||
|         a = self.type2test(range(20)) | ||||
|         self.assertRaises(ValueError, a.__setitem__, slice(0, 10, 0), [1,2,3]) | ||||
|         self.assertRaises(TypeError, a.__setitem__, slice(0, 10), 1) | ||||
|         self.assertRaises(ValueError, a.__setitem__, slice(0, 10, 2), [1,2]) | ||||
|         self.assertRaises(TypeError, a.__getitem__, 'x', 1) | ||||
|         a[slice(2,10,3)] = [1,2,3] | ||||
|         self.assertEqual(a, self.type2test([0, 1, 1, 3, 4, 2, 6, 7, 3, | ||||
|                                             9, 10, 11, 12, 13, 14, 15, | ||||
|                                             16, 17, 18, 19])) | ||||
|  | ||||
|     def test_reversed(self): | ||||
|         a = self.type2test(range(20)) | ||||
|         r = reversed(a) | ||||
|         self.assertEqual(list(r), self.type2test(range(19, -1, -1))) | ||||
|         self.assertRaises(StopIteration, next, r) | ||||
|         self.assertEqual(list(reversed(self.type2test())), | ||||
|                          self.type2test()) | ||||
|         # Bug 3689: make sure list-reversed-iterator doesn't have __len__ | ||||
|         self.assertRaises(TypeError, len, reversed([1,2,3])) | ||||
|  | ||||
|     def test_setitem(self): | ||||
|         a = self.type2test([0, 1]) | ||||
|         a[0] = 0 | ||||
|         a[1] = 100 | ||||
|         self.assertEqual(a, self.type2test([0, 100])) | ||||
|         a[-1] = 200 | ||||
|         self.assertEqual(a, self.type2test([0, 200])) | ||||
|         a[-2] = 100 | ||||
|         self.assertEqual(a, self.type2test([100, 200])) | ||||
|         self.assertRaises(IndexError, a.__setitem__, -3, 200) | ||||
|         self.assertRaises(IndexError, a.__setitem__, 2, 200) | ||||
|  | ||||
|         a = self.type2test([]) | ||||
|         self.assertRaises(IndexError, a.__setitem__, 0, 200) | ||||
|         self.assertRaises(IndexError, a.__setitem__, -1, 200) | ||||
|         self.assertRaises(TypeError, a.__setitem__) | ||||
|  | ||||
|         a = self.type2test([0,1,2,3,4]) | ||||
|         a[0] = 1 | ||||
|         a[1] = 2 | ||||
|         a[2] = 3 | ||||
|         self.assertEqual(a, self.type2test([1,2,3,3,4])) | ||||
|         a[0] = 5 | ||||
|         a[1] = 6 | ||||
|         a[2] = 7 | ||||
|         self.assertEqual(a, self.type2test([5,6,7,3,4])) | ||||
|         a[-2] = 88 | ||||
|         a[-1] = 99 | ||||
|         self.assertEqual(a, self.type2test([5,6,7,88,99])) | ||||
|         a[-2] = 8 | ||||
|         a[-1] = 9 | ||||
|         self.assertEqual(a, self.type2test([5,6,7,8,9])) | ||||
|  | ||||
|         msg = "list indices must be integers or slices" | ||||
|         with self.assertRaisesRegex(TypeError, msg): | ||||
|             a['a'] = "python" | ||||
|  | ||||
|     def test_delitem(self): | ||||
|         a = self.type2test([0, 1]) | ||||
|         del a[1] | ||||
|         self.assertEqual(a, [0]) | ||||
|         del a[0] | ||||
|         self.assertEqual(a, []) | ||||
|  | ||||
|         a = self.type2test([0, 1]) | ||||
|         del a[-2] | ||||
|         self.assertEqual(a, [1]) | ||||
|         del a[-1] | ||||
|         self.assertEqual(a, []) | ||||
|  | ||||
|         a = self.type2test([0, 1]) | ||||
|         self.assertRaises(IndexError, a.__delitem__, -3) | ||||
|         self.assertRaises(IndexError, a.__delitem__, 2) | ||||
|  | ||||
|         a = self.type2test([]) | ||||
|         self.assertRaises(IndexError, a.__delitem__, 0) | ||||
|  | ||||
|         self.assertRaises(TypeError, a.__delitem__) | ||||
|  | ||||
|     def test_setslice(self): | ||||
|         l = [0, 1] | ||||
|         a = self.type2test(l) | ||||
|  | ||||
|         for i in range(-3, 4): | ||||
|             a[:i] = l[:i] | ||||
|             self.assertEqual(a, l) | ||||
|             a2 = a[:] | ||||
|             a2[:i] = a[:i] | ||||
|             self.assertEqual(a2, a) | ||||
|             a[i:] = l[i:] | ||||
|             self.assertEqual(a, l) | ||||
|             a2 = a[:] | ||||
|             a2[i:] = a[i:] | ||||
|             self.assertEqual(a2, a) | ||||
|             for j in range(-3, 4): | ||||
|                 a[i:j] = l[i:j] | ||||
|                 self.assertEqual(a, l) | ||||
|                 a2 = a[:] | ||||
|                 a2[i:j] = a[i:j] | ||||
|                 self.assertEqual(a2, a) | ||||
|  | ||||
|         aa2 = a2[:] | ||||
|         aa2[:0] = [-2, -1] | ||||
|         self.assertEqual(aa2, [-2, -1, 0, 1]) | ||||
|         aa2[0:] = [] | ||||
|         self.assertEqual(aa2, []) | ||||
|  | ||||
|         a = self.type2test([1, 2, 3, 4, 5]) | ||||
|         a[:-1] = a | ||||
|         self.assertEqual(a, self.type2test([1, 2, 3, 4, 5, 5])) | ||||
|         a = self.type2test([1, 2, 3, 4, 5]) | ||||
|         a[1:] = a | ||||
|         self.assertEqual(a, self.type2test([1, 1, 2, 3, 4, 5])) | ||||
|         a = self.type2test([1, 2, 3, 4, 5]) | ||||
|         a[1:-1] = a | ||||
|         self.assertEqual(a, self.type2test([1, 1, 2, 3, 4, 5, 5])) | ||||
|  | ||||
|         a = self.type2test([]) | ||||
|         a[:] = tuple(range(10)) | ||||
|         self.assertEqual(a, self.type2test(range(10))) | ||||
|  | ||||
|         self.assertRaises(TypeError, a.__setitem__, slice(0, 1, 5)) | ||||
|  | ||||
|         self.assertRaises(TypeError, a.__setitem__) | ||||
|  | ||||
|     def test_slice_assign_iterator(self): | ||||
|         x = self.type2test(range(5)) | ||||
|         x[0:3] = reversed(range(3)) | ||||
|         self.assertEqual(x, self.type2test([2, 1, 0, 3, 4])) | ||||
|  | ||||
|         x[:] = reversed(range(3)) | ||||
|         self.assertEqual(x, self.type2test([2, 1, 0])) | ||||
|  | ||||
|     def test_delslice(self): | ||||
|         a = self.type2test([0, 1]) | ||||
|         del a[1:2] | ||||
|         del a[0:1] | ||||
|         self.assertEqual(a, self.type2test([])) | ||||
|  | ||||
|         a = self.type2test([0, 1]) | ||||
|         del a[1:2] | ||||
|         del a[0:1] | ||||
|         self.assertEqual(a, self.type2test([])) | ||||
|  | ||||
|         a = self.type2test([0, 1]) | ||||
|         del a[-2:-1] | ||||
|         self.assertEqual(a, self.type2test([1])) | ||||
|  | ||||
|         a = self.type2test([0, 1]) | ||||
|         del a[-2:-1] | ||||
|         self.assertEqual(a, self.type2test([1])) | ||||
|  | ||||
|         a = self.type2test([0, 1]) | ||||
|         del a[1:] | ||||
|         del a[:1] | ||||
|         self.assertEqual(a, self.type2test([])) | ||||
|  | ||||
|         a = self.type2test([0, 1]) | ||||
|         del a[1:] | ||||
|         del a[:1] | ||||
|         self.assertEqual(a, self.type2test([])) | ||||
|  | ||||
|         a = self.type2test([0, 1]) | ||||
|         del a[-1:] | ||||
|         self.assertEqual(a, self.type2test([0])) | ||||
|  | ||||
|         a = self.type2test([0, 1]) | ||||
|         del a[-1:] | ||||
|         self.assertEqual(a, self.type2test([0])) | ||||
|  | ||||
|         a = self.type2test([0, 1]) | ||||
|         del a[:] | ||||
|         self.assertEqual(a, self.type2test([])) | ||||
|  | ||||
|     def test_append(self): | ||||
|         a = self.type2test([]) | ||||
|         a.append(0) | ||||
|         a.append(1) | ||||
|         a.append(2) | ||||
|         self.assertEqual(a, self.type2test([0, 1, 2])) | ||||
|  | ||||
|         self.assertRaises(TypeError, a.append) | ||||
|  | ||||
|     def test_extend(self): | ||||
|         a1 = self.type2test([0]) | ||||
|         a2 = self.type2test((0, 1)) | ||||
|         a = a1[:] | ||||
|         a.extend(a2) | ||||
|         self.assertEqual(a, a1 + a2) | ||||
|  | ||||
|         a.extend(self.type2test([])) | ||||
|         self.assertEqual(a, a1 + a2) | ||||
|  | ||||
|         a.extend(a) | ||||
|         self.assertEqual(a, self.type2test([0, 0, 1, 0, 0, 1])) | ||||
|  | ||||
|         a = self.type2test("spam") | ||||
|         a.extend("eggs") | ||||
|         self.assertEqual(a, list("spameggs")) | ||||
|  | ||||
|         self.assertRaises(TypeError, a.extend, None) | ||||
|         self.assertRaises(TypeError, a.extend) | ||||
|  | ||||
|         # overflow test. issue1621 | ||||
|         class CustomIter: | ||||
|             def __iter__(self): | ||||
|                 return self | ||||
|             def __next__(self): | ||||
|                 raise StopIteration | ||||
|             def __length_hint__(self): | ||||
|                 return sys.maxsize | ||||
|         a = self.type2test([1,2,3,4]) | ||||
|         a.extend(CustomIter()) | ||||
|         self.assertEqual(a, [1,2,3,4]) | ||||
|  | ||||
|  | ||||
|     def test_insert(self): | ||||
|         a = self.type2test([0, 1, 2]) | ||||
|         a.insert(0, -2) | ||||
|         a.insert(1, -1) | ||||
|         a.insert(2, 0) | ||||
|         self.assertEqual(a, [-2, -1, 0, 0, 1, 2]) | ||||
|  | ||||
|         b = a[:] | ||||
|         b.insert(-2, "foo") | ||||
|         b.insert(-200, "left") | ||||
|         b.insert(200, "right") | ||||
|         self.assertEqual(b, self.type2test(["left",-2,-1,0,0,"foo",1,2,"right"])) | ||||
|  | ||||
|         self.assertRaises(TypeError, a.insert) | ||||
|  | ||||
|     def test_pop(self): | ||||
|         a = self.type2test([-1, 0, 1]) | ||||
|         a.pop() | ||||
|         self.assertEqual(a, [-1, 0]) | ||||
|         a.pop(0) | ||||
|         self.assertEqual(a, [0]) | ||||
|         self.assertRaises(IndexError, a.pop, 5) | ||||
|         a.pop(0) | ||||
|         self.assertEqual(a, []) | ||||
|         self.assertRaises(IndexError, a.pop) | ||||
|         self.assertRaises(TypeError, a.pop, 42, 42) | ||||
|         a = self.type2test([0, 10, 20, 30, 40]) | ||||
|  | ||||
|     def test_remove(self): | ||||
|         a = self.type2test([0, 0, 1]) | ||||
|         a.remove(1) | ||||
|         self.assertEqual(a, [0, 0]) | ||||
|         a.remove(0) | ||||
|         self.assertEqual(a, [0]) | ||||
|         a.remove(0) | ||||
|         self.assertEqual(a, []) | ||||
|  | ||||
|         self.assertRaises(ValueError, a.remove, 0) | ||||
|  | ||||
|         self.assertRaises(TypeError, a.remove) | ||||
|  | ||||
|         a = self.type2test([1, 2]) | ||||
|         self.assertRaises(ValueError, a.remove, NEVER_EQ) | ||||
|         self.assertEqual(a, [1, 2]) | ||||
|         a.remove(ALWAYS_EQ) | ||||
|         self.assertEqual(a, [2]) | ||||
|         a = self.type2test([ALWAYS_EQ]) | ||||
|         a.remove(1) | ||||
|         self.assertEqual(a, []) | ||||
|         a = self.type2test([ALWAYS_EQ]) | ||||
|         a.remove(NEVER_EQ) | ||||
|         self.assertEqual(a, []) | ||||
|         a = self.type2test([NEVER_EQ]) | ||||
|         self.assertRaises(ValueError, a.remove, ALWAYS_EQ) | ||||
|  | ||||
|         class BadExc(Exception): | ||||
|             pass | ||||
|  | ||||
|         class BadCmp: | ||||
|             def __eq__(self, other): | ||||
|                 if other == 2: | ||||
|                     raise BadExc() | ||||
|                 return False | ||||
|  | ||||
|         a = self.type2test([0, 1, 2, 3]) | ||||
|         self.assertRaises(BadExc, a.remove, BadCmp()) | ||||
|  | ||||
|         class BadCmp2: | ||||
|             def __eq__(self, other): | ||||
|                 raise BadExc() | ||||
|  | ||||
|         d = self.type2test('abcdefghcij') | ||||
|         d.remove('c') | ||||
|         self.assertEqual(d, self.type2test('abdefghcij')) | ||||
|         d.remove('c') | ||||
|         self.assertEqual(d, self.type2test('abdefghij')) | ||||
|         self.assertRaises(ValueError, d.remove, 'c') | ||||
|         self.assertEqual(d, self.type2test('abdefghij')) | ||||
|  | ||||
|         # Handle comparison errors | ||||
|         d = self.type2test(['a', 'b', BadCmp2(), 'c']) | ||||
|         e = self.type2test(d) | ||||
|         self.assertRaises(BadExc, d.remove, 'c') | ||||
|         for x, y in zip(d, e): | ||||
|             # verify that original order and values are retained. | ||||
|             self.assertIs(x, y) | ||||
|  | ||||
|     def test_index(self): | ||||
|         super().test_index() | ||||
|         a = self.type2test([-2, -1, 0, 0, 1, 2]) | ||||
|         a.remove(0) | ||||
|         self.assertRaises(ValueError, a.index, 2, 0, 4) | ||||
|         self.assertEqual(a, self.type2test([-2, -1, 0, 1, 2])) | ||||
|  | ||||
|         # Test modifying the list during index's iteration | ||||
|         class EvilCmp: | ||||
|             def __init__(self, victim): | ||||
|                 self.victim = victim | ||||
|             def __eq__(self, other): | ||||
|                 del self.victim[:] | ||||
|                 return False | ||||
|         a = self.type2test() | ||||
|         a[:] = [EvilCmp(a) for _ in range(100)] | ||||
|         # This used to seg fault before patch #1005778 | ||||
|         self.assertRaises(ValueError, a.index, None) | ||||
|  | ||||
|     def test_reverse(self): | ||||
|         u = self.type2test([-2, -1, 0, 1, 2]) | ||||
|         u2 = u[:] | ||||
|         u.reverse() | ||||
|         self.assertEqual(u, [2, 1, 0, -1, -2]) | ||||
|         u.reverse() | ||||
|         self.assertEqual(u, u2) | ||||
|  | ||||
|         self.assertRaises(TypeError, u.reverse, 42) | ||||
|  | ||||
|     def test_clear(self): | ||||
|         u = self.type2test([2, 3, 4]) | ||||
|         u.clear() | ||||
|         self.assertEqual(u, []) | ||||
|  | ||||
|         u = self.type2test([]) | ||||
|         u.clear() | ||||
|         self.assertEqual(u, []) | ||||
|  | ||||
|         u = self.type2test([]) | ||||
|         u.append(1) | ||||
|         u.clear() | ||||
|         u.append(2) | ||||
|         self.assertEqual(u, [2]) | ||||
|  | ||||
|         self.assertRaises(TypeError, u.clear, None) | ||||
|  | ||||
|     def test_copy(self): | ||||
|         u = self.type2test([1, 2, 3]) | ||||
|         v = u.copy() | ||||
|         self.assertEqual(v, [1, 2, 3]) | ||||
|  | ||||
|         u = self.type2test([]) | ||||
|         v = u.copy() | ||||
|         self.assertEqual(v, []) | ||||
|  | ||||
|         # test that it's indeed a copy and not a reference | ||||
|         u = self.type2test(['a', 'b']) | ||||
|         v = u.copy() | ||||
|         v.append('i') | ||||
|         self.assertEqual(u, ['a', 'b']) | ||||
|         self.assertEqual(v, u + ['i']) | ||||
|  | ||||
|         # test that it's a shallow, not a deep copy | ||||
|         u = self.type2test([1, 2, [3, 4], 5]) | ||||
|         v = u.copy() | ||||
|         self.assertEqual(u, v) | ||||
|         self.assertIs(v[3], u[3]) | ||||
|  | ||||
|         self.assertRaises(TypeError, u.copy, None) | ||||
|  | ||||
|     def test_sort(self): | ||||
|         u = self.type2test([1, 0]) | ||||
|         u.sort() | ||||
|         self.assertEqual(u, [0, 1]) | ||||
|  | ||||
|         u = self.type2test([2,1,0,-1,-2]) | ||||
|         u.sort() | ||||
|         self.assertEqual(u, self.type2test([-2,-1,0,1,2])) | ||||
|  | ||||
|         self.assertRaises(TypeError, u.sort, 42, 42) | ||||
|  | ||||
|         def revcmp(a, b): | ||||
|             if a == b: | ||||
|                 return 0 | ||||
|             elif a < b: | ||||
|                 return 1 | ||||
|             else: # a > b | ||||
|                 return -1 | ||||
|         u.sort(key=cmp_to_key(revcmp)) | ||||
|         self.assertEqual(u, self.type2test([2,1,0,-1,-2])) | ||||
|  | ||||
|         # The following dumps core in unpatched Python 1.5: | ||||
|         def myComparison(x,y): | ||||
|             xmod, ymod = x%3, y%7 | ||||
|             if xmod == ymod: | ||||
|                 return 0 | ||||
|             elif xmod < ymod: | ||||
|                 return -1 | ||||
|             else: # xmod > ymod | ||||
|                 return 1 | ||||
|         z = self.type2test(range(12)) | ||||
|         z.sort(key=cmp_to_key(myComparison)) | ||||
|  | ||||
|         self.assertRaises(TypeError, z.sort, 2) | ||||
|  | ||||
|         def selfmodifyingComparison(x,y): | ||||
|             z.append(1) | ||||
|             if x == y: | ||||
|                 return 0 | ||||
|             elif x < y: | ||||
|                 return -1 | ||||
|             else: # x > y | ||||
|                 return 1 | ||||
|         self.assertRaises(ValueError, z.sort, | ||||
|                           key=cmp_to_key(selfmodifyingComparison)) | ||||
|  | ||||
|         self.assertRaises(TypeError, z.sort, 42, 42, 42, 42) | ||||
|  | ||||
|     def test_slice(self): | ||||
|         u = self.type2test("spam") | ||||
|         u[:2] = "h" | ||||
|         self.assertEqual(u, list("ham")) | ||||
|  | ||||
|     def test_iadd(self): | ||||
|         super().test_iadd() | ||||
|         u = self.type2test([0, 1]) | ||||
|         u2 = u | ||||
|         u += [2, 3] | ||||
|         self.assertIs(u, u2) | ||||
|  | ||||
|         u = self.type2test("spam") | ||||
|         u += "eggs" | ||||
|         self.assertEqual(u, self.type2test("spameggs")) | ||||
|  | ||||
|         self.assertRaises(TypeError, u.__iadd__, None) | ||||
|  | ||||
|     def test_imul(self): | ||||
|         super().test_imul() | ||||
|         s = self.type2test([]) | ||||
|         oldid = id(s) | ||||
|         s *= 10 | ||||
|         self.assertEqual(id(s), oldid) | ||||
|  | ||||
|     def test_extendedslicing(self): | ||||
|         #  subscript | ||||
|         a = self.type2test([0,1,2,3,4]) | ||||
|  | ||||
|         #  deletion | ||||
|         del a[::2] | ||||
|         self.assertEqual(a, self.type2test([1,3])) | ||||
|         a = self.type2test(range(5)) | ||||
|         del a[1::2] | ||||
|         self.assertEqual(a, self.type2test([0,2,4])) | ||||
|         a = self.type2test(range(5)) | ||||
|         del a[1::-2] | ||||
|         self.assertEqual(a, self.type2test([0,2,3,4])) | ||||
|         a = self.type2test(range(10)) | ||||
|         del a[::1000] | ||||
|         self.assertEqual(a, self.type2test([1, 2, 3, 4, 5, 6, 7, 8, 9])) | ||||
|         #  assignment | ||||
|         a = self.type2test(range(10)) | ||||
|         a[::2] = [-1]*5 | ||||
|         self.assertEqual(a, self.type2test([-1, 1, -1, 3, -1, 5, -1, 7, -1, 9])) | ||||
|         a = self.type2test(range(10)) | ||||
|         a[::-4] = [10]*3 | ||||
|         self.assertEqual(a, self.type2test([0, 10, 2, 3, 4, 10, 6, 7, 8 ,10])) | ||||
|         a = self.type2test(range(4)) | ||||
|         a[::-1] = a | ||||
|         self.assertEqual(a, self.type2test([3, 2, 1, 0])) | ||||
|         a = self.type2test(range(10)) | ||||
|         b = a[:] | ||||
|         c = a[:] | ||||
|         a[2:3] = self.type2test(["two", "elements"]) | ||||
|         b[slice(2,3)] = self.type2test(["two", "elements"]) | ||||
|         c[2:3:] = self.type2test(["two", "elements"]) | ||||
|         self.assertEqual(a, b) | ||||
|         self.assertEqual(a, c) | ||||
|         a = self.type2test(range(10)) | ||||
|         a[::2] = tuple(range(5)) | ||||
|         self.assertEqual(a, self.type2test([0, 1, 1, 3, 2, 5, 3, 7, 4, 9])) | ||||
|         # test issue7788 | ||||
|         a = self.type2test(range(10)) | ||||
|         del a[9::1<<333] | ||||
|  | ||||
|     def test_constructor_exception_handling(self): | ||||
|         # Bug #1242657 | ||||
|         class F(object): | ||||
|             def __iter__(self): | ||||
|                 raise KeyboardInterrupt | ||||
|         self.assertRaises(KeyboardInterrupt, list, F()) | ||||
|  | ||||
|     def test_exhausted_iterator(self): | ||||
|         a = self.type2test([1, 2, 3]) | ||||
|         exhit = iter(a) | ||||
|         empit = iter(a) | ||||
|         for x in exhit:  # exhaust the iterator | ||||
|             next(empit)  # not exhausted | ||||
|         a.append(9) | ||||
|         self.assertEqual(list(exhit), []) | ||||
|         self.assertEqual(list(empit), [9]) | ||||
|         self.assertEqual(a, self.type2test([1, 2, 3, 9])) | ||||
|  | ||||
|         # gh-115733: Crash when iterating over exhausted iterator | ||||
|         exhit = iter(self.type2test([1, 2, 3])) | ||||
|         for _ in exhit: | ||||
|             next(exhit, 1) | ||||
							
								
								
									
										67
									
								
								test/dynamo/cpython/3_13/mapping_tests.diff
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								test/dynamo/cpython/3_13/mapping_tests.diff
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,67 @@ | ||||
| diff --git a/test/dynamo/cpython/3_13/mapping_tests.py b/test/dynamo/cpython/3_13/mapping_tests.py | ||||
| index ed89a81a6ea..eed59a68e94 100644 | ||||
| --- a/test/dynamo/cpython/3_13/mapping_tests.py | ||||
| +++ b/test/dynamo/cpython/3_13/mapping_tests.py | ||||
| @@ -1,10 +1,61 @@ | ||||
| +# ======= BEGIN Dynamo patch ======= | ||||
| +# Owner(s): ["module: dynamo"] | ||||
| + | ||||
| +# ruff: noqa | ||||
| +# flake8: noqa | ||||
| + | ||||
| +import sys | ||||
| +import torch | ||||
| +import torch._dynamo.test_case | ||||
| +import unittest | ||||
| +from torch._dynamo.test_case import CPythonTestCase | ||||
| +from torch.testing._internal.common_utils import run_tests | ||||
| + | ||||
| +__TestCase = CPythonTestCase | ||||
| + | ||||
| + | ||||
| +# redirect import statements | ||||
| +import sys | ||||
| +import importlib.abc | ||||
| + | ||||
| +redirect_imports = ( | ||||
| +    "test.mapping_tests", | ||||
| +    "test.typinganndata", | ||||
| +    "test.test_grammar", | ||||
| +    "test.test_math", | ||||
| +    "test.test_iter", | ||||
| +    "test.typinganndata.ann_module", | ||||
| +) | ||||
| + | ||||
| +class RedirectImportFinder(importlib.abc.MetaPathFinder): | ||||
| +    def find_spec(self, fullname, path, target=None): | ||||
| +        # Check if the import is the problematic one | ||||
| +        if fullname in redirect_imports: | ||||
| +            try: | ||||
| +                # Attempt to import the standalone module | ||||
| +                name = fullname.removeprefix("test.") | ||||
| +                r = importlib.import_module(name) | ||||
| +                # Redirect the module in sys.modules | ||||
| +                sys.modules[fullname] = r | ||||
| +                # Return a module spec from the found module | ||||
| +                return importlib.util.find_spec(name) | ||||
| +            except ImportError: | ||||
| +                return None | ||||
| +        return None | ||||
| + | ||||
| +# Add the custom finder to sys.meta_path | ||||
| +sys.meta_path.insert(0, RedirectImportFinder()) | ||||
| + | ||||
| + | ||||
| +# ======= END DYNAMO PATCH ======= | ||||
| + | ||||
|  # tests common to dict and UserDict | ||||
|  import unittest | ||||
|  import collections | ||||
|  from test.support import get_c_recursion_limit | ||||
|   | ||||
|   | ||||
| -class BasicTestMappingProtocol(unittest.TestCase): | ||||
| +class BasicTestMappingProtocol(__TestCase): | ||||
|      # This base class can be used to check that an object conforms to the | ||||
|      # mapping protocol | ||||
|   | ||||
							
								
								
									
										719
									
								
								test/dynamo/cpython/3_13/mapping_tests.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										719
									
								
								test/dynamo/cpython/3_13/mapping_tests.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,719 @@ | ||||
| # ======= BEGIN Dynamo patch ======= | ||||
| # Owner(s): ["module: dynamo"] | ||||
|  | ||||
| # ruff: noqa | ||||
| # flake8: noqa | ||||
|  | ||||
| import sys | ||||
| import torch | ||||
| import torch._dynamo.test_case | ||||
| import unittest | ||||
| from torch._dynamo.test_case import CPythonTestCase | ||||
| from torch.testing._internal.common_utils import run_tests | ||||
|  | ||||
| __TestCase = CPythonTestCase | ||||
|  | ||||
|  | ||||
| # redirect import statements | ||||
| import sys | ||||
| import importlib.abc | ||||
|  | ||||
| redirect_imports = ( | ||||
|     "test.mapping_tests", | ||||
|     "test.typinganndata", | ||||
|     "test.test_grammar", | ||||
|     "test.test_math", | ||||
|     "test.test_iter", | ||||
|     "test.typinganndata.ann_module", | ||||
| ) | ||||
|  | ||||
| class RedirectImportFinder(importlib.abc.MetaPathFinder): | ||||
|     def find_spec(self, fullname, path, target=None): | ||||
|         # Check if the import is the problematic one | ||||
|         if fullname in redirect_imports: | ||||
|             try: | ||||
|                 # Attempt to import the standalone module | ||||
|                 name = fullname.removeprefix("test.") | ||||
|                 r = importlib.import_module(name) | ||||
|                 # Redirect the module in sys.modules | ||||
|                 sys.modules[fullname] = r | ||||
|                 # Return a module spec from the found module | ||||
|                 return importlib.util.find_spec(name) | ||||
|             except ImportError: | ||||
|                 return None | ||||
|         return None | ||||
|  | ||||
| # Add the custom finder to sys.meta_path | ||||
| sys.meta_path.insert(0, RedirectImportFinder()) | ||||
|  | ||||
|  | ||||
| # ======= END DYNAMO PATCH ======= | ||||
|  | ||||
| # tests common to dict and UserDict | ||||
| import unittest | ||||
| import collections | ||||
| from test.support import get_c_recursion_limit | ||||
|  | ||||
|  | ||||
| class BasicTestMappingProtocol(__TestCase): | ||||
|     # This base class can be used to check that an object conforms to the | ||||
|     # mapping protocol | ||||
|  | ||||
|     # Functions that can be useful to override to adapt to dictionary | ||||
|     # semantics | ||||
|     type2test = None # which class is being tested (overwrite in subclasses) | ||||
|  | ||||
|     def _reference(self): | ||||
|         """Return a dictionary of values which are invariant by storage | ||||
|         in the object under test.""" | ||||
|         return {"1": "2", "key1":"value1", "key2":(1,2,3)} | ||||
|     def _empty_mapping(self): | ||||
|         """Return an empty mapping object""" | ||||
|         return self.type2test() | ||||
|     def _full_mapping(self, data): | ||||
|         """Return a mapping object with the value contained in data | ||||
|         dictionary""" | ||||
|         x = self._empty_mapping() | ||||
|         for key, value in data.items(): | ||||
|             x[key] = value | ||||
|         return x | ||||
|  | ||||
|     def __init__(self, *args, **kw): | ||||
|         unittest.TestCase.__init__(self, *args, **kw) | ||||
|         self.reference = self._reference().copy() | ||||
|  | ||||
|         # A (key, value) pair not in the mapping | ||||
|         key, value = self.reference.popitem() | ||||
|         self.other = {key:value} | ||||
|  | ||||
|         # A (key, value) pair in the mapping | ||||
|         key, value = self.reference.popitem() | ||||
|         self.inmapping = {key:value} | ||||
|         self.reference[key] = value | ||||
|  | ||||
|     def test_read(self): | ||||
|         # Test for read only operations on mapping | ||||
|         p = self._empty_mapping() | ||||
|         p1 = dict(p) #workaround for singleton objects | ||||
|         d = self._full_mapping(self.reference) | ||||
|         if d is p: | ||||
|             p = p1 | ||||
|         #Indexing | ||||
|         for key, value in self.reference.items(): | ||||
|             self.assertEqual(d[key], value) | ||||
|         knownkey = list(self.other.keys())[0] | ||||
|         self.assertRaises(KeyError, lambda:d[knownkey]) | ||||
|         #len | ||||
|         self.assertEqual(len(p), 0) | ||||
|         self.assertEqual(len(d), len(self.reference)) | ||||
|         #__contains__ | ||||
|         for k in self.reference: | ||||
|             self.assertIn(k, d) | ||||
|         for k in self.other: | ||||
|             self.assertNotIn(k, d) | ||||
|         #cmp | ||||
|         self.assertEqual(p, p) | ||||
|         self.assertEqual(d, d) | ||||
|         self.assertNotEqual(p, d) | ||||
|         self.assertNotEqual(d, p) | ||||
|         #bool | ||||
|         if p: self.fail("Empty mapping must compare to False") | ||||
|         if not d: self.fail("Full mapping must compare to True") | ||||
|         # keys(), items(), iterkeys() ... | ||||
|         def check_iterandlist(iter, lst, ref): | ||||
|             self.assertTrue(hasattr(iter, '__next__')) | ||||
|             self.assertTrue(hasattr(iter, '__iter__')) | ||||
|             x = list(iter) | ||||
|             self.assertTrue(set(x)==set(lst)==set(ref)) | ||||
|         check_iterandlist(iter(d.keys()), list(d.keys()), | ||||
|                           self.reference.keys()) | ||||
|         check_iterandlist(iter(d), list(d.keys()), self.reference.keys()) | ||||
|         check_iterandlist(iter(d.values()), list(d.values()), | ||||
|                           self.reference.values()) | ||||
|         check_iterandlist(iter(d.items()), list(d.items()), | ||||
|                           self.reference.items()) | ||||
|         #get | ||||
|         key, value = next(iter(d.items())) | ||||
|         knownkey, knownvalue = next(iter(self.other.items())) | ||||
|         self.assertEqual(d.get(key, knownvalue), value) | ||||
|         self.assertEqual(d.get(knownkey, knownvalue), knownvalue) | ||||
|         self.assertNotIn(knownkey, d) | ||||
|  | ||||
|     def test_write(self): | ||||
|         # Test for write operations on mapping | ||||
|         p = self._empty_mapping() | ||||
|         #Indexing | ||||
|         for key, value in self.reference.items(): | ||||
|             p[key] = value | ||||
|             self.assertEqual(p[key], value) | ||||
|         for key in self.reference.keys(): | ||||
|             del p[key] | ||||
|             self.assertRaises(KeyError, lambda:p[key]) | ||||
|         p = self._empty_mapping() | ||||
|         #update | ||||
|         p.update(self.reference) | ||||
|         self.assertEqual(dict(p), self.reference) | ||||
|         items = list(p.items()) | ||||
|         p = self._empty_mapping() | ||||
|         p.update(items) | ||||
|         self.assertEqual(dict(p), self.reference) | ||||
|         d = self._full_mapping(self.reference) | ||||
|         #setdefault | ||||
|         key, value = next(iter(d.items())) | ||||
|         knownkey, knownvalue = next(iter(self.other.items())) | ||||
|         self.assertEqual(d.setdefault(key, knownvalue), value) | ||||
|         self.assertEqual(d[key], value) | ||||
|         self.assertEqual(d.setdefault(knownkey, knownvalue), knownvalue) | ||||
|         self.assertEqual(d[knownkey], knownvalue) | ||||
|         #pop | ||||
|         self.assertEqual(d.pop(knownkey), knownvalue) | ||||
|         self.assertNotIn(knownkey, d) | ||||
|         self.assertRaises(KeyError, d.pop, knownkey) | ||||
|         default = 909 | ||||
|         d[knownkey] = knownvalue | ||||
|         self.assertEqual(d.pop(knownkey, default), knownvalue) | ||||
|         self.assertNotIn(knownkey, d) | ||||
|         self.assertEqual(d.pop(knownkey, default), default) | ||||
|         #popitem | ||||
|         key, value = d.popitem() | ||||
|         self.assertNotIn(key, d) | ||||
|         self.assertEqual(value, self.reference[key]) | ||||
|         p=self._empty_mapping() | ||||
|         self.assertRaises(KeyError, p.popitem) | ||||
|  | ||||
|     def test_constructor(self): | ||||
|         self.assertEqual(self._empty_mapping(), self._empty_mapping()) | ||||
|  | ||||
|     def test_bool(self): | ||||
|         self.assertTrue(not self._empty_mapping()) | ||||
|         self.assertTrue(self.reference) | ||||
|         self.assertTrue(bool(self._empty_mapping()) is False) | ||||
|         self.assertTrue(bool(self.reference) is True) | ||||
|  | ||||
|     def test_keys(self): | ||||
|         d = self._empty_mapping() | ||||
|         self.assertEqual(list(d.keys()), []) | ||||
|         d = self.reference | ||||
|         self.assertIn(list(self.inmapping.keys())[0], d.keys()) | ||||
|         self.assertNotIn(list(self.other.keys())[0], d.keys()) | ||||
|         self.assertRaises(TypeError, d.keys, None) | ||||
|  | ||||
|     def test_values(self): | ||||
|         d = self._empty_mapping() | ||||
|         self.assertEqual(list(d.values()), []) | ||||
|  | ||||
|         self.assertRaises(TypeError, d.values, None) | ||||
|  | ||||
|     def test_items(self): | ||||
|         d = self._empty_mapping() | ||||
|         self.assertEqual(list(d.items()), []) | ||||
|  | ||||
|         self.assertRaises(TypeError, d.items, None) | ||||
|  | ||||
|     def test_len(self): | ||||
|         d = self._empty_mapping() | ||||
|         self.assertEqual(len(d), 0) | ||||
|  | ||||
|     def test_getitem(self): | ||||
|         d = self.reference | ||||
|         self.assertEqual(d[list(self.inmapping.keys())[0]], | ||||
|                          list(self.inmapping.values())[0]) | ||||
|  | ||||
|         self.assertRaises(TypeError, d.__getitem__) | ||||
|  | ||||
|     def test_update(self): | ||||
|         # mapping argument | ||||
|         d = self._empty_mapping() | ||||
|         d.update(self.other) | ||||
|         self.assertEqual(list(d.items()), list(self.other.items())) | ||||
|  | ||||
|         # No argument | ||||
|         d = self._empty_mapping() | ||||
|         d.update() | ||||
|         self.assertEqual(d, self._empty_mapping()) | ||||
|  | ||||
|         # item sequence | ||||
|         d = self._empty_mapping() | ||||
|         d.update(self.other.items()) | ||||
|         self.assertEqual(list(d.items()), list(self.other.items())) | ||||
|  | ||||
|         # Iterator | ||||
|         d = self._empty_mapping() | ||||
|         d.update(self.other.items()) | ||||
|         self.assertEqual(list(d.items()), list(self.other.items())) | ||||
|  | ||||
|         # FIXME: Doesn't work with UserDict | ||||
|         # self.assertRaises((TypeError, AttributeError), d.update, None) | ||||
|         self.assertRaises((TypeError, AttributeError), d.update, 42) | ||||
|  | ||||
|         outerself = self | ||||
|         class SimpleUserDict: | ||||
|             def __init__(self): | ||||
|                 self.d = outerself.reference | ||||
|             def keys(self): | ||||
|                 return self.d.keys() | ||||
|             def __getitem__(self, i): | ||||
|                 return self.d[i] | ||||
|         d.clear() | ||||
|         d.update(SimpleUserDict()) | ||||
|         i1 = sorted(d.items()) | ||||
|         i2 = sorted(self.reference.items()) | ||||
|         self.assertEqual(i1, i2) | ||||
|  | ||||
|         class Exc(Exception): pass | ||||
|  | ||||
|         d = self._empty_mapping() | ||||
|         class FailingUserDict: | ||||
|             def keys(self): | ||||
|                 raise Exc | ||||
|         self.assertRaises(Exc, d.update, FailingUserDict()) | ||||
|  | ||||
|         d.clear() | ||||
|  | ||||
|         class FailingUserDict: | ||||
|             def keys(self): | ||||
|                 class BogonIter: | ||||
|                     def __init__(self): | ||||
|                         self.i = 1 | ||||
|                     def __iter__(self): | ||||
|                         return self | ||||
|                     def __next__(self): | ||||
|                         if self.i: | ||||
|                             self.i = 0 | ||||
|                             return 'a' | ||||
|                         raise Exc | ||||
|                 return BogonIter() | ||||
|             def __getitem__(self, key): | ||||
|                 return key | ||||
|         self.assertRaises(Exc, d.update, FailingUserDict()) | ||||
|  | ||||
|         class FailingUserDict: | ||||
|             def keys(self): | ||||
|                 class BogonIter: | ||||
|                     def __init__(self): | ||||
|                         self.i = ord('a') | ||||
|                     def __iter__(self): | ||||
|                         return self | ||||
|                     def __next__(self): | ||||
|                         if self.i <= ord('z'): | ||||
|                             rtn = chr(self.i) | ||||
|                             self.i += 1 | ||||
|                             return rtn | ||||
|                         raise StopIteration | ||||
|                 return BogonIter() | ||||
|             def __getitem__(self, key): | ||||
|                 raise Exc | ||||
|         self.assertRaises(Exc, d.update, FailingUserDict()) | ||||
|  | ||||
|         d = self._empty_mapping() | ||||
|         class badseq(object): | ||||
|             def __iter__(self): | ||||
|                 return self | ||||
|             def __next__(self): | ||||
|                 raise Exc() | ||||
|  | ||||
|         self.assertRaises(Exc, d.update, badseq()) | ||||
|  | ||||
|         self.assertRaises(ValueError, d.update, [(1, 2, 3)]) | ||||
|  | ||||
|     # no test_fromkeys or test_copy as both os.environ and selves don't support it | ||||
|  | ||||
|     def test_get(self): | ||||
|         d = self._empty_mapping() | ||||
|         self.assertTrue(d.get(list(self.other.keys())[0]) is None) | ||||
|         self.assertEqual(d.get(list(self.other.keys())[0], 3), 3) | ||||
|         d = self.reference | ||||
|         self.assertTrue(d.get(list(self.other.keys())[0]) is None) | ||||
|         self.assertEqual(d.get(list(self.other.keys())[0], 3), 3) | ||||
|         self.assertEqual(d.get(list(self.inmapping.keys())[0]), | ||||
|                          list(self.inmapping.values())[0]) | ||||
|         self.assertEqual(d.get(list(self.inmapping.keys())[0], 3), | ||||
|                          list(self.inmapping.values())[0]) | ||||
|         self.assertRaises(TypeError, d.get) | ||||
|         self.assertRaises(TypeError, d.get, None, None, None) | ||||
|  | ||||
|     def test_setdefault(self): | ||||
|         d = self._empty_mapping() | ||||
|         self.assertRaises(TypeError, d.setdefault) | ||||
|  | ||||
|     def test_popitem(self): | ||||
|         d = self._empty_mapping() | ||||
|         self.assertRaises(KeyError, d.popitem) | ||||
|         self.assertRaises(TypeError, d.popitem, 42) | ||||
|  | ||||
|     def test_pop(self): | ||||
|         d = self._empty_mapping() | ||||
|         k, v = list(self.inmapping.items())[0] | ||||
|         d[k] = v | ||||
|         self.assertRaises(KeyError, d.pop, list(self.other.keys())[0]) | ||||
|  | ||||
|         self.assertEqual(d.pop(k), v) | ||||
|         self.assertEqual(len(d), 0) | ||||
|  | ||||
|         self.assertRaises(KeyError, d.pop, k) | ||||
|  | ||||
|  | ||||
| class TestMappingProtocol(BasicTestMappingProtocol): | ||||
|     def test_constructor(self): | ||||
|         BasicTestMappingProtocol.test_constructor(self) | ||||
|         self.assertTrue(self._empty_mapping() is not self._empty_mapping()) | ||||
|         self.assertEqual(self.type2test(x=1, y=2), {"x": 1, "y": 2}) | ||||
|  | ||||
|     def test_bool(self): | ||||
|         BasicTestMappingProtocol.test_bool(self) | ||||
|         self.assertTrue(not self._empty_mapping()) | ||||
|         self.assertTrue(self._full_mapping({"x": "y"})) | ||||
|         self.assertTrue(bool(self._empty_mapping()) is False) | ||||
|         self.assertTrue(bool(self._full_mapping({"x": "y"})) is True) | ||||
|  | ||||
|     def test_keys(self): | ||||
|         BasicTestMappingProtocol.test_keys(self) | ||||
|         d = self._empty_mapping() | ||||
|         self.assertEqual(list(d.keys()), []) | ||||
|         d = self._full_mapping({'a': 1, 'b': 2}) | ||||
|         k = d.keys() | ||||
|         self.assertIn('a', k) | ||||
|         self.assertIn('b', k) | ||||
|         self.assertNotIn('c', k) | ||||
|  | ||||
|     def test_values(self): | ||||
|         BasicTestMappingProtocol.test_values(self) | ||||
|         d = self._full_mapping({1:2}) | ||||
|         self.assertEqual(list(d.values()), [2]) | ||||
|  | ||||
|     def test_items(self): | ||||
|         BasicTestMappingProtocol.test_items(self) | ||||
|  | ||||
|         d = self._full_mapping({1:2}) | ||||
|         self.assertEqual(list(d.items()), [(1, 2)]) | ||||
|  | ||||
|     def test_contains(self): | ||||
|         d = self._empty_mapping() | ||||
|         self.assertNotIn('a', d) | ||||
|         self.assertTrue(not ('a' in d)) | ||||
|         self.assertTrue('a' not in d) | ||||
|         d = self._full_mapping({'a': 1, 'b': 2}) | ||||
|         self.assertIn('a', d) | ||||
|         self.assertIn('b', d) | ||||
|         self.assertNotIn('c', d) | ||||
|  | ||||
|         self.assertRaises(TypeError, d.__contains__) | ||||
|  | ||||
|     def test_len(self): | ||||
|         BasicTestMappingProtocol.test_len(self) | ||||
|         d = self._full_mapping({'a': 1, 'b': 2}) | ||||
|         self.assertEqual(len(d), 2) | ||||
|  | ||||
|     def test_getitem(self): | ||||
|         BasicTestMappingProtocol.test_getitem(self) | ||||
|         d = self._full_mapping({'a': 1, 'b': 2}) | ||||
|         self.assertEqual(d['a'], 1) | ||||
|         self.assertEqual(d['b'], 2) | ||||
|         d['c'] = 3 | ||||
|         d['a'] = 4 | ||||
|         self.assertEqual(d['c'], 3) | ||||
|         self.assertEqual(d['a'], 4) | ||||
|         del d['b'] | ||||
|         self.assertEqual(d, {'a': 4, 'c': 3}) | ||||
|  | ||||
|         self.assertRaises(TypeError, d.__getitem__) | ||||
|  | ||||
|     def test_clear(self): | ||||
|         d = self._full_mapping({1:1, 2:2, 3:3}) | ||||
|         d.clear() | ||||
|         self.assertEqual(d, {}) | ||||
|  | ||||
|         self.assertRaises(TypeError, d.clear, None) | ||||
|  | ||||
|     def test_update(self): | ||||
|         BasicTestMappingProtocol.test_update(self) | ||||
|         # mapping argument | ||||
|         d = self._empty_mapping() | ||||
|         d.update({1:100}) | ||||
|         d.update({2:20}) | ||||
|         d.update({1:1, 2:2, 3:3}) | ||||
|         self.assertEqual(d, {1:1, 2:2, 3:3}) | ||||
|  | ||||
|         # no argument | ||||
|         d.update() | ||||
|         self.assertEqual(d, {1:1, 2:2, 3:3}) | ||||
|  | ||||
|         # keyword arguments | ||||
|         d = self._empty_mapping() | ||||
|         d.update(x=100) | ||||
|         d.update(y=20) | ||||
|         d.update(x=1, y=2, z=3) | ||||
|         self.assertEqual(d, {"x":1, "y":2, "z":3}) | ||||
|  | ||||
|         # item sequence | ||||
|         d = self._empty_mapping() | ||||
|         d.update([("x", 100), ("y", 20)]) | ||||
|         self.assertEqual(d, {"x":100, "y":20}) | ||||
|  | ||||
|         # Both item sequence and keyword arguments | ||||
|         d = self._empty_mapping() | ||||
|         d.update([("x", 100), ("y", 20)], x=1, y=2) | ||||
|         self.assertEqual(d, {"x":1, "y":2}) | ||||
|  | ||||
|         # iterator | ||||
|         d = self._full_mapping({1:3, 2:4}) | ||||
|         d.update(self._full_mapping({1:2, 3:4, 5:6}).items()) | ||||
|         self.assertEqual(d, {1:2, 2:4, 3:4, 5:6}) | ||||
|  | ||||
|         class SimpleUserDict: | ||||
|             def __init__(self): | ||||
|                 self.d = {1:1, 2:2, 3:3} | ||||
|             def keys(self): | ||||
|                 return self.d.keys() | ||||
|             def __getitem__(self, i): | ||||
|                 return self.d[i] | ||||
|         d.clear() | ||||
|         d.update(SimpleUserDict()) | ||||
|         self.assertEqual(d, {1:1, 2:2, 3:3}) | ||||
|  | ||||
|     def test_fromkeys(self): | ||||
|         self.assertEqual(self.type2test.fromkeys('abc'), {'a':None, 'b':None, 'c':None}) | ||||
|         d = self._empty_mapping() | ||||
|         self.assertTrue(not(d.fromkeys('abc') is d)) | ||||
|         self.assertEqual(d.fromkeys('abc'), {'a':None, 'b':None, 'c':None}) | ||||
|         self.assertEqual(d.fromkeys((4,5),0), {4:0, 5:0}) | ||||
|         self.assertEqual(d.fromkeys([]), {}) | ||||
|         def g(): | ||||
|             yield 1 | ||||
|         self.assertEqual(d.fromkeys(g()), {1:None}) | ||||
|         self.assertRaises(TypeError, {}.fromkeys, 3) | ||||
|         class dictlike(self.type2test): pass | ||||
|         self.assertEqual(dictlike.fromkeys('a'), {'a':None}) | ||||
|         self.assertEqual(dictlike().fromkeys('a'), {'a':None}) | ||||
|         self.assertTrue(dictlike.fromkeys('a').__class__ is dictlike) | ||||
|         self.assertTrue(dictlike().fromkeys('a').__class__ is dictlike) | ||||
|         self.assertTrue(type(dictlike.fromkeys('a')) is dictlike) | ||||
|         class mydict(self.type2test): | ||||
|             def __new__(cls): | ||||
|                 return collections.UserDict() | ||||
|         ud = mydict.fromkeys('ab') | ||||
|         self.assertEqual(ud, {'a':None, 'b':None}) | ||||
|         self.assertIsInstance(ud, collections.UserDict) | ||||
|         self.assertRaises(TypeError, dict.fromkeys) | ||||
|  | ||||
|         class Exc(Exception): pass | ||||
|  | ||||
|         class baddict1(self.type2test): | ||||
|             def __init__(self, *args, **kwargs): | ||||
|                 raise Exc() | ||||
|  | ||||
|         self.assertRaises(Exc, baddict1.fromkeys, [1]) | ||||
|  | ||||
|         class BadSeq(object): | ||||
|             def __iter__(self): | ||||
|                 return self | ||||
|             def __next__(self): | ||||
|                 raise Exc() | ||||
|  | ||||
|         self.assertRaises(Exc, self.type2test.fromkeys, BadSeq()) | ||||
|  | ||||
|         class baddict2(self.type2test): | ||||
|             def __setitem__(self, key, value): | ||||
|                 raise Exc() | ||||
|  | ||||
|         self.assertRaises(Exc, baddict2.fromkeys, [1]) | ||||
|  | ||||
|     def test_copy(self): | ||||
|         d = self._full_mapping({1:1, 2:2, 3:3}) | ||||
|         self.assertEqual(d.copy(), {1:1, 2:2, 3:3}) | ||||
|         d = self._empty_mapping() | ||||
|         self.assertEqual(d.copy(), d) | ||||
|         self.assertIsInstance(d.copy(), d.__class__) | ||||
|         self.assertRaises(TypeError, d.copy, None) | ||||
|  | ||||
|     def test_get(self): | ||||
|         BasicTestMappingProtocol.test_get(self) | ||||
|         d = self._empty_mapping() | ||||
|         self.assertTrue(d.get('c') is None) | ||||
|         self.assertEqual(d.get('c', 3), 3) | ||||
|         d = self._full_mapping({'a' : 1, 'b' : 2}) | ||||
|         self.assertTrue(d.get('c') is None) | ||||
|         self.assertEqual(d.get('c', 3), 3) | ||||
|         self.assertEqual(d.get('a'), 1) | ||||
|         self.assertEqual(d.get('a', 3), 1) | ||||
|  | ||||
|     def test_setdefault(self): | ||||
|         BasicTestMappingProtocol.test_setdefault(self) | ||||
|         d = self._empty_mapping() | ||||
|         self.assertTrue(d.setdefault('key0') is None) | ||||
|         d.setdefault('key0', []) | ||||
|         self.assertTrue(d.setdefault('key0') is None) | ||||
|         d.setdefault('key', []).append(3) | ||||
|         self.assertEqual(d['key'][0], 3) | ||||
|         d.setdefault('key', []).append(4) | ||||
|         self.assertEqual(len(d['key']), 2) | ||||
|  | ||||
|     def test_popitem(self): | ||||
|         BasicTestMappingProtocol.test_popitem(self) | ||||
|         for copymode in -1, +1: | ||||
|             # -1: b has same structure as a | ||||
|             # +1: b is a.copy() | ||||
|             for log2size in range(12): | ||||
|                 size = 2**log2size | ||||
|                 a = self._empty_mapping() | ||||
|                 b = self._empty_mapping() | ||||
|                 for i in range(size): | ||||
|                     a[repr(i)] = i | ||||
|                     if copymode < 0: | ||||
|                         b[repr(i)] = i | ||||
|                 if copymode > 0: | ||||
|                     b = a.copy() | ||||
|                 for i in range(size): | ||||
|                     ka, va = ta = a.popitem() | ||||
|                     self.assertEqual(va, int(ka)) | ||||
|                     kb, vb = tb = b.popitem() | ||||
|                     self.assertEqual(vb, int(kb)) | ||||
|                     self.assertTrue(not(copymode < 0 and ta != tb)) | ||||
|                 self.assertTrue(not a) | ||||
|                 self.assertTrue(not b) | ||||
|  | ||||
|     def test_pop(self): | ||||
|         BasicTestMappingProtocol.test_pop(self) | ||||
|  | ||||
|         # Tests for pop with specified key | ||||
|         d = self._empty_mapping() | ||||
|         k, v = 'abc', 'def' | ||||
|  | ||||
|         self.assertEqual(d.pop(k, v), v) | ||||
|         d[k] = v | ||||
|         self.assertEqual(d.pop(k, 1), v) | ||||
|  | ||||
|  | ||||
| class TestHashMappingProtocol(TestMappingProtocol): | ||||
|  | ||||
|     def test_getitem(self): | ||||
|         TestMappingProtocol.test_getitem(self) | ||||
|         class Exc(Exception): pass | ||||
|  | ||||
|         class BadEq(object): | ||||
|             def __eq__(self, other): | ||||
|                 raise Exc() | ||||
|             def __hash__(self): | ||||
|                 return 24 | ||||
|  | ||||
|         d = self._empty_mapping() | ||||
|         d[BadEq()] = 42 | ||||
|         self.assertRaises(KeyError, d.__getitem__, 23) | ||||
|  | ||||
|         class BadHash(object): | ||||
|             fail = False | ||||
|             def __hash__(self): | ||||
|                 if self.fail: | ||||
|                     raise Exc() | ||||
|                 else: | ||||
|                     return 42 | ||||
|  | ||||
|         d = self._empty_mapping() | ||||
|         x = BadHash() | ||||
|         d[x] = 42 | ||||
|         x.fail = True | ||||
|         self.assertRaises(Exc, d.__getitem__, x) | ||||
|  | ||||
|     def test_fromkeys(self): | ||||
|         TestMappingProtocol.test_fromkeys(self) | ||||
|         class mydict(self.type2test): | ||||
|             def __new__(cls): | ||||
|                 return collections.UserDict() | ||||
|         ud = mydict.fromkeys('ab') | ||||
|         self.assertEqual(ud, {'a':None, 'b':None}) | ||||
|         self.assertIsInstance(ud, collections.UserDict) | ||||
|  | ||||
|     def test_pop(self): | ||||
|         TestMappingProtocol.test_pop(self) | ||||
|  | ||||
|         class Exc(Exception): pass | ||||
|  | ||||
|         class BadHash(object): | ||||
|             fail = False | ||||
|             def __hash__(self): | ||||
|                 if self.fail: | ||||
|                     raise Exc() | ||||
|                 else: | ||||
|                     return 42 | ||||
|  | ||||
|         d = self._empty_mapping() | ||||
|         x = BadHash() | ||||
|         d[x] = 42 | ||||
|         x.fail = True | ||||
|         self.assertRaises(Exc, d.pop, x) | ||||
|  | ||||
|     def test_mutatingiteration(self): | ||||
|         d = self._empty_mapping() | ||||
|         d[1] = 1 | ||||
|         try: | ||||
|             count = 0 | ||||
|             for i in d: | ||||
|                 d[i+1] = 1 | ||||
|                 if count >= 1: | ||||
|                     self.fail("changing dict size during iteration doesn't raise Error") | ||||
|                 count += 1 | ||||
|         except RuntimeError: | ||||
|             pass | ||||
|  | ||||
|     def test_repr(self): | ||||
|         d = self._empty_mapping() | ||||
|         self.assertEqual(repr(d), '{}') | ||||
|         d[1] = 2 | ||||
|         self.assertEqual(repr(d), '{1: 2}') | ||||
|         d = self._empty_mapping() | ||||
|         d[1] = d | ||||
|         self.assertEqual(repr(d), '{1: {...}}') | ||||
|  | ||||
|         class Exc(Exception): pass | ||||
|  | ||||
|         class BadRepr(object): | ||||
|             def __repr__(self): | ||||
|                 raise Exc() | ||||
|  | ||||
|         d = self._full_mapping({1: BadRepr()}) | ||||
|         self.assertRaises(Exc, repr, d) | ||||
|  | ||||
|     def test_repr_deep(self): | ||||
|         d = self._empty_mapping() | ||||
|         for i in range(get_c_recursion_limit() + 1): | ||||
|             d0 = d | ||||
|             d = self._empty_mapping() | ||||
|             d[1] = d0 | ||||
|         self.assertRaises(RecursionError, repr, d) | ||||
|  | ||||
|     def test_eq(self): | ||||
|         self.assertEqual(self._empty_mapping(), self._empty_mapping()) | ||||
|         self.assertEqual(self._full_mapping({1: 2}), | ||||
|                          self._full_mapping({1: 2})) | ||||
|  | ||||
|         class Exc(Exception): pass | ||||
|  | ||||
|         class BadCmp(object): | ||||
|             def __eq__(self, other): | ||||
|                 raise Exc() | ||||
|             def __hash__(self): | ||||
|                 return 1 | ||||
|  | ||||
|         d1 = self._full_mapping({BadCmp(): 1}) | ||||
|         d2 = self._full_mapping({1: 1}) | ||||
|         self.assertRaises(Exc, lambda: BadCmp()==1) | ||||
|         self.assertRaises(Exc, lambda: d1==d2) | ||||
|  | ||||
|     def test_setdefault(self): | ||||
|         TestMappingProtocol.test_setdefault(self) | ||||
|  | ||||
|         class Exc(Exception): pass | ||||
|  | ||||
|         class BadHash(object): | ||||
|             fail = False | ||||
|             def __hash__(self): | ||||
|                 if self.fail: | ||||
|                     raise Exc() | ||||
|                 else: | ||||
|                     return 42 | ||||
|  | ||||
|         d = self._empty_mapping() | ||||
|         x = BadHash() | ||||
|         d[x] = 42 | ||||
|         x.fail = True | ||||
|         self.assertRaises(Exc, d.setdefault, x, []) | ||||
							
								
								
									
										68
									
								
								test/dynamo/cpython/3_13/seq_tests.diff
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								test/dynamo/cpython/3_13/seq_tests.diff
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,68 @@ | ||||
| diff --git a/test/dynamo/cpython/3_13/seq_tests.py b/test/dynamo/cpython/3_13/seq_tests.py | ||||
| index 719c9434a16..4325892276d 100644 | ||||
| --- a/test/dynamo/cpython/3_13/seq_tests.py | ||||
| +++ b/test/dynamo/cpython/3_13/seq_tests.py | ||||
| @@ -1,3 +1,54 @@ | ||||
| +# ======= BEGIN Dynamo patch ======= | ||||
| +# Owner(s): ["module: dynamo"] | ||||
| + | ||||
| +# ruff: noqa | ||||
| +# flake8: noqa | ||||
| + | ||||
| +import sys | ||||
| +import torch | ||||
| +import torch._dynamo.test_case | ||||
| +import unittest | ||||
| +from torch._dynamo.test_case import CPythonTestCase | ||||
| +from torch.testing._internal.common_utils import run_tests | ||||
| + | ||||
| +__TestCase = CPythonTestCase | ||||
| + | ||||
| + | ||||
| +# redirect import statements | ||||
| +import sys | ||||
| +import importlib.abc | ||||
| + | ||||
| +redirect_imports = ( | ||||
| +    "test.mapping_tests", | ||||
| +    "test.typinganndata", | ||||
| +    "test.test_grammar", | ||||
| +    "test.test_math", | ||||
| +    "test.test_iter", | ||||
| +    "test.typinganndata.ann_module", | ||||
| +) | ||||
| + | ||||
| +class RedirectImportFinder(importlib.abc.MetaPathFinder): | ||||
| +    def find_spec(self, fullname, path, target=None): | ||||
| +        # Check if the import is the problematic one | ||||
| +        if fullname in redirect_imports: | ||||
| +            try: | ||||
| +                # Attempt to import the standalone module | ||||
| +                name = fullname.removeprefix("test.") | ||||
| +                r = importlib.import_module(name) | ||||
| +                # Redirect the module in sys.modules | ||||
| +                sys.modules[fullname] = r | ||||
| +                # Return a module spec from the found module | ||||
| +                return importlib.util.find_spec(name) | ||||
| +            except ImportError: | ||||
| +                return None | ||||
| +        return None | ||||
| + | ||||
| +# Add the custom finder to sys.meta_path | ||||
| +sys.meta_path.insert(0, RedirectImportFinder()) | ||||
| + | ||||
| + | ||||
| +# ======= END DYNAMO PATCH ======= | ||||
| + | ||||
|  """ | ||||
|  Tests common to tuple, list and UserList.UserList | ||||
|  """ | ||||
| @@ -95,7 +146,7 @@ class LyingList(list): | ||||
|      def __iter__(self): | ||||
|          yield 1 | ||||
|   | ||||
| -class CommonTest(unittest.TestCase): | ||||
| +class CommonTest(__TestCase): | ||||
|      # The type to be tested | ||||
|      type2test = None | ||||
|   | ||||
							
								
								
									
										483
									
								
								test/dynamo/cpython/3_13/seq_tests.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										483
									
								
								test/dynamo/cpython/3_13/seq_tests.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,483 @@ | ||||
| # ======= BEGIN Dynamo patch ======= | ||||
| # Owner(s): ["module: dynamo"] | ||||
|  | ||||
| # ruff: noqa | ||||
| # flake8: noqa | ||||
|  | ||||
| import sys | ||||
| import torch | ||||
| import torch._dynamo.test_case | ||||
| import unittest | ||||
| from torch._dynamo.test_case import CPythonTestCase | ||||
| from torch.testing._internal.common_utils import run_tests | ||||
|  | ||||
| __TestCase = CPythonTestCase | ||||
|  | ||||
|  | ||||
| # redirect import statements | ||||
| import sys | ||||
| import importlib.abc | ||||
|  | ||||
| redirect_imports = ( | ||||
|     "test.mapping_tests", | ||||
|     "test.typinganndata", | ||||
|     "test.test_grammar", | ||||
|     "test.test_math", | ||||
|     "test.test_iter", | ||||
|     "test.typinganndata.ann_module", | ||||
| ) | ||||
|  | ||||
| class RedirectImportFinder(importlib.abc.MetaPathFinder): | ||||
|     def find_spec(self, fullname, path, target=None): | ||||
|         # Check if the import is the problematic one | ||||
|         if fullname in redirect_imports: | ||||
|             try: | ||||
|                 # Attempt to import the standalone module | ||||
|                 name = fullname.removeprefix("test.") | ||||
|                 r = importlib.import_module(name) | ||||
|                 # Redirect the module in sys.modules | ||||
|                 sys.modules[fullname] = r | ||||
|                 # Return a module spec from the found module | ||||
|                 return importlib.util.find_spec(name) | ||||
|             except ImportError: | ||||
|                 return None | ||||
|         return None | ||||
|  | ||||
| # Add the custom finder to sys.meta_path | ||||
| sys.meta_path.insert(0, RedirectImportFinder()) | ||||
|  | ||||
|  | ||||
| # ======= END DYNAMO PATCH ======= | ||||
|  | ||||
| """ | ||||
| Tests common to tuple, list and UserList.UserList | ||||
| """ | ||||
|  | ||||
| import unittest | ||||
| import sys | ||||
| import pickle | ||||
| from test import support | ||||
| from test.support import ALWAYS_EQ, NEVER_EQ | ||||
|  | ||||
| # Various iterables | ||||
| # This is used for checking the constructor (here and in test_deque.py) | ||||
| def iterfunc(seqn): | ||||
|     'Regular generator' | ||||
|     for i in seqn: | ||||
|         yield i | ||||
|  | ||||
| class Sequence: | ||||
|     'Sequence using __getitem__' | ||||
|     def __init__(self, seqn): | ||||
|         self.seqn = seqn | ||||
|     def __getitem__(self, i): | ||||
|         return self.seqn[i] | ||||
|  | ||||
| class IterFunc: | ||||
|     'Sequence using iterator protocol' | ||||
|     def __init__(self, seqn): | ||||
|         self.seqn = seqn | ||||
|         self.i = 0 | ||||
|     def __iter__(self): | ||||
|         return self | ||||
|     def __next__(self): | ||||
|         if self.i >= len(self.seqn): raise StopIteration | ||||
|         v = self.seqn[self.i] | ||||
|         self.i += 1 | ||||
|         return v | ||||
|  | ||||
| class IterGen: | ||||
|     'Sequence using iterator protocol defined with a generator' | ||||
|     def __init__(self, seqn): | ||||
|         self.seqn = seqn | ||||
|         self.i = 0 | ||||
|     def __iter__(self): | ||||
|         for val in self.seqn: | ||||
|             yield val | ||||
|  | ||||
| class IterNextOnly: | ||||
|     'Missing __getitem__ and __iter__' | ||||
|     def __init__(self, seqn): | ||||
|         self.seqn = seqn | ||||
|         self.i = 0 | ||||
|     def __next__(self): | ||||
|         if self.i >= len(self.seqn): raise StopIteration | ||||
|         v = self.seqn[self.i] | ||||
|         self.i += 1 | ||||
|         return v | ||||
|  | ||||
| class IterNoNext: | ||||
|     'Iterator missing __next__()' | ||||
|     def __init__(self, seqn): | ||||
|         self.seqn = seqn | ||||
|         self.i = 0 | ||||
|     def __iter__(self): | ||||
|         return self | ||||
|  | ||||
| class IterGenExc: | ||||
|     'Test propagation of exceptions' | ||||
|     def __init__(self, seqn): | ||||
|         self.seqn = seqn | ||||
|         self.i = 0 | ||||
|     def __iter__(self): | ||||
|         return self | ||||
|     def __next__(self): | ||||
|         3 // 0 | ||||
|  | ||||
| class IterFuncStop: | ||||
|     'Test immediate stop' | ||||
|     def __init__(self, seqn): | ||||
|         pass | ||||
|     def __iter__(self): | ||||
|         return self | ||||
|     def __next__(self): | ||||
|         raise StopIteration | ||||
|  | ||||
| from itertools import chain | ||||
| def itermulti(seqn): | ||||
|     'Test multiple tiers of iterators' | ||||
|     return chain(map(lambda x:x, iterfunc(IterGen(Sequence(seqn))))) | ||||
|  | ||||
| class LyingTuple(tuple): | ||||
|     def __iter__(self): | ||||
|         yield 1 | ||||
|  | ||||
| class LyingList(list): | ||||
|     def __iter__(self): | ||||
|         yield 1 | ||||
|  | ||||
| class CommonTest(__TestCase): | ||||
|     # The type to be tested | ||||
|     type2test = None | ||||
|  | ||||
|     def test_constructors(self): | ||||
|         l0 = [] | ||||
|         l1 = [0] | ||||
|         l2 = [0, 1] | ||||
|  | ||||
|         u = self.type2test() | ||||
|         u0 = self.type2test(l0) | ||||
|         u1 = self.type2test(l1) | ||||
|         u2 = self.type2test(l2) | ||||
|  | ||||
|         uu = self.type2test(u) | ||||
|         uu0 = self.type2test(u0) | ||||
|         uu1 = self.type2test(u1) | ||||
|         uu2 = self.type2test(u2) | ||||
|  | ||||
|         v = self.type2test(tuple(u)) | ||||
|         class OtherSeq: | ||||
|             def __init__(self, initseq): | ||||
|                 self.__data = initseq | ||||
|             def __len__(self): | ||||
|                 return len(self.__data) | ||||
|             def __getitem__(self, i): | ||||
|                 return self.__data[i] | ||||
|         s = OtherSeq(u0) | ||||
|         v0 = self.type2test(s) | ||||
|         self.assertEqual(len(v0), len(s)) | ||||
|  | ||||
|         s = "this is also a sequence" | ||||
|         vv = self.type2test(s) | ||||
|         self.assertEqual(len(vv), len(s)) | ||||
|  | ||||
|         # Create from various iteratables | ||||
|         for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)): | ||||
|             for g in (Sequence, IterFunc, IterGen, | ||||
|                       itermulti, iterfunc): | ||||
|                 self.assertEqual(self.type2test(g(s)), self.type2test(s)) | ||||
|             self.assertEqual(self.type2test(IterFuncStop(s)), self.type2test()) | ||||
|             self.assertEqual(self.type2test(c for c in "123"), self.type2test("123")) | ||||
|             self.assertRaises(TypeError, self.type2test, IterNextOnly(s)) | ||||
|             self.assertRaises(TypeError, self.type2test, IterNoNext(s)) | ||||
|             self.assertRaises(ZeroDivisionError, self.type2test, IterGenExc(s)) | ||||
|  | ||||
|         # Issue #23757 | ||||
|         self.assertEqual(self.type2test(LyingTuple((2,))), self.type2test((1,))) | ||||
|         self.assertEqual(self.type2test(LyingList([2])), self.type2test([1])) | ||||
|  | ||||
|         with self.assertRaises(TypeError): | ||||
|             self.type2test(unsupported_arg=[]) | ||||
|  | ||||
|     def test_truth(self): | ||||
|         self.assertFalse(self.type2test()) | ||||
|         self.assertTrue(self.type2test([42])) | ||||
|  | ||||
|     def test_getitem(self): | ||||
|         u = self.type2test([0, 1, 2, 3, 4]) | ||||
|         for i in range(len(u)): | ||||
|             self.assertEqual(u[i], i) | ||||
|             self.assertEqual(u[int(i)], i) | ||||
|         for i in range(-len(u), -1): | ||||
|             self.assertEqual(u[i], len(u)+i) | ||||
|             self.assertEqual(u[int(i)], len(u)+i) | ||||
|         self.assertRaises(IndexError, u.__getitem__, -len(u)-1) | ||||
|         self.assertRaises(IndexError, u.__getitem__, len(u)) | ||||
|         self.assertRaises(ValueError, u.__getitem__, slice(0,10,0)) | ||||
|  | ||||
|         u = self.type2test() | ||||
|         self.assertRaises(IndexError, u.__getitem__, 0) | ||||
|         self.assertRaises(IndexError, u.__getitem__, -1) | ||||
|  | ||||
|         self.assertRaises(TypeError, u.__getitem__) | ||||
|  | ||||
|         a = self.type2test([10, 11]) | ||||
|         self.assertEqual(a[0], 10) | ||||
|         self.assertEqual(a[1], 11) | ||||
|         self.assertEqual(a[-2], 10) | ||||
|         self.assertEqual(a[-1], 11) | ||||
|         self.assertRaises(IndexError, a.__getitem__, -3) | ||||
|         self.assertRaises(IndexError, a.__getitem__, 3) | ||||
|  | ||||
|     def test_getslice(self): | ||||
|         l = [0, 1, 2, 3, 4] | ||||
|         u = self.type2test(l) | ||||
|  | ||||
|         self.assertEqual(u[0:0], self.type2test()) | ||||
|         self.assertEqual(u[1:2], self.type2test([1])) | ||||
|         self.assertEqual(u[-2:-1], self.type2test([3])) | ||||
|         self.assertEqual(u[-1000:1000], u) | ||||
|         self.assertEqual(u[1000:-1000], self.type2test([])) | ||||
|         self.assertEqual(u[:], u) | ||||
|         self.assertEqual(u[1:None], self.type2test([1, 2, 3, 4])) | ||||
|         self.assertEqual(u[None:3], self.type2test([0, 1, 2])) | ||||
|  | ||||
|         # Extended slices | ||||
|         self.assertEqual(u[::], u) | ||||
|         self.assertEqual(u[::2], self.type2test([0, 2, 4])) | ||||
|         self.assertEqual(u[1::2], self.type2test([1, 3])) | ||||
|         self.assertEqual(u[::-1], self.type2test([4, 3, 2, 1, 0])) | ||||
|         self.assertEqual(u[::-2], self.type2test([4, 2, 0])) | ||||
|         self.assertEqual(u[3::-2], self.type2test([3, 1])) | ||||
|         self.assertEqual(u[3:3:-2], self.type2test([])) | ||||
|         self.assertEqual(u[3:2:-2], self.type2test([3])) | ||||
|         self.assertEqual(u[3:1:-2], self.type2test([3])) | ||||
|         self.assertEqual(u[3:0:-2], self.type2test([3, 1])) | ||||
|         self.assertEqual(u[::-100], self.type2test([4])) | ||||
|         self.assertEqual(u[100:-100:], self.type2test([])) | ||||
|         self.assertEqual(u[-100:100:], u) | ||||
|         self.assertEqual(u[100:-100:-1], u[::-1]) | ||||
|         self.assertEqual(u[-100:100:-1], self.type2test([])) | ||||
|         self.assertEqual(u[-100:100:2], self.type2test([0, 2, 4])) | ||||
|  | ||||
|         # Test extreme cases with long ints | ||||
|         a = self.type2test([0,1,2,3,4]) | ||||
|         self.assertEqual(a[ -pow(2,128): 3 ], self.type2test([0,1,2])) | ||||
|         self.assertEqual(a[ 3: pow(2,145) ], self.type2test([3,4])) | ||||
|         self.assertEqual(a[3::sys.maxsize], self.type2test([3])) | ||||
|  | ||||
|     def test_contains(self): | ||||
|         u = self.type2test([0, 1, 2]) | ||||
|         for i in u: | ||||
|             self.assertIn(i, u) | ||||
|         for i in min(u)-1, max(u)+1: | ||||
|             self.assertNotIn(i, u) | ||||
|  | ||||
|         self.assertRaises(TypeError, u.__contains__) | ||||
|  | ||||
|     def test_contains_fake(self): | ||||
|         # Sequences must use rich comparison against each item | ||||
|         # (unless "is" is true, or an earlier item answered) | ||||
|         # So ALWAYS_EQ must be found in all non-empty sequences. | ||||
|         self.assertNotIn(ALWAYS_EQ, self.type2test([])) | ||||
|         self.assertIn(ALWAYS_EQ, self.type2test([1])) | ||||
|         self.assertIn(1, self.type2test([ALWAYS_EQ])) | ||||
|         self.assertNotIn(NEVER_EQ, self.type2test([])) | ||||
|         self.assertNotIn(ALWAYS_EQ, self.type2test([NEVER_EQ])) | ||||
|         self.assertIn(NEVER_EQ, self.type2test([ALWAYS_EQ])) | ||||
|  | ||||
|     def test_contains_order(self): | ||||
|         # Sequences must test in-order.  If a rich comparison has side | ||||
|         # effects, these will be visible to tests against later members. | ||||
|         # In this test, the "side effect" is a short-circuiting raise. | ||||
|         class DoNotTestEq(Exception): | ||||
|             pass | ||||
|         class StopCompares: | ||||
|             def __eq__(self, other): | ||||
|                 raise DoNotTestEq | ||||
|  | ||||
|         checkfirst = self.type2test([1, StopCompares()]) | ||||
|         self.assertIn(1, checkfirst) | ||||
|         checklast = self.type2test([StopCompares(), 1]) | ||||
|         self.assertRaises(DoNotTestEq, checklast.__contains__, 1) | ||||
|  | ||||
|     def test_len(self): | ||||
|         self.assertEqual(len(self.type2test()), 0) | ||||
|         self.assertEqual(len(self.type2test([])), 0) | ||||
|         self.assertEqual(len(self.type2test([0])), 1) | ||||
|         self.assertEqual(len(self.type2test([0, 1, 2])), 3) | ||||
|  | ||||
|     def test_minmax(self): | ||||
|         u = self.type2test([0, 1, 2]) | ||||
|         self.assertEqual(min(u), 0) | ||||
|         self.assertEqual(max(u), 2) | ||||
|  | ||||
|     def test_addmul(self): | ||||
|         u1 = self.type2test([0]) | ||||
|         u2 = self.type2test([0, 1]) | ||||
|         self.assertEqual(u1, u1 + self.type2test()) | ||||
|         self.assertEqual(u1, self.type2test() + u1) | ||||
|         self.assertEqual(u1 + self.type2test([1]), u2) | ||||
|         self.assertEqual(self.type2test([-1]) + u1, self.type2test([-1, 0])) | ||||
|         self.assertEqual(self.type2test(), u2*0) | ||||
|         self.assertEqual(self.type2test(), 0*u2) | ||||
|         self.assertEqual(self.type2test(), u2*0) | ||||
|         self.assertEqual(self.type2test(), 0*u2) | ||||
|         self.assertEqual(u2, u2*1) | ||||
|         self.assertEqual(u2, 1*u2) | ||||
|         self.assertEqual(u2, u2*1) | ||||
|         self.assertEqual(u2, 1*u2) | ||||
|         self.assertEqual(u2+u2, u2*2) | ||||
|         self.assertEqual(u2+u2, 2*u2) | ||||
|         self.assertEqual(u2+u2, u2*2) | ||||
|         self.assertEqual(u2+u2, 2*u2) | ||||
|         self.assertEqual(u2+u2+u2, u2*3) | ||||
|         self.assertEqual(u2+u2+u2, 3*u2) | ||||
|  | ||||
|         class subclass(self.type2test): | ||||
|             pass | ||||
|         u3 = subclass([0, 1]) | ||||
|         self.assertEqual(u3, u3*1) | ||||
|         self.assertIsNot(u3, u3*1) | ||||
|  | ||||
|     def test_iadd(self): | ||||
|         u = self.type2test([0, 1]) | ||||
|         u += self.type2test() | ||||
|         self.assertEqual(u, self.type2test([0, 1])) | ||||
|         u += self.type2test([2, 3]) | ||||
|         self.assertEqual(u, self.type2test([0, 1, 2, 3])) | ||||
|         u += self.type2test([4, 5]) | ||||
|         self.assertEqual(u, self.type2test([0, 1, 2, 3, 4, 5])) | ||||
|  | ||||
|         u = self.type2test("spam") | ||||
|         u += self.type2test("eggs") | ||||
|         self.assertEqual(u, self.type2test("spameggs")) | ||||
|  | ||||
|     def test_imul(self): | ||||
|         u = self.type2test([0, 1]) | ||||
|         u *= 3 | ||||
|         self.assertEqual(u, self.type2test([0, 1, 0, 1, 0, 1])) | ||||
|         u *= 0 | ||||
|         self.assertEqual(u, self.type2test([])) | ||||
|  | ||||
|     def test_getitemoverwriteiter(self): | ||||
|         # Verify that __getitem__ overrides are not recognized by __iter__ | ||||
|         class T(self.type2test): | ||||
|             def __getitem__(self, key): | ||||
|                 return str(key) + '!!!' | ||||
|         self.assertEqual(next(iter(T((1,2)))), 1) | ||||
|  | ||||
|     def test_repeat(self): | ||||
|         for m in range(4): | ||||
|             s = tuple(range(m)) | ||||
|             for n in range(-3, 5): | ||||
|                 self.assertEqual(self.type2test(s*n), self.type2test(s)*n) | ||||
|             self.assertEqual(self.type2test(s)*(-4), self.type2test([])) | ||||
|             self.assertEqual(id(s), id(s*1)) | ||||
|  | ||||
|     def test_bigrepeat(self): | ||||
|         if sys.maxsize <= 2147483647: | ||||
|             x = self.type2test([0]) | ||||
|             x *= 2**16 | ||||
|             self.assertRaises(MemoryError, x.__mul__, 2**16) | ||||
|             if hasattr(x, '__imul__'): | ||||
|                 self.assertRaises(MemoryError, x.__imul__, 2**16) | ||||
|  | ||||
|     def test_subscript(self): | ||||
|         a = self.type2test([10, 11]) | ||||
|         self.assertEqual(a.__getitem__(0), 10) | ||||
|         self.assertEqual(a.__getitem__(1), 11) | ||||
|         self.assertEqual(a.__getitem__(-2), 10) | ||||
|         self.assertEqual(a.__getitem__(-1), 11) | ||||
|         self.assertRaises(IndexError, a.__getitem__, -3) | ||||
|         self.assertRaises(IndexError, a.__getitem__, 3) | ||||
|         self.assertEqual(a.__getitem__(slice(0,1)), self.type2test([10])) | ||||
|         self.assertEqual(a.__getitem__(slice(1,2)), self.type2test([11])) | ||||
|         self.assertEqual(a.__getitem__(slice(0,2)), self.type2test([10, 11])) | ||||
|         self.assertEqual(a.__getitem__(slice(0,3)), self.type2test([10, 11])) | ||||
|         self.assertEqual(a.__getitem__(slice(3,5)), self.type2test([])) | ||||
|         self.assertRaises(ValueError, a.__getitem__, slice(0, 10, 0)) | ||||
|         self.assertRaises(TypeError, a.__getitem__, 'x') | ||||
|  | ||||
|     def test_count(self): | ||||
|         a = self.type2test([0, 1, 2])*3 | ||||
|         self.assertEqual(a.count(0), 3) | ||||
|         self.assertEqual(a.count(1), 3) | ||||
|         self.assertEqual(a.count(3), 0) | ||||
|  | ||||
|         self.assertEqual(a.count(ALWAYS_EQ), 9) | ||||
|         self.assertEqual(self.type2test([ALWAYS_EQ, ALWAYS_EQ]).count(1), 2) | ||||
|         self.assertEqual(self.type2test([ALWAYS_EQ, ALWAYS_EQ]).count(NEVER_EQ), 2) | ||||
|         self.assertEqual(self.type2test([NEVER_EQ, NEVER_EQ]).count(ALWAYS_EQ), 0) | ||||
|  | ||||
|         self.assertRaises(TypeError, a.count) | ||||
|  | ||||
|         class BadExc(Exception): | ||||
|             pass | ||||
|  | ||||
|         class BadCmp: | ||||
|             def __eq__(self, other): | ||||
|                 if other == 2: | ||||
|                     raise BadExc() | ||||
|                 return False | ||||
|  | ||||
|         self.assertRaises(BadExc, a.count, BadCmp()) | ||||
|  | ||||
|     def test_index(self): | ||||
|         u = self.type2test([0, 1]) | ||||
|         self.assertEqual(u.index(0), 0) | ||||
|         self.assertEqual(u.index(1), 1) | ||||
|         self.assertRaises(ValueError, u.index, 2) | ||||
|  | ||||
|         u = self.type2test([-2, -1, 0, 0, 1, 2]) | ||||
|         self.assertEqual(u.count(0), 2) | ||||
|         self.assertEqual(u.index(0), 2) | ||||
|         self.assertEqual(u.index(0, 2), 2) | ||||
|         self.assertEqual(u.index(-2, -10), 0) | ||||
|         self.assertEqual(u.index(0, 3), 3) | ||||
|         self.assertEqual(u.index(0, 3, 4), 3) | ||||
|         self.assertRaises(ValueError, u.index, 2, 0, -10) | ||||
|  | ||||
|         self.assertEqual(u.index(ALWAYS_EQ), 0) | ||||
|         self.assertEqual(self.type2test([ALWAYS_EQ, ALWAYS_EQ]).index(1), 0) | ||||
|         self.assertEqual(self.type2test([ALWAYS_EQ, ALWAYS_EQ]).index(NEVER_EQ), 0) | ||||
|         self.assertRaises(ValueError, self.type2test([NEVER_EQ, NEVER_EQ]).index, ALWAYS_EQ) | ||||
|  | ||||
|         self.assertRaises(TypeError, u.index) | ||||
|  | ||||
|         class BadExc(Exception): | ||||
|             pass | ||||
|  | ||||
|         class BadCmp: | ||||
|             def __eq__(self, other): | ||||
|                 if other == 2: | ||||
|                     raise BadExc() | ||||
|                 return False | ||||
|  | ||||
|         a = self.type2test([0, 1, 2, 3]) | ||||
|         self.assertRaises(BadExc, a.index, BadCmp()) | ||||
|  | ||||
|         a = self.type2test([-2, -1, 0, 0, 1, 2]) | ||||
|         self.assertEqual(a.index(0), 2) | ||||
|         self.assertEqual(a.index(0, 2), 2) | ||||
|         self.assertEqual(a.index(0, -4), 2) | ||||
|         self.assertEqual(a.index(-2, -10), 0) | ||||
|         self.assertEqual(a.index(0, 3), 3) | ||||
|         self.assertEqual(a.index(0, -3), 3) | ||||
|         self.assertEqual(a.index(0, 3, 4), 3) | ||||
|         self.assertEqual(a.index(0, -3, -2), 3) | ||||
|         self.assertEqual(a.index(0, -4*sys.maxsize, 4*sys.maxsize), 2) | ||||
|         self.assertRaises(ValueError, a.index, 0, 4*sys.maxsize,-4*sys.maxsize) | ||||
|         self.assertRaises(ValueError, a.index, 2, 0, -10) | ||||
|  | ||||
|     def test_pickle(self): | ||||
|         lst = self.type2test([4, 5, 6, 7]) | ||||
|         for proto in range(pickle.HIGHEST_PROTOCOL + 1): | ||||
|             lst2 = pickle.loads(pickle.dumps(lst, proto)) | ||||
|             self.assertEqual(lst2, lst) | ||||
|             self.assertNotEqual(id(lst2), id(lst)) | ||||
|  | ||||
|     @support.suppress_immortalization() | ||||
|     def test_free_after_iterating(self): | ||||
|         support.check_free_after_iterating(self, iter, self.type2test) | ||||
|         support.check_free_after_iterating(self, reversed, self.type2test) | ||||
							
								
								
									
										122
									
								
								test/dynamo/cpython/3_13/test_dict.diff
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										122
									
								
								test/dynamo/cpython/3_13/test_dict.diff
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,122 @@ | ||||
| diff --git a/test/dynamo/cpython/3_13/test_dict.py b/test/dynamo/cpython/3_13/test_dict.py | ||||
| index 4729132c5a5..14f829c1715 100644 | ||||
| --- a/test/dynamo/cpython/3_13/test_dict.py | ||||
| +++ b/test/dynamo/cpython/3_13/test_dict.py | ||||
| @@ -1,3 +1,57 @@ | ||||
| +# ======= BEGIN Dynamo patch ======= | ||||
| +# Owner(s): ["module: dynamo"] | ||||
| + | ||||
| +# ruff: noqa | ||||
| +# flake8: noqa | ||||
| + | ||||
| +import sys | ||||
| +import torch | ||||
| +import torch._dynamo.test_case | ||||
| +import unittest | ||||
| +from torch._dynamo.test_case import CPythonTestCase | ||||
| +from torch.testing._internal.common_utils import ( | ||||
| +    run_tests, | ||||
| +    xfailIfTorchDynamo, | ||||
| +) | ||||
| + | ||||
| +__TestCase = CPythonTestCase | ||||
| + | ||||
| + | ||||
| +# redirect import statements | ||||
| +import sys | ||||
| +import importlib.abc | ||||
| + | ||||
| +redirect_imports = ( | ||||
| +    "test.mapping_tests", | ||||
| +    "test.typinganndata", | ||||
| +    "test.test_grammar", | ||||
| +    "test.test_math", | ||||
| +    "test.test_iter", | ||||
| +    "test.typinganndata.ann_module", | ||||
| +) | ||||
| + | ||||
| +class RedirectImportFinder(importlib.abc.MetaPathFinder): | ||||
| +    def find_spec(self, fullname, path, target=None): | ||||
| +        # Check if the import is the problematic one | ||||
| +        if fullname in redirect_imports: | ||||
| +            try: | ||||
| +                # Attempt to import the standalone module | ||||
| +                name = fullname.removeprefix("test.") | ||||
| +                r = importlib.import_module(name) | ||||
| +                # Redirect the module in sys.modules | ||||
| +                sys.modules[fullname] = r | ||||
| +                # Return a module spec from the found module | ||||
| +                return importlib.util.find_spec(name) | ||||
| +            except ImportError: | ||||
| +                return None | ||||
| +        return None | ||||
| + | ||||
| +# Add the custom finder to sys.meta_path | ||||
| +sys.meta_path.insert(0, RedirectImportFinder()) | ||||
| + | ||||
| + | ||||
| +# ======= END DYNAMO PATCH ======= | ||||
| + | ||||
|  import collections | ||||
|  import collections.abc | ||||
|  import gc | ||||
| @@ -11,7 +65,7 @@ from test import support | ||||
|  from test.support import import_helper, get_c_recursion_limit | ||||
|   | ||||
|   | ||||
| -class DictTest(unittest.TestCase): | ||||
| +class DictTest(__TestCase): | ||||
|   | ||||
|      def test_invalid_keyword_arguments(self): | ||||
|          class Custom(dict): | ||||
| @@ -265,6 +319,7 @@ class DictTest(unittest.TestCase): | ||||
|   | ||||
|          self.assertRaises(ValueError, {}.update, [(1, 2, 3)]) | ||||
|   | ||||
| +    @unittest.skip("test hangs") | ||||
|      def test_fromkeys(self): | ||||
|          self.assertEqual(dict.fromkeys('abc'), {'a':None, 'b':None, 'c':None}) | ||||
|          d = {} | ||||
| @@ -477,7 +532,7 @@ class DictTest(unittest.TestCase): | ||||
|          for copymode in -1, +1: | ||||
|              # -1: b has same structure as a | ||||
|              # +1: b is a.copy() | ||||
| -            for log2size in range(12): | ||||
| +            for log2size in range(4): | ||||
|                  size = 2**log2size | ||||
|                  a = {} | ||||
|                  b = {} | ||||
| @@ -1006,18 +1061,6 @@ class DictTest(unittest.TestCase): | ||||
|              pass | ||||
|          self._tracked(MyDict()) | ||||
|   | ||||
| -    @support.cpython_only | ||||
| -    def test_track_lazy_instance_dicts(self): | ||||
| -        class C: | ||||
| -            pass | ||||
| -        o = C() | ||||
| -        d = o.__dict__ | ||||
| -        self._not_tracked(d) | ||||
| -        o.untracked = 42 | ||||
| -        self._not_tracked(d) | ||||
| -        o.tracked = [] | ||||
| -        self._tracked(d) | ||||
| - | ||||
|      def make_shared_key_dict(self, n): | ||||
|          class C: | ||||
|              pass | ||||
| @@ -1622,7 +1665,7 @@ class DictTest(unittest.TestCase): | ||||
|                  self.assertGreaterEqual(eq_count, 1) | ||||
|   | ||||
|   | ||||
| -class CAPITest(unittest.TestCase): | ||||
| +class CAPITest(__TestCase): | ||||
|   | ||||
|      # Test _PyDict_GetItem_KnownHash() | ||||
|      @support.cpython_only | ||||
| @@ -1666,4 +1709,4 @@ class SubclassMappingTests(mapping_tests.BasicTestMappingProtocol): | ||||
|   | ||||
|   | ||||
|  if __name__ == "__main__": | ||||
| -    unittest.main() | ||||
| +    run_tests() | ||||
							
								
								
									
										1712
									
								
								test/dynamo/cpython/3_13/test_dict.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1712
									
								
								test/dynamo/cpython/3_13/test_dict.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										77
									
								
								test/dynamo/cpython/3_13/test_list.diff
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								test/dynamo/cpython/3_13/test_list.diff
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,77 @@ | ||||
| diff --git a/test/dynamo/cpython/3_13/test_list.py b/test/dynamo/cpython/3_13/test_list.py | ||||
| index 23ef902aa0b..30e69ff75bd 100644 | ||||
| --- a/test/dynamo/cpython/3_13/test_list.py | ||||
| +++ b/test/dynamo/cpython/3_13/test_list.py | ||||
| @@ -1,6 +1,57 @@ | ||||
| +# ======= BEGIN Dynamo patch ======= | ||||
| +# Owner(s): ["module: dynamo"] | ||||
| + | ||||
| +# ruff: noqa | ||||
| +# flake8: noqa | ||||
| + | ||||
| +import sys | ||||
| +import torch | ||||
| +import torch._dynamo.test_case | ||||
| +import unittest | ||||
| +from torch._dynamo.test_case import CPythonTestCase | ||||
| +from torch.testing._internal.common_utils import run_tests | ||||
| + | ||||
| +__TestCase = CPythonTestCase | ||||
| + | ||||
| + | ||||
| +# redirect import statements | ||||
| +import sys | ||||
| +import importlib.abc | ||||
| + | ||||
| +redirect_imports = ( | ||||
| +    "test.mapping_tests", | ||||
| +    "test.typinganndata", | ||||
| +    "test.test_grammar", | ||||
| +    "test.test_math", | ||||
| +    "test.test_iter", | ||||
| +    "test.typinganndata.ann_module", | ||||
| +) | ||||
| + | ||||
| +class RedirectImportFinder(importlib.abc.MetaPathFinder): | ||||
| +    def find_spec(self, fullname, path, target=None): | ||||
| +        # Check if the import is the problematic one | ||||
| +        if fullname in redirect_imports: | ||||
| +            try: | ||||
| +                # Attempt to import the standalone module | ||||
| +                name = fullname.removeprefix("test.") | ||||
| +                r = importlib.import_module(name) | ||||
| +                # Redirect the module in sys.modules | ||||
| +                sys.modules[fullname] = r | ||||
| +                # Return a module spec from the found module | ||||
| +                return importlib.util.find_spec(name) | ||||
| +            except ImportError: | ||||
| +                return None | ||||
| +        return None | ||||
| + | ||||
| +# Add the custom finder to sys.meta_path | ||||
| +sys.meta_path.insert(0, RedirectImportFinder()) | ||||
| + | ||||
| + | ||||
| +# ======= END DYNAMO PATCH ======= | ||||
| + | ||||
|  import sys | ||||
|  import textwrap | ||||
| -from test import list_tests | ||||
| +import list_tests | ||||
|  from test.support import cpython_only | ||||
|  from test.support.script_helper import assert_python_ok | ||||
|  import pickle | ||||
| @@ -324,6 +375,7 @@ class ListTest(list_tests.CommonTest): | ||||
|              a.append(4) | ||||
|              self.assertEqual(list(it), []) | ||||
|   | ||||
| +    @unittest.expectedFailure | ||||
|      def test_deopt_from_append_list(self): | ||||
|          # gh-132011: it used to crash, because | ||||
|          # of `CALL_LIST_APPEND` specialization failure. | ||||
| @@ -345,4 +397,4 @@ class ListTest(list_tests.CommonTest): | ||||
|          self.assertEqual(rc, 0) | ||||
|   | ||||
|  if __name__ == "__main__": | ||||
| -    unittest.main() | ||||
| +    run_tests() | ||||
							
								
								
									
										398
									
								
								test/dynamo/cpython/3_13/test_list.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										398
									
								
								test/dynamo/cpython/3_13/test_list.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,398 @@ | ||||
| # ======= BEGIN Dynamo patch ======= | ||||
| # Owner(s): ["module: dynamo"] | ||||
|  | ||||
| # ruff: noqa | ||||
| # flake8: noqa | ||||
|  | ||||
| import sys | ||||
| import torch | ||||
| import torch._dynamo.test_case | ||||
| import unittest | ||||
| from torch._dynamo.test_case import CPythonTestCase | ||||
| from torch.testing._internal.common_utils import run_tests | ||||
|  | ||||
| __TestCase = CPythonTestCase | ||||
|  | ||||
|  | ||||
| # redirect import statements | ||||
| import sys | ||||
| import importlib.abc | ||||
|  | ||||
| redirect_imports = ( | ||||
|     "test.mapping_tests", | ||||
|     "test.typinganndata", | ||||
|     "test.test_grammar", | ||||
|     "test.test_math", | ||||
|     "test.test_iter", | ||||
|     "test.typinganndata.ann_module", | ||||
| ) | ||||
|  | ||||
| class RedirectImportFinder(importlib.abc.MetaPathFinder): | ||||
|     def find_spec(self, fullname, path, target=None): | ||||
|         # Check if the import is the problematic one | ||||
|         if fullname in redirect_imports: | ||||
|             try: | ||||
|                 # Attempt to import the standalone module | ||||
|                 name = fullname.removeprefix("test.") | ||||
|                 r = importlib.import_module(name) | ||||
|                 # Redirect the module in sys.modules | ||||
|                 sys.modules[fullname] = r | ||||
|                 # Return a module spec from the found module | ||||
|                 return importlib.util.find_spec(name) | ||||
|             except ImportError: | ||||
|                 return None | ||||
|         return None | ||||
|  | ||||
| # Add the custom finder to sys.meta_path | ||||
| sys.meta_path.insert(0, RedirectImportFinder()) | ||||
|  | ||||
|  | ||||
| # ======= END DYNAMO PATCH ======= | ||||
|  | ||||
| import sys | ||||
| import textwrap | ||||
| import list_tests | ||||
| from test.support import cpython_only | ||||
| from test.support.script_helper import assert_python_ok | ||||
| import pickle | ||||
| import unittest | ||||
|  | ||||
| class ListTest(list_tests.CommonTest): | ||||
|     type2test = list | ||||
|  | ||||
|     def test_basic(self): | ||||
|         self.assertEqual(list([]), []) | ||||
|         l0_3 = [0, 1, 2, 3] | ||||
|         l0_3_bis = list(l0_3) | ||||
|         self.assertEqual(l0_3, l0_3_bis) | ||||
|         self.assertTrue(l0_3 is not l0_3_bis) | ||||
|         self.assertEqual(list(()), []) | ||||
|         self.assertEqual(list((0, 1, 2, 3)), [0, 1, 2, 3]) | ||||
|         self.assertEqual(list(''), []) | ||||
|         self.assertEqual(list('spam'), ['s', 'p', 'a', 'm']) | ||||
|         self.assertEqual(list(x for x in range(10) if x % 2), | ||||
|                          [1, 3, 5, 7, 9]) | ||||
|  | ||||
|         if sys.maxsize == 0x7fffffff: | ||||
|             # This test can currently only work on 32-bit machines. | ||||
|             # XXX If/when PySequence_Length() returns a ssize_t, it should be | ||||
|             # XXX re-enabled. | ||||
|             # Verify clearing of bug #556025. | ||||
|             # This assumes that the max data size (sys.maxint) == max | ||||
|             # address size this also assumes that the address size is at | ||||
|             # least 4 bytes with 8 byte addresses, the bug is not well | ||||
|             # tested | ||||
|             # | ||||
|             # Note: This test is expected to SEGV under Cygwin 1.3.12 or | ||||
|             # earlier due to a newlib bug.  See the following mailing list | ||||
|             # thread for the details: | ||||
|             self.assertRaises(MemoryError, list, range(sys.maxsize // 2)) | ||||
|  | ||||
|         # This code used to segfault in Py2.4a3 | ||||
|         x = [] | ||||
|         x.extend(-y for y in x) | ||||
|         self.assertEqual(x, []) | ||||
|  | ||||
|     def test_keyword_args(self): | ||||
|         with self.assertRaisesRegex(TypeError, 'keyword argument'): | ||||
|             list(sequence=[]) | ||||
|  | ||||
|     def test_keywords_in_subclass(self): | ||||
|         class subclass(list): | ||||
|             pass | ||||
|         u = subclass([1, 2]) | ||||
|         self.assertIs(type(u), subclass) | ||||
|         self.assertEqual(list(u), [1, 2]) | ||||
|         with self.assertRaises(TypeError): | ||||
|             subclass(sequence=()) | ||||
|  | ||||
|         class subclass_with_init(list): | ||||
|             def __init__(self, seq, newarg=None): | ||||
|                 super().__init__(seq) | ||||
|                 self.newarg = newarg | ||||
|         u = subclass_with_init([1, 2], newarg=3) | ||||
|         self.assertIs(type(u), subclass_with_init) | ||||
|         self.assertEqual(list(u), [1, 2]) | ||||
|         self.assertEqual(u.newarg, 3) | ||||
|  | ||||
|         class subclass_with_new(list): | ||||
|             def __new__(cls, seq, newarg=None): | ||||
|                 self = super().__new__(cls, seq) | ||||
|                 self.newarg = newarg | ||||
|                 return self | ||||
|         u = subclass_with_new([1, 2], newarg=3) | ||||
|         self.assertIs(type(u), subclass_with_new) | ||||
|         self.assertEqual(list(u), [1, 2]) | ||||
|         self.assertEqual(u.newarg, 3) | ||||
|  | ||||
|     def test_truth(self): | ||||
|         super().test_truth() | ||||
|         self.assertTrue(not []) | ||||
|         self.assertTrue([42]) | ||||
|  | ||||
|     def test_identity(self): | ||||
|         self.assertTrue([] is not []) | ||||
|  | ||||
|     def test_len(self): | ||||
|         super().test_len() | ||||
|         self.assertEqual(len([]), 0) | ||||
|         self.assertEqual(len([0]), 1) | ||||
|         self.assertEqual(len([0, 1, 2]), 3) | ||||
|  | ||||
|     def test_overflow(self): | ||||
|         lst = [4, 5, 6, 7] | ||||
|         n = int((sys.maxsize*2+2) // len(lst)) | ||||
|         def mul(a, b): return a * b | ||||
|         def imul(a, b): a *= b | ||||
|         self.assertRaises((MemoryError, OverflowError), mul, lst, n) | ||||
|         self.assertRaises((MemoryError, OverflowError), imul, lst, n) | ||||
|  | ||||
|     def test_empty_slice(self): | ||||
|         x = [] | ||||
|         x[:] = x | ||||
|         self.assertEqual(x, []) | ||||
|  | ||||
|     def test_list_resize_overflow(self): | ||||
|         # gh-97616: test new_allocated * sizeof(PyObject*) overflow | ||||
|         # check in list_resize() | ||||
|         lst = [0] * 65 | ||||
|         del lst[1:] | ||||
|         self.assertEqual(len(lst), 1) | ||||
|  | ||||
|         size = sys.maxsize | ||||
|         with self.assertRaises((MemoryError, OverflowError)): | ||||
|             lst * size | ||||
|         with self.assertRaises((MemoryError, OverflowError)): | ||||
|             lst *= size | ||||
|  | ||||
|     def test_repr_mutate(self): | ||||
|         class Obj: | ||||
|             @staticmethod | ||||
|             def __repr__(): | ||||
|                 try: | ||||
|                     mylist.pop() | ||||
|                 except IndexError: | ||||
|                     pass | ||||
|                 return 'obj' | ||||
|  | ||||
|         mylist = [Obj() for _ in range(5)] | ||||
|         self.assertEqual(repr(mylist), '[obj, obj, obj]') | ||||
|  | ||||
|     def test_repr_large(self): | ||||
|         # Check the repr of large list objects | ||||
|         def check(n): | ||||
|             l = [0] * n | ||||
|             s = repr(l) | ||||
|             self.assertEqual(s, | ||||
|                 '[' + ', '.join(['0'] * n) + ']') | ||||
|         check(10)       # check our checking code | ||||
|         check(1000000) | ||||
|  | ||||
|     def test_iterator_pickle(self): | ||||
|         orig = self.type2test([4, 5, 6, 7]) | ||||
|         data = [10, 11, 12, 13, 14, 15] | ||||
|         for proto in range(pickle.HIGHEST_PROTOCOL + 1): | ||||
|             # initial iterator | ||||
|             itorig = iter(orig) | ||||
|             d = pickle.dumps((itorig, orig), proto) | ||||
|             it, a = pickle.loads(d) | ||||
|             a[:] = data | ||||
|             self.assertEqual(type(it), type(itorig)) | ||||
|             self.assertEqual(list(it), data) | ||||
|  | ||||
|             # running iterator | ||||
|             next(itorig) | ||||
|             d = pickle.dumps((itorig, orig), proto) | ||||
|             it, a = pickle.loads(d) | ||||
|             a[:] = data | ||||
|             self.assertEqual(type(it), type(itorig)) | ||||
|             self.assertEqual(list(it), data[1:]) | ||||
|  | ||||
|             # empty iterator | ||||
|             for i in range(1, len(orig)): | ||||
|                 next(itorig) | ||||
|             d = pickle.dumps((itorig, orig), proto) | ||||
|             it, a = pickle.loads(d) | ||||
|             a[:] = data | ||||
|             self.assertEqual(type(it), type(itorig)) | ||||
|             self.assertEqual(list(it), data[len(orig):]) | ||||
|  | ||||
|             # exhausted iterator | ||||
|             self.assertRaises(StopIteration, next, itorig) | ||||
|             d = pickle.dumps((itorig, orig), proto) | ||||
|             it, a = pickle.loads(d) | ||||
|             a[:] = data | ||||
|             self.assertEqual(list(it), []) | ||||
|  | ||||
|     def test_reversed_pickle(self): | ||||
|         orig = self.type2test([4, 5, 6, 7]) | ||||
|         data = [10, 11, 12, 13, 14, 15] | ||||
|         for proto in range(pickle.HIGHEST_PROTOCOL + 1): | ||||
|             # initial iterator | ||||
|             itorig = reversed(orig) | ||||
|             d = pickle.dumps((itorig, orig), proto) | ||||
|             it, a = pickle.loads(d) | ||||
|             a[:] = data | ||||
|             self.assertEqual(type(it), type(itorig)) | ||||
|             self.assertEqual(list(it), data[len(orig)-1::-1]) | ||||
|  | ||||
|             # running iterator | ||||
|             next(itorig) | ||||
|             d = pickle.dumps((itorig, orig), proto) | ||||
|             it, a = pickle.loads(d) | ||||
|             a[:] = data | ||||
|             self.assertEqual(type(it), type(itorig)) | ||||
|             self.assertEqual(list(it), data[len(orig)-2::-1]) | ||||
|  | ||||
|             # empty iterator | ||||
|             for i in range(1, len(orig)): | ||||
|                 next(itorig) | ||||
|             d = pickle.dumps((itorig, orig), proto) | ||||
|             it, a = pickle.loads(d) | ||||
|             a[:] = data | ||||
|             self.assertEqual(type(it), type(itorig)) | ||||
|             self.assertEqual(list(it), []) | ||||
|  | ||||
|             # exhausted iterator | ||||
|             self.assertRaises(StopIteration, next, itorig) | ||||
|             d = pickle.dumps((itorig, orig), proto) | ||||
|             it, a = pickle.loads(d) | ||||
|             a[:] = data | ||||
|             self.assertEqual(list(it), []) | ||||
|  | ||||
|     def test_step_overflow(self): | ||||
|         a = [0, 1, 2, 3, 4] | ||||
|         a[1::sys.maxsize] = [0] | ||||
|         self.assertEqual(a[3::sys.maxsize], [3]) | ||||
|  | ||||
|     def test_no_comdat_folding(self): | ||||
|         # Issue 8847: In the PGO build, the MSVC linker's COMDAT folding | ||||
|         # optimization causes failures in code that relies on distinct | ||||
|         # function addresses. | ||||
|         class L(list): pass | ||||
|         with self.assertRaises(TypeError): | ||||
|             (3,) + L([1,2]) | ||||
|  | ||||
|     def test_equal_operator_modifying_operand(self): | ||||
|         # test fix for seg fault reported in bpo-38588 part 2. | ||||
|         class X: | ||||
|             def __eq__(self,other) : | ||||
|                 list2.clear() | ||||
|                 return NotImplemented | ||||
|  | ||||
|         class Y: | ||||
|             def __eq__(self, other): | ||||
|                 list1.clear() | ||||
|                 return NotImplemented | ||||
|  | ||||
|         class Z: | ||||
|             def __eq__(self, other): | ||||
|                 list3.clear() | ||||
|                 return NotImplemented | ||||
|  | ||||
|         list1 = [X()] | ||||
|         list2 = [Y()] | ||||
|         self.assertTrue(list1 == list2) | ||||
|  | ||||
|         list3 = [Z()] | ||||
|         list4 = [1] | ||||
|         self.assertFalse(list3 == list4) | ||||
|  | ||||
|     def test_lt_operator_modifying_operand(self): | ||||
|         # See gh-120298 | ||||
|         class evil: | ||||
|             def __lt__(self, other): | ||||
|                 other.clear() | ||||
|                 return NotImplemented | ||||
|  | ||||
|         a = [[evil()]] | ||||
|         with self.assertRaises(TypeError): | ||||
|             a[0] < a | ||||
|  | ||||
|     def test_list_index_modifing_operand(self): | ||||
|         # See gh-120384 | ||||
|         class evil: | ||||
|             def __init__(self, lst): | ||||
|                 self.lst = lst | ||||
|             def __iter__(self): | ||||
|                 yield from self.lst | ||||
|                 self.lst.clear() | ||||
|  | ||||
|         lst = list(range(5)) | ||||
|         operand = evil(lst) | ||||
|         with self.assertRaises(ValueError): | ||||
|             lst[::-1] = operand | ||||
|  | ||||
|     @cpython_only | ||||
|     def test_preallocation(self): | ||||
|         iterable = [0] * 10 | ||||
|         iter_size = sys.getsizeof(iterable) | ||||
|  | ||||
|         self.assertEqual(iter_size, sys.getsizeof(list([0] * 10))) | ||||
|         self.assertEqual(iter_size, sys.getsizeof(list(range(10)))) | ||||
|  | ||||
|     def test_count_index_remove_crashes(self): | ||||
|         # bpo-38610: The count(), index(), and remove() methods were not | ||||
|         # holding strong references to list elements while calling | ||||
|         # PyObject_RichCompareBool(). | ||||
|         class X: | ||||
|             def __eq__(self, other): | ||||
|                 lst.clear() | ||||
|                 return NotImplemented | ||||
|  | ||||
|         lst = [X()] | ||||
|         with self.assertRaises(ValueError): | ||||
|             lst.index(lst) | ||||
|  | ||||
|         class L(list): | ||||
|             def __eq__(self, other): | ||||
|                 str(other) | ||||
|                 return NotImplemented | ||||
|  | ||||
|         lst = L([X()]) | ||||
|         lst.count(lst) | ||||
|  | ||||
|         lst = L([X()]) | ||||
|         with self.assertRaises(ValueError): | ||||
|             lst.remove(lst) | ||||
|  | ||||
|         # bpo-39453: list.__contains__ was not holding strong references | ||||
|         # to list elements while calling PyObject_RichCompareBool(). | ||||
|         lst = [X(), X()] | ||||
|         3 in lst | ||||
|         lst = [X(), X()] | ||||
|         X() in lst | ||||
|  | ||||
|     def test_tier2_invalidates_iterator(self): | ||||
|         # GH-121012 | ||||
|         for _ in range(100): | ||||
|             a = [1, 2, 3] | ||||
|             it = iter(a) | ||||
|             for _ in it: | ||||
|                 pass | ||||
|             a.append(4) | ||||
|             self.assertEqual(list(it), []) | ||||
|  | ||||
|     @unittest.expectedFailure | ||||
|     def test_deopt_from_append_list(self): | ||||
|         # gh-132011: it used to crash, because | ||||
|         # of `CALL_LIST_APPEND` specialization failure. | ||||
|         code = textwrap.dedent(""" | ||||
|             l = [] | ||||
|             def lappend(l, x, y): | ||||
|                 l.append((x, y)) | ||||
|             for x in range(3): | ||||
|                 lappend(l, None, None) | ||||
|             try: | ||||
|                 lappend(list, None, None) | ||||
|             except TypeError: | ||||
|                 pass | ||||
|             else: | ||||
|                 raise AssertionError | ||||
|         """) | ||||
|  | ||||
|         rc, _, _ = assert_python_ok("-c", code) | ||||
|         self.assertEqual(rc, 0) | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
							
								
								
									
										173
									
								
								test/dynamo/cpython/3_13/test_ordered_dict.diff
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								test/dynamo/cpython/3_13/test_ordered_dict.diff
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,173 @@ | ||||
| diff --git a/test/dynamo/cpython/3_13/test_ordered_dict.py b/test/dynamo/cpython/3_13/test_ordered_dict.py | ||||
| index a9b6a84996e..b77eff70414 100644 | ||||
| --- a/test/dynamo/cpython/3_13/test_ordered_dict.py | ||||
| +++ b/test/dynamo/cpython/3_13/test_ordered_dict.py | ||||
| @@ -1,3 +1,57 @@ | ||||
| +# ======= BEGIN Dynamo patch ======= | ||||
| +# Owner(s): ["module: dynamo"] | ||||
| + | ||||
| +# ruff: noqa | ||||
| +# flake8: noqa | ||||
| + | ||||
| +import sys | ||||
| +import torch | ||||
| +import torch._dynamo.test_case | ||||
| +import unittest | ||||
| +from torch._dynamo.test_case import CPythonTestCase | ||||
| +from torch.testing._internal.common_utils import ( | ||||
| +    run_tests, | ||||
| +    xfailIfTorchDynamo, | ||||
| +) | ||||
| + | ||||
| +__TestCase = CPythonTestCase | ||||
| + | ||||
| + | ||||
| +# redirect import statements | ||||
| +import sys | ||||
| +import importlib.abc | ||||
| + | ||||
| +redirect_imports = ( | ||||
| +    "test.mapping_tests", | ||||
| +    "test.typinganndata", | ||||
| +    "test.test_grammar", | ||||
| +    "test.test_math", | ||||
| +    "test.test_iter", | ||||
| +    "test.typinganndata.ann_module", | ||||
| +) | ||||
| + | ||||
| +class RedirectImportFinder(importlib.abc.MetaPathFinder): | ||||
| +    def find_spec(self, fullname, path, target=None): | ||||
| +        # Check if the import is the problematic one | ||||
| +        if fullname in redirect_imports: | ||||
| +            try: | ||||
| +                # Attempt to import the standalone module | ||||
| +                name = fullname.removeprefix("test.") | ||||
| +                r = importlib.import_module(name) | ||||
| +                # Redirect the module in sys.modules | ||||
| +                sys.modules[fullname] = r | ||||
| +                # Return a module spec from the found module | ||||
| +                return importlib.util.find_spec(name) | ||||
| +            except ImportError: | ||||
| +                return None | ||||
| +        return None | ||||
| + | ||||
| +# Add the custom finder to sys.meta_path | ||||
| +sys.meta_path.insert(0, RedirectImportFinder()) | ||||
| + | ||||
| + | ||||
| +# ======= END DYNAMO PATCH ======= | ||||
| + | ||||
|  import builtins | ||||
|  import contextlib | ||||
|  import copy | ||||
| @@ -760,7 +814,7 @@ class _TriggerSideEffectOnEqual: | ||||
|      def side_effect(self): | ||||
|          raise NotImplementedError | ||||
|   | ||||
| -class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase): | ||||
| +class PurePythonOrderedDictTests(OrderedDictTests, __TestCase): | ||||
|   | ||||
|      module = py_coll | ||||
|      OrderedDict = py_coll.OrderedDict | ||||
| @@ -781,7 +835,7 @@ class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase): | ||||
|          self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2))) | ||||
|   | ||||
|   | ||||
| -class CPythonBuiltinDictTests(unittest.TestCase): | ||||
| +class CPythonBuiltinDictTests(__TestCase): | ||||
|      """Builtin dict preserves insertion order. | ||||
|   | ||||
|      Reuse some of tests in OrderedDict selectively. | ||||
| @@ -800,6 +854,7 @@ for method in ( | ||||
|  del method | ||||
|   | ||||
|   | ||||
| + | ||||
|  class CPythonOrderedDictSideEffects: | ||||
|   | ||||
|      def check_runtime_error_issue119004(self, dict1, dict2): | ||||
| @@ -878,7 +933,7 @@ class CPythonOrderedDictSideEffects: | ||||
|  @unittest.skipUnless(c_coll, 'requires the C version of the collections module') | ||||
|  class CPythonOrderedDictTests(OrderedDictTests, | ||||
|                                CPythonOrderedDictSideEffects, | ||||
| -                              unittest.TestCase): | ||||
| +                              __TestCase): | ||||
|   | ||||
|      module = c_coll | ||||
|      OrderedDict = c_coll.OrderedDict | ||||
| @@ -986,7 +1041,7 @@ class CPythonOrderedDictSubclassTests(CPythonOrderedDictTests): | ||||
|          pass | ||||
|   | ||||
|   | ||||
| -class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase): | ||||
| +class PurePythonOrderedDictWithSlotsCopyingTests(__TestCase): | ||||
|   | ||||
|      module = py_coll | ||||
|      class OrderedDict(py_coll.OrderedDict): | ||||
| @@ -995,7 +1050,7 @@ class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase): | ||||
|   | ||||
|   | ||||
|  @unittest.skipUnless(c_coll, 'requires the C version of the collections module') | ||||
| -class CPythonOrderedDictWithSlotsCopyingTests(unittest.TestCase): | ||||
| +class CPythonOrderedDictWithSlotsCopyingTests(__TestCase): | ||||
|   | ||||
|      module = c_coll | ||||
|      class OrderedDict(c_coll.OrderedDict): | ||||
| @@ -1008,6 +1063,7 @@ class PurePythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol): | ||||
|      @classmethod | ||||
|      def setUpClass(cls): | ||||
|          cls.type2test = py_coll.OrderedDict | ||||
| +        super().setUpClass() | ||||
|   | ||||
|      def test_popitem(self): | ||||
|          d = self._empty_mapping() | ||||
| @@ -1020,6 +1076,7 @@ class CPythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol): | ||||
|      @classmethod | ||||
|      def setUpClass(cls): | ||||
|          cls.type2test = c_coll.OrderedDict | ||||
| +        super().setUpClass() | ||||
|   | ||||
|      def test_popitem(self): | ||||
|          d = self._empty_mapping() | ||||
| @@ -1033,6 +1090,7 @@ class PurePythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol): | ||||
|          class MyOrderedDict(py_coll.OrderedDict): | ||||
|              pass | ||||
|          cls.type2test = MyOrderedDict | ||||
| +        super().setUpClass() | ||||
|   | ||||
|      def test_popitem(self): | ||||
|          d = self._empty_mapping() | ||||
| @@ -1047,6 +1105,7 @@ class CPythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol): | ||||
|          class MyOrderedDict(c_coll.OrderedDict): | ||||
|              pass | ||||
|          cls.type2test = MyOrderedDict | ||||
| +        super().setUpClass() | ||||
|   | ||||
|      def test_popitem(self): | ||||
|          d = self._empty_mapping() | ||||
| @@ -1120,21 +1179,22 @@ class SimpleLRUCacheTests: | ||||
|          self.assertEqual(list(c), [1, 3, 2]) | ||||
|   | ||||
|   | ||||
| -class PySimpleLRUCacheTests(SimpleLRUCacheTests, unittest.TestCase): | ||||
| +class PySimpleLRUCacheTests(SimpleLRUCacheTests, __TestCase): | ||||
|   | ||||
|      class type2test(SimpleLRUCache, py_coll.OrderedDict): | ||||
|          pass | ||||
|   | ||||
|   | ||||
|  @unittest.skipUnless(c_coll, 'requires the C version of the collections module') | ||||
| -class CSimpleLRUCacheTests(SimpleLRUCacheTests, unittest.TestCase): | ||||
| +class CSimpleLRUCacheTests(SimpleLRUCacheTests, __TestCase): | ||||
|   | ||||
|      @classmethod | ||||
|      def setUpClass(cls): | ||||
|          class type2test(SimpleLRUCache, c_coll.OrderedDict): | ||||
|              pass | ||||
|          cls.type2test = type2test | ||||
| +        super().setUpClass() | ||||
|   | ||||
|   | ||||
|  if __name__ == "__main__": | ||||
| -    unittest.main() | ||||
| +    run_tests() | ||||
							
								
								
									
										1200
									
								
								test/dynamo/cpython/3_13/test_ordered_dict.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1200
									
								
								test/dynamo/cpython/3_13/test_ordered_dict.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										67
									
								
								test/dynamo/cpython/3_13/test_tuple.diff
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								test/dynamo/cpython/3_13/test_tuple.diff
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,67 @@ | ||||
| diff --git a/test/dynamo/cpython/3_13/test_tuple.py b/test/dynamo/cpython/3_13/test_tuple.py | ||||
| index 9ce80c5e8ea..e52c0cbc140 100644 | ||||
| --- a/test/dynamo/cpython/3_13/test_tuple.py | ||||
| +++ b/test/dynamo/cpython/3_13/test_tuple.py | ||||
| @@ -1,4 +1,55 @@ | ||||
| -from test import support, seq_tests | ||||
| +# ======= BEGIN Dynamo patch ======= | ||||
| +# Owner(s): ["module: dynamo"] | ||||
| + | ||||
| +# ruff: noqa | ||||
| +# flake8: noqa | ||||
| + | ||||
| +import sys | ||||
| +import torch | ||||
| +import torch._dynamo.test_case | ||||
| +import unittest | ||||
| +from torch._dynamo.test_case import CPythonTestCase | ||||
| +from torch.testing._internal.common_utils import run_tests | ||||
| + | ||||
| +__TestCase = CPythonTestCase | ||||
| + | ||||
| +# redirect import statements | ||||
| +import sys | ||||
| +import importlib.abc | ||||
| + | ||||
| +redirect_imports = ( | ||||
| +    "test.mapping_tests", | ||||
| +    "test.typinganndata", | ||||
| +    "test.test_grammar", | ||||
| +    "test.test_math", | ||||
| +    "test.test_iter", | ||||
| +    "test.typinganndata.ann_module", | ||||
| +) | ||||
| + | ||||
| +class RedirectImportFinder(importlib.abc.MetaPathFinder): | ||||
| +    def find_spec(self, fullname, path, target=None): | ||||
| +        # Check if the import is the problematic one | ||||
| +        if fullname in redirect_imports: | ||||
| +            try: | ||||
| +                # Attempt to import the standalone module | ||||
| +                name = fullname.removeprefix("test.") | ||||
| +                r = importlib.import_module(name) | ||||
| +                # Redirect the module in sys.modules | ||||
| +                sys.modules[fullname] = r | ||||
| +                # Return a module spec from the found module | ||||
| +                return importlib.util.find_spec(name) | ||||
| +            except ImportError: | ||||
| +                return None | ||||
| +        return None | ||||
| + | ||||
| +# Add the custom finder to sys.meta_path | ||||
| +sys.meta_path.insert(0, RedirectImportFinder()) | ||||
| + | ||||
| + | ||||
| +# ======= END DYNAMO PATCH ======= | ||||
| + | ||||
| +from test import support | ||||
| +import seq_tests | ||||
|  import unittest | ||||
|   | ||||
|  import gc | ||||
| @@ -510,4 +561,4 @@ class TupleTest(seq_tests.CommonTest): | ||||
|  #            pileup 262,143 mean 8.0 coll 262,143 z +92683.6 | ||||
|   | ||||
|  if __name__ == "__main__": | ||||
| -    unittest.main() | ||||
| +    run_tests() | ||||
							
								
								
									
										564
									
								
								test/dynamo/cpython/3_13/test_tuple.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										564
									
								
								test/dynamo/cpython/3_13/test_tuple.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,564 @@ | ||||
| # ======= BEGIN Dynamo patch ======= | ||||
| # Owner(s): ["module: dynamo"] | ||||
|  | ||||
| # ruff: noqa | ||||
| # flake8: noqa | ||||
|  | ||||
| import sys | ||||
| import torch | ||||
| import torch._dynamo.test_case | ||||
| import unittest | ||||
| from torch._dynamo.test_case import CPythonTestCase | ||||
| from torch.testing._internal.common_utils import run_tests | ||||
|  | ||||
| __TestCase = CPythonTestCase | ||||
|  | ||||
| # redirect import statements | ||||
| import sys | ||||
| import importlib.abc | ||||
|  | ||||
| redirect_imports = ( | ||||
|     "test.mapping_tests", | ||||
|     "test.typinganndata", | ||||
|     "test.test_grammar", | ||||
|     "test.test_math", | ||||
|     "test.test_iter", | ||||
|     "test.typinganndata.ann_module", | ||||
| ) | ||||
|  | ||||
| class RedirectImportFinder(importlib.abc.MetaPathFinder): | ||||
|     def find_spec(self, fullname, path, target=None): | ||||
|         # Check if the import is the problematic one | ||||
|         if fullname in redirect_imports: | ||||
|             try: | ||||
|                 # Attempt to import the standalone module | ||||
|                 name = fullname.removeprefix("test.") | ||||
|                 r = importlib.import_module(name) | ||||
|                 # Redirect the module in sys.modules | ||||
|                 sys.modules[fullname] = r | ||||
|                 # Return a module spec from the found module | ||||
|                 return importlib.util.find_spec(name) | ||||
|             except ImportError: | ||||
|                 return None | ||||
|         return None | ||||
|  | ||||
| # Add the custom finder to sys.meta_path | ||||
| sys.meta_path.insert(0, RedirectImportFinder()) | ||||
|  | ||||
|  | ||||
| # ======= END DYNAMO PATCH ======= | ||||
|  | ||||
| from test import support | ||||
| import seq_tests | ||||
| import unittest | ||||
|  | ||||
| import gc | ||||
| import pickle | ||||
|  | ||||
| # For tuple hashes, we normally only run a test to ensure that we get | ||||
| # the same results across platforms in a handful of cases.  If that's | ||||
| # so, there's no real point to running more.  Set RUN_ALL_HASH_TESTS to | ||||
| # run more anyway.  That's usually of real interest only when analyzing, | ||||
| # or changing, the hash algorithm.  In which case it's usually also | ||||
| # most useful to set JUST_SHOW_HASH_RESULTS, to see all the results | ||||
| # instead of wrestling with test "failures".  See the bottom of the | ||||
| # file for extensive notes on what we're testing here and why. | ||||
| RUN_ALL_HASH_TESTS = False | ||||
| JUST_SHOW_HASH_RESULTS = False # if RUN_ALL_HASH_TESTS, just display | ||||
|  | ||||
| class TupleTest(seq_tests.CommonTest): | ||||
|     type2test = tuple | ||||
|  | ||||
|     def test_getitem_error(self): | ||||
|         t = () | ||||
|         msg = "tuple indices must be integers or slices" | ||||
|         with self.assertRaisesRegex(TypeError, msg): | ||||
|             t['a'] | ||||
|  | ||||
|     def test_constructors(self): | ||||
|         super().test_constructors() | ||||
|         # calling built-in types without argument must return empty | ||||
|         self.assertEqual(tuple(), ()) | ||||
|         t0_3 = (0, 1, 2, 3) | ||||
|         t0_3_bis = tuple(t0_3) | ||||
|         self.assertTrue(t0_3 is t0_3_bis) | ||||
|         self.assertEqual(tuple([]), ()) | ||||
|         self.assertEqual(tuple([0, 1, 2, 3]), (0, 1, 2, 3)) | ||||
|         self.assertEqual(tuple(''), ()) | ||||
|         self.assertEqual(tuple('spam'), ('s', 'p', 'a', 'm')) | ||||
|         self.assertEqual(tuple(x for x in range(10) if x % 2), | ||||
|                          (1, 3, 5, 7, 9)) | ||||
|  | ||||
|     def test_keyword_args(self): | ||||
|         with self.assertRaisesRegex(TypeError, 'keyword argument'): | ||||
|             tuple(sequence=()) | ||||
|  | ||||
|     def test_keywords_in_subclass(self): | ||||
|         class subclass(tuple): | ||||
|             pass | ||||
|         u = subclass([1, 2]) | ||||
|         self.assertIs(type(u), subclass) | ||||
|         self.assertEqual(list(u), [1, 2]) | ||||
|         with self.assertRaises(TypeError): | ||||
|             subclass(sequence=()) | ||||
|  | ||||
|         class subclass_with_init(tuple): | ||||
|             def __init__(self, arg, newarg=None): | ||||
|                 self.newarg = newarg | ||||
|         u = subclass_with_init([1, 2], newarg=3) | ||||
|         self.assertIs(type(u), subclass_with_init) | ||||
|         self.assertEqual(list(u), [1, 2]) | ||||
|         self.assertEqual(u.newarg, 3) | ||||
|  | ||||
|         class subclass_with_new(tuple): | ||||
|             def __new__(cls, arg, newarg=None): | ||||
|                 self = super().__new__(cls, arg) | ||||
|                 self.newarg = newarg | ||||
|                 return self | ||||
|         u = subclass_with_new([1, 2], newarg=3) | ||||
|         self.assertIs(type(u), subclass_with_new) | ||||
|         self.assertEqual(list(u), [1, 2]) | ||||
|         self.assertEqual(u.newarg, 3) | ||||
|  | ||||
|     def test_truth(self): | ||||
|         super().test_truth() | ||||
|         self.assertTrue(not ()) | ||||
|         self.assertTrue((42, )) | ||||
|  | ||||
|     def test_len(self): | ||||
|         super().test_len() | ||||
|         self.assertEqual(len(()), 0) | ||||
|         self.assertEqual(len((0,)), 1) | ||||
|         self.assertEqual(len((0, 1, 2)), 3) | ||||
|  | ||||
|     def test_iadd(self): | ||||
|         super().test_iadd() | ||||
|         u = (0, 1) | ||||
|         u2 = u | ||||
|         u += (2, 3) | ||||
|         self.assertTrue(u is not u2) | ||||
|  | ||||
|     def test_imul(self): | ||||
|         super().test_imul() | ||||
|         u = (0, 1) | ||||
|         u2 = u | ||||
|         u *= 3 | ||||
|         self.assertTrue(u is not u2) | ||||
|  | ||||
|     def test_tupleresizebug(self): | ||||
|         # Check that a specific bug in _PyTuple_Resize() is squashed. | ||||
|         def f(): | ||||
|             for i in range(1000): | ||||
|                 yield i | ||||
|         self.assertEqual(list(tuple(f())), list(range(1000))) | ||||
|  | ||||
|     # We expect tuples whose base components have deterministic hashes to | ||||
|     # have deterministic hashes too - and, indeed, the same hashes across | ||||
|     # platforms with hash codes of the same bit width. | ||||
|     def test_hash_exact(self): | ||||
|         def check_one_exact(t, e32, e64): | ||||
|             got = hash(t) | ||||
|             expected = e32 if support.NHASHBITS == 32 else e64 | ||||
|             if got != expected: | ||||
|                 msg = f"FAIL hash({t!r}) == {got} != {expected}" | ||||
|                 self.fail(msg) | ||||
|  | ||||
|         check_one_exact((), 750394483, 5740354900026072187) | ||||
|         check_one_exact((0,), 1214856301, -8753497827991233192) | ||||
|         check_one_exact((0, 0), -168982784, -8458139203682520985) | ||||
|         check_one_exact((0.5,), 2077348973, -408149959306781352) | ||||
|         check_one_exact((0.5, (), (-2, 3, (4, 6))), 714642271, | ||||
|                         -1845940830829704396) | ||||
|  | ||||
|     # Various tests for hashing of tuples to check that we get few collisions. | ||||
|     # Does something only if RUN_ALL_HASH_TESTS is true. | ||||
|     # | ||||
|     # Earlier versions of the tuple hash algorithm had massive collisions | ||||
|     # reported at: | ||||
|     # - https://bugs.python.org/issue942952 | ||||
|     # - https://bugs.python.org/issue34751 | ||||
|     def test_hash_optional(self): | ||||
|         from itertools import product | ||||
|  | ||||
|         if not RUN_ALL_HASH_TESTS: | ||||
|             return | ||||
|  | ||||
|         # If specified, `expected` is a 2-tuple of expected | ||||
|         # (number_of_collisions, pileup) values, and the test fails if | ||||
|         # those aren't the values we get.  Also if specified, the test | ||||
|         # fails if z > `zlimit`. | ||||
|         def tryone_inner(tag, nbins, hashes, expected=None, zlimit=None): | ||||
|             from collections import Counter | ||||
|  | ||||
|             nballs = len(hashes) | ||||
|             mean, sdev = support.collision_stats(nbins, nballs) | ||||
|             c = Counter(hashes) | ||||
|             collisions = nballs - len(c) | ||||
|             z = (collisions - mean) / sdev | ||||
|             pileup = max(c.values()) - 1 | ||||
|             del c | ||||
|             got = (collisions, pileup) | ||||
|             failed = False | ||||
|             prefix = "" | ||||
|             if zlimit is not None and z > zlimit: | ||||
|                 failed = True | ||||
|                 prefix = f"FAIL z > {zlimit}; " | ||||
|             if expected is not None and got != expected: | ||||
|                 failed = True | ||||
|                 prefix += f"FAIL {got} != {expected}; " | ||||
|             if failed or JUST_SHOW_HASH_RESULTS: | ||||
|                 msg = f"{prefix}{tag}; pileup {pileup:,} mean {mean:.1f} " | ||||
|                 msg += f"coll {collisions:,} z {z:+.1f}" | ||||
|                 if JUST_SHOW_HASH_RESULTS: | ||||
|                     import sys | ||||
|                     print(msg, file=sys.__stdout__) | ||||
|                 else: | ||||
|                     self.fail(msg) | ||||
|  | ||||
|         def tryone(tag, xs, | ||||
|                    native32=None, native64=None, hi32=None, lo32=None, | ||||
|                    zlimit=None): | ||||
|             NHASHBITS = support.NHASHBITS | ||||
|             hashes = list(map(hash, xs)) | ||||
|             tryone_inner(tag + f"; {NHASHBITS}-bit hash codes", | ||||
|                          1 << NHASHBITS, | ||||
|                          hashes, | ||||
|                          native32 if NHASHBITS == 32 else native64, | ||||
|                          zlimit) | ||||
|  | ||||
|             if NHASHBITS > 32: | ||||
|                 shift = NHASHBITS - 32 | ||||
|                 tryone_inner(tag + "; 32-bit upper hash codes", | ||||
|                              1 << 32, | ||||
|                              [h >> shift for h in hashes], | ||||
|                              hi32, | ||||
|                              zlimit) | ||||
|  | ||||
|                 mask = (1 << 32) - 1 | ||||
|                 tryone_inner(tag + "; 32-bit lower hash codes", | ||||
|                              1 << 32, | ||||
|                              [h & mask for h in hashes], | ||||
|                              lo32, | ||||
|                              zlimit) | ||||
|  | ||||
|         # Tuples of smallish positive integers are common - nice if we | ||||
|         # get "better than random" for these. | ||||
|         tryone("range(100) by 3", list(product(range(100), repeat=3)), | ||||
|                (0, 0), (0, 0), (4, 1), (0, 0)) | ||||
|  | ||||
|         # A previous hash had systematic problems when mixing integers of | ||||
|         # similar magnitude but opposite sign, obscurely related to that | ||||
|         # j ^ -2 == -j when j is odd. | ||||
|         cands = list(range(-10, -1)) + list(range(9)) | ||||
|  | ||||
|         # Note:  -1 is omitted because hash(-1) == hash(-2) == -2, and | ||||
|         # there's nothing the tuple hash can do to avoid collisions | ||||
|         # inherited from collisions in the tuple components' hashes. | ||||
|         tryone("-10 .. 8 by 4", list(product(cands, repeat=4)), | ||||
|                (0, 0), (0, 0), (0, 0), (0, 0)) | ||||
|         del cands | ||||
|  | ||||
|         # The hashes here are a weird mix of values where all the | ||||
|         # variation is in the lowest bits and across a single high-order | ||||
|         # bit - the middle bits are all zeroes. A decent hash has to | ||||
|         # both propagate low bits to the left and high bits to the | ||||
|         # right.  This is also complicated a bit in that there are | ||||
|         # collisions among the hashes of the integers in L alone. | ||||
|         L = [n << 60 for n in range(100)] | ||||
|         tryone("0..99 << 60 by 3", list(product(L, repeat=3)), | ||||
|                (0, 0), (0, 0), (0, 0), (324, 1)) | ||||
|         del L | ||||
|  | ||||
|         # Used to suffer a massive number of collisions. | ||||
|         tryone("[-3, 3] by 18", list(product([-3, 3], repeat=18)), | ||||
|                (7, 1), (0, 0), (7, 1), (6, 1)) | ||||
|  | ||||
|         # And even worse.  hash(0.5) has only a single bit set, at the | ||||
|         # high end. A decent hash needs to propagate high bits right. | ||||
|         tryone("[0, 0.5] by 18", list(product([0, 0.5], repeat=18)), | ||||
|                (5, 1), (0, 0), (9, 1), (12, 1)) | ||||
|  | ||||
|         # Hashes of ints and floats are the same across platforms. | ||||
|         # String hashes vary even on a single platform across runs, due | ||||
|         # to hash randomization for strings.  So we can't say exactly | ||||
|         # what this should do.  Instead we insist that the # of | ||||
|         # collisions is no more than 4 sdevs above the theoretically | ||||
|         # random mean.  Even if the tuple hash can't achieve that on its | ||||
|         # own, the string hash is trying to be decently pseudo-random | ||||
|         # (in all bit positions) on _its_ own.  We can at least test | ||||
|         # that the tuple hash doesn't systematically ruin that. | ||||
|         tryone("4-char tuples", | ||||
|                list(product("abcdefghijklmnopqrstuvwxyz", repeat=4)), | ||||
|                zlimit=4.0) | ||||
|  | ||||
|         # The "old tuple test".  See https://bugs.python.org/issue942952. | ||||
|         # Ensures, for example, that the hash: | ||||
|         #   is non-commutative | ||||
|         #   spreads closely spaced values | ||||
|         #   doesn't exhibit cancellation in tuples like (x,(x,y)) | ||||
|         N = 50 | ||||
|         base = list(range(N)) | ||||
|         xp = list(product(base, repeat=2)) | ||||
|         inps = base + list(product(base, xp)) + \ | ||||
|                      list(product(xp, base)) + xp + list(zip(base)) | ||||
|         tryone("old tuple test", inps, | ||||
|                (2, 1), (0, 0), (52, 49), (7, 1)) | ||||
|         del base, xp, inps | ||||
|  | ||||
|         # The "new tuple test".  See https://bugs.python.org/issue34751. | ||||
|         # Even more tortured nesting, and a mix of signed ints of very | ||||
|         # small magnitude. | ||||
|         n = 5 | ||||
|         A = [x for x in range(-n, n+1) if x != -1] | ||||
|         B = A + [(a,) for a in A] | ||||
|         L2 = list(product(A, repeat=2)) | ||||
|         L3 = L2 + list(product(A, repeat=3)) | ||||
|         L4 = L3 + list(product(A, repeat=4)) | ||||
|         # T = list of testcases. These consist of all (possibly nested | ||||
|         # at most 2 levels deep) tuples containing at most 4 items from | ||||
|         # the set A. | ||||
|         T = A | ||||
|         T += [(a,) for a in B + L4] | ||||
|         T += product(L3, B) | ||||
|         T += product(L2, repeat=2) | ||||
|         T += product(B, L3) | ||||
|         T += product(B, B, L2) | ||||
|         T += product(B, L2, B) | ||||
|         T += product(L2, B, B) | ||||
|         T += product(B, repeat=4) | ||||
|         assert len(T) == 345130 | ||||
|         tryone("new tuple test", T, | ||||
|                (9, 1), (0, 0), (21, 5), (6, 1)) | ||||
|  | ||||
|     def test_repr(self): | ||||
|         l0 = tuple() | ||||
|         l2 = (0, 1, 2) | ||||
|         a0 = self.type2test(l0) | ||||
|         a2 = self.type2test(l2) | ||||
|  | ||||
|         self.assertEqual(str(a0), repr(l0)) | ||||
|         self.assertEqual(str(a2), repr(l2)) | ||||
|         self.assertEqual(repr(a0), "()") | ||||
|         self.assertEqual(repr(a2), "(0, 1, 2)") | ||||
|  | ||||
|     def _not_tracked(self, t): | ||||
|         # Nested tuples can take several collections to untrack | ||||
|         gc.collect() | ||||
|         gc.collect() | ||||
|         self.assertFalse(gc.is_tracked(t), t) | ||||
|  | ||||
|     def _tracked(self, t): | ||||
|         self.assertTrue(gc.is_tracked(t), t) | ||||
|         gc.collect() | ||||
|         gc.collect() | ||||
|         self.assertTrue(gc.is_tracked(t), t) | ||||
|  | ||||
|     @support.cpython_only | ||||
|     def test_track_literals(self): | ||||
|         # Test GC-optimization of tuple literals | ||||
|         x, y, z = 1.5, "a", [] | ||||
|  | ||||
|         self._not_tracked(()) | ||||
|         self._not_tracked((1,)) | ||||
|         self._not_tracked((1, 2)) | ||||
|         self._not_tracked((1, 2, "a")) | ||||
|         self._not_tracked((1, 2, (None, True, False, ()), int)) | ||||
|         self._not_tracked((object(),)) | ||||
|         self._not_tracked(((1, x), y, (2, 3))) | ||||
|  | ||||
|         # Tuples with mutable elements are always tracked, even if those | ||||
|         # elements are not tracked right now. | ||||
|         self._tracked(([],)) | ||||
|         self._tracked(([1],)) | ||||
|         self._tracked(({},)) | ||||
|         self._tracked((set(),)) | ||||
|         self._tracked((x, y, z)) | ||||
|  | ||||
|     def check_track_dynamic(self, tp, always_track): | ||||
|         x, y, z = 1.5, "a", [] | ||||
|  | ||||
|         check = self._tracked if always_track else self._not_tracked | ||||
|         check(tp()) | ||||
|         check(tp([])) | ||||
|         check(tp(set())) | ||||
|         check(tp([1, x, y])) | ||||
|         check(tp(obj for obj in [1, x, y])) | ||||
|         check(tp(set([1, x, y]))) | ||||
|         check(tp(tuple([obj]) for obj in [1, x, y])) | ||||
|         check(tuple(tp([obj]) for obj in [1, x, y])) | ||||
|  | ||||
|         self._tracked(tp([z])) | ||||
|         self._tracked(tp([[x, y]])) | ||||
|         self._tracked(tp([{x: y}])) | ||||
|         self._tracked(tp(obj for obj in [x, y, z])) | ||||
|         self._tracked(tp(tuple([obj]) for obj in [x, y, z])) | ||||
|         self._tracked(tuple(tp([obj]) for obj in [x, y, z])) | ||||
|  | ||||
|     @support.cpython_only | ||||
|     def test_track_dynamic(self): | ||||
|         # Test GC-optimization of dynamically constructed tuples. | ||||
|         self.check_track_dynamic(tuple, False) | ||||
|  | ||||
|     @support.cpython_only | ||||
|     def test_track_subtypes(self): | ||||
|         # Tuple subtypes must always be tracked | ||||
|         class MyTuple(tuple): | ||||
|             pass | ||||
|         self.check_track_dynamic(MyTuple, True) | ||||
|  | ||||
|     @support.cpython_only | ||||
|     def test_bug7466(self): | ||||
|         # Trying to untrack an unfinished tuple could crash Python | ||||
|         self._not_tracked(tuple(gc.collect() for i in range(101))) | ||||
|  | ||||
|     def test_repr_large(self): | ||||
|         # Check the repr of large list objects | ||||
|         def check(n): | ||||
|             l = (0,) * n | ||||
|             s = repr(l) | ||||
|             self.assertEqual(s, | ||||
|                 '(' + ', '.join(['0'] * n) + ')') | ||||
|         check(10)       # check our checking code | ||||
|         check(1000000) | ||||
|  | ||||
|     def test_iterator_pickle(self): | ||||
|         # Userlist iterators don't support pickling yet since | ||||
|         # they are based on generators. | ||||
|         data = self.type2test([4, 5, 6, 7]) | ||||
|         for proto in range(pickle.HIGHEST_PROTOCOL + 1): | ||||
|             itorg = iter(data) | ||||
|             d = pickle.dumps(itorg, proto) | ||||
|             it = pickle.loads(d) | ||||
|             self.assertEqual(type(itorg), type(it)) | ||||
|             self.assertEqual(self.type2test(it), self.type2test(data)) | ||||
|  | ||||
|             it = pickle.loads(d) | ||||
|             next(it) | ||||
|             d = pickle.dumps(it, proto) | ||||
|             self.assertEqual(self.type2test(it), self.type2test(data)[1:]) | ||||
|  | ||||
|     def test_reversed_pickle(self): | ||||
|         data = self.type2test([4, 5, 6, 7]) | ||||
|         for proto in range(pickle.HIGHEST_PROTOCOL + 1): | ||||
|             itorg = reversed(data) | ||||
|             d = pickle.dumps(itorg, proto) | ||||
|             it = pickle.loads(d) | ||||
|             self.assertEqual(type(itorg), type(it)) | ||||
|             self.assertEqual(self.type2test(it), self.type2test(reversed(data))) | ||||
|  | ||||
|             it = pickle.loads(d) | ||||
|             next(it) | ||||
|             d = pickle.dumps(it, proto) | ||||
|             self.assertEqual(self.type2test(it), self.type2test(reversed(data))[1:]) | ||||
|  | ||||
|     def test_no_comdat_folding(self): | ||||
|         # Issue 8847: In the PGO build, the MSVC linker's COMDAT folding | ||||
|         # optimization causes failures in code that relies on distinct | ||||
|         # function addresses. | ||||
|         class T(tuple): pass | ||||
|         with self.assertRaises(TypeError): | ||||
|             [3,] + T((1,2)) | ||||
|  | ||||
|     def test_lexicographic_ordering(self): | ||||
|         # Issue 21100 | ||||
|         a = self.type2test([1, 2]) | ||||
|         b = self.type2test([1, 2, 0]) | ||||
|         c = self.type2test([1, 3]) | ||||
|         self.assertLess(a, b) | ||||
|         self.assertLess(b, c) | ||||
|  | ||||
| # Notes on testing hash codes.  The primary thing is that Python doesn't | ||||
| # care about "random" hash codes.  To the contrary, we like them to be | ||||
| # very regular when possible, so that the low-order bits are as evenly | ||||
| # distributed as possible.  For integers this is easy: hash(i) == i for | ||||
| # all not-huge i except i==-1. | ||||
| # | ||||
| # For tuples of mixed type there's really no hope of that, so we want | ||||
| # "randomish" here instead.  But getting close to pseudo-random in all | ||||
| # bit positions is more expensive than we've been willing to pay for. | ||||
| # | ||||
| # We can tolerate large deviations from random - what we don't want is | ||||
| # catastrophic pileups on a relative handful of hash codes.  The dict | ||||
| # and set lookup routines remain effective provided that full-width hash | ||||
| # codes for not-equal objects are distinct. | ||||
| # | ||||
| # So we compute various statistics here based on what a "truly random" | ||||
| # hash would do, but don't automate "pass or fail" based on those | ||||
| # results.  Instead those are viewed as inputs to human judgment, and the | ||||
| # automated tests merely ensure we get the _same_ results across | ||||
| # platforms.  In fact, we normally don't bother to run them at all - | ||||
| # set RUN_ALL_HASH_TESTS to force it. | ||||
| # | ||||
| # When global JUST_SHOW_HASH_RESULTS is True, the tuple hash statistics | ||||
| # are just displayed to stdout.  A typical output line looks like: | ||||
| # | ||||
| # old tuple test; 32-bit upper hash codes; \ | ||||
| #             pileup 49 mean 7.4 coll 52 z +16.4 | ||||
| # | ||||
| # "old tuple test" is just a string name for the test being run. | ||||
| # | ||||
| # "32-bit upper hash codes" means this was run under a 64-bit build and | ||||
| # we've shifted away the lower 32 bits of the hash codes. | ||||
| # | ||||
| # "pileup" is 0 if there were no collisions across those hash codes. | ||||
| # It's 1 less than the maximum number of times any single hash code was | ||||
| # seen.  So in this case, there was (at least) one hash code that was | ||||
| # seen 50 times:  that hash code "piled up" 49 more times than ideal. | ||||
| # | ||||
| # "mean" is the number of collisions a perfectly random hash function | ||||
| # would have yielded, on average. | ||||
| # | ||||
| # "coll" is the number of collisions actually seen. | ||||
| # | ||||
| # "z" is "coll - mean" divided by the standard deviation of the number | ||||
| # of collisions a perfectly random hash function would suffer.  A | ||||
| # positive value is "worse than random", and negative value "better than | ||||
| # random".  Anything of magnitude greater than 3 would be highly suspect | ||||
| # for a hash function that claimed to be random.  It's essentially | ||||
| # impossible that a truly random function would deliver a result 16.4 | ||||
| # sdevs "worse than random". | ||||
| # | ||||
| # But we don't care here!  That's why the test isn't coded to fail. | ||||
| # Knowing something about how the high-order hash code bits behave | ||||
| # provides insight, but is irrelevant to how the dict and set lookup | ||||
| # code performs.  The low-order bits are much more important to that, | ||||
| # and on the same test those did "just like random": | ||||
| # | ||||
| # old tuple test; 32-bit lower hash codes; \ | ||||
| #            pileup 1 mean 7.4 coll 7 z -0.2 | ||||
| # | ||||
| # So there are always tradeoffs to consider.  For another: | ||||
| # | ||||
| # 0..99 << 60 by 3; 32-bit hash codes; \ | ||||
| #            pileup 0 mean 116.4 coll 0 z -10.8 | ||||
| # | ||||
| # That was run under a 32-bit build, and is spectacularly "better than | ||||
| # random".  On a 64-bit build the wider hash codes are fine too: | ||||
| # | ||||
| # 0..99 << 60 by 3; 64-bit hash codes; \ | ||||
| #             pileup 0 mean 0.0 coll 0 z -0.0 | ||||
| # | ||||
| # but their lower 32 bits are poor: | ||||
| # | ||||
| # 0..99 << 60 by 3; 32-bit lower hash codes; \ | ||||
| #             pileup 1 mean 116.4 coll 324 z +19.2 | ||||
| # | ||||
| # In a statistical sense that's waaaaay too many collisions, but (a) 324 | ||||
| # collisions out of a million hash codes isn't anywhere near being a | ||||
| # real problem; and, (b) the worst pileup on a single hash code is a measly | ||||
| # 1 extra.  It's a relatively poor case for the tuple hash, but still | ||||
| # fine for practical use. | ||||
| # | ||||
| # This isn't, which is what Python 3.7.1 produced for the hashes of | ||||
| # itertools.product([0, 0.5], repeat=18).  Even with a fat 64-bit | ||||
| # hashcode, the highest pileup was over 16,000 - making a dict/set | ||||
| # lookup on one of the colliding values thousands of times slower (on | ||||
| # average) than we expect. | ||||
| # | ||||
| # [0, 0.5] by 18; 64-bit hash codes; \ | ||||
| #            pileup 16,383 mean 0.0 coll 262,128 z +6073641856.9 | ||||
| # [0, 0.5] by 18; 32-bit lower hash codes; \ | ||||
| #            pileup 262,143 mean 8.0 coll 262,143 z +92683.6 | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
							
								
								
									
										74
									
								
								test/dynamo/cpython/3_13/test_userdict.diff
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								test/dynamo/cpython/3_13/test_userdict.diff
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,74 @@ | ||||
| diff --git a/test/dynamo/cpython/3_13/test_userdict.py b/test/dynamo/cpython/3_13/test_userdict.py | ||||
| index 61e79f553e8..c953390355e 100644 | ||||
| --- a/test/dynamo/cpython/3_13/test_userdict.py | ||||
| +++ b/test/dynamo/cpython/3_13/test_userdict.py | ||||
| @@ -1,3 +1,54 @@ | ||||
| +# ======= BEGIN Dynamo patch ======= | ||||
| +# Owner(s): ["module: dynamo"] | ||||
| + | ||||
| +# ruff: noqa | ||||
| +# flake8: noqa | ||||
| + | ||||
| +import sys | ||||
| +import torch | ||||
| +import torch._dynamo.test_case | ||||
| +import unittest | ||||
| +from torch._dynamo.test_case import CPythonTestCase | ||||
| +from torch.testing._internal.common_utils import run_tests | ||||
| + | ||||
| +__TestCase = CPythonTestCase | ||||
| + | ||||
| + | ||||
| +# redirect import statements | ||||
| +import sys | ||||
| +import importlib.abc | ||||
| + | ||||
| +redirect_imports = ( | ||||
| +    "test.mapping_tests", | ||||
| +    "test.typinganndata", | ||||
| +    "test.test_grammar", | ||||
| +    "test.test_math", | ||||
| +    "test.test_iter", | ||||
| +    "test.typinganndata.ann_module", | ||||
| +) | ||||
| + | ||||
| +class RedirectImportFinder(importlib.abc.MetaPathFinder): | ||||
| +    def find_spec(self, fullname, path, target=None): | ||||
| +        # Check if the import is the problematic one | ||||
| +        if fullname in redirect_imports: | ||||
| +            try: | ||||
| +                # Attempt to import the standalone module | ||||
| +                name = fullname.removeprefix("test.") | ||||
| +                r = importlib.import_module(name) | ||||
| +                # Redirect the module in sys.modules | ||||
| +                sys.modules[fullname] = r | ||||
| +                # Return a module spec from the found module | ||||
| +                return importlib.util.find_spec(name) | ||||
| +            except ImportError: | ||||
| +                return None | ||||
| +        return None | ||||
| + | ||||
| +# Add the custom finder to sys.meta_path | ||||
| +sys.meta_path.insert(0, RedirectImportFinder()) | ||||
| + | ||||
| + | ||||
| +# ======= END DYNAMO PATCH ======= | ||||
| + | ||||
|  # Check every path through every method of UserDict | ||||
|   | ||||
|  from test import mapping_tests, support | ||||
| @@ -215,10 +266,10 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol): | ||||
|   | ||||
|      # Decorate existing test with recursion limit, because | ||||
|      # the test is for C structure, but `UserDict` is a Python structure. | ||||
| -    test_repr_deep = support.infinite_recursion(25)( | ||||
| -        mapping_tests.TestHashMappingProtocol.test_repr_deep, | ||||
| -    ) | ||||
| +    # test_repr_deep = support.infinite_recursion(25)( | ||||
| +    #     mapping_tests.TestHashMappingProtocol.test_repr_deep, | ||||
| +    # ) | ||||
|   | ||||
|   | ||||
|  if __name__ == "__main__": | ||||
| -    unittest.main() | ||||
| +    run_tests() | ||||
							
								
								
									
										275
									
								
								test/dynamo/cpython/3_13/test_userdict.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										275
									
								
								test/dynamo/cpython/3_13/test_userdict.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,275 @@ | ||||
| # ======= BEGIN Dynamo patch ======= | ||||
| # Owner(s): ["module: dynamo"] | ||||
|  | ||||
| # ruff: noqa | ||||
| # flake8: noqa | ||||
|  | ||||
| import sys | ||||
| import torch | ||||
| import torch._dynamo.test_case | ||||
| import unittest | ||||
| from torch._dynamo.test_case import CPythonTestCase | ||||
| from torch.testing._internal.common_utils import run_tests | ||||
|  | ||||
| __TestCase = CPythonTestCase | ||||
|  | ||||
|  | ||||
| # redirect import statements | ||||
| import sys | ||||
| import importlib.abc | ||||
|  | ||||
| redirect_imports = ( | ||||
|     "test.mapping_tests", | ||||
|     "test.typinganndata", | ||||
|     "test.test_grammar", | ||||
|     "test.test_math", | ||||
|     "test.test_iter", | ||||
|     "test.typinganndata.ann_module", | ||||
| ) | ||||
|  | ||||
| class RedirectImportFinder(importlib.abc.MetaPathFinder): | ||||
|     def find_spec(self, fullname, path, target=None): | ||||
|         # Check if the import is the problematic one | ||||
|         if fullname in redirect_imports: | ||||
|             try: | ||||
|                 # Attempt to import the standalone module | ||||
|                 name = fullname.removeprefix("test.") | ||||
|                 r = importlib.import_module(name) | ||||
|                 # Redirect the module in sys.modules | ||||
|                 sys.modules[fullname] = r | ||||
|                 # Return a module spec from the found module | ||||
|                 return importlib.util.find_spec(name) | ||||
|             except ImportError: | ||||
|                 return None | ||||
|         return None | ||||
|  | ||||
| # Add the custom finder to sys.meta_path | ||||
| sys.meta_path.insert(0, RedirectImportFinder()) | ||||
|  | ||||
|  | ||||
| # ======= END DYNAMO PATCH ======= | ||||
|  | ||||
| # Check every path through every method of UserDict | ||||
|  | ||||
| from test import mapping_tests, support | ||||
| import unittest | ||||
| import collections | ||||
|  | ||||
| d0 = {} | ||||
| d1 = {"one": 1} | ||||
| d2 = {"one": 1, "two": 2} | ||||
| d3 = {"one": 1, "two": 3, "three": 5} | ||||
| d4 = {"one": None, "two": None} | ||||
| d5 = {"one": 1, "two": 1} | ||||
|  | ||||
| class UserDictTest(mapping_tests.TestHashMappingProtocol): | ||||
|     type2test = collections.UserDict | ||||
|  | ||||
|     def test_all(self): | ||||
|         # Test constructors | ||||
|         u = collections.UserDict() | ||||
|         u0 = collections.UserDict(d0) | ||||
|         u1 = collections.UserDict(d1) | ||||
|         u2 = collections.UserDict(d2) | ||||
|  | ||||
|         uu = collections.UserDict(u) | ||||
|         uu0 = collections.UserDict(u0) | ||||
|         uu1 = collections.UserDict(u1) | ||||
|         uu2 = collections.UserDict(u2) | ||||
|  | ||||
|         # keyword arg constructor | ||||
|         self.assertEqual(collections.UserDict(one=1, two=2), d2) | ||||
|         # item sequence constructor | ||||
|         self.assertEqual(collections.UserDict([('one',1), ('two',2)]), d2) | ||||
|         self.assertEqual(collections.UserDict(dict=[('one',1), ('two',2)]), | ||||
|                          {'dict': [('one', 1), ('two', 2)]}) | ||||
|         # both together | ||||
|         self.assertEqual(collections.UserDict([('one',1), ('two',2)], two=3, three=5), d3) | ||||
|  | ||||
|         # alternate constructor | ||||
|         self.assertEqual(collections.UserDict.fromkeys('one two'.split()), d4) | ||||
|         self.assertEqual(collections.UserDict().fromkeys('one two'.split()), d4) | ||||
|         self.assertEqual(collections.UserDict.fromkeys('one two'.split(), 1), d5) | ||||
|         self.assertEqual(collections.UserDict().fromkeys('one two'.split(), 1), d5) | ||||
|         self.assertTrue(u1.fromkeys('one two'.split()) is not u1) | ||||
|         self.assertIsInstance(u1.fromkeys('one two'.split()), collections.UserDict) | ||||
|         self.assertIsInstance(u2.fromkeys('one two'.split()), collections.UserDict) | ||||
|  | ||||
|         # Test __repr__ | ||||
|         self.assertEqual(str(u0), str(d0)) | ||||
|         self.assertEqual(repr(u1), repr(d1)) | ||||
|         self.assertIn(repr(u2), ("{'one': 1, 'two': 2}", | ||||
|                                  "{'two': 2, 'one': 1}")) | ||||
|  | ||||
|         # Test rich comparison and __len__ | ||||
|         all = [d0, d1, d2, u, u0, u1, u2, uu, uu0, uu1, uu2] | ||||
|         for a in all: | ||||
|             for b in all: | ||||
|                 self.assertEqual(a == b, len(a) == len(b)) | ||||
|  | ||||
|         # Test __getitem__ | ||||
|         self.assertEqual(u2["one"], 1) | ||||
|         self.assertRaises(KeyError, u1.__getitem__, "two") | ||||
|  | ||||
|         # Test __setitem__ | ||||
|         u3 = collections.UserDict(u2) | ||||
|         u3["two"] = 2 | ||||
|         u3["three"] = 3 | ||||
|  | ||||
|         # Test __delitem__ | ||||
|         del u3["three"] | ||||
|         self.assertRaises(KeyError, u3.__delitem__, "three") | ||||
|  | ||||
|         # Test clear | ||||
|         u3.clear() | ||||
|         self.assertEqual(u3, {}) | ||||
|  | ||||
|         # Test copy() | ||||
|         u2a = u2.copy() | ||||
|         self.assertEqual(u2a, u2) | ||||
|         u2b = collections.UserDict(x=42, y=23) | ||||
|         u2c = u2b.copy() # making a copy of a UserDict is special cased | ||||
|         self.assertEqual(u2b, u2c) | ||||
|  | ||||
|         class MyUserDict(collections.UserDict): | ||||
|             def display(self): print(self) | ||||
|  | ||||
|         m2 = MyUserDict(u2) | ||||
|         m2a = m2.copy() | ||||
|         self.assertEqual(m2a, m2) | ||||
|  | ||||
|         # SF bug #476616 -- copy() of UserDict subclass shared data | ||||
|         m2['foo'] = 'bar' | ||||
|         self.assertNotEqual(m2a, m2) | ||||
|  | ||||
|         # Test keys, items, values | ||||
|         self.assertEqual(sorted(u2.keys()), sorted(d2.keys())) | ||||
|         self.assertEqual(sorted(u2.items()), sorted(d2.items())) | ||||
|         self.assertEqual(sorted(u2.values()), sorted(d2.values())) | ||||
|  | ||||
|         # Test "in". | ||||
|         for i in u2.keys(): | ||||
|             self.assertIn(i, u2) | ||||
|             self.assertEqual(i in u1, i in d1) | ||||
|             self.assertEqual(i in u0, i in d0) | ||||
|  | ||||
|         # Test update | ||||
|         t = collections.UserDict() | ||||
|         t.update(u2) | ||||
|         self.assertEqual(t, u2) | ||||
|  | ||||
|         # Test get | ||||
|         for i in u2.keys(): | ||||
|             self.assertEqual(u2.get(i), u2[i]) | ||||
|             self.assertEqual(u1.get(i), d1.get(i)) | ||||
|             self.assertEqual(u0.get(i), d0.get(i)) | ||||
|  | ||||
|         # Test "in" iteration. | ||||
|         for i in range(20): | ||||
|             u2[i] = str(i) | ||||
|         ikeys = [] | ||||
|         for k in u2: | ||||
|             ikeys.append(k) | ||||
|         keys = u2.keys() | ||||
|         self.assertEqual(set(ikeys), set(keys)) | ||||
|  | ||||
|         # Test setdefault | ||||
|         t = collections.UserDict() | ||||
|         self.assertEqual(t.setdefault("x", 42), 42) | ||||
|         self.assertIn("x", t) | ||||
|         self.assertEqual(t.setdefault("x", 23), 42) | ||||
|  | ||||
|         # Test pop | ||||
|         t = collections.UserDict(x=42) | ||||
|         self.assertEqual(t.pop("x"), 42) | ||||
|         self.assertRaises(KeyError, t.pop, "x") | ||||
|         self.assertEqual(t.pop("x", 1), 1) | ||||
|         t["x"] = 42 | ||||
|         self.assertEqual(t.pop("x", 1), 42) | ||||
|  | ||||
|         # Test popitem | ||||
|         t = collections.UserDict(x=42) | ||||
|         self.assertEqual(t.popitem(), ("x", 42)) | ||||
|         self.assertRaises(KeyError, t.popitem) | ||||
|  | ||||
|     def test_init(self): | ||||
|         for kw in 'self', 'other', 'iterable': | ||||
|             self.assertEqual(list(collections.UserDict(**{kw: 42}).items()), | ||||
|                              [(kw, 42)]) | ||||
|         self.assertEqual(list(collections.UserDict({}, dict=42).items()), | ||||
|                          [('dict', 42)]) | ||||
|         self.assertEqual(list(collections.UserDict({}, dict=None).items()), | ||||
|                          [('dict', None)]) | ||||
|         self.assertEqual(list(collections.UserDict(dict={'a': 42}).items()), | ||||
|                          [('dict', {'a': 42})]) | ||||
|         self.assertRaises(TypeError, collections.UserDict, 42) | ||||
|         self.assertRaises(TypeError, collections.UserDict, (), ()) | ||||
|         self.assertRaises(TypeError, collections.UserDict.__init__) | ||||
|  | ||||
|     def test_update(self): | ||||
|         for kw in 'self', 'dict', 'other', 'iterable': | ||||
|             d = collections.UserDict() | ||||
|             d.update(**{kw: 42}) | ||||
|             self.assertEqual(list(d.items()), [(kw, 42)]) | ||||
|         self.assertRaises(TypeError, collections.UserDict().update, 42) | ||||
|         self.assertRaises(TypeError, collections.UserDict().update, {}, {}) | ||||
|         self.assertRaises(TypeError, collections.UserDict.update) | ||||
|  | ||||
|     def test_missing(self): | ||||
|         # Make sure UserDict doesn't have a __missing__ method | ||||
|         self.assertEqual(hasattr(collections.UserDict, "__missing__"), False) | ||||
|         # Test several cases: | ||||
|         # (D) subclass defines __missing__ method returning a value | ||||
|         # (E) subclass defines __missing__ method raising RuntimeError | ||||
|         # (F) subclass sets __missing__ instance variable (no effect) | ||||
|         # (G) subclass doesn't define __missing__ at all | ||||
|         class D(collections.UserDict): | ||||
|             def __missing__(self, key): | ||||
|                 return 42 | ||||
|         d = D({1: 2, 3: 4}) | ||||
|         self.assertEqual(d[1], 2) | ||||
|         self.assertEqual(d[3], 4) | ||||
|         self.assertNotIn(2, d) | ||||
|         self.assertNotIn(2, d.keys()) | ||||
|         self.assertEqual(d[2], 42) | ||||
|         class E(collections.UserDict): | ||||
|             def __missing__(self, key): | ||||
|                 raise RuntimeError(key) | ||||
|         e = E() | ||||
|         try: | ||||
|             e[42] | ||||
|         except RuntimeError as err: | ||||
|             self.assertEqual(err.args, (42,)) | ||||
|         else: | ||||
|             self.fail("e[42] didn't raise RuntimeError") | ||||
|         class F(collections.UserDict): | ||||
|             def __init__(self): | ||||
|                 # An instance variable __missing__ should have no effect | ||||
|                 self.__missing__ = lambda key: None | ||||
|                 collections.UserDict.__init__(self) | ||||
|         f = F() | ||||
|         try: | ||||
|             f[42] | ||||
|         except KeyError as err: | ||||
|             self.assertEqual(err.args, (42,)) | ||||
|         else: | ||||
|             self.fail("f[42] didn't raise KeyError") | ||||
|         class G(collections.UserDict): | ||||
|             pass | ||||
|         g = G() | ||||
|         try: | ||||
|             g[42] | ||||
|         except KeyError as err: | ||||
|             self.assertEqual(err.args, (42,)) | ||||
|         else: | ||||
|             self.fail("g[42] didn't raise KeyError") | ||||
|  | ||||
|     # Decorate existing test with recursion limit, because | ||||
|     # the test is for C structure, but `UserDict` is a Python structure. | ||||
|     # test_repr_deep = support.infinite_recursion(25)( | ||||
|     #     mapping_tests.TestHashMappingProtocol.test_repr_deep, | ||||
|     # ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
							
								
								
									
										78
									
								
								test/dynamo/cpython/3_13/test_userlist.diff
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								test/dynamo/cpython/3_13/test_userlist.diff
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,78 @@ | ||||
| diff --git a/test/dynamo/cpython/3_13/test_userlist.py b/test/dynamo/cpython/3_13/test_userlist.py | ||||
| index 312702c8e39..a4532922f5d 100644 | ||||
| --- a/test/dynamo/cpython/3_13/test_userlist.py | ||||
| +++ b/test/dynamo/cpython/3_13/test_userlist.py | ||||
| @@ -1,7 +1,58 @@ | ||||
| +# ======= BEGIN Dynamo patch ======= | ||||
| +# Owner(s): ["module: dynamo"] | ||||
| + | ||||
| +# ruff: noqa | ||||
| +# flake8: noqa | ||||
| + | ||||
| +import sys | ||||
| +import torch | ||||
| +import torch._dynamo.test_case | ||||
| +import unittest | ||||
| +from torch._dynamo.test_case import CPythonTestCase | ||||
| +from torch.testing._internal.common_utils import run_tests | ||||
| + | ||||
| +__TestCase = CPythonTestCase | ||||
| + | ||||
| + | ||||
| +# redirect import statements | ||||
| +import sys | ||||
| +import importlib.abc | ||||
| + | ||||
| +redirect_imports = ( | ||||
| +    "test.mapping_tests", | ||||
| +    "test.typinganndata", | ||||
| +    "test.test_grammar", | ||||
| +    "test.test_math", | ||||
| +    "test.test_iter", | ||||
| +    "test.typinganndata.ann_module", | ||||
| +) | ||||
| + | ||||
| +class RedirectImportFinder(importlib.abc.MetaPathFinder): | ||||
| +    def find_spec(self, fullname, path, target=None): | ||||
| +        # Check if the import is the problematic one | ||||
| +        if fullname in redirect_imports: | ||||
| +            try: | ||||
| +                # Attempt to import the standalone module | ||||
| +                name = fullname.removeprefix("test.") | ||||
| +                r = importlib.import_module(name) | ||||
| +                # Redirect the module in sys.modules | ||||
| +                sys.modules[fullname] = r | ||||
| +                # Return a module spec from the found module | ||||
| +                return importlib.util.find_spec(name) | ||||
| +            except ImportError: | ||||
| +                return None | ||||
| +        return None | ||||
| + | ||||
| +# Add the custom finder to sys.meta_path | ||||
| +sys.meta_path.insert(0, RedirectImportFinder()) | ||||
| + | ||||
| + | ||||
| +# ======= END DYNAMO PATCH ======= | ||||
| + | ||||
|  # Check every path through every method of UserList | ||||
|   | ||||
|  from collections import UserList | ||||
| -from test import list_tests | ||||
| +import list_tests | ||||
|  import unittest | ||||
|  from test import support | ||||
|   | ||||
| @@ -69,9 +120,9 @@ class UserListTest(list_tests.CommonTest): | ||||
|   | ||||
|      # Decorate existing test with recursion limit, because | ||||
|      # the test is for C structure, but `UserList` is a Python structure. | ||||
| -    test_repr_deep = support.infinite_recursion(25)( | ||||
| -        list_tests.CommonTest.test_repr_deep, | ||||
| -    ) | ||||
| +    # test_repr_deep = support.infinite_recursion(25)( | ||||
| +    #     list_tests.CommonTest.test_repr_deep, | ||||
| +    # ) | ||||
|   | ||||
|  if __name__ == "__main__": | ||||
| -    unittest.main() | ||||
| +    run_tests() | ||||
							
								
								
									
										128
									
								
								test/dynamo/cpython/3_13/test_userlist.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										128
									
								
								test/dynamo/cpython/3_13/test_userlist.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,128 @@ | ||||
| # ======= BEGIN Dynamo patch ======= | ||||
| # Owner(s): ["module: dynamo"] | ||||
|  | ||||
| # ruff: noqa | ||||
| # flake8: noqa | ||||
|  | ||||
| import sys | ||||
| import torch | ||||
| import torch._dynamo.test_case | ||||
| import unittest | ||||
| from torch._dynamo.test_case import CPythonTestCase | ||||
| from torch.testing._internal.common_utils import run_tests | ||||
|  | ||||
| __TestCase = CPythonTestCase | ||||
|  | ||||
|  | ||||
| # redirect import statements | ||||
| import sys | ||||
| import importlib.abc | ||||
|  | ||||
| redirect_imports = ( | ||||
|     "test.mapping_tests", | ||||
|     "test.typinganndata", | ||||
|     "test.test_grammar", | ||||
|     "test.test_math", | ||||
|     "test.test_iter", | ||||
|     "test.typinganndata.ann_module", | ||||
| ) | ||||
|  | ||||
| class RedirectImportFinder(importlib.abc.MetaPathFinder): | ||||
|     def find_spec(self, fullname, path, target=None): | ||||
|         # Check if the import is the problematic one | ||||
|         if fullname in redirect_imports: | ||||
|             try: | ||||
|                 # Attempt to import the standalone module | ||||
|                 name = fullname.removeprefix("test.") | ||||
|                 r = importlib.import_module(name) | ||||
|                 # Redirect the module in sys.modules | ||||
|                 sys.modules[fullname] = r | ||||
|                 # Return a module spec from the found module | ||||
|                 return importlib.util.find_spec(name) | ||||
|             except ImportError: | ||||
|                 return None | ||||
|         return None | ||||
|  | ||||
| # Add the custom finder to sys.meta_path | ||||
| sys.meta_path.insert(0, RedirectImportFinder()) | ||||
|  | ||||
|  | ||||
| # ======= END DYNAMO PATCH ======= | ||||
|  | ||||
| # Check every path through every method of UserList | ||||
|  | ||||
| from collections import UserList | ||||
| import list_tests | ||||
| import unittest | ||||
| from test import support | ||||
|  | ||||
|  | ||||
| class UserListTest(list_tests.CommonTest): | ||||
|     type2test = UserList | ||||
|  | ||||
|     def test_getslice(self): | ||||
|         super().test_getslice() | ||||
|         l = [0, 1, 2, 3, 4] | ||||
|         u = self.type2test(l) | ||||
|         for i in range(-3, 6): | ||||
|             self.assertEqual(u[:i], l[:i]) | ||||
|             self.assertEqual(u[i:], l[i:]) | ||||
|             for j in range(-3, 6): | ||||
|                 self.assertEqual(u[i:j], l[i:j]) | ||||
|  | ||||
|     def test_slice_type(self): | ||||
|         l = [0, 1, 2, 3, 4] | ||||
|         u = UserList(l) | ||||
|         self.assertIsInstance(u[:], u.__class__) | ||||
|         self.assertEqual(u[:],u) | ||||
|  | ||||
|     def test_add_specials(self): | ||||
|         u = UserList("spam") | ||||
|         u2 = u + "eggs" | ||||
|         self.assertEqual(u2, list("spameggs")) | ||||
|  | ||||
|     def test_radd_specials(self): | ||||
|         u = UserList("eggs") | ||||
|         u2 = "spam" + u | ||||
|         self.assertEqual(u2, list("spameggs")) | ||||
|         u2 = u.__radd__(UserList("spam")) | ||||
|         self.assertEqual(u2, list("spameggs")) | ||||
|  | ||||
|     def test_iadd(self): | ||||
|         super().test_iadd() | ||||
|         u = [0, 1] | ||||
|         u += UserList([0, 1]) | ||||
|         self.assertEqual(u, [0, 1, 0, 1]) | ||||
|  | ||||
|     def test_mixedcmp(self): | ||||
|         u = self.type2test([0, 1]) | ||||
|         self.assertEqual(u, [0, 1]) | ||||
|         self.assertNotEqual(u, [0]) | ||||
|         self.assertNotEqual(u, [0, 2]) | ||||
|  | ||||
|     def test_mixedadd(self): | ||||
|         u = self.type2test([0, 1]) | ||||
|         self.assertEqual(u + [], u) | ||||
|         self.assertEqual(u + [2], [0, 1, 2]) | ||||
|  | ||||
|     def test_getitemoverwriteiter(self): | ||||
|         # Verify that __getitem__ overrides *are* recognized by __iter__ | ||||
|         class T(self.type2test): | ||||
|             def __getitem__(self, key): | ||||
|                 return str(key) + '!!!' | ||||
|         self.assertEqual(next(iter(T((1,2)))), "0!!!") | ||||
|  | ||||
|     def test_userlist_copy(self): | ||||
|         u = self.type2test([6, 8, 1, 9, 1]) | ||||
|         v = u.copy() | ||||
|         self.assertEqual(u, v) | ||||
|         self.assertEqual(type(u), type(v)) | ||||
|  | ||||
|     # Decorate existing test with recursion limit, because | ||||
|     # the test is for C structure, but `UserList` is a Python structure. | ||||
|     # test_repr_deep = support.infinite_recursion(25)( | ||||
|     #     list_tests.CommonTest.test_repr_deep, | ||||
|     # ) | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
							
								
								
									
										0
									
								
								test/dynamo/cpython/3_13/typinganndata/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								test/dynamo/cpython/3_13/typinganndata/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										30
									
								
								test/dynamo/cpython/3_13/typinganndata/_typed_dict_helper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								test/dynamo/cpython/3_13/typinganndata/_typed_dict_helper.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,30 @@ | ||||
| """Used to test `get_type_hints()` on a cross-module inherited `TypedDict` class | ||||
|  | ||||
| This script uses future annotations to postpone a type that won't be available | ||||
| on the module inheriting from to `Foo`. The subclass in the other module should | ||||
| look something like this: | ||||
|  | ||||
|     class Bar(_typed_dict_helper.Foo, total=False): | ||||
|         b: int | ||||
|  | ||||
| In addition, it uses multiple levels of Annotated to test the interaction | ||||
| between the __future__ import, Annotated, and Required. | ||||
| """ | ||||
|  | ||||
| from __future__ import annotations | ||||
|  | ||||
| from typing import Annotated, Generic, Optional, Required, TypedDict, TypeVar | ||||
|  | ||||
|  | ||||
| OptionalIntType = Optional[int] | ||||
|  | ||||
| class Foo(TypedDict): | ||||
|     a: OptionalIntType | ||||
|  | ||||
| T = TypeVar("T") | ||||
|  | ||||
| class FooGeneric(TypedDict, Generic[T]): | ||||
|     a: Optional[T] | ||||
|  | ||||
| class VeryAnnotated(TypedDict, total=False): | ||||
|     a: Annotated[Annotated[Annotated[Required[int], "a"], "b"], "c"] | ||||
							
								
								
									
										62
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,62 @@ | ||||
|  | ||||
|  | ||||
| """ | ||||
| The module for testing variable annotations. | ||||
| Empty lines above are for good reason (testing for correct line numbers) | ||||
| """ | ||||
|  | ||||
| from typing import Optional | ||||
| from functools import wraps | ||||
|  | ||||
| __annotations__[1] = 2 | ||||
|  | ||||
| class C: | ||||
|  | ||||
|     x = 5; y: Optional['C'] = None | ||||
|  | ||||
| from typing import Tuple | ||||
| x: int = 5; y: str = x; f: Tuple[int, int] | ||||
|  | ||||
| class M(type): | ||||
|  | ||||
|     __annotations__['123'] = 123 | ||||
|     o: type = object | ||||
|  | ||||
| (pars): bool = True | ||||
|  | ||||
| class D(C): | ||||
|     j: str = 'hi'; k: str= 'bye' | ||||
|  | ||||
| from types import new_class | ||||
| h_class = new_class('H', (C,)) | ||||
| j_class = new_class('J') | ||||
|  | ||||
| class F(): | ||||
|     z: int = 5 | ||||
|     def __init__(self, x): | ||||
|         pass | ||||
|  | ||||
| class Y(F): | ||||
|     def __init__(self): | ||||
|         super(F, self).__init__(123) | ||||
|  | ||||
| class Meta(type): | ||||
|     def __new__(meta, name, bases, namespace): | ||||
|         return super().__new__(meta, name, bases, namespace) | ||||
|  | ||||
| class S(metaclass = Meta): | ||||
|     x: str = 'something' | ||||
|     y: str = 'something else' | ||||
|  | ||||
| def foo(x: int = 10): | ||||
|     def bar(y: List[str]): | ||||
|         x: str = 'yes' | ||||
|     bar() | ||||
|  | ||||
| def dec(func): | ||||
|     @wraps(func) | ||||
|     def wrapper(*args, **kwargs): | ||||
|         return func(*args, **kwargs) | ||||
|     return wrapper | ||||
|  | ||||
| u: int | float | ||||
							
								
								
									
										36
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module2.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,36 @@ | ||||
| """ | ||||
| Some correct syntax for variable annotation here. | ||||
| More examples are in test_grammar and test_parser. | ||||
| """ | ||||
|  | ||||
| from typing import no_type_check, ClassVar | ||||
|  | ||||
| i: int = 1 | ||||
| j: int | ||||
| x: float = i/10 | ||||
|  | ||||
| def f(): | ||||
|     class C: ... | ||||
|     return C() | ||||
|  | ||||
| f().new_attr: object = object() | ||||
|  | ||||
| class C: | ||||
|     def __init__(self, x: int) -> None: | ||||
|         self.x = x | ||||
|  | ||||
| c = C(5) | ||||
| c.new_attr: int = 10 | ||||
|  | ||||
| __annotations__ = {} | ||||
|  | ||||
|  | ||||
| @no_type_check | ||||
| class NTC: | ||||
|     def meth(self, param: complex) -> None: | ||||
|         ... | ||||
|  | ||||
| class CV: | ||||
|     var: ClassVar['CV'] | ||||
|  | ||||
| CV.var = CV() | ||||
							
								
								
									
										18
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module3.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module3.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,18 @@ | ||||
| """ | ||||
| Correct syntax for variable annotation that should fail at runtime | ||||
| in a certain manner. More examples are in test_grammar and test_parser. | ||||
| """ | ||||
|  | ||||
| def f_bad_ann(): | ||||
|     __annotations__[1] = 2 | ||||
|  | ||||
| class C_OK: | ||||
|     def __init__(self, x: int) -> None: | ||||
|         self.x: no_such_name = x  # This one is OK as proposed by Guido | ||||
|  | ||||
| class D_bad_ann: | ||||
|     def __init__(self, x: int) -> None: | ||||
|         sfel.y: int = 0 | ||||
|  | ||||
| def g_bad_ann(): | ||||
|     no_such_name.attr: int = 0 | ||||
							
								
								
									
										5
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module4.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module4.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,5 @@ | ||||
| # This ann_module isn't for test_typing, | ||||
| # it's for test_module | ||||
|  | ||||
| a:int=3 | ||||
| b:str=4 | ||||
							
								
								
									
										10
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module5.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module5.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,10 @@ | ||||
| # Used by test_typing to verify that Final wrapped in ForwardRef works. | ||||
|  | ||||
| from __future__ import annotations | ||||
|  | ||||
| from typing import Final | ||||
|  | ||||
| name: Final[str] = "final" | ||||
|  | ||||
| class MyClass: | ||||
|     value: Final = 3000 | ||||
							
								
								
									
										7
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module6.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module6.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,7 @@ | ||||
| # Tests that top-level ClassVar is not allowed | ||||
|  | ||||
| from __future__ import annotations | ||||
|  | ||||
| from typing import ClassVar | ||||
|  | ||||
| wrong: ClassVar[int] = 1 | ||||
							
								
								
									
										22
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module695.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module695.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,22 @@ | ||||
| from __future__ import annotations | ||||
| from typing import Callable | ||||
|  | ||||
|  | ||||
| class A[T, *Ts, **P]: | ||||
|     x: T | ||||
|     y: tuple[*Ts] | ||||
|     z: Callable[P, str] | ||||
|  | ||||
|  | ||||
| class B[T, *Ts, **P]: | ||||
|     T = int | ||||
|     Ts = str | ||||
|     P = bytes | ||||
|     x: T | ||||
|     y: Ts | ||||
|     z: P | ||||
|  | ||||
|  | ||||
| def generic_function[T, *Ts, **P]( | ||||
|     x: T, *y: *Ts, z: P.args, zz: P.kwargs | ||||
| ) -> None: ... | ||||
							
								
								
									
										11
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module7.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module7.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,11 @@ | ||||
| # Tests class have ``__text_signature__`` | ||||
|  | ||||
| from __future__ import annotations | ||||
|  | ||||
| DEFAULT_BUFFER_SIZE = 8192 | ||||
|  | ||||
| class BufferedReader(object): | ||||
|     """BufferedReader(raw, buffer_size=DEFAULT_BUFFER_SIZE)\n--\n\n | ||||
|     Create a new buffered reader using the given readable raw IO object. | ||||
|     """ | ||||
|     pass | ||||
							
								
								
									
										10
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module8.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module8.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,10 @@ | ||||
| # Test `@no_type_check`, | ||||
| # see https://bugs.python.org/issue46571 | ||||
|  | ||||
| class NoTypeCheck_Outer: | ||||
|     class Inner: | ||||
|         x: int | ||||
|  | ||||
|  | ||||
| def NoTypeCheck_function(arg: int) -> int: | ||||
|     ... | ||||
							
								
								
									
										14
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module9.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								test/dynamo/cpython/3_13/typinganndata/ann_module9.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,14 @@ | ||||
| # Test ``inspect.formatannotation`` | ||||
| # https://github.com/python/cpython/issues/96073 | ||||
|  | ||||
| from typing import Union, List | ||||
|  | ||||
| ann = Union[List[str], int] | ||||
|  | ||||
| # mock typing._type_repr behaviour | ||||
| class A: ... | ||||
|  | ||||
| A.__module__ = 'testModule.typing' | ||||
| A.__qualname__ = 'A' | ||||
|  | ||||
| ann1 = Union[List[A], int] | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	