Compare commits

...

10 Commits

Author SHA1 Message Date
6c3db7e9a6 change to use linux-jammy-cuda12_8-py3_10-gcc11-build 2025-10-14 22:12:13 -07:00
f8ee3738bc change to use existing trunk.yml 2025-10-14 16:51:28 -07:00
c3d2da4fe7 clean up 2025-10-14 15:01:04 -07:00
c7827342f9 add workflow dispatch 2025-10-14 15:01:04 -07:00
c7b1335d09 fix ci 2025-10-14 15:01:04 -07:00
da9410474f lint 2025-10-14 15:01:04 -07:00
6081c6b0b8 fix linux tests 2025-10-14 15:01:04 -07:00
01a153c7d0 modify CI 2025-10-14 15:01:04 -07:00
e4735c67a4 add posix support and shim library path for lib torch 2025-10-14 15:00:20 -07:00
72627e8190 Windows CI test 2025-10-14 15:00:20 -07:00
10 changed files with 596 additions and 12 deletions

View File

@ -113,6 +113,7 @@ case "$tag" in
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
TRITON=yes
INSTALL_MINGW=yes
;;
pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11)
CUDA_VERSION=13.0.0
@ -361,6 +362,7 @@ docker build \
--build-arg "OPENBLAS=${OPENBLAS:-}" \
--build-arg "SKIP_SCCACHE_INSTALL=${SKIP_SCCACHE_INSTALL:-}" \
--build-arg "SKIP_LLVM_SRC_BUILD_INSTALL=${SKIP_LLVM_SRC_BUILD_INSTALL:-}" \
--build-arg "INSTALL_MINGW=${INSTALL_MINGW:-}" \
-f $(dirname ${DOCKERFILE})/Dockerfile \
-t "$tmp_tag" \
"$@" \

View File

@ -0,0 +1,10 @@
#!/bin/bash
set -ex
# Install MinGW-w64 for Windows cross-compilation
apt-get update
apt-get install -y g++-mingw-w64-x86-64-posix
echo "MinGW-w64 installed successfully"
x86_64-w64-mingw32-g++ --version

View File

@ -103,6 +103,11 @@ COPY ci_commit_pins/torchbench.txt torchbench.txt
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
ARG INSTALL_MINGW
COPY ./common/install_mingw.sh install_mingw.sh
RUN if [ -n "${INSTALL_MINGW}" ]; then bash ./install_mingw.sh; fi
RUN rm install_mingw.sh
ARG TRITON
ARG TRITON_CPU

View File

@ -485,6 +485,49 @@ test_inductor_aoti() {
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile
}
test_inductor_aoti_cross_compile_for_windows() {
# sudo apt-get update
# sudo apt-get install -y g++-mingw-w64-x86-64-posix
TEST_REPORTS_DIR=$(pwd)/test/test-reports
mkdir -p "$TEST_REPORTS_DIR"
# The artifact is downloaded to win-cuda-libs in the workspace by GitHub Actions
# The artifact contains only the .lib files (no directory structure)
WIN_CUDA_LIBS_DOWNLOAD_DIR="$(pwd)/win-cuda-libs"
if [[ ! -d "$WIN_CUDA_LIBS_DOWNLOAD_DIR" ]]; then
echo "ERROR: Windows CUDA libs directory not found at $WIN_CUDA_LIBS_DOWNLOAD_DIR"
echo "The artifact should have been downloaded by GitHub Actions before running the docker container"
exit 1
fi
echo "Contents of downloaded Windows CUDA libs:"
ls -lah "$WIN_CUDA_LIBS_DOWNLOAD_DIR/" || true
# Create the expected directory structure and move CUDA libs to lib/x64
WIN_TORCH_LIBS_DIR="$(pwd)/win-cuda-libs-structured"
mkdir -p "$WIN_TORCH_LIBS_DIR/lib/x64"
# Move the downloaded CUDA libs to the expected location
if [ -f "$WIN_CUDA_LIBS_DOWNLOAD_DIR/cuda.lib" ]; then
mv "$WIN_CUDA_LIBS_DOWNLOAD_DIR/cuda.lib" "$WIN_TORCH_LIBS_DIR/lib/x64/cuda.lib"
fi
if [ -f "$WIN_CUDA_LIBS_DOWNLOAD_DIR/cudart.lib" ]; then
mv "$WIN_CUDA_LIBS_DOWNLOAD_DIR/cudart.lib" "$WIN_TORCH_LIBS_DIR/lib/x64/cudart.lib"
fi
# Set WINDOWS_CUDA_HOME environment variable
export WINDOWS_CUDA_HOME="$WIN_TORCH_LIBS_DIR"
echo "WINDOWS_CUDA_HOME is set to: $WINDOWS_CUDA_HOME"
echo "Contents of Windows torch libs after restructuring:"
ls -lah "$WIN_TORCH_LIBS_DIR/lib/x64/" || true
ls -lah "$(pwd)/win-torch-wheel-extracted/torch/lib" || true
python test/inductor/test_aoti_cross_compile_windows.py -k compile --package-dir "$TEST_REPORTS_DIR" --win-torch-lib-dir "$(pwd)/win-torch-wheel-extracted/torch/lib"
}
test_inductor_cpp_wrapper_shard() {
if [[ -z "$NUM_TEST_SHARDS" ]]; then
echo "NUM_TEST_SHARDS must be defined to run a Python test shard"
@ -1717,6 +1760,8 @@ elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
test_inductor_triton_cpu
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then
test_inductor_micro_benchmark
elif [[ "${TEST_CONFIG}" == *aoti_cross_compile_for_windows* ]]; then
test_inductor_aoti_cross_compile_for_windows
elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then
install_torchvision
id=$((SHARD_NUMBER-1))

View File

@ -224,6 +224,42 @@ jobs:
continue-on-error: true
uses: ./.github/actions/download-td-artifacts
- name: Download Windows torch wheel for cross-compilation
if: matrix.win_torch_wheel_artifact != ''
uses: seemethere/download-artifact-s3@1da556a7aa0a088e3153970611f6c432d58e80e6 # v4.2.0
with:
name: ${{ matrix.win_torch_wheel_artifact }}
path: win-torch-wheel
- name: Download CUDA libraries for cross-compilation
if: matrix.win_torch_wheel_artifact != ''
uses: seemethere/download-artifact-s3@1da556a7aa0a088e3153970611f6c432d58e80e6 # v4.2.0
with:
name: win-vs2022-cuda12.8-py3-cuda-libs
path: win-cuda-libs
- name: Extract Windows wheel
if: matrix.win_torch_wheel_artifact != ''
shell: bash
run: |
set -x
# Find the wheel file
WHEEL_FILE=$(find win-torch-wheel -name "*.whl" -type f | head -n 1)
if [ -z "$WHEEL_FILE" ]; then
echo "Error: No wheel file found in win-torch-wheel directory"
exit 1
fi
echo "Found wheel file: $WHEEL_FILE"
# Unzip the wheel file
unzip -q "$WHEEL_FILE" -d win-torch-wheel-extracted
echo "Extracted wheel contents"
# Verify CUDA libraries are present
echo "Downloaded CUDA libraries:"
ls -la win-cuda-libs/
- name: Parse ref
id: parse-ref
run: .github/scripts/parse_ref.py
@ -296,6 +332,7 @@ jobs:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }}
ARTIFACTS_FILE_SUFFIX: ${{ github.job }}-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }}_${{ steps.get-job-id.outputs.job-id }}
WIN_TORCH_LIBS_ARTIFACT: ${{ matrix.win_torch_libs_artifact || '' }}
run: |
set -x
@ -373,6 +410,7 @@ jobs:
-e SCRIBE_GRAPHQL_ACCESS_TOKEN \
-e DASHBOARD_TAG \
-e ARTIFACTS_FILE_SUFFIX \
-e WIN_TORCH_LIBS_ARTIFACT \
--memory="${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}g" \
--memory-swap="${TOTAL_MEMORY_WITH_SWAP}g" \
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \

View File

@ -50,6 +50,13 @@ on:
default: "windows.4xlarge.nonephemeral"
description: |
Label of the runner this job should run on.
upload-win-cuda-libs-for-cross-compile:
required: false
type: boolean
default: false
description: |
If set, collect and upload Windows torch libs and CUDA libs for cross-compilation.
This is only needed for AOTI cross-compilation workflows.
outputs:
test-matrix:
@ -168,6 +175,42 @@ jobs:
run: |
.ci/pytorch/win-build.sh
# Collect Windows torch libs and CUDA libs for cross-compilation
- name: Collect Windows CUDA libs for cross-compilation
if: steps.build.outcome != 'skipped' && inputs.cuda-version != 'cpu' && inputs.upload-win-cuda-libs-for-cross-compile
shell: bash
run: |
set -ex
# Create directory structure
mkdir -p /c/${{ github.run_id }}/win-cuda-libs/lib
mkdir -p /c/${{ github.run_id }}/win-cuda-libs/lib/x64
# Copy CUDA libs
CUDA_PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${{ inputs.cuda-version }}"
if [ -f "${CUDA_PATH}/lib/x64/cuda.lib" ]; then
cp "${CUDA_PATH}/lib/x64/cuda.lib" /c/${{ github.run_id }}/win-cuda-libs/lib/x64/
fi
if [ -f "${CUDA_PATH}/lib/x64/cudart.lib" ]; then
cp "${CUDA_PATH}/lib/x64/cudart.lib" /c/${{ github.run_id }}/win-cuda-libs/lib/x64/
fi
# List collected files
echo "Collected CUDA libs:"
ls -lah /c/${{ github.run_id }}/win-cuda-libs/lib/x64/
# Upload CUDA libs for cross-compilation
- name: Upload CUDA libs to S3
if: steps.build.outcome != 'skipped' && inputs.cuda-version != 'cpu' && inputs.upload-win-cuda-libs-for-cross-compile
uses: seemethere/upload-artifact-s3@baba72d0712b404f646cebe0730933554ebce96a # v5.1.0
with:
retention-days: 14
if-no-files-found: error
name: ${{ inputs.build-environment }}-cuda-libs
path: C:\${{ github.run_id }}\win-cuda-libs\lib\x64
# Upload to github so that people can click and download artifacts
- name: Upload artifacts to s3
if: steps.build.outcome != 'skipped'

View File

@ -188,6 +188,7 @@ jobs:
build-environment: win-vs2022-cuda12.8-py3
cuda-version: "12.8"
runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
upload-win-cuda-libs-for-cross-compile: true
secrets: inherit
inductor-build:
@ -200,6 +201,23 @@ jobs:
cuda-arch-list: '8.0'
secrets: inherit
# Test cross-compiled models with Windows libs extracted from wheel
cross-compile-linux-test:
name: cross-compile-linux-test
uses: ./.github/workflows/_linux-test.yml
needs:
- linux-jammy-cuda12_8-py3_10-gcc11-build
- get-label-type
- win-vs2022-cuda12_8-py3-build
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }}
test-matrix: |
{ include: [
{ config: "aoti_cross_compile_for_windows", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", win_torch_wheel_artifact: "win-vs2022-cuda12.8-py3" },
]}
secrets: inherit
verify-cachebench-cpu-build:
name: verify-cachebench-cpu-build
uses: ./.github/workflows/_linux-build.yml

View File

@ -0,0 +1,391 @@
# Owner(s): ["module: inductor"]
import os
import platform
import subprocess
import tempfile
import unittest
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
import torch
import torch._inductor.config
from torch._inductor.test_case import TestCase
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu
def _check_mingw_gcc_available():
try:
result = subprocess.run(
["x86_64-w64-mingw32-gcc", "--version"],
capture_output=True,
text=True,
timeout=10,
)
return result.returncode == 0
except (subprocess.SubprocessError, FileNotFoundError, subprocess.TimeoutExpired):
return False
@dataclass
class ModelTestConfig:
"""Configuration for a model test case."""
name: str
model_class: type
example_inputs: tuple[torch.Tensor, ...]
dynamic_shapes: Optional[dict[str, Any]] = None
inductor_configs: Optional[dict[str, Any]] = None
rtol: float = 1e-4
atol: float = 1e-4
class WindowsCrossCompilationTestFramework:
"""
Framework for testing cross-compilation from Linux to Windows.
Provides reusable logic for creating compile and load test methods.
"""
_base_path: Optional[Path] = None
_win_torch_libs_path: Optional[str] = None
@classmethod
def base_path(cls) -> Path:
"""Get or create the base path for package files."""
if cls._base_path is None:
cls._base_path = Path(tempfile.mkdtemp(prefix="aoti_cross_compile_"))
return cls._base_path
@classmethod
def set_base_path(cls, path: Optional[Path | str] = None) -> None:
"""Set the base path for package files."""
cls._base_path = Path(path) if path else None
@classmethod
def set_win_torch_libs_path(cls, path: Optional[str] = None) -> None:
"""Set the path for Windows torch libs."""
cls._win_torch_libs_path = path
@classmethod
def get_package_path(cls, model_name: str) -> str:
"""Get the path for a model's .pt2 package file."""
package_dir = cls.base_path()
package_dir.mkdir(parents=True, exist_ok=True)
return str(package_dir / f"{model_name}_windows.pt2")
@classmethod
def get_win_torch_libs_path(cls) -> str:
"""Get the path for Windows torch libs."""
if cls._win_torch_libs_path is None:
raise RuntimeError("Windows torch libs path not set")
return str(cls._win_torch_libs_path)
@classmethod
def create_compile_test(cls, config: ModelTestConfig):
"""Create a compile test method for a model configuration."""
def compile_test(self):
if platform.system() == "Windows":
raise unittest.SkipTest(
"This test should run on Linux for cross-compilation"
)
if not _check_mingw_gcc_available():
raise unittest.SkipTest("requires x86_64-w64-mingw32-gcc")
if not HAS_GPU:
raise unittest.SkipTest("Test requires GPU")
self.assertTrue("WINDOWS_CUDA_HOME" in os.environ)
with torch.no_grad():
# Windows cross-compilation is only used for GPU.
# AOTI for CPU should be able to work as native compilation on Windows.
device = GPU_TYPE
model = config.model_class().to(device=device)
example_inputs = config.example_inputs
# Inputs should already be on GPU_TYPE but ensure they are
example_inputs = tuple(inp.to(device) for inp in example_inputs)
# Export the model
exported = torch.export.export(
model, example_inputs, dynamic_shapes=config.dynamic_shapes
)
# Prepare inductor configs
inductor_configs = {
"aot_inductor.cross_target_platform": "windows",
"aot_inductor.precompile_headers": False,
"aot_inductor.package_constants_on_disk_format": "binary_blob",
"aot_inductor.package_constants_in_so": False,
"aot_inductor.aoti_shim_library_path": cls.get_win_torch_libs_path(),
}
if config.inductor_configs:
inductor_configs.update(config.inductor_configs)
# Compile and package directly to the expected location
package_path = cls.get_package_path(config.name)
torch._inductor.aoti_compile_and_package(
exported,
package_path=package_path,
inductor_configs=inductor_configs,
)
self.assertTrue(
os.path.exists(package_path),
f"Package file should exist at {package_path}",
)
return compile_test
@classmethod
def create_load_test(cls, config: ModelTestConfig):
"""Create a load test method for a model configuration."""
def load_test(self):
if platform.system() != "Windows":
raise unittest.SkipTest("This test should run on Windows")
if not HAS_GPU:
raise unittest.SkipTest("Test requires GPU")
package_path = cls.get_package_path(config.name)
if not os.path.exists(package_path):
raise unittest.SkipTest(
f"Package file not found at {package_path}. "
f"Run test_{config.name}_compile first."
)
with torch.no_grad():
# Windows cross-compilation is only used for GPU.
# AOTI for CPU should be able to work as native compilation on Windows.
device = GPU_TYPE
# Create original model for comparison
original_model = config.model_class().to(device=device)
example_inputs = config.example_inputs
# Inputs should already be on GPU_TYPE but ensure they are
example_inputs = tuple(inp.to(device) for inp in example_inputs)
# Load the compiled package
loaded_model = torch._inductor.aoti_load_package(package_path)
# Test with the same inputs
original_output = original_model(*example_inputs)
loaded_output = loaded_model(*example_inputs)
# Compare outputs
torch.testing.assert_close(
original_output, loaded_output, rtol=config.rtol, atol=config.atol
)
return load_test
def auto_generate_tests(test_class):
"""
Class decorator to automatically generate compile/load test methods
from _define_* methods that return ModelTestConfig.
"""
# Find all _define_* methods that return ModelTestConfig
define_methods = {}
for name in dir(test_class):
if name.startswith("_define_") and callable(getattr(test_class, name)):
method = getattr(test_class, name)
# Try to call the method to see if it returns ModelTestConfig
try:
# Create a temporary instance to call the method
temp_instance = test_class.__new__(test_class)
result = method(temp_instance)
if isinstance(result, ModelTestConfig):
define_methods[name] = result
except Exception:
# If method fails, skip it
pass
# Generate compile/load methods for each discovered definition
for define_name, config in define_methods.items():
model_name = define_name[8:] # Remove '_define_' prefix
# Create compile test method
compile_method_name = f"test_{model_name}_compile"
compile_method = WindowsCrossCompilationTestFramework.create_compile_test(
config
)
compile_method.__name__ = compile_method_name
compile_method.__doc__ = f"Step 1: Cross-compile {model_name} model on Linux"
compile_method = requires_gpu()(compile_method)
setattr(test_class, compile_method_name, compile_method)
# Create load test method
load_method_name = f"test_{model_name}_load"
load_method = WindowsCrossCompilationTestFramework.create_load_test(config)
load_method.__name__ = load_method_name
load_method.__doc__ = f"Step 2: Load and test {model_name} model on Windows"
load_method = requires_gpu()(load_method)
setattr(test_class, load_method_name, load_method)
return test_class
@auto_generate_tests
class TestAOTInductorWindowsCrossCompilation(TestCase):
"""
Test class for AOT Inductor Windows cross-compilation.
Define test methods that return ModelTestConfig, and the decorator
will auto-generate compile/load test methods.
"""
def _define_simple(self):
"""Define the Simple model and its test configuration."""
class Simple(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(16, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return x
return ModelTestConfig(
name="simple",
model_class=Simple,
example_inputs=(torch.randn(8, 10, device=GPU_TYPE),),
dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1, max=1024)}},
)
def _define_simple_cnn(self):
"""Define the SimpleCNN model and its test configuration."""
class SimpleCNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 16, 3)
self.relu = torch.nn.ReLU()
self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
self.fc = torch.nn.Linear(16, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.pool(x)
x = x.flatten(1)
x = self.fc(x)
return x
return ModelTestConfig(
name="simple_cnn",
model_class=SimpleCNN,
example_inputs=(torch.randn(2, 3, 32, 32, device=GPU_TYPE),),
dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1, max=16)}},
rtol=1e-3,
atol=1e-3,
)
def _define_transformer(self):
"""Define the SimpleTransformer model and its test configuration."""
class SimpleTransformer(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Linear(128, 256)
self.attention = torch.nn.MultiheadAttention(256, 8, batch_first=True)
self.norm1 = torch.nn.LayerNorm(256)
self.ffn = torch.nn.Sequential(
torch.nn.Linear(256, 1024),
torch.nn.ReLU(),
torch.nn.Linear(1024, 256),
)
self.norm2 = torch.nn.LayerNorm(256)
self.output = torch.nn.Linear(256, 10)
def forward(self, x):
# x shape: (batch, seq_len, input_dim)
x = self.embedding(x)
attn_out, _ = self.attention(x, x, x)
x = self.norm1(x + attn_out)
ffn_out = self.ffn(x)
x = self.norm2(x + ffn_out)
x = x.mean(dim=1) # Global average pooling
x = self.output(x)
return x
return ModelTestConfig(
name="transformer",
model_class=SimpleTransformer,
example_inputs=(torch.randn(4, 16, 128, device=GPU_TYPE),),
dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1, max=32)}},
rtol=1e-3,
atol=1e-3,
)
if __name__ == "__main__":
import sys
from torch._inductor.test_case import run_tests
# Check for --package-dir argument and remove it before unittest sees it
package_dir = None
win_torch_lib_dir = None
filtered_argv = []
i = 0
while i < len(sys.argv):
if sys.argv[i] == "--package-dir":
if i + 1 < len(sys.argv):
package_dir = sys.argv[i + 1]
i += 2 # Skip both --package-dir and its value
else:
print("Error: --package-dir requires a valid directory path")
sys.exit(1)
elif sys.argv[i].startswith("--package-dir="):
package_dir = sys.argv[i].split("=", 1)[1]
i += 1
elif sys.argv[i] == "--win-torch-lib-dir":
if i + 1 < len(sys.argv):
win_torch_lib_dir = sys.argv[i + 1]
i += 2 # Skip both --win-torch-lib-dir and its value
else:
print("Error: --win-torch-lib-dir requires a valid directory path")
sys.exit(1)
elif sys.argv[i].startswith("--win-torch-lib-dir="):
win_torch_lib_dir = sys.argv[i].split("=", 1)[1]
i += 1
else:
filtered_argv.append(sys.argv[i])
i += 1
# Validate and set the base path for package storage
if package_dir:
try:
package_path = Path(package_dir)
package_path.mkdir(parents=True, exist_ok=True)
# Test write access
test_file = package_path / ".test_write"
test_file.touch()
test_file.unlink()
WindowsCrossCompilationTestFramework.set_base_path(package_path)
except Exception:
print("Error: --package-dir requires a valid directory path")
sys.exit(1)
# Set Windows torch libs path if provided (only needed for compile tests)
if win_torch_lib_dir:
WindowsCrossCompilationTestFramework.set_win_torch_libs_path(win_torch_lib_dir)
# Update sys.argv to remove our custom arguments
sys.argv = filtered_argv
if HAS_GPU:
run_tests(needs="filelock")

View File

@ -394,6 +394,8 @@ JIT_EXECUTOR_TESTS = [
]
INDUCTOR_TESTS = [test for test in TESTS if test.startswith(INDUCTOR_TEST_PREFIX)]
# Cross-compilation tests require special setup and should only run in their dedicated workflow
WIN_CROSS_COMPILE_TESTS = ["inductor/test_aoti_cross_compile_windows"]
DISTRIBUTED_TESTS = [test for test in TESTS if test.startswith(DISTRIBUTED_TEST_PREFIX)]
TORCH_EXPORT_TESTS = [test for test in TESTS if test.startswith("export")]
AOT_DISPATCH_TESTS = [
@ -1480,6 +1482,12 @@ def parse_args():
action="store_true",
help="exclude quantization tests",
)
parser.add_argument(
"--exclude-win-cross-compile-tests",
action="store_true",
default=True,
help="exclude Windows cross-compilation tests (they require MinGW and Windows libs)",
)
parser.add_argument(
"--dry-run",
action="store_true",
@ -1656,6 +1664,11 @@ def get_selected_tests(options) -> list[str]:
if options.exclude_quantization_tests:
options.exclude.extend(QUANTIZATION_TESTS)
# Exclude cross-compilation tests by default - they require special setup (MinGW, Windows libs)
# and should only run in their dedicated CI workflow
if options.exclude_win_cross_compile_tests:
options.exclude.extend(WIN_CROSS_COMPILE_TESTS)
# these tests failing in CUDA 11.6 temporary disabling. issue https://github.com/pytorch/pytorch/issues/75375
if torch.version.cuda is not None:
options.exclude.extend(["distributions/test_constraints"])

View File

@ -68,6 +68,8 @@ _IS_LINUX = sys.platform.startswith("linux")
_IS_MACOS = sys.platform.startswith("darwin")
_IS_WINDOWS = sys.platform == "win32"
MINGW_GXX = "x86_64-w64-mingw32-g++"
SUBPROCESS_DECODE_ARGS = (locale.getpreferredencoding(),) if _IS_WINDOWS else ()
log = logging.getLogger(__name__)
@ -333,9 +335,9 @@ def check_msvc_cl_language_id(compiler: str) -> None:
@functools.cache
def check_mingw_win32_flavor(compiler: str) -> None:
def check_mingw_win32_flavor(compiler: str) -> str:
"""
Check if MinGW `compiler` exists and whether it is the win32 flavor (instead of posix flavor).
Check if MinGW `compiler` exists and return it's flavor (win32 or posix).
"""
try:
out = subprocess.check_output(
@ -346,10 +348,22 @@ def check_mingw_win32_flavor(compiler: str) -> None:
except Exception as e:
raise RuntimeError(f"Failed to run {compiler} -v") from e
flavor: Optional[str] = None
for line in out.splitlines():
if "Thread model" in line:
if line.split(":")[1].strip().lower() != "win32":
raise RuntimeError(f"Compiler: {compiler} is not win32 flavor.")
flavor = line.split(":")[1].strip().lower()
if flavor is None:
raise RuntimeError(
f"Cannot determine the flavor of {compiler} (win32 or posix). No Thread model found in {compiler} -v"
)
if flavor not in ("win32", "posix"):
raise RuntimeError(
f"Only win32 and pofix flavor of {compiler} is supported. The flavor is {flavor}"
)
return flavor
def get_cpp_compiler() -> str:
@ -358,7 +372,7 @@ def get_cpp_compiler() -> str:
and sys.platform != "win32"
):
# we're doing cross-compilation
compiler = "x86_64-w64-mingw32-g++"
compiler = MINGW_GXX
if not config.aot_inductor.package_cpp_only:
check_mingw_win32_flavor(compiler)
return compiler
@ -919,8 +933,6 @@ def _get_shared_cflags(do_link: bool) -> list[str]:
# This causes undefined symbols to behave the same as linux
return ["shared", "fPIC", "undefined dynamic_lookup"]
flags = []
if config.aot_inductor.cross_target_platform == "windows":
flags.extend(["static-libstdc++", "static-libgcc", "fPIC"])
if do_link:
flags.append("shared")
@ -961,6 +973,11 @@ def get_cpp_options(
passthrough_args.append(" ".join(extra_flags))
if config.aot_inductor.cross_target_platform == "windows":
passthrough_args.extend(["-static-libstdc++", "-static-libgcc"])
if check_mingw_win32_flavor(MINGW_GXX) == "posix":
passthrough_args.append("-Wl,-Bstatic -lwinpthread -Wl,-Bdynamic")
return (
definitions,
include_dirs,
@ -1133,12 +1150,14 @@ def _get_torch_related_args(
assert config.aot_inductor.aoti_shim_library, (
"'config.aot_inductor.aoti_shim_library' must be set when 'cross_target_platform' is 'windows'."
)
assert config.aot_inductor.aoti_shim_library_path, (
"'config.aot_inductor.aoti_shim_library_path' must be set to the path of the AOTI shim library",
" when 'cross_target_platform' is 'windows'.",
)
libraries.append(config.aot_inductor.aoti_shim_library)
libraries_dirs.append(config.aot_inductor.aoti_shim_library_path)
if config.aot_inductor.cross_target_platform == "windows":
assert config.aot_inductor.aoti_shim_library_path, (
"'config.aot_inductor.aoti_shim_library_path' must be set to the path of the AOTI shim library",
" when 'cross_target_platform' is 'windows'.",
)
libraries_dirs.append(config.aot_inductor.aoti_shim_library_path)
if _IS_WINDOWS:
libraries.append("sleef")