mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI XPU] Support AOT Inductor for Intel GPU. (#140269)
This PR add XPU support for AOT Inductor, and reuse the corresponding UT. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140269 Approved by: https://github.com/desertfire, https://github.com/EikanWang ghstack dependencies: #140268 Co-authored-by: Bin Bao <binbao@meta.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
a1c6cf7e9f
commit
6680a83e89
@ -793,9 +793,12 @@ libtorch_python_xpu_sources = [
|
||||
"torch/csrc/xpu/Event.cpp",
|
||||
"torch/csrc/xpu/Module.cpp",
|
||||
"torch/csrc/xpu/Stream.cpp",
|
||||
"torch/csrc/inductor/aoti_runner/model_container_runner_xpu.cpp",
|
||||
"torch/csrc/inductor/aoti_torch/shim_xpu.cpp",
|
||||
]
|
||||
|
||||
libtorch_xpu_sources = libtorch_python_xpu_sources
|
||||
|
||||
libtorch_python_core_sources = [
|
||||
"torch/csrc/DataLoader.cpp",
|
||||
"torch/csrc/DeviceAccelerator.cpp",
|
||||
|
@ -1050,6 +1050,7 @@ endif()
|
||||
|
||||
if(USE_XPU)
|
||||
list(APPEND Caffe2_XPU_SRCS ${GENERATED_CXX_TORCH_XPU})
|
||||
list(APPEND Caffe2_XPU_SRCS ${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/shim_xpu.cpp)
|
||||
add_library(torch_xpu ${Caffe2_XPU_SRCS})
|
||||
torch_compile_options(torch_xpu) # see cmake/public/utils.cmake
|
||||
target_compile_definitions(torch_xpu PRIVATE USE_XPU)
|
||||
|
@ -15,6 +15,7 @@ import torch._inductor.config
|
||||
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
|
||||
import torch.nn as nn
|
||||
from torch._dynamo import config as dynamo_config
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
from torch._dynamo.testing import rand_strided, same
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import config
|
||||
@ -40,15 +41,17 @@ from torch.testing._internal.common_utils import (
|
||||
IS_MACOS,
|
||||
IS_WINDOWS,
|
||||
skipIfRocm,
|
||||
skipIfXpu,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE
|
||||
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
||||
from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda
|
||||
from torch.testing._internal.triton_utils import HAS_GPU, requires_gpu
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._triton import has_triton_tma
|
||||
|
||||
|
||||
if HAS_CUDA:
|
||||
if HAS_GPU:
|
||||
import triton # @manual
|
||||
from triton import language as tl
|
||||
|
||||
@ -198,7 +201,7 @@ class AOTInductorTestsTemplate:
|
||||
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
|
||||
self.check_model(Model(self.device), example_inputs)
|
||||
|
||||
@requires_cuda
|
||||
@requires_gpu
|
||||
def test_duplicate_constant_folding(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
@ -216,14 +219,21 @@ class AOTInductorTestsTemplate:
|
||||
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
|
||||
self.check_model(Model(self.device), example_inputs)
|
||||
|
||||
@requires_cuda
|
||||
@requires_gpu
|
||||
def test_multi_device(self):
|
||||
if self.device == "cpu" and GPU_TYPE == "xpu":
|
||||
raise unittest.SkipTest(
|
||||
"In this scenario, the test case will run XPU code in "
|
||||
"AOTIModelContainerRunnerCpu, which is not reasonable,"
|
||||
"See issue #140805"
|
||||
)
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x = x + 1
|
||||
x = x.cpu()
|
||||
x = x + 2
|
||||
x = x.cuda()
|
||||
x = x.to(GPU_TYPE)
|
||||
return x
|
||||
|
||||
example_inputs = (torch.randn(32, 64, device=self.device),)
|
||||
@ -420,15 +430,8 @@ class AOTInductorTestsTemplate:
|
||||
torch.randn(10, 10, device=self.device),
|
||||
torch.randn(10, 10, device=self.device),
|
||||
)
|
||||
if self.device == "cuda":
|
||||
ctx = torch.cuda.amp.autocast
|
||||
elif self.device == "cpu":
|
||||
ctx = torch.cpu.amp.autocast
|
||||
else:
|
||||
raise AssertionError("Unsupported device")
|
||||
|
||||
with config.patch({"fallback_random": True}):
|
||||
with ctx():
|
||||
with torch.amp.autocast(device_type=self.device):
|
||||
self.check_model(fn, example_inputs)
|
||||
|
||||
def test_missing_output(self):
|
||||
@ -621,8 +624,8 @@ class AOTInductorTestsTemplate:
|
||||
)
|
||||
|
||||
def test_assert_async(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU_TYPE")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -658,6 +661,7 @@ class AOTInductorTestsTemplate:
|
||||
"FP8 is only supported on H100+",
|
||||
)
|
||||
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
|
||||
@skipIfXpu
|
||||
def test_fp8(self):
|
||||
# cuda only
|
||||
if self.device != "cuda":
|
||||
@ -682,16 +686,16 @@ class AOTInductorTestsTemplate:
|
||||
|
||||
dtype = torch.float16
|
||||
|
||||
a_scale = torch.Tensor([1.0]).to(device="cuda")
|
||||
b_scale = torch.Tensor([1.0]).to(device="cuda")
|
||||
input_bias = torch.rand(32, device="cuda", dtype=dtype)
|
||||
a_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
|
||||
b_scale = torch.Tensor([1.0]).to(device=GPU_TYPE)
|
||||
input_bias = torch.rand(32, device=GPU_TYPE, dtype=dtype)
|
||||
weight_shape = (32, 16)
|
||||
weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T
|
||||
weight = torch.rand(*weight_shape, device=GPU_TYPE, dtype=dtype).T
|
||||
a_inverse_scale = 1 / a_scale
|
||||
b_inverse_scale = 1 / b_scale
|
||||
|
||||
x_shape = (16, 16)
|
||||
x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(torch.float8_e4m3fn)
|
||||
x = torch.rand(*x_shape, device=GPU_TYPE, dtype=dtype).to(torch.float8_e4m3fn)
|
||||
dim0_x = Dim("dim0_x", min=1, max=2048)
|
||||
dynamic_shapes = ({0: dim0_x}, None, None, None, None)
|
||||
self.check_model(
|
||||
@ -705,9 +709,10 @@ class AOTInductorTestsTemplate:
|
||||
"FP8 is only supported on H100+",
|
||||
)
|
||||
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
|
||||
@skipIfXpu
|
||||
def test_fp8_view_of_param(self):
|
||||
# cuda only
|
||||
if self.device != "cuda":
|
||||
if self.device != GPU_TYPE:
|
||||
return
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
@ -1025,9 +1030,10 @@ class AOTInductorTestsTemplate:
|
||||
)
|
||||
self.check_model(Repro(), example_inputs)
|
||||
|
||||
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
|
||||
def test_fallback_kernel_with_symexpr_output(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, q, k, v):
|
||||
@ -1076,8 +1082,8 @@ class AOTInductorTestsTemplate:
|
||||
torch.testing.assert_close(m(*inputs), aot_model(*inputs))
|
||||
|
||||
def test_large_grid(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -1262,8 +1268,16 @@ class AOTInductorTestsTemplate:
|
||||
|
||||
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
|
||||
|
||||
input1 = (torch.ones(3, 3), torch.ones(5), torch.ones(3, 3))
|
||||
input2 = (torch.ones(10, 3), torch.ones(6), torch.ones(10, 3))
|
||||
input1 = (
|
||||
torch.ones(3, 3, device=self.device),
|
||||
torch.ones(5, device=self.device),
|
||||
torch.ones(3, 3, device=self.device),
|
||||
)
|
||||
input2 = (
|
||||
torch.ones(10, 3, device=self.device),
|
||||
torch.ones(6, device=self.device),
|
||||
torch.ones(10, 3, device=self.device),
|
||||
)
|
||||
inputs = (input1, input2)
|
||||
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
|
||||
self.check_model_with_multiple_inputs(
|
||||
@ -1390,6 +1404,9 @@ class AOTInductorTestsTemplate:
|
||||
|
||||
@unittest.skipIf(IS_MACOS, "no CUDA on Mac")
|
||||
def test_zero_grid_with_backed_symbols(self):
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Repro(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -1412,7 +1429,7 @@ class AOTInductorTestsTemplate:
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
aot_inductor_module = AOTIRunnerUtil.load("cuda", so_path)
|
||||
aot_inductor_module = AOTIRunnerUtil.load(self.device, so_path)
|
||||
aot_inductor_module(*example_inputs)
|
||||
|
||||
# Re-run where dynamic dim size is 0.
|
||||
@ -1543,8 +1560,8 @@ class AOTInductorTestsTemplate:
|
||||
self.code_check_count(model, example_inputs, "empty_strided", 2)
|
||||
|
||||
def test_buffer_mutation_4(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -1555,14 +1572,17 @@ class AOTInductorTestsTemplate:
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self._tensor_constant0.to(torch.device(type="cuda", index=0))
|
||||
return x + self._tensor_constant0.to(
|
||||
torch.device(type=GPU_TYPE, index=0)
|
||||
)
|
||||
|
||||
example_inputs = (
|
||||
torch.randint(1, size=[38], dtype=torch.int64, device="cuda"),
|
||||
torch.randint(1, size=[38], dtype=torch.int64, device=GPU_TYPE),
|
||||
)
|
||||
torch._export.aot_compile(Model(), example_inputs)
|
||||
|
||||
@skipCUDAIf(True, "Test for x86 backend")
|
||||
@skipIfXpu
|
||||
def test_buffer_mutation_and_force_mmap_weights(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
@ -1594,8 +1614,8 @@ class AOTInductorTestsTemplate:
|
||||
|
||||
@requires_multigpu()
|
||||
def test_replicate_on_devices(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, w1, w2):
|
||||
@ -1614,29 +1634,33 @@ class AOTInductorTestsTemplate:
|
||||
result_cpu = Model(w1, w2)(*inputs)
|
||||
|
||||
# Compile model with AOTInductor
|
||||
with torch.cuda.device(0):
|
||||
device_interface = get_interface_for_device(GPU_TYPE)
|
||||
with device_interface.device(0):
|
||||
so_path = AOTIRunnerUtil.compile(
|
||||
model=Model(w1.cuda(0), w2.cuda(0)),
|
||||
example_inputs=tuple(t.cuda(0) for t in inputs),
|
||||
model=Model(
|
||||
w1.to(torch.device(GPU_TYPE, 0)), w2.to(torch.device(GPU_TYPE, 0))
|
||||
),
|
||||
example_inputs=tuple(t.to(torch.device(GPU_TYPE, 0)) for t in inputs),
|
||||
)
|
||||
|
||||
# Run model on cuda:N
|
||||
for i in range(torch.cuda.device_count()):
|
||||
with torch.cuda.device(i):
|
||||
example_inputs = tuple(t.cuda(i) for t in inputs)
|
||||
optimized = AOTIRunnerUtil.load("cuda", so_path)
|
||||
result_cuda = optimized(*example_inputs)
|
||||
self.assertTrue(same(result_cpu, result_cuda.cpu()))
|
||||
# Run model on gpu:N
|
||||
for i in range(device_interface.device_count()):
|
||||
with device_interface.device(i):
|
||||
example_inputs = tuple(t.to(torch.device(GPU_TYPE, i)) for t in inputs)
|
||||
optimized = AOTIRunnerUtil.load(GPU_TYPE, so_path)
|
||||
result_gpu = optimized(*example_inputs)
|
||||
self.assertTrue(same(result_cpu, result_gpu.cpu()))
|
||||
|
||||
@requires_multigpu()
|
||||
def test_on_cuda_device1(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
def test_on_gpu_device1(self):
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
device_interface = get_interface_for_device(GPU_TYPE)
|
||||
try:
|
||||
torch.cuda.get_device_properties(1)
|
||||
device_interface.get_device_properties(1)
|
||||
except AssertionError:
|
||||
raise unittest.SkipTest("CUDA device 1 is not available") from None
|
||||
raise unittest.SkipTest("GPU device 1 is not available") from None
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -1653,7 +1677,7 @@ class AOTInductorTestsTemplate:
|
||||
x = self.sigmoid(x)
|
||||
return x
|
||||
|
||||
device = "cuda:1"
|
||||
device = f"{GPU_TYPE}:1"
|
||||
model = Model().to(device)
|
||||
example_inputs = (torch.randn(8, 10, device=device),)
|
||||
expected = model(*example_inputs)
|
||||
@ -1689,9 +1713,9 @@ class AOTInductorTestsTemplate:
|
||||
)
|
||||
|
||||
@requires_multigpu()
|
||||
def test_non_default_cuda_device(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
def test_non_default_gpu_device(self):
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, weight):
|
||||
@ -1705,18 +1729,23 @@ class AOTInductorTestsTemplate:
|
||||
inputs = (torch.randn(10, 10), torch.randn(10, 10))
|
||||
result_cpu = Model(weight)(*inputs)
|
||||
|
||||
with torch.cuda.device(0), torch.no_grad():
|
||||
result_cuda_0 = AOTIRunnerUtil.run(
|
||||
"cuda", Model(weight.cuda(0)), tuple(t.cuda(0) for t in inputs)
|
||||
device_interface = get_interface_for_device(GPU_TYPE)
|
||||
with device_interface.device(0), torch.no_grad():
|
||||
result_gpu_0 = AOTIRunnerUtil.run(
|
||||
GPU_TYPE,
|
||||
Model(weight.to(torch.device(GPU_TYPE, 0))),
|
||||
tuple(t.to(torch.device(GPU_TYPE, 0)) for t in inputs),
|
||||
)
|
||||
|
||||
with torch.cuda.device(1), torch.no_grad():
|
||||
result_cuda_1 = AOTIRunnerUtil.run(
|
||||
"cuda", Model(weight.cuda(1)), tuple(t.cuda(1) for t in inputs)
|
||||
with device_interface.device(1), torch.no_grad():
|
||||
result_gpu_1 = AOTIRunnerUtil.run(
|
||||
GPU_TYPE,
|
||||
Model(weight.to(torch.device(GPU_TYPE, 1))),
|
||||
tuple(t.to(torch.device(GPU_TYPE, 1)) for t in inputs),
|
||||
)
|
||||
|
||||
self.assertTrue(same(result_cpu, result_cuda_0.cpu()))
|
||||
self.assertTrue(same(result_cpu, result_cuda_1.cpu()))
|
||||
self.assertTrue(same(result_cpu, result_gpu_0.cpu()))
|
||||
self.assertTrue(same(result_cpu, result_gpu_1.cpu()))
|
||||
|
||||
def test_reuse_kernel(self):
|
||||
class Model(torch.nn.Module):
|
||||
@ -1739,7 +1768,7 @@ class AOTInductorTestsTemplate:
|
||||
model, example_inputs, atol=1e-4, rtol=1e-4
|
||||
) # 1e-4 is the tol value used in pytorch/torch/_dynamo/utils.py
|
||||
|
||||
if self.device == "cuda":
|
||||
if self.device == GPU_TYPE:
|
||||
self.code_check_count(
|
||||
model, example_inputs, "triton_poi_fused_sin_0 = loadKernel(", 1
|
||||
)
|
||||
@ -1809,8 +1838,8 @@ class AOTInductorTestsTemplate:
|
||||
self.check_model(m, example_inputs, dynamic_shapes=dynamic_shapes)
|
||||
|
||||
def test_fake_tensor_device_validation(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -1824,7 +1853,7 @@ class AOTInductorTestsTemplate:
|
||||
# Export on CPU
|
||||
exported_program = export(Model(), example_inputs)
|
||||
|
||||
# Compile exported model on CUDA
|
||||
# Compile exported model on GPU
|
||||
gm = exported_program.graph_module.to(self.device)
|
||||
with self.assertRaisesRegex(ValueError, "Device mismatch between fake input"):
|
||||
torch._inductor.aot_compile(
|
||||
@ -1903,7 +1932,7 @@ class AOTInductorTestsTemplate:
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.normal_functional.default(x)
|
||||
|
||||
self.check_model(Model(), (torch.empty(4, 1, 4, 4),))
|
||||
self.check_model(Model(), (torch.empty(4, 1, 4, 4, device=self.device),))
|
||||
|
||||
def test_empty_graph(self):
|
||||
class Model(torch.nn.Module):
|
||||
@ -2101,8 +2130,8 @@ class AOTInductorTestsTemplate:
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
@common_utils.parametrize("autotune", [False, True])
|
||||
def test_triton_kernel(self, grid_type, num_dims, dynamic, autotune):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -2171,8 +2200,8 @@ class AOTInductorTestsTemplate:
|
||||
self.check_model(Model(), (x, y), dynamic_shapes=dynamic_shapes)
|
||||
|
||||
def test_triton_kernel_dynamic_shape_with_div(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
@triton.jit
|
||||
def pass_kernel(x, num):
|
||||
@ -2195,8 +2224,8 @@ class AOTInductorTestsTemplate:
|
||||
self.check_model(Model(), (x,), dynamic_shapes=dynamic_shapes)
|
||||
|
||||
def test_triton_kernel_reinterpret_view(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
@triton.jit
|
||||
def pass_kernel(x, y):
|
||||
@ -2224,8 +2253,8 @@ class AOTInductorTestsTemplate:
|
||||
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
def test_triton_kernel_tma_descriptor_1d(self, dynamic):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
if not has_triton_tma():
|
||||
raise unittest.SkipTest("requires Triton TMA")
|
||||
|
||||
@ -2280,8 +2309,8 @@ class AOTInductorTestsTemplate:
|
||||
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
def test_triton_kernel_tma_descriptor_2d(self, dynamic):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
if not has_triton_tma():
|
||||
raise unittest.SkipTest("requires Triton TMA")
|
||||
|
||||
@ -2340,8 +2369,8 @@ class AOTInductorTestsTemplate:
|
||||
)
|
||||
|
||||
def test_triton_kernel_sympy_expr_arg(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, e):
|
||||
@ -2366,8 +2395,8 @@ class AOTInductorTestsTemplate:
|
||||
def test_triton_kernel_sympy_fn_like_arg(self):
|
||||
# This test should hit sympy.expand("sqrt") which crashes with
|
||||
# AttributeError: 'function' object has no attribute 'expand'.
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -2386,8 +2415,8 @@ class AOTInductorTestsTemplate:
|
||||
self.check_model(Model(), inputs)
|
||||
|
||||
def test_triton_kernel_with_none_input(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -2427,8 +2456,8 @@ class AOTInductorTestsTemplate:
|
||||
self.check_model(Model(), example_inputs)
|
||||
|
||||
def test_triton_kernel_equal_to_1_arg(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -2446,8 +2475,8 @@ class AOTInductorTestsTemplate:
|
||||
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
def test_triton_kernel_equal_to_1_float_arg(self, dynamic):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -2482,8 +2511,8 @@ class AOTInductorTestsTemplate:
|
||||
)
|
||||
|
||||
def test_triton_kernel_weird_param_order(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -2595,8 +2624,8 @@ class AOTInductorTestsTemplate:
|
||||
self.check_model(Model(), inputs)
|
||||
|
||||
def test_repeated_user_defined_triton_kernel(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -2865,8 +2894,8 @@ class AOTInductorTestsTemplate:
|
||||
self.check_model(model, example_inputs)
|
||||
|
||||
def test_triton_kernel_extern_kernel_arg(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -2876,15 +2905,15 @@ class AOTInductorTestsTemplate:
|
||||
return out
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(4, 4, device="cuda"),
|
||||
torch.randn(4, 4, device="cuda"),
|
||||
torch.randn(4, 4, device=GPU_TYPE),
|
||||
torch.randn(4, 4, device=GPU_TYPE),
|
||||
)
|
||||
|
||||
self.check_model(Model(), example_inputs)
|
||||
|
||||
def test_triton_kernel_multi_output_arg(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -2894,16 +2923,17 @@ class AOTInductorTestsTemplate:
|
||||
return out
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(4, 4, device="cuda"),
|
||||
torch.randn(4, 4, device="cuda"),
|
||||
torch.randn(4, 4, device=GPU_TYPE),
|
||||
torch.randn(4, 4, device=GPU_TYPE),
|
||||
)
|
||||
|
||||
self.check_model(Model(), example_inputs)
|
||||
|
||||
# @skipIfXpu(msg="torch.xpu.memory_allocated not supported yet")
|
||||
def test_triton_kernel_reinterpret_view_mem_leak(self):
|
||||
# Check for memory leak when using user-defined Triton Kernel + AOTI.
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -2917,22 +2947,23 @@ class AOTInductorTestsTemplate:
|
||||
return out
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(4, 4, device="cuda"),
|
||||
torch.randn(1, 16, device="cuda"),
|
||||
torch.randn(4, 4, device=GPU_TYPE),
|
||||
torch.randn(1, 16, device=GPU_TYPE),
|
||||
)
|
||||
|
||||
so_path: str = AOTIRunnerUtil.compile(
|
||||
Model(),
|
||||
example_inputs,
|
||||
)
|
||||
aot_inductor_module = AOTIRunnerUtil.load("cuda", so_path)
|
||||
aot_inductor_module = AOTIRunnerUtil.load(GPU_TYPE, so_path)
|
||||
|
||||
# Don't assign outputs to a variable b/c it will allocate GPU memory.
|
||||
device: int = torch.cuda.current_device()
|
||||
mem_before = torch.cuda.memory_allocated(device)
|
||||
device_interface = get_interface_for_device(GPU_TYPE)
|
||||
device: int = device_interface.current_device()
|
||||
mem_before = device_interface.memory_allocated(device)
|
||||
aot_inductor_module(*example_inputs)
|
||||
aot_inductor_module(*example_inputs)
|
||||
mem_after = torch.cuda.memory_allocated(device)
|
||||
mem_after = device_interface.memory_allocated(device)
|
||||
self.assertEqual(mem_before, mem_after)
|
||||
|
||||
actual = aot_inductor_module(*example_inputs)
|
||||
@ -2943,8 +2974,8 @@ class AOTInductorTestsTemplate:
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
@common_utils.parametrize("autotuning", [False, True])
|
||||
def test_triton_kernel_unbacked_symint_in_grid(self, dynamic, autotuning):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y, n_elements_tensor):
|
||||
@ -2974,8 +3005,8 @@ class AOTInductorTestsTemplate:
|
||||
return output
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(123, device="cuda"),
|
||||
torch.randn(123, device="cuda"),
|
||||
torch.randn(123, device=GPU_TYPE),
|
||||
torch.randn(123, device=GPU_TYPE),
|
||||
torch.tensor(123),
|
||||
)
|
||||
|
||||
@ -2996,8 +3027,8 @@ class AOTInductorTestsTemplate:
|
||||
|
||||
@skipIfRocm # USE_MEM_EFF_ATTENTION was not enabled for build.
|
||||
def test_scaled_dot_product_efficient_attention(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, q, k, v, attn_bias):
|
||||
@ -3006,10 +3037,10 @@ class AOTInductorTestsTemplate:
|
||||
)[0]
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(4, 4, 36, 36, device="cuda"),
|
||||
torch.randn(4, 4, 36, 36, device="cuda"),
|
||||
torch.randn(4, 4, 36, 36, device="cuda"),
|
||||
torch.randn(4, 4, 36, 36, device="cuda"),
|
||||
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
|
||||
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
|
||||
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
|
||||
torch.randn(4, 4, 36, 36, device=GPU_TYPE),
|
||||
)
|
||||
self.check_model(Model(), example_inputs)
|
||||
|
||||
@ -3510,9 +3541,9 @@ class AOTInductorTestsTemplate:
|
||||
kernel_calls = (
|
||||
[
|
||||
("triton_poi_fused_0", 1),
|
||||
("aoti_torch_cuda_addmm_out", 2),
|
||||
(f"aoti_torch_{GPU_TYPE}_addmm_out", 2),
|
||||
]
|
||||
if self.device == "cuda"
|
||||
if self.device == GPU_TYPE
|
||||
else [
|
||||
("aoti_torch_cpu_addmm_out", 2),
|
||||
]
|
||||
@ -3573,8 +3604,8 @@ class AOTInductorTestsTemplate:
|
||||
FileCheck().check_not(f"after_launch - {kernel_name}").run(code)
|
||||
|
||||
def test_aoti_debug_printer_user_defined_triton_kernel(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -3650,8 +3681,8 @@ class AOTInductorTestsTemplate:
|
||||
).run(code)
|
||||
|
||||
def test_aoti_debug_printer_sym_inputs(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
from torch.testing._internal.triton_utils import add_kernel
|
||||
|
||||
@ -3661,8 +3692,8 @@ class AOTInductorTestsTemplate:
|
||||
|
||||
def forward(self, x):
|
||||
maxlen = max(x.item(), 512)
|
||||
a = torch.ones(maxlen, device="cuda")
|
||||
b = torch.ones(maxlen, device="cuda")
|
||||
a = torch.ones(maxlen, device=GPU_TYPE)
|
||||
b = torch.ones(maxlen, device=GPU_TYPE)
|
||||
out = torch.zeros_like(a)
|
||||
# unbacked symint in grid
|
||||
add_kernel[(1, 1, maxlen)](a, b, out, maxlen, 32)
|
||||
@ -3739,8 +3770,8 @@ class AOTInductorTestsTemplate:
|
||||
|
||||
@dynamo_config.patch({"capture_scalar_outputs": True})
|
||||
def test_sym_i64_input_codegen(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
from torch.testing._internal.triton_utils import add_kernel
|
||||
|
||||
@ -3750,8 +3781,8 @@ class AOTInductorTestsTemplate:
|
||||
|
||||
def forward(self, x):
|
||||
x_symint = x.item()
|
||||
a = torch.ones(x_symint, device="cuda")
|
||||
b = torch.ones(x_symint, device="cuda")
|
||||
a = torch.ones(x_symint, device=GPU_TYPE)
|
||||
b = torch.ones(x_symint, device=GPU_TYPE)
|
||||
out = torch.zeros_like(a)
|
||||
# unbacked symint in grid
|
||||
add_kernel[(1, 1, x_symint)](a, b, out, x_symint, 32)
|
||||
@ -3786,8 +3817,8 @@ class AOTInductorTestsTemplate:
|
||||
self.check_model(Model(), example_inputs)
|
||||
|
||||
def test_none_args_aot_codegen(self):
|
||||
if self.device != "cuda":
|
||||
raise unittest.SkipTest("requires CUDA")
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
@ -4099,9 +4130,9 @@ def fail_cpu(is_skip=False):
|
||||
)
|
||||
|
||||
|
||||
def fail_cuda(is_skip=False):
|
||||
def fail_gpu(suffixes: Tuple[str, ...], is_skip=False):
|
||||
return TestFailure(
|
||||
("cuda"),
|
||||
suffixes,
|
||||
is_skip=is_skip,
|
||||
)
|
||||
|
||||
@ -4115,10 +4146,14 @@ CPU_TEST_FAILURES = {
|
||||
}
|
||||
|
||||
# test_failures, xfail by default, set is_skip=True to skip
|
||||
CUDA_TEST_FAILURES = {
|
||||
GPU_TEST_FAILURES = {
|
||||
# quantized unsupported for GPU
|
||||
"test_quantized_linear": fail_cuda(),
|
||||
"test_quanatized_int8_linear": fail_cuda(),
|
||||
"test_quantized_linear": fail_gpu(("cuda", "xpu")),
|
||||
"test_quanatized_int8_linear": fail_gpu(("cuda", "xpu")),
|
||||
# No fft implementation for XPU yet.
|
||||
"test_fft_c2c": fail_gpu(("xpu",)),
|
||||
# No scaled_dot_product_efficient_attention implementation for XPU yet.
|
||||
"test_scaled_dot_product_efficient_attention": fail_gpu(("xpu",)),
|
||||
}
|
||||
|
||||
|
||||
@ -4141,9 +4176,9 @@ copy_tests(
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS")
|
||||
class AOTInductorTestABICompatibleCuda(TestCase):
|
||||
device = "cuda"
|
||||
device_type = "cuda"
|
||||
class AOTInductorTestABICompatibleGpu(TestCase):
|
||||
device = GPU_TYPE
|
||||
device_type = GPU_TYPE
|
||||
check_model = check_model
|
||||
check_model_with_multiple_inputs = check_model_with_multiple_inputs
|
||||
code_check_count = code_check_count
|
||||
@ -4153,14 +4188,14 @@ class AOTInductorTestABICompatibleCuda(TestCase):
|
||||
|
||||
copy_tests(
|
||||
AOTInductorTestsTemplate,
|
||||
AOTInductorTestABICompatibleCuda,
|
||||
"cuda",
|
||||
CUDA_TEST_FAILURES,
|
||||
AOTInductorTestABICompatibleGpu,
|
||||
GPU_TYPE,
|
||||
GPU_TEST_FAILURES,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
# cpp_extension N/A in fbcode
|
||||
if HAS_CUDA or sys.platform == "darwin":
|
||||
if HAS_GPU or sys.platform == "darwin":
|
||||
run_tests(needs="filelock")
|
||||
|
@ -92,11 +92,12 @@ class AOTIRunnerUtil:
|
||||
temp_so_path, device == "cpu"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1)
|
||||
if device == "cpu"
|
||||
else torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device)
|
||||
)
|
||||
if device == "cpu":
|
||||
return torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1)
|
||||
elif device == "xpu":
|
||||
return torch._C._aoti.AOTIModelContainerRunnerXpu(so_path, 1, device)
|
||||
else:
|
||||
return torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device)
|
||||
|
||||
@staticmethod
|
||||
def load(device, so_path):
|
||||
|
@ -18,6 +18,7 @@ def alloc_tensor_by_stealing_from_void_ptr(
|
||||
|
||||
class AOTIModelContainerRunnerCpu: ...
|
||||
class AOTIModelContainerRunnerCuda: ...
|
||||
class AOTIModelContainerRunnerXpu: ...
|
||||
|
||||
# Defined in torch/csrc/inductor/aoti_package/pybind.cpp
|
||||
class AOTIModelPackageLoader: ...
|
||||
|
@ -373,6 +373,9 @@ def aot_load(so_path: str, device: str) -> Callable:
|
||||
runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg]
|
||||
elif device == "cuda" or device.startswith("cuda:"):
|
||||
runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
|
||||
elif device == "xpu" or device.startswith("xpu:"):
|
||||
runner = torch._C._aoti.AOTIModelContainerRunnerXpu(so_path, 1, device) # type: ignore[assignment, call-arg]
|
||||
|
||||
else:
|
||||
raise RuntimeError("Unsupported device " + device)
|
||||
|
||||
|
@ -1531,7 +1531,9 @@ class AotCodeCompiler:
|
||||
object_output_dir,
|
||||
) = get_name_and_dir_from_output_file_path(consts_s)
|
||||
object_build_options = CppTorchDeviceOptions(
|
||||
device_type=device_type,
|
||||
# Intel compiler failed to compile this manully constructed assembly file.
|
||||
# it is ok to use gcc to compile the .S to a .o and linked with Intel comiler .
|
||||
device_type=device_type if device_type != "xpu" else "cpu",
|
||||
aot_mode=graph.aot_mode,
|
||||
compile_only=True,
|
||||
use_absolute_path=use_absolute_path,
|
||||
|
@ -20,7 +20,7 @@ from .. import config, ir
|
||||
from ..utils import _align, ALIGN_BYTES, cache_on_self, normalize_name
|
||||
from ..virtualized import V
|
||||
from .aoti_hipify_utils import maybe_hipify_code_wrapper
|
||||
from .common import IndentedBuffer, Kernel
|
||||
from .common import get_device_op_overrides, IndentedBuffer, Kernel
|
||||
from .cpp_utils import cexpr, DEVICE_TO_ATEN, DTYPE_TO_ATEN, DTYPE_TO_CPP
|
||||
from .triton_utils import should_unwrap_unspec_arg
|
||||
from .wrapper import (
|
||||
@ -66,6 +66,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
self.custom_op_wrapper_loaded = False
|
||||
# For GEMM kernels that must be initialized and are resolved at linking.
|
||||
self.initialized_kernels: Dict[str, Kernel] = {}
|
||||
self.device_codegen = get_device_op_overrides(self.device)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
@ -571,7 +572,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
)
|
||||
for kernel in sorted(declare_kernel):
|
||||
self.prefix.writeline(
|
||||
maybe_hipify_code_wrapper(f" CUfunction {kernel}{{nullptr}};")
|
||||
maybe_hipify_code_wrapper(
|
||||
f" {self.device_codegen.cpp_kernel_type()} {kernel}{{nullptr}};"
|
||||
)
|
||||
)
|
||||
for name, kernel in self.initialized_kernels.items():
|
||||
assert hasattr(
|
||||
|
@ -0,0 +1,34 @@
|
||||
#if !defined(C10_MOBILE) && !defined(ANDROID)
|
||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner_xpu.h>
|
||||
|
||||
namespace torch::inductor {
|
||||
|
||||
AOTIModelContainerRunnerXpu::AOTIModelContainerRunnerXpu(
|
||||
const std::string& model_so_path,
|
||||
size_t num_models,
|
||||
const std::string& device_str,
|
||||
const std::string& cubin_dir)
|
||||
: AOTIModelContainerRunner(
|
||||
model_so_path,
|
||||
num_models,
|
||||
device_str,
|
||||
cubin_dir) {}
|
||||
|
||||
AOTIModelContainerRunnerXpu::~AOTIModelContainerRunnerXpu() = default;
|
||||
|
||||
std::vector<at::Tensor> AOTIModelContainerRunnerXpu::run(
|
||||
std::vector<at::Tensor>& inputs) {
|
||||
at::xpu::XPUStream xpu_stream = c10::xpu::getCurrentXPUStream();
|
||||
return AOTIModelContainerRunner::run(
|
||||
inputs, reinterpret_cast<AOTInductorStreamHandle>(&(xpu_stream.queue())));
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> AOTIModelContainerRunnerXpu::run_with_xpu_stream(
|
||||
std::vector<at::Tensor>& inputs,
|
||||
at::xpu::XPUStream xpu_stream) {
|
||||
return AOTIModelContainerRunner::run(
|
||||
inputs, reinterpret_cast<AOTInductorStreamHandle>(&(xpu_stream.queue())));
|
||||
}
|
||||
|
||||
} // namespace torch::inductor
|
||||
#endif
|
30
torch/csrc/inductor/aoti_runner/model_container_runner_xpu.h
Normal file
30
torch/csrc/inductor/aoti_runner/model_container_runner_xpu.h
Normal file
@ -0,0 +1,30 @@
|
||||
#if !defined(C10_MOBILE) && !defined(ANDROID)
|
||||
#pragma once
|
||||
|
||||
#include <c10/xpu/XPUStream.h>
|
||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner.h>
|
||||
|
||||
namespace torch::inductor {
|
||||
|
||||
// NOTICE: Following APIs are subject to change due to active development
|
||||
// We provide NO BC guarantee for these APIs
|
||||
class TORCH_API AOTIModelContainerRunnerXpu : public AOTIModelContainerRunner {
|
||||
public:
|
||||
// @param device_str: xpu device string, e.g. "xpu", "xpu:0"
|
||||
AOTIModelContainerRunnerXpu(
|
||||
const std::string& model_so_path,
|
||||
size_t num_models = 1,
|
||||
const std::string& device_str = "xpu",
|
||||
const std::string& cubin_dir = "");
|
||||
|
||||
~AOTIModelContainerRunnerXpu();
|
||||
|
||||
std::vector<at::Tensor> run(std::vector<at::Tensor>& inputs);
|
||||
|
||||
std::vector<at::Tensor> run_with_xpu_stream(
|
||||
std::vector<at::Tensor>& inputs,
|
||||
at::xpu::XPUStream xpu_stream);
|
||||
};
|
||||
|
||||
} // namespace torch::inductor
|
||||
#endif
|
@ -2,6 +2,9 @@
|
||||
#ifdef USE_CUDA
|
||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
|
||||
#endif
|
||||
#ifdef USE_XPU
|
||||
#include <torch/csrc/inductor/aoti_runner/model_container_runner_xpu.h>
|
||||
#endif
|
||||
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
|
||||
@ -51,6 +54,30 @@ void initAOTIRunnerBindings(PyObject* module) {
|
||||
static_cast<void (AOTIModelContainerRunnerCuda::*)(
|
||||
std::unordered_map<std::string, at::Tensor>&, bool, bool)>(
|
||||
&AOTIModelContainerRunnerCuda::update_constant_buffer));
|
||||
#endif
|
||||
#ifdef USE_XPU
|
||||
py::class_<AOTIModelContainerRunnerXpu>(m, "AOTIModelContainerRunnerXpu")
|
||||
.def(py::init<const std::string&, int>())
|
||||
.def(py::init<const std::string&, int, const std::string&>())
|
||||
.def(py::init<
|
||||
const std::string&,
|
||||
int,
|
||||
const std::string&,
|
||||
const std::string&>())
|
||||
.def("run", &AOTIModelContainerRunnerXpu::run)
|
||||
.def("get_call_spec", &AOTIModelContainerRunnerXpu::get_call_spec)
|
||||
.def(
|
||||
"get_constant_names_to_original_fqns",
|
||||
&AOTIModelContainerRunnerXpu::getConstantNamesToOriginalFQNs)
|
||||
.def(
|
||||
"get_constant_names_to_dtypes",
|
||||
&AOTIModelContainerRunnerXpu::getConstantNamesToDtypes)
|
||||
.def(
|
||||
"update_constant_buffer",
|
||||
static_cast<void (AOTIModelContainerRunnerXpu::*)(
|
||||
std::unordered_map<std::string, at::Tensor>&, bool, bool)>(
|
||||
&AOTIModelContainerRunnerXpu::update_constant_buffer));
|
||||
|
||||
#endif
|
||||
|
||||
m.def(
|
||||
|
@ -30,7 +30,27 @@ using DeviceStreamType = cudaStream_t;
|
||||
|
||||
} // namespace torch::aot_inductor
|
||||
|
||||
#else // !USE_CUDA
|
||||
#elif defined(USE_XPU)
|
||||
#include <level_zero/ze_api.h>
|
||||
#include <sycl/sycl.hpp>
|
||||
#include <sstream>
|
||||
#define AOTI_RUNTIME_DEVICE_CHECK(EXPR) \
|
||||
do { \
|
||||
const ze_result_t status = EXPR; \
|
||||
if (status != ZE_RESULT_SUCCESS) { \
|
||||
std::stringstream ss; \
|
||||
ss << "L0 runtime error: " << std::hex << std::uppercase << status; \
|
||||
throw std::runtime_error(ss.str()); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace torch::aot_inductor {
|
||||
|
||||
using DeviceStreamType = sycl::queue*;
|
||||
|
||||
} // namespace torch::aot_inductor
|
||||
|
||||
#else
|
||||
|
||||
#define AOTI_RUNTIME_DEVICE_CHECK(EXPR) \
|
||||
bool ok = EXPR; \
|
||||
|
@ -15,7 +15,11 @@
|
||||
// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
|
||||
// applies to other files under torch/csrc/inductor/aoti_runtime/.
|
||||
#include <torch/csrc/inductor/aoti_runtime/device_utils.h>
|
||||
#ifdef USE_XPU
|
||||
#include <torch/csrc/inductor/aoti_runtime/utils_xpu.h>
|
||||
#else
|
||||
#include <torch/csrc/inductor/aoti_runtime/utils.h>
|
||||
#endif
|
||||
|
||||
#define AOTI_RUNTIME_CHECK(EXPR, MSG) \
|
||||
do { \
|
||||
@ -44,13 +48,27 @@ namespace {
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
using CUDAPtr = std::unique_ptr<void, std::function<void(void*)>>;
|
||||
using GPUPtr = std::unique_ptr<void, std::function<void(void*)>>;
|
||||
|
||||
CUDAPtr RAII_cudaMalloc(size_t num_bytes) {
|
||||
GPUPtr RAII_gpuMalloc(size_t num_bytes) {
|
||||
void* data_ptr;
|
||||
AOTI_RUNTIME_DEVICE_CHECK(cudaMalloc((void**)&data_ptr, num_bytes));
|
||||
auto deleter = [](void* ptr) { AOTI_RUNTIME_DEVICE_CHECK(cudaFree(ptr)); };
|
||||
return CUDAPtr(data_ptr, deleter);
|
||||
return GPUPtr(data_ptr, deleter);
|
||||
}
|
||||
|
||||
#endif // USE_CUDA
|
||||
|
||||
#ifdef USE_XPU
|
||||
|
||||
using GPUPtr = std::unique_ptr<void, std::function<void(void*)>>;
|
||||
|
||||
GPUPtr RAII_gpuMalloc(size_t num_bytes) {
|
||||
sycl::queue* queue_ptr = nullptr;
|
||||
aoti_torch_get_current_sycl_queue((void**)&queue_ptr);
|
||||
void* data_ptr = sycl::malloc_device(num_bytes, *queue_ptr);
|
||||
auto deleter = [queue_ptr](void* ptr) { sycl::free(ptr, *queue_ptr); };
|
||||
return GPUPtr(data_ptr, deleter);
|
||||
}
|
||||
|
||||
#endif // USE_CUDA
|
||||
@ -74,7 +92,7 @@ inline void parse_device_str(
|
||||
const std::string& device_str,
|
||||
int32_t& device_type,
|
||||
int32_t& device_idx) {
|
||||
std::regex re("(cpu|cuda)(:([0-9]+))?");
|
||||
std::regex re("(cpu|cuda|xpu)(:([0-9]+))?");
|
||||
std::smatch sm;
|
||||
bool matched = std::regex_match(device_str, sm, re);
|
||||
AOTI_RUNTIME_CHECK(matched, "Invalid device: " + device_str);
|
||||
@ -83,6 +101,10 @@ inline void parse_device_str(
|
||||
device_type = aoti_torch_device_type_cpu();
|
||||
} else if (sm[1].str() == "cuda") {
|
||||
device_type = aoti_torch_device_type_cuda();
|
||||
#ifdef USE_XPU
|
||||
} else if (sm[1].str() == "xpu") {
|
||||
device_type = aoti_torch_device_type_xpu();
|
||||
#endif
|
||||
} else {
|
||||
AOTI_RUNTIME_CHECK(false, "Invalid device: " + device_str);
|
||||
}
|
||||
@ -124,6 +146,13 @@ class AOTInductorModelBase {
|
||||
AOTI_RUNTIME_DEVICE_CHECK(cudaSetDevice(device_idx_));
|
||||
}
|
||||
#endif // USE_CUDA
|
||||
#ifdef USE_XPU
|
||||
if (device_idx_ == -1) {
|
||||
aoti_torch_get_current_xpu_device(&device_idx_);
|
||||
} else {
|
||||
aoti_torch_set_current_xpu_device(device_idx_);
|
||||
}
|
||||
#endif // USE_XPU
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(modernize-use-equals-default)
|
||||
@ -137,6 +166,12 @@ class AOTInductorModelBase {
|
||||
}
|
||||
}
|
||||
#endif // USE_CUDA
|
||||
#ifdef USE_XPU
|
||||
if (run_finished_) {
|
||||
(*run_finished_)->wait_and_throw();
|
||||
delete *run_finished_;
|
||||
}
|
||||
#endif // USE_XPU
|
||||
}
|
||||
|
||||
AOTInductorModelBase(AOTInductorModelBase&&) = delete;
|
||||
@ -160,14 +195,25 @@ class AOTInductorModelBase {
|
||||
AOTI_RUNTIME_DEVICE_CHECK(cudaEventCreate(&run_finished));
|
||||
run_finished_.emplace(run_finished);
|
||||
}
|
||||
#elif defined(USE_XPU)
|
||||
if (run_finished_) {
|
||||
(*run_finished_)->wait_and_throw();
|
||||
delete *run_finished_;
|
||||
run_finished_.reset();
|
||||
}
|
||||
#else // !USE_CUDA && !USE_XPU
|
||||
run_finished_ = false;
|
||||
#endif
|
||||
|
||||
auto* model = static_cast<Model*>(this);
|
||||
model->run_impl(input_handles, output_handles, stream, proxy_executor);
|
||||
|
||||
#ifdef USE_CUDA
|
||||
AOTI_RUNTIME_DEVICE_CHECK(cudaEventRecord(*run_finished_, stream));
|
||||
#else // !USE_CUDA
|
||||
run_finished_ = false;
|
||||
auto* model = static_cast<Model*>(this);
|
||||
model->run_impl(input_handles, output_handles, stream, proxy_executor);
|
||||
#elif defined(USE_XPU)
|
||||
run_finished_ = std::make_optional<sycl::event*>(new sycl::event(
|
||||
static_cast<sycl::queue*>(stream)->ext_oneapi_submit_barrier()));
|
||||
#else // !USE_CUDA && !USE_XPU
|
||||
run_finished_ = true;
|
||||
#endif // USE_CUDA
|
||||
}
|
||||
@ -182,9 +228,15 @@ class AOTInductorModelBase {
|
||||
AOTI_RUNTIME_DEVICE_CHECK(cudaEventCreate(&run_finished));
|
||||
run_finished_.emplace(run_finished);
|
||||
}
|
||||
#else // USE_CUDA
|
||||
#elif defined(USE_XPU)
|
||||
if (run_finished_) {
|
||||
(*run_finished_)->wait_and_throw();
|
||||
delete *run_finished_;
|
||||
run_finished_.reset();
|
||||
}
|
||||
#else // !USE_CUDA && !USE_XPU
|
||||
run_finished_ = false;
|
||||
#endif // USE_CUDA
|
||||
#endif
|
||||
|
||||
auto* model = static_cast<Model*>(this);
|
||||
auto folded_constants =
|
||||
@ -192,7 +244,13 @@ class AOTInductorModelBase {
|
||||
|
||||
#ifdef USE_CUDA
|
||||
AOTI_RUNTIME_DEVICE_CHECK(cudaEventRecord(*run_finished_, stream));
|
||||
#else // USE_CUDA
|
||||
#elif defined(USE_XPU)
|
||||
// sycl::queue* queue_ptr = nullptr;
|
||||
// aoti_torch_get_current_sycl_queue((void**)&queue_ptr);
|
||||
run_finished_ = std::make_optional<sycl::event*>(new sycl::event(
|
||||
static_cast<sycl::queue*>(stream)->ext_oneapi_submit_barrier()));
|
||||
|
||||
#else // !USE_CUDA && !USE_XPU
|
||||
run_finished_ = true;
|
||||
#endif // USE_CUDA
|
||||
|
||||
@ -206,9 +264,9 @@ class AOTInductorModelBase {
|
||||
std::vector<size_t> constants_internal_offset(num_constants);
|
||||
if (device_type_ != aoti_torch_device_type_cpu()) {
|
||||
size_t blob_size = 0;
|
||||
compute_cuda_constant_blob(blob_size, constants_internal_offset);
|
||||
#ifdef USE_CUDA
|
||||
constant_blob_ = RAII_cudaMalloc(blob_size);
|
||||
compute_gpu_constant_blob(blob_size, constants_internal_offset);
|
||||
#if defined(USE_CUDA) || defined(USE_XPU)
|
||||
constant_blob_ = RAII_gpuMalloc(blob_size);
|
||||
#endif
|
||||
}
|
||||
if (!include_weights) {
|
||||
@ -218,7 +276,7 @@ class AOTInductorModelBase {
|
||||
size_t bytes_read = 0;
|
||||
for (size_t i = 0; i < num_constants; i++) {
|
||||
bool from_folded = this->constant_from_folded(i);
|
||||
#ifndef USE_CUDA
|
||||
#if not defined(USE_XPU) && not defined(USE_CUDA)
|
||||
if (from_folded) {
|
||||
// We do not reallocate and copy for CPU.
|
||||
continue;
|
||||
@ -284,8 +342,8 @@ class AOTInductorModelBase {
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
CUDAPtr&& release_constant_blob() {
|
||||
#if defined(USE_CUDA) || defined(USE_XPU)
|
||||
GPUPtr&& release_constant_blob() {
|
||||
return std::move(constant_blob_);
|
||||
}
|
||||
#endif
|
||||
@ -303,17 +361,26 @@ class AOTInductorModelBase {
|
||||
size_t bytes_read,
|
||||
size_t data_size,
|
||||
bool skip_copy) {
|
||||
#ifdef USE_CUDA
|
||||
#if defined(USE_CUDA) || defined(USE_XPU)
|
||||
auto* constants_ptr = static_cast<uint8_t*>(constant_blob_.get());
|
||||
uint8_t* internal_ptr = constants_ptr + constant_offset;
|
||||
// Copy data to GPU memory
|
||||
// TODO: Handle shared storage case.
|
||||
if (!skip_copy) {
|
||||
#ifdef USE_XPU
|
||||
sycl::queue* queue_ptr = nullptr;
|
||||
aoti_torch_get_current_sycl_queue((void**)&queue_ptr);
|
||||
queue_ptr
|
||||
->memcpy(internal_ptr, _get_constants_start() + bytes_read, data_size)
|
||||
.wait();
|
||||
|
||||
#else
|
||||
AOTI_RUNTIME_DEVICE_CHECK(cudaMemcpy(
|
||||
internal_ptr,
|
||||
_get_constants_start() + bytes_read,
|
||||
data_size,
|
||||
cudaMemcpyHostToDevice));
|
||||
#endif
|
||||
}
|
||||
return internal_ptr;
|
||||
|
||||
@ -324,10 +391,10 @@ class AOTInductorModelBase {
|
||||
#endif // USE_CUDA
|
||||
}
|
||||
|
||||
void compute_cuda_constant_blob(
|
||||
void compute_gpu_constant_blob(
|
||||
size_t& blob_size,
|
||||
std::vector<size_t>& constants_internal_offset) {
|
||||
#ifdef USE_CUDA
|
||||
#if defined(USE_CUDA) || defined(USE_XPU)
|
||||
size_t num_constants = this->num_constants();
|
||||
// Compute required blob size with 64-alignment if on GPU.
|
||||
blob_size = 0;
|
||||
@ -477,7 +544,15 @@ class AOTInductorModelBase {
|
||||
throw std::runtime_error(
|
||||
std::string("The model did not finish successfully. Error: ") +
|
||||
cudaGetErrorString(cudaGetLastError()));
|
||||
#else // !USE_CUDA
|
||||
#elif defined(USE_XPU)
|
||||
if (!run_finished_) {
|
||||
throw std::runtime_error{"Model XPU event was not initialized"};
|
||||
}
|
||||
using namespace sycl::info;
|
||||
return (*run_finished_)->get_info<event::command_execution_status>() ==
|
||||
event_command_status::complete;
|
||||
|
||||
#else // !USE_CUDA && !USE_XPU
|
||||
return run_finished_;
|
||||
#endif // USE_CUDA
|
||||
}
|
||||
@ -491,6 +566,12 @@ class AOTInductorModelBase {
|
||||
|
||||
AOTI_RUNTIME_DEVICE_CHECK(cudaEventSynchronize(*run_finished_));
|
||||
#endif // USE_CUDA
|
||||
#ifdef USE_XPU
|
||||
if (!run_finished_) {
|
||||
throw std::runtime_error{"Model event was not initialized"};
|
||||
}
|
||||
(*run_finished_)->wait_and_throw();
|
||||
#endif
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -562,10 +643,11 @@ class AOTInductorModelBase {
|
||||
std::shared_ptr<ConstantMap> constants_map_;
|
||||
std::shared_ptr<std::vector<ConstantHandle>> constants_;
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#if defined(USE_CUDA) || defined(USE_XPU)
|
||||
// Holds the blob storage for constants' at::Tensor for CUDA.
|
||||
CUDAPtr constant_blob_;
|
||||
GPUPtr constant_blob_;
|
||||
#endif // USE_CUDA
|
||||
|
||||
#ifdef USE_MMAP_SELF
|
||||
uint8_t* self_mmap = NULL;
|
||||
#endif
|
||||
@ -582,6 +664,8 @@ class AOTInductorModelBase {
|
||||
// AOTModelContainer can re-use this instance.
|
||||
#ifdef USE_CUDA
|
||||
std::optional<cudaEvent_t> run_finished_;
|
||||
#elif defined(USE_XPU)
|
||||
std::optional<sycl::event*> run_finished_;
|
||||
#else // !USE_CUDA
|
||||
bool run_finished_{};
|
||||
#endif
|
||||
|
@ -51,12 +51,11 @@ class AOTInductorModelContainer {
|
||||
for (size_t i = 0; i < num_outputs; i++) {
|
||||
output_names_.emplace_back(model->output_name(static_cast<int64_t>(i)));
|
||||
}
|
||||
|
||||
model->load_constants();
|
||||
#ifdef USE_CUDA
|
||||
#if defined(USE_CUDA) || defined(USE_XPU)
|
||||
constant_blob_ = model->release_constant_blob();
|
||||
constants_internal_offset_.resize(model->num_constants());
|
||||
model->compute_cuda_constant_blob(blob_size_, constants_internal_offset_);
|
||||
model->compute_gpu_constant_blob(blob_size_, constants_internal_offset_);
|
||||
#endif
|
||||
|
||||
for (auto& model : models_) {
|
||||
@ -276,7 +275,7 @@ class AOTInductorModelContainer {
|
||||
continue;
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#if defined(USE_CUDA) || defined(USE_XPU)
|
||||
AtenTensorHandle tensor;
|
||||
if (_should_skip_update(idx) && use_inactive) {
|
||||
tensor = original_constants_map->find(constant_name)->second.get();
|
||||
@ -293,13 +292,20 @@ class AOTInductorModelContainer {
|
||||
int64_t constant_size;
|
||||
aoti_torch_get_data_ptr(tensor, &user_constant_ptr);
|
||||
aoti_torch_get_storage_size(tensor, &constant_size);
|
||||
#ifdef USE_XPU
|
||||
sycl::queue* queue_ptr = nullptr;
|
||||
aoti_torch_get_current_sycl_queue((void**)&queue_ptr);
|
||||
queue_ptr
|
||||
->memcpy(internal_constants_ptr, user_constant_ptr, constant_size)
|
||||
.wait();
|
||||
|
||||
#else
|
||||
AOTI_RUNTIME_DEVICE_CHECK(cudaMemcpy(
|
||||
internal_constants_ptr,
|
||||
user_constant_ptr,
|
||||
constant_size,
|
||||
cudaMemcpyDefault));
|
||||
|
||||
#endif
|
||||
// Generate Tensor from container handled blob.
|
||||
// We extract stride and offset from provided Tensor since we do not
|
||||
// guarantee that the tensor is contiguous.
|
||||
@ -317,7 +323,11 @@ class AOTInductorModelContainer {
|
||||
stride,
|
||||
offset,
|
||||
models_[0]->constant_dtype(idx),
|
||||
#ifdef USE_XPU
|
||||
aoti_torch_device_type_xpu(),
|
||||
#else
|
||||
aoti_torch_device_type_cuda(),
|
||||
#endif
|
||||
device_idx,
|
||||
&tensor_handle));
|
||||
#else // USE_CUDA
|
||||
@ -397,10 +407,10 @@ class AOTInductorModelContainer {
|
||||
const char* in_spec_;
|
||||
const char* out_spec_;
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#if defined(USE_CUDA) || defined(USE_XPU)
|
||||
// Holds the blob storage for constants' at::Tensor for CUDA.
|
||||
CUDAPtr constant_blob_;
|
||||
CUDAPtr constant_blob_secondary_;
|
||||
GPUPtr constant_blob_;
|
||||
GPUPtr constant_blob_secondary_;
|
||||
|
||||
// Let's place this within USE_CUDA at the moment before we fully support
|
||||
// update for CPU cases.
|
||||
@ -461,14 +471,14 @@ class AOTInductorModelContainer {
|
||||
// make sure no one is executing the model.
|
||||
std::shared_mutex model_exec_mutex_;
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#if defined(USE_CUDA) || defined(USE_XPU)
|
||||
void* get_constant_blob_ptr(bool get_inactive) {
|
||||
if ((get_inactive && use_secondary_) ||
|
||||
(!get_inactive && !use_secondary_)) {
|
||||
return constant_blob_.get();
|
||||
} else {
|
||||
if (!constant_blob_secondary_) {
|
||||
constant_blob_secondary_ = RAII_cudaMalloc(blob_size_);
|
||||
constant_blob_secondary_ = RAII_gpuMalloc(blob_size_);
|
||||
}
|
||||
return constant_blob_secondary_.get();
|
||||
}
|
||||
|
@ -15,6 +15,11 @@ inline void delete_xpu_guard(void* ptr) {
|
||||
aoti_torch_delete_xpu_guard(reinterpret_cast<XPUGuardHandle>(ptr)));
|
||||
}
|
||||
|
||||
inline void delete_xpu_stream_guard(void* ptr) {
|
||||
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_xpu_stream_guard(
|
||||
reinterpret_cast<XPUStreamGuardHandle>(ptr)));
|
||||
}
|
||||
|
||||
class AOTIXpuGuard {
|
||||
public:
|
||||
AOTIXpuGuard(int32_t device_index) : guard_(nullptr, delete_xpu_guard) {
|
||||
@ -32,5 +37,20 @@ class AOTIXpuGuard {
|
||||
private:
|
||||
std::unique_ptr<XPUGuardOpaque, DeleterFnPtr> guard_;
|
||||
};
|
||||
|
||||
class AOTIXpuStreamGuard {
|
||||
public:
|
||||
AOTIXpuStreamGuard(void* stream, int32_t device_index)
|
||||
: guard_(nullptr, delete_xpu_stream_guard) {
|
||||
XPUStreamGuardHandle ptr = nullptr;
|
||||
AOTI_TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_create_xpu_stream_guard(stream, device_index, &ptr));
|
||||
guard_.reset(ptr);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<XPUStreamGuardOpaque, DeleterFnPtr> guard_;
|
||||
};
|
||||
|
||||
} // namespace torch::aot_inductor
|
||||
#endif // USE_XPU
|
||||
|
@ -25,9 +25,26 @@ aoti_torch_xpu_guard_set_index(XPUGuardHandle guard, int32_t device_index);
|
||||
struct XPUStreamGuardOpaque;
|
||||
using XPUStreamGuardHandle = XPUStreamGuardOpaque*;
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_xpu_stream_guard(
|
||||
void* stream,
|
||||
int32_t device_index,
|
||||
XPUStreamGuardHandle* ret_guard // returns new reference
|
||||
);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_delete_xpu_stream_guard(XPUStreamGuardHandle guard);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_get_current_xpu_stream(int32_t device_index, void** ret_stream);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_get_current_xpu_device(int32_t* device_index);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_set_current_xpu_device(const int32_t& device_index);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_sycl_queue(void** ret);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
@ -30,8 +30,46 @@ AOTITorchError aoti_torch_xpu_guard_set_index(
|
||||
{ reinterpret_cast<at::DeviceGuard*>(guard)->set_index(device_index); });
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_create_xpu_stream_guard(
|
||||
void* stream,
|
||||
int32_t device_index,
|
||||
XPUStreamGuardHandle* ret_guard) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
assert(stream);
|
||||
at::StreamGuard* guard =
|
||||
new at::StreamGuard(at::xpu::getStreamFromExternal(
|
||||
static_cast<sycl::queue*>(stream), device_index)
|
||||
.unwrap());
|
||||
*ret_guard = reinterpret_cast<XPUStreamGuardHandle>(guard);
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_delete_xpu_stream_guard(XPUStreamGuardHandle guard) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
|
||||
{ delete reinterpret_cast<at::StreamGuard*>(guard); });
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_get_current_xpu_stream(int32_t device_index, void** ret_stream) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
|
||||
{ *ret_stream = &(at::xpu::getCurrentXPUStream(device_index).queue()); });
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_get_current_xpu_device(int32_t* device_index) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
|
||||
{ *device_index = static_cast<int32_t>(c10::xpu::current_device()); });
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_set_current_xpu_device(const int32_t& device_index) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
|
||||
{ c10::xpu::set_device(static_cast<int8_t>(device_index)); });
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_sycl_queue(void** ret) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
int32_t device_index = static_cast<int32_t>(c10::xpu::current_device());
|
||||
*ret = &(at::xpu::getCurrentXPUStream(device_index).queue());
|
||||
});
|
||||
}
|
||||
|
Reference in New Issue
Block a user