mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Make Context to be Device-agnostic Step by Step (2/N) (#136526)
---- - add new method(getDefaultGenerator, getNewGenerator) into AcceleratorHooksInterface Pull Request resolved: https://github.com/pytorch/pytorch/pull/136526 Approved by: https://github.com/ezyang, https://github.com/EikanWang
This commit is contained in:
@ -44,21 +44,9 @@ class TORCH_API Context {
|
||||
|
||||
if (device_type == at::kCPU) {
|
||||
return at::detail::getDefaultCPUGenerator();
|
||||
} else if (device_type == at::kCUDA) {
|
||||
return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index());
|
||||
} else if (device_type == at::kMPS) {
|
||||
return at::detail::getMPSHooks().getDefaultMPSGenerator();
|
||||
} else if (device_type == at::kXPU) {
|
||||
return at::detail::getXPUHooks().getDefaultXPUGenerator(device.index());
|
||||
} else if (device_type == at::kIPU) {
|
||||
return at::detail::getIPUHooks().getDefaultIPUGenerator(device.index());
|
||||
} else if (device_type == at::kHPU) {
|
||||
return at::detail::getHPUHooks().getDefaultHPUGenerator(device.index());
|
||||
} else if (device_type == at::kPrivateUse1) {
|
||||
return at::detail::getPrivateUse1Hooks().getDefaultGenerator(
|
||||
device.index());
|
||||
} else {
|
||||
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
|
||||
return getAcceleratorHooksInterface(device_type)
|
||||
.getDefaultGenerator(device.index());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -102,7 +102,7 @@ void CUDAHooks::init() const {
|
||||
#endif
|
||||
}
|
||||
|
||||
const Generator& CUDAHooks::getDefaultCUDAGenerator(DeviceIndex device_index) const {
|
||||
const Generator& CUDAHooks::getDefaultGenerator(DeviceIndex device_index) const {
|
||||
return at::cuda::detail::getDefaultCUDAGenerator(device_index);
|
||||
}
|
||||
|
||||
|
@ -21,7 +21,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
||||
void init() const override;
|
||||
Device getDeviceFromPtr(void* data) const override;
|
||||
bool isPinnedPtr(const void* data) const override;
|
||||
const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override;
|
||||
const Generator& getDefaultGenerator(
|
||||
DeviceIndex device_index = -1) const override;
|
||||
bool hasCUDA() const override;
|
||||
bool hasMAGMA() const override;
|
||||
bool hasCuDNN() const override;
|
||||
|
@ -1,9 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/Generator.h>
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/Stream.h>
|
||||
#include <c10/core/Allocator.h>
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
|
||||
|
||||
namespace at {
|
||||
|
||||
// AcceleratorHooksInterface is a shared interface provided by all
|
||||
@ -58,7 +62,18 @@ struct TORCH_API AcceleratorHooksInterface {
|
||||
virtual Device getDeviceFromPtr(void* data) const {
|
||||
TORCH_CHECK(false, "Backend doesn't support getDeviceFromPtr()");
|
||||
}
|
||||
|
||||
virtual const Generator& getDefaultGenerator(
|
||||
C10_UNUSED DeviceIndex device_index = -1) const {
|
||||
TORCH_CHECK(false, "Backend doesn`t support getDefaultGenerator()");
|
||||
}
|
||||
|
||||
virtual Generator getNewGenerator(
|
||||
C10_UNUSED DeviceIndex device_index = -1) const {
|
||||
TORCH_CHECK(false, "Backend doesn`t support getNewGenerator()");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace at
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
@ -6,16 +6,13 @@
|
||||
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
|
||||
// Forward-declares at::Generator and at::cuda::NVRTC
|
||||
// NB: Class must live in `at` due to limitations of Registry.h.
|
||||
namespace at {
|
||||
struct Generator;
|
||||
|
||||
// Forward-declares at::cuda::NVRTC
|
||||
namespace cuda {
|
||||
struct NVRTC;
|
||||
} // namespace cuda
|
||||
} // namespace at
|
||||
|
||||
// NB: Class must live in `at` due to limitations of Registry.h.
|
||||
namespace at {
|
||||
|
||||
#ifdef _MSC_VER
|
||||
constexpr const char* CUDA_HELP =
|
||||
@ -69,8 +66,8 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
|
||||
TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual const Generator& getDefaultCUDAGenerator(
|
||||
[[maybe_unused]] DeviceIndex device_index = -1) const {
|
||||
const Generator& getDefaultGenerator(
|
||||
[[maybe_unused]] DeviceIndex device_index = -1) const override {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Cannot get default CUDA generator without ATen_cuda library. ",
|
||||
|
@ -1,19 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/core/GeneratorImpl.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <c10/util/Registry.h>
|
||||
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace at {
|
||||
class Context;
|
||||
}
|
||||
|
||||
// NB: Class must live in `at` due to limitations of Registry.h.
|
||||
namespace at {
|
||||
|
||||
@ -30,8 +24,9 @@ struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface {
|
||||
TORCH_CHECK(false, "Cannot initialize HIP without ATen_hip library.");
|
||||
}
|
||||
|
||||
virtual std::unique_ptr<c10::GeneratorImpl> initHIPGenerator(Context*) const {
|
||||
AT_ERROR("Cannot initialize HIP generator without ATen_hip library.");
|
||||
const Generator& getDefaultGenerator(
|
||||
C10_UNUSED DeviceIndex device_index = -1) const override {
|
||||
TORCH_CHECK(false, "Cannot initialize HIP without ATen_hip library.");
|
||||
}
|
||||
|
||||
virtual bool hasHIP() const {
|
||||
@ -50,10 +45,6 @@ struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface {
|
||||
TORCH_CHECK(false, "Pinned memory requires HIP.");
|
||||
}
|
||||
|
||||
virtual void registerHIPTypes(Context*) const {
|
||||
AT_ERROR("Cannot registerHIPTypes() without ATen_hip library.");
|
||||
}
|
||||
|
||||
virtual int getNumGPUs() const {
|
||||
return 0;
|
||||
}
|
||||
|
@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
@ -9,7 +8,7 @@
|
||||
|
||||
namespace at {
|
||||
|
||||
struct TORCH_API IPUHooksInterface: AcceleratorHooksInterface {
|
||||
struct TORCH_API IPUHooksInterface : AcceleratorHooksInterface {
|
||||
~IPUHooksInterface() override = default;
|
||||
|
||||
void init() const override {
|
||||
@ -21,16 +20,14 @@ struct TORCH_API IPUHooksInterface: AcceleratorHooksInterface {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual const Generator& getDefaultIPUGenerator(
|
||||
DeviceIndex device_index [[maybe_unused]] = -1) const {
|
||||
AT_ERROR(
|
||||
"Cannot get the default IPU generator: the IPU backend is not "
|
||||
"available.");
|
||||
const Generator& getDefaultGenerator(
|
||||
C10_UNUSED DeviceIndex device_index = -1) const override {
|
||||
TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library.");
|
||||
}
|
||||
|
||||
virtual Generator newIPUGenerator(DeviceIndex device_index [[maybe_unused]] = -1) const {
|
||||
AT_ERROR(
|
||||
"Cannot create a new IPU generator: the IPU backend is not available.");
|
||||
Generator getNewGenerator(
|
||||
DeviceIndex device_index [[maybe_unused]] = -1) const override {
|
||||
TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library.");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Registry.h>
|
||||
|
||||
@ -31,7 +31,8 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
|
||||
virtual bool isOnMacOSorNewer(unsigned major = 13, unsigned minor = 0) const {
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
virtual const Generator& getDefaultMPSGenerator() const {
|
||||
const Generator& getDefaultGenerator(
|
||||
C10_UNUSED DeviceIndex device_index = -1) const override {
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
virtual Allocator* getMPSDeviceAllocator() const {
|
||||
|
@ -1,18 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/Storage.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
|
||||
|
||||
namespace at {
|
||||
|
||||
struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
|
||||
~PrivateUse1HooksInterface() override = default;
|
||||
virtual const at::Generator& getDefaultGenerator(
|
||||
c10::DeviceIndex device_index) const {
|
||||
|
||||
const at::Generator& getDefaultGenerator(
|
||||
c10::DeviceIndex device_index) const override {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDefaultGenerator`.");
|
||||
|
@ -4,7 +4,6 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Registry.h>
|
||||
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
|
||||
@ -32,17 +31,17 @@ struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{
|
||||
TORCH_CHECK(false, "Cannot get XPU global device index without ATen_xpu library.");
|
||||
}
|
||||
|
||||
virtual Generator getXPUGenerator(
|
||||
[[maybe_unused]] DeviceIndex device_index = -1) const {
|
||||
TORCH_CHECK(false, "Cannot get XPU generator without ATen_xpu library.");
|
||||
}
|
||||
|
||||
virtual const Generator& getDefaultXPUGenerator(
|
||||
[[maybe_unused]] DeviceIndex device_index = -1) const {
|
||||
const Generator& getDefaultGenerator(
|
||||
[[maybe_unused]] DeviceIndex device_index = -1) const override {
|
||||
TORCH_CHECK(
|
||||
false, "Cannot get default XPU generator without ATen_xpu library.");
|
||||
}
|
||||
|
||||
Generator getNewGenerator(
|
||||
[[maybe_unused]] DeviceIndex device_index = -1) const override {
|
||||
TORCH_CHECK(false, "Cannot get XPU generator without ATen_xpu library.");
|
||||
}
|
||||
|
||||
virtual DeviceIndex getNumGPUs() const {
|
||||
return 0;
|
||||
}
|
||||
|
@ -19,7 +19,8 @@ struct MPSHooks : public at::MPSHooksInterface {
|
||||
bool isOnMacOSorNewer(unsigned major, unsigned minor) const override;
|
||||
|
||||
// MPSGeneratorImpl interface
|
||||
const Generator& getDefaultMPSGenerator() const override;
|
||||
const Generator& getDefaultGenerator(
|
||||
DeviceIndex device_index = -1) const override;
|
||||
|
||||
// MPSStream interface
|
||||
void deviceSynchronize() const override;
|
||||
|
@ -59,7 +59,7 @@ Allocator* MPSHooks::getMPSDeviceAllocator() const {
|
||||
return at::mps::GetMPSAllocator();
|
||||
}
|
||||
|
||||
const Generator& MPSHooks::getDefaultMPSGenerator() const {
|
||||
const Generator& MPSHooks::getDefaultGenerator([[maybe_unused]] DeviceIndex device_index) const {
|
||||
return at::mps::detail::getDefaultMPSGenerator();
|
||||
}
|
||||
|
||||
|
@ -2358,8 +2358,7 @@ DropoutState& get_dropout_state(
|
||||
std::unique_lock<std::mutex> lock{state_cache_mut};
|
||||
auto& state = dropout_state_cache.at(device);
|
||||
if (train && dropout_p > 0) {
|
||||
const auto& gen =
|
||||
at::detail::getCUDAHooks().getDefaultCUDAGenerator(device);
|
||||
const auto& gen = at::detail::getCUDAHooks().getDefaultGenerator(device);
|
||||
auto gen_impl = gen.get<at::CUDAGeneratorImpl>();
|
||||
bool reset_rnn_state = gen_impl->reset_rnn_state();
|
||||
if (!state.buffer.defined() || reset_rnn_state) {
|
||||
|
@ -34,13 +34,12 @@ int32_t XPUHooks::getGlobalIdxFromDevice(const at::Device& device) const {
|
||||
#endif
|
||||
}
|
||||
|
||||
Generator XPUHooks::getXPUGenerator(DeviceIndex device_index) const {
|
||||
return make_generator<at::XPUGeneratorImpl>(device_index);
|
||||
const Generator& XPUHooks::getDefaultGenerator(DeviceIndex device_index) const {
|
||||
return at::xpu::detail::getDefaultXPUGenerator(device_index);
|
||||
}
|
||||
|
||||
const Generator& XPUHooks::getDefaultXPUGenerator(
|
||||
DeviceIndex device_index) const {
|
||||
return at::xpu::detail::getDefaultXPUGenerator(device_index);
|
||||
Generator XPUHooks::getNewGenerator(DeviceIndex device_index) const {
|
||||
return make_generator<at::XPUGeneratorImpl>(device_index);
|
||||
}
|
||||
|
||||
Device XPUHooks::getDeviceFromPtr(void* data) const {
|
||||
|
@ -11,9 +11,9 @@ struct XPUHooks : public at::XPUHooksInterface {
|
||||
bool hasXPU() const override;
|
||||
std::string showConfig() const override;
|
||||
int32_t getGlobalIdxFromDevice(const at::Device& device) const override;
|
||||
Generator getXPUGenerator(DeviceIndex device_index = -1) const override;
|
||||
const Generator& getDefaultXPUGenerator(
|
||||
const Generator& getDefaultGenerator(
|
||||
DeviceIndex device_index = -1) const override;
|
||||
Generator getNewGenerator(DeviceIndex device_index = -1) const override;
|
||||
Device getDeviceFromPtr(void* data) const override;
|
||||
c10::DeviceIndex getNumGPUs() const override;
|
||||
DeviceIndex current_device() const override;
|
||||
|
@ -73,9 +73,9 @@ static PyObject* THPGenerator_pynew(
|
||||
}
|
||||
#endif
|
||||
else if (device.type() == at::kXPU) {
|
||||
self->cdata = at::detail::getXPUHooks().getXPUGenerator(device.index());
|
||||
self->cdata = at::detail::getXPUHooks().getNewGenerator(device.index());
|
||||
} else if (device.type() == at::kIPU) {
|
||||
self->cdata = at::detail::getIPUHooks().newIPUGenerator(device.index());
|
||||
self->cdata = at::detail::getIPUHooks().getNewGenerator(device.index());
|
||||
} else if (device.type() == at::kPrivateUse1) {
|
||||
self->cdata = at::GetGeneratorForPrivateuse1(device.index());
|
||||
} else {
|
||||
|
@ -28,7 +28,7 @@ bool cudnn_is_available() {
|
||||
void manual_seed(uint64_t seed) {
|
||||
if (is_available()) {
|
||||
auto index = at::detail::getCUDAHooks().getCurrentDevice();
|
||||
auto gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator(index);
|
||||
auto gen = at::detail::getCUDAHooks().getDefaultGenerator(index);
|
||||
{
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen.mutex());
|
||||
@ -41,8 +41,7 @@ void manual_seed(uint64_t seed) {
|
||||
void manual_seed_all(uint64_t seed) {
|
||||
auto num_gpu = device_count();
|
||||
for (const auto i : c10::irange(num_gpu)) {
|
||||
auto gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator(
|
||||
static_cast<c10::DeviceIndex>(i));
|
||||
auto gen = at::detail::getCUDAHooks().getDefaultGenerator(i);
|
||||
{
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen.mutex());
|
||||
|
@ -10,7 +10,7 @@ bool is_available() {
|
||||
/// Sets the seed for the MPS's default generator.
|
||||
void manual_seed(uint64_t seed) {
|
||||
if (is_available()) {
|
||||
auto gen = at::detail::getMPSHooks().getDefaultMPSGenerator();
|
||||
auto gen = at::detail::getMPSHooks().getDefaultGenerator();
|
||||
{
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen.mutex());
|
||||
|
@ -14,7 +14,7 @@ bool is_available() {
|
||||
void manual_seed(uint64_t seed) {
|
||||
if (is_available()) {
|
||||
auto index = at::detail::getXPUHooks().getCurrentDevice();
|
||||
auto gen = at::detail::getXPUHooks().getDefaultXPUGenerator(index);
|
||||
auto gen = at::detail::getXPUHooks().getDefaultGenerator(index);
|
||||
{
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen.mutex());
|
||||
@ -27,8 +27,7 @@ void manual_seed(uint64_t seed) {
|
||||
void manual_seed_all(uint64_t seed) {
|
||||
auto num_gpu = device_count();
|
||||
for (const auto i : c10::irange(num_gpu)) {
|
||||
auto gen = at::detail::getXPUHooks().getDefaultXPUGenerator(
|
||||
static_cast<c10::DeviceIndex>(i));
|
||||
auto gen = at::detail::getXPUHooks().getDefaultGenerator(i);
|
||||
{
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen.mutex());
|
||||
|
@ -44,7 +44,7 @@ static PyObject* MPSModule_getDefaultMPSGenerator(
|
||||
HANDLE_TH_ERRORS
|
||||
track_bad_mps_fork();
|
||||
return THPGenerator_initDefaultGenerator(
|
||||
at::detail::getMPSHooks().getDefaultMPSGenerator());
|
||||
at::detail::getMPSHooks().getDefaultGenerator());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user