mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 22:25:10 +08:00
Introduce mlc
device (ML Compute device) to PyTorch's device list (#50634)
Summary: Apple recently announced ML Compute, a new framework available in macOS Big Sur, which enables users to accelerate the training of neural networks on Mac hardware. This PR is the first on a series of PRs that will enable the integration with ML Compute. Most of the integration code will live on a separate subrepo named `mlc`. The integration with `mlc` (ML Compute) will be very similar to that of xla. We rely on registering our ops through: TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { m.impl_UNBOXED(<op_schema_name>, &customized_op_kernel) ... } Pull Request resolved: https://github.com/pytorch/pytorch/pull/50634 Reviewed By: malfet Differential Revision: D26614213 Pulled By: smessmer fbshipit-source-id: 3b492b346c61cc3950ac880ac01a82fbdddbc07b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
2bdf6305a0
commit
30cb6ac53c
@ -157,6 +157,7 @@ class ExperimentalFeatureConfigNode(TreeConfigNode):
|
|||||||
next_nodes = {
|
next_nodes = {
|
||||||
"asan": AsanConfigNode,
|
"asan": AsanConfigNode,
|
||||||
"xla": XlaConfigNode,
|
"xla": XlaConfigNode,
|
||||||
|
"mlc": MLCConfigNode,
|
||||||
"vulkan": VulkanConfigNode,
|
"vulkan": VulkanConfigNode,
|
||||||
"parallel_tbb": ParallelTBBConfigNode,
|
"parallel_tbb": ParallelTBBConfigNode,
|
||||||
"parallel_native": ParallelNativeConfigNode,
|
"parallel_native": ParallelNativeConfigNode,
|
||||||
@ -193,6 +194,16 @@ class XlaConfigNode(TreeConfigNode):
|
|||||||
def child_constructor(self):
|
def child_constructor(self):
|
||||||
return ImportantConfigNode
|
return ImportantConfigNode
|
||||||
|
|
||||||
|
class MLCConfigNode(TreeConfigNode):
|
||||||
|
def modify_label(self, label):
|
||||||
|
return "MLC=" + str(label)
|
||||||
|
|
||||||
|
def init2(self, node_name):
|
||||||
|
self.props["is_mlc"] = node_name
|
||||||
|
|
||||||
|
def child_constructor(self):
|
||||||
|
return ImportantConfigNode
|
||||||
|
|
||||||
|
|
||||||
class AsanConfigNode(TreeConfigNode):
|
class AsanConfigNode(TreeConfigNode):
|
||||||
def modify_label(self, label):
|
def modify_label(self, label):
|
||||||
|
@ -73,6 +73,9 @@ class TORCH_API Context {
|
|||||||
bool hasXLA() const {
|
bool hasXLA() const {
|
||||||
return c10::impl::hasDeviceGuardImpl(at::DeviceType::XLA);
|
return c10::impl::hasDeviceGuardImpl(at::DeviceType::XLA);
|
||||||
}
|
}
|
||||||
|
bool hasMLC() const {
|
||||||
|
return c10::impl::hasDeviceGuardImpl(at::DeviceType::MLC);
|
||||||
|
}
|
||||||
// defined in header so that getNonVariableType has ability to inline
|
// defined in header so that getNonVariableType has ability to inline
|
||||||
// call_once check. getNonVariableType is called fairly frequently
|
// call_once check. getNonVariableType is called fairly frequently
|
||||||
THCState* lazyInitCUDA() {
|
THCState* lazyInitCUDA() {
|
||||||
@ -276,6 +279,10 @@ static inline bool hasXLA() {
|
|||||||
return globalContext().hasXLA();
|
return globalContext().hasXLA();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline bool hasMLC() {
|
||||||
|
return globalContext().hasMLC();
|
||||||
|
}
|
||||||
|
|
||||||
// Despite its name, this function returns the number of *CUDA* GPUs.
|
// Despite its name, this function returns the number of *CUDA* GPUs.
|
||||||
static inline size_t getNumGPUs() {
|
static inline size_t getNumGPUs() {
|
||||||
// WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
|
// WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
|
||||||
|
@ -194,6 +194,7 @@ std::string show_config() {
|
|||||||
|
|
||||||
// TODO: do HIP
|
// TODO: do HIP
|
||||||
// TODO: do XLA
|
// TODO: do XLA
|
||||||
|
// TODO: do MLC
|
||||||
|
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
@ -48,4 +48,8 @@ TORCH_LIBRARY_IMPL(_, AutogradXLA, m) {
|
|||||||
m.fallback(torch::CppFunction::makeFallthrough());
|
m.fallback(torch::CppFunction::makeFallthrough());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL(_, AutogradMLC, m) {
|
||||||
|
m.fallback(torch::CppFunction::makeFallthrough());
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -397,6 +397,7 @@ _(aten, is_coalesced) \
|
|||||||
_(aten, is_complex) \
|
_(aten, is_complex) \
|
||||||
_(aten, is_contiguous) \
|
_(aten, is_contiguous) \
|
||||||
_(aten, is_cuda) \
|
_(aten, is_cuda) \
|
||||||
|
_(aten, is_mlc) \
|
||||||
_(aten, is_distributed) \
|
_(aten, is_distributed) \
|
||||||
_(aten, is_floating_point) \
|
_(aten, is_floating_point) \
|
||||||
_(aten, is_nonzero) \
|
_(aten, is_nonzero) \
|
||||||
|
@ -364,6 +364,9 @@ class TORCH_API Tensor {
|
|||||||
/// Returns if a `Tensor` is mkldnn tensor.
|
/// Returns if a `Tensor` is mkldnn tensor.
|
||||||
bool is_mkldnn() const;
|
bool is_mkldnn() const;
|
||||||
|
|
||||||
|
/// Returns if a `Tensor` is mlc tensor.
|
||||||
|
bool is_mlc() const;
|
||||||
|
|
||||||
/// Returns if a `Tensor` is vulkan tensor.
|
/// Returns if a `Tensor` is vulkan tensor.
|
||||||
bool is_vulkan() const;
|
bool is_vulkan() const;
|
||||||
|
|
||||||
|
@ -145,6 +145,15 @@ bool is_mkldnn(Tensor self) {
|
|||||||
return self.is_mkldnn();
|
return self.is_mkldnn();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Tensor::is_mlc() const {
|
||||||
|
// NB: this is not a native function to avoid dispatching overhead.
|
||||||
|
return impl_->is_mlc();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_mlc(Tensor self) {
|
||||||
|
return self.is_mlc();
|
||||||
|
}
|
||||||
|
|
||||||
bool Tensor::is_vulkan() const {
|
bool Tensor::is_vulkan() const {
|
||||||
// NB: this is not a native function to avoid dispatching overhead.
|
// NB: this is not a native function to avoid dispatching overhead.
|
||||||
return impl_->is_vulkan();
|
return impl_->is_vulkan();
|
||||||
|
@ -45,6 +45,7 @@ enum class Backend {
|
|||||||
QuantizedXPU,
|
QuantizedXPU,
|
||||||
Undefined,
|
Undefined,
|
||||||
MkldnnCPU,
|
MkldnnCPU,
|
||||||
|
MLC,
|
||||||
NumOptions
|
NumOptions
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -99,6 +100,8 @@ static inline Backend toDense(Backend b) {
|
|||||||
return Backend::QuantizedCUDA;
|
return Backend::QuantizedCUDA;
|
||||||
case Backend::QuantizedXPU:
|
case Backend::QuantizedXPU:
|
||||||
return Backend::QuantizedXPU;
|
return Backend::QuantizedXPU;
|
||||||
|
case Backend::MLC:
|
||||||
|
return Backend::MLC;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown backend");
|
throw std::runtime_error("Unknown backend");
|
||||||
}
|
}
|
||||||
@ -117,6 +120,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) {
|
|||||||
return Backend::MSNPU;
|
return Backend::MSNPU;
|
||||||
} else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
|
} else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
|
||||||
return Backend::XLA;
|
return Backend::XLA;
|
||||||
|
} else if (t == DispatchKey::MLC || t == DispatchKey::AutogradMLC) {
|
||||||
|
return Backend::MLC;
|
||||||
} else if (t == DispatchKey::Vulkan) {
|
} else if (t == DispatchKey::Vulkan) {
|
||||||
return Backend::Vulkan;
|
return Backend::Vulkan;
|
||||||
} else if (t == DispatchKey::Metal) {
|
} else if (t == DispatchKey::Metal) {
|
||||||
@ -182,6 +187,8 @@ static inline DispatchKey backendToDispatchKey(Backend b) {
|
|||||||
return DispatchKey::QuantizedCUDA;
|
return DispatchKey::QuantizedCUDA;
|
||||||
case Backend::Undefined:
|
case Backend::Undefined:
|
||||||
return DispatchKey::Undefined;
|
return DispatchKey::Undefined;
|
||||||
|
case Backend::MLC:
|
||||||
|
return DispatchKey::MLC;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Unknown backend");
|
throw std::runtime_error("Unknown backend");
|
||||||
}
|
}
|
||||||
@ -220,6 +227,8 @@ static inline DeviceType backendToDeviceType(Backend b) {
|
|||||||
return DeviceType::Vulkan;
|
return DeviceType::Vulkan;
|
||||||
case Backend::Metal:
|
case Backend::Metal:
|
||||||
return DeviceType::Metal;
|
return DeviceType::Metal;
|
||||||
|
case Backend::MLC:
|
||||||
|
return DeviceType::MLC;
|
||||||
case Backend::Undefined:
|
case Backend::Undefined:
|
||||||
AT_ERROR("Undefined backend is not a valid device type");
|
AT_ERROR("Undefined backend is not a valid device type");
|
||||||
default:
|
default:
|
||||||
@ -250,6 +259,8 @@ static inline Backend backendToCPU(Backend b) {
|
|||||||
case Backend::MSNPU:
|
case Backend::MSNPU:
|
||||||
case Backend::XLA:
|
case Backend::XLA:
|
||||||
return Backend::CPU;
|
return Backend::CPU;
|
||||||
|
case Backend::MLC:
|
||||||
|
return Backend::CPU;
|
||||||
case Backend::MkldnnCPU:
|
case Backend::MkldnnCPU:
|
||||||
return Backend::MkldnnCPU;
|
return Backend::MkldnnCPU;
|
||||||
case Backend::QuantizedCPU:
|
case Backend::QuantizedCPU:
|
||||||
@ -302,6 +313,7 @@ static inline Backend backendToCUDA(Backend b) {
|
|||||||
case Backend::FPGA:
|
case Backend::FPGA:
|
||||||
case Backend::MSNPU:
|
case Backend::MSNPU:
|
||||||
case Backend::XLA:
|
case Backend::XLA:
|
||||||
|
case Backend::MLC:
|
||||||
return Backend::CUDA;
|
return Backend::CUDA;
|
||||||
case Backend::SparseXPU:
|
case Backend::SparseXPU:
|
||||||
case Backend::SparseCPU:
|
case Backend::SparseCPU:
|
||||||
@ -324,6 +336,7 @@ static inline Backend backendToHIP(Backend b) {
|
|||||||
case Backend::FPGA:
|
case Backend::FPGA:
|
||||||
case Backend::MSNPU:
|
case Backend::MSNPU:
|
||||||
case Backend::XLA:
|
case Backend::XLA:
|
||||||
|
case Backend::MLC:
|
||||||
return Backend::HIP;
|
return Backend::HIP;
|
||||||
case Backend::SparseXPU:
|
case Backend::SparseXPU:
|
||||||
case Backend::SparseCPU:
|
case Backend::SparseCPU:
|
||||||
@ -354,6 +367,8 @@ static inline const char* toString(Backend b) {
|
|||||||
return "MSNPU";
|
return "MSNPU";
|
||||||
case Backend::XLA:
|
case Backend::XLA:
|
||||||
return "XLA";
|
return "XLA";
|
||||||
|
case Backend::MLC:
|
||||||
|
return "MLC";
|
||||||
case Backend::SparseCPU:
|
case Backend::SparseCPU:
|
||||||
return "SparseCPU";
|
return "SparseCPU";
|
||||||
case Backend::SparseCUDA:
|
case Backend::SparseCUDA:
|
||||||
|
@ -46,6 +46,7 @@ DeviceType parse_type(const std::string& device_string) {
|
|||||||
{"msnpu", DeviceType::MSNPU},
|
{"msnpu", DeviceType::MSNPU},
|
||||||
{"xla", DeviceType::XLA},
|
{"xla", DeviceType::XLA},
|
||||||
{"vulkan", DeviceType::Vulkan},
|
{"vulkan", DeviceType::Vulkan},
|
||||||
|
{"mlc", DeviceType::MLC},
|
||||||
}};
|
}};
|
||||||
auto device = std::find_if(
|
auto device = std::find_if(
|
||||||
types.begin(),
|
types.begin(),
|
||||||
@ -57,7 +58,7 @@ DeviceType parse_type(const std::string& device_string) {
|
|||||||
return device->second;
|
return device->second;
|
||||||
}
|
}
|
||||||
AT_ERROR(
|
AT_ERROR(
|
||||||
"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan device type at start of device string: ",
|
"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan device type at start of device string: ",
|
||||||
device_string);
|
device_string);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -27,6 +27,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) {
|
|||||||
return lower_case ? "msnpu" : "MSNPU";
|
return lower_case ? "msnpu" : "MSNPU";
|
||||||
case DeviceType::XLA:
|
case DeviceType::XLA:
|
||||||
return lower_case ? "xla" : "XLA";
|
return lower_case ? "xla" : "XLA";
|
||||||
|
case DeviceType::MLC:
|
||||||
|
return lower_case ? "mlc" : "MLC";
|
||||||
case DeviceType::Vulkan:
|
case DeviceType::Vulkan:
|
||||||
return lower_case ? "vulkan" : "VULKAN";
|
return lower_case ? "vulkan" : "VULKAN";
|
||||||
case DeviceType::Metal:
|
case DeviceType::Metal:
|
||||||
@ -65,6 +67,7 @@ bool isValidDeviceType(DeviceType d) {
|
|||||||
case DeviceType::FPGA:
|
case DeviceType::FPGA:
|
||||||
case DeviceType::MSNPU:
|
case DeviceType::MSNPU:
|
||||||
case DeviceType::XLA:
|
case DeviceType::XLA:
|
||||||
|
case DeviceType::MLC:
|
||||||
case DeviceType::Vulkan:
|
case DeviceType::Vulkan:
|
||||||
case DeviceType::Metal:
|
case DeviceType::Metal:
|
||||||
case DeviceType::XPU:
|
case DeviceType::XPU:
|
||||||
|
@ -26,11 +26,12 @@ enum class DeviceType : int8_t {
|
|||||||
Vulkan = 10, // Vulkan
|
Vulkan = 10, // Vulkan
|
||||||
Metal = 11, // Metal
|
Metal = 11, // Metal
|
||||||
XPU = 12, // XPU
|
XPU = 12, // XPU
|
||||||
|
MLC = 13, //ML Compute / Apple
|
||||||
// NB: If you add more devices:
|
// NB: If you add more devices:
|
||||||
// - Change the implementations of DeviceTypeName and isValidDeviceType
|
// - Change the implementations of DeviceTypeName and isValidDeviceType
|
||||||
// in DeviceType.cpp
|
// in DeviceType.cpp
|
||||||
// - Change the number below
|
// - Change the number below
|
||||||
COMPILE_TIME_MAX_DEVICE_TYPES = 13,
|
COMPILE_TIME_MAX_DEVICE_TYPES = 14,
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr DeviceType kCPU = DeviceType::CPU;
|
constexpr DeviceType kCPU = DeviceType::CPU;
|
||||||
@ -39,6 +40,7 @@ constexpr DeviceType kHIP = DeviceType::HIP;
|
|||||||
constexpr DeviceType kFPGA = DeviceType::FPGA;
|
constexpr DeviceType kFPGA = DeviceType::FPGA;
|
||||||
constexpr DeviceType kMSNPU = DeviceType::MSNPU;
|
constexpr DeviceType kMSNPU = DeviceType::MSNPU;
|
||||||
constexpr DeviceType kXLA = DeviceType::XLA;
|
constexpr DeviceType kXLA = DeviceType::XLA;
|
||||||
|
constexpr DeviceType kMLC = DeviceType::MLC;
|
||||||
constexpr DeviceType kVulkan = DeviceType::Vulkan;
|
constexpr DeviceType kVulkan = DeviceType::Vulkan;
|
||||||
constexpr DeviceType kMetal = DeviceType::Metal;
|
constexpr DeviceType kMetal = DeviceType::Metal;
|
||||||
constexpr DeviceType kXPU = DeviceType::XPU;
|
constexpr DeviceType kXPU = DeviceType::XPU;
|
||||||
|
@ -21,6 +21,8 @@ const char* toString(DispatchKey t) {
|
|||||||
return "MSNPU";
|
return "MSNPU";
|
||||||
case DispatchKey::XLA:
|
case DispatchKey::XLA:
|
||||||
return "XLA";
|
return "XLA";
|
||||||
|
case DispatchKey::MLC:
|
||||||
|
return "MLC";
|
||||||
case DispatchKey::Vulkan:
|
case DispatchKey::Vulkan:
|
||||||
return "Vulkan";
|
return "Vulkan";
|
||||||
case DispatchKey::Metal:
|
case DispatchKey::Metal:
|
||||||
@ -80,6 +82,8 @@ const char* toString(DispatchKey t) {
|
|||||||
return "AutogradCUDA";
|
return "AutogradCUDA";
|
||||||
case DispatchKey::AutogradXLA:
|
case DispatchKey::AutogradXLA:
|
||||||
return "AutogradXLA";
|
return "AutogradXLA";
|
||||||
|
case DispatchKey::AutogradMLC:
|
||||||
|
return "AutogradMLC";
|
||||||
case DispatchKey::AutogradNestedTensor:
|
case DispatchKey::AutogradNestedTensor:
|
||||||
return "AutogradNestedTensor";
|
return "AutogradNestedTensor";
|
||||||
case DispatchKey::AutogradPrivateUse1:
|
case DispatchKey::AutogradPrivateUse1:
|
||||||
@ -143,6 +147,8 @@ DispatchKey getAutogradKeyFromBackend(DispatchKey t) {
|
|||||||
return DispatchKey::AutogradCUDA;
|
return DispatchKey::AutogradCUDA;
|
||||||
case DispatchKey::XLA:
|
case DispatchKey::XLA:
|
||||||
return DispatchKey::AutogradXLA;
|
return DispatchKey::AutogradXLA;
|
||||||
|
case DispatchKey::MLC:
|
||||||
|
return DispatchKey::AutogradMLC;
|
||||||
case DispatchKey::NestedTensor:
|
case DispatchKey::NestedTensor:
|
||||||
return DispatchKey::AutogradNestedTensor;
|
return DispatchKey::AutogradNestedTensor;
|
||||||
case DispatchKey::PrivateUse1:
|
case DispatchKey::PrivateUse1:
|
||||||
|
@ -62,6 +62,7 @@ enum class DispatchKey : uint8_t {
|
|||||||
MSNPU, // unused externally, but tested at
|
MSNPU, // unused externally, but tested at
|
||||||
// test/cpp_extensions/msnpu_extension.cpp
|
// test/cpp_extensions/msnpu_extension.cpp
|
||||||
XLA, // lives out of tree at https://github.com/pytorch/xla
|
XLA, // lives out of tree at https://github.com/pytorch/xla
|
||||||
|
MLC, // lives out of tree at https://github.com/pytorch/MLCompute
|
||||||
Vulkan,
|
Vulkan,
|
||||||
Metal,
|
Metal,
|
||||||
XPU, // For out of tree Intel's heterogeneous computing plug-in
|
XPU, // For out of tree Intel's heterogeneous computing plug-in
|
||||||
@ -224,9 +225,9 @@ enum class DispatchKey : uint8_t {
|
|||||||
AutogradCPU,
|
AutogradCPU,
|
||||||
AutogradCUDA,
|
AutogradCUDA,
|
||||||
AutogradXLA,
|
AutogradXLA,
|
||||||
AutogradNestedTensor, // lives out of tree at
|
|
||||||
// https://github.com/pytorch/nestedtensor
|
|
||||||
AutogradXPU,
|
AutogradXPU,
|
||||||
|
AutogradMLC,
|
||||||
|
AutogradNestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor
|
||||||
// Here are some reserved pre-autograd keys for user-defined backends, see
|
// Here are some reserved pre-autograd keys for user-defined backends, see
|
||||||
// Note [Private use DispatchKey]
|
// Note [Private use DispatchKey]
|
||||||
AutogradPrivateUse1,
|
AutogradPrivateUse1,
|
||||||
|
@ -14,6 +14,7 @@ constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends |
|
|||||||
DispatchKey::PrivateUse1,
|
DispatchKey::PrivateUse1,
|
||||||
DispatchKey::PrivateUse2,
|
DispatchKey::PrivateUse2,
|
||||||
DispatchKey::PrivateUse3,
|
DispatchKey::PrivateUse3,
|
||||||
|
DispatchKey::MLC,
|
||||||
});
|
});
|
||||||
|
|
||||||
bool isBackendDispatchKey(DispatchKey t) {
|
bool isBackendDispatchKey(DispatchKey t) {
|
||||||
@ -48,6 +49,8 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
|
|||||||
return DispatchKeySet(DispatchKey::CUDA);
|
return DispatchKeySet(DispatchKey::CUDA);
|
||||||
case DispatchKey::AutogradXLA:
|
case DispatchKey::AutogradXLA:
|
||||||
return DispatchKeySet(DispatchKey::XLA);
|
return DispatchKeySet(DispatchKey::XLA);
|
||||||
|
case DispatchKey::AutogradMLC:
|
||||||
|
return DispatchKeySet(DispatchKey::MLC);
|
||||||
case DispatchKey::AutogradNestedTensor:
|
case DispatchKey::AutogradNestedTensor:
|
||||||
return DispatchKeySet(DispatchKey::NestedTensor);
|
return DispatchKeySet(DispatchKey::NestedTensor);
|
||||||
case DispatchKey::AutogradXPU:
|
case DispatchKey::AutogradXPU:
|
||||||
|
@ -195,6 +195,7 @@ constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
|
|||||||
DispatchKey::AutogradCUDA,
|
DispatchKey::AutogradCUDA,
|
||||||
DispatchKey::AutogradXLA,
|
DispatchKey::AutogradXLA,
|
||||||
DispatchKey::AutogradNestedTensor,
|
DispatchKey::AutogradNestedTensor,
|
||||||
|
DispatchKey::AutogradMLC,
|
||||||
DispatchKey::AutogradXPU,
|
DispatchKey::AutogradXPU,
|
||||||
DispatchKey::AutogradPrivateUse1,
|
DispatchKey::AutogradPrivateUse1,
|
||||||
DispatchKey::AutogradPrivateUse2,
|
DispatchKey::AutogradPrivateUse2,
|
||||||
|
@ -543,6 +543,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||||||
return key_set_.has(DispatchKey::Metal);
|
return key_set_.has(DispatchKey::Metal);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_mlc() const {
|
||||||
|
return key_set_.has(DispatchKey::MLC);
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: remove this once we don't automatically enabled Autograd dispatch keys
|
// TODO: remove this once we don't automatically enabled Autograd dispatch keys
|
||||||
// in TensorImpl constructor.
|
// in TensorImpl constructor.
|
||||||
// DON'T USE THIS API!! It's only created for testing purpose in
|
// DON'T USE THIS API!! It's only created for testing purpose in
|
||||||
|
@ -629,6 +629,8 @@ inline DispatchKey computeDispatchKey(c10::optional<ScalarType> dtype, c10::opti
|
|||||||
return DispatchKey::MSNPU;
|
return DispatchKey::MSNPU;
|
||||||
case DeviceType::XLA:
|
case DeviceType::XLA:
|
||||||
return DispatchKey::XLA;
|
return DispatchKey::XLA;
|
||||||
|
case DeviceType::MLC:
|
||||||
|
return DispatchKey::MLC;
|
||||||
case DeviceType::Vulkan:
|
case DeviceType::Vulkan:
|
||||||
return DispatchKey::Vulkan;
|
return DispatchKey::Vulkan;
|
||||||
case DeviceType::Metal:
|
case DeviceType::Metal:
|
||||||
@ -687,6 +689,8 @@ inline DeviceType computeDeviceType(DispatchKey tid) {
|
|||||||
return DeviceType::MSNPU;
|
return DeviceType::MSNPU;
|
||||||
} else if (tid == DispatchKey::XLA) {
|
} else if (tid == DispatchKey::XLA) {
|
||||||
return DeviceType::XLA;
|
return DeviceType::XLA;
|
||||||
|
} else if (tid == DispatchKey::MLC) {
|
||||||
|
return DeviceType::MLC;
|
||||||
} else if (tid == DispatchKey::SparseCPU) {
|
} else if (tid == DispatchKey::SparseCPU) {
|
||||||
return DeviceType::CPU;
|
return DeviceType::CPU;
|
||||||
} else if (tid == DispatchKey::SparseCUDA) {
|
} else if (tid == DispatchKey::SparseCUDA) {
|
||||||
|
@ -198,8 +198,9 @@ enum DeviceTypeProto {
|
|||||||
PROTO_FPGA = 7; // FPGA
|
PROTO_FPGA = 7; // FPGA
|
||||||
PROTO_MSNPU = 8; // MSNPU
|
PROTO_MSNPU = 8; // MSNPU
|
||||||
PROTO_XLA = 9; // XLA / TPU
|
PROTO_XLA = 9; // XLA / TPU
|
||||||
|
PROTO_MLC = 10; // ML Compute
|
||||||
// Change the following number if you add more devices in the code.
|
// Change the following number if you add more devices in the code.
|
||||||
PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 10;
|
PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 11;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Device-specific options. We do not distinguish DeviceOption protos for
|
// Device-specific options. We do not distinguish DeviceOption protos for
|
||||||
|
@ -179,6 +179,12 @@ def _rebuild_xla_tensor(data, dtype, device, requires_grad):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def _rebuild_mlc_tensor(data, dtype, device, requires_grad):
|
||||||
|
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
|
||||||
|
tensor.requires_grad = requires_grad
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def _rebuild_qtensor(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks):
|
def _rebuild_qtensor(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks):
|
||||||
qscheme = quantizer_params[0]
|
qscheme = quantizer_params[0]
|
||||||
if qscheme == torch.per_tensor_affine:
|
if qscheme == torch.per_tensor_affine:
|
||||||
|
@ -95,6 +95,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
|
|||||||
.value("FPGA", c10::DeviceType::FPGA)
|
.value("FPGA", c10::DeviceType::FPGA)
|
||||||
.value("MSNPU", c10::DeviceType::MSNPU)
|
.value("MSNPU", c10::DeviceType::MSNPU)
|
||||||
.value("XLA", c10::DeviceType::XLA)
|
.value("XLA", c10::DeviceType::XLA)
|
||||||
|
.value("MLC", c10::DeviceType::MLC)
|
||||||
.value("Vulkan", c10::DeviceType::Vulkan)
|
.value("Vulkan", c10::DeviceType::Vulkan)
|
||||||
.value("Metal", c10::DeviceType::Metal);
|
.value("Metal", c10::DeviceType::Metal);
|
||||||
|
|
||||||
|
@ -610,6 +610,17 @@ PyObject *THPVariable_is_mkldnn(THPVariable *self, void *unused)
|
|||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject *THPVariable_is_mlc(THPVariable *self, void *unused)
|
||||||
|
{
|
||||||
|
HANDLE_TH_ERRORS
|
||||||
|
if (check_has_torch_function((PyObject *)self)) {
|
||||||
|
return handle_torch_function_getter(self, "is_mlc");
|
||||||
|
}
|
||||||
|
auto& self_ = self->cdata;
|
||||||
|
return torch::autograd::utils::wrap(self_.is_mlc());
|
||||||
|
END_HANDLE_TH_ERRORS
|
||||||
|
}
|
||||||
|
|
||||||
PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused)
|
PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused)
|
||||||
{
|
{
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
@ -751,6 +762,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
|
|||||||
{"is_xpu", (getter)THPVariable_is_xpu, nullptr, nullptr, nullptr},
|
{"is_xpu", (getter)THPVariable_is_xpu, nullptr, nullptr, nullptr},
|
||||||
{"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
|
{"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
|
||||||
{"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
|
{"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
|
||||||
|
{"is_mlc", (getter)THPVariable_is_mlc, nullptr, nullptr, nullptr},
|
||||||
{"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},
|
{"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},
|
||||||
{"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr},
|
{"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr},
|
||||||
{"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr},
|
{"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr},
|
||||||
|
@ -110,6 +110,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
|
|||||||
{"is_xpu", "prim"},
|
{"is_xpu", "prim"},
|
||||||
{"is_sparse", "prim"},
|
{"is_sparse", "prim"},
|
||||||
{"is_mkldnn", "prim"},
|
{"is_mkldnn", "prim"},
|
||||||
|
{"is_mlc", "prim"},
|
||||||
{"is_quantized", "prim"},
|
{"is_quantized", "prim"},
|
||||||
{"is_vulkan", "prim"},
|
{"is_vulkan", "prim"},
|
||||||
{"is_meta", "prim"},
|
{"is_meta", "prim"},
|
||||||
|
@ -289,6 +289,14 @@ RegisterOperators reg(
|
|||||||
push(stack, a.is_mkldnn());
|
push(stack, a.is_mkldnn());
|
||||||
},
|
},
|
||||||
aliasAnalysisFromSchema()),
|
aliasAnalysisFromSchema()),
|
||||||
|
Operator(
|
||||||
|
"prim::is_mlc(Tensor a) -> bool",
|
||||||
|
[](Stack* stack) {
|
||||||
|
at::Tensor a;
|
||||||
|
pop(stack, a);
|
||||||
|
push(stack, a.is_mlc());
|
||||||
|
},
|
||||||
|
aliasAnalysisFromSchema()),
|
||||||
Operator(
|
Operator(
|
||||||
"prim::is_vulkan(Tensor a) -> bool",
|
"prim::is_vulkan(Tensor a) -> bool",
|
||||||
[](Stack* stack) {
|
[](Stack* stack) {
|
||||||
|
@ -61,6 +61,9 @@ Backend backendToBackendOfDeviceType(Backend b, DeviceType d) {
|
|||||||
return Backend::XLA;
|
return Backend::XLA;
|
||||||
case DeviceType::XPU:
|
case DeviceType::XPU:
|
||||||
return backendToXPU(b);
|
return backendToXPU(b);
|
||||||
|
case DeviceType::MLC:
|
||||||
|
TORCH_CHECK(!isSparse(b), "Sparse not implemented for MLC");
|
||||||
|
return Backend::MLC;
|
||||||
default:
|
default:
|
||||||
AT_ERROR("Unknown device type");
|
AT_ERROR("Unknown device type");
|
||||||
}
|
}
|
||||||
|
@ -292,6 +292,8 @@ inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) {
|
|||||||
return c10::DispatchKey::CUDA;
|
return c10::DispatchKey::CUDA;
|
||||||
case c10::DeviceType::XLA:
|
case c10::DeviceType::XLA:
|
||||||
return c10::DispatchKey::XLA;
|
return c10::DispatchKey::XLA;
|
||||||
|
case c10::DeviceType::MLC:
|
||||||
|
return c10::DispatchKey::MLC;
|
||||||
case c10::DeviceType::HIP:
|
case c10::DeviceType::HIP:
|
||||||
return c10::DispatchKey::HIP;
|
return c10::DispatchKey::HIP;
|
||||||
case c10::DeviceType::MSNPU:
|
case c10::DeviceType::MSNPU:
|
||||||
|
@ -910,6 +910,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||||||
Tensor.is_xpu.__get__: lambda self: -1,
|
Tensor.is_xpu.__get__: lambda self: -1,
|
||||||
Tensor.is_leaf.__get__: lambda self: -1,
|
Tensor.is_leaf.__get__: lambda self: -1,
|
||||||
Tensor.is_meta.__get__: lambda self: -1,
|
Tensor.is_meta.__get__: lambda self: -1,
|
||||||
|
Tensor.is_mlc.__get__: lambda self: -1,
|
||||||
Tensor.is_mkldnn.__get__: lambda self: -1,
|
Tensor.is_mkldnn.__get__: lambda self: -1,
|
||||||
Tensor.is_quantized.__get__: lambda self: -1,
|
Tensor.is_quantized.__get__: lambda self: -1,
|
||||||
Tensor.is_sparse.__get__: lambda self: -1,
|
Tensor.is_sparse.__get__: lambda self: -1,
|
||||||
|
@ -57,7 +57,7 @@ class Tensor(torch._C._TensorBase):
|
|||||||
if id(self) in memo:
|
if id(self) in memo:
|
||||||
return memo[id(self)]
|
return memo[id(self)]
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.is_sparse or self.device.type == 'xla':
|
if self.is_sparse or self.device.type == 'xla' or self.device.type == 'mlc':
|
||||||
new_tensor = self.clone()
|
new_tensor = self.clone()
|
||||||
else:
|
else:
|
||||||
new_storage = self.storage().__deepcopy__(memo)
|
new_storage = self.storage().__deepcopy__(memo)
|
||||||
@ -123,6 +123,12 @@ class Tensor(torch._C._TensorBase):
|
|||||||
str(self.device),
|
str(self.device),
|
||||||
self.requires_grad)
|
self.requires_grad)
|
||||||
return (torch._utils._rebuild_xla_tensor, arg_xla)
|
return (torch._utils._rebuild_xla_tensor, arg_xla)
|
||||||
|
if self.device.type == 'mlc':
|
||||||
|
arg_mlc = (self.cpu().numpy(),
|
||||||
|
self.dtype,
|
||||||
|
str(self.device),
|
||||||
|
self.requires_grad)
|
||||||
|
return (torch._utils._rebuild_mlc_tensor, arg_mlc)
|
||||||
if self.is_quantized:
|
if self.is_quantized:
|
||||||
# quantizer_params can be different type based on torch attribute
|
# quantizer_params can be different type based on torch attribute
|
||||||
quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]]
|
quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]]
|
||||||
|
@ -257,7 +257,7 @@ class RemoteModuleTest(RpcAgentTestFixture):
|
|||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
r"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan"
|
r"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan"
|
||||||
" device type at start of device string",
|
" device type at start of device string",
|
||||||
):
|
):
|
||||||
list(
|
list(
|
||||||
|
Reference in New Issue
Block a user