Make Context to be Device-agnostic Step by Step (2/N) (#136526)

----

- add new method(getDefaultGenerator, getNewGenerator) into AcceleratorHooksInterface
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136526
Approved by: https://github.com/ezyang, https://github.com/EikanWang
This commit is contained in:
FFFrog
2024-11-06 11:11:23 +08:00
committed by PyTorch MergeBot
parent ca30704f0b
commit c03324de2d
20 changed files with 70 additions and 82 deletions

View File

@ -44,21 +44,9 @@ class TORCH_API Context {
if (device_type == at::kCPU) {
return at::detail::getDefaultCPUGenerator();
} else if (device_type == at::kCUDA) {
return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index());
} else if (device_type == at::kMPS) {
return at::detail::getMPSHooks().getDefaultMPSGenerator();
} else if (device_type == at::kXPU) {
return at::detail::getXPUHooks().getDefaultXPUGenerator(device.index());
} else if (device_type == at::kIPU) {
return at::detail::getIPUHooks().getDefaultIPUGenerator(device.index());
} else if (device_type == at::kHPU) {
return at::detail::getHPUHooks().getDefaultHPUGenerator(device.index());
} else if (device_type == at::kPrivateUse1) {
return at::detail::getPrivateUse1Hooks().getDefaultGenerator(
device.index());
} else {
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
return getAcceleratorHooksInterface(device_type)
.getDefaultGenerator(device.index());
}
}

View File

@ -102,7 +102,7 @@ void CUDAHooks::init() const {
#endif
}
const Generator& CUDAHooks::getDefaultCUDAGenerator(DeviceIndex device_index) const {
const Generator& CUDAHooks::getDefaultGenerator(DeviceIndex device_index) const {
return at::cuda::detail::getDefaultCUDAGenerator(device_index);
}

View File

@ -21,7 +21,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
void init() const override;
Device getDeviceFromPtr(void* data) const override;
bool isPinnedPtr(const void* data) const override;
const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override;
const Generator& getDefaultGenerator(
DeviceIndex device_index = -1) const override;
bool hasCUDA() const override;
bool hasMAGMA() const override;
bool hasCuDNN() const override;

View File

@ -1,9 +1,13 @@
#pragma once
#include <ATen/core/Generator.h>
#include <c10/core/Allocator.h>
#include <c10/core/Device.h>
#include <c10/core/Stream.h>
#include <c10/core/Allocator.h>
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
namespace at {
// AcceleratorHooksInterface is a shared interface provided by all
@ -58,7 +62,18 @@ struct TORCH_API AcceleratorHooksInterface {
virtual Device getDeviceFromPtr(void* data) const {
TORCH_CHECK(false, "Backend doesn't support getDeviceFromPtr()");
}
virtual const Generator& getDefaultGenerator(
C10_UNUSED DeviceIndex device_index = -1) const {
TORCH_CHECK(false, "Backend doesn`t support getDefaultGenerator()");
}
virtual Generator getNewGenerator(
C10_UNUSED DeviceIndex device_index = -1) const {
TORCH_CHECK(false, "Backend doesn`t support getNewGenerator()");
}
};
} // namespace at
C10_DIAGNOSTIC_POP()

View File

@ -6,16 +6,13 @@
#include <ATen/detail/AcceleratorHooksInterface.h>
// Forward-declares at::Generator and at::cuda::NVRTC
// NB: Class must live in `at` due to limitations of Registry.h.
namespace at {
struct Generator;
// Forward-declares at::cuda::NVRTC
namespace cuda {
struct NVRTC;
} // namespace cuda
} // namespace at
// NB: Class must live in `at` due to limitations of Registry.h.
namespace at {
#ifdef _MSC_VER
constexpr const char* CUDA_HELP =
@ -69,8 +66,8 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP);
}
virtual const Generator& getDefaultCUDAGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const {
const Generator& getDefaultGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const override {
TORCH_CHECK(
false,
"Cannot get default CUDA generator without ATen_cuda library. ",

View File

@ -1,19 +1,13 @@
#pragma once
#include <c10/core/Allocator.h>
#include <c10/core/GeneratorImpl.h>
#include <c10/util/Exception.h>
#include <c10/util/Registry.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <memory>
namespace at {
class Context;
}
// NB: Class must live in `at` due to limitations of Registry.h.
namespace at {
@ -30,8 +24,9 @@ struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface {
TORCH_CHECK(false, "Cannot initialize HIP without ATen_hip library.");
}
virtual std::unique_ptr<c10::GeneratorImpl> initHIPGenerator(Context*) const {
AT_ERROR("Cannot initialize HIP generator without ATen_hip library.");
const Generator& getDefaultGenerator(
C10_UNUSED DeviceIndex device_index = -1) const override {
TORCH_CHECK(false, "Cannot initialize HIP without ATen_hip library.");
}
virtual bool hasHIP() const {
@ -50,10 +45,6 @@ struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface {
TORCH_CHECK(false, "Pinned memory requires HIP.");
}
virtual void registerHIPTypes(Context*) const {
AT_ERROR("Cannot registerHIPTypes() without ATen_hip library.");
}
virtual int getNumGPUs() const {
return 0;
}

View File

@ -1,6 +1,5 @@
#pragma once
#include <ATen/core/Generator.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <c10/core/Allocator.h>
@ -9,7 +8,7 @@
namespace at {
struct TORCH_API IPUHooksInterface: AcceleratorHooksInterface {
struct TORCH_API IPUHooksInterface : AcceleratorHooksInterface {
~IPUHooksInterface() override = default;
void init() const override {
@ -21,16 +20,14 @@ struct TORCH_API IPUHooksInterface: AcceleratorHooksInterface {
return false;
}
virtual const Generator& getDefaultIPUGenerator(
DeviceIndex device_index [[maybe_unused]] = -1) const {
AT_ERROR(
"Cannot get the default IPU generator: the IPU backend is not "
"available.");
const Generator& getDefaultGenerator(
C10_UNUSED DeviceIndex device_index = -1) const override {
TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library.");
}
virtual Generator newIPUGenerator(DeviceIndex device_index [[maybe_unused]] = -1) const {
AT_ERROR(
"Cannot create a new IPU generator: the IPU backend is not available.");
Generator getNewGenerator(
DeviceIndex device_index [[maybe_unused]] = -1) const override {
TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library.");
}
};

View File

@ -2,9 +2,9 @@
#pragma once
#include <c10/core/Allocator.h>
#include <ATen/core/Generator.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <c10/core/Allocator.h>
#include <c10/util/Exception.h>
#include <c10/util/Registry.h>
@ -31,7 +31,8 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
virtual bool isOnMacOSorNewer(unsigned major = 13, unsigned minor = 0) const {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual const Generator& getDefaultMPSGenerator() const {
const Generator& getDefaultGenerator(
C10_UNUSED DeviceIndex device_index = -1) const override {
FAIL_MPSHOOKS_FUNC(__func__);
}
virtual Allocator* getMPSDeviceAllocator() const {

View File

@ -1,18 +1,20 @@
#pragma once
#include <ATen/core/Generator.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <c10/core/Allocator.h>
#include <c10/core/Device.h>
#include <c10/core/Storage.h>
#include <c10/util/Exception.h>
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
namespace at {
struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
~PrivateUse1HooksInterface() override = default;
virtual const at::Generator& getDefaultGenerator(
c10::DeviceIndex device_index) const {
const at::Generator& getDefaultGenerator(
c10::DeviceIndex device_index) const override {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDefaultGenerator`.");

View File

@ -4,7 +4,6 @@
#include <c10/util/Exception.h>
#include <c10/util/Registry.h>
#include <ATen/core/Generator.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
@ -32,17 +31,17 @@ struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{
TORCH_CHECK(false, "Cannot get XPU global device index without ATen_xpu library.");
}
virtual Generator getXPUGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const {
TORCH_CHECK(false, "Cannot get XPU generator without ATen_xpu library.");
}
virtual const Generator& getDefaultXPUGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const {
const Generator& getDefaultGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const override {
TORCH_CHECK(
false, "Cannot get default XPU generator without ATen_xpu library.");
}
Generator getNewGenerator(
[[maybe_unused]] DeviceIndex device_index = -1) const override {
TORCH_CHECK(false, "Cannot get XPU generator without ATen_xpu library.");
}
virtual DeviceIndex getNumGPUs() const {
return 0;
}

View File

@ -19,7 +19,8 @@ struct MPSHooks : public at::MPSHooksInterface {
bool isOnMacOSorNewer(unsigned major, unsigned minor) const override;
// MPSGeneratorImpl interface
const Generator& getDefaultMPSGenerator() const override;
const Generator& getDefaultGenerator(
DeviceIndex device_index = -1) const override;
// MPSStream interface
void deviceSynchronize() const override;

View File

@ -59,7 +59,7 @@ Allocator* MPSHooks::getMPSDeviceAllocator() const {
return at::mps::GetMPSAllocator();
}
const Generator& MPSHooks::getDefaultMPSGenerator() const {
const Generator& MPSHooks::getDefaultGenerator([[maybe_unused]] DeviceIndex device_index) const {
return at::mps::detail::getDefaultMPSGenerator();
}

View File

@ -2358,8 +2358,7 @@ DropoutState& get_dropout_state(
std::unique_lock<std::mutex> lock{state_cache_mut};
auto& state = dropout_state_cache.at(device);
if (train && dropout_p > 0) {
const auto& gen =
at::detail::getCUDAHooks().getDefaultCUDAGenerator(device);
const auto& gen = at::detail::getCUDAHooks().getDefaultGenerator(device);
auto gen_impl = gen.get<at::CUDAGeneratorImpl>();
bool reset_rnn_state = gen_impl->reset_rnn_state();
if (!state.buffer.defined() || reset_rnn_state) {

View File

@ -34,13 +34,12 @@ int32_t XPUHooks::getGlobalIdxFromDevice(const at::Device& device) const {
#endif
}
Generator XPUHooks::getXPUGenerator(DeviceIndex device_index) const {
return make_generator<at::XPUGeneratorImpl>(device_index);
const Generator& XPUHooks::getDefaultGenerator(DeviceIndex device_index) const {
return at::xpu::detail::getDefaultXPUGenerator(device_index);
}
const Generator& XPUHooks::getDefaultXPUGenerator(
DeviceIndex device_index) const {
return at::xpu::detail::getDefaultXPUGenerator(device_index);
Generator XPUHooks::getNewGenerator(DeviceIndex device_index) const {
return make_generator<at::XPUGeneratorImpl>(device_index);
}
Device XPUHooks::getDeviceFromPtr(void* data) const {

View File

@ -11,9 +11,9 @@ struct XPUHooks : public at::XPUHooksInterface {
bool hasXPU() const override;
std::string showConfig() const override;
int32_t getGlobalIdxFromDevice(const at::Device& device) const override;
Generator getXPUGenerator(DeviceIndex device_index = -1) const override;
const Generator& getDefaultXPUGenerator(
const Generator& getDefaultGenerator(
DeviceIndex device_index = -1) const override;
Generator getNewGenerator(DeviceIndex device_index = -1) const override;
Device getDeviceFromPtr(void* data) const override;
c10::DeviceIndex getNumGPUs() const override;
DeviceIndex current_device() const override;

View File

@ -73,9 +73,9 @@ static PyObject* THPGenerator_pynew(
}
#endif
else if (device.type() == at::kXPU) {
self->cdata = at::detail::getXPUHooks().getXPUGenerator(device.index());
self->cdata = at::detail::getXPUHooks().getNewGenerator(device.index());
} else if (device.type() == at::kIPU) {
self->cdata = at::detail::getIPUHooks().newIPUGenerator(device.index());
self->cdata = at::detail::getIPUHooks().getNewGenerator(device.index());
} else if (device.type() == at::kPrivateUse1) {
self->cdata = at::GetGeneratorForPrivateuse1(device.index());
} else {

View File

@ -28,7 +28,7 @@ bool cudnn_is_available() {
void manual_seed(uint64_t seed) {
if (is_available()) {
auto index = at::detail::getCUDAHooks().getCurrentDevice();
auto gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator(index);
auto gen = at::detail::getCUDAHooks().getDefaultGenerator(index);
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());
@ -41,8 +41,7 @@ void manual_seed(uint64_t seed) {
void manual_seed_all(uint64_t seed) {
auto num_gpu = device_count();
for (const auto i : c10::irange(num_gpu)) {
auto gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator(
static_cast<c10::DeviceIndex>(i));
auto gen = at::detail::getCUDAHooks().getDefaultGenerator(i);
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());

View File

@ -10,7 +10,7 @@ bool is_available() {
/// Sets the seed for the MPS's default generator.
void manual_seed(uint64_t seed) {
if (is_available()) {
auto gen = at::detail::getMPSHooks().getDefaultMPSGenerator();
auto gen = at::detail::getMPSHooks().getDefaultGenerator();
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());

View File

@ -14,7 +14,7 @@ bool is_available() {
void manual_seed(uint64_t seed) {
if (is_available()) {
auto index = at::detail::getXPUHooks().getCurrentDevice();
auto gen = at::detail::getXPUHooks().getDefaultXPUGenerator(index);
auto gen = at::detail::getXPUHooks().getDefaultGenerator(index);
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());
@ -27,8 +27,7 @@ void manual_seed(uint64_t seed) {
void manual_seed_all(uint64_t seed) {
auto num_gpu = device_count();
for (const auto i : c10::irange(num_gpu)) {
auto gen = at::detail::getXPUHooks().getDefaultXPUGenerator(
static_cast<c10::DeviceIndex>(i));
auto gen = at::detail::getXPUHooks().getDefaultGenerator(i);
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());

View File

@ -44,7 +44,7 @@ static PyObject* MPSModule_getDefaultMPSGenerator(
HANDLE_TH_ERRORS
track_bad_mps_fork();
return THPGenerator_initDefaultGenerator(
at::detail::getMPSHooks().getDefaultMPSGenerator());
at::detail::getMPSHooks().getDefaultGenerator());
END_HANDLE_TH_ERRORS
}