mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Intel GPU]device guard codegen for XPU (#133980)
This PR is a supplement to #130082. The previous PR #130082 fulfill the basic functionality of codegen, while we found it fails to handle the device sameness check in lots of uts. Current PR is aimed to facilitate the XPU device guard code generation. With current PR, the code snippet in `RegisterXPU.cpp` is as follows, where we can see the device guard is successfully generated. ```c++ namespace { at::Tensor & wrapper_XPU_Tensor_float_out_normal_out(const at::Tensor & mean, double std, ::std::optional<at::Generator> generator, at::Tensor & out) { std::optional<Device> common_device = std::nullopt; (void)common_device; // Suppress unused variable warning c10::impl::check_and_update_common_device(common_device, out, "wrapper_XPU_Tensor_float_out_normal_out", "out"); c10::impl::check_and_update_common_device(common_device, mean, "wrapper_XPU_Tensor_float_out_normal_out", "mean"); const OptionalDeviceGuard device_guard(device_of(out)); return at::native::normal_out(mean, std, generator, out); } } // anonymous namespace ``` Nevertheless, without current change, the generated code is ```c++ namespace { at::Tensor & wrapper_XPU_Tensor_float_out_normal_out(const at::Tensor & mean, double std, ::std::optional<at::Generator> generator, at::Tensor & out) { // No device check // DeviceGuard omitted return at::native::normal_out(mean, std, generator, out); } } // anonymous namespace ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/133980 Approved by: https://github.com/EikanWang, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
30b98940b8
commit
a7a53b796b
@ -59,6 +59,7 @@ from torchgen.model import (
|
||||
is_cuda_dispatch_key,
|
||||
is_generic_dispatch_key,
|
||||
is_ufunc_dispatch_key,
|
||||
is_xpu_dispatch_key,
|
||||
Location,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
@ -184,7 +185,7 @@ def parse_native_yaml_struct(
|
||||
use_out_as_primary=True,
|
||||
external=False,
|
||||
# Only cuda-like devices in tree require device guards
|
||||
device_guard=is_cuda_dispatch_key(k),
|
||||
device_guard=is_cuda_dispatch_key(k) or is_xpu_dispatch_key(k),
|
||||
index=v,
|
||||
)
|
||||
return ParsedYaml(rs, indices)
|
||||
|
@ -323,6 +323,18 @@ def is_cuda_dispatch_key(dk: DispatchKey) -> bool:
|
||||
}
|
||||
|
||||
|
||||
# XPU specific dispatcy keys
|
||||
def is_xpu_dispatch_key(dk: DispatchKey) -> bool:
|
||||
return dk in {
|
||||
DispatchKey.XPU,
|
||||
DispatchKey.QuantizedXPU,
|
||||
DispatchKey.SparseXPU,
|
||||
DispatchKey.SparseCsrXPU,
|
||||
DispatchKey.NestedTensorXPU,
|
||||
DispatchKey.AutogradXPU,
|
||||
}
|
||||
|
||||
|
||||
# Structured kernel generation is only supported for certain key types;
|
||||
# otherwise use old-style
|
||||
def is_structured_dispatch_key(dk: DispatchKey) -> bool:
|
||||
|
Reference in New Issue
Block a user