mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[XPU] Enhance XPUGeneratorImpl functionality to support XPUGraph (#163332)
As this [XPUGraph RFC](https://github.com/pytorch/pytorch/issues/162143) descripted. This PR enhances `XPUGeneratorImpl` to support XPUGraph. In this PR, we add `XPUGerneratorState` and `PhiloxXpuState`. Which makes XPUGraph update philox state during graph capture and replay correctly XPUGraph PR submission plan: - [ ] 1, Enhance XPUGenerator functionality. Add XPUGeneratorState and philoxState - [ ] 2, implemenet XPUGraph capture_begin/capture_end/instantiate functionality - [ ] 3, implemenet XPUGraph replay/debug_dump/reset functionality - [ ] 4, python APIs: is_current_stream_capturing/graph_pool_handle/graph - [ ] 5, python APIs: make_graphed_callables Pull Request resolved: https://github.com/pytorch/pytorch/pull/163332 Approved by: https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
8de85896e0
commit
59ad8f1ac6
@ -80,3 +80,19 @@ TEST(XpuGeneratorTest, testMultithreadingGetSetCurrentSeed) {
|
||||
t2.join();
|
||||
EXPECT_EQ(gen1.current_seed(), initial_seed+3);
|
||||
}
|
||||
|
||||
TEST(XpuGeneratorTest, testRNGForking) {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
if (!at::xpu::is_available()) return;
|
||||
auto default_gen = at::xpu::detail::getDefaultXPUGenerator();
|
||||
auto current_gen = at::xpu::detail::createXPUGenerator();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(default_gen.mutex());
|
||||
current_gen = default_gen.clone(); // capture the current state of default generator
|
||||
}
|
||||
auto target_value = at::randn({1000}, at::kXPU);
|
||||
// Dramatically alter the internal state of the main generator
|
||||
auto x = at::randn({100000}, at::kXPU);
|
||||
auto forked_value = at::randn({1000}, current_gen, at::kXPU);
|
||||
ASSERT_EQ(target_value.sum().item<double>(), forked_value.sum().item<double>());
|
||||
}
|
||||
|
45
aten/src/ATen/xpu/PhiloxXpuState.h
Normal file
45
aten/src/ATen/xpu/PhiloxXpuState.h
Normal file
@ -0,0 +1,45 @@
|
||||
#pragma once
|
||||
|
||||
namespace at {
|
||||
|
||||
struct PhiloxXpuState {
|
||||
PhiloxXpuState() = default;
|
||||
PhiloxXpuState(uint64_t seed, uint64_t offset) {
|
||||
seed_.val = seed;
|
||||
offset_.val = offset;
|
||||
}
|
||||
// for graph capture
|
||||
PhiloxXpuState(
|
||||
int64_t* seed,
|
||||
int64_t* offset_extragraph,
|
||||
uint32_t offset_intragraph) {
|
||||
seed_.ptr = seed;
|
||||
offset_.ptr = offset_extragraph;
|
||||
offset_intragraph_ = offset_intragraph;
|
||||
captured_ = true;
|
||||
}
|
||||
|
||||
union Payload {
|
||||
uint64_t val;
|
||||
int64_t* ptr;
|
||||
};
|
||||
|
||||
Payload seed_{};
|
||||
Payload offset_{};
|
||||
uint32_t offset_intragraph_ = 0;
|
||||
bool captured_ = false;
|
||||
};
|
||||
|
||||
namespace xpu::philox {
|
||||
inline std::tuple<uint64_t, uint64_t> unpack(at::PhiloxXpuState arg) {
|
||||
if (arg.captured_) {
|
||||
return std::make_tuple(
|
||||
static_cast<uint64_t>(*arg.seed_.ptr),
|
||||
static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
|
||||
} else {
|
||||
return std::make_tuple(arg.seed_.val, arg.offset_.val);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xpu::philox
|
||||
} // namespace at
|
@ -1,9 +1,14 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/xpu/XPUGeneratorImpl.h>
|
||||
#include <ATen/xpu/XPUGraphsUtils.h>
|
||||
#include <c10/core/StreamGuard.h>
|
||||
#include <c10/util/CallOnce.h>
|
||||
#include <c10/xpu/XPUFunctions.h>
|
||||
|
||||
constexpr uint64_t PHILOX_ROUND_SIZE = 4;
|
||||
|
||||
namespace at {
|
||||
namespace xpu::detail {
|
||||
namespace {
|
||||
@ -58,29 +63,82 @@ Generator createXPUGenerator(DeviceIndex device) {
|
||||
|
||||
} // namespace xpu::detail
|
||||
|
||||
// Creates a clone of this XPU Generator State.
|
||||
c10::intrusive_ptr<XPUGeneratorState> XPUGeneratorState::clone() {
|
||||
return make_intrusive<XPUGeneratorState>(
|
||||
seed_, philox_offset_per_thread_, offset_intragraph_);
|
||||
}
|
||||
|
||||
// Function to increase the internal offset based on the specified increment.
|
||||
void XPUGeneratorState::increase(uint64_t increment) {
|
||||
increment = ((increment + PHILOX_ROUND_SIZE - 1) / PHILOX_ROUND_SIZE) *
|
||||
PHILOX_ROUND_SIZE;
|
||||
if (at::xpu::currentStreamCaptureStatus() !=
|
||||
at::xpu::CaptureStatus::Executing) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
capturing_,
|
||||
"Attempt to increase offset for a XPU generator not in capture mode.");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
offset_intragraph_ % 4 == 0, "RNG offset must be a multiple of 4.");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
offset_intragraph_ <= std::numeric_limits<uint32_t>::max() - increment,
|
||||
"Increment causes overflow in the offset value.");
|
||||
offset_intragraph_ += increment;
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!capturing_,
|
||||
"Offset increment outside graph capture encountered unexpectedly.");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
philox_offset_per_thread_ % 4 == 0,
|
||||
"RNG offset must be a multiple of 4.");
|
||||
philox_offset_per_thread_ += increment;
|
||||
}
|
||||
}
|
||||
|
||||
XPUGeneratorImpl::XPUGeneratorImpl(DeviceIndex device_index)
|
||||
: GeneratorImpl{
|
||||
Device(DeviceType::XPU, device_index),
|
||||
DispatchKeySet(c10::DispatchKey::XPU)} {}
|
||||
DispatchKeySet(c10::DispatchKey::XPU)} {
|
||||
at::xpu::assertNotCapturing("Cannot construct a new XPUGeneratorImpl");
|
||||
state_ = make_intrusive<XPUGeneratorState>();
|
||||
}
|
||||
|
||||
XPUGeneratorImpl::XPUGeneratorImpl(
|
||||
DeviceIndex device_index,
|
||||
intrusive_ptr<XPUGeneratorState> state)
|
||||
: GeneratorImpl{Device(DeviceType::XPU, device_index), DispatchKeySet(c10::DispatchKey::XPU)},
|
||||
state_(std::move(state)) {}
|
||||
|
||||
void XPUGeneratorImpl::set_current_seed(uint64_t seed) {
|
||||
seed_ = seed;
|
||||
set_philox_offset_per_thread(0);
|
||||
if (C10_LIKELY(
|
||||
at::xpu::currentStreamCaptureStatus() ==
|
||||
at::xpu::CaptureStatus::Executing)) {
|
||||
state_->seed_ = seed;
|
||||
state_->philox_offset_per_thread_ = 0;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
state_->seed_ == seed,
|
||||
"XPUGeneratorImpl::set_current_seed can be called during stream capture only if new seed is the same as the original seed.");
|
||||
}
|
||||
}
|
||||
|
||||
void XPUGeneratorImpl::set_offset(uint64_t offset) {
|
||||
at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::set_offset");
|
||||
set_philox_offset_per_thread(offset);
|
||||
}
|
||||
|
||||
uint64_t XPUGeneratorImpl::get_offset() const {
|
||||
return philox_offset_per_thread_;
|
||||
at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::get_offset");
|
||||
return state_->philox_offset_per_thread_;
|
||||
}
|
||||
|
||||
uint64_t XPUGeneratorImpl::current_seed() const {
|
||||
return seed_;
|
||||
at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::current_seed");
|
||||
return state_->seed_;
|
||||
}
|
||||
|
||||
uint64_t XPUGeneratorImpl::seed() {
|
||||
at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::seed");
|
||||
auto random = c10::detail::getNonDeterministicRandom(true);
|
||||
this->set_current_seed(random);
|
||||
return random;
|
||||
@ -110,39 +168,65 @@ c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
|
||||
}
|
||||
|
||||
void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
||||
at::xpu::assertNotCapturing(
|
||||
"Please ensure to utilize the XPUGeneratorImpl::set_state_index method during capturing.");
|
||||
static const size_t seed_size = sizeof(uint64_t);
|
||||
static const size_t offset_size = sizeof(uint64_t);
|
||||
static const size_t total_size = seed_size + offset_size;
|
||||
|
||||
at::detail::check_rng_state(new_state);
|
||||
auto new_state_size = new_state.numel();
|
||||
TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
|
||||
|
||||
uint64_t input_seed;
|
||||
bool no_philox_seed = false;
|
||||
auto new_state_size = new_state.numel();
|
||||
if (new_state_size == total_size - offset_size) {
|
||||
no_philox_seed = true;
|
||||
} else {
|
||||
TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
|
||||
}
|
||||
|
||||
uint64_t input_seed = 0;
|
||||
auto new_rng_state = new_state.data_dtype_initialized<uint8_t>();
|
||||
memcpy(&input_seed, new_rng_state, seed_size);
|
||||
this->set_current_seed(input_seed);
|
||||
uint64_t philox_offset;
|
||||
memcpy(&philox_offset, new_rng_state + seed_size, offset_size);
|
||||
uint64_t philox_offset = 0;
|
||||
if (!no_philox_seed) {
|
||||
memcpy(&philox_offset, new_rng_state + seed_size, offset_size);
|
||||
}
|
||||
this->set_philox_offset_per_thread(philox_offset);
|
||||
}
|
||||
|
||||
void XPUGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) {
|
||||
TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4");
|
||||
philox_offset_per_thread_ = offset;
|
||||
state_->philox_offset_per_thread_ = offset;
|
||||
}
|
||||
|
||||
uint64_t XPUGeneratorImpl::philox_offset_per_thread() const {
|
||||
return philox_offset_per_thread_;
|
||||
return state_->philox_offset_per_thread_;
|
||||
}
|
||||
|
||||
PhiloxXpuState XPUGeneratorImpl::philox_xpu_state(uint64_t increment) {
|
||||
if (at::xpu::currentStreamCaptureStatus() !=
|
||||
at::xpu::CaptureStatus::Executing) {
|
||||
uint32_t offset = state_->offset_intragraph_;
|
||||
state_->increase(increment);
|
||||
return PhiloxXpuState(
|
||||
state_->seed_extragraph_.data_ptr<int64_t>(),
|
||||
state_->offset_extragraph_.data_ptr<int64_t>(),
|
||||
offset);
|
||||
} else {
|
||||
uint64_t offset = state_->philox_offset_per_thread_;
|
||||
state_->increase(increment);
|
||||
return PhiloxXpuState(state_->seed_, offset);
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<uint64_t, uint64_t> XPUGeneratorImpl::philox_engine_inputs(
|
||||
uint64_t increment) {
|
||||
increment = ((increment + 3) / 4) * 4;
|
||||
TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0);
|
||||
uint64_t offset = this->philox_offset_per_thread_;
|
||||
this->philox_offset_per_thread_ += increment;
|
||||
return std::make_pair(this->seed_, offset);
|
||||
at::xpu::assertNotCapturing(
|
||||
"Refactor this op to use XPUGeneratorImpl::philox_xpu_state. Cannot call XPUGeneratorImpl::philox_engine_inputs");
|
||||
uint64_t offset = state_->philox_offset_per_thread_;
|
||||
state_->increase(increment);
|
||||
return std::make_pair(state_->seed_, offset);
|
||||
}
|
||||
|
||||
DeviceType XPUGeneratorImpl::device_type() {
|
||||
@ -154,9 +238,8 @@ std::shared_ptr<XPUGeneratorImpl> XPUGeneratorImpl::clone() const {
|
||||
}
|
||||
|
||||
XPUGeneratorImpl* XPUGeneratorImpl::clone_impl() const {
|
||||
auto gen = new XPUGeneratorImpl(this->device().index());
|
||||
gen->set_current_seed(this->seed_);
|
||||
gen->set_philox_offset_per_thread(this->philox_offset_per_thread_);
|
||||
at::xpu::assertNotCapturing("Cannot call XPUGeneratorImpl::clone_impl");
|
||||
auto gen = new XPUGeneratorImpl(this->device().index(), state_->clone());
|
||||
return gen;
|
||||
}
|
||||
|
||||
|
@ -1,12 +1,43 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <ATen/core/TensorBase.h>
|
||||
#include <ATen/xpu/PhiloxXpuState.h>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace xpu {
|
||||
struct XPUGraph;
|
||||
}
|
||||
|
||||
struct XPUGeneratorState : public c10::intrusive_ptr_target {
|
||||
uint64_t seed_;
|
||||
uint64_t philox_offset_per_thread_;
|
||||
uint32_t offset_intragraph_;
|
||||
bool capturing_{};
|
||||
at::TensorBase seed_extragraph_{};
|
||||
at::TensorBase offset_extragraph_{};
|
||||
|
||||
XPUGeneratorState(
|
||||
uint64_t seed = default_rng_seed_val,
|
||||
uint64_t philox_offset_per_thread = 0,
|
||||
uint32_t offset_intragraph = 0)
|
||||
: seed_(seed),
|
||||
philox_offset_per_thread_(philox_offset_per_thread),
|
||||
offset_intragraph_(offset_intragraph) {}
|
||||
|
||||
void increase(uint64_t increment);
|
||||
|
||||
c10::intrusive_ptr<XPUGeneratorState> clone();
|
||||
};
|
||||
|
||||
struct TORCH_XPU_API XPUGeneratorImpl : public GeneratorImpl {
|
||||
// Constructors
|
||||
XPUGeneratorImpl(DeviceIndex device_index = -1);
|
||||
XPUGeneratorImpl(
|
||||
DeviceIndex device_index,
|
||||
c10::intrusive_ptr<XPUGeneratorState> state_);
|
||||
~XPUGeneratorImpl() override = default;
|
||||
|
||||
// XPUGeneratorImpl methods
|
||||
@ -18,15 +49,18 @@ struct TORCH_XPU_API XPUGeneratorImpl : public GeneratorImpl {
|
||||
uint64_t seed() override;
|
||||
void set_state(const c10::TensorImpl& new_state) override;
|
||||
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
|
||||
|
||||
void set_philox_offset_per_thread(uint64_t offset);
|
||||
uint64_t philox_offset_per_thread() const;
|
||||
|
||||
PhiloxXpuState philox_xpu_state(uint64_t increment);
|
||||
// will remove once all ops are refactored to use philox_xpu_state.
|
||||
std::pair<uint64_t, uint64_t> philox_engine_inputs(uint64_t increment);
|
||||
static c10::DeviceType device_type();
|
||||
|
||||
private:
|
||||
XPUGeneratorImpl* clone_impl() const override;
|
||||
uint64_t seed_ = default_rng_seed_val;
|
||||
uint64_t philox_offset_per_thread_ = 0;
|
||||
c10::intrusive_ptr<XPUGeneratorState> state_;
|
||||
};
|
||||
|
||||
namespace xpu::detail {
|
||||
|
22
aten/src/ATen/xpu/XPUGraphsUtils.h
Normal file
22
aten/src/ATen/xpu/XPUGraphsUtils.h
Normal file
@ -0,0 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/xpu/XPUGraphsC10Utils.h>
|
||||
|
||||
namespace at::xpu {
|
||||
|
||||
inline CaptureStatus currentStreamCaptureStatus() {
|
||||
return c10::xpu::currentStreamCaptureStatusMayInitCtx();
|
||||
}
|
||||
|
||||
inline void assertNotCapturing(const std::string& attempt) {
|
||||
auto status = currentStreamCaptureStatus();
|
||||
TORCH_CHECK(
|
||||
status == CaptureStatus::Executing,
|
||||
attempt,
|
||||
" during XPU graph capture. If you need this call to be captured, "
|
||||
"please file an issue. "
|
||||
"Current xpuStreamCaptureStatus: ",
|
||||
status);
|
||||
}
|
||||
|
||||
} // namespace at::xpu
|
42
c10/xpu/XPUGraphsC10Utils.h
Normal file
42
c10/xpu/XPUGraphsC10Utils.h
Normal file
@ -0,0 +1,42 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/xpu/XPUStream.h>
|
||||
#include <iostream>
|
||||
|
||||
// XPU Graphs utils used by c10 and aten.
|
||||
using namespace sycl::ext::oneapi::experimental;
|
||||
namespace c10::xpu {
|
||||
|
||||
static_assert(
|
||||
int8_t(queue_state::executing) == 0,
|
||||
"unexpected int(queue_state::executing) value");
|
||||
static_assert(
|
||||
int8_t(queue_state::recording) == 1,
|
||||
"unexpected int(queue_state::recording) value");
|
||||
|
||||
enum class CaptureStatus : int8_t {
|
||||
Executing = int8_t(queue_state::executing),
|
||||
Recording = int8_t(queue_state::recording)
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
|
||||
switch (status) {
|
||||
case CaptureStatus::Executing:
|
||||
os << "Executing";
|
||||
break;
|
||||
case CaptureStatus::Recording:
|
||||
os << "Recording";
|
||||
break;
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "Unknown XPU graph CaptureStatus", int(status));
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
inline CaptureStatus currentStreamCaptureStatusMayInitCtx() {
|
||||
auto state = c10::xpu::getCurrentXPUStream().queue().ext_oneapi_get_state();
|
||||
return CaptureStatus(state);
|
||||
}
|
||||
|
||||
} // namespace c10::xpu
|
Reference in New Issue
Block a user