mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Make Context to be Device-agnostic Step by Step (1/N) (#136519)"
This reverts commit 4a8e49389c33934234dc89616fd17a58e760e2e7. Reverted https://github.com/pytorch/pytorch/pull/136519 on behalf of https://github.com/clee2000 due to breaking internal tests related to MITA, @ezyang has a forward fix? ([comment](https://github.com/pytorch/pytorch/pull/136519#issuecomment-2414588302))
This commit is contained in:
@ -39,8 +39,8 @@ class TORCH_API Context {
|
||||
|
||||
const Generator& defaultGenerator(Device device) {
|
||||
c10::DeviceType device_type = device.type();
|
||||
lazyInitDevice(device_type);
|
||||
|
||||
initCUDAIfNeeded(device_type);
|
||||
initHIPIfNeeded(device_type);
|
||||
if (device_type == at::kCPU) {
|
||||
return at::detail::getDefaultCPUGenerator();
|
||||
} else if (device_type == at::kCUDA) {
|
||||
@ -58,7 +58,6 @@ class TORCH_API Context {
|
||||
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
|
||||
}
|
||||
}
|
||||
|
||||
const AcceleratorHooksInterface& getAcceleratorHooksInterface(
|
||||
std::optional<c10::DeviceType> opt_device_type = std::nullopt) {
|
||||
c10::DeviceType device_type = opt_device_type.has_value()
|
||||
@ -81,17 +80,16 @@ class TORCH_API Context {
|
||||
c10::DeviceTypeName(device_type), " device type not an accelerator.");
|
||||
}
|
||||
}
|
||||
|
||||
Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
|
||||
lazyInitDevice(device_type);
|
||||
|
||||
initCUDAIfNeeded(device_type);
|
||||
initHIPIfNeeded(device_type);
|
||||
initXPUIfNeeded(device_type);
|
||||
if (device_type == at::kCPU) {
|
||||
return c10::DeviceType::CPU;
|
||||
} else {
|
||||
return getAcceleratorHooksInterface(device_type).getDeviceFromPtr(data);
|
||||
}
|
||||
}
|
||||
|
||||
bool isPinnedPtr(
|
||||
const void* data,
|
||||
std::optional<c10::DeviceType> device_type = std::nullopt) {
|
||||
@ -104,20 +102,10 @@ class TORCH_API Context {
|
||||
}
|
||||
return getAcceleratorHooksInterface(opt_device_type).isPinnedPtr(data);
|
||||
}
|
||||
|
||||
Allocator* getPinnedMemoryAllocator(
|
||||
std::optional<c10::DeviceType> device_type = std::nullopt) {
|
||||
return getAcceleratorHooksInterface(device_type).getPinnedMemoryAllocator();
|
||||
}
|
||||
|
||||
void lazyInitDevice(c10::DeviceType device_type) {
|
||||
if (device_type != at::kCPU) {
|
||||
c10::call_once(init_[static_cast<int8_t>(device_type)], [&] {
|
||||
getAcceleratorHooksInterface(device_type).init();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static bool hasOpenMP();
|
||||
static bool hasMKL();
|
||||
static bool hasLAPACK();
|
||||
@ -170,6 +158,27 @@ class TORCH_API Context {
|
||||
static bool hasMAIA() {
|
||||
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA);
|
||||
}
|
||||
// defined in header so that getNonVariableType has ability to inline
|
||||
// call_once check. getNonVariableType is called fairly frequently
|
||||
void lazyInitCUDA() {
|
||||
c10::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); });
|
||||
}
|
||||
void lazyInitHIP() {
|
||||
c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); });
|
||||
}
|
||||
void lazyInitXPU() {
|
||||
c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); });
|
||||
}
|
||||
void lazyInitMTIA() {
|
||||
c10::call_once(th_mtia_init, [&] { detail::getMTIAHooks().initMTIA(); });
|
||||
}
|
||||
void lazyInitPrivateUse1() {
|
||||
c10::call_once(thp_init, [&] {
|
||||
if (isPrivateUse1HooksRegistered()) {
|
||||
at::detail::getPrivateUse1Hooks().initPrivateUse1();
|
||||
}
|
||||
});
|
||||
}
|
||||
static const at::cuda::NVRTC& getNVRTC() {
|
||||
return detail::getCUDAHooks().nvrtc();
|
||||
}
|
||||
@ -344,26 +353,28 @@ class TORCH_API Context {
|
||||
bool allowFP16ReductionCPU() const;
|
||||
void setAllowFP16ReductionCPU(bool);
|
||||
|
||||
// Preserved for BC
|
||||
void lazyInitCUDA() {
|
||||
lazyInitDevice(at::kCUDA);
|
||||
}
|
||||
void lazyInitHIP() {
|
||||
lazyInitDevice(at::kHIP);
|
||||
}
|
||||
void lazyInitXPU() {
|
||||
lazyInitDevice(at::kXPU);
|
||||
}
|
||||
void lazyInitMTIA() {
|
||||
lazyInitDevice(at::kMTIA);
|
||||
}
|
||||
void lazyInitPrivateUse1() {
|
||||
lazyInitDevice(at::kPrivateUse1);
|
||||
}
|
||||
|
||||
private:
|
||||
void initCUDAIfNeeded(c10::DeviceType p) {
|
||||
if (p == c10::DeviceType::CUDA) {
|
||||
lazyInitCUDA();
|
||||
}
|
||||
}
|
||||
void initHIPIfNeeded(c10::DeviceType p) {
|
||||
if (p == c10::DeviceType::HIP) {
|
||||
lazyInitHIP();
|
||||
}
|
||||
}
|
||||
void initXPUIfNeeded(c10::DeviceType p) {
|
||||
if (p == c10::DeviceType::XPU) {
|
||||
lazyInitXPU();
|
||||
}
|
||||
}
|
||||
static bool checkCuBLASConfigDeterministic();
|
||||
std::array<c10::once_flag, at::COMPILE_TIME_MAX_DEVICE_TYPES> init_;
|
||||
c10::once_flag thc_init;
|
||||
c10::once_flag thh_init;
|
||||
c10::once_flag thx_init;
|
||||
c10::once_flag th_mtia_init;
|
||||
c10::once_flag thp_init;
|
||||
bool enabled_cudnn = true;
|
||||
bool deterministic_cudnn = false;
|
||||
bool deterministic_mkldnn = false;
|
||||
|
@ -10,7 +10,7 @@ TensorBase empty_cuda(
|
||||
ScalarType dtype,
|
||||
std::optional<Device> device_opt,
|
||||
std::optional<c10::MemoryFormat> memory_format_opt) {
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
|
||||
at::globalContext().lazyInitCUDA();
|
||||
const auto device = device_or_default(device_opt);
|
||||
TORCH_INTERNAL_ASSERT(device.is_cuda());
|
||||
const DeviceGuard device_guard(device);
|
||||
@ -50,7 +50,7 @@ TensorBase empty_strided_cuda(
|
||||
IntArrayRef stride,
|
||||
ScalarType dtype,
|
||||
std::optional<Device> device_opt) {
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
|
||||
at::globalContext().lazyInitCUDA();
|
||||
const auto device = device_or_default(device_opt);
|
||||
TORCH_INTERNAL_ASSERT(device.is_cuda());
|
||||
const DeviceGuard device_guard(device);
|
||||
|
@ -34,7 +34,7 @@ void init_p2p_access_cache(int64_t num_devices) {
|
||||
} // namespace detail
|
||||
|
||||
bool get_p2p_access(int dev, int dev_to_access) {
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
|
||||
at::globalContext().lazyInitCUDA();
|
||||
|
||||
TORCH_CHECK(dev >= 0 || dev < num_devices_,
|
||||
dev, " is not a device");
|
||||
|
@ -84,7 +84,7 @@ struct _Initializer {
|
||||
// NB: deleter is dynamic, because we need it to live in a separate
|
||||
// compilation unit (alt is to have another method in hooks, but
|
||||
// let's not if we don't need to!)
|
||||
void CUDAHooks::init() const {
|
||||
void CUDAHooks::initCUDA() const {
|
||||
C10_LOG_API_USAGE_ONCE("aten.init.cuda");
|
||||
// Force the update to enable unit testing. This code get executed before unit tests
|
||||
// have a chance to enable vitals.
|
||||
|
@ -19,7 +19,7 @@ TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)());
|
||||
// The real implementation of CUDAHooksInterface
|
||||
struct CUDAHooks : public at::CUDAHooksInterface {
|
||||
CUDAHooks(at::CUDAHooksArgs) {}
|
||||
void init() const override;
|
||||
void initCUDA() const override;
|
||||
Device getDeviceFromPtr(void* data) const override;
|
||||
bool isPinnedPtr(const void* data) const override;
|
||||
const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override;
|
||||
|
@ -19,10 +19,6 @@ struct TORCH_API AcceleratorHooksInterface {
|
||||
// Whether the device at device_index is fully initialized or not.
|
||||
virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0;
|
||||
|
||||
virtual void init() const {
|
||||
TORCH_CHECK(false, "Backend doesn`t support init()");
|
||||
}
|
||||
|
||||
virtual DeviceIndex deviceCount() const {
|
||||
return 0;
|
||||
}
|
||||
|
@ -65,7 +65,7 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
|
||||
~CUDAHooksInterface() override = default;
|
||||
|
||||
// Initialize THCState and, transitively, the CUDA state
|
||||
void init() const override {
|
||||
virtual void initCUDA() const {
|
||||
TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
|
@ -26,8 +26,9 @@ struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface {
|
||||
// squelch -Werror=non-virtual-dtor
|
||||
~HIPHooksInterface() override = default;
|
||||
|
||||
void init() const override {
|
||||
TORCH_CHECK(false, "Cannot initialize HIP without ATen_hip library.");
|
||||
// Initialize the HIP library state
|
||||
virtual void initHIP() const {
|
||||
AT_ERROR("Cannot initialize HIP without ATen_hip library.");
|
||||
}
|
||||
|
||||
virtual std::unique_ptr<c10::GeneratorImpl> initHIPGenerator(Context*) const {
|
||||
|
@ -1,25 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Registry.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
struct TORCH_API IPUHooksInterface: AcceleratorHooksInterface {
|
||||
~IPUHooksInterface() override = default;
|
||||
|
||||
void init() const override {
|
||||
TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library.");
|
||||
}
|
||||
|
||||
bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||
TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library.");
|
||||
return false;
|
||||
}
|
||||
struct TORCH_API IPUHooksInterface {
|
||||
virtual ~IPUHooksInterface() = default;
|
||||
|
||||
virtual const Generator& getDefaultIPUGenerator(
|
||||
DeviceIndex device_index [[maybe_unused]] = -1) const {
|
||||
|
@ -3,24 +3,13 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Registry.h>
|
||||
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
|
||||
// NB: Class must live in `at` due to limitations of Registry.h.
|
||||
namespace at {
|
||||
|
||||
struct TORCH_API MAIAHooksInterface : AcceleratorHooksInterface {
|
||||
struct TORCH_API MAIAHooksInterface {
|
||||
// This should never actually be implemented, but it is used to
|
||||
// squelch -Werror=non-virtual-dtor
|
||||
~MAIAHooksInterface() override = default;
|
||||
|
||||
void init() const override {
|
||||
TORCH_CHECK(false, "Cannot initialize MAIA without ATen_maia library.");
|
||||
}
|
||||
|
||||
bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||
TORCH_CHECK(false, "Cannot initialize MAIA without ATen_maia library.");
|
||||
return false;
|
||||
}
|
||||
virtual ~MAIAHooksInterface() = default;
|
||||
|
||||
virtual std::string showConfig() const {
|
||||
TORCH_CHECK(false, "Cannot query detailed MAIA version information.");
|
||||
|
@ -22,7 +22,7 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
|
||||
~MPSHooksInterface() override = default;
|
||||
|
||||
// Initialize the MPS library state
|
||||
void init() const override {
|
||||
virtual void initMPS() const {
|
||||
FAIL_MPSHOOKS_FUNC(__func__);
|
||||
}
|
||||
virtual bool hasMPS() const {
|
||||
|
@ -31,7 +31,7 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
||||
|
||||
~MTIAHooksInterface() override = default;
|
||||
|
||||
void init() const override {
|
||||
virtual void initMTIA() const {
|
||||
// Avoid logging here, since MTIA needs init devices first then it will know
|
||||
// how many devices are available. Make it as no-op if mtia extension is not
|
||||
// dynamically loaded.
|
||||
@ -109,11 +109,6 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Perserved for BC
|
||||
virtual void initMTIA() const {
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
struct TORCH_API MTIAHooksArgs {};
|
||||
|
@ -40,7 +40,7 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
|
||||
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`.");
|
||||
}
|
||||
|
||||
void init() const override {}
|
||||
virtual void initPrivateUse1() const {}
|
||||
virtual void resizePrivateUse1Bytes(
|
||||
const c10::Storage& storage,
|
||||
size_t newsize) const {
|
||||
|
@ -14,8 +14,10 @@ namespace at {
|
||||
struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{
|
||||
~XPUHooksInterface() override = default;
|
||||
|
||||
void init() const override {
|
||||
TORCH_CHECK(false, "Cannot initialize XPU without ATen_xpu library.");
|
||||
virtual void initXPU() const {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Cannot initialize XPU without ATen_xpu library.");
|
||||
}
|
||||
|
||||
virtual bool hasXPU() const {
|
||||
|
@ -12,7 +12,7 @@ namespace at::mps {
|
||||
// The real implementation of MPSHooksInterface
|
||||
struct MPSHooks : public at::MPSHooksInterface {
|
||||
MPSHooks(at::MPSHooksArgs) {}
|
||||
void init() const override;
|
||||
void initMPS() const override;
|
||||
|
||||
// MPSDevice interface
|
||||
bool hasMPS() const override;
|
||||
|
@ -10,7 +10,7 @@
|
||||
|
||||
namespace at::mps {
|
||||
|
||||
void MPSHooks::init() const {
|
||||
void MPSHooks::initMPS() const {
|
||||
C10_LOG_API_USAGE_ONCE("aten.init.mps");
|
||||
// TODO: initialize MPS devices and streams here
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes) {
|
||||
c10::cuda::CUDAGuard guard(device.index());
|
||||
at::DataPtr data = allocator->allocate(size_bytes);
|
||||
if (storage->data_ptr()) {
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
|
||||
at::globalContext().lazyInitCUDA();
|
||||
|
||||
C10_CUDA_CHECK(
|
||||
cudaMemcpyAsync(
|
||||
|
@ -138,9 +138,7 @@ __managed__ int input[] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
|
||||
|
||||
TEST(InclusiveScanSplit, CubTest) {
|
||||
if (!at::cuda::is_available()) return;
|
||||
at::globalContext().lazyInitDevice(
|
||||
c10::DeviceType::CUDA); // This is required to use PyTorch's caching
|
||||
// allocator.
|
||||
at::globalContext().lazyInitCUDA(); // This is required to use PyTorch's caching allocator.
|
||||
|
||||
int *output1;
|
||||
cudaMallocManaged(&output1, sizeof(int) * 10);
|
||||
@ -164,9 +162,7 @@ TEST(InclusiveScanSplit, CubTest) {
|
||||
|
||||
TEST(ExclusiveScanSplit, CubTest) {
|
||||
if (!at::cuda::is_available()) return;
|
||||
at::globalContext().lazyInitDevice(
|
||||
c10::DeviceType::CUDA); // This is required to use PyTorch's caching
|
||||
// allocator.
|
||||
at::globalContext().lazyInitCUDA(); // This is required to use PyTorch's caching allocator.
|
||||
|
||||
int *output2;
|
||||
cudaMallocManaged(&output2, sizeof(int) * 10);
|
||||
|
@ -9,7 +9,7 @@
|
||||
|
||||
namespace at::xpu::detail {
|
||||
|
||||
void XPUHooks::init() const {
|
||||
void XPUHooks::initXPU() const {
|
||||
C10_LOG_API_USAGE_ONCE("aten.init.xpu");
|
||||
const auto device_count = c10::xpu::device_count_ensure_non_zero();
|
||||
c10::xpu::XPUCachingAllocator::init(device_count);
|
||||
|
@ -7,7 +7,7 @@ namespace at::xpu::detail {
|
||||
// The real implementation of XPUHooksInterface
|
||||
struct XPUHooks : public at::XPUHooksInterface {
|
||||
XPUHooks(at::XPUHooksArgs) {}
|
||||
void init() const override;
|
||||
void initXPU() const override;
|
||||
bool hasXPU() const override;
|
||||
std::string showConfig() const override;
|
||||
int32_t getGlobalIdxFromDevice(const at::Device& device) const override;
|
||||
|
@ -66,7 +66,7 @@ class AsyncInputIsOutputTest : public AsyncTest {
|
||||
numTensors_(numTensors),
|
||||
numDevices_(cudaNumDevices()) {
|
||||
// Allocate inputs on available devices in a round robin fashion.
|
||||
::at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
|
||||
::at::globalContext().lazyInitCUDA();
|
||||
inputs_.resize(numTensors_);
|
||||
for (const auto i : c10::irange(numTensors_)) {
|
||||
inputs_[i] = at::empty(
|
||||
|
@ -75,7 +75,7 @@ class NCCLTest : public NCCLTestBase {
|
||||
int inputDim = 3)
|
||||
: NCCLTestBase(path, pgTimeout), rank_(rank), worldSize_(worldSize) {
|
||||
// Each device has a single tensor to perf the NCCL op
|
||||
::at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
|
||||
::at::globalContext().lazyInitCUDA();
|
||||
tensors_.resize(numDevices_);
|
||||
inputs_.resize(numDevices_);
|
||||
outputs_.resize(numDevices_);
|
||||
|
@ -139,7 +139,7 @@ struct MTIAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
|
||||
struct MTIAHooks : public at::MTIAHooksInterface {
|
||||
explicit MTIAHooks(at::MTIAHooksArgs) {}
|
||||
void init() const override {}
|
||||
void initMTIA() const override {}
|
||||
|
||||
bool hasMTIA() const override {
|
||||
return true;
|
||||
|
@ -43,7 +43,7 @@ std::vector<at::DeprecatedTypeProperties*> allCPUTypes() {
|
||||
}
|
||||
|
||||
std::vector<at::DeprecatedTypeProperties*> allCUDATypes() {
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
|
||||
at::globalContext().lazyInitCUDA();
|
||||
return allTypesForBackends({Backend::CUDA, Backend::SparseCUDA});
|
||||
}
|
||||
|
||||
@ -52,7 +52,7 @@ std::vector<at::DeprecatedTypeProperties*> allXPUTypes() {
|
||||
}
|
||||
|
||||
std::vector<at::DeprecatedTypeProperties*> allPrivateUser1Types() {
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::PrivateUse1);
|
||||
at::globalContext().lazyInitPrivateUse1();
|
||||
return allTypesForBackends(
|
||||
{Backend::PrivateUse1, Backend::SparsePrivateUse1});
|
||||
}
|
||||
|
@ -890,7 +890,7 @@ PyObject* THCPModule_attachOutOfMemoryObserver(
|
||||
}
|
||||
Py_XDECREF(result);
|
||||
};
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
|
||||
at::globalContext().lazyInitCUDA();
|
||||
c10::cuda::CUDACachingAllocator::attachOutOfMemoryObserver(std::move(obs));
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
@ -1425,7 +1425,7 @@ static PyObject* THCPModule_initExtension(PyObject* self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
|
||||
poison_fork();
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
|
||||
at::globalContext().lazyInitCUDA();
|
||||
|
||||
auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda"));
|
||||
if (!m)
|
||||
|
@ -138,7 +138,7 @@ void _record_memory_history(
|
||||
} else if (record_context) {
|
||||
when = c10::cuda::CUDACachingAllocator::RecordContext::STATE;
|
||||
}
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
|
||||
at::globalContext().lazyInitCUDA();
|
||||
_initRecordAnnotations();
|
||||
c10::cuda::CUDACachingAllocator::recordHistory(
|
||||
enabled, recorder, trace_alloc_max_entries, when);
|
||||
@ -189,7 +189,7 @@ void _record_memory_history(
|
||||
when = c10::cuda::CUDACachingAllocator::RecordContext::STATE;
|
||||
}
|
||||
}
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
|
||||
at::globalContext().lazyInitCUDA();
|
||||
_initRecordAnnotations();
|
||||
c10::cuda::CUDACachingAllocator::recordHistory(
|
||||
enabled.has_value(), recorder, max_entries, when);
|
||||
|
@ -1032,10 +1032,9 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
// SEGMENT_FREE action occurs.
|
||||
// We attach hooks only once at the first PG creation.
|
||||
// Attaching hooks fails if CUDACachingAllocator is not initialized, so
|
||||
// Init for CUDA is called (and is a no-op if CUDA is already
|
||||
// initialized).
|
||||
// lazyInitCUDA is called (and is a no-op if CUDA is already initialized).
|
||||
if (useTensorRegisterAllocatorHook_ && !allocatorHooksAttached) {
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
|
||||
at::globalContext().lazyInitCUDA();
|
||||
c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker(
|
||||
&cacheAllocatorRegisterHook);
|
||||
c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker(
|
||||
|
@ -21,7 +21,7 @@ static C10_UNUSED at::Tensor to_dispatch(
|
||||
bool non_blocking,
|
||||
bool copy) {
|
||||
if (device && device->is_cuda()) {
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
|
||||
at::globalContext().lazyInitCUDA();
|
||||
}
|
||||
if (!device && !scalarType && !copy) {
|
||||
return self;
|
||||
|
@ -39,7 +39,7 @@ void initModule(PyObject* module) {
|
||||
m.def("_mtia_init", []() {
|
||||
TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
|
||||
poison_fork();
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::MTIA);
|
||||
at::globalContext().lazyInitMTIA();
|
||||
});
|
||||
|
||||
m.def("_mtia_isBuilt", []() {
|
||||
|
@ -363,7 +363,7 @@ static PyObject* THXPModule_initExtension(PyObject* self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
|
||||
poison_fork();
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::XPU);
|
||||
at::globalContext().lazyInitXPU();
|
||||
|
||||
auto m = THPObjectPtr(PyImport_ImportModule("torch.xpu"));
|
||||
if (!m)
|
||||
|
@ -515,7 +515,9 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
|
||||
|
||||
# CUDA requires special handling
|
||||
if is_cuda_dispatch_key(self.backend_index.dispatch_key):
|
||||
device_guard = f"globalContext().lazyInitDevice(c10::DeviceType::CUDA);\n{device_guard}"
|
||||
device_guard = (
|
||||
f"globalContext().lazyInitCUDA();\n{device_guard}"
|
||||
)
|
||||
else:
|
||||
# kernel is operating on existing tensors
|
||||
|
||||
|
Reference in New Issue
Block a user