mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make meta a device (getting rid of empty_meta) (#53143)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53143 Meta is now an honest to goodness device type, like cpu, so you can use device='meta' to trigger allocation of meta tensors. This way better than empty_meta since we now have working API for most factory functions (they don't necessarily work yet, though, because need to register Meta versions of those functions.) Some subtleties: - I decided to drop the concept of CPU versus CUDA meta tensors; meta tensors are device agnostic. It's hard to say exactly what the correct level of abstraction here is, but in this particular case implementation considerations trump semantic considerations: it is way easier to have just a meta device, than to have a meta device AND a cpu device AND a cuda device. This may limit the applicability of meta tensors for tracing models that do explicit cpu()/cuda() conversions (unless, perhaps, we make those operations no-ops on meta tensors). - I noticed that the DeviceType uppercase strings are kind of weird. Are they really supposed to be all caps? That's weird. - I moved the Meta dispatch key to live with the rest of the "device" dispatch keys. - I intentionally did NOT add a Backend for Meta. For now, I'm going to hope meta tensors never exercise any of the Backend conversion code; even if it does, better to fix the code to just stop converting to and from Backend. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: samestep Differential Revision: D26763552 Pulled By: ezyang fbshipit-source-id: 14633b6ca738e60b921db66a763155d01795480d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
fd3004d3ee
commit
0f81a69a96
@ -1373,17 +1373,9 @@ void TensorIterator::set_output(int64_t output_idx, IntArrayRef sizes, IntArrayR
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_);
|
||||
if (!op.tensor.defined()) {
|
||||
if (strides.empty()) {
|
||||
if (is_meta_) {
|
||||
op.tensor = at::empty_meta(sizes, options);
|
||||
} else {
|
||||
op.tensor = at::empty(sizes, options);
|
||||
}
|
||||
op.tensor = at::empty(sizes, options);
|
||||
} else {
|
||||
if (is_meta_) {
|
||||
TORCH_INTERNAL_ASSERT(0, "meta strided not yet implemented");
|
||||
} else {
|
||||
op.tensor = at::empty_strided(sizes, strides, options);
|
||||
}
|
||||
op.tensor = at::empty_strided(sizes, strides, options);
|
||||
}
|
||||
op.current_dtype = op.target_dtype;
|
||||
} else if (op.will_resize) {
|
||||
|
9
aten/src/ATen/detail/MetaGuardImpl.cpp
Normal file
9
aten/src/ATen/detail/MetaGuardImpl.cpp
Normal file
@ -0,0 +1,9 @@
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace at {
|
||||
namespace detail {
|
||||
|
||||
C10_REGISTER_GUARD_IMPL(Meta, c10::impl::NoOpDeviceGuardImpl<DeviceType::Meta>);
|
||||
|
||||
}} // namespace at::detail
|
@ -4,7 +4,6 @@
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
// Will be promoted to a public API later, but not now
|
||||
Tensor empty_meta(
|
||||
IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
@ -16,11 +15,7 @@ Tensor empty_meta(
|
||||
// TODO: deduplicate this logic with empty_cpu
|
||||
|
||||
auto tensor = detail::make_tensor<TensorImpl>(
|
||||
// NB: We include the computed dispatch key, not because it will actually
|
||||
// participate in dispatch, but so that tests like is_sparse/is_cuda
|
||||
// give the correct result (a CUDA meta tensor "is cuda"). If we don't
|
||||
// like this, remove the computeDispatchKey line
|
||||
DispatchKeySet{DispatchKey::Meta, computeDispatchKey(dtype, layout, device)},
|
||||
DispatchKeySet{DispatchKey::Meta},
|
||||
scalarTypeToTypeMeta(dtype_or_default(dtype)),
|
||||
device
|
||||
);
|
||||
|
@ -1570,8 +1570,6 @@
|
||||
CPU: _embedding_bag_per_sample_weights_backward_cpu
|
||||
CUDA: _embedding_bag_per_sample_weights_backward_cuda
|
||||
|
||||
- func: empty_meta(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
||||
|
||||
- func: empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
||||
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
|
||||
device_guard: False
|
||||
@ -1580,6 +1578,7 @@
|
||||
dispatch:
|
||||
CPU: empty_cpu
|
||||
CUDA: empty_cuda
|
||||
Meta: empty_meta
|
||||
MkldnnCPU: empty_mkldnn
|
||||
SparseCPU, SparseCUDA: empty_sparse
|
||||
|
||||
|
@ -47,6 +47,7 @@ DeviceType parse_type(const std::string& device_string) {
|
||||
{"xla", DeviceType::XLA},
|
||||
{"vulkan", DeviceType::Vulkan},
|
||||
{"mlc", DeviceType::MLC},
|
||||
{"meta", DeviceType::Meta},
|
||||
}};
|
||||
auto device = std::find_if(
|
||||
types.begin(),
|
||||
@ -58,7 +59,7 @@ DeviceType parse_type(const std::string& device_string) {
|
||||
return device->second;
|
||||
}
|
||||
TORCH_CHECK(false,
|
||||
"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan device type at start of device string: ",
|
||||
"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan, meta device type at start of device string: ",
|
||||
device_string);
|
||||
}
|
||||
} // namespace
|
||||
|
@ -35,6 +35,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) {
|
||||
return lower_case ? "metal" : "METAL";
|
||||
case DeviceType::XPU:
|
||||
return lower_case ? "xpu" : "XPU";
|
||||
case DeviceType::Meta:
|
||||
return lower_case ? "meta" : "META";
|
||||
default:
|
||||
TORCH_CHECK(false,
|
||||
"Unknown device: ",
|
||||
@ -71,6 +73,7 @@ bool isValidDeviceType(DeviceType d) {
|
||||
case DeviceType::Vulkan:
|
||||
case DeviceType::Metal:
|
||||
case DeviceType::XPU:
|
||||
case DeviceType::Meta:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
@ -26,12 +26,13 @@ enum class DeviceType : int8_t {
|
||||
Vulkan = 10, // Vulkan
|
||||
Metal = 11, // Metal
|
||||
XPU = 12, // XPU
|
||||
MLC = 13, //ML Compute / Apple
|
||||
MLC = 13, // ML Compute / Apple
|
||||
Meta = 14, // Meta (tensors with no data)
|
||||
// NB: If you add more devices:
|
||||
// - Change the implementations of DeviceTypeName and isValidDeviceType
|
||||
// in DeviceType.cpp
|
||||
// - Change the number below
|
||||
COMPILE_TIME_MAX_DEVICE_TYPES = 14,
|
||||
COMPILE_TIME_MAX_DEVICE_TYPES = 15,
|
||||
};
|
||||
|
||||
constexpr DeviceType kCPU = DeviceType::CPU;
|
||||
@ -41,6 +42,7 @@ constexpr DeviceType kFPGA = DeviceType::FPGA;
|
||||
constexpr DeviceType kMSNPU = DeviceType::MSNPU;
|
||||
constexpr DeviceType kXLA = DeviceType::XLA;
|
||||
constexpr DeviceType kMLC = DeviceType::MLC;
|
||||
constexpr DeviceType kMeta = DeviceType::Meta;
|
||||
constexpr DeviceType kVulkan = DeviceType::Vulkan;
|
||||
constexpr DeviceType kMetal = DeviceType::Metal;
|
||||
constexpr DeviceType kXPU = DeviceType::XPU;
|
||||
|
@ -76,6 +76,13 @@ enum class DispatchKey : uint8_t {
|
||||
OpenCL,
|
||||
IDEEP,
|
||||
|
||||
// A meta tensor is a tensor without any data associated with it. (They
|
||||
// have also colloquially been referred to as tensors on the "null" device).
|
||||
// A meta tensor can be used to dry run operators without actually doing any
|
||||
// computation, e.g., add on two meta tensors would give you another meta
|
||||
// tensor with the output shape and dtype, but wouldn't actually add anything.
|
||||
Meta,
|
||||
|
||||
// Here are backends which specify more specialized operators
|
||||
// based on the dtype of the tensor.
|
||||
QuantizedCPU, // registered at build/aten/src/ATen/RegisterQuantizedCPU.cpp
|
||||
@ -123,58 +130,6 @@ enum class DispatchKey : uint8_t {
|
||||
// If you add new backend keys after PrivateUse3, please also update it here.
|
||||
EndOfBackendKeys = PrivateUse3,
|
||||
|
||||
// The meta function characterizes how an operation affects the metadata of a
|
||||
// tensor (shape, dtype) without doing any of the actual computation. A
|
||||
// meta tensor can be used to dry run operators without actually doing
|
||||
// any computation, e.g., add on two meta tensors would give you another
|
||||
// meta tensor with the output shape and dtype, but wouldn't actually
|
||||
// add anything. A meta implementation typically would look something like:
|
||||
//
|
||||
// Tensor meta::add(const Tensor& self, const Tensor& other) {
|
||||
// TORCH_CHECK(self.size().equals(other.size()));
|
||||
// return at::empty_like(self, self.size());
|
||||
// }
|
||||
//
|
||||
// The meta function would get invoked if you ran an operator passing
|
||||
// in meta tensors. The call stack in such a case would look something like
|
||||
// this:
|
||||
//
|
||||
// at::add(x: Meta, y: Meta) {
|
||||
// return [dispatch] meta::add(x: Meta, y: Meta) {
|
||||
// output_shape = ...
|
||||
// [dispatch] meta::empty(output_shape) {
|
||||
// return ... meta tensor with output_shape but no data allocated ...
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Meta functions have an important secondary function, which is they can
|
||||
// be used as tensor "allocators". A typical backend implementation should
|
||||
// be implemented in this way:
|
||||
//
|
||||
// Tensor cpu::add(const Tensor& self, const Tensor& other) {
|
||||
// Tensor result = meta::add(self, other);
|
||||
// // ... do the actual computation into result ...
|
||||
// return result;
|
||||
// }
|
||||
//
|
||||
// In this case, the internal at::empty_like invocation would dispatch to the
|
||||
// CPU factory function, not the meta factory function. The call stack in
|
||||
// this case looks like:
|
||||
//
|
||||
// at::add(x: CPU, y: CPU) {
|
||||
// return [dispatch] cpu::add(x: CPU, y: CPU) {
|
||||
// output = [direct] meta::add(x: CPU, y: CPU) {
|
||||
// output_shape = ...
|
||||
// [dispatch] cpu::empty(output_shape)
|
||||
// }
|
||||
// ... compute on output ...
|
||||
// return output;
|
||||
// }
|
||||
// }
|
||||
//
|
||||
Meta,
|
||||
|
||||
// In some situations, it is not immediately obvious what the correct
|
||||
// backend for function is, because the function in question doesn't
|
||||
// have any "tensor" arguments. In this case, a BackendSelect function
|
||||
|
@ -635,6 +635,8 @@ inline DispatchKey computeDispatchKey(c10::optional<ScalarType> dtype, c10::opti
|
||||
return DispatchKey::Vulkan;
|
||||
case DeviceType::Metal:
|
||||
return DispatchKey::Metal;
|
||||
case DeviceType::Meta:
|
||||
return DispatchKey::Meta;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported device type for dense layout: ", device_.type());
|
||||
}
|
||||
@ -691,6 +693,8 @@ inline DeviceType computeDeviceType(DispatchKey tid) {
|
||||
return DeviceType::XLA;
|
||||
} else if (tid == DispatchKey::MLC) {
|
||||
return DeviceType::MLC;
|
||||
} else if (tid == DispatchKey::Meta) {
|
||||
return DeviceType::Meta;
|
||||
} else if (tid == DispatchKey::SparseCPU) {
|
||||
return DeviceType::CPU;
|
||||
} else if (tid == DispatchKey::SparseCUDA) {
|
||||
|
@ -74,6 +74,7 @@ allow_list = [
|
||||
("aten::_foreach_addcdiv", datetime.date(2021, 2, 25)),
|
||||
("aten::mkldnn_linear", datetime.date(2021, 3, 2)),
|
||||
("aten::linalg_multi_dot", datetime.date(2021, 3, 25)),
|
||||
("aten::empty_meta", datetime.date(2021, 4, 1)),
|
||||
]
|
||||
|
||||
def allow_listed(schema, allow_list):
|
||||
|
@ -2556,8 +2556,8 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
self.assertEqual(output3, output2)
|
||||
|
||||
def test_empty_meta(self):
|
||||
x = torch.empty_meta(2 ** 20, 2 ** 20)
|
||||
y = torch.empty_meta(2 ** 20)
|
||||
x = torch.empty(2 ** 20, 2 ** 20, device='meta')
|
||||
y = torch.empty(2 ** 20, device='meta')
|
||||
z = x + y
|
||||
self.assertEqual(z.size(), (2 ** 20, 2 ** 20))
|
||||
self.assertRaises(RuntimeError, lambda: z[0][0].item())
|
||||
@ -2568,14 +2568,14 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
# integrated testing strategy
|
||||
# NB: Can't make the exponent too big, or it will overflow
|
||||
# signed 64-bit integer
|
||||
x = torch.empty_meta(2 * 10 ** 8, 3, 2 * 10 ** 8)
|
||||
x = torch.empty(2 * 10 ** 8, 3, 2 * 10 ** 8, device='meta')
|
||||
z = torch.nn.functional.interpolate(x, scale_factor=2)
|
||||
self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8))
|
||||
self.assertRaises(RuntimeError, lambda: z[0][0][0].item())
|
||||
|
||||
# interpolate doesn't seem to support out=
|
||||
# (not sure why passing None here doesn't work? How strange...)
|
||||
z = torch.empty_meta(0)
|
||||
z = torch.empty(0, device='meta')
|
||||
torch._C._nn.upsample_nearest1d(x, (4 * 10 ** 8,), 2, out=z)
|
||||
self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8))
|
||||
self.assertRaises(RuntimeError, lambda: z[0][0][0].item())
|
||||
|
@ -247,7 +247,7 @@ if (C10_UNLIKELY(current_device.has_value())) {
|
||||
if self.dispatch_key == DispatchKey.Meta:
|
||||
return """
|
||||
if (strides.empty()) {
|
||||
outputs_[output_idx] = at::empty_meta(sizes, options);
|
||||
outputs_[output_idx] = at::empty(sizes, options.device(at::kMeta));
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(0, "not implemented yet");
|
||||
}
|
||||
|
@ -96,6 +96,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
|
||||
.value("MSNPU", c10::DeviceType::MSNPU)
|
||||
.value("XLA", c10::DeviceType::XLA)
|
||||
.value("MLC", c10::DeviceType::MLC)
|
||||
.value("Meta", c10::DeviceType::Meta)
|
||||
.value("Vulkan", c10::DeviceType::Vulkan)
|
||||
.value("Metal", c10::DeviceType::Metal);
|
||||
|
||||
|
@ -294,6 +294,8 @@ inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) {
|
||||
return c10::DispatchKey::XLA;
|
||||
case c10::DeviceType::MLC:
|
||||
return c10::DispatchKey::MLC;
|
||||
case c10::DeviceType::Meta:
|
||||
return c10::DispatchKey::Meta;
|
||||
case c10::DeviceType::HIP:
|
||||
return c10::DispatchKey::HIP;
|
||||
case c10::DeviceType::MSNPU:
|
||||
|
@ -129,7 +129,6 @@ def get_ignored_functions() -> Set[Callable]:
|
||||
torch.cudnn_grid_sampler,
|
||||
torch.cudnn_is_acceptable,
|
||||
torch.empty,
|
||||
torch.empty_meta,
|
||||
torch.empty_strided,
|
||||
torch.empty_quantized,
|
||||
torch.eye,
|
||||
|
@ -257,8 +257,7 @@ class RemoteModuleTest(RpcAgentTestFixture):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, msnpu, mlc, xla, vulkan"
|
||||
" device type at start of device string",
|
||||
r"Expected one of .+ device type at start of device string",
|
||||
):
|
||||
list(
|
||||
self._create_remote_module_iter(
|
||||
|
Reference in New Issue
Block a user