diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 9ca0decd087e..3e2dc09ef495 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -485,6 +485,22 @@ test_inductor_aoti() { /usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile } +test_inductor_aoti_cross_compile_for_windows() { + + TEST_REPORTS_DIR=$(pwd)/test/test-reports + mkdir -p "$TEST_REPORTS_DIR" + + # Set WINDOWS_CUDA_HOME environment variable + WINDOWS_CUDA_HOME="$(pwd)/win-torch-wheel-extracted" + export WINDOWS_CUDA_HOME + + echo "WINDOWS_CUDA_HOME is set to: $WINDOWS_CUDA_HOME" + echo "Contents:" + ls -lah "$(pwd)/win-torch-wheel-extracted/lib/x64/" || true + + python test/inductor/test_aoti_cross_compile_windows.py -k compile --package-dir "$TEST_REPORTS_DIR" --win-torch-lib-dir "$(pwd)/win-torch-wheel-extracted/torch/lib" +} + test_inductor_cpp_wrapper_shard() { if [[ -z "$NUM_TEST_SHARDS" ]]; then echo "NUM_TEST_SHARDS must be defined to run a Python test shard" @@ -1718,6 +1734,8 @@ elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then test_inductor_triton_cpu elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then test_inductor_micro_benchmark +elif [[ "${TEST_CONFIG}" == *aoti_cross_compile_for_windows* ]]; then + test_inductor_aoti_cross_compile_for_windows elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then install_torchvision id=$((SHARD_NUMBER-1)) diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index 89f13d3fea8f..29c2fc8e0847 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -224,6 +224,46 @@ jobs: continue-on-error: true uses: ./.github/actions/download-td-artifacts + - name: Download Windows torch wheel for cross-compilation + if: matrix.win_torch_wheel_artifact != '' + uses: seemethere/download-artifact-s3@1da556a7aa0a088e3153970611f6c432d58e80e6 # v4.2.0 + with: + name: ${{ matrix.win_torch_wheel_artifact }} + path: win-torch-wheel + + - name: Extract Windows wheel and setup CUDA libraries + if: matrix.win_torch_wheel_artifact != '' + shell: bash + run: | + set -x + + # Find the wheel file + WHEEL_FILE=$(find win-torch-wheel -name "*.whl" -type f | head -n 1) + if [ -z "$WHEEL_FILE" ]; then + echo "Error: No wheel file found in win-torch-wheel directory" + exit 1 + fi + echo "Found wheel file: $WHEEL_FILE" + + # Unzip the wheel file + unzip -q "$WHEEL_FILE" -d win-torch-wheel-extracted + echo "Extracted wheel contents" + + # Setup CUDA libraries (cuda.lib and cudart.lib) directory + mkdir -p win-torch-wheel-extracted/lib/x64 + if [ -f "win-torch-wheel/cuda.lib" ]; then + mv win-torch-wheel/cuda.lib win-torch-wheel-extracted/lib/x64/ + echo "Moved cuda.lib to win-torch-wheel-extracted/lib/x64/" + fi + if [ -f "win-torch-wheel/cudart.lib" ]; then + mv win-torch-wheel/cudart.lib win-torch-wheel-extracted/lib/x64/ + echo "Moved cudart.lib to win-torch-wheel-extracted/lib/x64/" + fi + + # Verify CUDA libraries are present + echo "CUDA libraries:" + ls -la win-torch-wheel-extracted/lib/x64/ || echo "No CUDA libraries found" + - name: Parse ref id: parse-ref run: .github/scripts/parse_ref.py diff --git a/.github/workflows/_win-build.yml b/.github/workflows/_win-build.yml index 153f6007b3f0..0fd3cf7f3972 100644 --- a/.github/workflows/_win-build.yml +++ b/.github/workflows/_win-build.yml @@ -168,6 +168,31 @@ jobs: run: | .ci/pytorch/win-build.sh + # Collect Windows torch libs and CUDA libs for cross-compilation + - name: Collect Windows CUDA libs for cross-compilation + if: steps.build.outcome != 'skipped' && inputs.cuda-version != 'cpu' + shell: bash + run: | + set -ex + + # Create directory structure if does not exist + mkdir -p /c/${{ github.run_id }}/build-results + + # Copy CUDA libs + CUDA_PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${{ inputs.cuda-version }}" + + if [ -f "${CUDA_PATH}/lib/x64/cuda.lib" ]; then + cp "${CUDA_PATH}/lib/x64/cuda.lib" /c/${{ github.run_id }}/build-results/ + fi + + if [ -f "${CUDA_PATH}/lib/x64/cudart.lib" ]; then + cp "${CUDA_PATH}/lib/x64/cudart.lib" /c/${{ github.run_id }}/build-results/ + fi + + # List collected files + echo "Collected CUDA libs:" + ls -lah /c/${{ github.run_id }}/build-results/*.lib + # Upload to github so that people can click and download artifacts - name: Upload artifacts to s3 if: steps.build.outcome != 'skipped' diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index c8aab0aee10e..710b6cfa9eaf 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -200,6 +200,23 @@ jobs: cuda-arch-list: '8.0' secrets: inherit + # Test cross-compiled models with Windows libs extracted from wheel + cross-compile-linux-test: + name: cross-compile-linux-test + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-cuda12_8-py3_10-gcc11-build + - get-label-type + - win-vs2022-cuda12_8-py3-build + with: + build-environment: linux-jammy-cuda12.8-py3.10-gcc11 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }} + test-matrix: | + { include: [ + { config: "aoti_cross_compile_for_windows", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", win_torch_wheel_artifact: "win-vs2022-cuda12.8-py3" }, + ]} + secrets: inherit + verify-cachebench-cpu-build: name: verify-cachebench-cpu-build uses: ./.github/workflows/_linux-build.yml diff --git a/test/inductor/test_aoti_cross_compile_windows.py b/test/inductor/test_aoti_cross_compile_windows.py new file mode 100644 index 000000000000..04065add9081 --- /dev/null +++ b/test/inductor/test_aoti_cross_compile_windows.py @@ -0,0 +1,371 @@ +# Owner(s): ["module: inductor"] +import os +import platform +import tempfile +import unittest +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional + +import torch +import torch._inductor.config +from torch._inductor.test_case import TestCase +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu + + +@dataclass +class ModelTestConfig: + """Configuration for a model test case.""" + + name: str + model_class: type + example_inputs: tuple[torch.Tensor, ...] + dynamic_shapes: Optional[dict[str, Any]] = None + inductor_configs: Optional[dict[str, Any]] = None + rtol: float = 1e-4 + atol: float = 1e-4 + + +class WindowsCrossCompilationTestFramework: + """ + Framework for testing cross-compilation from Linux to Windows. + + Provides reusable logic for creating compile and load test methods. + """ + + _base_path: Optional[Path] = None + _win_torch_libs_path: Optional[str] = None + + @classmethod + def base_path(cls) -> Path: + """Get or create the base path for package files.""" + if cls._base_path is None: + cls._base_path = Path(tempfile.mkdtemp(prefix="aoti_cross_compile_")) + return cls._base_path + + @classmethod + def set_base_path(cls, path: Optional[Path | str] = None) -> None: + """Set the base path for package files.""" + cls._base_path = Path(path) if path else None + + @classmethod + def set_win_torch_libs_path(cls, path: Optional[str] = None) -> None: + """Set the path for Windows torch libs.""" + cls._win_torch_libs_path = path + + @classmethod + def get_package_path(cls, model_name: str) -> str: + """Get the path for a model's .pt2 package file.""" + package_dir = cls.base_path() + package_dir.mkdir(parents=True, exist_ok=True) + return str(package_dir / f"{model_name}_windows.pt2") + + @classmethod + def get_win_torch_libs_path(cls) -> str: + """Get the path for Windows torch libs.""" + if cls._win_torch_libs_path is None: + raise RuntimeError("Windows torch libs path not set") + return str(cls._win_torch_libs_path) + + @classmethod + def create_compile_test(cls, config: ModelTestConfig): + """Create a compile test method for a model configuration.""" + + def compile_test(self): + if platform.system() == "Windows": + raise unittest.SkipTest( + "This test should run on Linux for cross-compilation" + ) + + self.assertTrue("WINDOWS_CUDA_HOME" in os.environ) + + with torch.no_grad(): + # Windows cross-compilation is only used for GPU. + # AOTI for CPU should be able to work as native compilation on Windows. + device = GPU_TYPE + model = config.model_class().to(device=device) + example_inputs = config.example_inputs + + # Inputs should already be on GPU_TYPE but ensure they are + example_inputs = tuple(inp.to(device) for inp in example_inputs) + + # Export the model + exported = torch.export.export( + model, example_inputs, dynamic_shapes=config.dynamic_shapes + ) + + # Prepare inductor configs + inductor_configs = { + "aot_inductor.cross_target_platform": "windows", + "aot_inductor.precompile_headers": False, + "aot_inductor.package_constants_on_disk_format": "binary_blob", + "aot_inductor.package_constants_in_so": False, + "aot_inductor.aoti_shim_library_path": cls.get_win_torch_libs_path(), + } + if config.inductor_configs: + inductor_configs.update(config.inductor_configs) + + # Compile and package directly to the expected location + package_path = cls.get_package_path(config.name) + torch._inductor.aoti_compile_and_package( + exported, + package_path=package_path, + inductor_configs=inductor_configs, + ) + + self.assertTrue( + os.path.exists(package_path), + f"Package file should exist at {package_path}", + ) + + return compile_test + + @classmethod + def create_load_test(cls, config: ModelTestConfig): + """Create a load test method for a model configuration.""" + + def load_test(self): + if platform.system() != "Windows": + raise unittest.SkipTest("This test should run on Windows") + + if not HAS_GPU: + raise unittest.SkipTest("Test requires GPU") + + package_path = cls.get_package_path(config.name) + if not os.path.exists(package_path): + raise unittest.SkipTest( + f"Package file not found at {package_path}. " + f"Run test_{config.name}_compile first." + ) + + with torch.no_grad(): + # Windows cross-compilation is only used for GPU. + # AOTI for CPU should be able to work as native compilation on Windows. + device = GPU_TYPE + + # Create original model for comparison + original_model = config.model_class().to(device=device) + example_inputs = config.example_inputs + + # Inputs should already be on GPU_TYPE but ensure they are + example_inputs = tuple(inp.to(device) for inp in example_inputs) + + # Load the compiled package + loaded_model = torch._inductor.aoti_load_package(package_path) + + # Test with the same inputs + original_output = original_model(*example_inputs) + loaded_output = loaded_model(*example_inputs) + + # Compare outputs + torch.testing.assert_close( + original_output, loaded_output, rtol=config.rtol, atol=config.atol + ) + + return load_test + + +def auto_generate_tests(test_class): + """ + Class decorator to automatically generate compile/load test methods + from _define_* methods that return ModelTestConfig. + """ + # Find all _define_* methods that return ModelTestConfig + define_methods = {} + for name in dir(test_class): + if name.startswith("_define_") and callable(getattr(test_class, name)): + method = getattr(test_class, name) + # Try to call the method to see if it returns ModelTestConfig + try: + # Create a temporary instance to call the method + temp_instance = test_class.__new__(test_class) + result = method(temp_instance) + if isinstance(result, ModelTestConfig): + define_methods[name] = result + except Exception: + # If method fails, skip it + pass + + # Generate compile/load methods for each discovered definition + for define_name, config in define_methods.items(): + model_name = define_name[8:] # Remove '_define_' prefix + + # Create compile test method + compile_method_name = f"test_{model_name}_compile" + compile_method = WindowsCrossCompilationTestFramework.create_compile_test( + config + ) + compile_method.__name__ = compile_method_name + compile_method.__doc__ = f"Step 1: Cross-compile {model_name} model on Linux" + compile_method = requires_gpu()(compile_method) + setattr(test_class, compile_method_name, compile_method) + + # Create load test method + load_method_name = f"test_{model_name}_load" + load_method = WindowsCrossCompilationTestFramework.create_load_test(config) + load_method.__name__ = load_method_name + load_method.__doc__ = f"Step 2: Load and test {model_name} model on Windows" + load_method = requires_gpu()(load_method) + setattr(test_class, load_method_name, load_method) + + return test_class + + +@auto_generate_tests +class TestAOTInductorWindowsCrossCompilation(TestCase): + """ + Test class for AOT Inductor Windows cross-compilation. + + Define test methods that return ModelTestConfig, and the decorator + will auto-generate compile/load test methods. + """ + + def _define_simple(self): + """Define the Simple model and its test configuration.""" + + class Simple(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 16) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(16, 1) + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return x + + return ModelTestConfig( + name="simple", + model_class=Simple, + example_inputs=(torch.randn(8, 10, device=GPU_TYPE),), + dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1, max=1024)}}, + ) + + def _define_simple_cnn(self): + """Define the SimpleCNN model and its test configuration.""" + + class SimpleCNN(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 16, 3) + self.relu = torch.nn.ReLU() + self.pool = torch.nn.AdaptiveAvgPool2d((1, 1)) + self.fc = torch.nn.Linear(16, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.pool(x) + x = x.flatten(1) + x = self.fc(x) + return x + + return ModelTestConfig( + name="simple_cnn", + model_class=SimpleCNN, + example_inputs=(torch.randn(2, 3, 32, 32, device=GPU_TYPE),), + dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1, max=16)}}, + rtol=1e-3, + atol=1e-3, + ) + + def _define_transformer(self): + """Define the SimpleTransformer model and its test configuration.""" + + class SimpleTransformer(torch.nn.Module): + def __init__(self): + super().__init__() + self.embedding = torch.nn.Linear(128, 256) + self.attention = torch.nn.MultiheadAttention(256, 8, batch_first=True) + self.norm1 = torch.nn.LayerNorm(256) + self.ffn = torch.nn.Sequential( + torch.nn.Linear(256, 1024), + torch.nn.ReLU(), + torch.nn.Linear(1024, 256), + ) + self.norm2 = torch.nn.LayerNorm(256) + self.output = torch.nn.Linear(256, 10) + + def forward(self, x): + # x shape: (batch, seq_len, input_dim) + x = self.embedding(x) + attn_out, _ = self.attention(x, x, x) + x = self.norm1(x + attn_out) + ffn_out = self.ffn(x) + x = self.norm2(x + ffn_out) + x = x.mean(dim=1) # Global average pooling + x = self.output(x) + return x + + return ModelTestConfig( + name="transformer", + model_class=SimpleTransformer, + example_inputs=(torch.randn(4, 16, 128, device=GPU_TYPE),), + dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1, max=32)}}, + rtol=1e-3, + atol=1e-3, + ) + + +if __name__ == "__main__": + import sys + + from torch._inductor.test_case import run_tests + + # Check for --package-dir argument and remove it before unittest sees it + package_dir = None + win_torch_lib_dir = None + filtered_argv = [] + i = 0 + while i < len(sys.argv): + if sys.argv[i] == "--package-dir": + if i + 1 < len(sys.argv): + package_dir = sys.argv[i + 1] + i += 2 # Skip both --package-dir and its value + else: + print("Error: --package-dir requires a valid directory path") + sys.exit(1) + elif sys.argv[i].startswith("--package-dir="): + package_dir = sys.argv[i].split("=", 1)[1] + i += 1 + elif sys.argv[i] == "--win-torch-lib-dir": + if i + 1 < len(sys.argv): + win_torch_lib_dir = sys.argv[i + 1] + i += 2 # Skip both --win-torch-lib-dir and its value + else: + print("Error: --win-torch-lib-dir requires a valid directory path") + sys.exit(1) + elif sys.argv[i].startswith("--win-torch-lib-dir="): + win_torch_lib_dir = sys.argv[i].split("=", 1)[1] + i += 1 + else: + filtered_argv.append(sys.argv[i]) + i += 1 + + # Validate and set the base path for package storage + if package_dir: + try: + package_path = Path(package_dir) + package_path.mkdir(parents=True, exist_ok=True) + # Test write access + test_file = package_path / ".test_write" + test_file.touch() + test_file.unlink() + WindowsCrossCompilationTestFramework.set_base_path(package_path) + except Exception: + print("Error: --package-dir requires a valid directory path") + sys.exit(1) + + # Set Windows torch libs path if provided (only needed for compile tests) + if win_torch_lib_dir: + WindowsCrossCompilationTestFramework.set_win_torch_libs_path(win_torch_lib_dir) + + # Update sys.argv to remove our custom arguments + sys.argv = filtered_argv + + if HAS_GPU: + run_tests(needs="filelock") diff --git a/tools/testing/discover_tests.py b/tools/testing/discover_tests.py index 13511b1ec129..1210326a02db 100644 --- a/tools/testing/discover_tests.py +++ b/tools/testing/discover_tests.py @@ -107,6 +107,7 @@ TESTS = discover_tests( "lazy/test_meta_kernel", "lazy/test_extract_compiled_graph", "test/inductor/test_aot_inductor_utils", + "inductor/test_aoti_cross_compile_windows", "onnx/test_onnxscript_no_runtime", "onnx/test_pytorch_onnx_onnxruntime_cuda", "onnx/test_models",