mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR is a supplement to https://github.com/pytorch/pytorch/pull/133980. The previous PR fulfill the basic functionality of XPU device guard, while we found it fails to address structured operators. With current PR, the code snippet in RegisterXPU.cpp is as follows, where we can see the device guard is successfully generated. ```c++ struct structured_exp_out_functional final : public at::native::structured_exp_out { void set_output_strided( int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names ) override { auto current_device = guard_.current_device(); if (C10_UNLIKELY(current_device.has_value())) { TORCH_INTERNAL_ASSERT(*current_device == options.device(), "structured kernels don't support multi-device outputs"); } else { guard_.reset_device(options.device()); } outputs_[output_idx] = create_out(sizes, strides, options); if (!names.empty()) { namedinference::propagate_names(outputs_[output_idx], names); } // super must happen after, so that downstream can use maybe_get_output // to retrieve the output at::native::structured_exp_out::set_output_raw_strided(output_idx, sizes, strides, options, names); } void set_output_raw_strided( int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names ) override { auto current_device = guard_.current_device(); if (C10_UNLIKELY(current_device.has_value())) { TORCH_INTERNAL_ASSERT(*current_device == options.device(), "structured kernels don't support multi-device outputs"); } else { guard_.reset_device(options.device()); } outputs_[output_idx] = create_out(sizes, strides, options); if (!names.empty()) { namedinference::propagate_names(outputs_[output_idx], names); } // super must happen after, so that downstream can use maybe_get_output // to retrieve the output at::native::structured_exp_out::set_output_raw_strided(output_idx, sizes, strides, options, names); } const Tensor& maybe_get_output(int64_t output_idx) override { return outputs_[output_idx]; } std::array<Tensor, 1> outputs_; c10::OptionalDeviceGuard guard_; }; ``` However, without current change, the generated code is ```c++ struct structured_exp_out_functional final : public at::native::structured_exp_out { void set_output_strided( int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names ) override { outputs_[output_idx] = create_out(sizes, strides, options); if (!names.empty()) { namedinference::propagate_names(outputs_[output_idx], names); } // super must happen after, so that downstream can use maybe_get_output // to retrieve the output at::native::structured_exp_out::set_output_raw_strided(output_idx, sizes, strides, options, names); } void set_output_raw_strided( int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names ) override { outputs_[output_idx] = create_out(sizes, strides, options); if (!names.empty()) { namedinference::propagate_names(outputs_[output_idx], names); } // super must happen after, so that downstream can use maybe_get_output // to retrieve the output at::native::structured_exp_out::set_output_raw_strided(output_idx, sizes, strides, options, names); } const Tensor& maybe_get_output(int64_t output_idx) override { return outputs_[output_idx]; } std::array<Tensor, 1> outputs_; }; ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/138802 Approved by: https://github.com/EikanWang, https://github.com/guangyey, https://github.com/ezyang