Introduce mlc device (ML Compute device) to PyTorch's device list (#50634)

Summary:
Apple recently announced ML Compute, a new framework available in macOS Big Sur, which enables users to accelerate the training of neural networks on Mac hardware. This PR is the first on a series of PRs that will enable the integration with ML Compute. Most of the integration code will live on a separate subrepo named `mlc`.
The integration with `mlc` (ML Compute) will be very similar to that of xla. We rely on registering our ops through:

TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
 m.impl_UNBOXED(<op_schema_name>, &customized_op_kernel)
 ...
}

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50634

Reviewed By: malfet

Differential Revision: D26614213

Pulled By: smessmer

fbshipit-source-id: 3b492b346c61cc3950ac880ac01a82fbdddbc07b
This commit is contained in:
Bel H
2021-02-24 22:37:12 -08:00
committed by Facebook GitHub Bot
parent 2bdf6305a0
commit 30cb6ac53c
28 changed files with 124 additions and 7 deletions

View File

@ -157,6 +157,7 @@ class ExperimentalFeatureConfigNode(TreeConfigNode):
next_nodes = {
"asan": AsanConfigNode,
"xla": XlaConfigNode,
"mlc": MLCConfigNode,
"vulkan": VulkanConfigNode,
"parallel_tbb": ParallelTBBConfigNode,
"parallel_native": ParallelNativeConfigNode,
@ -193,6 +194,16 @@ class XlaConfigNode(TreeConfigNode):
def child_constructor(self):
return ImportantConfigNode
class MLCConfigNode(TreeConfigNode):
def modify_label(self, label):
return "MLC=" + str(label)
def init2(self, node_name):
self.props["is_mlc"] = node_name
def child_constructor(self):
return ImportantConfigNode
class AsanConfigNode(TreeConfigNode):
def modify_label(self, label):

View File

@ -73,6 +73,9 @@ class TORCH_API Context {
bool hasXLA() const {
return c10::impl::hasDeviceGuardImpl(at::DeviceType::XLA);
}
bool hasMLC() const {
return c10::impl::hasDeviceGuardImpl(at::DeviceType::MLC);
}
// defined in header so that getNonVariableType has ability to inline
// call_once check. getNonVariableType is called fairly frequently
THCState* lazyInitCUDA() {
@ -276,6 +279,10 @@ static inline bool hasXLA() {
return globalContext().hasXLA();
}
static inline bool hasMLC() {
return globalContext().hasMLC();
}
// Despite its name, this function returns the number of *CUDA* GPUs.
static inline size_t getNumGPUs() {
// WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS

View File

@ -194,6 +194,7 @@ std::string show_config() {
// TODO: do HIP
// TODO: do XLA
// TODO: do MLC
return ss.str();
}

View File

@ -48,4 +48,8 @@ TORCH_LIBRARY_IMPL(_, AutogradXLA, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
TORCH_LIBRARY_IMPL(_, AutogradMLC, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
}

View File

@ -397,6 +397,7 @@ _(aten, is_coalesced) \
_(aten, is_complex) \
_(aten, is_contiguous) \
_(aten, is_cuda) \
_(aten, is_mlc) \
_(aten, is_distributed) \
_(aten, is_floating_point) \
_(aten, is_nonzero) \

View File

@ -364,6 +364,9 @@ class TORCH_API Tensor {
/// Returns if a `Tensor` is mkldnn tensor.
bool is_mkldnn() const;
/// Returns if a `Tensor` is mlc tensor.
bool is_mlc() const;
/// Returns if a `Tensor` is vulkan tensor.
bool is_vulkan() const;

View File

@ -145,6 +145,15 @@ bool is_mkldnn(Tensor self) {
return self.is_mkldnn();
}
bool Tensor::is_mlc() const {
// NB: this is not a native function to avoid dispatching overhead.
return impl_->is_mlc();
}
bool is_mlc(Tensor self) {
return self.is_mlc();
}
bool Tensor::is_vulkan() const {
// NB: this is not a native function to avoid dispatching overhead.
return impl_->is_vulkan();

View File

@ -45,6 +45,7 @@ enum class Backend {
QuantizedXPU,
Undefined,
MkldnnCPU,
MLC,
NumOptions
};
@ -99,6 +100,8 @@ static inline Backend toDense(Backend b) {
return Backend::QuantizedCUDA;
case Backend::QuantizedXPU:
return Backend::QuantizedXPU;
case Backend::MLC:
return Backend::MLC;
default:
throw std::runtime_error("Unknown backend");
}
@ -117,6 +120,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) {
return Backend::MSNPU;
} else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
return Backend::XLA;
} else if (t == DispatchKey::MLC || t == DispatchKey::AutogradMLC) {
return Backend::MLC;
} else if (t == DispatchKey::Vulkan) {
return Backend::Vulkan;
} else if (t == DispatchKey::Metal) {
@ -182,6 +187,8 @@ static inline DispatchKey backendToDispatchKey(Backend b) {
return DispatchKey::QuantizedCUDA;
case Backend::Undefined:
return DispatchKey::Undefined;
case Backend::MLC:
return DispatchKey::MLC;
default:
throw std::runtime_error("Unknown backend");
}
@ -220,6 +227,8 @@ static inline DeviceType backendToDeviceType(Backend b) {
return DeviceType::Vulkan;
case Backend::Metal:
return DeviceType::Metal;
case Backend::MLC:
return DeviceType::MLC;
case Backend::Undefined:
AT_ERROR("Undefined backend is not a valid device type");
default:
@ -250,6 +259,8 @@ static inline Backend backendToCPU(Backend b) {
case Backend::MSNPU:
case Backend::XLA:
return Backend::CPU;
case Backend::MLC:
return Backend::CPU;
case Backend::MkldnnCPU:
return Backend::MkldnnCPU;
case Backend::QuantizedCPU:
@ -302,6 +313,7 @@ static inline Backend backendToCUDA(Backend b) {
case Backend::FPGA:
case Backend::MSNPU:
case Backend::XLA:
case Backend::MLC:
return Backend::CUDA;
case Backend::SparseXPU:
case Backend::SparseCPU:
@ -324,6 +336,7 @@ static inline Backend backendToHIP(Backend b) {
case Backend::FPGA:
case Backend::MSNPU:
case Backend::XLA:
case Backend::MLC:
return Backend::HIP;
case Backend::SparseXPU:
case Backend::SparseCPU:
@ -354,6 +367,8 @@ static inline const char* toString(Backend b) {
return "MSNPU";
case Backend::XLA:
return "XLA";
case Backend::MLC:
return "MLC";
case Backend::SparseCPU:
return "SparseCPU";
case Backend::SparseCUDA:

View File

@ -46,6 +46,7 @@ DeviceType parse_type(const std::string& device_string) {
{"msnpu", DeviceType::MSNPU},
{"xla", DeviceType::XLA},
{"vulkan", DeviceType::Vulkan},
{"mlc", DeviceType::MLC},
}};
auto device = std::find_if(
types.begin(),
@ -57,7 +58,7 @@ DeviceType parse_type(const std::string& device_string) {
return device->second;
}
AT_ERROR(
"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan device type at start of device string: ",
"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan device type at start of device string: ",
device_string);
}
} // namespace

View File

@ -27,6 +27,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) {
return lower_case ? "msnpu" : "MSNPU";
case DeviceType::XLA:
return lower_case ? "xla" : "XLA";
case DeviceType::MLC:
return lower_case ? "mlc" : "MLC";
case DeviceType::Vulkan:
return lower_case ? "vulkan" : "VULKAN";
case DeviceType::Metal:
@ -65,6 +67,7 @@ bool isValidDeviceType(DeviceType d) {
case DeviceType::FPGA:
case DeviceType::MSNPU:
case DeviceType::XLA:
case DeviceType::MLC:
case DeviceType::Vulkan:
case DeviceType::Metal:
case DeviceType::XPU:

View File

@ -26,11 +26,12 @@ enum class DeviceType : int8_t {
Vulkan = 10, // Vulkan
Metal = 11, // Metal
XPU = 12, // XPU
MLC = 13, //ML Compute / Apple
// NB: If you add more devices:
// - Change the implementations of DeviceTypeName and isValidDeviceType
// in DeviceType.cpp
// - Change the number below
COMPILE_TIME_MAX_DEVICE_TYPES = 13,
COMPILE_TIME_MAX_DEVICE_TYPES = 14,
};
constexpr DeviceType kCPU = DeviceType::CPU;
@ -39,6 +40,7 @@ constexpr DeviceType kHIP = DeviceType::HIP;
constexpr DeviceType kFPGA = DeviceType::FPGA;
constexpr DeviceType kMSNPU = DeviceType::MSNPU;
constexpr DeviceType kXLA = DeviceType::XLA;
constexpr DeviceType kMLC = DeviceType::MLC;
constexpr DeviceType kVulkan = DeviceType::Vulkan;
constexpr DeviceType kMetal = DeviceType::Metal;
constexpr DeviceType kXPU = DeviceType::XPU;

View File

@ -21,6 +21,8 @@ const char* toString(DispatchKey t) {
return "MSNPU";
case DispatchKey::XLA:
return "XLA";
case DispatchKey::MLC:
return "MLC";
case DispatchKey::Vulkan:
return "Vulkan";
case DispatchKey::Metal:
@ -80,6 +82,8 @@ const char* toString(DispatchKey t) {
return "AutogradCUDA";
case DispatchKey::AutogradXLA:
return "AutogradXLA";
case DispatchKey::AutogradMLC:
return "AutogradMLC";
case DispatchKey::AutogradNestedTensor:
return "AutogradNestedTensor";
case DispatchKey::AutogradPrivateUse1:
@ -143,6 +147,8 @@ DispatchKey getAutogradKeyFromBackend(DispatchKey t) {
return DispatchKey::AutogradCUDA;
case DispatchKey::XLA:
return DispatchKey::AutogradXLA;
case DispatchKey::MLC:
return DispatchKey::AutogradMLC;
case DispatchKey::NestedTensor:
return DispatchKey::AutogradNestedTensor;
case DispatchKey::PrivateUse1:

View File

@ -62,6 +62,7 @@ enum class DispatchKey : uint8_t {
MSNPU, // unused externally, but tested at
// test/cpp_extensions/msnpu_extension.cpp
XLA, // lives out of tree at https://github.com/pytorch/xla
MLC, // lives out of tree at https://github.com/pytorch/MLCompute
Vulkan,
Metal,
XPU, // For out of tree Intel's heterogeneous computing plug-in
@ -224,9 +225,9 @@ enum class DispatchKey : uint8_t {
AutogradCPU,
AutogradCUDA,
AutogradXLA,
AutogradNestedTensor, // lives out of tree at
// https://github.com/pytorch/nestedtensor
AutogradXPU,
AutogradMLC,
AutogradNestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor
// Here are some reserved pre-autograd keys for user-defined backends, see
// Note [Private use DispatchKey]
AutogradPrivateUse1,

View File

@ -14,6 +14,7 @@ constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends |
DispatchKey::PrivateUse1,
DispatchKey::PrivateUse2,
DispatchKey::PrivateUse3,
DispatchKey::MLC,
});
bool isBackendDispatchKey(DispatchKey t) {
@ -48,6 +49,8 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
return DispatchKeySet(DispatchKey::CUDA);
case DispatchKey::AutogradXLA:
return DispatchKeySet(DispatchKey::XLA);
case DispatchKey::AutogradMLC:
return DispatchKeySet(DispatchKey::MLC);
case DispatchKey::AutogradNestedTensor:
return DispatchKeySet(DispatchKey::NestedTensor);
case DispatchKey::AutogradXPU:

View File

@ -195,6 +195,7 @@ constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
DispatchKey::AutogradCUDA,
DispatchKey::AutogradXLA,
DispatchKey::AutogradNestedTensor,
DispatchKey::AutogradMLC,
DispatchKey::AutogradXPU,
DispatchKey::AutogradPrivateUse1,
DispatchKey::AutogradPrivateUse2,

View File

@ -543,6 +543,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return key_set_.has(DispatchKey::Metal);
}
bool is_mlc() const {
return key_set_.has(DispatchKey::MLC);
}
// TODO: remove this once we don't automatically enabled Autograd dispatch keys
// in TensorImpl constructor.
// DON'T USE THIS API!! It's only created for testing purpose in

View File

@ -629,6 +629,8 @@ inline DispatchKey computeDispatchKey(c10::optional<ScalarType> dtype, c10::opti
return DispatchKey::MSNPU;
case DeviceType::XLA:
return DispatchKey::XLA;
case DeviceType::MLC:
return DispatchKey::MLC;
case DeviceType::Vulkan:
return DispatchKey::Vulkan;
case DeviceType::Metal:
@ -687,6 +689,8 @@ inline DeviceType computeDeviceType(DispatchKey tid) {
return DeviceType::MSNPU;
} else if (tid == DispatchKey::XLA) {
return DeviceType::XLA;
} else if (tid == DispatchKey::MLC) {
return DeviceType::MLC;
} else if (tid == DispatchKey::SparseCPU) {
return DeviceType::CPU;
} else if (tid == DispatchKey::SparseCUDA) {

View File

@ -198,8 +198,9 @@ enum DeviceTypeProto {
PROTO_FPGA = 7; // FPGA
PROTO_MSNPU = 8; // MSNPU
PROTO_XLA = 9; // XLA / TPU
PROTO_MLC = 10; // ML Compute
// Change the following number if you add more devices in the code.
PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 10;
PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 11;
}
// Device-specific options. We do not distinguish DeviceOption protos for

View File

@ -179,6 +179,12 @@ def _rebuild_xla_tensor(data, dtype, device, requires_grad):
return tensor
def _rebuild_mlc_tensor(data, dtype, device, requires_grad):
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
tensor.requires_grad = requires_grad
return tensor
def _rebuild_qtensor(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks):
qscheme = quantizer_params[0]
if qscheme == torch.per_tensor_affine:

View File

@ -95,6 +95,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
.value("FPGA", c10::DeviceType::FPGA)
.value("MSNPU", c10::DeviceType::MSNPU)
.value("XLA", c10::DeviceType::XLA)
.value("MLC", c10::DeviceType::MLC)
.value("Vulkan", c10::DeviceType::Vulkan)
.value("Metal", c10::DeviceType::Metal);

View File

@ -610,6 +610,17 @@ PyObject *THPVariable_is_mkldnn(THPVariable *self, void *unused)
END_HANDLE_TH_ERRORS
}
PyObject *THPVariable_is_mlc(THPVariable *self, void *unused)
{
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject *)self)) {
return handle_torch_function_getter(self, "is_mlc");
}
auto& self_ = self->cdata;
return torch::autograd::utils::wrap(self_.is_mlc());
END_HANDLE_TH_ERRORS
}
PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused)
{
HANDLE_TH_ERRORS
@ -751,6 +762,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
{"is_xpu", (getter)THPVariable_is_xpu, nullptr, nullptr, nullptr},
{"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
{"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
{"is_mlc", (getter)THPVariable_is_mlc, nullptr, nullptr, nullptr},
{"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},
{"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr},
{"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr},

View File

@ -110,6 +110,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
{"is_xpu", "prim"},
{"is_sparse", "prim"},
{"is_mkldnn", "prim"},
{"is_mlc", "prim"},
{"is_quantized", "prim"},
{"is_vulkan", "prim"},
{"is_meta", "prim"},

View File

@ -289,6 +289,14 @@ RegisterOperators reg(
push(stack, a.is_mkldnn());
},
aliasAnalysisFromSchema()),
Operator(
"prim::is_mlc(Tensor a) -> bool",
[](Stack* stack) {
at::Tensor a;
pop(stack, a);
push(stack, a.is_mlc());
},
aliasAnalysisFromSchema()),
Operator(
"prim::is_vulkan(Tensor a) -> bool",
[](Stack* stack) {

View File

@ -61,6 +61,9 @@ Backend backendToBackendOfDeviceType(Backend b, DeviceType d) {
return Backend::XLA;
case DeviceType::XPU:
return backendToXPU(b);
case DeviceType::MLC:
TORCH_CHECK(!isSparse(b), "Sparse not implemented for MLC");
return Backend::MLC;
default:
AT_ERROR("Unknown device type");
}

View File

@ -292,6 +292,8 @@ inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) {
return c10::DispatchKey::CUDA;
case c10::DeviceType::XLA:
return c10::DispatchKey::XLA;
case c10::DeviceType::MLC:
return c10::DispatchKey::MLC;
case c10::DeviceType::HIP:
return c10::DispatchKey::HIP;
case c10::DeviceType::MSNPU:

View File

@ -910,6 +910,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.is_xpu.__get__: lambda self: -1,
Tensor.is_leaf.__get__: lambda self: -1,
Tensor.is_meta.__get__: lambda self: -1,
Tensor.is_mlc.__get__: lambda self: -1,
Tensor.is_mkldnn.__get__: lambda self: -1,
Tensor.is_quantized.__get__: lambda self: -1,
Tensor.is_sparse.__get__: lambda self: -1,

View File

@ -57,7 +57,7 @@ class Tensor(torch._C._TensorBase):
if id(self) in memo:
return memo[id(self)]
with torch.no_grad():
if self.is_sparse or self.device.type == 'xla':
if self.is_sparse or self.device.type == 'xla' or self.device.type == 'mlc':
new_tensor = self.clone()
else:
new_storage = self.storage().__deepcopy__(memo)
@ -123,6 +123,12 @@ class Tensor(torch._C._TensorBase):
str(self.device),
self.requires_grad)
return (torch._utils._rebuild_xla_tensor, arg_xla)
if self.device.type == 'mlc':
arg_mlc = (self.cpu().numpy(),
self.dtype,
str(self.device),
self.requires_grad)
return (torch._utils._rebuild_mlc_tensor, arg_mlc)
if self.is_quantized:
# quantizer_params can be different type based on torch attribute
quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]]

View File

@ -257,7 +257,7 @@ class RemoteModuleTest(RpcAgentTestFixture):
with self.assertRaisesRegex(
RuntimeError,
r"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan"
r"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan"
" device type at start of device string",
):
list(