[AOTI XPU] Support AOT Inductor for Intel GPU. (#140269)

This PR add XPU support for AOT Inductor, and reuse the corresponding UT.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140269
Approved by: https://github.com/desertfire, https://github.com/EikanWang
ghstack dependencies: #140268

Co-authored-by: Bin Bao <binbao@meta.com>
This commit is contained in:
Bin Bao
2024-12-09 13:59:19 -08:00
committed by PyTorch MergeBot
parent a1c6cf7e9f
commit 6680a83e89
17 changed files with 513 additions and 184 deletions

View File

@ -793,9 +793,12 @@ libtorch_python_xpu_sources = [
"torch/csrc/xpu/Event.cpp",
"torch/csrc/xpu/Module.cpp",
"torch/csrc/xpu/Stream.cpp",
"torch/csrc/inductor/aoti_runner/model_container_runner_xpu.cpp",
"torch/csrc/inductor/aoti_torch/shim_xpu.cpp",
]
libtorch_xpu_sources = libtorch_python_xpu_sources
libtorch_python_core_sources = [
"torch/csrc/DataLoader.cpp",
"torch/csrc/DeviceAccelerator.cpp",

View File

@ -1050,6 +1050,7 @@ endif()
if(USE_XPU)
list(APPEND Caffe2_XPU_SRCS ${GENERATED_CXX_TORCH_XPU})
list(APPEND Caffe2_XPU_SRCS ${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/shim_xpu.cpp)
add_library(torch_xpu ${Caffe2_XPU_SRCS})
torch_compile_options(torch_xpu) # see cmake/public/utils.cmake
target_compile_definitions(torch_xpu PRIVATE USE_XPU)

View File

@ -15,6 +15,7 @@ import torch._inductor.config
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
import torch.nn as nn
from torch._dynamo import config as dynamo_config
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.testing import rand_strided, same
from torch._dynamo.utils import counters
from torch._inductor import config
@ -40,15 +41,17 @@ from torch.testing._internal.common_utils import (
IS_MACOS,
IS_WINDOWS,
skipIfRocm,
skipIfXpu,
TEST_WITH_ROCM,
)
from torch.testing._internal.inductor_utils import GPU_TYPE
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda
from torch.testing._internal.triton_utils import HAS_GPU, requires_gpu
from torch.utils import _pytree as pytree
from torch.utils._triton import has_triton_tma
if HAS_CUDA:
if HAS_GPU:
import triton # @manual
from triton import language as tl
@ -198,7 +201,7 @@ class AOTInductorTestsTemplate:
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
self.check_model(Model(self.device), example_inputs)
@requires_cuda
@requires_gpu
def test_duplicate_constant_folding(self):
class Model(torch.nn.Module):
def __init__(self, device):
@ -216,14 +219,21 @@ class AOTInductorTestsTemplate:
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
self.check_model(Model(self.device), example_inputs)
@requires_cuda
@requires_gpu
def test_multi_device(self):
if self.device == "cpu" and GPU_TYPE == "xpu":
raise unittest.SkipTest(
"In this scenario, the test case will run XPU code in "
"AOTIModelContainerRunnerCpu, which is not reasonable,"
"See issue #140805"
)
class Model(torch.nn.Module):
def forward(self, x):
x = x + 1
x = x.cpu()
x = x + 2
x = x.cuda()
x = x.to(GPU_TYPE)
return x
example_inputs = (torch.randn(32, 64, device=self.device),)
@ -420,15 +430,8 @@ class AOTInductorTestsTemplate:
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
if self.device == "cuda":
ctx = torch.cuda.amp.autocast
elif self.device == "cpu":
ctx = torch.cpu.amp.autocast
else:
raise AssertionError("Unsupported device")
with config.patch({"fallback_random": True}):
with ctx():
with torch.amp.autocast(device_type=self.device):
self.check_model(fn, example_inputs)
def test_missing_output(self):
@ -621,8 +624,8 @@ class AOTInductorTestsTemplate:
)
def test_assert_async(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU_TYPE")
class Model(torch.nn.Module):
def __init__(self) -> None:
@ -658,6 +661,7 @@ class AOTInductorTestsTemplate:
"FP8 is only supported on H100+",
)
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
@skipIfXpu
def test_fp8(self):
# cuda only
if self.device != "cuda":
@ -682,16 +686,16 @@ class AOTInductorTestsTemplate:
dtype = torch.float16
a_scale = torch.Tensor([1.0]).to(device="cuda")
b_scale = torch.Tensor([1.0]).to(device="cuda")
input_bias = torch.rand(32, device="cuda", dtype=dtype)
a_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
b_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
input_bias = torch.rand(32, device=GPU_TYPE, dtype=dtype)
weight_shape = (32, 16)
weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T
weight = torch.rand(*weight_shape, device=GPU_TYPE, dtype=dtype).T
a_inverse_scale = 1 / a_scale
b_inverse_scale = 1 / b_scale
x_shape = (16, 16)
x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(torch.float8_e4m3fn)
x = torch.rand(*x_shape, device=GPU_TYPE, dtype=dtype).to(torch.float8_e4m3fn)
dim0_x = Dim("dim0_x", min=1, max=2048)
dynamic_shapes = ({0: dim0_x}, None, None, None, None)
self.check_model(
@ -705,9 +709,10 @@ class AOTInductorTestsTemplate:
"FP8 is only supported on H100+",
)
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
@skipIfXpu
def test_fp8_view_of_param(self):
# cuda only
if self.device != "cuda":
if self.device != GPU_TYPE:
return
class Model(torch.nn.Module):
@ -1025,9 +1030,10 @@ class AOTInductorTestsTemplate:
)
self.check_model(Repro(), example_inputs)
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
def test_fallback_kernel_with_symexpr_output(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Module(torch.nn.Module):
def forward(self, q, k, v):
@ -1076,8 +1082,8 @@ class AOTInductorTestsTemplate:
torch.testing.assert_close(m(*inputs), aot_model(*inputs))
def test_large_grid(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
@ -1262,8 +1268,16 @@ class AOTInductorTestsTemplate:
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
input1 = (torch.ones(3, 3), torch.ones(5), torch.ones(3, 3))
input2 = (torch.ones(10, 3), torch.ones(6), torch.ones(10, 3))
input1 = (
torch.ones(3, 3, device=self.device),
torch.ones(5, device=self.device),
torch.ones(3, 3, device=self.device),
)
input2 = (
torch.ones(10, 3, device=self.device),
torch.ones(6, device=self.device),
torch.ones(10, 3, device=self.device),
)
inputs = (input1, input2)
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
self.check_model_with_multiple_inputs(
@ -1390,6 +1404,9 @@ class AOTInductorTestsTemplate:
@unittest.skipIf(IS_MACOS, "no CUDA on Mac")
def test_zero_grid_with_backed_symbols(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Repro(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@ -1412,7 +1429,7 @@ class AOTInductorTestsTemplate:
example_inputs,
dynamic_shapes=dynamic_shapes,
)
aot_inductor_module = AOTIRunnerUtil.load("cuda", so_path)
aot_inductor_module = AOTIRunnerUtil.load(self.device, so_path)
aot_inductor_module(*example_inputs)
# Re-run where dynamic dim size is 0.
@ -1543,8 +1560,8 @@ class AOTInductorTestsTemplate:
self.code_check_count(model, example_inputs, "empty_strided", 2)
def test_buffer_mutation_4(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
@ -1555,14 +1572,17 @@ class AOTInductorTestsTemplate:
)
def forward(self, x):
return x + self._tensor_constant0.to(torch.device(type="cuda", index=0))
return x + self._tensor_constant0.to(
torch.device(type=GPU_TYPE, index=0)
)
example_inputs = (
torch.randint(1, size=[38], dtype=torch.int64, device="cuda"),
torch.randint(1, size=[38], dtype=torch.int64, device=GPU_TYPE),
)
torch._export.aot_compile(Model(), example_inputs)
@skipCUDAIf(True, "Test for x86 backend")
@skipIfXpu
def test_buffer_mutation_and_force_mmap_weights(self):
class Model(nn.Module):
def __init__(self):
@ -1594,8 +1614,8 @@ class AOTInductorTestsTemplate:
@requires_multigpu()
def test_replicate_on_devices(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self, w1, w2):
@ -1614,29 +1634,33 @@ class AOTInductorTestsTemplate:
result_cpu = Model(w1, w2)(*inputs)
# Compile model with AOTInductor
with torch.cuda.device(0):
device_interface = get_interface_for_device(GPU_TYPE)
with device_interface.device(0):
so_path = AOTIRunnerUtil.compile(
model=Model(w1.cuda(0), w2.cuda(0)),
example_inputs=tuple(t.cuda(0) for t in inputs),
model=Model(
w1.to(torch.device(GPU_TYPE, 0)), w2.to(torch.device(GPU_TYPE, 0))
),
example_inputs=tuple(t.to(torch.device(GPU_TYPE, 0)) for t in inputs),
)
# Run model on cuda:N
for i in range(torch.cuda.device_count()):
with torch.cuda.device(i):
example_inputs = tuple(t.cuda(i) for t in inputs)
optimized = AOTIRunnerUtil.load("cuda", so_path)
result_cuda = optimized(*example_inputs)
self.assertTrue(same(result_cpu, result_cuda.cpu()))
# Run model on gpu:N
for i in range(device_interface.device_count()):
with device_interface.device(i):
example_inputs = tuple(t.to(torch.device(GPU_TYPE, i)) for t in inputs)
optimized = AOTIRunnerUtil.load(GPU_TYPE, so_path)
result_gpu = optimized(*example_inputs)
self.assertTrue(same(result_cpu, result_gpu.cpu()))
@requires_multigpu()
def test_on_cuda_device1(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
def test_on_gpu_device1(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
device_interface = get_interface_for_device(GPU_TYPE)
try:
torch.cuda.get_device_properties(1)
device_interface.get_device_properties(1)
except AssertionError:
raise unittest.SkipTest("CUDA device 1 is not available") from None
raise unittest.SkipTest("GPU device 1 is not available") from None
class Model(torch.nn.Module):
def __init__(self):
@ -1653,7 +1677,7 @@ class AOTInductorTestsTemplate:
x = self.sigmoid(x)
return x
device = "cuda:1"
device = f"{GPU_TYPE}:1"
model = Model().to(device)
example_inputs = (torch.randn(8, 10, device=device),)
expected = model(*example_inputs)
@ -1689,9 +1713,9 @@ class AOTInductorTestsTemplate:
)
@requires_multigpu()
def test_non_default_cuda_device(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
def test_non_default_gpu_device(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self, weight):
@ -1705,18 +1729,23 @@ class AOTInductorTestsTemplate:
inputs = (torch.randn(10, 10), torch.randn(10, 10))
result_cpu = Model(weight)(*inputs)
with torch.cuda.device(0), torch.no_grad():
result_cuda_0 = AOTIRunnerUtil.run(
"cuda", Model(weight.cuda(0)), tuple(t.cuda(0) for t in inputs)
device_interface = get_interface_for_device(GPU_TYPE)
with device_interface.device(0), torch.no_grad():
result_gpu_0 = AOTIRunnerUtil.run(
GPU_TYPE,
Model(weight.to(torch.device(GPU_TYPE, 0))),
tuple(t.to(torch.device(GPU_TYPE, 0)) for t in inputs),
)
with torch.cuda.device(1), torch.no_grad():
result_cuda_1 = AOTIRunnerUtil.run(
"cuda", Model(weight.cuda(1)), tuple(t.cuda(1) for t in inputs)
with device_interface.device(1), torch.no_grad():
result_gpu_1 = AOTIRunnerUtil.run(
GPU_TYPE,
Model(weight.to(torch.device(GPU_TYPE, 1))),
tuple(t.to(torch.device(GPU_TYPE, 1)) for t in inputs),
)
self.assertTrue(same(result_cpu, result_cuda_0.cpu()))
self.assertTrue(same(result_cpu, result_cuda_1.cpu()))
self.assertTrue(same(result_cpu, result_gpu_0.cpu()))
self.assertTrue(same(result_cpu, result_gpu_1.cpu()))
def test_reuse_kernel(self):
class Model(torch.nn.Module):
@ -1739,7 +1768,7 @@ class AOTInductorTestsTemplate:
model, example_inputs, atol=1e-4, rtol=1e-4
) # 1e-4 is the tol value used in pytorch/torch/_dynamo/utils.py
if self.device == "cuda":
if self.device == GPU_TYPE:
self.code_check_count(
model, example_inputs, "triton_poi_fused_sin_0 = loadKernel(", 1
)
@ -1809,8 +1838,8 @@ class AOTInductorTestsTemplate:
self.check_model(m, example_inputs, dynamic_shapes=dynamic_shapes)
def test_fake_tensor_device_validation(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
@ -1824,7 +1853,7 @@ class AOTInductorTestsTemplate:
# Export on CPU
exported_program = export(Model(), example_inputs)
# Compile exported model on CUDA
# Compile exported model on GPU
gm = exported_program.graph_module.to(self.device)
with self.assertRaisesRegex(ValueError, "Device mismatch between fake input"):
torch._inductor.aot_compile(
@ -1903,7 +1932,7 @@ class AOTInductorTestsTemplate:
def forward(self, x):
return torch.ops.aten.normal_functional.default(x)
self.check_model(Model(), (torch.empty(4, 1, 4, 4),))
self.check_model(Model(), (torch.empty(4, 1, 4, 4, device=self.device),))
def test_empty_graph(self):
class Model(torch.nn.Module):
@ -2101,8 +2130,8 @@ class AOTInductorTestsTemplate:
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("autotune", [False, True])
def test_triton_kernel(self, grid_type, num_dims, dynamic, autotune):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
@ -2171,8 +2200,8 @@ class AOTInductorTestsTemplate:
self.check_model(Model(), (x, y), dynamic_shapes=dynamic_shapes)
def test_triton_kernel_dynamic_shape_with_div(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
@triton.jit
def pass_kernel(x, num):
@ -2195,8 +2224,8 @@ class AOTInductorTestsTemplate:
self.check_model(Model(), (x,), dynamic_shapes=dynamic_shapes)
def test_triton_kernel_reinterpret_view(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
@triton.jit
def pass_kernel(x, y):
@ -2224,8 +2253,8 @@ class AOTInductorTestsTemplate:
@common_utils.parametrize("dynamic", [False, True])
def test_triton_kernel_tma_descriptor_1d(self, dynamic):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
if not has_triton_tma():
raise unittest.SkipTest("requires Triton TMA")
@ -2280,8 +2309,8 @@ class AOTInductorTestsTemplate:
@common_utils.parametrize("dynamic", [False, True])
def test_triton_kernel_tma_descriptor_2d(self, dynamic):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
if not has_triton_tma():
raise unittest.SkipTest("requires Triton TMA")
@ -2340,8 +2369,8 @@ class AOTInductorTestsTemplate:
)
def test_triton_kernel_sympy_expr_arg(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x, e):
@ -2366,8 +2395,8 @@ class AOTInductorTestsTemplate:
def test_triton_kernel_sympy_fn_like_arg(self):
# This test should hit sympy.expand("sqrt") which crashes with
# AttributeError: 'function' object has no attribute 'expand'.
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x):
@ -2386,8 +2415,8 @@ class AOTInductorTestsTemplate:
self.check_model(Model(), inputs)
def test_triton_kernel_with_none_input(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
@ -2427,8 +2456,8 @@ class AOTInductorTestsTemplate:
self.check_model(Model(), example_inputs)
def test_triton_kernel_equal_to_1_arg(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x, y):
@ -2446,8 +2475,8 @@ class AOTInductorTestsTemplate:
@common_utils.parametrize("dynamic", [False, True])
def test_triton_kernel_equal_to_1_float_arg(self, dynamic):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x, y):
@ -2482,8 +2511,8 @@ class AOTInductorTestsTemplate:
)
def test_triton_kernel_weird_param_order(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
@ -2595,8 +2624,8 @@ class AOTInductorTestsTemplate:
self.check_model(Model(), inputs)
def test_repeated_user_defined_triton_kernel(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
@ -2865,8 +2894,8 @@ class AOTInductorTestsTemplate:
self.check_model(model, example_inputs)
def test_triton_kernel_extern_kernel_arg(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x, y):
@ -2876,15 +2905,15 @@ class AOTInductorTestsTemplate:
return out
example_inputs = (
torch.randn(4, 4, device="cuda"),
torch.randn(4, 4, device="cuda"),
torch.randn(4, 4, device=GPU_TYPE),
torch.randn(4, 4, device=GPU_TYPE),
)
self.check_model(Model(), example_inputs)
def test_triton_kernel_multi_output_arg(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x, y):
@ -2894,16 +2923,17 @@ class AOTInductorTestsTemplate:
return out
example_inputs = (
torch.randn(4, 4, device="cuda"),
torch.randn(4, 4, device="cuda"),
torch.randn(4, 4, device=GPU_TYPE),
torch.randn(4, 4, device=GPU_TYPE),
)
self.check_model(Model(), example_inputs)
# @skipIfXpu(msg="torch.xpu.memory_allocated not supported yet")
def test_triton_kernel_reinterpret_view_mem_leak(self):
# Check for memory leak when using user-defined Triton Kernel + AOTI.
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
@ -2917,22 +2947,23 @@ class AOTInductorTestsTemplate:
return out
example_inputs = (
torch.randn(4, 4, device="cuda"),
torch.randn(1, 16, device="cuda"),
torch.randn(4, 4, device=GPU_TYPE),
torch.randn(1, 16, device=GPU_TYPE),
)
so_path: str = AOTIRunnerUtil.compile(
Model(),
example_inputs,
)
aot_inductor_module = AOTIRunnerUtil.load("cuda", so_path)
aot_inductor_module = AOTIRunnerUtil.load(GPU_TYPE, so_path)
# Don't assign outputs to a variable b/c it will allocate GPU memory.
device: int = torch.cuda.current_device()
mem_before = torch.cuda.memory_allocated(device)
device_interface = get_interface_for_device(GPU_TYPE)
device: int = device_interface.current_device()
mem_before = device_interface.memory_allocated(device)
aot_inductor_module(*example_inputs)
aot_inductor_module(*example_inputs)
mem_after = torch.cuda.memory_allocated(device)
mem_after = device_interface.memory_allocated(device)
self.assertEqual(mem_before, mem_after)
actual = aot_inductor_module(*example_inputs)
@ -2943,8 +2974,8 @@ class AOTInductorTestsTemplate:
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("autotuning", [False, True])
def test_triton_kernel_unbacked_symint_in_grid(self, dynamic, autotuning):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x, y, n_elements_tensor):
@ -2974,8 +3005,8 @@ class AOTInductorTestsTemplate:
return output
example_inputs = (
torch.randn(123, device="cuda"),
torch.randn(123, device="cuda"),
torch.randn(123, device=GPU_TYPE),
torch.randn(123, device=GPU_TYPE),
torch.tensor(123),
)
@ -2996,8 +3027,8 @@ class AOTInductorTestsTemplate:
@skipIfRocm # USE_MEM_EFF_ATTENTION was not enabled for build.
def test_scaled_dot_product_efficient_attention(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, q, k, v, attn_bias):
@ -3006,10 +3037,10 @@ class AOTInductorTestsTemplate:
)[0]
example_inputs = (
torch.randn(4, 4, 36, 36, device="cuda"),
torch.randn(4, 4, 36, 36, device="cuda"),
torch.randn(4, 4, 36, 36, device="cuda"),
torch.randn(4, 4, 36, 36, device="cuda"),
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
)
self.check_model(Model(), example_inputs)
@ -3510,9 +3541,9 @@ class AOTInductorTestsTemplate:
kernel_calls = (
[
("triton_poi_fused_0", 1),
("aoti_torch_cuda_addmm_out", 2),
(f"aoti_torch_{GPU_TYPE}_addmm_out", 2),
]
if self.device == "cuda"
if self.device == GPU_TYPE
else [
("aoti_torch_cpu_addmm_out", 2),
]
@ -3573,8 +3604,8 @@ class AOTInductorTestsTemplate:
FileCheck().check_not(f"after_launch - {kernel_name}").run(code)
def test_aoti_debug_printer_user_defined_triton_kernel(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
@ -3650,8 +3681,8 @@ class AOTInductorTestsTemplate:
).run(code)
def test_aoti_debug_printer_sym_inputs(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
from torch.testing._internal.triton_utils import add_kernel
@ -3661,8 +3692,8 @@ class AOTInductorTestsTemplate:
def forward(self, x):
maxlen = max(x.item(), 512)
a = torch.ones(maxlen, device="cuda")
b = torch.ones(maxlen, device="cuda")
a = torch.ones(maxlen, device=GPU_TYPE)
b = torch.ones(maxlen, device=GPU_TYPE)
out = torch.zeros_like(a)
# unbacked symint in grid
add_kernel[(1, 1, maxlen)](a, b, out, maxlen, 32)
@ -3739,8 +3770,8 @@ class AOTInductorTestsTemplate:
@dynamo_config.patch({"capture_scalar_outputs": True})
def test_sym_i64_input_codegen(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
from torch.testing._internal.triton_utils import add_kernel
@ -3750,8 +3781,8 @@ class AOTInductorTestsTemplate:
def forward(self, x):
x_symint = x.item()
a = torch.ones(x_symint, device="cuda")
b = torch.ones(x_symint, device="cuda")
a = torch.ones(x_symint, device=GPU_TYPE)
b = torch.ones(x_symint, device=GPU_TYPE)
out = torch.zeros_like(a)
# unbacked symint in grid
add_kernel[(1, 1, x_symint)](a, b, out, x_symint, 32)
@ -3786,8 +3817,8 @@ class AOTInductorTestsTemplate:
self.check_model(Model(), example_inputs)
def test_none_args_aot_codegen(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
@triton.autotune(
configs=[
@ -4099,9 +4130,9 @@ def fail_cpu(is_skip=False):
)
def fail_cuda(is_skip=False):
def fail_gpu(suffixes: Tuple[str, ...], is_skip=False):
return TestFailure(
("cuda"),
suffixes,
is_skip=is_skip,
)
@ -4115,10 +4146,14 @@ CPU_TEST_FAILURES = {
}
# test_failures, xfail by default, set is_skip=True to skip
CUDA_TEST_FAILURES = {
GPU_TEST_FAILURES = {
# quantized unsupported for GPU
"test_quantized_linear": fail_cuda(),
"test_quanatized_int8_linear": fail_cuda(),
"test_quantized_linear": fail_gpu(("cuda", "xpu")),
"test_quanatized_int8_linear": fail_gpu(("cuda", "xpu")),
# No fft implementation for XPU yet.
"test_fft_c2c": fail_gpu(("xpu",)),
# No scaled_dot_product_efficient_attention implementation for XPU yet.
"test_scaled_dot_product_efficient_attention": fail_gpu(("xpu",)),
}
@ -4141,9 +4176,9 @@ copy_tests(
@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS")
class AOTInductorTestABICompatibleCuda(TestCase):
device = "cuda"
device_type = "cuda"
class AOTInductorTestABICompatibleGpu(TestCase):
device = GPU_TYPE
device_type = GPU_TYPE
check_model = check_model
check_model_with_multiple_inputs = check_model_with_multiple_inputs
code_check_count = code_check_count
@ -4153,14 +4188,14 @@ class AOTInductorTestABICompatibleCuda(TestCase):
copy_tests(
AOTInductorTestsTemplate,
AOTInductorTestABICompatibleCuda,
"cuda",
CUDA_TEST_FAILURES,
AOTInductorTestABICompatibleGpu,
GPU_TYPE,
GPU_TEST_FAILURES,
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
# cpp_extension N/A in fbcode
if HAS_CUDA or sys.platform == "darwin":
if HAS_GPU or sys.platform == "darwin":
run_tests(needs="filelock")

View File

@ -92,11 +92,12 @@ class AOTIRunnerUtil:
temp_so_path, device == "cpu"
)
else:
return (
torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1)
if device == "cpu"
else torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device)
)
if device == "cpu":
return torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1)
elif device == "xpu":
return torch._C._aoti.AOTIModelContainerRunnerXpu(so_path, 1, device)
else:
return torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device)
@staticmethod
def load(device, so_path):

View File

@ -18,6 +18,7 @@ def alloc_tensor_by_stealing_from_void_ptr(
class AOTIModelContainerRunnerCpu: ...
class AOTIModelContainerRunnerCuda: ...
class AOTIModelContainerRunnerXpu: ...
# Defined in torch/csrc/inductor/aoti_package/pybind.cpp
class AOTIModelPackageLoader: ...

View File

@ -373,6 +373,9 @@ def aot_load(so_path: str, device: str) -> Callable:
runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg]
elif device == "cuda" or device.startswith("cuda:"):
runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
elif device == "xpu" or device.startswith("xpu:"):
runner = torch._C._aoti.AOTIModelContainerRunnerXpu(so_path, 1, device) # type: ignore[assignment, call-arg]
else:
raise RuntimeError("Unsupported device " + device)

View File

@ -1531,7 +1531,9 @@ class AotCodeCompiler:
object_output_dir,
) = get_name_and_dir_from_output_file_path(consts_s)
object_build_options = CppTorchDeviceOptions(
device_type=device_type,
# Intel compiler failed to compile this manully constructed assembly file.
# it is ok to use gcc to compile the .S to a .o and linked with Intel comiler .
device_type=device_type if device_type != "xpu" else "cpu",
aot_mode=graph.aot_mode,
compile_only=True,
use_absolute_path=use_absolute_path,

View File

@ -20,7 +20,7 @@ from .. import config, ir
from ..utils import _align, ALIGN_BYTES, cache_on_self, normalize_name
from ..virtualized import V
from .aoti_hipify_utils import maybe_hipify_code_wrapper
from .common import IndentedBuffer, Kernel
from .common import get_device_op_overrides, IndentedBuffer, Kernel
from .cpp_utils import cexpr, DEVICE_TO_ATEN, DTYPE_TO_ATEN, DTYPE_TO_CPP
from .triton_utils import should_unwrap_unspec_arg
from .wrapper import (
@ -66,6 +66,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
self.custom_op_wrapper_loaded = False
# For GEMM kernels that must be initialized and are resolved at linking.
self.initialized_kernels: Dict[str, Kernel] = {}
self.device_codegen = get_device_op_overrides(self.device)
@staticmethod
def create(
@ -571,7 +572,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
)
for kernel in sorted(declare_kernel):
self.prefix.writeline(
maybe_hipify_code_wrapper(f" CUfunction {kernel}{{nullptr}};")
maybe_hipify_code_wrapper(
f" {self.device_codegen.cpp_kernel_type()} {kernel}{{nullptr}};"
)
)
for name, kernel in self.initialized_kernels.items():
assert hasattr(

View File

@ -0,0 +1,34 @@
#if !defined(C10_MOBILE) && !defined(ANDROID)
#include <torch/csrc/inductor/aoti_runner/model_container_runner_xpu.h>
namespace torch::inductor {
AOTIModelContainerRunnerXpu::AOTIModelContainerRunnerXpu(
const std::string& model_so_path,
size_t num_models,
const std::string& device_str,
const std::string& cubin_dir)
: AOTIModelContainerRunner(
model_so_path,
num_models,
device_str,
cubin_dir) {}
AOTIModelContainerRunnerXpu::~AOTIModelContainerRunnerXpu() = default;
std::vector<at::Tensor> AOTIModelContainerRunnerXpu::run(
std::vector<at::Tensor>& inputs) {
at::xpu::XPUStream xpu_stream = c10::xpu::getCurrentXPUStream();
return AOTIModelContainerRunner::run(
inputs, reinterpret_cast<AOTInductorStreamHandle>(&(xpu_stream.queue())));
}
std::vector<at::Tensor> AOTIModelContainerRunnerXpu::run_with_xpu_stream(
std::vector<at::Tensor>& inputs,
at::xpu::XPUStream xpu_stream) {
return AOTIModelContainerRunner::run(
inputs, reinterpret_cast<AOTInductorStreamHandle>(&(xpu_stream.queue())));
}
} // namespace torch::inductor
#endif

View File

@ -0,0 +1,30 @@
#if !defined(C10_MOBILE) && !defined(ANDROID)
#pragma once
#include <c10/xpu/XPUStream.h>
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
namespace torch::inductor {
// NOTICE: Following APIs are subject to change due to active development
// We provide NO BC guarantee for these APIs
class TORCH_API AOTIModelContainerRunnerXpu : public AOTIModelContainerRunner {
public:
// @param device_str: xpu device string, e.g. "xpu", "xpu:0"
AOTIModelContainerRunnerXpu(
const std::string& model_so_path,
size_t num_models = 1,
const std::string& device_str = "xpu",
const std::string& cubin_dir = "");
~AOTIModelContainerRunnerXpu();
std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs);
std::vector<at::Tensor> run_with_xpu_stream(
std::vector<at::Tensor>& inputs,
at::xpu::XPUStream xpu_stream);
};
} // namespace torch::inductor
#endif

View File

@ -2,6 +2,9 @@
#ifdef USE_CUDA
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
#endif
#ifdef USE_XPU
#include <torch/csrc/inductor/aoti_runner/model_container_runner_xpu.h>
#endif
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
@ -51,6 +54,30 @@ void initAOTIRunnerBindings(PyObject* module) {
static_cast<void (AOTIModelContainerRunnerCuda::*)(
std::unordered_map<std::string, at::Tensor>&, bool, bool)>(
&AOTIModelContainerRunnerCuda::update_constant_buffer));
#endif
#ifdef USE_XPU
py::class_<AOTIModelContainerRunnerXpu>(m, "AOTIModelContainerRunnerXpu")
.def(py::init<const std::string&, int>())
.def(py::init<const std::string&, int, const std::string&>())
.def(py::init<
const std::string&,
int,
const std::string&,
const std::string&>())
.def("run", &AOTIModelContainerRunnerXpu::run)
.def("get_call_spec", &AOTIModelContainerRunnerXpu::get_call_spec)
.def(
"get_constant_names_to_original_fqns",
&AOTIModelContainerRunnerXpu::getConstantNamesToOriginalFQNs)
.def(
"get_constant_names_to_dtypes",
&AOTIModelContainerRunnerXpu::getConstantNamesToDtypes)
.def(
"update_constant_buffer",
static_cast<void (AOTIModelContainerRunnerXpu::*)(
std::unordered_map<std::string, at::Tensor>&, bool, bool)>(
&AOTIModelContainerRunnerXpu::update_constant_buffer));
#endif
m.def(

View File

@ -30,7 +30,27 @@ using DeviceStreamType = cudaStream_t;
} // namespace torch::aot_inductor
#else // !USE_CUDA
#elif defined(USE_XPU)
#include <level_zero/ze_api.h>
#include <sycl/sycl.hpp>
#include <sstream>
#define AOTI_RUNTIME_DEVICE_CHECK(EXPR) \
do { \
const ze_result_t status = EXPR; \
if (status != ZE_RESULT_SUCCESS) { \
std::stringstream ss; \
ss << "L0 runtime error: " << std::hex << std::uppercase << status; \
throw std::runtime_error(ss.str()); \
} \
} while (0)
namespace torch::aot_inductor {
using DeviceStreamType = sycl::queue*;
} // namespace torch::aot_inductor
#else
#define AOTI_RUNTIME_DEVICE_CHECK(EXPR) \
bool ok = EXPR; \

View File

@ -15,7 +15,11 @@
// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
// applies to other files under torch/csrc/inductor/aoti_runtime/.
#include <torch/csrc/inductor/aoti_runtime/device_utils.h>
#ifdef USE_XPU
#include <torch/csrc/inductor/aoti_runtime/utils_xpu.h>
#else
#include <torch/csrc/inductor/aoti_runtime/utils.h>
#endif
#define AOTI_RUNTIME_CHECK(EXPR, MSG) \
do { \
@ -44,13 +48,27 @@ namespace {
#ifdef USE_CUDA
using CUDAPtr = std::unique_ptr<void, std::function<void(void*)>>;
using GPUPtr = std::unique_ptr<void, std::function<void(void*)>>;
CUDAPtr RAII_cudaMalloc(size_t num_bytes) {
GPUPtr RAII_gpuMalloc(size_t num_bytes) {
void* data_ptr;
AOTI_RUNTIME_DEVICE_CHECK(cudaMalloc((void**)&data_ptr, num_bytes));
auto deleter = [](void* ptr) { AOTI_RUNTIME_DEVICE_CHECK(cudaFree(ptr)); };
return CUDAPtr(data_ptr, deleter);
return GPUPtr(data_ptr, deleter);
}
#endif // USE_CUDA
#ifdef USE_XPU
using GPUPtr = std::unique_ptr<void, std::function<void(void*)>>;
GPUPtr RAII_gpuMalloc(size_t num_bytes) {
sycl::queue* queue_ptr = nullptr;
aoti_torch_get_current_sycl_queue((void**)&queue_ptr);
void* data_ptr = sycl::malloc_device(num_bytes, *queue_ptr);
auto deleter = [queue_ptr](void* ptr) { sycl::free(ptr, *queue_ptr); };
return GPUPtr(data_ptr, deleter);
}
#endif // USE_CUDA
@ -74,7 +92,7 @@ inline void parse_device_str(
const std::string& device_str,
int32_t& device_type,
int32_t& device_idx) {
std::regex re("(cpu|cuda)(:([0-9]+))?");
std::regex re("(cpu|cuda|xpu)(:([0-9]+))?");
std::smatch sm;
bool matched = std::regex_match(device_str, sm, re);
AOTI_RUNTIME_CHECK(matched, "Invalid device: " + device_str);
@ -83,6 +101,10 @@ inline void parse_device_str(
device_type = aoti_torch_device_type_cpu();
} else if (sm[1].str() == "cuda") {
device_type = aoti_torch_device_type_cuda();
#ifdef USE_XPU
} else if (sm[1].str() == "xpu") {
device_type = aoti_torch_device_type_xpu();
#endif
} else {
AOTI_RUNTIME_CHECK(false, "Invalid device: " + device_str);
}
@ -124,6 +146,13 @@ class AOTInductorModelBase {
AOTI_RUNTIME_DEVICE_CHECK(cudaSetDevice(device_idx_));
}
#endif // USE_CUDA
#ifdef USE_XPU
if (device_idx_ == -1) {
aoti_torch_get_current_xpu_device(&device_idx_);
} else {
aoti_torch_set_current_xpu_device(device_idx_);
}
#endif // USE_XPU
}
// NOLINTNEXTLINE(modernize-use-equals-default)
@ -137,6 +166,12 @@ class AOTInductorModelBase {
}
}
#endif // USE_CUDA
#ifdef USE_XPU
if (run_finished_) {
(*run_finished_)->wait_and_throw();
delete *run_finished_;
}
#endif // USE_XPU
}
AOTInductorModelBase(AOTInductorModelBase&&) = delete;
@ -160,14 +195,25 @@ class AOTInductorModelBase {
AOTI_RUNTIME_DEVICE_CHECK(cudaEventCreate(&run_finished));
run_finished_.emplace(run_finished);
}
#elif defined(USE_XPU)
if (run_finished_) {
(*run_finished_)->wait_and_throw();
delete *run_finished_;
run_finished_.reset();
}
#else // !USE_CUDA && !USE_XPU
run_finished_ = false;
#endif
auto* model = static_cast<Model*>(this);
model->run_impl(input_handles, output_handles, stream, proxy_executor);
#ifdef USE_CUDA
AOTI_RUNTIME_DEVICE_CHECK(cudaEventRecord(*run_finished_, stream));
#else // !USE_CUDA
run_finished_ = false;
auto* model = static_cast<Model*>(this);
model->run_impl(input_handles, output_handles, stream, proxy_executor);
#elif defined(USE_XPU)
run_finished_ = std::make_optional<sycl::event*>(new sycl::event(
static_cast<sycl::queue*>(stream)->ext_oneapi_submit_barrier()));
#else // !USE_CUDA && !USE_XPU
run_finished_ = true;
#endif // USE_CUDA
}
@ -182,9 +228,15 @@ class AOTInductorModelBase {
AOTI_RUNTIME_DEVICE_CHECK(cudaEventCreate(&run_finished));
run_finished_.emplace(run_finished);
}
#else // USE_CUDA
#elif defined(USE_XPU)
if (run_finished_) {
(*run_finished_)->wait_and_throw();
delete *run_finished_;
run_finished_.reset();
}
#else // !USE_CUDA && !USE_XPU
run_finished_ = false;
#endif // USE_CUDA
#endif
auto* model = static_cast<Model*>(this);
auto folded_constants =
@ -192,7 +244,13 @@ class AOTInductorModelBase {
#ifdef USE_CUDA
AOTI_RUNTIME_DEVICE_CHECK(cudaEventRecord(*run_finished_, stream));
#else // USE_CUDA
#elif defined(USE_XPU)
// sycl::queue* queue_ptr = nullptr;
// aoti_torch_get_current_sycl_queue((void**)&queue_ptr);
run_finished_ = std::make_optional<sycl::event*>(new sycl::event(
static_cast<sycl::queue*>(stream)->ext_oneapi_submit_barrier()));
#else // !USE_CUDA && !USE_XPU
run_finished_ = true;
#endif // USE_CUDA
@ -206,9 +264,9 @@ class AOTInductorModelBase {
std::vector<size_t> constants_internal_offset(num_constants);
if (device_type_ != aoti_torch_device_type_cpu()) {
size_t blob_size = 0;
compute_cuda_constant_blob(blob_size, constants_internal_offset);
#ifdef USE_CUDA
constant_blob_ = RAII_cudaMalloc(blob_size);
compute_gpu_constant_blob(blob_size, constants_internal_offset);
#if defined(USE_CUDA) || defined(USE_XPU)
constant_blob_ = RAII_gpuMalloc(blob_size);
#endif
}
if (!include_weights) {
@ -218,7 +276,7 @@ class AOTInductorModelBase {
size_t bytes_read = 0;
for (size_t i = 0; i < num_constants; i++) {
bool from_folded = this->constant_from_folded(i);
#ifndef USE_CUDA
#if not defined(USE_XPU) && not defined(USE_CUDA)
if (from_folded) {
// We do not reallocate and copy for CPU.
continue;
@ -284,8 +342,8 @@ class AOTInductorModelBase {
}
}
#ifdef USE_CUDA
CUDAPtr&& release_constant_blob() {
#if defined(USE_CUDA) || defined(USE_XPU)
GPUPtr&& release_constant_blob() {
return std::move(constant_blob_);
}
#endif
@ -303,17 +361,26 @@ class AOTInductorModelBase {
size_t bytes_read,
size_t data_size,
bool skip_copy) {
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_XPU)
auto* constants_ptr = static_cast<uint8_t*>(constant_blob_.get());
uint8_t* internal_ptr = constants_ptr + constant_offset;
// Copy data to GPU memory
// TODO: Handle shared storage case.
if (!skip_copy) {
#ifdef USE_XPU
sycl::queue* queue_ptr = nullptr;
aoti_torch_get_current_sycl_queue((void**)&queue_ptr);
queue_ptr
->memcpy(internal_ptr, _get_constants_start() + bytes_read, data_size)
.wait();
#else
AOTI_RUNTIME_DEVICE_CHECK(cudaMemcpy(
internal_ptr,
_get_constants_start() + bytes_read,
data_size,
cudaMemcpyHostToDevice));
#endif
}
return internal_ptr;
@ -324,10 +391,10 @@ class AOTInductorModelBase {
#endif // USE_CUDA
}
void compute_cuda_constant_blob(
void compute_gpu_constant_blob(
size_t& blob_size,
std::vector<size_t>& constants_internal_offset) {
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_XPU)
size_t num_constants = this->num_constants();
// Compute required blob size with 64-alignment if on GPU.
blob_size = 0;
@ -477,7 +544,15 @@ class AOTInductorModelBase {
throw std::runtime_error(
std::string("The model did not finish successfully. Error: ") +
cudaGetErrorString(cudaGetLastError()));
#else // !USE_CUDA
#elif defined(USE_XPU)
if (!run_finished_) {
throw std::runtime_error{"Model XPU event was not initialized"};
}
using namespace sycl::info;
return (*run_finished_)->get_info<event::command_execution_status>() ==
event_command_status::complete;
#else // !USE_CUDA && !USE_XPU
return run_finished_;
#endif // USE_CUDA
}
@ -491,6 +566,12 @@ class AOTInductorModelBase {
AOTI_RUNTIME_DEVICE_CHECK(cudaEventSynchronize(*run_finished_));
#endif // USE_CUDA
#ifdef USE_XPU
if (!run_finished_) {
throw std::runtime_error{"Model event was not initialized"};
}
(*run_finished_)->wait_and_throw();
#endif
}
protected:
@ -562,10 +643,11 @@ class AOTInductorModelBase {
std::shared_ptr<ConstantMap> constants_map_;
std::shared_ptr<std::vector<ConstantHandle>> constants_;
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_XPU)
// Holds the blob storage for constants' at::Tensor for CUDA.
CUDAPtr constant_blob_;
GPUPtr constant_blob_;
#endif // USE_CUDA
#ifdef USE_MMAP_SELF
uint8_t* self_mmap = NULL;
#endif
@ -582,6 +664,8 @@ class AOTInductorModelBase {
// AOTModelContainer can re-use this instance.
#ifdef USE_CUDA
std::optional<cudaEvent_t> run_finished_;
#elif defined(USE_XPU)
std::optional<sycl::event*> run_finished_;
#else // !USE_CUDA
bool run_finished_{};
#endif

View File

@ -51,12 +51,11 @@ class AOTInductorModelContainer {
for (size_t i = 0; i < num_outputs; i++) {
output_names_.emplace_back(model->output_name(static_cast<int64_t>(i)));
}
model->load_constants();
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_XPU)
constant_blob_ = model->release_constant_blob();
constants_internal_offset_.resize(model->num_constants());
model->compute_cuda_constant_blob(blob_size_, constants_internal_offset_);
model->compute_gpu_constant_blob(blob_size_, constants_internal_offset_);
#endif
for (auto& model : models_) {
@ -276,7 +275,7 @@ class AOTInductorModelContainer {
continue;
}
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_XPU)
AtenTensorHandle tensor;
if (_should_skip_update(idx) && use_inactive) {
tensor = original_constants_map->find(constant_name)->second.get();
@ -293,13 +292,20 @@ class AOTInductorModelContainer {
int64_t constant_size;
aoti_torch_get_data_ptr(tensor, &user_constant_ptr);
aoti_torch_get_storage_size(tensor, &constant_size);
#ifdef USE_XPU
sycl::queue* queue_ptr = nullptr;
aoti_torch_get_current_sycl_queue((void**)&queue_ptr);
queue_ptr
->memcpy(internal_constants_ptr, user_constant_ptr, constant_size)
.wait();
#else
AOTI_RUNTIME_DEVICE_CHECK(cudaMemcpy(
internal_constants_ptr,
user_constant_ptr,
constant_size,
cudaMemcpyDefault));
#endif
// Generate Tensor from container handled blob.
// We extract stride and offset from provided Tensor since we do not
// guarantee that the tensor is contiguous.
@ -317,7 +323,11 @@ class AOTInductorModelContainer {
stride,
offset,
models_[0]->constant_dtype(idx),
#ifdef USE_XPU
aoti_torch_device_type_xpu(),
#else
aoti_torch_device_type_cuda(),
#endif
device_idx,
&tensor_handle));
#else // USE_CUDA
@ -397,10 +407,10 @@ class AOTInductorModelContainer {
const char* in_spec_;
const char* out_spec_;
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_XPU)
// Holds the blob storage for constants' at::Tensor for CUDA.
CUDAPtr constant_blob_;
CUDAPtr constant_blob_secondary_;
GPUPtr constant_blob_;
GPUPtr constant_blob_secondary_;
// Let's place this within USE_CUDA at the moment before we fully support
// update for CPU cases.
@ -461,14 +471,14 @@ class AOTInductorModelContainer {
// make sure no one is executing the model.
std::shared_mutex model_exec_mutex_;
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_XPU)
void* get_constant_blob_ptr(bool get_inactive) {
if ((get_inactive && use_secondary_) ||
(!get_inactive && !use_secondary_)) {
return constant_blob_.get();
} else {
if (!constant_blob_secondary_) {
constant_blob_secondary_ = RAII_cudaMalloc(blob_size_);
constant_blob_secondary_ = RAII_gpuMalloc(blob_size_);
}
return constant_blob_secondary_.get();
}

View File

@ -15,6 +15,11 @@ inline void delete_xpu_guard(void* ptr) {
aoti_torch_delete_xpu_guard(reinterpret_cast<XPUGuardHandle>(ptr)));
}
inline void delete_xpu_stream_guard(void* ptr) {
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_xpu_stream_guard(
reinterpret_cast<XPUStreamGuardHandle>(ptr)));
}
class AOTIXpuGuard {
public:
AOTIXpuGuard(int32_t device_index) : guard_(nullptr, delete_xpu_guard) {
@ -32,5 +37,20 @@ class AOTIXpuGuard {
private:
std::unique_ptr<XPUGuardOpaque, DeleterFnPtr> guard_;
};
class AOTIXpuStreamGuard {
public:
AOTIXpuStreamGuard(void* stream, int32_t device_index)
: guard_(nullptr, delete_xpu_stream_guard) {
XPUStreamGuardHandle ptr = nullptr;
AOTI_TORCH_ERROR_CODE_CHECK(
aoti_torch_create_xpu_stream_guard(stream, device_index, &ptr));
guard_.reset(ptr);
}
private:
std::unique_ptr<XPUStreamGuardOpaque, DeleterFnPtr> guard_;
};
} // namespace torch::aot_inductor
#endif // USE_XPU

View File

@ -25,9 +25,26 @@ aoti_torch_xpu_guard_set_index(XPUGuardHandle guard, int32_t device_index);
struct XPUStreamGuardOpaque;
using XPUStreamGuardHandle = XPUStreamGuardOpaque*;
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_xpu_stream_guard(
void* stream,
int32_t device_index,
XPUStreamGuardHandle* ret_guard // returns new reference
);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_delete_xpu_stream_guard(XPUStreamGuardHandle guard);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_get_current_xpu_stream(int32_t device_index, void** ret_stream);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_get_current_xpu_device(int32_t* device_index);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_set_current_xpu_device(const int32_t& device_index);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_sycl_queue(void** ret);
#ifdef __cplusplus
} // extern "C"
#endif

View File

@ -30,8 +30,46 @@ AOTITorchError aoti_torch_xpu_guard_set_index(
{ reinterpret_cast<at::DeviceGuard*>(guard)->set_index(device_index); });
}
AOTITorchError aoti_torch_create_xpu_stream_guard(
void* stream,
int32_t device_index,
XPUStreamGuardHandle* ret_guard) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
assert(stream);
at::StreamGuard* guard =
new at::StreamGuard(at::xpu::getStreamFromExternal(
static_cast<sycl::queue*>(stream), device_index)
.unwrap());
*ret_guard = reinterpret_cast<XPUStreamGuardHandle>(guard);
});
}
AOTITorchError aoti_torch_delete_xpu_stream_guard(XPUStreamGuardHandle guard) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ delete reinterpret_cast<at::StreamGuard*>(guard); });
}
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_get_current_xpu_stream(int32_t device_index, void** ret_stream) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *ret_stream = &(at::xpu::getCurrentXPUStream(device_index).queue()); });
}
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_get_current_xpu_device(int32_t* device_index) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *device_index = static_cast<int32_t>(c10::xpu::current_device()); });
}
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_set_current_xpu_device(const int32_t& device_index) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ c10::xpu::set_device(static_cast<int8_t>(device_index)); });
}
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_sycl_queue(void** ret) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
int32_t device_index = static_cast<int32_t>(c10::xpu::current_device());
*ret = &(at::xpu::getCurrentXPUStream(device_index).queue());
});
}