Introduce mlc device (ML Compute device) to PyTorch's device list (#50634)

Summary:
Apple recently announced ML Compute, a new framework available in macOS Big Sur, which enables users to accelerate the training of neural networks on Mac hardware. This PR is the first on a series of PRs that will enable the integration with ML Compute. Most of the integration code will live on a separate subrepo named `mlc`.
The integration with `mlc` (ML Compute) will be very similar to that of xla. We rely on registering our ops through:

TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
 m.impl_UNBOXED(<op_schema_name>, &customized_op_kernel)
 ...
}

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50634

Reviewed By: malfet

Differential Revision: D26614213

Pulled By: smessmer

fbshipit-source-id: 3b492b346c61cc3950ac880ac01a82fbdddbc07b
This commit is contained in:
Bel H
2021-02-24 22:37:12 -08:00
committed by Facebook GitHub Bot
parent 2bdf6305a0
commit 30cb6ac53c
28 changed files with 124 additions and 7 deletions

View File

@ -157,6 +157,7 @@ class ExperimentalFeatureConfigNode(TreeConfigNode):
next_nodes = { next_nodes = {
"asan": AsanConfigNode, "asan": AsanConfigNode,
"xla": XlaConfigNode, "xla": XlaConfigNode,
"mlc": MLCConfigNode,
"vulkan": VulkanConfigNode, "vulkan": VulkanConfigNode,
"parallel_tbb": ParallelTBBConfigNode, "parallel_tbb": ParallelTBBConfigNode,
"parallel_native": ParallelNativeConfigNode, "parallel_native": ParallelNativeConfigNode,
@ -193,6 +194,16 @@ class XlaConfigNode(TreeConfigNode):
def child_constructor(self): def child_constructor(self):
return ImportantConfigNode return ImportantConfigNode
class MLCConfigNode(TreeConfigNode):
def modify_label(self, label):
return "MLC=" + str(label)
def init2(self, node_name):
self.props["is_mlc"] = node_name
def child_constructor(self):
return ImportantConfigNode
class AsanConfigNode(TreeConfigNode): class AsanConfigNode(TreeConfigNode):
def modify_label(self, label): def modify_label(self, label):

View File

@ -73,6 +73,9 @@ class TORCH_API Context {
bool hasXLA() const { bool hasXLA() const {
return c10::impl::hasDeviceGuardImpl(at::DeviceType::XLA); return c10::impl::hasDeviceGuardImpl(at::DeviceType::XLA);
} }
bool hasMLC() const {
return c10::impl::hasDeviceGuardImpl(at::DeviceType::MLC);
}
// defined in header so that getNonVariableType has ability to inline // defined in header so that getNonVariableType has ability to inline
// call_once check. getNonVariableType is called fairly frequently // call_once check. getNonVariableType is called fairly frequently
THCState* lazyInitCUDA() { THCState* lazyInitCUDA() {
@ -276,6 +279,10 @@ static inline bool hasXLA() {
return globalContext().hasXLA(); return globalContext().hasXLA();
} }
static inline bool hasMLC() {
return globalContext().hasMLC();
}
// Despite its name, this function returns the number of *CUDA* GPUs. // Despite its name, this function returns the number of *CUDA* GPUs.
static inline size_t getNumGPUs() { static inline size_t getNumGPUs() {
// WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS

View File

@ -194,6 +194,7 @@ std::string show_config() {
// TODO: do HIP // TODO: do HIP
// TODO: do XLA // TODO: do XLA
// TODO: do MLC
return ss.str(); return ss.str();
} }

View File

@ -48,4 +48,8 @@ TORCH_LIBRARY_IMPL(_, AutogradXLA, m) {
m.fallback(torch::CppFunction::makeFallthrough()); m.fallback(torch::CppFunction::makeFallthrough());
} }
TORCH_LIBRARY_IMPL(_, AutogradMLC, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
} }

View File

@ -397,6 +397,7 @@ _(aten, is_coalesced) \
_(aten, is_complex) \ _(aten, is_complex) \
_(aten, is_contiguous) \ _(aten, is_contiguous) \
_(aten, is_cuda) \ _(aten, is_cuda) \
_(aten, is_mlc) \
_(aten, is_distributed) \ _(aten, is_distributed) \
_(aten, is_floating_point) \ _(aten, is_floating_point) \
_(aten, is_nonzero) \ _(aten, is_nonzero) \

View File

@ -364,6 +364,9 @@ class TORCH_API Tensor {
/// Returns if a `Tensor` is mkldnn tensor. /// Returns if a `Tensor` is mkldnn tensor.
bool is_mkldnn() const; bool is_mkldnn() const;
/// Returns if a `Tensor` is mlc tensor.
bool is_mlc() const;
/// Returns if a `Tensor` is vulkan tensor. /// Returns if a `Tensor` is vulkan tensor.
bool is_vulkan() const; bool is_vulkan() const;

View File

@ -145,6 +145,15 @@ bool is_mkldnn(Tensor self) {
return self.is_mkldnn(); return self.is_mkldnn();
} }
bool Tensor::is_mlc() const {
// NB: this is not a native function to avoid dispatching overhead.
return impl_->is_mlc();
}
bool is_mlc(Tensor self) {
return self.is_mlc();
}
bool Tensor::is_vulkan() const { bool Tensor::is_vulkan() const {
// NB: this is not a native function to avoid dispatching overhead. // NB: this is not a native function to avoid dispatching overhead.
return impl_->is_vulkan(); return impl_->is_vulkan();

View File

@ -45,6 +45,7 @@ enum class Backend {
QuantizedXPU, QuantizedXPU,
Undefined, Undefined,
MkldnnCPU, MkldnnCPU,
MLC,
NumOptions NumOptions
}; };
@ -99,6 +100,8 @@ static inline Backend toDense(Backend b) {
return Backend::QuantizedCUDA; return Backend::QuantizedCUDA;
case Backend::QuantizedXPU: case Backend::QuantizedXPU:
return Backend::QuantizedXPU; return Backend::QuantizedXPU;
case Backend::MLC:
return Backend::MLC;
default: default:
throw std::runtime_error("Unknown backend"); throw std::runtime_error("Unknown backend");
} }
@ -117,6 +120,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) {
return Backend::MSNPU; return Backend::MSNPU;
} else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) { } else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
return Backend::XLA; return Backend::XLA;
} else if (t == DispatchKey::MLC || t == DispatchKey::AutogradMLC) {
return Backend::MLC;
} else if (t == DispatchKey::Vulkan) { } else if (t == DispatchKey::Vulkan) {
return Backend::Vulkan; return Backend::Vulkan;
} else if (t == DispatchKey::Metal) { } else if (t == DispatchKey::Metal) {
@ -182,6 +187,8 @@ static inline DispatchKey backendToDispatchKey(Backend b) {
return DispatchKey::QuantizedCUDA; return DispatchKey::QuantizedCUDA;
case Backend::Undefined: case Backend::Undefined:
return DispatchKey::Undefined; return DispatchKey::Undefined;
case Backend::MLC:
return DispatchKey::MLC;
default: default:
throw std::runtime_error("Unknown backend"); throw std::runtime_error("Unknown backend");
} }
@ -220,6 +227,8 @@ static inline DeviceType backendToDeviceType(Backend b) {
return DeviceType::Vulkan; return DeviceType::Vulkan;
case Backend::Metal: case Backend::Metal:
return DeviceType::Metal; return DeviceType::Metal;
case Backend::MLC:
return DeviceType::MLC;
case Backend::Undefined: case Backend::Undefined:
AT_ERROR("Undefined backend is not a valid device type"); AT_ERROR("Undefined backend is not a valid device type");
default: default:
@ -250,6 +259,8 @@ static inline Backend backendToCPU(Backend b) {
case Backend::MSNPU: case Backend::MSNPU:
case Backend::XLA: case Backend::XLA:
return Backend::CPU; return Backend::CPU;
case Backend::MLC:
return Backend::CPU;
case Backend::MkldnnCPU: case Backend::MkldnnCPU:
return Backend::MkldnnCPU; return Backend::MkldnnCPU;
case Backend::QuantizedCPU: case Backend::QuantizedCPU:
@ -302,6 +313,7 @@ static inline Backend backendToCUDA(Backend b) {
case Backend::FPGA: case Backend::FPGA:
case Backend::MSNPU: case Backend::MSNPU:
case Backend::XLA: case Backend::XLA:
case Backend::MLC:
return Backend::CUDA; return Backend::CUDA;
case Backend::SparseXPU: case Backend::SparseXPU:
case Backend::SparseCPU: case Backend::SparseCPU:
@ -324,6 +336,7 @@ static inline Backend backendToHIP(Backend b) {
case Backend::FPGA: case Backend::FPGA:
case Backend::MSNPU: case Backend::MSNPU:
case Backend::XLA: case Backend::XLA:
case Backend::MLC:
return Backend::HIP; return Backend::HIP;
case Backend::SparseXPU: case Backend::SparseXPU:
case Backend::SparseCPU: case Backend::SparseCPU:
@ -354,6 +367,8 @@ static inline const char* toString(Backend b) {
return "MSNPU"; return "MSNPU";
case Backend::XLA: case Backend::XLA:
return "XLA"; return "XLA";
case Backend::MLC:
return "MLC";
case Backend::SparseCPU: case Backend::SparseCPU:
return "SparseCPU"; return "SparseCPU";
case Backend::SparseCUDA: case Backend::SparseCUDA:

View File

@ -46,6 +46,7 @@ DeviceType parse_type(const std::string& device_string) {
{"msnpu", DeviceType::MSNPU}, {"msnpu", DeviceType::MSNPU},
{"xla", DeviceType::XLA}, {"xla", DeviceType::XLA},
{"vulkan", DeviceType::Vulkan}, {"vulkan", DeviceType::Vulkan},
{"mlc", DeviceType::MLC},
}}; }};
auto device = std::find_if( auto device = std::find_if(
types.begin(), types.begin(),
@ -57,7 +58,7 @@ DeviceType parse_type(const std::string& device_string) {
return device->second; return device->second;
} }
AT_ERROR( AT_ERROR(
"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan device type at start of device string: ", "Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan device type at start of device string: ",
device_string); device_string);
} }
} // namespace } // namespace

View File

@ -27,6 +27,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) {
return lower_case ? "msnpu" : "MSNPU"; return lower_case ? "msnpu" : "MSNPU";
case DeviceType::XLA: case DeviceType::XLA:
return lower_case ? "xla" : "XLA"; return lower_case ? "xla" : "XLA";
case DeviceType::MLC:
return lower_case ? "mlc" : "MLC";
case DeviceType::Vulkan: case DeviceType::Vulkan:
return lower_case ? "vulkan" : "VULKAN"; return lower_case ? "vulkan" : "VULKAN";
case DeviceType::Metal: case DeviceType::Metal:
@ -65,6 +67,7 @@ bool isValidDeviceType(DeviceType d) {
case DeviceType::FPGA: case DeviceType::FPGA:
case DeviceType::MSNPU: case DeviceType::MSNPU:
case DeviceType::XLA: case DeviceType::XLA:
case DeviceType::MLC:
case DeviceType::Vulkan: case DeviceType::Vulkan:
case DeviceType::Metal: case DeviceType::Metal:
case DeviceType::XPU: case DeviceType::XPU:

View File

@ -26,11 +26,12 @@ enum class DeviceType : int8_t {
Vulkan = 10, // Vulkan Vulkan = 10, // Vulkan
Metal = 11, // Metal Metal = 11, // Metal
XPU = 12, // XPU XPU = 12, // XPU
MLC = 13, //ML Compute / Apple
// NB: If you add more devices: // NB: If you add more devices:
// - Change the implementations of DeviceTypeName and isValidDeviceType // - Change the implementations of DeviceTypeName and isValidDeviceType
// in DeviceType.cpp // in DeviceType.cpp
// - Change the number below // - Change the number below
COMPILE_TIME_MAX_DEVICE_TYPES = 13, COMPILE_TIME_MAX_DEVICE_TYPES = 14,
}; };
constexpr DeviceType kCPU = DeviceType::CPU; constexpr DeviceType kCPU = DeviceType::CPU;
@ -39,6 +40,7 @@ constexpr DeviceType kHIP = DeviceType::HIP;
constexpr DeviceType kFPGA = DeviceType::FPGA; constexpr DeviceType kFPGA = DeviceType::FPGA;
constexpr DeviceType kMSNPU = DeviceType::MSNPU; constexpr DeviceType kMSNPU = DeviceType::MSNPU;
constexpr DeviceType kXLA = DeviceType::XLA; constexpr DeviceType kXLA = DeviceType::XLA;
constexpr DeviceType kMLC = DeviceType::MLC;
constexpr DeviceType kVulkan = DeviceType::Vulkan; constexpr DeviceType kVulkan = DeviceType::Vulkan;
constexpr DeviceType kMetal = DeviceType::Metal; constexpr DeviceType kMetal = DeviceType::Metal;
constexpr DeviceType kXPU = DeviceType::XPU; constexpr DeviceType kXPU = DeviceType::XPU;

View File

@ -21,6 +21,8 @@ const char* toString(DispatchKey t) {
return "MSNPU"; return "MSNPU";
case DispatchKey::XLA: case DispatchKey::XLA:
return "XLA"; return "XLA";
case DispatchKey::MLC:
return "MLC";
case DispatchKey::Vulkan: case DispatchKey::Vulkan:
return "Vulkan"; return "Vulkan";
case DispatchKey::Metal: case DispatchKey::Metal:
@ -80,6 +82,8 @@ const char* toString(DispatchKey t) {
return "AutogradCUDA"; return "AutogradCUDA";
case DispatchKey::AutogradXLA: case DispatchKey::AutogradXLA:
return "AutogradXLA"; return "AutogradXLA";
case DispatchKey::AutogradMLC:
return "AutogradMLC";
case DispatchKey::AutogradNestedTensor: case DispatchKey::AutogradNestedTensor:
return "AutogradNestedTensor"; return "AutogradNestedTensor";
case DispatchKey::AutogradPrivateUse1: case DispatchKey::AutogradPrivateUse1:
@ -143,6 +147,8 @@ DispatchKey getAutogradKeyFromBackend(DispatchKey t) {
return DispatchKey::AutogradCUDA; return DispatchKey::AutogradCUDA;
case DispatchKey::XLA: case DispatchKey::XLA:
return DispatchKey::AutogradXLA; return DispatchKey::AutogradXLA;
case DispatchKey::MLC:
return DispatchKey::AutogradMLC;
case DispatchKey::NestedTensor: case DispatchKey::NestedTensor:
return DispatchKey::AutogradNestedTensor; return DispatchKey::AutogradNestedTensor;
case DispatchKey::PrivateUse1: case DispatchKey::PrivateUse1:

View File

@ -62,6 +62,7 @@ enum class DispatchKey : uint8_t {
MSNPU, // unused externally, but tested at MSNPU, // unused externally, but tested at
// test/cpp_extensions/msnpu_extension.cpp // test/cpp_extensions/msnpu_extension.cpp
XLA, // lives out of tree at https://github.com/pytorch/xla XLA, // lives out of tree at https://github.com/pytorch/xla
MLC, // lives out of tree at https://github.com/pytorch/MLCompute
Vulkan, Vulkan,
Metal, Metal,
XPU, // For out of tree Intel's heterogeneous computing plug-in XPU, // For out of tree Intel's heterogeneous computing plug-in
@ -224,9 +225,9 @@ enum class DispatchKey : uint8_t {
AutogradCPU, AutogradCPU,
AutogradCUDA, AutogradCUDA,
AutogradXLA, AutogradXLA,
AutogradNestedTensor, // lives out of tree at
// https://github.com/pytorch/nestedtensor
AutogradXPU, AutogradXPU,
AutogradMLC,
AutogradNestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor
// Here are some reserved pre-autograd keys for user-defined backends, see // Here are some reserved pre-autograd keys for user-defined backends, see
// Note [Private use DispatchKey] // Note [Private use DispatchKey]
AutogradPrivateUse1, AutogradPrivateUse1,

View File

@ -14,6 +14,7 @@ constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends |
DispatchKey::PrivateUse1, DispatchKey::PrivateUse1,
DispatchKey::PrivateUse2, DispatchKey::PrivateUse2,
DispatchKey::PrivateUse3, DispatchKey::PrivateUse3,
DispatchKey::MLC,
}); });
bool isBackendDispatchKey(DispatchKey t) { bool isBackendDispatchKey(DispatchKey t) {
@ -48,6 +49,8 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
return DispatchKeySet(DispatchKey::CUDA); return DispatchKeySet(DispatchKey::CUDA);
case DispatchKey::AutogradXLA: case DispatchKey::AutogradXLA:
return DispatchKeySet(DispatchKey::XLA); return DispatchKeySet(DispatchKey::XLA);
case DispatchKey::AutogradMLC:
return DispatchKeySet(DispatchKey::MLC);
case DispatchKey::AutogradNestedTensor: case DispatchKey::AutogradNestedTensor:
return DispatchKeySet(DispatchKey::NestedTensor); return DispatchKeySet(DispatchKey::NestedTensor);
case DispatchKey::AutogradXPU: case DispatchKey::AutogradXPU:

View File

@ -195,6 +195,7 @@ constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
DispatchKey::AutogradCUDA, DispatchKey::AutogradCUDA,
DispatchKey::AutogradXLA, DispatchKey::AutogradXLA,
DispatchKey::AutogradNestedTensor, DispatchKey::AutogradNestedTensor,
DispatchKey::AutogradMLC,
DispatchKey::AutogradXPU, DispatchKey::AutogradXPU,
DispatchKey::AutogradPrivateUse1, DispatchKey::AutogradPrivateUse1,
DispatchKey::AutogradPrivateUse2, DispatchKey::AutogradPrivateUse2,

View File

@ -543,6 +543,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return key_set_.has(DispatchKey::Metal); return key_set_.has(DispatchKey::Metal);
} }
bool is_mlc() const {
return key_set_.has(DispatchKey::MLC);
}
// TODO: remove this once we don't automatically enabled Autograd dispatch keys // TODO: remove this once we don't automatically enabled Autograd dispatch keys
// in TensorImpl constructor. // in TensorImpl constructor.
// DON'T USE THIS API!! It's only created for testing purpose in // DON'T USE THIS API!! It's only created for testing purpose in

View File

@ -629,6 +629,8 @@ inline DispatchKey computeDispatchKey(c10::optional<ScalarType> dtype, c10::opti
return DispatchKey::MSNPU; return DispatchKey::MSNPU;
case DeviceType::XLA: case DeviceType::XLA:
return DispatchKey::XLA; return DispatchKey::XLA;
case DeviceType::MLC:
return DispatchKey::MLC;
case DeviceType::Vulkan: case DeviceType::Vulkan:
return DispatchKey::Vulkan; return DispatchKey::Vulkan;
case DeviceType::Metal: case DeviceType::Metal:
@ -687,6 +689,8 @@ inline DeviceType computeDeviceType(DispatchKey tid) {
return DeviceType::MSNPU; return DeviceType::MSNPU;
} else if (tid == DispatchKey::XLA) { } else if (tid == DispatchKey::XLA) {
return DeviceType::XLA; return DeviceType::XLA;
} else if (tid == DispatchKey::MLC) {
return DeviceType::MLC;
} else if (tid == DispatchKey::SparseCPU) { } else if (tid == DispatchKey::SparseCPU) {
return DeviceType::CPU; return DeviceType::CPU;
} else if (tid == DispatchKey::SparseCUDA) { } else if (tid == DispatchKey::SparseCUDA) {

View File

@ -198,8 +198,9 @@ enum DeviceTypeProto {
PROTO_FPGA = 7; // FPGA PROTO_FPGA = 7; // FPGA
PROTO_MSNPU = 8; // MSNPU PROTO_MSNPU = 8; // MSNPU
PROTO_XLA = 9; // XLA / TPU PROTO_XLA = 9; // XLA / TPU
PROTO_MLC = 10; // ML Compute
// Change the following number if you add more devices in the code. // Change the following number if you add more devices in the code.
PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 10; PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 11;
} }
// Device-specific options. We do not distinguish DeviceOption protos for // Device-specific options. We do not distinguish DeviceOption protos for

View File

@ -179,6 +179,12 @@ def _rebuild_xla_tensor(data, dtype, device, requires_grad):
return tensor return tensor
def _rebuild_mlc_tensor(data, dtype, device, requires_grad):
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
tensor.requires_grad = requires_grad
return tensor
def _rebuild_qtensor(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks): def _rebuild_qtensor(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks):
qscheme = quantizer_params[0] qscheme = quantizer_params[0]
if qscheme == torch.per_tensor_affine: if qscheme == torch.per_tensor_affine:

View File

@ -95,6 +95,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
.value("FPGA", c10::DeviceType::FPGA) .value("FPGA", c10::DeviceType::FPGA)
.value("MSNPU", c10::DeviceType::MSNPU) .value("MSNPU", c10::DeviceType::MSNPU)
.value("XLA", c10::DeviceType::XLA) .value("XLA", c10::DeviceType::XLA)
.value("MLC", c10::DeviceType::MLC)
.value("Vulkan", c10::DeviceType::Vulkan) .value("Vulkan", c10::DeviceType::Vulkan)
.value("Metal", c10::DeviceType::Metal); .value("Metal", c10::DeviceType::Metal);

View File

@ -610,6 +610,17 @@ PyObject *THPVariable_is_mkldnn(THPVariable *self, void *unused)
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
PyObject *THPVariable_is_mlc(THPVariable *self, void *unused)
{
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject *)self)) {
return handle_torch_function_getter(self, "is_mlc");
}
auto& self_ = self->cdata;
return torch::autograd::utils::wrap(self_.is_mlc());
END_HANDLE_TH_ERRORS
}
PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused) PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused)
{ {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
@ -751,6 +762,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
{"is_xpu", (getter)THPVariable_is_xpu, nullptr, nullptr, nullptr}, {"is_xpu", (getter)THPVariable_is_xpu, nullptr, nullptr, nullptr},
{"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr}, {"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
{"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr}, {"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
{"is_mlc", (getter)THPVariable_is_mlc, nullptr, nullptr, nullptr},
{"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr}, {"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},
{"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr}, {"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr},
{"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr}, {"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr},

View File

@ -110,6 +110,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
{"is_xpu", "prim"}, {"is_xpu", "prim"},
{"is_sparse", "prim"}, {"is_sparse", "prim"},
{"is_mkldnn", "prim"}, {"is_mkldnn", "prim"},
{"is_mlc", "prim"},
{"is_quantized", "prim"}, {"is_quantized", "prim"},
{"is_vulkan", "prim"}, {"is_vulkan", "prim"},
{"is_meta", "prim"}, {"is_meta", "prim"},

View File

@ -289,6 +289,14 @@ RegisterOperators reg(
push(stack, a.is_mkldnn()); push(stack, a.is_mkldnn());
}, },
aliasAnalysisFromSchema()), aliasAnalysisFromSchema()),
Operator(
"prim::is_mlc(Tensor a) -> bool",
[](Stack* stack) {
at::Tensor a;
pop(stack, a);
push(stack, a.is_mlc());
},
aliasAnalysisFromSchema()),
Operator( Operator(
"prim::is_vulkan(Tensor a) -> bool", "prim::is_vulkan(Tensor a) -> bool",
[](Stack* stack) { [](Stack* stack) {

View File

@ -61,6 +61,9 @@ Backend backendToBackendOfDeviceType(Backend b, DeviceType d) {
return Backend::XLA; return Backend::XLA;
case DeviceType::XPU: case DeviceType::XPU:
return backendToXPU(b); return backendToXPU(b);
case DeviceType::MLC:
TORCH_CHECK(!isSparse(b), "Sparse not implemented for MLC");
return Backend::MLC;
default: default:
AT_ERROR("Unknown device type"); AT_ERROR("Unknown device type");
} }

View File

@ -292,6 +292,8 @@ inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) {
return c10::DispatchKey::CUDA; return c10::DispatchKey::CUDA;
case c10::DeviceType::XLA: case c10::DeviceType::XLA:
return c10::DispatchKey::XLA; return c10::DispatchKey::XLA;
case c10::DeviceType::MLC:
return c10::DispatchKey::MLC;
case c10::DeviceType::HIP: case c10::DeviceType::HIP:
return c10::DispatchKey::HIP; return c10::DispatchKey::HIP;
case c10::DeviceType::MSNPU: case c10::DeviceType::MSNPU:

View File

@ -910,6 +910,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.is_xpu.__get__: lambda self: -1, Tensor.is_xpu.__get__: lambda self: -1,
Tensor.is_leaf.__get__: lambda self: -1, Tensor.is_leaf.__get__: lambda self: -1,
Tensor.is_meta.__get__: lambda self: -1, Tensor.is_meta.__get__: lambda self: -1,
Tensor.is_mlc.__get__: lambda self: -1,
Tensor.is_mkldnn.__get__: lambda self: -1, Tensor.is_mkldnn.__get__: lambda self: -1,
Tensor.is_quantized.__get__: lambda self: -1, Tensor.is_quantized.__get__: lambda self: -1,
Tensor.is_sparse.__get__: lambda self: -1, Tensor.is_sparse.__get__: lambda self: -1,

View File

@ -57,7 +57,7 @@ class Tensor(torch._C._TensorBase):
if id(self) in memo: if id(self) in memo:
return memo[id(self)] return memo[id(self)]
with torch.no_grad(): with torch.no_grad():
if self.is_sparse or self.device.type == 'xla': if self.is_sparse or self.device.type == 'xla' or self.device.type == 'mlc':
new_tensor = self.clone() new_tensor = self.clone()
else: else:
new_storage = self.storage().__deepcopy__(memo) new_storage = self.storage().__deepcopy__(memo)
@ -123,6 +123,12 @@ class Tensor(torch._C._TensorBase):
str(self.device), str(self.device),
self.requires_grad) self.requires_grad)
return (torch._utils._rebuild_xla_tensor, arg_xla) return (torch._utils._rebuild_xla_tensor, arg_xla)
if self.device.type == 'mlc':
arg_mlc = (self.cpu().numpy(),
self.dtype,
str(self.device),
self.requires_grad)
return (torch._utils._rebuild_mlc_tensor, arg_mlc)
if self.is_quantized: if self.is_quantized:
# quantizer_params can be different type based on torch attribute # quantizer_params can be different type based on torch attribute
quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]] quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]]

View File

@ -257,7 +257,7 @@ class RemoteModuleTest(RpcAgentTestFixture):
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
r"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan" r"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan"
" device type at start of device string", " device type at start of device string",
): ):
list( list(