mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	Compare commits
	
		
			129 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| aa91de45d0 | |||
| f4eca0e3b3 | |||
| ed3438ff13 | |||
| 7dcb568c8f | |||
| bb7c9a2d41 | |||
| 159c2140f7 | |||
| 62a746f62c | |||
| 8627454c84 | |||
| 93964ed6ab | |||
| 80f8be9840 | |||
| e36a6fcf0f | |||
| 23af32a078 | |||
| a81a2e54ed | |||
| 4b7aed89d8 | |||
| 1aeac304b8 | |||
| 56893ca1f6 | |||
| af8c232b75 | |||
| 4908fb53c3 | |||
| 3c8b90542c | |||
| 1f21f8544c | |||
| 1330c638be | |||
| 8bc4a467a7 | |||
| e769026bcb | |||
| 14f8d86136 | |||
| c43ccfbc2d | |||
| cfb8aec1a4 | |||
| 98ce93db0b | |||
| 627482a7b7 | |||
| c5e7bb08b0 | |||
| 6f9b4ccf8f | |||
| 708dc6e3cd | |||
| 7803d2c244 | |||
| 033b7d1e1a | |||
| d734b26141 | |||
| a27c002186 | |||
| 0f462740a0 | |||
| d6aaf08344 | |||
| 13304401df | |||
| 8e48d1ba25 | |||
| 6189a5f731 | |||
| 48a7e8cc70 | |||
| f17e2ab1f9 | |||
| e14b290d1e | |||
| 04ddea44fd | |||
| 57a54a04b6 | |||
| 6edfb3062c | |||
| 72fedf0575 | |||
| b26d4c9a7a | |||
| bb25c60945 | |||
| fdcef1477c | |||
| e93706c2c8 | |||
| 26eefd5ae2 | |||
| 28c42cc280 | |||
| 0661ecdb38 | |||
| 7a0f93344e | |||
| 63276edb7c | |||
| dfda2dfd53 | |||
| 876824f174 | |||
| a89d5e97ec | |||
| 4660e38e5a | |||
| 5236007806 | |||
| 167ad09be5 | |||
| bcbb45b746 | |||
| 3cad2403cb | |||
| e63476b236 | |||
| 4f641aa1a2 | |||
| 8dbac62edb | |||
| 7a1e267d4a | |||
| 2291199e9b | |||
| 0e9f9c3a61 | |||
| 0e9e3cf996 | |||
| c5c9e20f11 | |||
| d1993c27ae | |||
| 928ac57c2a | |||
| f2206b1ed8 | |||
| 4ca3f435fb | |||
| 79fd497423 | |||
| 9b7a8c4d05 | |||
| 16475a829f | |||
| 6cfb080d84 | |||
| bc38c5baa1 | |||
| 607489f3d0 | |||
| 89a6dbe73a | |||
| c6392fcc06 | |||
| c52c4052d8 | |||
| 175299416b | |||
| a97cefac15 | |||
| 821458d97a | |||
| c9f16f201a | |||
| b229455ddd | |||
| a5419743c6 | |||
| a63221a335 | |||
| c9485f8ff3 | |||
| 71b272e4a3 | |||
| 39450e7b00 | |||
| f1eb99e2e4 | |||
| bb635a11f8 | |||
| 3bfa35d62e | |||
| 7a3791c5d0 | |||
| e28983be76 | |||
| 794b48c9f4 | |||
| 65845d7291 | |||
| 3009b6959a | |||
| df4ebddbe0 | |||
| e13cf68d03 | |||
| 814338826e | |||
| c527292c43 | |||
| d4554bc284 | |||
| f6ea41ead2 | |||
| 489860f3c2 | |||
| 9494b09549 | |||
| c230ac7300 | |||
| 77cafe105a | |||
| 66308fb470 | |||
| 232dd65c15 | |||
| 505ee42570 | |||
| 9babcae1ed | |||
| 69a5a5ac02 | |||
| a4e74f416b | |||
| cb7f45fd34 | |||
| 5937861eba | |||
| bb3f3cc65e | |||
| 0819de412d | |||
| 6db37d7206 | |||
| 559e8d1c20 | |||
| 6702f545d8 | |||
| ddf3124b05 | |||
| 457b27f92f | |||
| b6a48ff69f | 
| @ -31,8 +31,7 @@ pip install -r /pytorch/requirements.txt | ||||
| pip install auditwheel==6.2.0 wheel | ||||
| if [ "$DESIRED_CUDA" = "cpu" ]; then | ||||
|     echo "BASE_CUDA_VERSION is not set. Building cpu wheel." | ||||
|     #USE_PRIORITIZED_TEXT_FOR_LD for enable linker script optimization https://github.com/pytorch/pytorch/pull/121975/files | ||||
|     USE_PRIORITIZED_TEXT_FOR_LD=1 python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn | ||||
|     python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn | ||||
| else | ||||
|     echo "BASE_CUDA_VERSION is set to: $DESIRED_CUDA" | ||||
|     export USE_SYSTEM_NCCL=1 | ||||
| @ -46,6 +45,5 @@ else | ||||
|         export USE_NVIDIA_PYPI_LIBS=1 | ||||
|     fi | ||||
|  | ||||
|     #USE_PRIORITIZED_TEXT_FOR_LD for enable linker script optimization https://github.com/pytorch/pytorch/pull/121975/files | ||||
|     USE_PRIORITIZED_TEXT_FOR_LD=1 python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda | ||||
|     python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda | ||||
| fi | ||||
|  | ||||
| @ -317,7 +317,7 @@ if __name__ == "__main__": | ||||
|     ).decode() | ||||
|  | ||||
|     print("Building PyTorch wheel") | ||||
|     build_vars = "CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000 " | ||||
|     build_vars = "" | ||||
|     # MAX_JOB=5 is not required for CPU backend (see commit 465d98b) | ||||
|     if enable_cuda: | ||||
|         build_vars += "MAX_JOBS=5 " | ||||
|  | ||||
| @ -1 +1 @@ | ||||
| 56392aa978594cc155fa8af48cd949f5b5f1823a | ||||
| e0dda9059d082537cee36be6c5e4fe3b18c880c0 | ||||
|  | ||||
| @ -42,22 +42,27 @@ install_pip_dependencies() { | ||||
|   # A workaround, ExecuTorch has moved to numpy 2.0 which is not compatible with the current | ||||
|   # numba and scipy version used in PyTorch CI | ||||
|   conda_run pip uninstall -y numba scipy | ||||
|   # Yaspin is needed for running CI test (get_benchmark_analysis_data.py) | ||||
|   pip_install yaspin==3.1.0 | ||||
|  | ||||
|   popd | ||||
| } | ||||
|  | ||||
| setup_executorch() { | ||||
|   pushd executorch | ||||
|  | ||||
|   export PYTHON_EXECUTABLE=python | ||||
|   export CMAKE_ARGS="-DEXECUTORCH_BUILD_PYBIND=ON -DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON" | ||||
|   export CMAKE_ARGS="-DEXECUTORCH_BUILD_PYBIND=ON -DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON -DEXECUTORCH_BUILD_TESTS=ON" | ||||
|  | ||||
|   as_jenkins .ci/scripts/setup-linux.sh --build-tool cmake || true | ||||
|   popd | ||||
| } | ||||
|  | ||||
| clone_executorch | ||||
| install_buck2 | ||||
| install_conda_dependencies | ||||
| install_pip_dependencies | ||||
| setup_executorch | ||||
| if [ $# -eq 0 ]; then | ||||
|   clone_executorch | ||||
|   install_buck2 | ||||
|   install_conda_dependencies | ||||
|   install_pip_dependencies | ||||
|   pushd executorch | ||||
|   setup_executorch | ||||
|   popd | ||||
| else | ||||
|   "$@" | ||||
| fi | ||||
|  | ||||
| @ -1,40 +0,0 @@ | ||||
| #!/bin/bash | ||||
|  | ||||
| # This is where the local pytorch install in the docker image is located | ||||
| pt_checkout="/var/lib/jenkins/workspace" | ||||
| source "$pt_checkout/.ci/pytorch/common_utils.sh" | ||||
| echo "functorch_doc_push_script.sh: Invoked with $*" | ||||
|  | ||||
| set -ex -o pipefail | ||||
|  | ||||
| version=${DOCS_VERSION:-nightly} | ||||
| echo "version: $version" | ||||
|  | ||||
| # Build functorch docs | ||||
| pushd $pt_checkout/functorch/docs | ||||
| make html | ||||
| popd | ||||
|  | ||||
| git clone https://github.com/pytorch/functorch -b gh-pages --depth 1 functorch_ghpages | ||||
| pushd functorch_ghpages | ||||
|  | ||||
| if [ "$version" == "main" ]; then | ||||
|   version=nightly | ||||
| fi | ||||
|  | ||||
| git rm -rf "$version" || true | ||||
| mv "$pt_checkout/functorch/docs/build/html" "$version" | ||||
|  | ||||
| git add "$version" || true | ||||
| git status | ||||
| git config user.email "soumith+bot@pytorch.org" | ||||
| git config user.name "pytorchbot" | ||||
| # If there aren't changes, don't make a commit; push is no-op | ||||
| git commit -m "Generate Python docs from pytorch/pytorch@${GITHUB_SHA}" || true | ||||
| git status | ||||
|  | ||||
| if [[ "${WITH_PUSH:-}" == true ]]; then | ||||
|   git push -u origin gh-pages | ||||
| fi | ||||
|  | ||||
| popd | ||||
							
								
								
									
										25
									
								
								.ci/pytorch/numba-cuda-13.patch
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								.ci/pytorch/numba-cuda-13.patch
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,25 @@ | ||||
| From 6e08c9d08e9de59c7af28b720289debbbd384764 Mon Sep 17 00:00:00 2001 | ||||
| From: Michael Wang <13521008+isVoid@users.noreply.github.com> | ||||
| Date: Tue, 1 Apr 2025 17:28:05 -0700 | ||||
| Subject: [PATCH] Avoid bumping certain driver API to avoid future breakage | ||||
|  (#185) | ||||
|  | ||||
| Co-authored-by: isVoid <isVoid@users.noreply.github.com> | ||||
| --- | ||||
|  numba_cuda/numba/cuda/cudadrv/driver.py | 3 +++ | ||||
|  1 file changed, 3 insertions(+) | ||||
|  | ||||
| diff --git a/numba_cuda/numba/cuda/cudadrv/driver.py b/numba_cuda/numba/cuda/cudadrv/driver.py | ||||
| index 1641bf77..233e9ed7 100644 | ||||
| --- a/numba_cuda/numba/cuda/cudadrv/driver.py | ||||
| +++ b/numba_cuda/numba/cuda/cudadrv/driver.py | ||||
| @@ -365,6 +365,9 @@ def _find_api(self, fname): | ||||
|          else: | ||||
|              variants = ('_v2', '') | ||||
|   | ||||
| +        if fname in ("cuCtxGetDevice", "cuCtxSynchronize"): | ||||
| +            return getattr(self.lib, fname) | ||||
| + | ||||
|          for variant in variants: | ||||
|              try: | ||||
|                  return getattr(self.lib, f'{fname}{variant}') | ||||
| @ -32,6 +32,16 @@ if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *s390x* && -d /v | ||||
|   git config --global --add safe.directory /var/lib/jenkins/workspace | ||||
| fi | ||||
|  | ||||
|  | ||||
| # Patch numba to avoid CUDA-13 crash, see https://github.com/pytorch/pytorch/issues/162878 | ||||
| NUMBA_CUDA_DIR=$(python -c "import os;import numba.cuda; print(os.path.dirname(numba.cuda.__file__))" 2>/dev/null || true) | ||||
| if [ -n "$NUMBA_CUDA_DIR" ]; then | ||||
|   NUMBA_PATCH="$(dirname "$(realpath "${BASH_SOURCE[0]}")")/numba-cuda-13.patch" | ||||
|   pushd "$NUMBA_CUDA_DIR" | ||||
|   patch -p4 <"$NUMBA_PATCH" | ||||
|   popd | ||||
| fi | ||||
|  | ||||
| echo "Environment variables:" | ||||
| env | ||||
|  | ||||
| @ -1540,14 +1550,10 @@ test_executorch() { | ||||
|   install_torchvision | ||||
|   install_torchaudio | ||||
|  | ||||
|   INSTALL_SCRIPT="$(pwd)/.ci/docker/common/install_executorch.sh" | ||||
|  | ||||
|   pushd /executorch | ||||
|  | ||||
|   export PYTHON_EXECUTABLE=python | ||||
|   export CMAKE_ARGS="-DEXECUTORCH_BUILD_PYBIND=ON -DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON" | ||||
|  | ||||
|   # NB: We need to rebuild ExecuTorch runner here because it depends on PyTorch | ||||
|   # from the PR | ||||
|   bash .ci/scripts/setup-linux.sh --build-tool cmake | ||||
|   "${INSTALL_SCRIPT}" setup_executorch | ||||
|  | ||||
|   echo "Run ExecuTorch unit tests" | ||||
|   pytest -v -n auto | ||||
| @ -1561,10 +1567,6 @@ test_executorch() { | ||||
|  | ||||
|   popd | ||||
|  | ||||
|   # Test torchgen generated code for Executorch. | ||||
|   echo "Testing ExecuTorch op registration" | ||||
|   "$BUILD_BIN_DIR"/test_edge_op_registration | ||||
|  | ||||
|   assert_git_not_dirty | ||||
| } | ||||
|  | ||||
| @ -1572,6 +1574,7 @@ test_linux_aarch64() { | ||||
|   python test/run_test.py --include test_modules test_mkldnn test_mkldnn_fusion test_openmp test_torch test_dynamic_shapes \ | ||||
|         test_transformers test_multiprocessing test_numpy_interop test_autograd test_binary_ufuncs test_complex test_spectral_ops \ | ||||
|         test_foreach test_reductions test_unary_ufuncs test_tensor_creation_ops test_ops \ | ||||
|         distributed/elastic/timer/api_test distributed/elastic/timer/local_timer_example distributed/elastic/timer/local_timer_test \ | ||||
|         --shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose | ||||
|  | ||||
|   # Dynamo tests | ||||
| @ -1614,26 +1617,6 @@ test_operator_benchmark() { | ||||
|       --expected "expected_ci_operator_benchmark_eager_float32_cpu.csv" | ||||
| } | ||||
|  | ||||
| test_operator_microbenchmark() { | ||||
|   TEST_REPORTS_DIR=$(pwd)/test/test-reports | ||||
|   mkdir -p "$TEST_REPORTS_DIR" | ||||
|   TEST_DIR=$(pwd) | ||||
|  | ||||
|   pip_uninstall torch torchvision torchaudio | ||||
|   pip_install torch==2.8.0 torchvision torchaudio ninja --force-reinstall | ||||
|   cd benchmarks/operator_benchmark/pt_extension | ||||
|   python -m pip install . | ||||
|  | ||||
|   cd "${TEST_DIR}"/benchmarks/operator_benchmark | ||||
|   for OP_BENCHMARK_TESTS in matmul mm addmm bmm; do | ||||
|     $TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \ | ||||
|       --output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \ | ||||
|       --benchmark-name "PyTorch operator microbenchmark" --use-compile | ||||
|     $TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \ | ||||
|       --output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}.json" \ | ||||
|       --benchmark-name "PyTorch operator microbenchmark" | ||||
|   done | ||||
| } | ||||
|  | ||||
| if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then | ||||
|   (cd test && python -c "import torch; print(torch.__config__.show())") | ||||
| @ -1688,8 +1671,6 @@ elif [[ "${TEST_CONFIG}" == *operator_benchmark* ]]; then | ||||
|     test_operator_benchmark cpu ${TEST_MODE} | ||||
|  | ||||
|   fi | ||||
| elif [[ "${TEST_CONFIG}" == *operator_microbenchmark* ]]; then | ||||
|   test_operator_microbenchmark | ||||
| elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then | ||||
|   test_inductor_distributed | ||||
| elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then | ||||
|  | ||||
| @ -264,7 +264,7 @@ def unzip_artifact_and_replace_files() -> None: | ||||
|         change_content_to_new_version(f"artifacts/dist/{old_stem}/torch/version.py") | ||||
|  | ||||
|         for file in Path(f"artifacts/dist/{old_stem}").glob( | ||||
|             "*.dist-info/**", | ||||
|             "*.dist-info/*", | ||||
|         ): | ||||
|             change_content_to_new_version(file) | ||||
|  | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/vllm.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/vllm.txt
									
									
									
									
										vendored
									
									
								
							| @ -1 +1 @@ | ||||
| 5bcc153d7bf69ef34bc5788a33f60f1792cf2861 | ||||
| 5963b98b465007e3cfb0d39447e4459a8afa96dc | ||||
|  | ||||
							
								
								
									
										14
									
								
								.github/workflows/_docs.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/_docs.yml
									
									
									
									
										vendored
									
									
								
							| @ -75,10 +75,6 @@ jobs: | ||||
|             runner: ${{ inputs.runner_prefix }}linux.2xlarge | ||||
|             # It takes less than 30m to finish python docs unless there are issues | ||||
|             timeout-minutes: 30 | ||||
|           - docs_type: functorch | ||||
|             runner: ${{ inputs.runner_prefix }}linux.2xlarge | ||||
|             # It takes less than 15m to finish functorch docs unless there are issues | ||||
|             timeout-minutes: 15 | ||||
|     # Set a fixed name for this job instead of using the current matrix-generated name, i.e. build-docs (cpp, linux.12xlarge, 180) | ||||
|     # The current name requires updating the database last docs push query from test-infra every time the matrix is updated | ||||
|     name: build-docs-${{ matrix.docs_type }}-${{ inputs.push }} | ||||
| @ -211,16 +207,6 @@ jobs: | ||||
|           path: cppdocs/ | ||||
|           s3-prefix: pytorch/pytorch/${{ github.event.pull_request.number }}/cppdocs | ||||
|  | ||||
|       - name: Upload functorch Docs Preview | ||||
|         uses: seemethere/upload-artifact-s3@baba72d0712b404f646cebe0730933554ebce96a # v5.1.0 | ||||
|         if: ${{ github.event_name == 'pull_request' && matrix.docs_type == 'functorch' && steps.build-docs.outcome == 'success' }} | ||||
|         with: | ||||
|           retention-days: 14 | ||||
|           s3-bucket: doc-previews | ||||
|           if-no-files-found: error | ||||
|           path: functorch_ghpages/nightly/ | ||||
|           s3-prefix: pytorch/pytorch/${{ github.event.pull_request.number }}/functorchdocs | ||||
|  | ||||
|       - name: Teardown Linux | ||||
|         uses: pytorch/test-infra/.github/actions/teardown-linux@main | ||||
|         if: always() | ||||
|  | ||||
							
								
								
									
										4
									
								
								.github/workflows/_linux-test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/_linux-test.yml
									
									
									
									
										vendored
									
									
								
							| @ -169,7 +169,7 @@ jobs: | ||||
|         id: install-nvidia-driver | ||||
|         uses: pytorch/test-infra/.github/actions/setup-nvidia@main | ||||
|         with: | ||||
|           driver-version: ${{ matrix.config == 'legacy_nvidia_driver' && '525.105.17' || '570.133.07' }} | ||||
|           driver-version: ${{ matrix.config == 'legacy_nvidia_driver' && '525.105.17' || '580.82.07' }} | ||||
|         if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && !contains(matrix.runner, 'b200') }} | ||||
|  | ||||
|       - name: Setup GPU_FLAG for docker run | ||||
| @ -273,8 +273,6 @@ jobs: | ||||
|           TEST_CONFIG: ${{ matrix.config }} | ||||
|           SHARD_NUMBER: ${{ matrix.shard }} | ||||
|           NUM_TEST_SHARDS: ${{ matrix.num_shards }} | ||||
|           EXTRA_FLAGS: ${{ matrix.extra_flags || '' }} | ||||
|           OP_BENCHMARK_TESTS: ${{ matrix.op_benchmark_tests }} | ||||
|           REENABLED_ISSUES: ${{ steps.keep-going.outputs.reenabled-issues }} | ||||
|           CONTINUE_THROUGH_ERROR: ${{ steps.keep-going.outputs.keep-going }} | ||||
|           VERBOSE_TEST_LOGS: ${{ steps.keep-going.outputs.ci-verbose-test-logs }} | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/workflows/build-vllm-wheel.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/build-vllm-wheel.yml
									
									
									
									
										vendored
									
									
								
							| @ -178,7 +178,7 @@ jobs: | ||||
|       contents: read | ||||
|     container: | ||||
|       image: continuumio/miniconda3:4.12.0 | ||||
|     environment: ${{ (github.event_name == 'push' && github.event.ref == 'refs/heads/main') && 'nightly-wheel-upload' || '' }} | ||||
|     environment: ${{ ((github.event_name == 'push' && github.event.ref == 'refs/heads/main') || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && 'nightly-wheel-upload' || '' }} | ||||
|     steps: | ||||
|       - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | ||||
|  | ||||
|  | ||||
							
								
								
									
										3
									
								
								.github/workflows/docker-builds.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/docker-builds.yml
									
									
									
									
										vendored
									
									
								
							| @ -71,8 +71,7 @@ jobs: | ||||
|           pytorch-linux-jammy-py3-clang12-onnx, | ||||
|           pytorch-linux-jammy-linter, | ||||
|           pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter, | ||||
|           # Executorch pin needs update | ||||
|           # pytorch-linux-jammy-py3-clang12-executorch, | ||||
|           pytorch-linux-jammy-py3-clang12-executorch, | ||||
|           pytorch-linux-jammy-py3.12-triton-cpu, | ||||
|           pytorch-linux-noble-riscv64-py3.12-gcc14 | ||||
|         ] | ||||
|  | ||||
							
								
								
									
										46
									
								
								.github/workflows/operator_microbenchmark.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										46
									
								
								.github/workflows/operator_microbenchmark.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,46 +0,0 @@ | ||||
| name: operator_microbenchmark | ||||
|  | ||||
| on: | ||||
|   push: | ||||
|     tags: | ||||
|       - ciflow/op-benchmark/* | ||||
|   workflow_dispatch: | ||||
|   schedule: | ||||
|     # Run at 06:00 UTC everyday | ||||
|     - cron: 0 6 * * * | ||||
|  | ||||
| concurrency: | ||||
|   group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} | ||||
|   cancel-in-progress: true | ||||
|  | ||||
| permissions: | ||||
|   id-token: write | ||||
|   contents: read | ||||
|  | ||||
| jobs: | ||||
|   opmicrobenchmark-build: | ||||
|     if: github.repository_owner == 'pytorch' | ||||
|     name: opmicrobenchmark-build | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     with: | ||||
|       runner: linux.12xlarge.memory | ||||
|       build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 | ||||
|       cuda-arch-list: '8.0 9.0' | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.h100" }, | ||||
|           { config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.a100" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   opmicrobenchmark-test: | ||||
|     name: opmicrobenchmark-test | ||||
|     uses: ./.github/workflows/_linux-test.yml | ||||
|     needs: opmicrobenchmark-build | ||||
|     with: | ||||
|       timeout-minutes: 500 | ||||
|       build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 | ||||
|       docker-image: ${{ needs.opmicrobenchmark-build.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.opmicrobenchmark-build.outputs.test-matrix }} | ||||
|     secrets: inherit | ||||
| @ -1,46 +0,0 @@ | ||||
| name: operator_microbenchmark_b200 | ||||
|  | ||||
| on: | ||||
|   push: | ||||
|     tags: | ||||
|       - ciflow/op-benchmark/* | ||||
|   workflow_dispatch: | ||||
|   schedule: | ||||
|     # Run at 06:00 UTC everyday | ||||
|     - cron: 0 6 * * * | ||||
|  | ||||
| concurrency: | ||||
|   group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} | ||||
|   cancel-in-progress: true | ||||
|  | ||||
| permissions: | ||||
|   id-token: write | ||||
|   contents: read | ||||
|  | ||||
| jobs: | ||||
|   opmicrobenchmark-build: | ||||
|     if: github.repository_owner == 'pytorch' | ||||
|     name: opmicrobenchmark-build | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     with: | ||||
|       runner: linux.12xlarge.memory | ||||
|       build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 | ||||
|       cuda-arch-list: '10.0' | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.dgx.b200" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   opmicrobenchmark-test: | ||||
|     name: opmicrobenchmark-test | ||||
|     uses: ./.github/workflows/_linux-test.yml | ||||
|     needs: opmicrobenchmark-build | ||||
|     with: | ||||
|       timeout-minutes: 500 | ||||
|       build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 | ||||
|       docker-image: ${{ needs.opmicrobenchmark-build.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.opmicrobenchmark-build.outputs.test-matrix }} | ||||
|       aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only | ||||
|     secrets: inherit | ||||
							
								
								
									
										28
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										28
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							| @ -127,6 +127,8 @@ jobs: | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     needs: get-label-type | ||||
|     with: | ||||
|       # More memory is needed to build with asan | ||||
|       runner: linux.2xlarge.memory | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build-environment: linux-jammy-py3.10-clang18-asan | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan | ||||
| @ -316,32 +318,6 @@ jobs: | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   linux-jammy-py3-clang12-executorch-build: | ||||
|     if: false  # Docker build needs pin update | ||||
|     name: linux-jammy-py3-clang12-executorch | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     needs: get-label-type | ||||
|     with: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build-environment: linux-jammy-py3-clang12-executorch | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-py3-clang12-executorch | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "executorch", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   linux-jammy-py3-clang12-executorch-test: | ||||
|     name: linux-jammy-py3-clang12-executorch | ||||
|     uses: ./.github/workflows/_linux-test.yml | ||||
|     needs: linux-jammy-py3-clang12-executorch-build | ||||
|     if: false # Has been broken for a while | ||||
|     with: | ||||
|       build-environment: linux-jammy-py3-clang12-executorch | ||||
|       docker-image: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.test-matrix }} | ||||
|     secrets: inherit | ||||
|  | ||||
|   linux-jammy-cuda12_8-py3_10-gcc9-inductor-build: | ||||
|     name: cuda12.8-py3.10-gcc9-sm75 | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/workflows/slow.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/slow.yml
									
									
									
									
										vendored
									
									
								
							| @ -140,6 +140,8 @@ jobs: | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     needs: get-label-type | ||||
|     with: | ||||
|       # More memory is needed to build with asan | ||||
|       runner: linux.2xlarge.memory | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build-environment: linux-jammy-py3.10-clang18-asan | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan | ||||
|  | ||||
							
								
								
									
										24
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										24
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							| @ -259,3 +259,27 @@ jobs: | ||||
|       docker-image: ${{ needs.verify-cachebench-cpu-build.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.verify-cachebench-cpu-build.outputs.test-matrix }} | ||||
|     secrets: inherit | ||||
|  | ||||
|   linux-jammy-py3-clang12-executorch-build: | ||||
|     name: linux-jammy-py3-clang12-executorch | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     needs: get-label-type | ||||
|     with: | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build-environment: linux-jammy-py3-clang12-executorch | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-py3-clang12-executorch | ||||
|       test-matrix: | | ||||
|         { include: [ | ||||
|           { config: "executorch", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, | ||||
|         ]} | ||||
|     secrets: inherit | ||||
|  | ||||
|   linux-jammy-py3-clang12-executorch-test: | ||||
|     name: linux-jammy-py3-clang12-executorch | ||||
|     uses: ./.github/workflows/_linux-test.yml | ||||
|     needs: linux-jammy-py3-clang12-executorch-build | ||||
|     with: | ||||
|       build-environment: linux-jammy-py3-clang12-executorch | ||||
|       docker-image: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.docker-image }} | ||||
|       test-matrix: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.test-matrix }} | ||||
|     secrets: inherit | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/workflows/vllm.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/vllm.yml
									
									
									
									
										vendored
									
									
								
							| @ -36,6 +36,8 @@ jobs: | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     needs: get-label-type | ||||
|     with: | ||||
|       # When building vLLM, uv doesn't like that we rename wheel without changing the wheel metadata | ||||
|       allow-reuse-old-whl: false | ||||
|       build-additional-packages: "vision audio" | ||||
|       build-external-packages: "vllm" | ||||
|       build-environment: linux-jammy-cuda12.8-py3.12-gcc11 | ||||
|  | ||||
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -259,6 +259,9 @@ gen | ||||
| .pytest_cache | ||||
| aten/build/* | ||||
|  | ||||
| # Linker scripts for prioritized text optimization | ||||
| cmake/linker_script.ld | ||||
|  | ||||
| # Bram | ||||
| plsdontbreak | ||||
|  | ||||
|  | ||||
| @ -964,7 +964,6 @@ exclude_patterns = [ | ||||
|     'test/jit/**',  # should be run through test/test_jit.py | ||||
|     'test/ao/sparsity/**',  # should be run through test/test_ao_sparsity.py | ||||
|     'test/fx/**',  # should be run through test/test_fx.py | ||||
|     'test/bottleneck_test/**',  # excluded by test/run_test.py | ||||
|     'test/package/**',  # excluded by test/run_test.py | ||||
|     'test/distributed/argparse_util_test.py', | ||||
|     'test/distributed/bin/test_script.py', | ||||
| @ -1410,8 +1409,6 @@ exclude_patterns = [ | ||||
|     'torch/utils/benchmark/utils/timer.py', | ||||
|     'torch/utils/benchmark/utils/valgrind_wrapper/__init__.py', | ||||
|     'torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py', | ||||
|     'torch/utils/bottleneck/__init__.py', | ||||
|     'torch/utils/bottleneck/__main__.py', | ||||
|     'torch/utils/bundled_inputs.py', | ||||
|     'torch/utils/checkpoint.py', | ||||
|     'torch/utils/collect_env.py', | ||||
|  | ||||
| @ -380,6 +380,13 @@ cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler" | ||||
|                        OFF "USE_CUDA" OFF) | ||||
| cmake_dependent_option(USE_KLEIDIAI "Use KleidiAI for the ARM CPU & AARCH64 architecture." ON | ||||
|                         "CPU_AARCH64" OFF) | ||||
| # prioritized text linker, ON by default for AArch64+Linux, option visible to all AArch64, x86 and ppc64le. | ||||
| set(USE_PRIORITIZED_TEXT_DEFAULT OFF) | ||||
| if(LINUX AND CPU_AARCH64) | ||||
|   set(USE_PRIORITIZED_TEXT_DEFAULT ON) | ||||
| endif() | ||||
| cmake_dependent_option(USE_PRIORITIZED_TEXT_FOR_LD "Use prioritized text linker for ld." | ||||
|   "${USE_PRIORITIZED_TEXT_DEFAULT}" "CPU_INTEL OR CPU_AARCH64 OR CPU_POWER" OFF) | ||||
|  | ||||
| option(USE_MIMALLOC "Use mimalloc" OFF) | ||||
| # Enable third party mimalloc library to improve memory allocation performance | ||||
| @ -657,6 +664,11 @@ endif(MSVC) | ||||
|  | ||||
| string(APPEND CMAKE_CUDA_FLAGS " -Xfatbin -compress-all") | ||||
|  | ||||
| # Set linker max-page-size to 64KiB on AArch64 Linux | ||||
| if(LINUX AND CPU_AARCH64) | ||||
|   add_link_options_if_supported("-z,max-page-size=0x10000") | ||||
| endif() | ||||
|  | ||||
| # Set INTERN_BUILD_MOBILE for all mobile builds. Components that are not | ||||
| # applicable to mobile are disabled by this variable. Setting | ||||
| # `BUILD_PYTORCH_MOBILE_WITH_HOST_TOOLCHAIN` environment variable can force it | ||||
| @ -891,7 +903,7 @@ IF(USE_FBGEMM_GENAI AND USE_ROCM AND NOT "gfx942" IN_LIST PYTORCH_ROCM_ARCH) | ||||
| endif() | ||||
|  | ||||
| # Set USE_FBGEMM_GENAI to ON for CUDA build on SM100. | ||||
| if(USE_CUDA AND "$ENV{TORCH_CUDA_ARCH_LIST}" MATCHES "10.0" AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) | ||||
| if(USE_CUDA AND "$ENV{TORCH_CUDA_ARCH_LIST}" MATCHES "10.0" AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8 AND NOT WIN32) | ||||
|   message(STATUS "Setting USE_FBGEMM_GENAI to ON, doing CUDA build for SM100a") | ||||
|   set(USE_FBGEMM_GENAI ON) | ||||
| endif() | ||||
| @ -1421,3 +1433,57 @@ if(BUILD_BUNDLE_PTXAS AND USE_CUDA) | ||||
|   install(PROGRAMS "${PROJECT_BINARY_DIR}/ptxas" | ||||
|           DESTINATION "${CMAKE_INSTALL_BINDIR}") | ||||
| endif() | ||||
|  | ||||
| if(USE_PRIORITIZED_TEXT_FOR_LD) | ||||
|   add_compile_options( | ||||
|     $<$<COMPILE_LANGUAGE:C,CXX>:-ffunction-sections> | ||||
|     $<$<COMPILE_LANGUAGE:C,CXX>:-fdata-sections> | ||||
|   ) | ||||
|   set(LINKER_SCRIPT_FILE_OUT "${CMAKE_SOURCE_DIR}/cmake/linker_script.ld") | ||||
|   set(LINKER_SCRIPT_FILE_IN "${CMAKE_SOURCE_DIR}/cmake/prioritized_text.txt") | ||||
|  | ||||
|   add_custom_command( | ||||
|     OUTPUT "${LINKER_SCRIPT_FILE_OUT}" | ||||
|     COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py --filein "${LINKER_SCRIPT_FILE_IN}" --fout "${LINKER_SCRIPT_FILE_OUT}" | ||||
|     DEPENDS ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py "${LINKER_SCRIPT_FILE_IN}" | ||||
|     COMMENT "Generating prioritized text linker files" | ||||
|     VERBATIM | ||||
|   ) | ||||
|  | ||||
|   add_custom_target(generate_linker_script DEPENDS "${LINKER_SCRIPT_FILE_OUT}") | ||||
|  | ||||
|   if(BUILD_PYTHON) | ||||
|     set(LINKER_OPT_TARGETS torch_python) | ||||
|   endif() | ||||
|  | ||||
|   if(NOT BUILD_LIBTORCHLESS) | ||||
|     list(APPEND LINKER_OPT_TARGETS torch_cpu c10) | ||||
|     if(USE_CUDA) | ||||
|       list(APPEND LINKER_OPT_TARGETS torch_cuda c10_cuda) | ||||
|     endif() | ||||
|     if(USE_XPU) | ||||
|       list(APPEND LINKER_OPT_TARGETS torch_xpu c10_xpu) | ||||
|     endif() | ||||
|     if(USE_ROCM) | ||||
|       list(APPEND LINKER_OPT_TARGETS torch_hip c10_hip) | ||||
|     endif() | ||||
|   endif() | ||||
|  | ||||
|   foreach(tgt IN LISTS LINKER_OPT_TARGETS) | ||||
|     if(TARGET ${tgt}) | ||||
|       add_dependencies("${tgt}" generate_linker_script) | ||||
|       target_link_options_if_supported(${tgt} "-T,${LINKER_SCRIPT_FILE_OUT}") | ||||
|       set_property(TARGET ${tgt} APPEND PROPERTY LINK_DEPENDS "${LINKER_SCRIPT_FILE_OUT}") | ||||
|     else() | ||||
|        message(WARNING "Requested target '${tgt}' for linker script optimization was not found.") | ||||
|     endif() | ||||
|   endforeach() | ||||
|  | ||||
| else() | ||||
|   if(LINUX AND CPU_AARCH64) | ||||
|     message(WARNING [[ | ||||
|     It is strongly recommend to enable linker script optimization for all AArch64 Linux builds. | ||||
|     To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1 | ||||
|     ]]) | ||||
|   endif() | ||||
| endif() | ||||
| @ -65,14 +65,24 @@ DLDataType getDLDataType(const Tensor& t) { | ||||
|       break; | ||||
|     // TODO(#146647): use macro here instead of spelling out each shell dtype | ||||
|     case ScalarType::Float8_e5m2: | ||||
|       dtype.code = DLDataTypeCode::kDLFloat8_e5m2; | ||||
|       break; | ||||
|     case ScalarType::Float8_e5m2fnuz: | ||||
|       dtype.code = DLDataTypeCode::kDLFloat8_e5m2fnuz; | ||||
|       break; | ||||
|     case ScalarType::Float8_e4m3fn: | ||||
|       dtype.code = DLDataTypeCode::kDLFloat8_e4m3fn; | ||||
|       break; | ||||
|     case ScalarType::Float8_e4m3fnuz: | ||||
|       dtype.code = DLDataTypeCode::kDLFloat8_e4m3fnuz; | ||||
|       break; | ||||
|     case ScalarType::Float8_e8m0fnu: | ||||
|       TORCH_CHECK_BUFFER(false, "float8 types are not supported by dlpack"); | ||||
|       dtype.code = DLDataTypeCode::kDLFloat8_e8m0fnu; | ||||
|       break; | ||||
|     case ScalarType::Float4_e2m1fn_x2: | ||||
|       TORCH_CHECK_BUFFER(false, "float4 types are not supported by dlpack"); | ||||
|       dtype.code = DLDataTypeCode::kDLFloat4_e2m1fn; | ||||
|       dtype.lanes = 2; | ||||
|       dtype.bits = 4; | ||||
|       break; | ||||
|     case ScalarType::QInt8: | ||||
|     case ScalarType::QUInt8: | ||||
| @ -177,7 +187,11 @@ static Device getATenDevice(DLDeviceType type, c10::DeviceIndex index, void* dat | ||||
|  | ||||
| ScalarType toScalarType(const DLDataType& dtype) { | ||||
|   ScalarType stype = ScalarType::Undefined; | ||||
|   TORCH_CHECK_BUFFER(dtype.lanes == 1, "ATen does not support lanes != 1"); | ||||
|   if (dtype.code != DLDataTypeCode::kDLFloat4_e2m1fn) { | ||||
|     TORCH_CHECK_BUFFER( | ||||
|         dtype.lanes == 1, | ||||
|         "ATen does not support lanes != 1 for dtype code", std::to_string(dtype.code)); | ||||
|   } | ||||
|   switch (dtype.code) { | ||||
|     case DLDataTypeCode::kDLUInt: | ||||
|       switch (dtype.bits) { | ||||
| @ -269,6 +283,73 @@ ScalarType toScalarType(const DLDataType& dtype) { | ||||
|               false, "Unsupported kDLBool bits ", std::to_string(dtype.bits)); | ||||
|       } | ||||
|       break; | ||||
|     case DLDataTypeCode::kDLFloat8_e5m2: | ||||
|       switch (dtype.bits) { | ||||
|         case 8: | ||||
|           stype = ScalarType::Float8_e5m2; | ||||
|           break; | ||||
|         default: | ||||
|           TORCH_CHECK_BUFFER( | ||||
|               false, "Unsupported kDLFloat8_e5m2 bits ", std::to_string(dtype.bits)); | ||||
|       } | ||||
|       break; | ||||
|     case DLDataTypeCode::kDLFloat8_e5m2fnuz: | ||||
|       switch (dtype.bits) { | ||||
|         case 8: | ||||
|           stype = ScalarType::Float8_e5m2fnuz; | ||||
|           break; | ||||
|         default: | ||||
|           TORCH_CHECK_BUFFER( | ||||
|               false, "Unsupported kDLFloat8_e5m2fnuz bits ", std::to_string(dtype.bits)); | ||||
|       } | ||||
|       break; | ||||
|     case DLDataTypeCode::kDLFloat8_e4m3fn: | ||||
|       switch (dtype.bits) { | ||||
|         case 8: | ||||
|           stype = ScalarType::Float8_e4m3fn; | ||||
|           break; | ||||
|         default: | ||||
|           TORCH_CHECK_BUFFER( | ||||
|               false, "Unsupported kDLFloat8_e4m3fn bits ", std::to_string(dtype.bits)); | ||||
|       } | ||||
|       break; | ||||
|     case DLDataTypeCode::kDLFloat8_e4m3fnuz: | ||||
|       switch (dtype.bits) { | ||||
|         case 8: | ||||
|           stype = ScalarType::Float8_e4m3fnuz; | ||||
|           break; | ||||
|         default: | ||||
|           TORCH_CHECK_BUFFER( | ||||
|               false, "Unsupported kDLFloat8_e4m3fnuz bits ", std::to_string(dtype.bits)); | ||||
|       } | ||||
|       break; | ||||
|     case DLDataTypeCode::kDLFloat8_e8m0fnu: | ||||
|       switch (dtype.bits) { | ||||
|         case 8: | ||||
|           stype = ScalarType::Float8_e8m0fnu; | ||||
|           break; | ||||
|         default: | ||||
|           TORCH_CHECK_BUFFER( | ||||
|               false, "Unsupported kDLFloat8_e8m0fnu bits ", std::to_string(dtype.bits)); | ||||
|       } | ||||
|       break; | ||||
|     case DLDataTypeCode::kDLFloat4_e2m1fn: | ||||
|       switch (dtype.bits) { | ||||
|         case 4: | ||||
|           switch (dtype.lanes) { | ||||
|             case 2: | ||||
|               stype = ScalarType::Float4_e2m1fn_x2; | ||||
|               break; | ||||
|             default: | ||||
|               TORCH_CHECK_BUFFER( | ||||
|                 false, "Unsupported kDLFloat4_e2m1fn lanes ", std::to_string(dtype.lanes)); | ||||
|           } | ||||
|           break; | ||||
|         default: | ||||
|           TORCH_CHECK_BUFFER( | ||||
|               false, "Unsupported kDLFloat4_e2m1fn bits ", std::to_string(dtype.bits)); | ||||
|       } | ||||
|       break; | ||||
|     default: | ||||
|       TORCH_CHECK_BUFFER(false, "Unsupported code ", std::to_string(dtype.code)); | ||||
|   } | ||||
| @ -354,8 +435,8 @@ T* toDLPackImpl(const Tensor& src) { | ||||
|   atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device()); | ||||
|   atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim()); | ||||
|   atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src); | ||||
|   atDLMTensor->tensor.dl_tensor.shape = view.sizes().data(); | ||||
|   atDLMTensor->tensor.dl_tensor.strides = view.strides().data(); | ||||
|   atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(view.sizes().data()); | ||||
|   atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(view.strides().data()); | ||||
|   atDLMTensor->tensor.dl_tensor.byte_offset = 0; | ||||
|   fillVersion(&atDLMTensor->tensor); | ||||
|  | ||||
|  | ||||
| @ -266,11 +266,14 @@ CUDAGeneratorImpl::CUDAGeneratorImpl( | ||||
|  * See Note [Acquire lock when using random generators] | ||||
|  */ | ||||
| void CUDAGeneratorImpl::set_current_seed(uint64_t seed) { | ||||
|   at::cuda::assertNotCapturing( | ||||
|       "Cannot call CUDAGeneratorImpl::set_current_seed"); | ||||
|   state_->seed_ = seed; | ||||
|   state_->philox_offset_per_thread_ = 0; | ||||
|   no_reset_rnn_state_.clear(); | ||||
|   if (C10_LIKELY(at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None)) { | ||||
|     state_->seed_ = seed; | ||||
|     state_->philox_offset_per_thread_ = 0; | ||||
|     no_reset_rnn_state_.clear(); | ||||
|   } else { | ||||
|     TORCH_CHECK(state_->seed_ == seed, "CUDAGeneratorImpl::set_current_seed can be called during stream capture only if new seed is the same as the original seed."); | ||||
|     // no-op case | ||||
|   } | ||||
| } | ||||
|  | ||||
| /** | ||||
| @ -299,9 +302,6 @@ uint64_t CUDAGeneratorImpl::get_offset() const { | ||||
|  * Gets the current seed of CUDAGeneratorImpl. | ||||
|  */ | ||||
| uint64_t CUDAGeneratorImpl::current_seed() const { | ||||
|   // Debatable if current_seed() should be allowed in captured regions. | ||||
|   // Conservatively disallow it for now. | ||||
|   at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::current_seed"); | ||||
|   return state_->seed_; | ||||
| } | ||||
|  | ||||
| @ -346,8 +346,6 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const { | ||||
|  * and size of the internal state. | ||||
|  */ | ||||
| void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) { | ||||
|   at::cuda::assertNotCapturing( | ||||
|       "Please ensure to utilize the CUDAGeneratorImpl::set_state_index method during capturing."); | ||||
|   static const size_t seed_size = sizeof(uint64_t); | ||||
|   static const size_t offset_size = sizeof(int64_t); | ||||
|   static const size_t total_size = seed_size + offset_size; | ||||
| @ -402,15 +400,27 @@ c10::intrusive_ptr<c10::GeneratorImpl> CUDAGeneratorImpl::graphsafe_get_state() | ||||
|  */ | ||||
| void CUDAGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) { | ||||
|   // see Note [Why enforce RNG offset % 4 == 0?] | ||||
|  | ||||
|   // Note: If you use CUDNN RNN's, calling | ||||
|   // set_philox_offset_per_thread instead of set_offset will cause the | ||||
|   // cudnn RNN rng state to become stale. | ||||
|   TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4"); | ||||
|   state_->philox_offset_per_thread_ = offset; | ||||
|   if (C10_LIKELY(at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None)) { | ||||
|     state_->philox_offset_per_thread_ = offset; | ||||
|   } else { | ||||
|     state_->offset_intragraph_ = offset; | ||||
|   } | ||||
| } | ||||
|  | ||||
| /** | ||||
|  * Gets the current philox_offset_per_thread_ of CUDAGeneratorImpl. | ||||
|  */ | ||||
| uint64_t CUDAGeneratorImpl::philox_offset_per_thread() const { | ||||
|   return state_->philox_offset_per_thread_; | ||||
|   if (C10_LIKELY(at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None)) { | ||||
|     return state_->philox_offset_per_thread_; | ||||
|   } else { | ||||
|     return state_->offset_intragraph_; | ||||
|   } | ||||
| } | ||||
|  | ||||
| /** | ||||
|  | ||||
| @ -19,7 +19,7 @@ | ||||
| #define DLPACK_MAJOR_VERSION 1 | ||||
|  | ||||
| /*! \brief The current minor version of dlpack */ | ||||
| #define DLPACK_MINOR_VERSION 0 | ||||
| #define DLPACK_MINOR_VERSION 1 | ||||
|  | ||||
| /*! \brief DLPACK_DLL prefix for windows */ | ||||
| #ifdef _WIN32 | ||||
| @ -32,9 +32,7 @@ | ||||
| #define DLPACK_DLL | ||||
| #endif | ||||
|  | ||||
| // NOLINTNEXTLINE(modernize-deprecated-headers) | ||||
| #include <stdint.h> | ||||
| // NOLINTNEXTLINE(modernize-deprecated-headers) | ||||
| #include <stddef.h> | ||||
|  | ||||
| #ifdef __cplusplus | ||||
| @ -159,6 +157,26 @@ typedef enum { | ||||
|   kDLComplex = 5U, | ||||
|   /*! \brief boolean */ | ||||
|   kDLBool = 6U, | ||||
|   /*! \brief FP8 data types */ | ||||
|   kDLFloat8_e3m4 = 7U, | ||||
|   kDLFloat8_e4m3 = 8U, | ||||
|   kDLFloat8_e4m3b11fnuz = 9U, | ||||
|   kDLFloat8_e4m3fn = 10U, | ||||
|   kDLFloat8_e4m3fnuz = 11U, | ||||
|   kDLFloat8_e5m2 = 12U, | ||||
|   kDLFloat8_e5m2fnuz = 13U, | ||||
|   kDLFloat8_e8m0fnu = 14U, | ||||
|   /*! \brief FP6 data types | ||||
|    * Setting bits != 6 is currently unspecified, and the producer must ensure it is set | ||||
|    * while the consumer must stop importing if the value is unexpected. | ||||
|    */ | ||||
|   kDLFloat6_e2m3fn = 15U, | ||||
|   kDLFloat6_e3m2fn = 16U, | ||||
|   /*! \brief FP4 data types | ||||
|    * Setting bits != 4 is currently unspecified, and the producer must ensure it is set | ||||
|    * while the consumer must stop importing if the value is unexpected. | ||||
|    */ | ||||
|   kDLFloat4_e2m1fn = 17U, | ||||
| } DLDataTypeCode; | ||||
|  | ||||
| /*! | ||||
| @ -172,6 +190,12 @@ typedef enum { | ||||
|  *   - int8: type_code = 0, bits = 8, lanes = 1 | ||||
|  *   - std::complex<float>: type_code = 5, bits = 64, lanes = 1 | ||||
|  *   - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of bool is 8 bits) | ||||
|  *   - float8_e4m3: type_code = 8, bits = 8, lanes = 1 (packed in memory) | ||||
|  *   - float6_e3m2fn: type_code = 16, bits = 6, lanes = 1 (packed in memory) | ||||
|  *   - float4_e2m1fn: type_code = 17, bits = 4, lanes = 1 (packed in memory) | ||||
|  * | ||||
|  *  When a sub-byte type is packed, DLPack requires the data to be in little bit-endian, i.e., | ||||
|  *  for a packed data set D ((D >> (i * bits)) && bit_mask) stores the i-th element. | ||||
|  */ | ||||
| typedef struct { | ||||
|   /*! | ||||
| @ -229,12 +253,12 @@ typedef struct { | ||||
|   /*! \brief The data type of the pointer*/ | ||||
|   DLDataType dtype; | ||||
|   /*! \brief The shape of the tensor */ | ||||
|   const int64_t* shape; | ||||
|   int64_t* shape; | ||||
|   /*! | ||||
|    * \brief strides of the tensor (in number of elements, not bytes) | ||||
|    *  can be NULL, indicating tensor is compact and row-majored. | ||||
|    */ | ||||
|   const int64_t* strides; | ||||
|   int64_t* strides; | ||||
|   /*! \brief The offset in bytes to the beginning pointer to data */ | ||||
|   uint64_t byte_offset; | ||||
| } DLTensor; | ||||
| @ -269,7 +293,7 @@ typedef struct DLManagedTensor { | ||||
|   void (*deleter)(struct DLManagedTensor * self); | ||||
| } DLManagedTensor; | ||||
|  | ||||
| // bit masks used in in the DLManagedTensorVersioned | ||||
| // bit masks used in the DLManagedTensorVersioned | ||||
|  | ||||
| /*! \brief bit mask to indicate that the tensor is read only. */ | ||||
| #define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL) | ||||
| @ -282,6 +306,14 @@ typedef struct DLManagedTensor { | ||||
|  */ | ||||
| #define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL) | ||||
|  | ||||
| /* | ||||
|  * \brief bit mask to indicate that whether a sub-byte type is packed or padded. | ||||
|  * | ||||
|  * The default for sub-byte types (ex: fp4/fp6) is assumed packed. This flag can | ||||
|  * be set by the producer to signal that a tensor of sub-byte type is padded. | ||||
|  */ | ||||
| #define DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED (1UL << 2UL) | ||||
|  | ||||
| /*! | ||||
|  * \brief A versioned and managed C Tensor object, manage memory of DLTensor. | ||||
|  * | ||||
|  | ||||
| @ -171,6 +171,8 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { | ||||
|  | ||||
|   POINTWISE_BOXED(fill_.Scalar); | ||||
|   POINTWISE_BOXED(zero_); | ||||
|   // This is special because this op doesn't return anything | ||||
|   m.impl("_assert_tensor_metadata", native::_assert_tensor_metadata); | ||||
|  | ||||
| #undef UNARY_POINTWISE | ||||
| #undef UNARY_POINTWISE_ALL | ||||
|  | ||||
| @ -81,7 +81,7 @@ Tensor math_channel_shuffle(const Tensor& self, int64_t groups) { | ||||
|   // TODO: contiguous can be made to preserve the memory format | ||||
|   // of the input. However since the above reshape clobbers h and w | ||||
|   // it may not be safe to do that, since channels_last contiguous | ||||
|   // may think oc and and the last dim correspond to h,w? | ||||
|   // may think oc and the last dim correspond to h,w? | ||||
|   // It is not clear, however from initial looking around it feels that | ||||
|   // this may not be correct. | ||||
|   // In this case channels last will likely require custom implementation | ||||
|  | ||||
| @ -1,3 +1,4 @@ | ||||
| #pragma once | ||||
| #include <ATen/core/Tensor.h> | ||||
| #include <ATen/Config.h> | ||||
| #include <cstdint> | ||||
|  | ||||
| @ -85,11 +85,11 @@ void cpu_max_unpool( | ||||
|     if constexpr (is_3d) { | ||||
|       TORCH_CHECK(false, "Found an invalid max index: ", optional_error_index.value(), | ||||
|           " (output volumes are of size ", output_depth, | ||||
|           "x", output_height, "x", output_width); | ||||
|           "x", output_height, "x", output_width, ")"); | ||||
|     } else { | ||||
|       TORCH_CHECK(false, "Found an invalid max index: ", optional_error_index.value(), | ||||
|           " (output volumes are of size ", output_height, | ||||
|           "x", output_width); | ||||
|           "x", output_width, ")"); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|  | ||||
| @ -226,6 +226,38 @@ __global__ void CatArrayBatchedCopy_contig( | ||||
|     } | ||||
| } | ||||
|  | ||||
|  | ||||
| template <typename T, typename IndexType, int Dims, int batch_size, int stride_size, int alignment, int elems_per_vec> | ||||
| __global__ void CatArrayBatchedCopy_vectorized( | ||||
|     char* output, | ||||
|     CatArrInputTensorMetadata<T, IndexType, batch_size, stride_size> inputs, | ||||
|     TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os, | ||||
|     const int concatDim, | ||||
|     IndexType trailingSize) { | ||||
|  | ||||
|     IndexType tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
|     IndexType nElements = inputs.nElements[blockIdx.y] / elems_per_vec; | ||||
|  | ||||
|     if(tid >= nElements) return; | ||||
|  | ||||
|     const char * data = (char*)inputs.input[blockIdx.y]; | ||||
|     IndexType offset = inputs.offset[blockIdx.y] * trailingSize / elems_per_vec; | ||||
|     IndexType dimSize = inputs.dimSize[blockIdx.y] * trailingSize / elems_per_vec; | ||||
|     int64_t dataOffset = (int64_t)offset  * alignment; // in bytes | ||||
|  | ||||
|     IndexType stride = gridDim.x * blockDim.x; | ||||
|  | ||||
|     while( tid < nElements){ | ||||
|       int64_t elementOffset = (int64_t)CatArrIndexToOffset<IndexType, Dims>::compute( | ||||
|                     os.tensorSize, os.tensorStride, dimSize, concatDim, tid) * alignment; // in bytes | ||||
|       auto vec = at::native::memory::ld_vec<alignment>(data + (int64_t)alignment * tid); | ||||
|       at::native::memory::st_vec<alignment>(output + dataOffset + elementOffset, vec); | ||||
|       tid += stride; | ||||
|     } | ||||
| } | ||||
|  | ||||
|  | ||||
|  | ||||
| /* | ||||
|   Specialized implementation of the CatArrayBatchedCopy written to generate wide memory loads | ||||
|   to improve memory bandwidth throughput. | ||||
| @ -296,12 +328,27 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i | ||||
|   scalar_t *data = (scalar_t *)(out.mutable_data_ptr()); | ||||
|   CatArrInputTensorMetadata<scalar_t, unsigned int, batch_size, stride_size> catMetaData; | ||||
|   TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> outputParam; | ||||
|   // If all batches are contiguous we can call a specialized implementation | ||||
|   // which requires the input tensor addresses to be aligned to a | ||||
|   // 16 Byte boundary. | ||||
|  | ||||
|   constexpr bool isContig = stride_size == 1; | ||||
|   bool isAligned = true; | ||||
|   constexpr int alignment = 16; | ||||
|  | ||||
|   // Next, let's initialize the size, stride arrays for the output Tensor. | ||||
|   // for contig case, we'll canonicalize output strides, so that | ||||
|   // we don't have arbitrary strides for dims of size 0 | ||||
|   size_t stride0 = 1; | ||||
|   if (memory_format == c10::MemoryFormat::Contiguous) { | ||||
|     for (int i = 0; i < nDims; ++i) { | ||||
|     for (int i = nDims - 1; i >= 0; --i) { | ||||
|       outputParam.tensorSize[i] = out.size(i); | ||||
|       outputParam.tensorStride[i] = out.stride(i); | ||||
|       if (isContig) { | ||||
|         outputParam.tensorStride[i] = stride0; | ||||
|         stride0 *= out.size(i); | ||||
|       } else { | ||||
|         outputParam.tensorStride[i] = out.stride(i); | ||||
|       } | ||||
|     } | ||||
|   } else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) { | ||||
|     // permute the semantics of dims from NCHW to NHWC so that the input | ||||
| @ -320,12 +367,15 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i | ||||
|  | ||||
|   at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); | ||||
|  | ||||
|   // If all batches are contiguous we can call a specialized implementation | ||||
|   // which requires the input tensor addresses to be aligned to a | ||||
|   // 16 Byte boundary. | ||||
|  | ||||
|   bool isContig = true; | ||||
|   bool isAligned = true; | ||||
|   // for channels last computing slice size correctly is much more involved, so we never send it | ||||
|   // on the fully vectorized path | ||||
|   // we need output stride in cat dimension to be multiple of alignment, | ||||
|   // if we ever use it to compute offsets | ||||
|   // for catting in 0th dimension it doesn't matter | ||||
|   bool isInOutAligned = isContig && at::native::memory::get_alignment(data) >= alignment && | ||||
|                         memory_format == c10::MemoryFormat::Contiguous && (dimension == 0 || | ||||
|                         outputParam.tensorStride[dimension - 1] * sizeof(scalar_t) % alignment == 0); | ||||
|   unsigned int max_elements_per_tensor = 0; | ||||
|  | ||||
|   // Now we loop | ||||
| @ -341,6 +391,16 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i | ||||
|       // high-dimensional tensor | ||||
|       if (inputs[i+batchCounter].get().numel() > 0) { | ||||
|         dimSize = inputs[i+batchCounter].get().size(dimension); | ||||
|         if (isInOutAligned) { | ||||
|           auto t = inputs[i+batchCounter].get(); | ||||
|           // similarly to output stride, we cannot trust stride value to | ||||
|           // determine slice size if the corresponding dimension is 1 | ||||
|           // we have to multiply all the subsequent sizes | ||||
|           int64_t slice_size = dimension == 0 ? t.numel() : t.sizes()[dimension - 1] != 1 ? | ||||
|              t.strides()[dimension - 1] : c10::multiply_integers(t.sizes().begin() + dimension, t.sizes().end()); | ||||
|           slice_size *= sizeof(scalar_t); | ||||
|           isInOutAligned &= (slice_size % alignment == 0); | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       catMetaData.input[batchCounter] = (scalar_t*)(inputs[i+batchCounter].get().const_data_ptr()); | ||||
| @ -351,10 +411,12 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i | ||||
| #ifdef USE_ROCM | ||||
|       // On ROCm, CatArrayBatchedCopy_contig is faster | ||||
|       isAligned = false; | ||||
|       isInOutAligned = false; | ||||
| #else | ||||
|       // If at least one of the inputs is not aligned, we can't call the | ||||
|       // CatArrayBatchedCopy_alignedK_contig | ||||
|       isAligned &= is_aligned_vec4(catMetaData.input[batchCounter]); | ||||
|       isInOutAligned &= at::native::memory::get_alignment(catMetaData.input[batchCounter]) >= alignment; | ||||
| #endif | ||||
|  | ||||
|       if (stride_size > 1) { | ||||
| @ -365,7 +427,6 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i | ||||
|           catMetaData.tensorStride[batchCounter].tensorStride[j] = strides[j]; | ||||
|         } | ||||
|         catMetaData.isContiguous[batchCounter] = false; | ||||
|         isContig = false; | ||||
|       } else { | ||||
|         catMetaData.isContiguous[batchCounter] = true; | ||||
|       } | ||||
| @ -388,10 +449,13 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i | ||||
|           max_elements_per_tensor, batchCounter); | ||||
| #else | ||||
|     dim3 applyBlock, catGrid; | ||||
|     if (isContig && sizeof(scalar_t) > 2) { | ||||
|     if (isInOutAligned) { | ||||
|       std::tie(catGrid, applyBlock) = getCatGridContig<scalar_t, alignment>( | ||||
|         max_elements_per_tensor, batchCounter); | ||||
|     } else if (isContig && isAligned && sizeof(scalar_t) > 2) { | ||||
|       std::tie(catGrid, applyBlock) = getCatGridContig<scalar_t, ALIGNED_VEC_LOAD_BYTES_16>( | ||||
|           max_elements_per_tensor, batchCounter); | ||||
|     } else if (isContig && sizeof(scalar_t) == 2) { | ||||
|     } else if (isContig && isAligned && sizeof(scalar_t) == 2) { | ||||
|       std::tie(catGrid, applyBlock) = getCatGridContig<scalar_t, ALIGNED_VEC_LOAD_BYTES_8>( | ||||
|           max_elements_per_tensor, batchCounter); | ||||
|     } else { | ||||
| @ -399,6 +463,30 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i | ||||
|       getCatGrid(batchCounter, catGrid); | ||||
|     } | ||||
| #endif | ||||
|     int32_t trailingSize; | ||||
|     TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> kernelOutputParam; | ||||
|     if (isInOutAligned) { | ||||
|       // in this case we can and should flatten the tensors after the cat dim | ||||
|       // we want to view the tensors as if consisting of `alignment`-sized elements | ||||
|       // however, we might not be able to cleanly divide just the last dim - | ||||
|       // it might not be the multiple of alignment. | ||||
|       // however, we know that the full concatted slice is multiple of alignment, | ||||
|       // so if we flatten all the dims after and including concat dim, | ||||
|       // it will be divisible by alignment | ||||
|       // then we need to divide last out size by elems_per_vec, | ||||
|       // and divide all strides except last by elems_per_vec (last stride is 1 always) | ||||
|       // for input, we will fix up the sizes and strides in the kernel directly | ||||
|       kernelOutputParam = outputParam; | ||||
|       nDims = dimension + 1; | ||||
|       constexpr auto elems_per_vec = alignment / sizeof(scalar_t); | ||||
|       auto out_size = dimension == 0 ? out.numel() : kernelOutputParam.tensorStride[dimension-1]; | ||||
|       kernelOutputParam.tensorSize[dimension] = out_size / elems_per_vec; | ||||
|       trailingSize = outputParam.tensorStride[dimension]; | ||||
|       kernelOutputParam.tensorStride[dimension] = 1; | ||||
|       for (int i = 0; i < dimension; ++i) { | ||||
|         kernelOutputParam.tensorStride[i] /= elems_per_vec; | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     if (memory_format != c10::MemoryFormat::Contiguous) { | ||||
|       switch (dimension) { | ||||
| @ -413,7 +501,12 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i | ||||
|     } | ||||
|     // Template Declarations for dim = 1, 2, 3, 4 | ||||
| #define HANDLE_CASE(DIMS) \ | ||||
|     if (isContig && isAligned && sizeof(scalar_t) > 2 && sizeof(scalar_t) <= 8) {\ | ||||
|     if (isInOutAligned) {\ | ||||
|       constexpr auto elems_per_vec = alignment / sizeof(scalar_t); \ | ||||
|       CatArrayBatchedCopy_vectorized<scalar_t, unsigned int, DIMS, batch_size, stride_size, alignment, elems_per_vec><<<\ | ||||
|       catGrid, applyBlock, 0, stream.stream()>>>(\ | ||||
|         (char*)data, catMetaData, kernelOutputParam, dimension, trailingSize);\ | ||||
|     } else if (isContig && isAligned && sizeof(scalar_t) > 2 && sizeof(scalar_t) <= 8) {\ | ||||
|       CatArrayBatchedCopy_alignedK_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size, ALIGNED_VEC_LOAD_BYTES_16><<<\ | ||||
|           catGrid, applyBlock, 0, stream.stream()>>>(\ | ||||
|               data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\ | ||||
|  | ||||
| @ -5,12 +5,20 @@ | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| __global__ void weight_int8pack_mm_kernel(const float* x, const int8_t* w, const float* scale, float* out, int B, int K, int N) { | ||||
| __global__ void weight_int8pack_mm_kernel( | ||||
|     const float* x, | ||||
|     const int8_t* w, | ||||
|     const float* scale, | ||||
|     float* out, | ||||
|     int B, | ||||
|     int K, | ||||
|     int N) { | ||||
|   // one thread per output element: [B, N] | ||||
|   int b = blockIdx.y * blockDim.y + threadIdx.y; | ||||
|   int n = blockIdx.x * blockDim.x + threadIdx.x; | ||||
|  | ||||
|   if (b >= B || n >= N) return; | ||||
|   if (b >= B || n >= N) | ||||
|     return; | ||||
|  | ||||
|   float acc = 0.0f; | ||||
|   for (int k = 0; k < K; ++k) { | ||||
| @ -20,7 +28,11 @@ __global__ void weight_int8pack_mm_kernel(const float* x, const int8_t* w, const | ||||
|   out[b * N + n] = acc * scale[n]; | ||||
| } | ||||
|  | ||||
| void launch_weight_int8pack_mm_cuda_kernel(const Tensor& x, const Tensor& w_int8, const Tensor& scale, Tensor& out) { | ||||
| void launch_weight_int8pack_mm_cuda_kernel( | ||||
|     const Tensor& x, | ||||
|     const Tensor& w_int8, | ||||
|     const Tensor& scale, | ||||
|     Tensor& out) { | ||||
|   const int B = x.size(0); | ||||
|   const int K = x.size(1); | ||||
|   const int N = w_int8.size(0); | ||||
| @ -35,12 +47,16 @@ void launch_weight_int8pack_mm_cuda_kernel(const Tensor& x, const Tensor& w_int8 | ||||
|       w_int8.data_ptr<int8_t>(), | ||||
|       scale.data_ptr<float>(), | ||||
|       out.data_ptr<float>(), | ||||
|       B, K, N); | ||||
|       B, | ||||
|       K, | ||||
|       N); | ||||
| } | ||||
|  | ||||
|  | ||||
| // Main GPU entry point | ||||
| at::Tensor _weight_int8pack_mm_cuda(const at::Tensor& x, const at::Tensor& w_int8, const at::Tensor& scale) { | ||||
| at::Tensor _weight_int8pack_mm_cuda( | ||||
|     const at::Tensor& x, | ||||
|     const at::Tensor& w_int8, | ||||
|     const at::Tensor& scale) { | ||||
|   // --- Check inputs --- | ||||
|   TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor"); | ||||
|   TORCH_CHECK(w_int8.is_cuda(), "w must be a CUDA tensor"); | ||||
| @ -50,12 +66,16 @@ at::Tensor _weight_int8pack_mm_cuda(const at::Tensor& x, const at::Tensor& w_int | ||||
|   TORCH_CHECK(w_int8.dim() == 2, "w must be 2D"); | ||||
|   TORCH_CHECK(scale.dim() == 1, "scale must be 1D"); | ||||
|  | ||||
|   TORCH_CHECK(x.size(1) == w_int8.size(1), "K dimension mismatch: x.size(1) != w.size(1)"); | ||||
|   TORCH_CHECK(w_int8.size(0) == scale.size(0), "Output dim mismatch: w.size(0) != scale.size(0)"); | ||||
|   TORCH_CHECK( | ||||
|       x.size(1) == w_int8.size(1), | ||||
|       "K dimension mismatch: x.size(1) != w.size(1)"); | ||||
|   TORCH_CHECK( | ||||
|       w_int8.size(0) == scale.size(0), | ||||
|       "Output dim mismatch: w.size(0) != scale.size(0)"); | ||||
|  | ||||
|   // --- Determine shapes --- | ||||
|   auto B = x.size(0);  // batch size | ||||
|   auto N = w_int8.size(0);  // output dim | ||||
|   auto B = x.size(0); // batch size | ||||
|   auto N = w_int8.size(0); // output dim | ||||
|  | ||||
|   // Ensure inputs are in the correct types for the kernel | ||||
|   auto x_f32 = x.to(at::kFloat); | ||||
| @ -63,12 +83,13 @@ at::Tensor _weight_int8pack_mm_cuda(const at::Tensor& x, const at::Tensor& w_int | ||||
|   auto scale_f32 = scale.to(at::kFloat); | ||||
|  | ||||
|   // --- Allocate output --- | ||||
|   auto out = at::empty({B, N}, x.options().dtype(at::kFloat)); | ||||
|   auto out = at::empty({B, N}, x_f32.options()); | ||||
|  | ||||
|   // --- Launch kernel --- | ||||
|   launch_weight_int8pack_mm_cuda_kernel(x_f32, w_int8_contiguous, scale_f32, out); | ||||
|   launch_weight_int8pack_mm_cuda_kernel( | ||||
|       x_f32, w_int8_contiguous, scale_f32, out); | ||||
|  | ||||
|   return out; | ||||
|   return out.to(x.dtype()); | ||||
| } | ||||
|  | ||||
| } // namespace at::native | ||||
|  | ||||
| @ -482,7 +482,9 @@ auto build_graph( | ||||
|   auto scaled_dot_product_flash_attention_options = | ||||
|       fe::graph::SDPA_attributes() | ||||
|           .set_name("CUDNN_SDPA") | ||||
|           .set_generate_stats(return_softmaxstats) | ||||
|           .set_is_inference(return_softmaxstats == false) | ||||
|           // TODO(eqy): switch to this API once cuDNN FE is upgraded | ||||
|           // .set_generate_stats(return_softmaxstats) | ||||
|           .set_causal_mask(is_causal) | ||||
|           .set_attn_scale(attn_scale); | ||||
|   if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) { | ||||
| @ -702,7 +704,9 @@ auto build_graph_nestedtensor( | ||||
|   auto scaled_dot_product_flash_attention_options = | ||||
|       fe::graph::SDPA_attributes() | ||||
|           .set_name("CUDNN_SDPA_NESTEDTENSOR") | ||||
|           .set_generate_stats(return_softmaxstats) | ||||
|           .set_is_inference(return_softmaxstats == false) | ||||
|           // TODO(eqy): switch to this API once cuDNN FE is upgraded | ||||
|           // .set_generate_stats(return_softmaxstats) | ||||
|           .set_causal_mask(is_causal) | ||||
|           .set_attn_scale(attn_scale) | ||||
|           .set_seq_len_q(SEQ_LEN_Q_) | ||||
|  | ||||
							
								
								
									
										25
									
								
								aten/src/ATen/native/mps/kernels/EmbeddingBag.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								aten/src/ATen/native/mps/kernels/EmbeddingBag.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,25 @@ | ||||
| #pragma once | ||||
| #include <c10/metal/common.h> | ||||
|  | ||||
| #ifdef __METAL__ | ||||
| enum class EmbeddingBagMode { SUM = 0, MEAN, MAX }; | ||||
| #else | ||||
| #include <ATen/native/EmbeddingBag.h> | ||||
| using at::native::EmbeddingBagMode; | ||||
| #endif | ||||
|  | ||||
| template <typename idx_type_t = uint32_t> | ||||
| struct EmbeddingBagParams { | ||||
|   ::c10::metal::array<idx_type_t, 2> weight_strides; | ||||
|   ::c10::metal::array<idx_type_t, 2> output_strides; | ||||
|   ::c10::metal::array<idx_type_t, 2> max_indices_strides; | ||||
|  | ||||
|   idx_type_t per_sample_weights_strides; | ||||
|  | ||||
|   idx_type_t num_indices; | ||||
|   idx_type_t num_bags; | ||||
|   idx_type_t feature_size; | ||||
|  | ||||
|   EmbeddingBagMode mode; | ||||
|   int64_t padding_idx; | ||||
| }; | ||||
							
								
								
									
										212
									
								
								aten/src/ATen/native/mps/kernels/EmbeddingBag.metal
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										212
									
								
								aten/src/ATen/native/mps/kernels/EmbeddingBag.metal
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,212 @@ | ||||
| #include <ATen/native/mps/kernels/EmbeddingBag.h> | ||||
| #include <c10/metal/utils.h> | ||||
| #include <metal_array> | ||||
| #include <metal_stdlib> | ||||
|  | ||||
| using namespace metal; | ||||
| using namespace c10::metal; | ||||
|  | ||||
| template <EmbeddingBagMode M, typename T> | ||||
| struct ReductionOpInit { | ||||
|   inline opmath_t<T> operator()() { | ||||
|     return 0; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <typename T> | ||||
| struct ReductionOpInit<EmbeddingBagMode::MAX, T> { | ||||
|   inline opmath_t<T> operator()() { | ||||
|     return static_cast<opmath_t<T>>(-INFINITY); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <EmbeddingBagMode M, typename T> | ||||
| struct ReductionOp { | ||||
|   inline opmath_t<T> operator()( | ||||
|       T weight_val, | ||||
|       opmath_t<T> out_val, | ||||
|       uint32_t per_sample_weights_index, | ||||
|       constant T* per_sample_weights, | ||||
|       uint32_t per_sample_weights_strides); | ||||
| }; | ||||
|  | ||||
| template <typename T> | ||||
| struct ReductionOp<EmbeddingBagMode::SUM, T> { | ||||
|   inline opmath_t<T> operator()( | ||||
|       T weight_val, | ||||
|       opmath_t<T> out_val, | ||||
|       uint32_t per_sample_weights_index, | ||||
|       constant T* per_sample_weights, | ||||
|       uint32_t per_sample_weights_strides) { | ||||
|     if (per_sample_weights_strides) { | ||||
|       T per_sample_weight = per_sample_weights | ||||
|           [per_sample_weights_strides * per_sample_weights_index]; | ||||
|       return static_cast<opmath_t<T>>(per_sample_weight) * | ||||
|           static_cast<opmath_t<T>>(weight_val) + | ||||
|           out_val; | ||||
|     } else { | ||||
|       return static_cast<opmath_t<T>>(weight_val) + out_val; | ||||
|     } | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <typename T> | ||||
| struct ReductionOp<EmbeddingBagMode::MEAN, T> { | ||||
|   inline opmath_t<T> operator()( | ||||
|       T weight_val, | ||||
|       opmath_t<T> out_val, | ||||
|       uint32_t, | ||||
|       constant T*, | ||||
|       uint32_t) { | ||||
|     return static_cast<opmath_t<T>>(weight_val) + out_val; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <typename T> | ||||
| struct ReductionOp<EmbeddingBagMode::MAX, T> { | ||||
|   inline opmath_t<T> operator()( | ||||
|       T weight_val, | ||||
|       opmath_t<T> out_val, | ||||
|       uint32_t, | ||||
|       constant T*, | ||||
|       uint32_t) { | ||||
|     return max(static_cast<opmath_t<T>>(weight_val), out_val); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <EmbeddingBagMode M, typename T> | ||||
| struct ReductionOpFinal { | ||||
|   inline T operator()(opmath_t<T> val, uint32_t) { | ||||
|     return static_cast<T>(val); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <typename T> | ||||
| struct ReductionOpFinal<EmbeddingBagMode::MEAN, T> { | ||||
|   inline T operator()(opmath_t<T> val, uint32_t count) { | ||||
|     auto out = val / count; | ||||
|     return static_cast<T>((count == 0) ? 0 : out); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <typename T> | ||||
| struct ReductionOpFinal<EmbeddingBagMode::MAX, T> { | ||||
|   inline T operator()(opmath_t<T> val, uint32_t count) { | ||||
|     return static_cast<T>((count == 0) ? 0 : val); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <EmbeddingBagMode M, typename T, typename I> | ||||
| void embedding_bag_impl( | ||||
|     constant T* weight, | ||||
|     constant I* indices, | ||||
|     constant I* offsets, | ||||
|     constant T* per_sample_weights, | ||||
|     device T* output, | ||||
|     device I* offset2bag, | ||||
|     device I* bag_size, | ||||
|     device I* max_indices, | ||||
|     constant EmbeddingBagParams<uint32_t>& params, | ||||
|     uint tid) { | ||||
|   auto num_indices = params.num_indices; | ||||
|   auto num_bags = params.num_bags; | ||||
|   auto feature_size = params.feature_size; | ||||
|   auto padding_idx = params.padding_idx; | ||||
|   auto per_sample_weights_strides = params.per_sample_weights_strides; | ||||
|   constant auto& output_strides = params.output_strides; | ||||
|   constant auto& weight_strides = params.weight_strides; | ||||
|   constant auto& max_indices_strides = params.max_indices_strides; | ||||
|  | ||||
|   auto bag_idx = tid / feature_size; | ||||
|   auto feature_idx = tid % feature_size; | ||||
|  | ||||
|   output += bag_idx * output_strides[0] + feature_idx * output_strides[1]; | ||||
|  | ||||
|   uint32_t offsets_end = min(bag_idx + 1, num_bags - 1); | ||||
|   bool is_last_bag = bag_idx + 1 == num_bags; | ||||
|   uint32_t indices_start = static_cast<uint32_t>(offsets[bag_idx]); | ||||
|   uint32_t indices_end = is_last_bag * (num_indices) + | ||||
|       (!is_last_bag) * (static_cast<uint32_t>(offsets[offsets_end])); | ||||
|  | ||||
|   auto out_val = ReductionOpInit<M, T>()(); | ||||
|  | ||||
|   uint32_t bag_size_ = 0; | ||||
|  | ||||
|   for (uint32_t indices_idx = indices_start; indices_idx < indices_end; | ||||
|        indices_idx++) { | ||||
|     I weight_idx = indices[indices_idx]; | ||||
|     bool pad = (weight_idx == padding_idx); | ||||
|     T weight_val = weight | ||||
|         [static_cast<uint32_t>(weight_idx) * weight_strides[0] + | ||||
|          feature_idx * weight_strides[1]]; | ||||
|  | ||||
|     bag_size_ += static_cast<uint32_t>(!pad); | ||||
|  | ||||
|     auto tmp_val = ReductionOp<M, T>()( | ||||
|         weight_val, | ||||
|         out_val, | ||||
|         indices_idx, | ||||
|         per_sample_weights, | ||||
|         per_sample_weights_strides); | ||||
|  | ||||
|     out_val = pad ? out_val : tmp_val; | ||||
|   } | ||||
|  | ||||
|   *output = ReductionOpFinal<M, T>()(out_val, bag_size_); | ||||
| } | ||||
|  | ||||
| #define DISPATCH_IMPL(MODE)        \ | ||||
|   return embedding_bag_impl<MODE>( \ | ||||
|       weight,                      \ | ||||
|       indices,                     \ | ||||
|       offsets,                     \ | ||||
|       per_sample_weights,          \ | ||||
|       output,                      \ | ||||
|       offset2bag,                  \ | ||||
|       bag_size,                    \ | ||||
|       max_indices,                 \ | ||||
|       params,                      \ | ||||
|       tid) | ||||
|  | ||||
| template <typename T, typename I> | ||||
| kernel void embedding_bag( | ||||
|     constant T* weight [[buffer(0)]], | ||||
|     constant I* indices [[buffer(1)]], | ||||
|     constant I* offsets [[buffer(2)]], | ||||
|     constant T* per_sample_weights [[buffer(3)]], | ||||
|     device T* output [[buffer(4)]], | ||||
|     device I* offset2bag [[buffer(5)]], | ||||
|     device I* bag_size [[buffer(6)]], | ||||
|     device I* max_indices [[buffer(7)]], | ||||
|     constant EmbeddingBagParams<uint32_t>& params [[buffer(8)]], | ||||
|     uint tid [[thread_position_in_grid]]) { | ||||
|   switch (params.mode) { | ||||
|     case EmbeddingBagMode::SUM: | ||||
|       DISPATCH_IMPL(EmbeddingBagMode::SUM); | ||||
|     case EmbeddingBagMode::MEAN: | ||||
|       DISPATCH_IMPL(EmbeddingBagMode::MEAN); | ||||
|     case EmbeddingBagMode::MAX: | ||||
|       DISPATCH_IMPL(EmbeddingBagMode::MAX); | ||||
|   } | ||||
| } | ||||
|  | ||||
| #define REGISTER_EMBEDDING_BAG_OP(T, I)                             \ | ||||
|   template [[host_name("embedding_bag_" #T "_" #I)]]                \ | ||||
|   kernel void embedding_bag<T, I>(                                  \ | ||||
|       constant T * weight [[buffer(0)]],                            \ | ||||
|       constant I * indices [[buffer(1)]],                           \ | ||||
|       constant I * offsets [[buffer(2)]],                           \ | ||||
|       constant T * per_sample_weights [[buffer(3)]],                \ | ||||
|       device T * output [[buffer(4)]],                              \ | ||||
|       device I * offset2bag [[buffer(5)]],                          \ | ||||
|       device I * bag_size [[buffer(6)]],                            \ | ||||
|       device I * max_indices [[buffer(7)]],                         \ | ||||
|       constant EmbeddingBagParams<uint32_t> & params [[buffer(8)]], \ | ||||
|       uint tid [[thread_position_in_grid]]); | ||||
|  | ||||
| REGISTER_EMBEDDING_BAG_OP(float, int); | ||||
| REGISTER_EMBEDDING_BAG_OP(float, long); | ||||
| REGISTER_EMBEDDING_BAG_OP(half, int); | ||||
| REGISTER_EMBEDDING_BAG_OP(half, long); | ||||
| REGISTER_EMBEDDING_BAG_OP(bfloat, int); | ||||
| REGISTER_EMBEDDING_BAG_OP(bfloat, long); | ||||
							
								
								
									
										179
									
								
								aten/src/ATen/native/mps/operations/EmbeddingBag.mm
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										179
									
								
								aten/src/ATen/native/mps/operations/EmbeddingBag.mm
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,179 @@ | ||||
| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS | ||||
| #include <ATen/TensorUtils.h> | ||||
| #include <ATen/core/Tensor.h> | ||||
| #include <ATen/mps/MPSProfiler.h> | ||||
| #include <ATen/native/EmbeddingBag.h> | ||||
| #include <ATen/native/Pool.h> | ||||
| #include <ATen/native/mps/OperationUtils.h> | ||||
| #include <ATen/native/mps/kernels/EmbeddingBag.h> | ||||
|  | ||||
| #include <fmt/format.h> | ||||
|  | ||||
| #ifndef AT_PER_OPERATOR_HEADERS | ||||
| #include <ATen/Functions.h> | ||||
| #include <ATen/NativeFunctions.h> | ||||
| #else | ||||
| #include <ATen/ops/_embedding_bag_forward_only_native.h> | ||||
| #include <ATen/ops/_embedding_bag_native.h> | ||||
| #include <ATen/ops/empty.h> | ||||
| #endif | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| #ifndef PYTORCH_JIT_COMPILE_SHADERS | ||||
| static auto& lib = mps::MetalShaderLibrary::getBundledLibrary(); | ||||
| #else | ||||
| #include <ATen/native/mps/EmbeddingBag_metallib.h> | ||||
| #endif | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| std::pair<Tensor, Tensor> promoteIndicesAndOffsets(const Tensor& indices, const Tensor& offsets) { | ||||
|   const auto commonType = promoteTypes(offsets.scalar_type(), indices.scalar_type()); | ||||
|   return {indices.scalar_type() == commonType ? indices : indices.toType(commonType), | ||||
|           offsets.scalar_type() == commonType ? offsets : offsets.toType(commonType)}; | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| namespace mps { | ||||
|  | ||||
| static std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_mps_impl( | ||||
|     const Tensor& weight, | ||||
|     const Tensor& indices_, | ||||
|     const Tensor& offsets_, | ||||
|     const bool scale_grad_by_freq, | ||||
|     const int64_t mode, | ||||
|     bool sparse, | ||||
|     const std::optional<Tensor>& per_sample_weights_opt, | ||||
|     bool include_last_offset, | ||||
|     int64_t padding_idx) { | ||||
|   TORCH_CHECK(indices_.dim() == 1, "input has to be a 1D Tensor, but got Tensor of dimension ", indices_.dim()); | ||||
|   if (indices_.dim() == 1) { | ||||
|     TORCH_CHECK(offsets_.dim() == 1, "offsets has to be a 1D Tensor, but got Tensor of dimension ", offsets_.dim()); | ||||
|   } | ||||
|   TORCH_CHECK(weight.dim() == 2, "weight has to be a 2D Tensor, but got Tensor of dimension ", weight.dim()); | ||||
|  | ||||
|   Tensor indices, offsets; | ||||
|   std::tie(indices, offsets) = promoteIndicesAndOffsets(indices_, offsets_); | ||||
|   auto indices_arg = TensorArg(indices, "indices", 1); | ||||
|   checkScalarTypes("embedding_bag_mps", indices_arg, {kLong, kInt}); | ||||
|   auto offsets_arg = TensorArg(offsets, "offsets", 1); | ||||
|   checkScalarTypes("embedding_bag_mps", offsets_arg, {kLong, kInt}); | ||||
|   checkSameType("embedding_bag_mps", indices_arg, offsets_arg); | ||||
|   auto weight_arg = TensorArg(weight, "weight", 1); | ||||
|  | ||||
|   int64_t num_indices = indices.size(0); | ||||
|   int64_t num_bags = offsets.size(0); | ||||
|   if (include_last_offset) { | ||||
|     num_bags -= 1; | ||||
|   } | ||||
|   int64_t feature_size = weight.size(1); | ||||
|  | ||||
|   auto bag_size = at::empty(offsets.sizes(), indices.options()); | ||||
|   auto offset2bag = at::empty({indices.size(0)}, indices.options()); | ||||
|   auto output = at::empty({num_bags, feature_size}, weight.options()); | ||||
|  | ||||
|   Tensor max_indices; | ||||
|  | ||||
|   if (mode == EmbeddingBagMode::MAX) { | ||||
|     max_indices = at::empty({num_bags, feature_size}, indices.options()); | ||||
|   } else { | ||||
|     max_indices = at::empty({0}, indices.options()); | ||||
|   } | ||||
|  | ||||
|   EmbeddingBagParams<uint32_t> params; | ||||
|  | ||||
|   for (const auto dim : c10::irange(weight.dim())) { | ||||
|     params.weight_strides[dim] = safe_downcast<uint32_t, int64_t>(weight.stride(dim)); | ||||
|     params.output_strides[dim] = safe_downcast<uint32_t, int64_t>(output.stride(dim)); | ||||
|  | ||||
|     if (mode == EmbeddingBagMode::MAX) { | ||||
|       params.max_indices_strides[dim] = safe_downcast<uint32_t, int64_t>(max_indices.stride(dim)); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   bool use_per_sample_weights = per_sample_weights_opt.has_value() && per_sample_weights_opt->defined(); | ||||
|   params.per_sample_weights_strides = use_per_sample_weights ? per_sample_weights_opt->stride(0) : 0; | ||||
|  | ||||
|   params.num_indices = num_indices; | ||||
|   params.num_bags = num_bags; | ||||
|   params.feature_size = feature_size; | ||||
|   params.mode = static_cast<EmbeddingBagMode>(mode); | ||||
|   params.padding_idx = padding_idx; | ||||
|  | ||||
|   auto num_threads = output.numel(); | ||||
|   MPSStream* stream = getCurrentMPSStream(); | ||||
|  | ||||
|   dispatch_sync_with_rethrow(stream->queue(), ^() { | ||||
|     @autoreleasepool { | ||||
|       id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder(); | ||||
|       auto pipeline_state = lib.getPipelineStateForFunc( | ||||
|           fmt::format("embedding_bag_{}_{}", scalarToMetalTypeString(weight), scalarToMetalTypeString(indices))); | ||||
|  | ||||
|       getMPSProfiler().beginProfileKernel(pipeline_state, "embedding_bag", {weight, indices, offsets}); | ||||
|       [computeEncoder setComputePipelineState:pipeline_state]; | ||||
|       mtl_setArgs(computeEncoder, | ||||
|                   weight, | ||||
|                   indices, | ||||
|                   offsets, | ||||
|                   use_per_sample_weights ? per_sample_weights_opt : std::nullopt, | ||||
|                   output, | ||||
|                   offset2bag, | ||||
|                   bag_size, | ||||
|                   max_indices, | ||||
|                   params); | ||||
|  | ||||
|       mtl_dispatch1DJob(computeEncoder, pipeline_state, num_threads); | ||||
|       getMPSProfiler().endProfileKernel(pipeline_state); | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   return std::tuple<Tensor, Tensor, Tensor, Tensor>( | ||||
|       std::move(output), std::move(offset2bag), std::move(bag_size), std::move(max_indices)); | ||||
| } | ||||
|  | ||||
| } // namespace mps | ||||
|  | ||||
| std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_mps(const Tensor& weight, | ||||
|                                                               const Tensor& indices, | ||||
|                                                               const Tensor& offsets, | ||||
|                                                               const bool scale_grad_by_freq, | ||||
|                                                               const int64_t mode, | ||||
|                                                               bool sparse, | ||||
|                                                               const std::optional<Tensor>& per_sample_weights_opt, | ||||
|                                                               bool include_last_offset, | ||||
|                                                               int64_t padding_idx) { | ||||
|   return mps::_embedding_bag_mps_impl(weight, | ||||
|                                       indices, | ||||
|                                       offsets, | ||||
|                                       scale_grad_by_freq, | ||||
|                                       mode, | ||||
|                                       sparse, | ||||
|                                       per_sample_weights_opt, | ||||
|                                       include_last_offset, | ||||
|                                       padding_idx); | ||||
| } | ||||
|  | ||||
| std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_forward_only_mps( | ||||
|     const Tensor& weight, | ||||
|     const Tensor& indices, | ||||
|     const Tensor& offsets, | ||||
|     const bool scale_grad_by_freq, | ||||
|     const int64_t mode, | ||||
|     bool sparse, | ||||
|     const std::optional<Tensor>& per_sample_weights_opt, | ||||
|     bool include_last_offset, | ||||
|     int64_t padding_idx) { | ||||
|   return _embedding_bag_mps(weight, | ||||
|                             indices, | ||||
|                             offsets, | ||||
|                             scale_grad_by_freq, | ||||
|                             mode, | ||||
|                             sparse, | ||||
|                             per_sample_weights_opt, | ||||
|                             include_last_offset, | ||||
|                             padding_idx); | ||||
| } | ||||
|  | ||||
| } // namespace at::native | ||||
| @ -2351,6 +2351,7 @@ | ||||
|   dispatch: | ||||
|     CPU: _embedding_bag_forward_only_cpu | ||||
|     CUDA: _embedding_bag_forward_only_cuda | ||||
|     MPS: _embedding_bag_forward_only_mps | ||||
|   autogen: _embedding_bag_forward_only.out | ||||
|  | ||||
| - func: _rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor) | ||||
| @ -2372,6 +2373,7 @@ | ||||
|   dispatch: | ||||
|     CPU: _embedding_bag_cpu | ||||
|     CUDA: _embedding_bag_cuda | ||||
|     MPS: _embedding_bag_mps | ||||
|   autogen: _embedding_bag.out | ||||
|   tags: core | ||||
|  | ||||
| @ -4372,7 +4374,7 @@ | ||||
|   variants: function, method | ||||
|   dispatch: | ||||
|     CPU: narrow_copy_dense_cpu | ||||
|     SparseCPU, SparseCUDA: narrow_copy_sparse | ||||
|     SparseCPU, SparseCUDA, SparseMPS: narrow_copy_sparse | ||||
|     CompositeExplicitAutogradNonFunctional: narrow_copy_dense_symint | ||||
|   tags: view_copy | ||||
|  | ||||
| @ -6660,7 +6662,7 @@ | ||||
| - func: zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) | ||||
|   dispatch: | ||||
|     CompositeExplicitAutograd: zeros_out | ||||
|     SparseCPU, SparseCUDA, SparseMeta: zeros_sparse_out | ||||
|     SparseCPU, SparseCUDA, SparseMPS, SparseMeta: zeros_sparse_out | ||||
|  | ||||
| - func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor | ||||
|   dispatch: | ||||
| @ -10699,6 +10701,7 @@ | ||||
|   dispatch: | ||||
|     CompositeExplicitAutograd: foreach_tensor_div_list_kernel_slow | ||||
|     CUDA: foreach_tensor_div_list_kernel_cuda | ||||
|     MTIA: foreach_tensor_div_list_kernel_mtia | ||||
|  | ||||
| - func: _foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> () | ||||
|   device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices | ||||
| @ -10706,6 +10709,7 @@ | ||||
|   dispatch: | ||||
|     CompositeExplicitAutograd: foreach_tensor_div_list_kernel_slow_ | ||||
|     CUDA: foreach_tensor_div_list_kernel_cuda_ | ||||
|     MTIA: foreach_tensor_div_list_kernel_mtia_ | ||||
|   autogen: _foreach_div.List_out | ||||
|  | ||||
| - func: _foreach_div.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] | ||||
| @ -10729,6 +10733,7 @@ | ||||
|   dispatch: | ||||
|     CompositeExplicitAutograd: foreach_tensor_div_tensor_kernel_slow | ||||
|     CUDA: foreach_tensor_div_tensor_kernel_cuda | ||||
|     MTIA: foreach_tensor_div_tensor_kernel_mtia | ||||
|  | ||||
| - func: _foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> () | ||||
|   device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices | ||||
| @ -10736,6 +10741,7 @@ | ||||
|   dispatch: | ||||
|     CompositeExplicitAutograd: foreach_tensor_div_tensor_kernel_slow_ | ||||
|     CUDA: foreach_tensor_div_tensor_kernel_cuda_ | ||||
|     MTIA: foreach_tensor_div_tensor_kernel_mtia_ | ||||
|   autogen: _foreach_div.Tensor_out | ||||
|  | ||||
| - func: _foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS | ||||
| #include <ATen/core/Tensor.h> | ||||
| #include <ATen/Dispatch.h> | ||||
| #include <ATen/ceil_div.h> | ||||
| #include <ATen/native/cuda/Loops.cuh> | ||||
| #include <c10/cuda/CUDAGuard.h> | ||||
| @ -21,10 +22,11 @@ | ||||
| namespace at::native { | ||||
|  | ||||
| namespace { | ||||
| template <typename T> | ||||
| __global__ void ChooseQuantizationParamsKernelImpl( | ||||
|     const int64_t* fake_quant_on, | ||||
|     const float* x_min, | ||||
|     const float* x_max, | ||||
|     const T* x_min, | ||||
|     const T* x_max, | ||||
|     int32_t qmin, | ||||
|     int32_t qmax, | ||||
|     int size, | ||||
| @ -93,34 +95,44 @@ __global__ void ChooseQuantizationParamsKernelImpl( | ||||
|   } | ||||
| } | ||||
|  | ||||
| __device__ inline bool isinf_device(float v) { | ||||
|   return ::isinf(v); | ||||
| } | ||||
| __device__ inline bool isinf_device(c10::BFloat16 v) { | ||||
|   return ::isinf(static_cast<float>(v)); | ||||
| } | ||||
|  | ||||
| // CUDA kernel to compute Moving Average Min/Max of the tensor. | ||||
| // It uses the running_min and running_max along with averaging const, c. | ||||
| // The formula used to compute the new min/max is as follows | ||||
| // | ||||
| // running_min = (1 - c) * running_min + c * x_min, if running_min != inf | ||||
| // running_min = x_min, if running_min == inf | ||||
| template <typename T> | ||||
| __global__ void MovingAverageMinMax( | ||||
|     const int64_t* observer_on, | ||||
|     const float* x_min, | ||||
|     const float* x_max, | ||||
|     float* running_min, | ||||
|     float* running_max, | ||||
|     const T* x_min, | ||||
|     const T* x_max, | ||||
|     T* running_min, | ||||
|     T* running_max, | ||||
|     const float averaging_const, | ||||
|     const int size) { | ||||
|   int i = blockIdx.x * blockDim.x + threadIdx.x; | ||||
|  | ||||
|   if (*observer_on == 1) { | ||||
|     if (i < size) { | ||||
|       float curr_min = x_min[i]; | ||||
|       float curr_max = x_max[i]; | ||||
|       T curr_min = x_min[i]; | ||||
|       T curr_max = x_max[i]; | ||||
|  | ||||
|       float adjusted_min = ::isinf(running_min[i]) | ||||
|           ? curr_min | ||||
|           : (running_min[i]) + averaging_const * (curr_min - (running_min[i])); | ||||
|       T averaging_const_t = static_cast<T>(averaging_const); | ||||
|  | ||||
|       float adjusted_max = ::isinf(running_max[i]) | ||||
|           ? curr_max | ||||
|           : (running_max[i]) + averaging_const * (curr_max - (running_max[i])); | ||||
|       T adjusted_min = isinf_device(running_min[i]) ? curr_min | ||||
|                                                     : (running_min[i]) + | ||||
|               averaging_const_t * (curr_min - (running_min[i])); | ||||
|  | ||||
|       T adjusted_max = isinf_device(running_max[i]) ? curr_max | ||||
|                                                     : (running_max[i]) + | ||||
|               averaging_const_t * (curr_max - (running_max[i])); | ||||
|  | ||||
|       running_min[i] = adjusted_min; | ||||
|       running_max[i] = adjusted_max; | ||||
| @ -142,40 +154,51 @@ void _calculate_moving_average( | ||||
|   at::Tensor x_min, x_max; | ||||
|  | ||||
|   int64_t* observer_on_data = observer_on.data_ptr<int64_t>(); | ||||
|   float* running_min_data = running_min.data_ptr<float>(); | ||||
|   float* running_max_data = running_max.data_ptr<float>(); | ||||
|   cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); | ||||
|  | ||||
|   if (per_row_fq) { | ||||
|     std::tie(x_min, x_max) = at::aminmax(x, 1); | ||||
|     float* x_min_data = x_min.data_ptr<float>(); | ||||
|     float* x_max_data = x_max.data_ptr<float>(); | ||||
|     int num_threads = std::min(size, (int64_t)512); | ||||
|     const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads); | ||||
|     AT_DISPATCH_FLOATING_TYPES_AND( | ||||
|         at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] { | ||||
|           scalar_t* x_min_data = x_min.data_ptr<scalar_t>(); | ||||
|           scalar_t* x_max_data = x_max.data_ptr<scalar_t>(); | ||||
|  | ||||
|     // Moving Average Min/Max observer for activations | ||||
|     MovingAverageMinMax<<<num_blocks, num_threads, 0, cuda_stream>>>( | ||||
|         observer_on_data, | ||||
|         x_min_data, | ||||
|         x_max_data, | ||||
|         running_min_data, | ||||
|         running_max_data, | ||||
|         averaging_const, | ||||
|         size); | ||||
|           scalar_t* running_min_data = running_min.data_ptr<scalar_t>(); | ||||
|           scalar_t* running_max_data = running_max.data_ptr<scalar_t>(); | ||||
|  | ||||
|           // Moving Average Min/Max observer for activations | ||||
|           MovingAverageMinMax<<<num_blocks, num_threads, 0, cuda_stream>>>( | ||||
|               observer_on_data, | ||||
|               x_min_data, | ||||
|               x_max_data, | ||||
|               running_min_data, | ||||
|               running_max_data, | ||||
|               averaging_const, | ||||
|               size); | ||||
|         }); | ||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
|   } else { | ||||
|     std::tie(x_min, x_max) = at::aminmax(x); | ||||
|     float* x_min_data = x_min.data_ptr<float>(); | ||||
|     float* x_max_data = x_max.data_ptr<float>(); | ||||
|     // Moving Average Min/Max observer for activations | ||||
|     MovingAverageMinMax<<<1, 1, 0, cuda_stream>>>( | ||||
|         observer_on_data, | ||||
|         x_min_data, | ||||
|         x_max_data, | ||||
|         running_min_data, | ||||
|         running_max_data, | ||||
|         averaging_const, | ||||
|         1 /*size*/); | ||||
|     AT_DISPATCH_FLOATING_TYPES_AND( | ||||
|         at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] { | ||||
|           scalar_t* x_min_data = x_min.data_ptr<scalar_t>(); | ||||
|           scalar_t* x_max_data = x_max.data_ptr<scalar_t>(); | ||||
|  | ||||
|           scalar_t* running_min_data = running_min.data_ptr<scalar_t>(); | ||||
|           scalar_t* running_max_data = running_max.data_ptr<scalar_t>(); | ||||
|  | ||||
|           // Moving Average Min/Max observer for activations | ||||
|           MovingAverageMinMax<<<1, 1, 0, cuda_stream>>>( | ||||
|               observer_on_data, | ||||
|               x_min_data, | ||||
|               x_max_data, | ||||
|               running_min_data, | ||||
|               running_max_data, | ||||
|               averaging_const, | ||||
|               1 /*size*/); | ||||
|         }); | ||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
|   } | ||||
| } | ||||
| @ -198,34 +221,44 @@ void _calc_moving_avg_qparams_helper( | ||||
|   cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); | ||||
|   int64_t* fake_quant_on_data = fake_quant_on.data_ptr<int64_t>(); | ||||
|   if (per_row_fq) { | ||||
|     float* running_min_data = running_min.data_ptr<float>(); | ||||
|     float* running_max_data = running_max.data_ptr<float>(); | ||||
|     int num_threads = std::min(size, (int64_t)512); | ||||
|     const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads); | ||||
|     ChooseQuantizationParamsKernelImpl<<<num_blocks, num_threads, 0, cuda_stream>>>( | ||||
|         fake_quant_on_data, | ||||
|         running_min_data, | ||||
|         running_max_data, | ||||
|         qmin, | ||||
|         qmax, | ||||
|         size, | ||||
|         symmetric_quant, | ||||
|         scale_ptr, | ||||
|         zp_ptr); | ||||
|     AT_DISPATCH_FLOATING_TYPES_AND( | ||||
|         at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] { | ||||
|           scalar_t* running_min_data = running_min.data_ptr<scalar_t>(); | ||||
|           scalar_t* running_max_data = running_max.data_ptr<scalar_t>(); | ||||
|           int num_threads = std::min(size, (int64_t)512); | ||||
|           const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads); | ||||
|           ChooseQuantizationParamsKernelImpl<<< | ||||
|               num_blocks, | ||||
|               num_threads, | ||||
|               0, | ||||
|               cuda_stream>>>( | ||||
|               fake_quant_on_data, | ||||
|               running_min_data, | ||||
|               running_max_data, | ||||
|               qmin, | ||||
|               qmax, | ||||
|               size, | ||||
|               symmetric_quant, | ||||
|               scale_ptr, | ||||
|               zp_ptr); | ||||
|         }); | ||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
|   } else { | ||||
|     float* running_min_data = running_min.data_ptr<float>(); | ||||
|     float* running_max_data = running_max.data_ptr<float>(); | ||||
|     ChooseQuantizationParamsKernelImpl<<<1, 1, 0, cuda_stream>>>( | ||||
|         fake_quant_on_data, | ||||
|         running_min_data, | ||||
|         running_max_data, | ||||
|         qmin, | ||||
|         qmax, | ||||
|         1, // size | ||||
|         symmetric_quant, // preserve_sparsity | ||||
|         scale_ptr, | ||||
|         zp_ptr); | ||||
|     AT_DISPATCH_FLOATING_TYPES_AND( | ||||
|         at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] { | ||||
|           scalar_t* running_min_data = running_min.data_ptr<scalar_t>(); | ||||
|           scalar_t* running_max_data = running_max.data_ptr<scalar_t>(); | ||||
|           ChooseQuantizationParamsKernelImpl<<<1, 1, 0, cuda_stream>>>( | ||||
|               fake_quant_on_data, | ||||
|               running_min_data, | ||||
|               running_max_data, | ||||
|               qmin, | ||||
|               qmax, | ||||
|               1, // size | ||||
|               symmetric_quant, // preserve_sparsity | ||||
|               scale_ptr, | ||||
|               zp_ptr); | ||||
|         }); | ||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
|   } | ||||
| } | ||||
|  | ||||
| @ -42,7 +42,7 @@ TEST(MPSObjCInterfaceTest, MPSCustomKernel) { | ||||
|     id<MTLLibrary> customKernelLibrary = [device newLibraryWithSource: [NSString stringWithUTF8String:CUSTOM_KERNEL] | ||||
|                                                               options: nil | ||||
|                                                                 error: &error]; | ||||
|     TORCH_CHECK(customKernelLibrary, "Failed to to create custom kernel library, error: ", error.localizedDescription.UTF8String); | ||||
|     TORCH_CHECK(customKernelLibrary, "Failed to create custom kernel library, error: ", error.localizedDescription.UTF8String); | ||||
|  | ||||
|     id<MTLFunction> customFunction = [customKernelLibrary newFunctionWithName: @"add_arrays"]; | ||||
|     TORCH_CHECK(customFunction, "Failed to create function state object for the kernel"); | ||||
|  | ||||
| @ -76,4 +76,23 @@ int32_t getGlobalIdxFromDevice(DeviceIndex device) { | ||||
|   return device_global_idxs[device]; | ||||
| } | ||||
|  | ||||
| // Check if a device can access the memory of a peer device directly. | ||||
| bool canDeviceAccessPeer(DeviceIndex device, DeviceIndex peer) { | ||||
|   if (device == -1) { | ||||
|     device = c10::xpu::current_device(); | ||||
|   } | ||||
|   if (peer == -1) { | ||||
|     peer = c10::xpu::current_device(); | ||||
|   } | ||||
|   check_device_index(device); | ||||
|   check_device_index(peer); | ||||
|   // A device can always access itself | ||||
|   if (device == peer) { | ||||
|     return true; | ||||
|   } | ||||
|   return c10::xpu::get_raw_device(device).ext_oneapi_can_access_peer( | ||||
|       c10::xpu::get_raw_device(peer), | ||||
|       sycl::ext::oneapi::peer_access::access_supported); | ||||
| } | ||||
|  | ||||
| } // namespace at::xpu | ||||
|  | ||||
| @ -17,4 +17,6 @@ TORCH_XPU_API DeviceProp* getDeviceProperties(DeviceIndex device); | ||||
|  | ||||
| TORCH_XPU_API int32_t getGlobalIdxFromDevice(DeviceIndex device); | ||||
|  | ||||
| TORCH_XPU_API bool canDeviceAccessPeer(DeviceIndex device, DeviceIndex peer); | ||||
|  | ||||
| } // namespace at::xpu | ||||
|  | ||||
| @ -205,7 +205,7 @@ llama,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| llama_v2_7b_16h,model_fail_to_load,0 | ||||
| llama_v2_7b_16h,pass_due_to_skip,0 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -178,7 +178,7 @@ llama,fail_to_run,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| llama_v2_7b_16h,model_fail_to_load,0 | ||||
| llama_v2_7b_16h,pass_due_to_skip,0 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -198,7 +198,7 @@ llama,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| llama_v2_7b_16h,model_fail_to_load,0 | ||||
| llama_v2_7b_16h,pass_due_to_skip,0 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -171,3 +171,23 @@ XLNetLMHeadModel,pass,5 | ||||
|  | ||||
|  | ||||
| YituTechConvBert,pass,5 | ||||
|  | ||||
|  | ||||
|  | ||||
| meta-llama/Llama-3.2-1B,eager_failed_to_run,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| google/gemma-2-2b,eager_failed_to_run,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| google/gemma-3-4b-it,eager_failed_to_run,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| openai/whisper-tiny,eager_failed_to_run,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| Qwen/Qwen3-0.6B,eager_failed_to_run,0 | ||||
|  | ||||
| 
 | 
| @ -198,7 +198,7 @@ llama,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| llama_v2_7b_16h,model_fail_to_load,0 | ||||
| llama_v2_7b_16h,pass_due_to_skip,0 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -18,7 +18,6 @@ import torch | ||||
|  | ||||
| # needs to be imported after torch | ||||
| import torch.utils.cpp_extension as cpp_extension  # noqa: F401 | ||||
| from torch.utils.benchmark import Timer | ||||
|  | ||||
|  | ||||
| """Performance microbenchmarks. | ||||
| @ -349,24 +348,10 @@ class BenchmarkRunner: | ||||
|             func = test_case.run_jit_forward | ||||
|         if self.use_compile: | ||||
|             func = test_case.run_compile_forward | ||||
|  | ||||
|         if not cuda_sync: | ||||
|             forward_time = timeit.timeit( | ||||
|                 functools.partial(func, iters, print_per_iter, cuda_sync), number=1 | ||||
|             ) | ||||
|             return forward_time | ||||
|         # Stable timing with Timer | ||||
|         timer = Timer( | ||||
|             stmt="func(iters, print_per_iter, cuda_sync)", | ||||
|             globals={ | ||||
|                 "func": func, | ||||
|                 "iters": iters, | ||||
|                 "print_per_iter": print_per_iter, | ||||
|                 "cuda_sync": cuda_sync, | ||||
|             }, | ||||
|         forward_time = timeit.timeit( | ||||
|             functools.partial(func, iters, print_per_iter, cuda_sync), number=1 | ||||
|         ) | ||||
|         result = timer.adaptive_autorange(min_run_time=0.0001) | ||||
|         return result.median * iters | ||||
|         return forward_time | ||||
|  | ||||
|     def _launch_backward(self, test_case, iters, print_per_iter=False): | ||||
|         """This function runs forward path of an op to get an output. Then the backward path is executed | ||||
|  | ||||
| @ -161,8 +161,6 @@ class PyTorchOperatorTestCase: | ||||
|         if self._compile_forward_graph is None: | ||||
|             self._compile_forward_graph = self._generate_compile_forward_graph() | ||||
|         self._compile_forward_graph(num_runs) | ||||
|         if cuda_sync: | ||||
|             torch.cuda.synchronize(torch.cuda.current_device()) | ||||
|  | ||||
|     def _print_per_iter(self): | ||||
|         # print last 50 values | ||||
|  | ||||
| @ -52,6 +52,27 @@ class AddBenchmark(op_bench.TorchBenchmarkBase): | ||||
| op_bench.generate_pt_test(add_long_configs + add_short_configs, AddBenchmark) | ||||
| op_bench.generate_pt_gradient_test(add_long_configs + add_short_configs, AddBenchmark) | ||||
|  | ||||
|  | ||||
| """Mircobenchmark for addmm operator.""" | ||||
|  | ||||
|  | ||||
| class AddmmBenchmark(op_bench.TorchBenchmarkBase): | ||||
|     def init(self, M, N, K, device): | ||||
|         self.inputs = { | ||||
|             "input_one": torch.rand(M, K, device=device, requires_grad=self.auto_set()), | ||||
|             "mat1": torch.rand(M, N, device=device, requires_grad=self.auto_set()), | ||||
|             "mat2": torch.rand(N, K, device=device, requires_grad=self.auto_set()), | ||||
|         } | ||||
|         self.set_module_name("addmm") | ||||
|  | ||||
|     def forward(self, input_one, mat1, mat2): | ||||
|         return torch.addmm(input_one, mat1, mat2) | ||||
|  | ||||
|  | ||||
| op_bench.generate_pt_test(add_long_configs + add_short_configs, AddmmBenchmark) | ||||
| op_bench.generate_pt_gradient_test(add_long_configs + add_short_configs, AddmmBenchmark) | ||||
|  | ||||
|  | ||||
| """Mircobenchmark for addr operator.""" | ||||
|  | ||||
|  | ||||
| @ -85,5 +106,46 @@ addr_configs = op_bench.cross_product_configs( | ||||
| op_bench.generate_pt_test(addr_configs, AddrBenchmark) | ||||
| op_bench.generate_pt_gradient_test(addr_configs, AddrBenchmark) | ||||
|  | ||||
|  | ||||
| """Mircobenchmark for addbmm operator.""" | ||||
|  | ||||
|  | ||||
| class AddbmmBenchmark(op_bench.TorchBenchmarkBase): | ||||
|     def init(self, B, M, N, K, device): | ||||
|         self.inputs = { | ||||
|             "input_one": torch.rand( | ||||
|                 (M, N), device=device, requires_grad=self.auto_set() | ||||
|             ), | ||||
|             "batch1": torch.rand( | ||||
|                 (B, M, K), device=device, requires_grad=self.auto_set() | ||||
|             ), | ||||
|             "batch2": torch.rand( | ||||
|                 ( | ||||
|                     B, | ||||
|                     K, | ||||
|                     N, | ||||
|                 ), | ||||
|                 device=device, | ||||
|                 requires_grad=self.auto_set(), | ||||
|             ), | ||||
|         } | ||||
|         self.set_module_name("addbmm") | ||||
|  | ||||
|     def forward(self, input_one, batch1, batch2): | ||||
|         return torch.addbmm(input_one, batch1, batch2) | ||||
|  | ||||
|  | ||||
| addbmm_configs = op_bench.cross_product_configs( | ||||
|     B=[2, 100], | ||||
|     M=[8, 256], | ||||
|     N=[256, 16], | ||||
|     K=[15, 16], | ||||
|     device=["cpu", "cuda"], | ||||
|     tags=["addbmm"], | ||||
| ) | ||||
|  | ||||
| op_bench.generate_pt_test(addbmm_configs, AddbmmBenchmark) | ||||
| op_bench.generate_pt_gradient_test(addbmm_configs, AddbmmBenchmark) | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     op_bench.benchmark_runner.main() | ||||
|  | ||||
| @ -1,115 +0,0 @@ | ||||
| import operator_benchmark as op_bench | ||||
|  | ||||
| import torch | ||||
|  | ||||
|  | ||||
| """Microbenchmarks for add_(matmul) operator. Supports both Caffe2/PyTorch.""" | ||||
|  | ||||
| # Configs for PT add operator | ||||
| addmm_long_configs = op_bench.cross_product_configs( | ||||
|     M=[256, 1024, 3000], | ||||
|     N=[512, 4096], | ||||
|     K=[512, 4096], | ||||
|     device=["cuda"], | ||||
|     tags=["long"], | ||||
|     dtype=[torch.float16, torch.bfloat16, torch.float32], | ||||
| ) | ||||
|  | ||||
|  | ||||
| addmm_short_configs = op_bench.config_list( | ||||
|     attr_names=["M", "N", "K"], | ||||
|     attrs=[ | ||||
|         [1, 1, 1], | ||||
|         [64, 64, 64], | ||||
|         [64, 64, 128], | ||||
|     ], | ||||
|     cross_product_configs={ | ||||
|         "device": ["cpu", "cuda"], | ||||
|         "dtype": [torch.float], | ||||
|     }, | ||||
|     tags=["short"], | ||||
| ) | ||||
|  | ||||
|  | ||||
| """Mircobenchmark for addmm operator.""" | ||||
|  | ||||
|  | ||||
| class AddmmBenchmark(op_bench.TorchBenchmarkBase): | ||||
|     def init(self, M, N, K, device, dtype): | ||||
|         self.inputs = { | ||||
|             "input_one": torch.rand( | ||||
|                 M, K, device=device, requires_grad=self.auto_set(), dtype=dtype | ||||
|             ), | ||||
|             "mat1": torch.rand( | ||||
|                 M, N, device=device, requires_grad=self.auto_set(), dtype=dtype | ||||
|             ), | ||||
|             "mat2": torch.rand( | ||||
|                 N, K, device=device, requires_grad=self.auto_set(), dtype=dtype | ||||
|             ), | ||||
|         } | ||||
|         self.set_module_name("addmm") | ||||
|  | ||||
|     def forward(self, input_one, mat1, mat2): | ||||
|         return torch.addmm(input_one, mat1, mat2) | ||||
|  | ||||
|  | ||||
| op_bench.generate_pt_test(addmm_long_configs + addmm_long_configs, AddmmBenchmark) | ||||
| op_bench.generate_pt_gradient_test( | ||||
|     addmm_long_configs + addmm_long_configs, AddmmBenchmark | ||||
| ) | ||||
|  | ||||
| """Mircobenchmark for addbmm operator.""" | ||||
|  | ||||
|  | ||||
| class AddbmmBenchmark(op_bench.TorchBenchmarkBase): | ||||
|     def init(self, B, M, N, K, device, dtype): | ||||
|         self.inputs = { | ||||
|             "input_one": torch.rand( | ||||
|                 (M, N), device=device, requires_grad=self.auto_set(), dtype=dtype | ||||
|             ), | ||||
|             "batch1": torch.rand( | ||||
|                 (B, M, K), device=device, requires_grad=self.auto_set(), dtype=dtype | ||||
|             ), | ||||
|             "batch2": torch.rand( | ||||
|                 ( | ||||
|                     B, | ||||
|                     K, | ||||
|                     N, | ||||
|                 ), | ||||
|                 device=device, | ||||
|                 requires_grad=self.auto_set(), | ||||
|                 dtype=dtype, | ||||
|             ), | ||||
|         } | ||||
|         self.set_module_name("addbmm") | ||||
|  | ||||
|     def forward(self, input_one, batch1, batch2): | ||||
|         return torch.addbmm(input_one, batch1, batch2) | ||||
|  | ||||
|  | ||||
| addbmm_long_configs = op_bench.cross_product_configs( | ||||
|     B=[8, 32], | ||||
|     M=[256, 1024], | ||||
|     N=[256, 1024], | ||||
|     K=[64, 128], | ||||
|     device=["cuda"], | ||||
|     dtype=[torch.float16, torch.bfloat16, torch.float32], | ||||
|     tags=["long"], | ||||
| ) | ||||
| addbmm_short_configs = op_bench.cross_product_configs( | ||||
|     B=[1, 8], | ||||
|     M=[8, 128], | ||||
|     N=[32, 64], | ||||
|     K=[256, 512], | ||||
|     device=["cpu", "cuda"], | ||||
|     dtype=[torch.float16, torch.bfloat16, torch.float32], | ||||
|     tags=["short"], | ||||
| ) | ||||
|  | ||||
| op_bench.generate_pt_test(addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark) | ||||
| op_bench.generate_pt_gradient_test( | ||||
|     addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark | ||||
| ) | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     op_bench.benchmark_runner.main() | ||||
| @ -27,12 +27,12 @@ batched_binary_configs_short = op_bench.config_list( | ||||
| ) | ||||
|  | ||||
| batched_binary_configs_long = op_bench.cross_product_configs( | ||||
|     B=[8, 32], | ||||
|     M=[256, 1024], | ||||
|     N=[256, 1024], | ||||
|     K=[64, 128], | ||||
|     device=["cuda"], | ||||
|     dtype=[torch.float32, torch.bfloat16, torch.float16], | ||||
|     B=[1, 128], | ||||
|     M=[8, 128], | ||||
|     N=[32, 64], | ||||
|     K=[4, 256], | ||||
|     device=["cpu", "cuda"], | ||||
|     dtype=[torch.float, torch.bfloat16], | ||||
|     tags=["long"], | ||||
| ) | ||||
|  | ||||
| @ -40,12 +40,8 @@ batched_binary_configs_long = op_bench.cross_product_configs( | ||||
| class BatchedBinaryOpBenchmark(op_bench.TorchBenchmarkBase): | ||||
|     def init(self, B, M, N, K, device, dtype, op_func): | ||||
|         self.inputs = { | ||||
|             "batch1": torch.rand( | ||||
|                 (B, M, N), device=device, dtype=dtype, requires_grad=self.auto_set() | ||||
|             ), | ||||
|             "batch2": torch.rand( | ||||
|                 (B, N, K), device=device, dtype=dtype, requires_grad=self.auto_set() | ||||
|             ), | ||||
|             "batch1": torch.rand((B, M, N), device=device).to(dtype=dtype), | ||||
|             "batch2": torch.rand((B, N, K), device=device).to(dtype=dtype), | ||||
|         } | ||||
|         self.op_func = op_func | ||||
|  | ||||
| @ -58,11 +54,6 @@ op_bench.generate_pt_tests_from_op_list( | ||||
|     batched_binary_configs_short + batched_binary_configs_long, | ||||
|     BatchedBinaryOpBenchmark, | ||||
| ) | ||||
| op_bench.generate_pt_gradient_tests_from_op_list( | ||||
|     batched_binary_ops, | ||||
|     batched_binary_configs_long, | ||||
|     BatchedBinaryOpBenchmark, | ||||
| ) | ||||
|  | ||||
|  | ||||
| # batched ternary ops | ||||
| @ -75,15 +66,9 @@ batched_ternary_ops = op_bench.op_list( | ||||
| class BatchedTernaryOpBenchmark(op_bench.TorchBenchmarkBase): | ||||
|     def init(self, B, M, N, K, device, dtype, op_func): | ||||
|         self.inputs = { | ||||
|             "input_": torch.rand( | ||||
|                 (B, M, K), device=device, dtype=dtype, requires_grad=self.auto_set() | ||||
|             ), | ||||
|             "batch1": torch.rand( | ||||
|                 (B, M, N), device=device, dtype=dtype, requires_grad=self.auto_set() | ||||
|             ), | ||||
|             "batch2": torch.rand( | ||||
|                 (B, N, K), device=device, dtype=dtype, requires_grad=self.auto_set() | ||||
|             ), | ||||
|             "input_": torch.rand((B, M, K), device=device).to(dtype=dtype), | ||||
|             "batch1": torch.rand((B, M, N), device=device).to(dtype=dtype), | ||||
|             "batch2": torch.rand((B, N, K), device=device).to(dtype=dtype), | ||||
|         } | ||||
|         self.op_func = op_func | ||||
|  | ||||
| @ -96,12 +81,6 @@ op_bench.generate_pt_tests_from_op_list( | ||||
|     batched_binary_configs_short + batched_binary_configs_long, | ||||
|     BatchedTernaryOpBenchmark, | ||||
| ) | ||||
| op_bench.generate_pt_gradient_tests_from_op_list( | ||||
|     batched_ternary_ops, | ||||
|     batched_binary_configs_long, | ||||
|     BatchedTernaryOpBenchmark, | ||||
| ) | ||||
|  | ||||
|  | ||||
| # TODO: does it automatically register new scripts? | ||||
|  | ||||
|  | ||||
| @ -13,46 +13,33 @@ mm_short_configs = op_bench.config_list( | ||||
|         [128, 128, 128, True, False], | ||||
|         [256, 256, 256, False, True], | ||||
|     ], | ||||
|     cross_product_configs={"device": ["cpu", "cuda"], "dtype": [torch.float]}, | ||||
|     cross_product_configs={ | ||||
|         "device": ["cpu", "cuda"], | ||||
|     }, | ||||
|     tags=["short"], | ||||
| ) | ||||
|  | ||||
|  | ||||
| mm_long_configs = op_bench.cross_product_configs( | ||||
|     M=[256, 1024, 3000], | ||||
|     N=[512, 4096], | ||||
|     K=[512, 4096], | ||||
|     M=[32], | ||||
|     N=[512, 128], | ||||
|     K=[64], | ||||
|     trans_a=[False, True], | ||||
|     trans_b=[True, False], | ||||
|     device=["cuda"], | ||||
|     dtype=[torch.float16, torch.bfloat16, torch.float32], | ||||
|     device=["cpu", "cuda"], | ||||
|     tags=["long"], | ||||
| ) | ||||
|  | ||||
|  | ||||
| class MatMulBenchmark(op_bench.TorchBenchmarkBase): | ||||
|     def init(self, M, N, K, trans_a, trans_b, device, dtype): | ||||
|         # Create tensors without requires_grad first, then set it separately | ||||
|         # This avoids creating graph leaves that cannot be deep copied | ||||
|         if trans_a: | ||||
|             input_one = torch.rand(M, N, device=device, dtype=dtype) | ||||
|         else: | ||||
|             input_one = torch.rand(N, M, device=device, dtype=dtype).t() | ||||
|  | ||||
|         if trans_b: | ||||
|             input_two = torch.rand(N, K, device=device, dtype=dtype) | ||||
|         else: | ||||
|             input_two = torch.rand(K, N, device=device, dtype=dtype).t() | ||||
|  | ||||
|         # Set requires_grad after tensor creation to avoid graph leaf issues | ||||
|         if self.auto_set(): | ||||
|             input_one.requires_grad_(True) | ||||
|         if self.auto_set(): | ||||
|             input_two.requires_grad_(True) | ||||
|  | ||||
|     def init(self, M, N, K, trans_a, trans_b, device): | ||||
|         self.inputs = { | ||||
|             "input_one": input_one, | ||||
|             "input_two": input_two, | ||||
|             "input_one": torch.rand(M, N, device=device) | ||||
|             if trans_a | ||||
|             else torch.rand(N, M, device=device).t(), | ||||
|             "input_two": torch.rand(N, K, device=device) | ||||
|             if trans_b | ||||
|             else torch.rand(K, N, device=device).t(), | ||||
|         } | ||||
|         self.set_module_name("matmul") | ||||
|  | ||||
| @ -61,7 +48,6 @@ class MatMulBenchmark(op_bench.TorchBenchmarkBase): | ||||
|  | ||||
|  | ||||
| op_bench.generate_pt_test(mm_long_configs + mm_short_configs, MatMulBenchmark) | ||||
| op_bench.generate_pt_gradient_test(mm_long_configs, MatMulBenchmark) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
| @ -23,11 +23,11 @@ mm_short_configs = op_bench.config_list( | ||||
| ) | ||||
|  | ||||
| mm_long_configs = op_bench.cross_product_configs( | ||||
|     M=[256, 1024, 3000], | ||||
|     N=[512, 4096], | ||||
|     K=[512, 4096], | ||||
|     device=["cuda"], | ||||
|     dtype=[torch.float16, torch.bfloat16, torch.float32], | ||||
|     M=[8, 128], | ||||
|     N=[32, 64], | ||||
|     K=[256, 512], | ||||
|     device=["cpu", "cuda"], | ||||
|     dtype=[torch.float, torch.bfloat16], | ||||
|     tags=["long"], | ||||
| ) | ||||
|  | ||||
| @ -35,12 +35,8 @@ mm_long_configs = op_bench.cross_product_configs( | ||||
| class MmOpBenchmark(op_bench.TorchBenchmarkBase): | ||||
|     def init(self, M, N, K, device, dtype, op_func): | ||||
|         self.inputs = { | ||||
|             "input_one": torch.randn( | ||||
|                 M, N, device=device, requires_grad=self.auto_set(), dtype=dtype | ||||
|             ), | ||||
|             "input_two": torch.randn( | ||||
|                 N, K, device=device, requires_grad=self.auto_set(), dtype=dtype | ||||
|             ), | ||||
|             "input_one": torch.randn(M, N, device=device).to(dtype=dtype), | ||||
|             "input_two": torch.randn(N, K, device=device).to(dtype=dtype), | ||||
|         } | ||||
|         self.op_func = op_func | ||||
|  | ||||
| @ -51,9 +47,6 @@ class MmOpBenchmark(op_bench.TorchBenchmarkBase): | ||||
| op_bench.generate_pt_tests_from_op_list( | ||||
|     ops_list, mm_short_configs + mm_long_configs, MmOpBenchmark | ||||
| ) | ||||
| op_bench.generate_pt_gradient_tests_from_op_list( | ||||
|     ops_list, mm_long_configs, MmOpBenchmark | ||||
| ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
| @ -61,6 +61,22 @@ struct C10_API Storage { | ||||
|             allocator, | ||||
|             resizable)) {} | ||||
|  | ||||
|   // Creates storage with pre-allocated memory buffer. Allocator is given for | ||||
|   // potential future reallocations, however it can be nullptr if the storage | ||||
|   // is non-resizable | ||||
|   Storage( | ||||
|       use_byte_size_t /*use_byte_size*/, | ||||
|       SymInt size_bytes, | ||||
|       at::DataPtr data_ptr, | ||||
|       at::Allocator* allocator = nullptr, | ||||
|       bool resizable = false) | ||||
|       : storage_impl_(c10::make_intrusive<StorageImpl>( | ||||
|             StorageImpl::use_byte_size_t(), | ||||
|             std::move(size_bytes), | ||||
|             std::move(data_ptr), | ||||
|             allocator, | ||||
|             resizable)) {} | ||||
|  | ||||
|  protected: | ||||
|   explicit Storage(unsafe_borrow_t, const Storage& rhs) | ||||
|       : storage_impl_(c10::intrusive_ptr<c10::StorageImpl>::reclaim( | ||||
|  | ||||
| @ -3269,7 +3269,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl { | ||||
|     is_le<sizeof(autograd_meta_),         16,  FieldNameEnum::autograd_meta_>(); | ||||
|     is_le<sizeof(extra_meta_),            16,  FieldNameEnum::extra_meta_>(); | ||||
|     are_equal<sizeof(version_counter_),    8,  FieldNameEnum::version_counter_>(); | ||||
|     are_equal<sizeof(pyobj_slot_),   16,  FieldNameEnum::pyobj_slot_>(); | ||||
|     are_equal<sizeof(pyobj_slot_),         8,  FieldNameEnum::pyobj_slot_>(); | ||||
|     are_equal<sizeof(sizes_and_strides_), 88,  FieldNameEnum::sizes_and_strides_>(); | ||||
|     are_equal<sizeof(storage_offset_),     8,  FieldNameEnum::storage_offset_>(); | ||||
|     are_equal<sizeof(numel_),              8,  FieldNameEnum::numel_>(); | ||||
|  | ||||
| @ -1,4 +1,5 @@ | ||||
| #include <c10/core/impl/DeviceGuardImplInterface.h> | ||||
| #include <c10/core/impl/FakeGuardImpl.h> | ||||
| #include <array> | ||||
|  | ||||
| namespace c10::impl { | ||||
| @ -14,4 +15,26 @@ DeviceGuardImplRegistrar::DeviceGuardImplRegistrar( | ||||
|   device_guard_impl_registry[static_cast<size_t>(type)].store(impl); | ||||
| } | ||||
|  | ||||
| namespace { | ||||
| thread_local std::unique_ptr<DeviceGuardImplInterface> tls_fake_device_guard = | ||||
|     nullptr; | ||||
| } | ||||
|  | ||||
| void ensureCUDADeviceGuardSet() { | ||||
|   constexpr auto cuda_idx = static_cast<std::size_t>(DeviceType::CUDA); | ||||
|  | ||||
|   const DeviceGuardImplInterface* p = | ||||
|       device_guard_impl_registry[cuda_idx].load(); | ||||
|  | ||||
|   // A non-null `ptr` indicates that the CUDA guard is already set up, | ||||
|   // implying this is using cuda build | ||||
|   if (p && p->deviceCount() == 0) { | ||||
|     // In following cases, we override CUDA guard interface with a no-op | ||||
|     // device guard. When p->deviceCount() == 0, cuda build is enabled, but no | ||||
|     // cuda devices available. | ||||
|     tls_fake_device_guard = std::make_unique<FakeGuardImpl<DeviceType::CUDA>>(); | ||||
|     device_guard_impl_registry[cuda_idx].store(tls_fake_device_guard.get()); | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace c10::impl | ||||
|  | ||||
| @ -6,6 +6,7 @@ | ||||
| #include <c10/util/Exception.h> | ||||
|  | ||||
| // Just for C10_ANONYMOUS_VARIABLE | ||||
| #include <c10/core/impl/TorchDispatchModeTLS.h> | ||||
| #include <c10/util/Registry.h> | ||||
|  | ||||
| #include <array> | ||||
| @ -251,7 +252,7 @@ struct C10_API DeviceGuardImplInterface { | ||||
| // for devices that don't actually have a concept of device index.  Prominent | ||||
| // examples are CPU and Meta. | ||||
| template <DeviceType D> | ||||
| struct NoOpDeviceGuardImpl final : public DeviceGuardImplInterface { | ||||
| struct NoOpDeviceGuardImpl : public DeviceGuardImplInterface { | ||||
|   NoOpDeviceGuardImpl() = default; | ||||
|   DeviceType type() const override { | ||||
|     return D; | ||||
| @ -371,5 +372,7 @@ inline bool hasDeviceGuardImpl(DeviceType type) { | ||||
|   return device_guard_impl_registry[static_cast<size_t>(type)].load(); | ||||
| } | ||||
|  | ||||
| void C10_API ensureCUDADeviceGuardSet(); | ||||
|  | ||||
| } // namespace impl | ||||
| } // namespace c10 | ||||
|  | ||||
| @ -13,11 +13,10 @@ struct C10_API PyInterpreterHooksInterface { | ||||
|  | ||||
|   // Get the PyInterpreter instance | ||||
|   // Stub implementation throws error when Python is not available | ||||
|   // We return nullptr rather than throwing an error since there are bits of c10 | ||||
|   // that expect an empty PyObjectSlot when python is not available. | ||||
|   virtual PyInterpreter* getPyInterpreter() const { | ||||
|     TORCH_CHECK( | ||||
|         false, | ||||
|         "PyTorch was compiled without Python support. " | ||||
|         "Cannot access Python interpreter from C++."); | ||||
|     return nullptr; | ||||
|   } | ||||
| }; | ||||
|  | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
|  | ||||
| namespace c10::impl { | ||||
|  | ||||
| PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {} | ||||
| PyObjectSlot::PyObjectSlot() : pyobj_(nullptr) {} | ||||
|  | ||||
| PyObjectSlot::~PyObjectSlot() { | ||||
|   maybe_destroy_pyobj(); | ||||
| @ -10,9 +10,9 @@ PyObjectSlot::~PyObjectSlot() { | ||||
|  | ||||
| void PyObjectSlot::maybe_destroy_pyobj() { | ||||
|   if (owns_pyobj()) { | ||||
|     TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr); | ||||
|     TORCH_INTERNAL_ASSERT(getGlobalPyInterpreter() != nullptr); | ||||
|     TORCH_INTERNAL_ASSERT(pyobj_ != nullptr); | ||||
|     (*pyobj_interpreter_.load(std::memory_order_acquire)) | ||||
|     (*getGlobalPyInterpreter()) | ||||
|         ->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true); | ||||
|     // NB: this destructor can only be entered when there are no | ||||
|     // references to this C++ object (obviously), NOR any references | ||||
| @ -25,7 +25,7 @@ void PyObjectSlot::maybe_destroy_pyobj() { | ||||
| } | ||||
|  | ||||
| PyInterpreter* PyObjectSlot::pyobj_interpreter() { | ||||
|   return pyobj_interpreter_.load(std::memory_order_acquire); | ||||
|   return getGlobalPyInterpreter(); | ||||
| } | ||||
|  | ||||
| PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const { | ||||
| @ -35,7 +35,7 @@ PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const { | ||||
| } | ||||
|  | ||||
| PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const { | ||||
|   auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire); | ||||
|   auto interpreter = getGlobalPyInterpreter(); | ||||
|   if (interpreter) { | ||||
|     return *interpreter; | ||||
|   } | ||||
|  | ||||
| @ -6,10 +6,17 @@ | ||||
| #include <c10/util/python_stub.h> | ||||
| #include <optional> | ||||
|  | ||||
| #include <atomic> | ||||
|  | ||||
| namespace c10::impl { | ||||
|  | ||||
| // Function pointer type for getting the global interpreter | ||||
| using GetPyInterpreterFn = PyInterpreter* (*)(); | ||||
|  | ||||
| // Global function pointer (set by csrc initialization) | ||||
| C10_API extern GetPyInterpreterFn g_get_pyinterpreter_fn; | ||||
|  | ||||
| // Helper function to get the global interpreter | ||||
| C10_API PyInterpreter* getGlobalPyInterpreter(); | ||||
|  | ||||
| struct C10_API PyObjectSlot { | ||||
|  public: | ||||
|   PyObjectSlot(); | ||||
| @ -26,8 +33,6 @@ struct C10_API PyObjectSlot { | ||||
|   // NB: THIS FUNCTION CAN RAISE AN EXCEPTION.  Make sure to clean up after | ||||
|   // PyObject if necessary! | ||||
|   void init_pyobj(PyObject* pyobj) { | ||||
|     pyobj_interpreter_.store( | ||||
|         getGlobalPyInterpreter(), std::memory_order_relaxed); | ||||
|     pyobj_ = pyobj; | ||||
|   } | ||||
|  | ||||
| @ -55,18 +60,15 @@ struct C10_API PyObjectSlot { | ||||
|  | ||||
|   // @todo alban: I'm not too sure what's going on here, we can probably delete | ||||
|   // it but it's worthwhile making sure | ||||
|   std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const { | ||||
|     impl::PyInterpreter* interpreter = | ||||
|         pyobj_interpreter_.load(std::memory_order_acquire); | ||||
|     if (interpreter == nullptr) { | ||||
|   std::optional<PyObject*> check_pyobj() const { | ||||
|     impl::PyInterpreter* interpreter = getGlobalPyInterpreter(); | ||||
|     if (interpreter == nullptr || pyobj_ == nullptr) { | ||||
|       return std::nullopt; | ||||
|     } | ||||
|  | ||||
|     if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) { | ||||
|     if (c10::impl::HermeticPyObjectTLS::get_state()) { | ||||
|       return std::nullopt; | ||||
|     } else { | ||||
|       return _unchecked_untagged_pyobj(); | ||||
|     } | ||||
|     return _unchecked_untagged_pyobj(); | ||||
|   } | ||||
|  | ||||
|   PyInterpreter& load_pyobj_interpreter() const; | ||||
| @ -76,30 +78,6 @@ struct C10_API PyObjectSlot { | ||||
|   void set_owns_pyobj(bool b); | ||||
|  | ||||
|  private: | ||||
|   // This field contains the interpreter tag for this object.  See | ||||
|   // Note [Python interpreter tag] for general context | ||||
|   // | ||||
|   // Note [Memory ordering on Python interpreter tag] | ||||
|   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
|   // What memory_order do we need when accessing this atomic?  We don't | ||||
|   // need a single total modification order (as provided by | ||||
|   // memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only | ||||
|   // transition from -1 to some positive integer and never changes afterwards. | ||||
|   // Because there is only one modification, it trivially already has a total | ||||
|   // modification order (e.g., we don't need fences or locked instructions on | ||||
|   // x86) | ||||
|   // | ||||
|   // In fact, one could make a reasonable argument that relaxed reads are OK, | ||||
|   // due to the presence of external locking (GIL) to ensure that interactions | ||||
|   // with other data structures are still correctly synchronized, so that | ||||
|   // we fall in the "Single-Location Data Structures" case as described in | ||||
|   // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf | ||||
|   // However, on x86, it doesn't matter if I use acquire or relaxed on the load | ||||
|   // as I get the same assembly in both cases.  So I just use the more | ||||
|   // conservative acquire (which will impede compiler optimizations but I don't | ||||
|   // care) | ||||
|   std::atomic<PyInterpreter*> pyobj_interpreter_; | ||||
|  | ||||
|   // This field contains a reference to a PyObject representing this Tensor. | ||||
|   // If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new | ||||
|   // PyObject for it and set this field.  This field does not have to be | ||||
|  | ||||
| @ -52,7 +52,7 @@ struct maybe_bool { | ||||
| template <typename src_t> | ||||
| struct maybe_bool<true, src_t> { | ||||
|   C10_HOST_DEVICE static inline decltype(auto) apply(src_t src) { | ||||
|     // Don't use bool operator so as to to also compile for ComplexHalf. | ||||
|     // Don't use bool operator so as to also compile for ComplexHalf. | ||||
|     return src.real() || src.imag(); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| #if defined(__aarch64__) && defined(CAFFE2_PERF_WITH_SVE128) | ||||
| #if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(CAFFE2_PERF_WITH_SVE128) | ||||
| #include <arm_neon.h> | ||||
| #include <arm_neon_sve_bridge.h> | ||||
| #include <arm_sve.h> | ||||
| @ -240,4 +240,4 @@ template void compute_batch_box_cox__sve128<float>( | ||||
|  | ||||
| } // namespace caffe2::details | ||||
|  | ||||
| #endif // __aarch64__ && CAFFE2_PERF_WITH_SVE128 | ||||
| #endif // __aarch64__ && __ARM_FEATURE_SVE && CAFFE2_PERF_WITH_SVE128 | ||||
|  | ||||
| @ -158,6 +158,7 @@ function(caffe2_print_configuration_summary) | ||||
|   if(${USE_KLEIDIAI}) | ||||
|     message(STATUS "  USE_KLEIDIAI          : ${USE_KLEIDIAI}") | ||||
|   endif() | ||||
|   message(STATUS "  USE_PRIORITIZED_TEXT_FOR_LD : ${USE_PRIORITIZED_TEXT_FOR_LD}") | ||||
|   message(STATUS "  USE_UCC               : ${USE_UCC}") | ||||
|   if(${USE_UCC}) | ||||
|     message(STATUS "    USE_SYSTEM_UCC        : ${USE_SYSTEM_UCC}") | ||||
|  | ||||
| @ -482,6 +482,7 @@ function(torch_update_find_cuda_flags) | ||||
| endfunction() | ||||
|  | ||||
| include(CheckCXXCompilerFlag) | ||||
| include(CheckLinkerFlag) | ||||
|  | ||||
| ############################################################################## | ||||
| # CHeck if given flag is supported and append it to provided outputvar | ||||
| @ -511,3 +512,22 @@ function(target_compile_options_if_supported target flag) | ||||
|     target_compile_options(${target} PRIVATE ${flag}) | ||||
|   endif() | ||||
| endfunction() | ||||
|  | ||||
| # Check if a global link option is supported | ||||
| function(add_link_options_if_supported flag) | ||||
|   check_linker_flag(C "LINKER:${flag}" _supported) | ||||
|   if("${_supported}") | ||||
|     add_link_options("LINKER:${flag}") | ||||
|   else() | ||||
|     message(WARNING "Attempted to use unsupported link option : ${flag}.") | ||||
|   endif() | ||||
| endfunction() | ||||
|  | ||||
| function(target_link_options_if_supported tgt flag) | ||||
|   check_linker_flag(C "LINKER:${flag}" _supported) | ||||
|   if("${_supported}") | ||||
|     target_link_options("${tgt}" PRIVATE "LINKER:${flag}") | ||||
|   else() | ||||
|     message(WARNING "Attempted to use unsupported link option : ${flag}.") | ||||
|   endif() | ||||
| endfunction() | ||||
							
								
								
									
										
											BIN
										
									
								
								docs/source/_static/img/inductor_provenance/index_2.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								docs/source/_static/img/inductor_provenance/index_2.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 563 KiB | 
							
								
								
									
										
											BIN
										
									
								
								docs/source/_static/img/inductor_provenance/kernel_source_1.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								docs/source/_static/img/inductor_provenance/kernel_source_1.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 281 KiB | 
							
								
								
									
										
											BIN
										
									
								
								docs/source/_static/img/inductor_provenance/kernel_source_2.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								docs/source/_static/img/inductor_provenance/kernel_source_2.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 348 KiB | 
							
								
								
									
										
											BIN
										
									
								
								docs/source/_static/img/inductor_provenance/kernel_source_3.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								docs/source/_static/img/inductor_provenance/kernel_source_3.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 747 KiB | 
| @ -1,62 +0,0 @@ | ||||
| torch.utils.bottleneck | ||||
| ====================== | ||||
|  | ||||
| .. automodule:: torch.utils.bottleneck | ||||
| .. currentmodule:: torch.utils.bottleneck | ||||
|  | ||||
| `torch.utils.bottleneck` is a tool that can be used as an initial step for | ||||
| debugging bottlenecks in your program. It summarizes runs of your script with | ||||
| the Python profiler and PyTorch's autograd profiler. | ||||
|  | ||||
| Run it on the command line with | ||||
|  | ||||
| :: | ||||
|  | ||||
|     python -m torch.utils.bottleneck /path/to/source/script.py [args] | ||||
|  | ||||
| where [args] are any number of arguments to `script.py`, or run | ||||
| ``python -m torch.utils.bottleneck -h`` for more usage instructions. | ||||
|  | ||||
| .. warning:: | ||||
|     Because your script will be profiled, please ensure that it exits in a | ||||
|     finite amount of time. | ||||
|  | ||||
| .. warning:: | ||||
|     Due to the asynchronous nature of CUDA kernels, when running against | ||||
|     CUDA code, the cProfile output and CPU-mode autograd profilers may | ||||
|     not show correct timings: the reported CPU time reports the amount of time | ||||
|     used to launch the kernels but does not include the time the kernel | ||||
|     spent executing on a GPU unless the operation does a synchronize. | ||||
|     Ops that do synchronize appear to be extremely expensive under regular | ||||
|     CPU-mode profilers. | ||||
|     In these case where timings are incorrect, the CUDA-mode autograd profiler | ||||
|     may be helpful. | ||||
|  | ||||
| .. note:: | ||||
|     To decide which (CPU-only-mode or CUDA-mode) autograd profiler output to | ||||
|     look at, you should first check if your script is CPU-bound | ||||
|     ("CPU total time is much greater than CUDA total time"). | ||||
|     If it is CPU-bound, looking at the results of the CPU-mode autograd | ||||
|     profiler will help. If on the other hand your script spends most of its | ||||
|     time executing on the GPU, then it makes sense to start | ||||
|     looking for responsible CUDA operators in the output of the CUDA-mode | ||||
|     autograd profiler. | ||||
|  | ||||
|     Of course the reality is much more complicated and your script might not be | ||||
|     in one of those two extremes depending on the part of the model you're | ||||
|     evaluating. If the profiler outputs don't help, you could try looking at | ||||
|     the result of :func:`torch.autograd.profiler.emit_nvtx()` with ``nvprof``. | ||||
|     However, please take into account that the NVTX overhead is very high and | ||||
|     often gives a heavily skewed timeline. Similarly, ``Intel® VTune™ Profiler`` | ||||
|     helps to analyze performance on Intel platforms further with | ||||
|     :func:`torch.autograd.profiler.emit_itt()`. | ||||
|  | ||||
| .. warning:: | ||||
|     If you are profiling CUDA code, the first profiler that ``bottleneck`` runs | ||||
|     (cProfile) will include the CUDA startup time (CUDA buffer allocation cost) | ||||
|     in its time reporting. This should not matter if your bottlenecks result | ||||
|     in code much slower than the CUDA startup time. | ||||
|  | ||||
| For more complicated uses of the profilers (like in a multi-GPU case), | ||||
| please see https://docs.python.org/3/library/profile.html | ||||
| or :func:`torch.autograd.profiler.profile()` for more information. | ||||
| @ -210,10 +210,6 @@ templates_path = [ | ||||
| coverage_ignore_functions = [ | ||||
|     # torch | ||||
|     "typename", | ||||
|     # torch.cuda | ||||
|     "check_error", | ||||
|     "cudart", | ||||
|     "is_bf16_supported", | ||||
|     # torch.cuda._sanitizer | ||||
|     "zip_arguments", | ||||
|     "zip_by_key", | ||||
|  | ||||
| @ -15,6 +15,7 @@ | ||||
|  | ||||
|     StreamContext | ||||
|     can_device_access_peer | ||||
|     check_error | ||||
|     current_blas_handle | ||||
|     current_device | ||||
|     current_stream | ||||
| @ -34,6 +35,7 @@ | ||||
|     init | ||||
|     ipc_collect | ||||
|     is_available | ||||
|     is_bf16_supported | ||||
|     is_initialized | ||||
|     is_tf32_supported | ||||
|     memory_usage | ||||
|  | ||||
| @ -260,3 +260,73 @@ these features. | ||||
| ```{eval-rst} | ||||
| .. py:module:: torch.distributed.tensor.device_mesh | ||||
| ``` | ||||
|  | ||||
| ## Mixed Tensor and DTensor operations | ||||
|  | ||||
| So you got the following error message. | ||||
| ``` | ||||
| got mixed torch.Tensor and DTensor, need to convert all | ||||
| torch.Tensor to DTensor before calling distributed operators! | ||||
| ``` | ||||
|  | ||||
| There are two cases. | ||||
|  | ||||
| ### Case 1: this is user error | ||||
|  | ||||
| The most common way to run into this error is to create a regular Tensor | ||||
| (using a factory function) and then perform a Tensor-DTensor operation, | ||||
| like the following: | ||||
|  | ||||
| ``` | ||||
| tensor = torch.arange(10) | ||||
| return tensor + dtensor | ||||
| ``` | ||||
|  | ||||
| We disallow mixed Tensor-DTensor operations: if the input to any operations | ||||
| (e.g. torch.add) is a DTensor, then all Tensor inputs must be DTensors. | ||||
| This is because the semantics are ambiguous. We don't know if `tensor` is | ||||
| the same across ranks or if it is different so we ask that the user | ||||
| figure out how to construct a DTensor with accurate placements from `tensor`. | ||||
|  | ||||
| If each rank does have the same `tensor`, then please construct a replicated | ||||
| DTensor: | ||||
|  | ||||
| ``` | ||||
| tensor = torch.arange(10) | ||||
| tensor = DTensor.from_local(tensor, placements=(Replicate(),)) | ||||
| return tensor + dtensor | ||||
| ``` | ||||
|  | ||||
| If you wanted to create a DTensor with shards, below is how to do it. | ||||
| Semantically this means that your Tensor data is split between the shards | ||||
| and that operations act on the "full stacked data". | ||||
|  | ||||
| ``` | ||||
| tensor = torch.full([], RANK) | ||||
| tensor = DTensor.from_local(tensor, placements=(Shard(0),)) | ||||
| return tensor + dtensor | ||||
| ``` | ||||
|  | ||||
| There are other things you may wish to do with your tensor beyond | ||||
| these situations (these are not the only two options!). | ||||
|  | ||||
| ## Case 2: the error came from PyTorch framework code | ||||
|  | ||||
| Sometimes the problem is that PyTorch framework code attempts to perform mixed | ||||
| Tensor-DTensor operations. These are bugs in PyTorch, please file an issue | ||||
| so that we can fix them. | ||||
|  | ||||
| On the user side, the only thing you can do is to avoid using the operation | ||||
| that caused the issue and file a bug report. | ||||
|  | ||||
| For PyTorch Developers: one approach of fixing this is to rewrite PyTorch | ||||
| framework code to avoid mixed Tensor-DTensor code (like in the previous section). | ||||
|  | ||||
| For PyTorch Developers: the second approach is to turn on DTensor implicit | ||||
| replication inside the right places in PyTorch framework code. | ||||
| When on, any mixed Tensor-DTensor operations will assume that the | ||||
| non-DTensors can be replicated. Please be careful when using this as it | ||||
| can lead to silent incorrectness. | ||||
|  | ||||
| - [Turning on implicit replication in Python](https://github.com/pytorch/pytorch/blob/d8e6b2fddc54c748d976e8f0ebe4b63ebe36d85b/torch/distributed/tensor/experimental/__init__.py#L15) | ||||
| - [Turning on implicit replication in C++](https://github.com/pytorch/pytorch/blob/7a0f93344e2c851b9bcf2b9c3225a323d48fde26/aten/src/ATen/DTensorState.h#L10) | ||||
|  | ||||
| @ -22,6 +22,7 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined | ||||
|     device_count | ||||
|     init | ||||
|     is_available | ||||
|     is_bf16_supported | ||||
|     is_initialized | ||||
|     memory_stats | ||||
|     get_device_capability | ||||
|  | ||||
| @ -76,7 +76,6 @@ storage | ||||
| torch.testing <testing> | ||||
| torch.utils <utils> | ||||
| torch.utils.benchmark <benchmark_utils> | ||||
| torch.utils.bottleneck <bottleneck> | ||||
| torch.utils.checkpoint <checkpoint> | ||||
| torch.utils.cpp_extension <cpp_extension> | ||||
| torch.utils.data <data> | ||||
|  | ||||
| @ -3,12 +3,6 @@ | ||||
| TorchInductor and AOTInductor Provenance Tracking | ||||
| ================================================= | ||||
|  | ||||
| .. warning:: | ||||
|     This feature is a prototype under active development and there will be | ||||
|     breaking change in future releases. | ||||
|     The current compatibility of this tool is limited to the latest nightly build of PyTorch. | ||||
|  | ||||
|  | ||||
| This section describes how to use the provenance tracking feature for TorchInductor and AOTInductor in ``tlparse``. | ||||
| Provenance tracking helps you visualize the relationships between the input GraphModule to (AOT)Inductor and the optimized code generated. This feature allows you to trace how your original operations are transformed during compilation. | ||||
|  | ||||
| @ -37,7 +31,7 @@ Follow these steps to enable and use provenance tracking in your PyTorch project | ||||
|  | ||||
|    .. code-block:: bash | ||||
|  | ||||
|      TORCH_TRACE=~/my_trace_log_dir TORCH_LOGS="+inductor" TORCH_COMPILE_DEBUG=1 python your_program.py | ||||
|      TORCH_TRACE=~/my_trace_log_dir INDUCTOR_PROVENANCE=1 python your_program.py | ||||
|  | ||||
|    This will generate a log file in ``~/my_trace_log_dir``. The log file will be used by tlparse to generate the provenance tracking highlighter. | ||||
| 3. Run ``tlparse`` on the log with ``--inductor-provenance`` flag. For example: | ||||
| @ -62,6 +56,24 @@ For a demo, see: https://github.com/pytorch/tlparse/pull/93 | ||||
|  .. image:: _static/img/inductor_provenance/index.png | ||||
|  | ||||
|  | ||||
| Source code corresponding to each Inductor kernel | ||||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
|  | ||||
| With ``INDUCTOR_PROVENANCE=1``, you can also view the source code corresponding to each Inductor kernel in tlparse. To access it, click the "readable_html" link next to "inductor_provenance_tracking_kernel_stack_traces.json" in the tlparse output. | ||||
|  | ||||
|  .. image:: _static/img/inductor_provenance/index_2.png | ||||
|  | ||||
|  | ||||
| Below are some example screenshots. The ``:1`` and ``:467`` suffixes at the end of the kernel names are used to distinguish different calls to the same kernel. We refer to these suffixes as debug handles. | ||||
|  | ||||
|  .. image:: _static/img/inductor_provenance/kernel_source_1.png | ||||
|  .. image:: _static/img/inductor_provenance/kernel_source_2.png | ||||
|  | ||||
| You can also find the debug handle in the comments within the kernel source code. | ||||
|  | ||||
|  .. image:: _static/img/inductor_provenance/kernel_source_3.png | ||||
|  | ||||
|  | ||||
| See Also | ||||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
|  | ||||
|  | ||||
| @ -12,6 +12,7 @@ | ||||
|     :nosignatures: | ||||
|  | ||||
|     StreamContext | ||||
|     can_device_access_peer | ||||
|     current_device | ||||
|     current_stream | ||||
|     device | ||||
| @ -25,6 +26,7 @@ | ||||
|     get_stream_from_external | ||||
|     init | ||||
|     is_available | ||||
|     is_bf16_supported | ||||
|     is_initialized | ||||
|     set_device | ||||
|     set_stream | ||||
|  | ||||
| @ -1187,8 +1187,7 @@ int64_t _Tensor_ndim(mpy::handle h) { | ||||
| mpy::handle handle_from_tensor(Arena& A, TensorRef t) { | ||||
|   // fast case: tensor is live in python | ||||
|   std::optional<PyObject*> mb_obj = | ||||
|       t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( | ||||
|           /*ignore_hermetic_tls=*/false); | ||||
|       t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(); | ||||
|   if (mb_obj.has_value() && | ||||
|       !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { | ||||
|     return *mb_obj; | ||||
|  | ||||
							
								
								
									
										24
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								setup.py
									
									
									
									
									
								
							| @ -227,9 +227,6 @@ | ||||
| #      Static link mimalloc into C10, and use mimalloc in alloc_cpu & alloc_free. | ||||
| #      By default, It is only enabled on Windows. | ||||
| # | ||||
| #   USE_PRIORITIZED_TEXT_FOR_LD | ||||
| #      Uses prioritized text form cmake/prioritized_text.txt for LD | ||||
| # | ||||
| #   BUILD_LIBTORCH_WHL | ||||
| #      Builds libtorch.so and its dependencies as a wheel | ||||
| # | ||||
| @ -323,7 +320,6 @@ from tools.setup_helpers.env import ( | ||||
|     IS_LINUX, | ||||
|     IS_WINDOWS, | ||||
| ) | ||||
| from tools.setup_helpers.generate_linker_script import gen_linker_script | ||||
|  | ||||
|  | ||||
| def str2bool(value: str | None) -> bool: | ||||
| @ -1627,26 +1623,6 @@ def main() -> None: | ||||
|     if BUILD_PYTHON_ONLY: | ||||
|         install_requires += [f"{LIBTORCH_PKG_NAME}=={TORCH_VERSION}"] | ||||
|  | ||||
|     if str2bool(os.getenv("USE_PRIORITIZED_TEXT_FOR_LD")): | ||||
|         gen_linker_script( | ||||
|             filein="cmake/prioritized_text.txt", fout="cmake/linker_script.ld" | ||||
|         ) | ||||
|         linker_script_path = os.path.abspath("cmake/linker_script.ld") | ||||
|         os.environ["LDFLAGS"] = os.getenv("LDFLAGS", "") + f" -T{linker_script_path}" | ||||
|         os.environ["CFLAGS"] = ( | ||||
|             os.getenv("CFLAGS", "") + " -ffunction-sections -fdata-sections" | ||||
|         ) | ||||
|         os.environ["CXXFLAGS"] = ( | ||||
|             os.getenv("CXXFLAGS", "") + " -ffunction-sections -fdata-sections" | ||||
|         ) | ||||
|     elif platform.system() == "Linux" and platform.processor() == "aarch64": | ||||
|         print_box( | ||||
|             """ | ||||
|             WARNING: we strongly recommend enabling linker script optimization for ARM + CUDA. | ||||
|             To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1 | ||||
|             """ | ||||
|         ) | ||||
|  | ||||
|     # Parse the command line and check the arguments before we proceed with | ||||
|     # building deps and setup. We need to set values so `--help` works. | ||||
|     dist = Distribution() | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # Owner(s): ["module: unknown"] | ||||
| # Owner(s): ["module: sparse"] | ||||
|  | ||||
| import copy | ||||
|  | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # Owner(s): ["module: unknown"] | ||||
| # Owner(s): ["module: sparse"] | ||||
|  | ||||
|  | ||||
| import torch | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # Owner(s): ["module: unknown"] | ||||
| # Owner(s): ["module: sparse"] | ||||
|  | ||||
| import copy | ||||
| import warnings | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # Owner(s): ["module: unknown"] | ||||
| # Owner(s): ["module: sparse"] | ||||
|  | ||||
| import copy | ||||
| import itertools | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # Owner(s): ["module: unknown"] | ||||
| # Owner(s): ["module: sparse"] | ||||
|  | ||||
| import copy | ||||
| import io | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # Owner(s): ["module: unknown"] | ||||
| # Owner(s): ["module: sparse"] | ||||
|  | ||||
|  | ||||
| import torch | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # Owner(s): ["module: unknown"] | ||||
| # Owner(s): ["module: sparse"] | ||||
|  | ||||
| import warnings | ||||
|  | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # Owner(s): ["module: unknown"] | ||||
| # Owner(s): ["module: sparse"] | ||||
|  | ||||
| import itertools | ||||
| import re | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # Owner(s): ["module: unknown"] | ||||
| # Owner(s): ["module: sparse"] | ||||
|  | ||||
|  | ||||
| import logging | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # Owner(s): ["module: unknown"] | ||||
| # Owner(s): ["module: sparse"] | ||||
| import copy | ||||
| import random | ||||
|  | ||||
|  | ||||
| @ -1,7 +0,0 @@ | ||||
| # Owner(s): ["module: unknown"] | ||||
|  | ||||
| import torch | ||||
|  | ||||
|  | ||||
| x = torch.ones((3, 3), requires_grad=True) | ||||
| (3 * x).sum().backward() | ||||
| @ -1,17 +0,0 @@ | ||||
| # Owner(s): ["module: unknown"] | ||||
|  | ||||
| import argparse | ||||
|  | ||||
| import torch | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser() | ||||
|  | ||||
|     # Required args. Raises error if they aren't passed. | ||||
|     parser.add_argument("--foo", help="foo", required=True) | ||||
|     parser.add_argument("--bar", help="bar", required=True) | ||||
|     _ = parser.parse_args() | ||||
|  | ||||
|     x = torch.ones((3, 3), requires_grad=True) | ||||
|     (3 * x).sum().backward() | ||||
| @ -1,29 +0,0 @@ | ||||
| # Owner(s): ["module: unknown"] | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
|  | ||||
| class Model(nn.Module): | ||||
|     def __init__(self) -> None: | ||||
|         super().__init__() | ||||
|         self.linear = nn.Linear(20, 20) | ||||
|  | ||||
|     def forward(self, input): | ||||
|         out = self.linear(input[:, 10:30]) | ||||
|         return out.sum() | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     data = torch.randn(10, 50).cuda() | ||||
|     model = Model().cuda() | ||||
|     optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) | ||||
|     for _ in range(10): | ||||
|         optimizer.zero_grad() | ||||
|         loss = model(data) | ||||
|         loss.backward() | ||||
|         optimizer.step() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
| @ -879,12 +879,15 @@ void test_cuda_alloc_test() { | ||||
|   if (cudaStatus != cudaSuccess || device_idx == -1) { | ||||
|     throw std::runtime_error("cudaGetDevice failed!"); | ||||
|   } | ||||
|  | ||||
|   c10::cuda::CUDACachingAllocator::emptyCache(); | ||||
|   c10::cuda::CUDACachingAllocator::DeviceStats stats = | ||||
|       c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); | ||||
|   size_t initTorchActive = stats.active_bytes[0].current; | ||||
|   size_t initTorchActive = stats.allocated_bytes[0].current; | ||||
|   auto runner = std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>( | ||||
|       model_so_path); | ||||
|   size_t torchActive = stats.active_bytes[0].current; | ||||
|   stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); | ||||
|   size_t torchActive = stats.allocated_bytes[0].current; | ||||
|  | ||||
|   ASSERT_EQ(initTorchActive + DATASIZE, torchActive); | ||||
|  | ||||
| @ -1113,8 +1116,7 @@ TEST(AotInductorTest, MultiStreamTestCuda) { | ||||
|   test_multi_cuda_streams("cuda"); | ||||
| } | ||||
|  | ||||
| // TODO: ENABLE CUDACachingAllocator Test | ||||
| TEST(DISABLED_AotInductorTest, CudaAllocTestCuda) { | ||||
| TEST(AotInductorTest, CudaAllocTestCuda) { | ||||
|   test_cuda_alloc_test(); | ||||
| } | ||||
| #endif | ||||
|  | ||||
| @ -117,6 +117,49 @@ class TestFullyShardStateDictMultiProcess(FSDPTest): | ||||
|         for key, value in ref_sharded_sd.items(): | ||||
|             self.assertEqual(value, sharded_sd[key]) | ||||
|  | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     def test_cached_state_dict(self): | ||||
|         self.run_subtests( | ||||
|             {"mlp_dim": [2, 3, 4, 5], "mutate_after_state_dict": [True, False]}, | ||||
|             self._test_cached_state_dict, | ||||
|         ) | ||||
|  | ||||
|     def _test_cached_state_dict(self, mlp_dim: int, mutate_after_state_dict: bool): | ||||
|         torch.manual_seed(42) | ||||
|         model = nn.Linear(mlp_dim, mlp_dim, bias=False) | ||||
|         fully_shard(model, reshard_after_forward=True) | ||||
|         optim = torch.optim.AdamW(model.parameters(), lr=1e-2) | ||||
|  | ||||
|         # call .state_dict() once and use `sd` directly to reduce cpu overhead | ||||
|         sd = model.state_dict() | ||||
|         assert isinstance(model.weight, DTensor) | ||||
|  | ||||
|         if not mutate_after_state_dict: | ||||
|             self.assertTrue( | ||||
|                 sd["weight"]._local_tensor.untyped_storage().data_ptr() | ||||
|                 == model.weight._local_tensor.untyped_storage().data_ptr() | ||||
|             ) | ||||
|         else: | ||||
|             model = model.cpu() | ||||
|             model = model.cuda() | ||||
|             self.assertTrue( | ||||
|                 sd["weight"]._local_tensor.untyped_storage().data_ptr() | ||||
|                 != model.weight._local_tensor.untyped_storage().data_ptr() | ||||
|             ) | ||||
|  | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         inp = torch.rand(mlp_dim, mlp_dim, device="cuda") | ||||
|         for _ in range(5): | ||||
|             optim.zero_grad() | ||||
|             loss = model(inp).sum() | ||||
|             loss.backward() | ||||
|             optim.step() | ||||
|             if not mutate_after_state_dict: | ||||
|                 self.assertTrue( | ||||
|                     sd["weight"]._local_tensor.untyped_storage().data_ptr() | ||||
|                     == model.weight._local_tensor.untyped_storage().data_ptr() | ||||
|                 ) | ||||
|  | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     def test_dp_state_dict_cpu_offload(self): | ||||
|         self.run_subtests( | ||||
|  | ||||
| @ -1490,8 +1490,8 @@ class TestFullyShardWorldSize1(FSDPTest): | ||||
|     @skip_if_lt_x_gpu(1) | ||||
|     def test_train_parity_single_worldsize1(self): | ||||
|         """ | ||||
|         Tests train parity with DDP for a single FSDP group when sharding | ||||
|         parameters on dim-0. | ||||
|         Tests train parity with DDP for a single FSDP group | ||||
|         when sharding parameters on dim-0. | ||||
|         """ | ||||
|         self.run_subtests( | ||||
|             { | ||||
| @ -1539,9 +1539,7 @@ class TestFullyShardWorldSize1(FSDPTest): | ||||
|                 losses.append(model(*inp).sum()) | ||||
|                 losses[-1].backward() | ||||
|  | ||||
|             # Before there was 1 all-gather and 1 reduce-scatter | ||||
|             # Now therre is 1 reduce-scatter | ||||
|             self.assertEqual(comm_mode.get_total_counts(), 1) | ||||
|             self.assertEqual(comm_mode.get_total_counts(), 0) | ||||
|             optim.step() | ||||
|  | ||||
|             self.assertEqual(losses[0], losses[1]) | ||||
|  | ||||
| @ -294,11 +294,11 @@ class TestFullyShard2DTraining(FSDPTest): | ||||
|             with CommDebugMode() as bwd_comm_mode: | ||||
|                 loss.backward() | ||||
|             bwd_comm_counts = bwd_comm_mode.get_comm_counts() | ||||
|             self.assertEqual(len(bwd_comm_counts), 2) | ||||
|             self.assertEqual(len(bwd_comm_counts), 1) | ||||
|             # First MLP's input gradient does not need to be all-reduced | ||||
|             self.assertEqual(bwd_comm_counts[funcol.all_reduce], num_mlps - 1) | ||||
|             self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], 0) | ||||
|             self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_mlps) | ||||
|             self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], 0) | ||||
|             ref_loss.backward() | ||||
|  | ||||
|             optim.step() | ||||
|  | ||||
							
								
								
									
										656
									
								
								test/distributed/_composable/test_replicate_training.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										656
									
								
								test/distributed/_composable/test_replicate_training.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,656 @@ | ||||
| # Owner(s): ["oncall: distributed"] | ||||
|  | ||||
| import contextlib | ||||
| import copy | ||||
| import functools | ||||
| import itertools | ||||
| import unittest | ||||
| from collections.abc import Iterable | ||||
| from typing import Union | ||||
|  | ||||
| import torch | ||||
| import torch.distributed as dist | ||||
| import torch.nn as nn | ||||
| from torch.distributed._composable.replicate_with_fsdp import replicate | ||||
| from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, OffloadPolicy | ||||
| from torch.distributed.tensor import DTensor, init_device_mesh | ||||
| from torch.testing._internal.common_distributed import skip_if_lt_x_gpu | ||||
| from torch.testing._internal.common_fsdp import ( | ||||
|     check_sharded_parity, | ||||
|     compiled_fsdp_test, | ||||
|     FSDPTest, | ||||
|     FSDPTestMultiThread, | ||||
|     MLP, | ||||
|     patch_all_gather, | ||||
|     patch_reduce_scatter, | ||||
| ) | ||||
| from torch.testing._internal.common_utils import ( | ||||
|     get_cycles_per_ms, | ||||
|     run_tests, | ||||
|     TEST_HPU, | ||||
|     wrapSwapTensorsTest, | ||||
| ) | ||||
| from torch.testing._internal.distributed._tensor.common_dtensor import ( | ||||
|     ModelArgs, | ||||
|     Transformer, | ||||
|     TransformerBlock, | ||||
| ) | ||||
|  | ||||
|  | ||||
| c10d_ops = torch.ops.c10d | ||||
| funcol = torch.ops.c10d_functional | ||||
|  | ||||
| from torch.testing._internal.common_fsdp import get_devtype | ||||
|  | ||||
|  | ||||
| device_type = torch.device(get_devtype()) | ||||
|  | ||||
|  | ||||
| class TestReplicateForwardInputs(FSDPTestMultiThread): | ||||
|     @property | ||||
|     def world_size(self) -> int: | ||||
|         return 2 | ||||
|  | ||||
|     @skip_if_lt_x_gpu(1) | ||||
|     def test_root_move_forward_input_to_device(self): | ||||
|         device = torch.device(device_type.type, 0) | ||||
|  | ||||
|         class ParamlessModule(nn.Module): | ||||
|             def forward(self, x: torch.Tensor, ys: tuple[torch.Tensor, ...]): | ||||
|                 # Check that Replicate moved the inputs to GPU, including recursing | ||||
|                 # into the tuple data structure | ||||
|                 assert x.device == device, f"Expects {device} but got {x.device}" | ||||
|                 assert ys[0].device == device, ( | ||||
|                     f"Expects {device} but got {ys[0].device}" | ||||
|                 ) | ||||
|                 assert ys[1].device == device, ( | ||||
|                     f"Expects {device} but got {ys[1].device}" | ||||
|                 ) | ||||
|                 y = ys[0] + ys[1] | ||||
|                 return x + y + 1 | ||||
|  | ||||
|         model = ParamlessModule().to(device) | ||||
|         replicate(model).to(device) | ||||
|         x = torch.randn((3,)) | ||||
|         ys = (torch.randn((3,)), torch.randn((3,))) | ||||
|         self.assertEqual(x.device, torch.device("cpu")) | ||||
|         self.assertEqual(ys[0].device, torch.device("cpu")) | ||||
|         self.assertEqual(ys[1].device, torch.device("cpu")) | ||||
|         model(x, ys) | ||||
|  | ||||
|  | ||||
| class TestReplicateRegisteredParams(FSDPTestMultiThread): | ||||
|     @property | ||||
|     def world_size(self) -> int: | ||||
|         return 4 | ||||
|  | ||||
|     @skip_if_lt_x_gpu(1) | ||||
|     def test_param_registration_after_forward(self): | ||||
|         """Tests the parameter registration after forward.""" | ||||
|         device = torch.device(device_type.type, 0) | ||||
|         # Single Replicate group | ||||
|         for reshard_after_forward in (True, False, None): | ||||
|             torch.manual_seed(42) | ||||
|             model = MLP(3, device) | ||||
|             # Since seed is per process, not per thread, we broadcast to ensure | ||||
|             # the same parameters across ranks | ||||
|             for param in model.parameters(): | ||||
|                 dist.broadcast(param, src=0) | ||||
|             ref_model = copy.deepcopy(model) | ||||
|             replicate(model, reshard_after_forward=reshard_after_forward)  # root only | ||||
|             inp = torch.randn((2, 3), device=device_type.type) | ||||
|             self._assert_dtensor_params(model.parameters()) | ||||
|             self._assert_same_params(model.parameters(), ref_model.parameters()) | ||||
|             model(inp) | ||||
|             if reshard_after_forward: | ||||
|                 self._assert_dtensor_params(model.parameters()) | ||||
|             else: | ||||
|                 self._assert_tensor_params(model.parameters()) | ||||
|             self._assert_same_params(model.parameters(), ref_model.parameters()) | ||||
|             model.reshard()  # however, we can manually reshard | ||||
|             self._assert_dtensor_params(model.parameters()) | ||||
|             self._assert_same_params(model.parameters(), ref_model.parameters()) | ||||
|  | ||||
|         # Multiple Replicate groups | ||||
|         for reshard_after_forward in (True, False, None): | ||||
|             torch.manual_seed(42) | ||||
|             model = nn.Sequential(MLP(3, device), MLP(3, device)) | ||||
|             for param in model.parameters(): | ||||
|                 dist.broadcast(param, src=0) | ||||
|             ref_model = copy.deepcopy(model) | ||||
|             replicate(model[0].in_proj, reshard_after_forward=reshard_after_forward) | ||||
|             replicate(model[0].out_proj, reshard_after_forward=reshard_after_forward) | ||||
|             replicate(model, reshard_after_forward=reshard_after_forward) | ||||
|  | ||||
|             self._assert_dtensor_params(model.parameters()) | ||||
|             self._assert_same_params(model.parameters(), ref_model.parameters()) | ||||
|             model(inp) | ||||
|             non_root_params = list(model[0].in_proj.parameters()) + list( | ||||
|                 model[0].out_proj.parameters() | ||||
|             ) | ||||
|             root_params = list(set(model.parameters()) - set(non_root_params)) | ||||
|             if reshard_after_forward is None: | ||||
|                 self._assert_dtensor_params(non_root_params) | ||||
|                 self._assert_tensor_params(root_params) | ||||
|             elif reshard_after_forward: | ||||
|                 self._assert_dtensor_params(non_root_params) | ||||
|                 self._assert_dtensor_params(root_params) | ||||
|             else: | ||||
|                 self._assert_tensor_params(non_root_params) | ||||
|                 self._assert_tensor_params(root_params) | ||||
|             self._assert_same_params(model.parameters(), ref_model.parameters()) | ||||
|             for module in model.modules(): | ||||
|                 if isinstance(module, FSDPModule): | ||||
|                     module.reshard()  # however, we can manually reshard | ||||
|             self._assert_dtensor_params(model.parameters()) | ||||
|             self._assert_same_params(model.parameters(), ref_model.parameters()) | ||||
|  | ||||
|     @skip_if_lt_x_gpu(1) | ||||
|     def test_param_registration_after_backward(self): | ||||
|         """Tests the parameter registration after backward.""" | ||||
|         device = torch.device(device_type.type, 0) | ||||
|         # Single Replicate group | ||||
|         for reshard_after_forward in (True, False): | ||||
|             model = MLP(8, device) | ||||
|             replicate(model, reshard_after_forward=reshard_after_forward)  # root only | ||||
|             inp = torch.randn((2, 8), device=device_type.type) | ||||
|             self._assert_dtensor_params(model.parameters()) | ||||
|             model(inp).sum().backward() | ||||
|             self._assert_dtensor_params(model.parameters()) | ||||
|  | ||||
|         # Multiple Replicate groups | ||||
|         for reshard_after_forward in (True, False): | ||||
|             model = MLP(8, device) | ||||
|             replicate(model.in_proj, reshard_after_forward=reshard_after_forward) | ||||
|             replicate(model.out_proj, reshard_after_forward=reshard_after_forward) | ||||
|             replicate(model, reshard_after_forward=reshard_after_forward) | ||||
|             self._assert_dtensor_params(model.parameters()) | ||||
|             model(inp).sum().backward() | ||||
|             self._assert_dtensor_params(model.parameters()) | ||||
|  | ||||
|     def _assert_tensor_params(self, params: Iterable[nn.Parameter]): | ||||
|         # need to iterate over the list multiple times | ||||
|         params = list(params) | ||||
|         self.assertGreater(len(params), 0) | ||||
|         for param in params: | ||||
|             self.assertNotIsInstance(param, DTensor) | ||||
|             self.assertIsInstance(param, torch.Tensor) | ||||
|  | ||||
|     def _assert_dtensor_params(self, params: Iterable[nn.Parameter]): | ||||
|         params = list(params) | ||||
|         self.assertGreater(len(params), 0) | ||||
|         for param in params: | ||||
|             self.assertIsInstance(param, DTensor) | ||||
|  | ||||
|     def _assert_same_params( | ||||
|         self, params: Iterable[nn.Parameter], ref_params: Iterable[nn.Parameter] | ||||
|     ): | ||||
|         params, ref_params = list(params), list(ref_params) | ||||
|         self.assertEqual(len(params), len(ref_params)) | ||||
|         for param, ref_param in zip(params, ref_params): | ||||
|             if isinstance(param, DTensor): | ||||
|                 param = param.full_tensor() | ||||
|             self.assertEqual(param.shape, ref_param.shape) | ||||
|             self.assertEqual(param, ref_param) | ||||
|  | ||||
|  | ||||
| class TestReplicateCastAfterInit(FSDPTestMultiThread): | ||||
|     @property | ||||
|     def world_size(self) -> int: | ||||
|         return 2 | ||||
|  | ||||
|     @skip_if_lt_x_gpu(1) | ||||
|     @wrapSwapTensorsTest(True) | ||||
|     def test_to_float64_after_init(self): | ||||
|         """Tests that the user can cast the module to float64 after init.""" | ||||
|         # NOTE: Test fp64 instead of a lower precision dtype like bf16 for | ||||
|         # better numerics. The important part is changing the dtype. | ||||
|  | ||||
|         torch.manual_seed(42) | ||||
|         mlp_dim, device, dtype = 4, device_type, torch.float64 | ||||
|         model = MLP(mlp_dim, device=device) | ||||
|         for param in model.parameters(): | ||||
|             dist.broadcast(param, src=0) | ||||
|         ref_model = copy.deepcopy(model).to(dtype) | ||||
|  | ||||
|         ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) | ||||
|         for module in (model.in_proj, model.out_proj, model): | ||||
|             replicate(module) | ||||
|         model.to(dtype) | ||||
|         for param in model.parameters(): | ||||
|             self.assertEqual(param.dtype, dtype) | ||||
|             self.assertEqual(param.to_local().dtype, dtype) | ||||
|             self.assertEqual(param._spec.tensor_meta.dtype, dtype) | ||||
|         optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True) | ||||
|         check_sharded_parity(self, ref_model, model) | ||||
|         torch.manual_seed(42 + self.rank + 1) | ||||
|         inp = torch.randn((2, mlp_dim), device=device_type.type, dtype=dtype) | ||||
|         for iter_idx in range(10): | ||||
|             losses: list[torch.Tensor] = [] | ||||
|             for _model in (ref_model, model): | ||||
|                 losses.append(_model(inp).sum()) | ||||
|                 losses[-1].backward() | ||||
|  | ||||
|             for param in ref_model.parameters(): | ||||
|                 if param.grad is not None: | ||||
|                     dist.all_reduce(param.grad) | ||||
|                     param.grad.div_(self.world_size) | ||||
|  | ||||
|             self.assertEqual(losses[0], losses[1]) | ||||
|             check_sharded_parity(self, ref_model, model) | ||||
|             for param in model.parameters(): | ||||
|                 self.assertEqual(param.dtype, dtype) | ||||
|                 self.assertEqual(param.to_local().dtype, dtype) | ||||
|                 self.assertEqual(param._spec.tensor_meta.dtype, dtype) | ||||
|                 self.assertEqual(param.grad.dtype, dtype) | ||||
|                 self.assertEqual(param.grad.to_local().dtype, dtype) | ||||
|                 self.assertEqual(param.grad._spec.tensor_meta.dtype, dtype) | ||||
|             for _optim in (ref_optim, optim): | ||||
|                 _optim.step() | ||||
|                 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) | ||||
|  | ||||
|  | ||||
| class TestReplicate1DTrainingCore(FSDPTest): | ||||
|     @property | ||||
|     def world_size(self) -> int: | ||||
|         return min(8, torch.get_device_module(device_type).device_count()) | ||||
|  | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     def test_train_parity_single_group(self): | ||||
|         """ | ||||
|         Tests train parity with DDP for a single FSDP group when sharding | ||||
|         parameters on dim-0. | ||||
|         """ | ||||
|         self.run_subtests( | ||||
|             { | ||||
|                 "lin_shapes": [ | ||||
|                     [(16, 15), (15, 8)], | ||||
|                     [(7, 15), (15, 3)], | ||||
|                     [(16, 17), (17, 8)], | ||||
|                 ], | ||||
|                 "use_shard_placement_fn": [False], | ||||
|             }, | ||||
|             self._test_train_parity_single_group, | ||||
|         ) | ||||
|  | ||||
|     def _test_train_parity_single_group( | ||||
|         self, lin_shapes: list[tuple[int, int]], use_shard_placement_fn: bool | ||||
|     ): | ||||
|         torch.manual_seed(42) | ||||
|         model = nn.Sequential( | ||||
|             nn.Linear(*lin_shapes[0]), nn.ReLU(), nn.Linear(*lin_shapes[1]) | ||||
|         ) | ||||
|         ref_model = copy.deepcopy(model).to(device_type) | ||||
|         ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) | ||||
|  | ||||
|         replicate(model) | ||||
|         optim = torch.optim.Adam(model.parameters(), lr=1e-2) | ||||
|         torch.manual_seed(42 + self.rank + 1) | ||||
|         inp = (torch.randn((4, lin_shapes[0][0]), device=device_type.type),) | ||||
|         for iter_idx in range(10): | ||||
|             losses: list[torch.Tensor] = [] | ||||
|             for _model in (ref_model, model): | ||||
|                 losses.append(_model(*inp).sum()) | ||||
|                 losses[-1].backward() | ||||
|  | ||||
|             for param in ref_model.parameters(): | ||||
|                 if param.grad is not None: | ||||
|                     dist.all_reduce(param.grad) | ||||
|                     param.grad.div_(self.world_size) | ||||
|  | ||||
|             for _optim in (ref_optim, optim): | ||||
|                 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) | ||||
|                 _optim.step() | ||||
|             self.assertEqual(losses[0], losses[1]) | ||||
|  | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     @unittest.skipIf(TEST_HPU, "Sleep kernel not supported for HPU") | ||||
|     @compiled_fsdp_test(compile_compute_on_module=Transformer) | ||||
|     def test_train_parity_multi_groups(self): | ||||
|         """ | ||||
|         Tests train parity against DDP when using multiple parameter groups for | ||||
|         communication (for communication and computation overlap plus memory | ||||
|         reduction). | ||||
|         """ | ||||
|         self.run_subtests( | ||||
|             { | ||||
|                 "reshard_after_forward": [True, False], | ||||
|                 "test_device_type": [device_type.type], | ||||
|                 "offload_policy": [OffloadPolicy()], | ||||
|                 "delay_after_forward": [False, True], | ||||
|                 "delay_before_all_gather": [False, True], | ||||
|                 "delay_before_reduce_scatter": [False, True], | ||||
|                 "delay_before_optim": [False, True], | ||||
|                 "unshard_async_op": [False], | ||||
|             }, | ||||
|             self._test_train_parity_multi_group, | ||||
|         ) | ||||
|  | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     @unittest.skipIf(TEST_HPU, "sleep kernel not supported on HPU") | ||||
|     def test_train_parity_multi_group_cpu_offload_eager(self): | ||||
|         """ | ||||
|         Tests train parity when using multiple parameter groups for | ||||
|         communication and CPU offloading. | ||||
|         """ | ||||
|         self.run_subtests( | ||||
|             { | ||||
|                 "reshard_after_forward": [True],  # save CI time | ||||
|                 "offload_policy": [ | ||||
|                     CPUOffloadPolicy(pin_memory=True), | ||||
|                     CPUOffloadPolicy(pin_memory=False), | ||||
|                 ], | ||||
|                 "test_device_type": [device_type.type], | ||||
|                 "delay_after_forward": [False, True], | ||||
|                 "delay_before_all_gather": [False, True], | ||||
|                 "delay_before_reduce_scatter": [False, True], | ||||
|                 "delay_before_optim": [False, True], | ||||
|                 "unshard_async_op": [False], | ||||
|             }, | ||||
|             self._test_train_parity_multi_group, | ||||
|         ) | ||||
|  | ||||
|     def _test_train_parity_multi_group( | ||||
|         self, | ||||
|         reshard_after_forward: Union[bool, int], | ||||
|         offload_policy: OffloadPolicy, | ||||
|         test_device_type: str, | ||||
|         delay_after_forward: bool, | ||||
|         delay_before_all_gather: bool, | ||||
|         delay_before_reduce_scatter: bool, | ||||
|         delay_before_optim: bool, | ||||
|         unshard_async_op: bool, | ||||
|     ): | ||||
|         # Only test individual delays or all four delays to save test time | ||||
|         if ( | ||||
|             delay_after_forward | ||||
|             + delay_before_all_gather | ||||
|             + delay_before_reduce_scatter | ||||
|             + delay_before_optim | ||||
|             in (2, 3) | ||||
|         ): | ||||
|             return | ||||
|         assert test_device_type in ("cuda", "hpu", "xpu", "cpu"), f"{test_device_type}" | ||||
|         torch.manual_seed(42) | ||||
|         vocab_size = 1024 | ||||
|         model_args = ModelArgs( | ||||
|             n_layers=3, | ||||
|             n_heads=4, | ||||
|             vocab_size=vocab_size, | ||||
|             max_seq_len=64, | ||||
|             dropout_p=0, | ||||
|         ) | ||||
|         model = Transformer(model_args) | ||||
|         ref_model = copy.deepcopy(model).to(device_type) | ||||
|  | ||||
|         ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) | ||||
|         mesh = init_device_mesh( | ||||
|             test_device_type, | ||||
|             (self.world_size, 1), | ||||
|             mesh_dim_names=("replicate", "shard"), | ||||
|         ) | ||||
|         fully_shard_fn = functools.partial( | ||||
|             replicate, | ||||
|             device_mesh=mesh, | ||||
|             reshard_after_forward=reshard_after_forward, | ||||
|             offload_policy=offload_policy, | ||||
|         ) | ||||
|         for module in model.modules(): | ||||
|             if isinstance(module, TransformerBlock): | ||||
|                 fully_shard_fn(module) | ||||
|         fully_shard_fn(model) | ||||
|         if unshard_async_op: | ||||
|             model._set_unshard_async_op(unshard_async_op) | ||||
|         optim = torch.optim.Adam(model.parameters(), lr=1e-2) | ||||
|  | ||||
|         delay_in_ms = 100 | ||||
|         orig_all_gather = dist.all_gather_into_tensor | ||||
|         orig_reduce_scatter = dist.reduce_scatter_tensor | ||||
|  | ||||
|         def delayed_all_gather(*args, **kwargs): | ||||
|             torch.get_device_module(device_type)._sleep( | ||||
|                 int(delay_in_ms * get_cycles_per_ms()) | ||||
|             ) | ||||
|             return orig_all_gather(*args, **kwargs) | ||||
|  | ||||
|         def delayed_reduce_scatter(*args, **kwargs): | ||||
|             torch.get_device_module(device_type)._sleep( | ||||
|                 int(delay_in_ms * get_cycles_per_ms()) | ||||
|             ) | ||||
|             return orig_reduce_scatter(*args, **kwargs) | ||||
|  | ||||
|         torch.manual_seed(42 + self.rank + 1) | ||||
|         patch_all_gather_ctx = ( | ||||
|             patch_all_gather(delayed_all_gather) | ||||
|             if delay_before_all_gather | ||||
|             else contextlib.nullcontext() | ||||
|         ) | ||||
|         patch_reduce_scatter_ctx = ( | ||||
|             patch_reduce_scatter(delayed_reduce_scatter) | ||||
|             if delay_before_reduce_scatter | ||||
|             else contextlib.nullcontext() | ||||
|         ) | ||||
|         with patch_all_gather_ctx, patch_reduce_scatter_ctx: | ||||
|             for iter_idx in range(10): | ||||
|                 inp = torch.randint(0, vocab_size, (3, 64), device=device_type) | ||||
|                 losses: list[torch.Tensor] = [] | ||||
|                 for _model, _optim in ((ref_model, ref_optim), (model, optim)): | ||||
|                     losses.append(_model(inp).sum()) | ||||
|                     if _model is model and delay_after_forward: | ||||
|                         torch.get_device_module(device_type)._sleep( | ||||
|                             int(delay_in_ms * get_cycles_per_ms()) | ||||
|                         ) | ||||
|                     losses[-1].backward() | ||||
|                     if _model is model and delay_before_optim: | ||||
|                         torch.get_device_module(device_type)._sleep( | ||||
|                             int(delay_in_ms * get_cycles_per_ms()) | ||||
|                         ) | ||||
|  | ||||
|                 for param in ref_model.parameters(): | ||||
|                     if param.grad is not None: | ||||
|                         dist.all_reduce(param.grad) | ||||
|                         param.grad.div_(self.world_size) | ||||
|  | ||||
|                 for _optim in (ref_optim, optim): | ||||
|                     _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) | ||||
|                     _optim.step() | ||||
|                 self.assertEqual(losses[0], losses[1]) | ||||
|  | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     def test_non_root_forward_backward(self): | ||||
|         """ | ||||
|         Tests running forward/backward through the root and then through a | ||||
|         non-root. The non-root needs to synchronize streams/queue the callback. | ||||
|         """ | ||||
|         torch.manual_seed(42) | ||||
|         lin_dim = 32 | ||||
|         model = nn.Sequential(*[MLP(lin_dim, torch.device("cpu")) for _ in range(3)]) | ||||
|         ref_model = copy.deepcopy(model).to(device_type) | ||||
|         ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) | ||||
|         for mlp in model: | ||||
|             replicate(mlp) | ||||
|         replicate(model) | ||||
|         optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True) | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         inp = torch.randn((8, lin_dim), device=device_type) | ||||
|  | ||||
|         ref_root_loss = ref_model(inp).sum() | ||||
|         ref_root_loss.backward() | ||||
|         for param in ref_model.parameters(): | ||||
|             dist.all_reduce(param.grad) | ||||
|             param.grad.detach().div_(self.world_size) | ||||
|         ref_optim.step() | ||||
|         ref_optim.zero_grad() | ||||
|         ref_nonroot_loss = ref_model[0](inp).sum() | ||||
|         ref_nonroot_loss.backward() | ||||
|         for param in ref_model.parameters(): | ||||
|             if param.grad is not None: | ||||
|                 dist.all_reduce(param.grad) | ||||
|                 param.grad.detach().div_(self.world_size) | ||||
|         ref_optim.step() | ||||
|  | ||||
|         root_loss = model(inp).sum() | ||||
|         root_loss.backward() | ||||
|         torch.get_device_module(device_type)._sleep(int(100 * get_cycles_per_ms())) | ||||
|         optim.step() | ||||
|         optim.zero_grad() | ||||
|         nonroot_loss = model[0](inp).sum() | ||||
|         nonroot_loss.backward() | ||||
|         optim.step() | ||||
|  | ||||
|         self.assertEqual(ref_root_loss, root_loss) | ||||
|         self.assertEqual(ref_nonroot_loss, nonroot_loss) | ||||
|         self.assertEqual(ref_model(inp).sum(), model(inp).sum()) | ||||
|  | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     def test_multi_forward_module(self): | ||||
|         """ | ||||
|         Tests parity when running a module that participates multiple | ||||
|         times in forward. | ||||
|         """ | ||||
|         self.run_subtests( | ||||
|             {"reshard_after_forward": [True, False]}, | ||||
|             self._test_multi_forward_module, | ||||
|         ) | ||||
|  | ||||
|     def _test_multi_forward_module(self, reshard_after_forward: Union[bool, int]): | ||||
|         class MultiForwardModule(nn.Module): | ||||
|             def __init__(self, device: torch.device): | ||||
|                 super().__init__() | ||||
|                 self.inner = nn.Linear(4, 4, device=device) | ||||
|                 self.outer = nn.Linear(4, 5, device=device) | ||||
|  | ||||
|             def forward(self, x): | ||||
|                 i = self.inner(x) | ||||
|                 j = self.inner(x) | ||||
|                 return self.outer(i + j) | ||||
|  | ||||
|         torch.manual_seed(42) | ||||
|         model = MultiForwardModule(device=device_type.type) | ||||
|         ref_model = copy.deepcopy(model).to(device_type) | ||||
|  | ||||
|         ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) | ||||
|         replicate(model.inner) | ||||
|         replicate(model) | ||||
|         optim = torch.optim.Adam(model.parameters(), lr=1e-2) | ||||
|  | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         inp = torch.randn((32, 4), device=device_type.type) | ||||
|         for iter_idx in range(10): | ||||
|             losses: list[torch.Tensor] = [] | ||||
|             for _model in (ref_model, model): | ||||
|                 losses.append(_model(inp).sum()) | ||||
|                 losses[-1].backward() | ||||
|  | ||||
|             for param in ref_model.parameters(): | ||||
|                 if param.grad is not None: | ||||
|                     dist.all_reduce(param.grad) | ||||
|                     param.grad.div_(self.world_size) | ||||
|  | ||||
|             for _optim in (ref_optim, optim): | ||||
|                 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) | ||||
|                 _optim.step() | ||||
|  | ||||
|             self.assertEqual(losses[0], losses[1]) | ||||
|  | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     def test_explicit_prefetching(self): | ||||
|         torch.manual_seed(42) | ||||
|         model_args = ModelArgs(n_layers=8, dropout_p=0.0) | ||||
|         model = Transformer(model_args) | ||||
|         ref_model = copy.deepcopy(model).to(device_type) | ||||
|         ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) | ||||
|  | ||||
|         for layer in itertools.chain(model.layers, [model]): | ||||
|             replicate(layer) | ||||
|         optim = torch.optim.AdamW(model.parameters(), lr=1e-2) | ||||
|  | ||||
|         num_to_forward_prefetch = num_to_backward_prefetch = 2 | ||||
|         for i, layer in enumerate(model.layers): | ||||
|             if i >= len(model.layers) - num_to_forward_prefetch: | ||||
|                 break | ||||
|             layers_to_prefetch = [ | ||||
|                 model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1) | ||||
|             ] | ||||
|             layer.set_modules_to_forward_prefetch(layers_to_prefetch) | ||||
|         for i, layer in enumerate(model.layers): | ||||
|             if i < num_to_backward_prefetch: | ||||
|                 continue | ||||
|             layers_to_prefetch = [ | ||||
|                 model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1) | ||||
|             ] | ||||
|             layer.set_modules_to_backward_prefetch(layers_to_prefetch) | ||||
|  | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         inp = torch.randint(0, model_args.vocab_size, (2, 8), device=device_type.type) | ||||
|         for _ in range(10): | ||||
|             losses: list[torch.Tensor] = [] | ||||
|  | ||||
|             for _model in (ref_model, model): | ||||
|                 losses.append(_model(inp).sum()) | ||||
|                 losses[-1].backward() | ||||
|  | ||||
|             for param in ref_model.parameters(): | ||||
|                 if param.grad is not None: | ||||
|                     dist.all_reduce(param.grad) | ||||
|                     param.grad.div_(self.world_size) | ||||
|  | ||||
|             for _optim in (ref_optim, optim): | ||||
|                 _optim.zero_grad() | ||||
|                 _optim.step() | ||||
|  | ||||
|             self.assertEqual(losses[0], losses[1]) | ||||
|  | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     @unittest.skipIf(TEST_HPU, "Sleep is not supported on HPU") | ||||
|     def test_post_optim_event(self): | ||||
|         torch.manual_seed(42) | ||||
|         model_args = ModelArgs(dropout_p=0.0) | ||||
|         model = Transformer(model_args) | ||||
|         ref_model = copy.deepcopy(model).to(device_type.type) | ||||
|         ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) | ||||
|         for layer in itertools.chain(model.layers, [model]): | ||||
|             replicate(layer) | ||||
|         optim = torch.optim.AdamW(model.parameters(), lr=1e-2) | ||||
|  | ||||
|         def step_post_hook( | ||||
|             fsdp_module: FSDPModule, opt: torch.optim.Optimizer, args, kwargs | ||||
|         ) -> None: | ||||
|             post_optim_event = ( | ||||
|                 torch.get_device_module(device_type).current_stream().record_event() | ||||
|             ) | ||||
|             fsdp_module.set_post_optim_event(post_optim_event) | ||||
|  | ||||
|         optim.register_step_post_hook(functools.partial(step_post_hook, model)) | ||||
|  | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         inp = torch.randint(0, model_args.vocab_size, (2, 8), device=device_type.type) | ||||
|         # Track all losses and check for equality at the end to avoid a CPU | ||||
|         # sync point after each iteration | ||||
|         ref_losses: list[torch.Tensor] = [] | ||||
|         losses: list[torch.Tensor] = [] | ||||
|         for _ in range(10): | ||||
|             ref_optim.zero_grad() | ||||
|             ref_losses.append(ref_model(inp).sum()) | ||||
|             ref_losses[-1].backward() | ||||
|  | ||||
|             for param in ref_model.parameters(): | ||||
|                 if param.grad is not None: | ||||
|                     dist.all_reduce(param.grad) | ||||
|                     param.grad.div_(self.world_size) | ||||
|  | ||||
|             ref_optim.step() | ||||
|         for _ in range(10): | ||||
|             optim.zero_grad() | ||||
|             losses.append(model(inp).sum()) | ||||
|             losses[-1].backward() | ||||
|             optim.step() | ||||
|             # Sleep after the optimizer step to allow CPU to run ahead into the | ||||
|             # next iteration's forward, exercising the post-optim stream sync | ||||
|             torch.get_device_module(device_type)._sleep(int(25 * get_cycles_per_ms())) | ||||
|         for ref_loss, loss in zip(ref_losses, losses): | ||||
|             self.assertEqual(ref_loss, loss) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
| @ -47,11 +47,14 @@ _LOGGER = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| class TestCoalesce(TestCase): | ||||
|     def helper_test_coalesce(self, layout): | ||||
|     def helper_test_coalesce(self, layout, coalesced_layout=None): | ||||
|         layoutR = coalesce(layout) | ||||
|  | ||||
|         _LOGGER.debug(f"{layout}  =>  {layoutR}") | ||||
|  | ||||
|         if coalesced_layout: | ||||
|             self.assertEqual(coalesced_layout.shape, layoutR.shape) | ||||
|             self.assertEqual(coalesced_layout.stride, layoutR.stride) | ||||
|         self.assertEqual(size(layoutR), size(layout)) | ||||
|  | ||||
|         for i in range(size(layout)): | ||||
| @ -82,11 +85,17 @@ class TestCoalesce(TestCase): | ||||
|         layout = Layout((2, (4, 6))) | ||||
|         self.helper_test_coalesce(layout) | ||||
|  | ||||
|         layout = Layout((1, 2), (8, 1)) | ||||
|         coalesced_layout = Layout(2, 1) | ||||
|         self.helper_test_coalesce(layout, coalesced_layout) | ||||
|  | ||||
|         layout = Layout((2, 4), (4, 1)) | ||||
|         self.helper_test_coalesce(layout) | ||||
|         coalesced_layout = Layout(8, 1) | ||||
|         self.helper_test_coalesce(layout, coalesced_layout) | ||||
|  | ||||
|         layout = Layout((2, 4, 6), (24, 6, 1)) | ||||
|         self.helper_test_coalesce(layout) | ||||
|         coalesced_layout = Layout(48, 1) | ||||
|         self.helper_test_coalesce(layout, coalesced_layout) | ||||
|  | ||||
|         layout = Layout((2, 1, 3), (2, 4, 4)) | ||||
|         self.helper_test_coalesce(layout) | ||||
| @ -94,6 +103,10 @@ class TestCoalesce(TestCase): | ||||
|         layout = Layout(((2, 2), (2, 2)), ((1, 4), (8, 32))) | ||||
|         self.helper_test_coalesce(layout) | ||||
|  | ||||
|         layout = Layout(((2, 2), (2, 2)), ((32, 8), (4, 1))) | ||||
|         coalesced_layout = Layout((2, 4, 2), (32, 4, 1)) | ||||
|         self.helper_test_coalesce(layout, coalesced_layout) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
|  | ||||
| @ -208,11 +208,26 @@ class TestComposition(TestCase): | ||||
|         layoutB = Layout((6), (1)) | ||||
|         self.helper_test_composition(layoutA, layoutB) | ||||
|  | ||||
|         # Pre-coalesced RHS | ||||
|         layoutA = Layout((8, 6, 4), (7, 4, 1)) | ||||
|         layoutB = Layout((6), (1)) | ||||
|         self.helper_test_composition(layoutA, layoutB) | ||||
|  | ||||
|         # Case when not meet stride divisibility condition | ||||
|         with self.assertRaises(AssertionError): | ||||
|             layoutA = Layout((4, 6, 8, 10), (2, 3, 5, 7)) | ||||
|             layoutB = Layout(6, 12) | ||||
|             self.helper_test_composition(layoutA, layoutB) | ||||
|  | ||||
|         # Mid-layout truncation | ||||
|         layoutA = Layout((4, 6, 8, 10), (2, 3, 5, 7)) | ||||
|         layoutA = Layout((10, 8, 6, 4), (7, 5, 3, 2)) | ||||
|         layoutB = Layout(6, 12) | ||||
|         self.helper_test_composition(layoutA, layoutB) | ||||
|  | ||||
|         layoutA = Layout((4,), (3,)) | ||||
|         layoutB = Layout((6,), (2,)) | ||||
|         self.helper_test_composition(layoutA, layoutB) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
|  | ||||
| @ -67,20 +67,159 @@ class TestIntTuple(TestCase): | ||||
|  | ||||
|         self.assertEqual(shape_div((6, (3, 4)), 36), (1, (1, 2))) | ||||
|  | ||||
|     def test_prefix_product(self): | ||||
|         self.assertEqual(prefix_product(2), 1) | ||||
|     def test_suffix_product(self): | ||||
|         self.assertEqual(suffix_product(2), 1) | ||||
|  | ||||
|         self.assertEqual(prefix_product((3, 2)), (1, 3)) | ||||
|         self.assertEqual(suffix_product((3, 2)), (2, 1)) | ||||
|  | ||||
|         self.assertEqual(prefix_product((3, 2, 4)), (1, 3, 6)) | ||||
|         self.assertEqual(suffix_product((3, 2, 4)), (8, 4, 1)) | ||||
|  | ||||
|         self.assertEqual(prefix_product(((2, 3), 4)), ((1, 2), 6)) | ||||
|         self.assertEqual(suffix_product(((2, 3), 4)), ((12, 4), 1)) | ||||
|  | ||||
|         self.assertEqual( | ||||
|             prefix_product(((2, 3), (2, 1, 2), (5, 2, 1))), | ||||
|             ((1, 2), (6, 12, 12), (24, 120, 240)), | ||||
|             suffix_product(((2, 3), (2, 1, 2), (5, 2, 1))), | ||||
|             ((120, 40), (20, 20, 10), (2, 1, 1)), | ||||
|         ) | ||||
|  | ||||
|     def test_crd2idx_basic(self): | ||||
|         # Test basic int/int case | ||||
|         self.assertEqual(crd2idx(2, 5, 1), 2) | ||||
|         self.assertEqual(crd2idx(0, 5, 1), 0) | ||||
|         self.assertEqual(crd2idx(4, 5, 1), 4) | ||||
|  | ||||
|         # Test with custom stride | ||||
|         self.assertEqual(crd2idx(2, 5, 3), 6) | ||||
|         self.assertEqual(crd2idx(1, 5, 3), 3) | ||||
|  | ||||
|     def test_crd2idx_tuple(self): | ||||
|         # Test tuple coordinates with default stride | ||||
|         self.assertEqual(crd2idx((1, 2), (3, 4)), 6)  # 1*4 + 2*1 = 6 | ||||
|         self.assertEqual(crd2idx((0, 0), (3, 4)), 0) | ||||
|         self.assertEqual(crd2idx((2, 3), (3, 4)), 11)  # 2*4 + 3*1 = 11 | ||||
|  | ||||
|         # Test with custom stride | ||||
|         self.assertEqual(crd2idx((1, 2), (3, 4), (8, 2)), 12)  # 1*8 + 2*2 = 12 | ||||
|  | ||||
|         # Test 3D case | ||||
|         self.assertEqual(crd2idx((1, 0, 2), (2, 3, 4)), 14)  # 1*12 + 0*4 + 2*1 = 14 | ||||
|  | ||||
|     def test_crd2idx_none(self): | ||||
|         # Test None coordinate (should default to 0) | ||||
|         self.assertEqual(crd2idx(None, 5), 0) | ||||
|         self.assertEqual(crd2idx(None, (3, 4)), 0) | ||||
|  | ||||
|     def test_crd2idx_int_with_tuple_shape(self): | ||||
|         # Test single integer coordinate with multi-dimensional shape and stride | ||||
|         # When crd is int and shape is tuple, it converts the int to multi-dim coordinate first | ||||
|         self.assertEqual(crd2idx(0, (2, 2), (2, 1)), 0)  # 0 -> (0,0) -> 0*2 + 0*1 = 0 | ||||
|         self.assertEqual(crd2idx(1, (2, 2), (2, 1)), 1)  # 1 -> (0,1) -> 0*2 + 1*1 = 1 | ||||
|         self.assertEqual(crd2idx(2, (2, 2), (2, 1)), 2)  # 2 -> (1,0) -> 1*2 + 0*1 = 2 | ||||
|         self.assertEqual(crd2idx(3, (2, 2), (2, 1)), 3)  # 3 -> (1,1) -> 1*2 + 1*1 = 3 | ||||
|  | ||||
|         # Test with non-trivial strides | ||||
|         self.assertEqual(crd2idx(0, (2, 3), (6, 2)), 0)  # 0 -> (0,0) -> 0*6 + 0*2 = 0 | ||||
|         self.assertEqual(crd2idx(1, (2, 3), (6, 2)), 2)  # 1 -> (0,1) -> 0*6 + 1*2 = 2 | ||||
|         self.assertEqual(crd2idx(2, (2, 3), (6, 2)), 4)  # 2 -> (0,2) -> 0*6 + 2*2 = 4 | ||||
|         self.assertEqual(crd2idx(3, (2, 3), (6, 2)), 6)  # 3 -> (1,0) -> 1*6 + 0*2 = 6 | ||||
|         self.assertEqual(crd2idx(4, (2, 3), (6, 2)), 8)  # 4 -> (1,1) -> 1*6 + 1*2 = 8 | ||||
|         self.assertEqual(crd2idx(5, (2, 3), (6, 2)), 10)  # 5 -> (1,2) -> 1*6 + 2*2 = 10 | ||||
|  | ||||
|         # Test with larger strides | ||||
|         self.assertEqual(crd2idx(0, (3, 2), (10, 5)), 0)  # 0 -> (0,0) -> 0*10 + 0*5 = 0 | ||||
|         self.assertEqual(crd2idx(1, (3, 2), (10, 5)), 5)  # 1 -> (0,1) -> 0*10 + 1*5 = 5 | ||||
|         self.assertEqual( | ||||
|             crd2idx(2, (3, 2), (10, 5)), 10 | ||||
|         )  # 2 -> (1,0) -> 1*10 + 0*5 = 10 | ||||
|         self.assertEqual( | ||||
|             crd2idx(3, (3, 2), (10, 5)), 15 | ||||
|         )  # 3 -> (1,1) -> 1*10 + 1*5 = 15 | ||||
|         self.assertEqual( | ||||
|             crd2idx(4, (3, 2), (10, 5)), 20 | ||||
|         )  # 4 -> (2,0) -> 2*10 + 0*5 = 20 | ||||
|         self.assertEqual( | ||||
|             crd2idx(5, (3, 2), (10, 5)), 25 | ||||
|         )  # 5 -> (2,1) -> 2*10 + 1*5 = 25 | ||||
|  | ||||
|         # Test with 3D shape and various strides | ||||
|         self.assertEqual( | ||||
|             crd2idx(0, (2, 2, 2), (8, 4, 2)), 0 | ||||
|         )  # 0 -> (0,0,0) -> 0*8 + 0*4 + 0*2 = 0 | ||||
|         self.assertEqual( | ||||
|             crd2idx(1, (2, 2, 2), (8, 4, 2)), 2 | ||||
|         )  # 1 -> (0,0,1) -> 0*8 + 0*4 + 1*2 = 2 | ||||
|         self.assertEqual( | ||||
|             crd2idx(2, (2, 2, 2), (8, 4, 2)), 4 | ||||
|         )  # 2 -> (0,1,0) -> 0*8 + 1*4 + 0*2 = 4 | ||||
|         self.assertEqual( | ||||
|             crd2idx(3, (2, 2, 2), (8, 4, 2)), 6 | ||||
|         )  # 3 -> (0,1,1) -> 0*8 + 1*4 + 1*2 = 6 | ||||
|         self.assertEqual( | ||||
|             crd2idx(4, (2, 2, 2), (8, 4, 2)), 8 | ||||
|         )  # 4 -> (1,0,0) -> 1*8 + 0*4 + 0*2 = 8 | ||||
|         self.assertEqual( | ||||
|             crd2idx(7, (2, 2, 2), (8, 4, 2)), 14 | ||||
|         )  # 7 -> (1,1,1) -> 1*8 + 1*4 + 1*2 = 14 | ||||
|  | ||||
|         self.assertEqual( | ||||
|             crd2idx(4, ((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))), 8 | ||||
|         )  # 4 -> (1,0,0) -> 1*8 = 8 | ||||
|  | ||||
|     def test_idx2crd_basic(self): | ||||
|         # Test basic int/int case | ||||
|         self.assertEqual(idx2crd(2, 5, 1), 2) | ||||
|         self.assertEqual(idx2crd(0, 5, 1), 0) | ||||
|         self.assertEqual(idx2crd(4, 5, 1), 4) | ||||
|  | ||||
|         # Test with custom stride | ||||
|         self.assertEqual(idx2crd(6, 5, 3), 2)  # (6 // 3) % 5 = 2 | ||||
|         self.assertEqual(idx2crd(3, 5, 3), 1)  # (3 // 3) % 5 = 1 | ||||
|  | ||||
|     def test_idx2crd_tuple(self): | ||||
|         # Test tuple shape with default stride | ||||
|         self.assertEqual(idx2crd(6, (3, 4)), (1, 2))  # 6 -> (1, 2) | ||||
|         self.assertEqual(idx2crd(0, (3, 4)), (0, 0)) | ||||
|         self.assertEqual(idx2crd(11, (3, 4)), (2, 3)) | ||||
|  | ||||
|         # Test 3D case | ||||
|         self.assertEqual(idx2crd(14, (2, 3, 4)), (1, 0, 2)) | ||||
|  | ||||
|     def test_crd2idx_idx2crd_roundtrip(self): | ||||
|         # Test that crd2idx and idx2crd are inverse operations | ||||
|         shapes = [ | ||||
|             5, | ||||
|             (3, 4), | ||||
|             (2, 3, 4), | ||||
|             (2, 2, 2, 2), | ||||
|         ] | ||||
|  | ||||
|         for shape in shapes: | ||||
|             size = product(shape) | ||||
|             for idx in range(size): | ||||
|                 crd = idx2crd(idx, shape) | ||||
|                 recovered_idx = crd2idx(crd, shape) | ||||
|                 self.assertEqual( | ||||
|                     recovered_idx, idx, f"Failed roundtrip for shape {shape}, idx {idx}" | ||||
|                 ) | ||||
|  | ||||
|     def test_idx2crd_crd2idx_roundtrip(self): | ||||
|         # Test roundtrip starting from coordinates | ||||
|         test_cases = [ | ||||
|             (0, 5), | ||||
|             (4, 5), | ||||
|             ((0, 0), (3, 4)), | ||||
|             ((1, 2), (3, 4)), | ||||
|             ((2, 3), (3, 4)), | ||||
|             ((0, 0, 0), (2, 3, 4)), | ||||
|             ((1, 2, 3), (2, 3, 4)), | ||||
|         ] | ||||
|  | ||||
|         for crd, shape in test_cases: | ||||
|             idx = crd2idx(crd, shape) | ||||
|             recovered_crd = idx2crd(idx, shape) | ||||
|             self.assertEqual( | ||||
|                 recovered_crd, crd, f"Failed roundtrip for crd {crd}, shape {shape}" | ||||
|             ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
|  | ||||
| @ -1,21 +1,21 @@ | ||||
| #!/usr/bin/env python3 | ||||
| # Owner(s): ["oncall: r2p"] | ||||
|  | ||||
| import functools | ||||
|  | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the BSD-style license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
|  | ||||
| import functools | ||||
| # LICENSE file in the root directory  of this source tree. | ||||
| import json | ||||
| import os | ||||
| import signal | ||||
| import unittest | ||||
| import uuid | ||||
| from multiprocessing.pool import ThreadPool | ||||
| from typing import Any | ||||
| from unittest.mock import call, patch | ||||
| from unittest.mock import call, MagicMock, patch | ||||
|  | ||||
| import torch.distributed as dist | ||||
| import torch.distributed.elastic.rendezvous.registry as rdzv_registry | ||||
| @ -29,6 +29,7 @@ from torch.distributed.elastic.agent.server.api import ( | ||||
|     WorkerSpec, | ||||
|     WorkerState, | ||||
| ) | ||||
| from torch.distributed.elastic.events import EventSource | ||||
| from torch.distributed.elastic.multiprocessing import SignalException | ||||
| from torch.distributed.elastic.multiprocessing.errors import ProcessFailure | ||||
| from torch.distributed.elastic.rendezvous import RendezvousHandler, RendezvousParameters | ||||
| @ -157,6 +158,243 @@ def monres(state: WorkerState): | ||||
|         return RunResult(state=state) | ||||
|  | ||||
|  | ||||
| class RecordWorkerEventsTest(unittest.TestCase): | ||||
|     def setUp(self): | ||||
|         self.spec = MagicMock() | ||||
|         self.spec.role = "test_role" | ||||
|         self.spec.get_entrypoint_name.return_value = "test_entrypoint" | ||||
|         self.spec.rdzv_handler.get_run_id.return_value = "test_run_id" | ||||
|         self.spec.rdzv_handler.get_backend.return_value = "test_backend" | ||||
|         self.spec.max_restarts = 3 | ||||
|  | ||||
|         self.agent = TestAgent(self.spec) | ||||
|  | ||||
|         # Create a mock worker spec and agent | ||||
|         self.agent._worker_group = MagicMock() | ||||
|         self.agent._worker_group.spec = MagicMock() | ||||
|         self.agent._worker_group.spec.event_log_handler = "test_handler" | ||||
|  | ||||
|         # Setup worker group | ||||
|         self.worker_group = WorkerGroup(self.spec) | ||||
|         self.worker_group.group_world_size = 2 | ||||
|         self.worker_group.group_rank = 1 | ||||
|         self.agent._worker_group = self.worker_group | ||||
|  | ||||
|         # Create a test worker | ||||
|  | ||||
|         self.workers = [ | ||||
|             Worker( | ||||
|                 local_rank=0, | ||||
|                 global_rank=0, | ||||
|                 role_rank=0, | ||||
|                 world_size=2, | ||||
|                 role_world_size=2, | ||||
|             ), | ||||
|             Worker( | ||||
|                 local_rank=1, | ||||
|                 global_rank=1, | ||||
|                 role_rank=1, | ||||
|                 world_size=2, | ||||
|                 role_world_size=2, | ||||
|             ), | ||||
|         ] | ||||
|         self.workers[0].id = 0 | ||||
|         self.workers[1].id = 1 | ||||
|         self.agent._worker_group.workers = self.workers | ||||
|  | ||||
|     @patch("torch.distributed.elastic.agent.server.api.record") | ||||
|     def test_record_worker_events_success(self, mock_record): | ||||
|         # Create a RunResult with successful workers | ||||
|         result = RunResult( | ||||
|             state=WorkerState.SUCCEEDED, | ||||
|             return_values={0: "result0", 1: "result1"}, | ||||
|             failures={}, | ||||
|         ) | ||||
|  | ||||
|         # Call the method under test | ||||
|         self.agent._record_worker_events(result) | ||||
|  | ||||
|         # Verify record was called twice (once for each worker) | ||||
|         self.assertEqual(mock_record.call_count, 2) | ||||
|  | ||||
|         # Check that both calls were for SUCCEEDED events | ||||
|         for call_args in mock_record.call_args_list: | ||||
|             event = call_args[0][0] | ||||
|  | ||||
|             self.assertEqual(event.source, EventSource.WORKER) | ||||
|             self.assertEqual(event.metadata["state"], "SUCCEEDED") | ||||
|             self.assertIsNone(event.metadata["raw_error"]) | ||||
|             md = json.loads(event.metadata["metadata"]) | ||||
|             self.assertEqual(md["exit_code"], [None]) | ||||
|             self.assertEqual(md["worker_pid"], [None]) | ||||
|  | ||||
|     @patch("torch.distributed.elastic.agent.server.api.record") | ||||
|     def test_record_worker_events_failure(self, mock_record): | ||||
|         # Create failures with error data | ||||
|         failure0 = ProcessFailure( | ||||
|             local_rank=0, pid=1000, exitcode=1, error_file="error0.json" | ||||
|         ) | ||||
|  | ||||
|         # Create a RunResult with one failed worker and one terminated worker | ||||
|         result = RunResult( | ||||
|             state=WorkerState.FAILED, | ||||
|             return_values={}, | ||||
|             failures={0: failure0},  # Only worker 0 has a specific failure | ||||
|         ) | ||||
|  | ||||
|         # Call the method under test | ||||
|         self.agent._record_worker_events(result) | ||||
|  | ||||
|         # Verify record was called twice (once for each worker) | ||||
|         self.assertEqual(mock_record.call_count, 2) | ||||
|  | ||||
|         # Get the calls | ||||
|         calls = mock_record.call_args_list | ||||
|  | ||||
|         # Check first call for the failed worker (global_rank=0) | ||||
|         failed_event = calls[0][0][0] | ||||
|         self.assertEqual(failed_event.source, EventSource.WORKER) | ||||
|         self.assertEqual(failed_event.metadata["state"], "FAILED") | ||||
|         self.assertEqual(failed_event.metadata["global_rank"], 0) | ||||
|         md = json.loads(failed_event.metadata["metadata"]) | ||||
|         self.assertEqual(failed_event.metadata["raw_error"], '{"message": "<NONE>"}') | ||||
|         self.assertEqual(md["exit_code"], [1]) | ||||
|         self.assertEqual(md["worker_pid"], [1000]) | ||||
|  | ||||
|         # Check second call for the terminated worker (global_rank=1) | ||||
|         terminated_event = calls[1][0][0] | ||||
|         self.assertEqual(terminated_event.source, EventSource.WORKER) | ||||
|         self.assertEqual(terminated_event.metadata["state"], "TERMINATED") | ||||
|         self.assertEqual(terminated_event.metadata["global_rank"], 1) | ||||
|         self.assertIsNone(terminated_event.metadata["raw_error"]) | ||||
|         md = json.loads(terminated_event.metadata["metadata"]) | ||||
|         self.assertEqual(md["exit_code"], [None]) | ||||
|         self.assertEqual(md["worker_pid"], [None]) | ||||
|  | ||||
|  | ||||
| class ConstructEventTest(unittest.TestCase): | ||||
|     def setUp(self): | ||||
|         # Create minimal spec and agent for testing | ||||
|         self.spec = MagicMock() | ||||
|         self.spec.role = "test_role" | ||||
|         self.spec.get_entrypoint_name.return_value = "test_entrypoint" | ||||
|         self.spec.rdzv_handler.get_run_id.return_value = "test_run_id" | ||||
|         self.spec.rdzv_handler.get_backend.return_value = "test_backend" | ||||
|         self.spec.max_restarts = 3 | ||||
|  | ||||
|         self.agent = TestAgent(self.spec) | ||||
|         self.agent._remaining_restarts = 2 | ||||
|         self.agent._total_execution_time = 42 | ||||
|  | ||||
|         # Setup worker group | ||||
|         self.worker_group = WorkerGroup(self.spec) | ||||
|         self.worker_group.group_world_size = 2 | ||||
|         self.worker_group.group_rank = 1 | ||||
|         self.agent._worker_group = self.worker_group | ||||
|  | ||||
|         # Create a test worker | ||||
|         self.worker = Worker( | ||||
|             local_rank=0, global_rank=5, role_rank=3, world_size=8, role_world_size=4 | ||||
|         ) | ||||
|         self.worker.id = 12345 | ||||
|  | ||||
|     def test_construct_event_agent_success(self): | ||||
|         # Test constructing an agent success event | ||||
|         event = self.agent._construct_event(state="SUCCEEDED", source=EventSource.AGENT) | ||||
|  | ||||
|         # Verify basic event properties | ||||
|         self.assertEqual(event.name, "torchelastic.worker.status.SUCCEEDED") | ||||
|         self.assertEqual(event.source, EventSource.AGENT) | ||||
|  | ||||
|         # Verify metadata | ||||
|         metadata = event.metadata | ||||
|         self.assertEqual(metadata["run_id"], "test_run_id") | ||||
|         self.assertIsNone(metadata["global_rank"]) | ||||
|         self.assertEqual(metadata["group_rank"], 1) | ||||
|         self.assertIsNone(metadata["worker_id"]) | ||||
|         self.assertEqual(metadata["role"], "test_role") | ||||
|         self.assertEqual(metadata["state"], "SUCCEEDED") | ||||
|         self.assertEqual(metadata["total_run_time"], 42) | ||||
|         self.assertEqual(metadata["rdzv_backend"], "test_backend") | ||||
|         self.assertIsNone(metadata["raw_error"]) | ||||
|         self.assertEqual( | ||||
|             metadata["agent_restarts"], 1 | ||||
|         )  # max_restarts - remaining_restarts | ||||
|         self.assertIsNone(metadata["duration_ms"]) | ||||
|  | ||||
|         # Verify JSON metadata | ||||
|         md_dict = json.loads(metadata["metadata"]) | ||||
|         self.assertEqual(md_dict["group_world_size"], 2) | ||||
|         self.assertEqual(md_dict["entry_point"], "test_entrypoint") | ||||
|  | ||||
|     def test_construct_event_worker_failure(self): | ||||
|         # Test constructing a worker failure event with raw error | ||||
|         raw_error = json.dumps( | ||||
|             {"error_message": "Test error", "traceback": "stack trace"} | ||||
|         ) | ||||
|         event = self.agent._construct_event( | ||||
|             state="FAILED", | ||||
|             source=EventSource.WORKER, | ||||
|             worker=self.worker, | ||||
|             raw_error=raw_error, | ||||
|             exit_code=1, | ||||
|         ) | ||||
|  | ||||
|         # Verify basic event properties | ||||
|         self.assertEqual(event.name, "torchelastic.worker.status.FAILED") | ||||
|         self.assertEqual(event.source, EventSource.WORKER) | ||||
|  | ||||
|         # Verify metadata | ||||
|         metadata = event.metadata | ||||
|         self.assertEqual(metadata["run_id"], "test_run_id") | ||||
|         self.assertEqual(metadata["global_rank"], 5) | ||||
|         self.assertEqual(metadata["group_rank"], 1) | ||||
|         self.assertEqual(metadata["worker_id"], "12345") | ||||
|         self.assertEqual(metadata["role"], "test_role") | ||||
|         self.assertEqual(metadata["state"], "FAILED") | ||||
|         self.assertEqual(metadata["total_run_time"], 42) | ||||
|         self.assertEqual(metadata["rdzv_backend"], "test_backend") | ||||
|         self.assertEqual(metadata["raw_error"], raw_error) | ||||
|         self.assertEqual(metadata["agent_restarts"], 1) | ||||
|  | ||||
|         # Verify worker-specific metadata | ||||
|         md_dict = json.loads(metadata["metadata"]) | ||||
|         self.assertEqual(md_dict["local_rank"], [0]) | ||||
|         self.assertEqual(md_dict["role_rank"], [3]) | ||||
|         self.assertEqual(md_dict["role_world_size"], [4]) | ||||
|         self.assertEqual(md_dict["exit_code"], [1]) | ||||
|  | ||||
|     def test_construct_event_with_duration(self): | ||||
|         # Test constructing an event with duration_ms | ||||
|         event = self.agent._construct_event( | ||||
|             state="RENDEZVOUS", source=EventSource.AGENT, duration_ms=123.45 | ||||
|         ) | ||||
|  | ||||
|         # Verify duration is set correctly | ||||
|         self.assertEqual(event.metadata["duration_ms"], 123.45) | ||||
|  | ||||
|     def test_construct_event_worker_no_error(self): | ||||
|         # Test constructing a worker event without error info | ||||
|         event = self.agent._construct_event( | ||||
|             state="HEALTHY", source=EventSource.WORKER, worker=self.worker | ||||
|         ) | ||||
|  | ||||
|         # Verify error fields are None | ||||
|         metadata = event.metadata | ||||
|         self.assertIsNone(metadata["raw_error"]) | ||||
|  | ||||
|         # Check worker info is set | ||||
|         self.assertEqual(metadata["global_rank"], 5) | ||||
|         self.assertEqual(metadata["worker_id"], "12345") | ||||
|  | ||||
|         # Check metadata JSON | ||||
|         md_dict = json.loads(metadata["metadata"]) | ||||
|         self.assertEqual(md_dict["local_rank"], [0]) | ||||
|         self.assertEqual(md_dict["role_rank"], [3]) | ||||
|         self.assertEqual(md_dict["role_world_size"], [4]) | ||||
|         self.assertNotIn("exit_code", [None]) | ||||
|  | ||||
|  | ||||
| class SimpleElasticAgentTest(unittest.TestCase): | ||||
|     def _get_worker_spec( | ||||
|         self, | ||||
|  | ||||
| @ -568,9 +568,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): | ||||
|             ) | ||||
|  | ||||
|             results = pc.wait(period=0.1) | ||||
|  | ||||
|             self.assertTrue(results.is_failed()) | ||||
|             self.assertEqual(1, len(results.failures)) | ||||
|             self.assertEqual(2, len(results.failures)) | ||||
|  | ||||
|             failure = results.failures[0] | ||||
|             self.assertEqual(138, failure.exitcode) | ||||
| @ -583,6 +582,13 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): | ||||
|             self.assertTrue(pc._stderr_tail.stopped()) | ||||
|             self.assertTrue(pc._stdout_tail.stopped()) | ||||
|  | ||||
|             failure = results.failures[1] | ||||
|             self.assertEqual(-15, failure.exitcode) | ||||
|             self.assertEqual("SIGTERM", failure.signal_name()) | ||||
|             self.assertEqual("<NONE>", failure.error_file_data["message"]) | ||||
|             # Assert that the failure message contains expected substrings | ||||
|             self.assertIn("Signal 15 (SIGTERM) received by PID", failure.message) | ||||
|  | ||||
|         def test_binary_raises(self): | ||||
|             pc = start_processes( | ||||
|                 name="echo", | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	