[XPU] Enhance XPUGeneratorImpl functionality to support XPUGraph (#163332)

As this [XPUGraph RFC](https://github.com/pytorch/pytorch/issues/162143) descripted. This PR enhances `XPUGeneratorImpl` to support XPUGraph.
In this PR, we add `XPUGerneratorState` and `PhiloxXpuState`. Which makes XPUGraph update philox state during graph capture and replay correctly

XPUGraph PR submission plan:

- [ ] 1, Enhance XPUGenerator functionality. Add XPUGeneratorState and philoxState
- [ ] 2, implemenet XPUGraph capture_begin/capture_end/instantiate functionality
- [ ] 3, implemenet XPUGraph replay/debug_dump/reset functionality
- [ ] 4, python APIs: is_current_stream_capturing/graph_pool_handle/graph
- [ ] 5, python APIs: make_graphed_callables

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163332
Approved by: https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/albanD
This commit is contained in:
Ma, Jing1
2025-10-13 02:10:41 +00:00
committed by PyTorch MergeBot
parent 8de85896e0
commit 59ad8f1ac6
6 changed files with 264 additions and 22 deletions

View File

@ -80,3 +80,19 @@ TEST(XpuGeneratorTest, testMultithreadingGetSetCurrentSeed) {
t2.join();
EXPECT_EQ(gen1.current_seed(), initial_seed+3);
}
TEST(XpuGeneratorTest, testRNGForking) {
// See Note [Acquire lock when using random generators]
if (!at::xpu::is_available()) return;
auto default_gen = at::xpu::detail::getDefaultXPUGenerator();
auto current_gen = at::xpu::detail::createXPUGenerator();
{
std::lock_guard<std::mutex> lock(default_gen.mutex());
current_gen = default_gen.clone(); // capture the current state of default generator
}
auto target_value = at::randn({1000}, at::kXPU);
// Dramatically alter the internal state of the main generator
auto x = at::randn({100000}, at::kXPU);
auto forked_value = at::randn({1000}, current_gen, at::kXPU);
ASSERT_EQ(target_value.sum().item<double>(), forked_value.sum().item<double>());
}

View File

@ -0,0 +1,45 @@
#pragma once
namespace at {
struct PhiloxXpuState {
PhiloxXpuState() = default;
PhiloxXpuState(uint64_t seed, uint64_t offset) {
seed_.val = seed;
offset_.val = offset;
}
// for graph capture
PhiloxXpuState(
int64_t* seed,
int64_t* offset_extragraph,
uint32_t offset_intragraph) {
seed_.ptr = seed;
offset_.ptr = offset_extragraph;
offset_intragraph_ = offset_intragraph;
captured_ = true;
}
union Payload {
uint64_t val;
int64_t* ptr;
};
Payload seed_{};
Payload offset_{};
uint32_t offset_intragraph_ = 0;
bool captured_ = false;
};
namespace xpu::philox {
inline std::tuple<uint64_t, uint64_t> unpack(at::PhiloxXpuState arg) {
if (arg.captured_) {
return std::make_tuple(
static_cast<uint64_t>(*arg.seed_.ptr),
static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
} else {
return std::make_tuple(arg.seed_.val, arg.offset_.val);
}
}
} // namespace xpu::philox
} // namespace at

View File

@ -1,9 +1,14 @@
#include <ATen/Functions.h>
#include <ATen/Tensor.h>
#include <ATen/Utils.h>
#include <ATen/xpu/XPUGeneratorImpl.h>
#include <ATen/xpu/XPUGraphsUtils.h>
#include <c10/core/StreamGuard.h>
#include <c10/util/CallOnce.h>
#include <c10/xpu/XPUFunctions.h>
constexpr uint64_t PHILOX_ROUND_SIZE = 4;
namespace at {
namespace xpu::detail {
namespace {
@ -58,29 +63,82 @@ Generator createXPUGenerator(DeviceIndex device) {
} // namespace xpu::detail
// Creates a clone of this XPU Generator State.
c10::intrusive_ptr<XPUGeneratorState> XPUGeneratorState::clone() {
return make_intrusive<XPUGeneratorState>(
seed_, philox_offset_per_thread_, offset_intragraph_);
}
// Function to increase the internal offset based on the specified increment.
void XPUGeneratorState::increase(uint64_t increment) {
increment = ((increment + PHILOX_ROUND_SIZE - 1) / PHILOX_ROUND_SIZE) *
PHILOX_ROUND_SIZE;
if (at::xpu::currentStreamCaptureStatus() !=
at::xpu::CaptureStatus::Executing) {
TORCH_INTERNAL_ASSERT(
capturing_,
"Attempt to increase offset for a XPU generator not in capture mode.");
TORCH_INTERNAL_ASSERT(
offset_intragraph_ % 4 == 0, "RNG offset must be a multiple of 4.");
TORCH_INTERNAL_ASSERT(
offset_intragraph_ <= std::numeric_limits<uint32_t>::max() - increment,
"Increment causes overflow in the offset value.");
offset_intragraph_ += increment;
} else {
TORCH_INTERNAL_ASSERT(
!capturing_,
"Offset increment outside graph capture encountered unexpectedly.");
TORCH_INTERNAL_ASSERT(
philox_offset_per_thread_ % 4 == 0,
"RNG offset must be a multiple of 4.");
philox_offset_per_thread_ += increment;
}
}
XPUGeneratorImpl::XPUGeneratorImpl(DeviceIndex device_index)
: GeneratorImpl{
Device(DeviceType::XPU, device_index),
DispatchKeySet(c10::DispatchKey::XPU)} {}
DispatchKeySet(c10::DispatchKey::XPU)} {
at::xpu::assertNotCapturing("Cannot construct a new XPUGeneratorImpl");
state_ = make_intrusive<XPUGeneratorState>();
}
XPUGeneratorImpl::XPUGeneratorImpl(
DeviceIndex device_index,
intrusive_ptr<XPUGeneratorState> state)
: GeneratorImpl{Device(DeviceType::XPU, device_index), DispatchKeySet(c10::DispatchKey::XPU)},
state_(std::move(state)) {}
void XPUGeneratorImpl::set_current_seed(uint64_t seed) {
seed_ = seed;
set_philox_offset_per_thread(0);
if (C10_LIKELY(
at::xpu::currentStreamCaptureStatus() ==
at::xpu::CaptureStatus::Executing)) {
state_->seed_ = seed;
state_->philox_offset_per_thread_ = 0;
} else {
TORCH_CHECK(
state_->seed_ == seed,
"XPUGeneratorImpl::set_current_seed can be called during stream capture only if new seed is the same as the original seed.");
}
}
void XPUGeneratorImpl::set_offset(uint64_t offset) {
at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::set_offset");
set_philox_offset_per_thread(offset);
}
uint64_t XPUGeneratorImpl::get_offset() const {
return philox_offset_per_thread_;
at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::get_offset");
return state_->philox_offset_per_thread_;
}
uint64_t XPUGeneratorImpl::current_seed() const {
return seed_;
at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::current_seed");
return state_->seed_;
}
uint64_t XPUGeneratorImpl::seed() {
at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::seed");
auto random = c10::detail::getNonDeterministicRandom(true);
this->set_current_seed(random);
return random;
@ -110,39 +168,65 @@ c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
}
void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
at::xpu::assertNotCapturing(
"Please ensure to utilize the XPUGeneratorImpl::set_state_index method during capturing.");
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(uint64_t);
static const size_t total_size = seed_size + offset_size;
at::detail::check_rng_state(new_state);
auto new_state_size = new_state.numel();
TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
uint64_t input_seed;
bool no_philox_seed = false;
auto new_state_size = new_state.numel();
if (new_state_size == total_size - offset_size) {
no_philox_seed = true;
} else {
TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
}
uint64_t input_seed = 0;
auto new_rng_state = new_state.data_dtype_initialized<uint8_t>();
memcpy(&input_seed, new_rng_state, seed_size);
this->set_current_seed(input_seed);
uint64_t philox_offset;
uint64_t philox_offset = 0;
if (!no_philox_seed) {
memcpy(&philox_offset, new_rng_state + seed_size, offset_size);
}
this->set_philox_offset_per_thread(philox_offset);
}
void XPUGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) {
TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4");
philox_offset_per_thread_ = offset;
state_->philox_offset_per_thread_ = offset;
}
uint64_t XPUGeneratorImpl::philox_offset_per_thread() const {
return philox_offset_per_thread_;
return state_->philox_offset_per_thread_;
}
PhiloxXpuState XPUGeneratorImpl::philox_xpu_state(uint64_t increment) {
if (at::xpu::currentStreamCaptureStatus() !=
at::xpu::CaptureStatus::Executing) {
uint32_t offset = state_->offset_intragraph_;
state_->increase(increment);
return PhiloxXpuState(
state_->seed_extragraph_.data_ptr<int64_t>(),
state_->offset_extragraph_.data_ptr<int64_t>(),
offset);
} else {
uint64_t offset = state_->philox_offset_per_thread_;
state_->increase(increment);
return PhiloxXpuState(state_->seed_, offset);
}
}
std::pair<uint64_t, uint64_t> XPUGeneratorImpl::philox_engine_inputs(
uint64_t increment) {
increment = ((increment + 3) / 4) * 4;
TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0);
uint64_t offset = this->philox_offset_per_thread_;
this->philox_offset_per_thread_ += increment;
return std::make_pair(this->seed_, offset);
at::xpu::assertNotCapturing(
"Refactor this op to use XPUGeneratorImpl::philox_xpu_state. Cannot call XPUGeneratorImpl::philox_engine_inputs");
uint64_t offset = state_->philox_offset_per_thread_;
state_->increase(increment);
return std::make_pair(state_->seed_, offset);
}
DeviceType XPUGeneratorImpl::device_type() {
@ -154,9 +238,8 @@ std::shared_ptr<XPUGeneratorImpl> XPUGeneratorImpl::clone() const {
}
XPUGeneratorImpl* XPUGeneratorImpl::clone_impl() const {
auto gen = new XPUGeneratorImpl(this->device().index());
gen->set_current_seed(this->seed_);
gen->set_philox_offset_per_thread(this->philox_offset_per_thread_);
at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::clone_impl");
auto gen = new XPUGeneratorImpl(this->device().index(), state_->clone());
return gen;
}

View File

@ -1,12 +1,43 @@
#pragma once
#include <ATen/core/Generator.h>
#include <ATen/core/TensorBase.h>
#include <ATen/xpu/PhiloxXpuState.h>
#include <unordered_set>
namespace at {
namespace xpu {
struct XPUGraph;
}
struct XPUGeneratorState : public c10::intrusive_ptr_target {
uint64_t seed_;
uint64_t philox_offset_per_thread_;
uint32_t offset_intragraph_;
bool capturing_{};
at::TensorBase seed_extragraph_{};
at::TensorBase offset_extragraph_{};
XPUGeneratorState(
uint64_t seed = default_rng_seed_val,
uint64_t philox_offset_per_thread = 0,
uint32_t offset_intragraph = 0)
: seed_(seed),
philox_offset_per_thread_(philox_offset_per_thread),
offset_intragraph_(offset_intragraph) {}
void increase(uint64_t increment);
c10::intrusive_ptr<XPUGeneratorState> clone();
};
struct TORCH_XPU_API XPUGeneratorImpl : public GeneratorImpl {
// Constructors
XPUGeneratorImpl(DeviceIndex device_index = -1);
XPUGeneratorImpl(
DeviceIndex device_index,
c10::intrusive_ptr<XPUGeneratorState> state_);
~XPUGeneratorImpl() override = default;
// XPUGeneratorImpl methods
@ -18,15 +49,18 @@ struct TORCH_XPU_API XPUGeneratorImpl : public GeneratorImpl {
uint64_t seed() override;
void set_state(const c10::TensorImpl& new_state) override;
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
void set_philox_offset_per_thread(uint64_t offset);
uint64_t philox_offset_per_thread() const;
PhiloxXpuState philox_xpu_state(uint64_t increment);
// will remove once all ops are refactored to use philox_xpu_state.
std::pair<uint64_t, uint64_t> philox_engine_inputs(uint64_t increment);
static c10::DeviceType device_type();
private:
XPUGeneratorImpl* clone_impl() const override;
uint64_t seed_ = default_rng_seed_val;
uint64_t philox_offset_per_thread_ = 0;
c10::intrusive_ptr<XPUGeneratorState> state_;
};
namespace xpu::detail {

View File

@ -0,0 +1,22 @@
#pragma once
#include <c10/xpu/XPUGraphsC10Utils.h>
namespace at::xpu {
inline CaptureStatus currentStreamCaptureStatus() {
return c10::xpu::currentStreamCaptureStatusMayInitCtx();
}
inline void assertNotCapturing(const std::string& attempt) {
auto status = currentStreamCaptureStatus();
TORCH_CHECK(
status == CaptureStatus::Executing,
attempt,
" during XPU graph capture. If you need this call to be captured, "
"please file an issue. "
"Current xpuStreamCaptureStatus: ",
status);
}
} // namespace at::xpu

View File

@ -0,0 +1,42 @@
#pragma once
#include <c10/xpu/XPUStream.h>
#include <iostream>
// XPU Graphs utils used by c10 and aten.
using namespace sycl::ext::oneapi::experimental;
namespace c10::xpu {
static_assert(
int8_t(queue_state::executing) == 0,
"unexpected int(queue_state::executing) value");
static_assert(
int8_t(queue_state::recording) == 1,
"unexpected int(queue_state::recording) value");
enum class CaptureStatus : int8_t {
Executing = int8_t(queue_state::executing),
Recording = int8_t(queue_state::recording)
};
inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
switch (status) {
case CaptureStatus::Executing:
os << "Executing";
break;
case CaptureStatus::Recording:
os << "Recording";
break;
default:
TORCH_INTERNAL_ASSERT(
false, "Unknown XPU graph CaptureStatus", int(status));
}
return os;
}
inline CaptureStatus currentStreamCaptureStatusMayInitCtx() {
auto state = c10::xpu::getCurrentXPUStream().queue().ext_oneapi_get_state();
return CaptureStatus(state);
}
} // namespace c10::xpu