diff --git a/aten/src/ATen/test/xpu_generator_test.cpp b/aten/src/ATen/test/xpu_generator_test.cpp index f47ca4d72118..0b915c1b0cc9 100644 --- a/aten/src/ATen/test/xpu_generator_test.cpp +++ b/aten/src/ATen/test/xpu_generator_test.cpp @@ -80,3 +80,19 @@ TEST(XpuGeneratorTest, testMultithreadingGetSetCurrentSeed) { t2.join(); EXPECT_EQ(gen1.current_seed(), initial_seed+3); } + +TEST(XpuGeneratorTest, testRNGForking) { + // See Note [Acquire lock when using random generators] + if (!at::xpu::is_available()) return; + auto default_gen = at::xpu::detail::getDefaultXPUGenerator(); + auto current_gen = at::xpu::detail::createXPUGenerator(); + { + std::lock_guard lock(default_gen.mutex()); + current_gen = default_gen.clone(); // capture the current state of default generator + } + auto target_value = at::randn({1000}, at::kXPU); + // Dramatically alter the internal state of the main generator + auto x = at::randn({100000}, at::kXPU); + auto forked_value = at::randn({1000}, current_gen, at::kXPU); + ASSERT_EQ(target_value.sum().item(), forked_value.sum().item()); +} diff --git a/aten/src/ATen/xpu/PhiloxXpuState.h b/aten/src/ATen/xpu/PhiloxXpuState.h new file mode 100644 index 000000000000..039b992b89ba --- /dev/null +++ b/aten/src/ATen/xpu/PhiloxXpuState.h @@ -0,0 +1,45 @@ +#pragma once + +namespace at { + +struct PhiloxXpuState { + PhiloxXpuState() = default; + PhiloxXpuState(uint64_t seed, uint64_t offset) { + seed_.val = seed; + offset_.val = offset; + } + // for graph capture + PhiloxXpuState( + int64_t* seed, + int64_t* offset_extragraph, + uint32_t offset_intragraph) { + seed_.ptr = seed; + offset_.ptr = offset_extragraph; + offset_intragraph_ = offset_intragraph; + captured_ = true; + } + + union Payload { + uint64_t val; + int64_t* ptr; + }; + + Payload seed_{}; + Payload offset_{}; + uint32_t offset_intragraph_ = 0; + bool captured_ = false; +}; + +namespace xpu::philox { +inline std::tuple unpack(at::PhiloxXpuState arg) { + if (arg.captured_) { + return std::make_tuple( + static_cast(*arg.seed_.ptr), + static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_)); + } else { + return std::make_tuple(arg.seed_.val, arg.offset_.val); + } +} + +} // namespace xpu::philox +} // namespace at diff --git a/aten/src/ATen/xpu/XPUGeneratorImpl.cpp b/aten/src/ATen/xpu/XPUGeneratorImpl.cpp index 1af0f4f890df..14f3059cc2b3 100644 --- a/aten/src/ATen/xpu/XPUGeneratorImpl.cpp +++ b/aten/src/ATen/xpu/XPUGeneratorImpl.cpp @@ -1,9 +1,14 @@ +#include +#include #include #include +#include #include #include #include +constexpr uint64_t PHILOX_ROUND_SIZE = 4; + namespace at { namespace xpu::detail { namespace { @@ -58,29 +63,82 @@ Generator createXPUGenerator(DeviceIndex device) { } // namespace xpu::detail +// Creates a clone of this XPU Generator State. +c10::intrusive_ptr XPUGeneratorState::clone() { + return make_intrusive( + seed_, philox_offset_per_thread_, offset_intragraph_); +} + +// Function to increase the internal offset based on the specified increment. +void XPUGeneratorState::increase(uint64_t increment) { + increment = ((increment + PHILOX_ROUND_SIZE - 1) / PHILOX_ROUND_SIZE) * + PHILOX_ROUND_SIZE; + if (at::xpu::currentStreamCaptureStatus() != + at::xpu::CaptureStatus::Executing) { + TORCH_INTERNAL_ASSERT( + capturing_, + "Attempt to increase offset for a XPU generator not in capture mode."); + TORCH_INTERNAL_ASSERT( + offset_intragraph_ % 4 == 0, "RNG offset must be a multiple of 4."); + TORCH_INTERNAL_ASSERT( + offset_intragraph_ <= std::numeric_limits::max() - increment, + "Increment causes overflow in the offset value."); + offset_intragraph_ += increment; + } else { + TORCH_INTERNAL_ASSERT( + !capturing_, + "Offset increment outside graph capture encountered unexpectedly."); + TORCH_INTERNAL_ASSERT( + philox_offset_per_thread_ % 4 == 0, + "RNG offset must be a multiple of 4."); + philox_offset_per_thread_ += increment; + } +} + XPUGeneratorImpl::XPUGeneratorImpl(DeviceIndex device_index) : GeneratorImpl{ Device(DeviceType::XPU, device_index), - DispatchKeySet(c10::DispatchKey::XPU)} {} + DispatchKeySet(c10::DispatchKey::XPU)} { + at::xpu::assertNotCapturing("Cannot construct a new XPUGeneratorImpl"); + state_ = make_intrusive(); +} + +XPUGeneratorImpl::XPUGeneratorImpl( + DeviceIndex device_index, + intrusive_ptr state) + : GeneratorImpl{Device(DeviceType::XPU, device_index), DispatchKeySet(c10::DispatchKey::XPU)}, + state_(std::move(state)) {} void XPUGeneratorImpl::set_current_seed(uint64_t seed) { - seed_ = seed; - set_philox_offset_per_thread(0); + if (C10_LIKELY( + at::xpu::currentStreamCaptureStatus() == + at::xpu::CaptureStatus::Executing)) { + state_->seed_ = seed; + state_->philox_offset_per_thread_ = 0; + } else { + TORCH_CHECK( + state_->seed_ == seed, + "XPUGeneratorImpl::set_current_seed can be called during stream capture only if new seed is the same as the original seed."); + } } void XPUGeneratorImpl::set_offset(uint64_t offset) { + at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::set_offset"); set_philox_offset_per_thread(offset); } uint64_t XPUGeneratorImpl::get_offset() const { - return philox_offset_per_thread_; + at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::get_offset"); + return state_->philox_offset_per_thread_; } uint64_t XPUGeneratorImpl::current_seed() const { - return seed_; + at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::current_seed"); + return state_->seed_; } uint64_t XPUGeneratorImpl::seed() { + at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::seed"); auto random = c10::detail::getNonDeterministicRandom(true); this->set_current_seed(random); return random; @@ -110,39 +168,65 @@ c10::intrusive_ptr XPUGeneratorImpl::get_state() const { } void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) { + at::xpu::assertNotCapturing( + "Please ensure to utilize the XPUGeneratorImpl::set_state_index method during capturing."); static const size_t seed_size = sizeof(uint64_t); static const size_t offset_size = sizeof(uint64_t); static const size_t total_size = seed_size + offset_size; at::detail::check_rng_state(new_state); - auto new_state_size = new_state.numel(); - TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size"); - uint64_t input_seed; + bool no_philox_seed = false; + auto new_state_size = new_state.numel(); + if (new_state_size == total_size - offset_size) { + no_philox_seed = true; + } else { + TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size"); + } + + uint64_t input_seed = 0; auto new_rng_state = new_state.data_dtype_initialized(); memcpy(&input_seed, new_rng_state, seed_size); this->set_current_seed(input_seed); - uint64_t philox_offset; - memcpy(&philox_offset, new_rng_state + seed_size, offset_size); + uint64_t philox_offset = 0; + if (!no_philox_seed) { + memcpy(&philox_offset, new_rng_state + seed_size, offset_size); + } this->set_philox_offset_per_thread(philox_offset); } void XPUGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) { TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4"); - philox_offset_per_thread_ = offset; + state_->philox_offset_per_thread_ = offset; } uint64_t XPUGeneratorImpl::philox_offset_per_thread() const { - return philox_offset_per_thread_; + return state_->philox_offset_per_thread_; +} + +PhiloxXpuState XPUGeneratorImpl::philox_xpu_state(uint64_t increment) { + if (at::xpu::currentStreamCaptureStatus() != + at::xpu::CaptureStatus::Executing) { + uint32_t offset = state_->offset_intragraph_; + state_->increase(increment); + return PhiloxXpuState( + state_->seed_extragraph_.data_ptr(), + state_->offset_extragraph_.data_ptr(), + offset); + } else { + uint64_t offset = state_->philox_offset_per_thread_; + state_->increase(increment); + return PhiloxXpuState(state_->seed_, offset); + } } std::pair XPUGeneratorImpl::philox_engine_inputs( uint64_t increment) { - increment = ((increment + 3) / 4) * 4; - TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0); - uint64_t offset = this->philox_offset_per_thread_; - this->philox_offset_per_thread_ += increment; - return std::make_pair(this->seed_, offset); + at::xpu::assertNotCapturing( + "Refactor this op to use XPUGeneratorImpl::philox_xpu_state. Cannot call XPUGeneratorImpl::philox_engine_inputs"); + uint64_t offset = state_->philox_offset_per_thread_; + state_->increase(increment); + return std::make_pair(state_->seed_, offset); } DeviceType XPUGeneratorImpl::device_type() { @@ -154,9 +238,8 @@ std::shared_ptr XPUGeneratorImpl::clone() const { } XPUGeneratorImpl* XPUGeneratorImpl::clone_impl() const { - auto gen = new XPUGeneratorImpl(this->device().index()); - gen->set_current_seed(this->seed_); - gen->set_philox_offset_per_thread(this->philox_offset_per_thread_); + at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::clone_impl"); + auto gen = new XPUGeneratorImpl(this->device().index(), state_->clone()); return gen; } diff --git a/aten/src/ATen/xpu/XPUGeneratorImpl.h b/aten/src/ATen/xpu/XPUGeneratorImpl.h index a1f264382a36..331f7387a629 100644 --- a/aten/src/ATen/xpu/XPUGeneratorImpl.h +++ b/aten/src/ATen/xpu/XPUGeneratorImpl.h @@ -1,12 +1,43 @@ #pragma once #include +#include +#include +#include namespace at { +namespace xpu { +struct XPUGraph; +} + +struct XPUGeneratorState : public c10::intrusive_ptr_target { + uint64_t seed_; + uint64_t philox_offset_per_thread_; + uint32_t offset_intragraph_; + bool capturing_{}; + at::TensorBase seed_extragraph_{}; + at::TensorBase offset_extragraph_{}; + + XPUGeneratorState( + uint64_t seed = default_rng_seed_val, + uint64_t philox_offset_per_thread = 0, + uint32_t offset_intragraph = 0) + : seed_(seed), + philox_offset_per_thread_(philox_offset_per_thread), + offset_intragraph_(offset_intragraph) {} + + void increase(uint64_t increment); + + c10::intrusive_ptr clone(); +}; + struct TORCH_XPU_API XPUGeneratorImpl : public GeneratorImpl { // Constructors XPUGeneratorImpl(DeviceIndex device_index = -1); + XPUGeneratorImpl( + DeviceIndex device_index, + c10::intrusive_ptr state_); ~XPUGeneratorImpl() override = default; // XPUGeneratorImpl methods @@ -18,15 +49,18 @@ struct TORCH_XPU_API XPUGeneratorImpl : public GeneratorImpl { uint64_t seed() override; void set_state(const c10::TensorImpl& new_state) override; c10::intrusive_ptr get_state() const override; + void set_philox_offset_per_thread(uint64_t offset); uint64_t philox_offset_per_thread() const; + + PhiloxXpuState philox_xpu_state(uint64_t increment); + // will remove once all ops are refactored to use philox_xpu_state. std::pair philox_engine_inputs(uint64_t increment); static c10::DeviceType device_type(); private: XPUGeneratorImpl* clone_impl() const override; - uint64_t seed_ = default_rng_seed_val; - uint64_t philox_offset_per_thread_ = 0; + c10::intrusive_ptr state_; }; namespace xpu::detail { diff --git a/aten/src/ATen/xpu/XPUGraphsUtils.h b/aten/src/ATen/xpu/XPUGraphsUtils.h new file mode 100644 index 000000000000..b18fe4ef0417 --- /dev/null +++ b/aten/src/ATen/xpu/XPUGraphsUtils.h @@ -0,0 +1,22 @@ +#pragma once + +#include + +namespace at::xpu { + +inline CaptureStatus currentStreamCaptureStatus() { + return c10::xpu::currentStreamCaptureStatusMayInitCtx(); +} + +inline void assertNotCapturing(const std::string& attempt) { + auto status = currentStreamCaptureStatus(); + TORCH_CHECK( + status == CaptureStatus::Executing, + attempt, + " during XPU graph capture. If you need this call to be captured, " + "please file an issue. " + "Current xpuStreamCaptureStatus: ", + status); +} + +} // namespace at::xpu diff --git a/c10/xpu/XPUGraphsC10Utils.h b/c10/xpu/XPUGraphsC10Utils.h new file mode 100644 index 000000000000..b60fc4ac30a6 --- /dev/null +++ b/c10/xpu/XPUGraphsC10Utils.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +// XPU Graphs utils used by c10 and aten. +using namespace sycl::ext::oneapi::experimental; +namespace c10::xpu { + +static_assert( + int8_t(queue_state::executing) == 0, + "unexpected int(queue_state::executing) value"); +static_assert( + int8_t(queue_state::recording) == 1, + "unexpected int(queue_state::recording) value"); + +enum class CaptureStatus : int8_t { + Executing = int8_t(queue_state::executing), + Recording = int8_t(queue_state::recording) +}; + +inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) { + switch (status) { + case CaptureStatus::Executing: + os << "Executing"; + break; + case CaptureStatus::Recording: + os << "Recording"; + break; + default: + TORCH_INTERNAL_ASSERT( + false, "Unknown XPU graph CaptureStatus", int(status)); + } + return os; +} + +inline CaptureStatus currentStreamCaptureStatusMayInitCtx() { + auto state = c10::xpu::getCurrentXPUStream().queue().ext_oneapi_get_state(); + return CaptureStatus(state); +} + +} // namespace c10::xpu