mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This PR is a supplement to #130082. The previous PR #130082 fulfill the basic functionality of codegen, while we found it fails to handle the device sameness check in lots of uts. Current PR is aimed to facilitate the XPU device guard code generation. With current PR, the code snippet in `RegisterXPU.cpp` is as follows, where we can see the device guard is successfully generated. ```c++ namespace { at::Tensor & wrapper_XPU_Tensor_float_out_normal_out(const at::Tensor & mean, double std, ::std::optional<at::Generator> generator, at::Tensor & out) { std::optional<Device> common_device = std::nullopt; (void)common_device; // Suppress unused variable warning c10::impl::check_and_update_common_device(common_device, out, "wrapper_XPU_Tensor_float_out_normal_out", "out"); c10::impl::check_and_update_common_device(common_device, mean, "wrapper_XPU_Tensor_float_out_normal_out", "mean"); const OptionalDeviceGuard device_guard(device_of(out)); return at::native::normal_out(mean, std, generator, out); } } // anonymous namespace ``` Nevertheless, without current change, the generated code is ```c++ namespace { at::Tensor & wrapper_XPU_Tensor_float_out_normal_out(const at::Tensor & mean, double std, ::std::optional<at::Generator> generator, at::Tensor & out) { // No device check // DeviceGuard omitted return at::native::normal_out(mean, std, generator, out); } } // anonymous namespace ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/133980 Approved by: https://github.com/EikanWang, https://github.com/malfet