mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Restore storage on meta tensors; increase meta coverage (#53973)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53973 Two parts to this PR; I had to put them together because adding support for X causes more test code to be exercised, which in turn may require a fix for Y. The first part is restoring the concept of storage to meta tensors. Previously, meta tensors had a nullptr storage (e.g., `meta_tensor.storage()` is an error.) As I was increasing the coverage of meta tensors, I started running into test cases (specifically memory overlap tests) that were failing because not having storage meant I couldn't check for memory overlap. After some discussion, we decided that it would make sense for meta tensors to model this as well (we already model strides, so getting accurate view information also seems useful). This PR does that by: * Rewrite all of the factory functions in MetaTensor.cpp to use the generic versions (which are very carefully written to not actually poke at the data pointer, so everything works out). The key idea here is we give meta tensors a special allocator, MetaAllocator, which always returns a nullptr even if you ask for a nonzero number of bytes. resize_ is also made generic; the normal variant can be used directly rather than having to instruct it to avoid resizing storage * Turn on memory overlap checking in TensorIterator even for meta tensors * Although meta tensors now have storage, the concept of meta storage is NOT exposed to Python land (as it would imply I would have to codegen MetaFloatStorage, MetaDoubleStorage, etc. classes). So `x.storage()` still raises an error and I have a cludge in `__deepcopy__` to break storage sharing upon deep copy (this is wrong, but no tests exercise this at the moment). The second part is adding more support for the most used functions in the test suite. * Inplace operations have very simple meta functions. I added `fill_`, `zero_`, `random_`, `uniform_` and `normal_`. In the case of random, I take advantage of pbelevich's templates for defining random kernels, so that I can reuse the common scaffolding, and then just register a noop stub that actually does the RNG. (Look, another structured kernels tiny variant!) * `copy_` is now implemented. Copying into a meta tensor is always OK, but copying out of a meta tensor raises an error (as we don't know what the "correct" data to copy out is in this case) * `empty_strided` usage from structured kernels now is implemented (TBH, this could have been done as soon as `empty_strided` was added) * Meta was missing in a few places in TensorOptions/DispatchKey utility functions, so I added them * Autograd engine now correctly homes meta tensors with CPU tensors (they have -1 device index so CUDA queues wouldn't work anyway) * `apply_`, `map_` and `map2_` are special cased to no-op on meta tensor self. These count as inplace operations too but they are implemented a little differently. Getting more meta function support triggers a number of bugs in the test suite, which I then fix: - Linear algebra functions sometimes don't report NotImplementedError because they get swallowed by catch all try blocks. This is tracked in https://github.com/pytorch/pytorch/issues/53739 - dlpack obviously doesn't work with meta tensors, I just disabled the test Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: D27036572 Test Plan: Imported from OSS Reviewed By: agolynski, bdhirsh Pulled By: ezyang fbshipit-source-id: 7005ecf4feb92a643c37389fdfbd852dbf00ac78
This commit is contained in:
committed by
Facebook GitHub Bot
parent
94efb48e16
commit
1f36ce6e4d
@ -4,6 +4,8 @@
|
||||
# shellcheck source=./macos-common.sh
|
||||
source "$(dirname "${BASH_SOURCE[0]}")/macos-common.sh"
|
||||
|
||||
export PYTORCH_TEST_SKIP_NOARCH=1
|
||||
|
||||
conda install -y six
|
||||
pip install -q hypothesis "librosa>=0.6.2" "numba<=0.49.1" psutil
|
||||
|
||||
|
@ -21,6 +21,7 @@ export TEST_DIR="${PWD}/test"
|
||||
export TEST_DIR_WIN=$(cygpath -w "${TEST_DIR}")
|
||||
export PYTORCH_FINAL_PACKAGE_DIR="/c/users/circleci/workspace/build-results"
|
||||
export PYTORCH_FINAL_PACKAGE_DIR_WIN=$(cygpath -w "${PYTORCH_FINAL_PACKAGE_DIR}")
|
||||
export PYTORCH_TEST_SKIP_NOARCH=1
|
||||
|
||||
mkdir -p $TMP_DIR/build/torch
|
||||
|
||||
|
@ -940,10 +940,6 @@ void TensorIteratorBase::compute_mem_overlaps(const TensorIteratorConfig& config
|
||||
if (!config.check_mem_overlap_) {
|
||||
return;
|
||||
}
|
||||
if (is_meta_) {
|
||||
// We don't have pointer addresses, cannot check for overlap!
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < num_outputs_; i++) {
|
||||
const auto& output = operands_[i].tensor;
|
||||
if (!output.defined()) continue;
|
||||
|
@ -160,6 +160,16 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
|
||||
return self;
|
||||
}
|
||||
|
||||
// Copies into meta self are OK and just ignored (similar to inplace)
|
||||
if (self.is_meta()) {
|
||||
// TODO: need to see if there is extra error checking needed
|
||||
return self;
|
||||
}
|
||||
|
||||
if (src.is_meta()) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Cannot copy out of meta tensor; no data!")
|
||||
}
|
||||
|
||||
// Re-dispatch copies when either src or self device not implemented here (e.g. XLA).
|
||||
// _copy_from has a proper device dispatch setup.
|
||||
// This includes:
|
||||
|
@ -225,10 +225,21 @@ struct UniformStub {
|
||||
}
|
||||
};
|
||||
|
||||
template<typename RNG>
|
||||
struct UniformMeta {
|
||||
// No-op!
|
||||
void operator()(TensorIterator& iter, double from, double to, c10::optional<Generator> gen) {
|
||||
}
|
||||
};
|
||||
|
||||
Tensor& uniform_(Tensor& self, double from, double to, c10::optional<Generator> gen) {
|
||||
return at::native::templates::uniform_impl_<UniformStub, Generator>(self, from, to, gen);
|
||||
}
|
||||
|
||||
Tensor& uniform_meta_(Tensor& self, double from, double to, c10::optional<Generator> gen) {
|
||||
return at::native::templates::uniform_impl_<UniformMeta, Generator>(self, from, to, gen);
|
||||
}
|
||||
|
||||
// ==================================================== Normal ========================================================
|
||||
|
||||
template<typename RNG>
|
||||
@ -242,6 +253,11 @@ Tensor& normal_(Tensor& self, double mean, double std, c10::optional<Generator>
|
||||
return at::native::templates::normal_impl_<NormalStub, Generator>(self, mean, std, gen);
|
||||
}
|
||||
|
||||
Tensor& normal_meta_(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
|
||||
TORCH_CHECK(std > 0.0, "normal_ expects std > 0.0, but found std=", std); // TODO: dedupe
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor& normal_out(Tensor& output, const Tensor& mean, double std, c10::optional<Generator> gen) {
|
||||
return at::native::templates::normal_out_impl<NormalStub, Generator>(output, mean, std, gen);
|
||||
}
|
||||
@ -289,6 +305,15 @@ struct RandomFromToStub {
|
||||
}
|
||||
};
|
||||
|
||||
template<typename RNG>
|
||||
struct RandomFromToMeta {
|
||||
// No-op!
|
||||
void operator()(TensorIterator& iter, uint64_t range, int64_t from, c10::optional<Generator> gen) {
|
||||
}
|
||||
void operator()(TensorIterator& iter, c10::optional<Generator> gen) {
|
||||
}
|
||||
};
|
||||
|
||||
Tensor& random_(Tensor& self, int64_t from, optional<int64_t> to, c10::optional<Generator> gen) {
|
||||
return at::native::templates::random_from_to_impl<RandomFromToStub, Generator>(self, from, to, gen);
|
||||
}
|
||||
@ -297,6 +322,19 @@ Tensor& random_(Tensor& self, int64_t to, c10::optional<Generator> gen) {
|
||||
return random_(self, 0, to, gen);
|
||||
}
|
||||
|
||||
Tensor& random_meta_(Tensor& self, c10::optional<Generator> gen) {
|
||||
// No error checking yay
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor& random_meta_(Tensor& self, int64_t from, optional<int64_t> to, c10::optional<Generator> gen) {
|
||||
return at::native::templates::random_from_to_impl<RandomFromToMeta, Generator>(self, from, to, gen);
|
||||
}
|
||||
|
||||
Tensor& random_meta_(Tensor& self, int64_t to, c10::optional<Generator> gen) {
|
||||
return random_meta_(self, 0, to, gen);
|
||||
}
|
||||
|
||||
// ====================================================================================================================
|
||||
|
||||
Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) {
|
||||
|
@ -41,6 +41,15 @@ Tensor& fill_(Tensor& self, const Tensor& value) {
|
||||
return fill_out(self, value.item());
|
||||
}
|
||||
|
||||
Tensor& fill_meta_(Tensor& self, const Scalar& value) {
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor& fill_meta_(Tensor& self, const Tensor& value) {
|
||||
TORCH_CHECK(value.dim() == 0, "fill_ only supports 0-dimension value tensor but got tensor with ", value.dim(), " dimensions.");
|
||||
return self;
|
||||
}
|
||||
|
||||
DEFINE_DISPATCH(fill_stub);
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ fill_diagonal ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -114,5 +123,9 @@ Tensor& zero_(Tensor &self) {
|
||||
return self.fill_(0);
|
||||
}
|
||||
|
||||
Tensor& zero_meta_(Tensor& self) {
|
||||
return self;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -1,81 +1,70 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
// The meta allocator ignores whatever allocation is requested and always
|
||||
// gives you nullptr
|
||||
struct MetaAllocator final : public at::Allocator {
|
||||
MetaAllocator() = default;
|
||||
~MetaAllocator() override = default;
|
||||
static void deleter(void* const pointer) {
|
||||
TORCH_INTERNAL_ASSERT(!pointer);
|
||||
}
|
||||
DataPtr allocate(const size_t nbytes) const override {
|
||||
return {nullptr, nullptr, &deleter, at::Device(DeviceType::Meta)};
|
||||
}
|
||||
DeleterFnPtr raw_deleter() const override {
|
||||
return deleter;
|
||||
}
|
||||
};
|
||||
|
||||
static MetaAllocator g_meta_alloc;
|
||||
|
||||
at::Allocator* GetMetaAllocator() {
|
||||
return &g_meta_alloc;
|
||||
}
|
||||
|
||||
Tensor empty_meta(
|
||||
IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory,
|
||||
c10::optional<c10::MemoryFormat> memory_format
|
||||
c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt,
|
||||
c10::optional<Device> device_opt,
|
||||
c10::optional<bool> pin_memory_opt,
|
||||
c10::optional<c10::MemoryFormat> memory_format_opt
|
||||
) {
|
||||
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device_or_default(device).type() == DeviceType::Meta);
|
||||
auto device = device_or_default(device_opt);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device.type() == DeviceType::Meta);
|
||||
// NB: because there is no SparseMeta (yet), non-strided layout is
|
||||
// exerciseable
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
layout_or_default(layout) == Layout::Strided,
|
||||
layout_or_default(layout_opt) == Layout::Strided,
|
||||
"strided meta tensors not supported yet"
|
||||
);
|
||||
|
||||
check_size_nonnegative(size);
|
||||
|
||||
auto tensor = detail::make_tensor<TensorImpl>(
|
||||
DispatchKeySet{DispatchKey::Meta},
|
||||
scalarTypeToTypeMeta(dtype_or_default(dtype)),
|
||||
device
|
||||
);
|
||||
|
||||
tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size);
|
||||
|
||||
auto memory_format_ = memory_format.value_or(MemoryFormat::Contiguous);
|
||||
tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format_);
|
||||
|
||||
tensor.unsafeGetTensorImpl()->set_storage_access_should_throw();
|
||||
|
||||
return tensor;
|
||||
auto* allocator = GetMetaAllocator();
|
||||
auto dtype = dtype_or_default(dtype_opt);
|
||||
auto r = at::detail::empty_generic(size, allocator, at::DispatchKey::Meta, dtype, device, memory_format_opt);
|
||||
return r;
|
||||
}
|
||||
|
||||
Tensor empty_strided_meta(
|
||||
IntArrayRef size,
|
||||
IntArrayRef stride,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory
|
||||
c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt,
|
||||
c10::optional<Device> device_opt,
|
||||
c10::optional<bool> pin_memory_opt
|
||||
) {
|
||||
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device_or_default(device).type() == DeviceType::Meta);
|
||||
// NB: because there is no SparseMeta (yet), non-strided layout is
|
||||
// exerciseable
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
layout_or_default(layout) == Layout::Strided,
|
||||
"strided meta tensors not supported yet"
|
||||
);
|
||||
|
||||
// NB: pin_memory intentionally ignored; it is a property of storage and
|
||||
// therefore meta does not track it (this is not a forced choice, but it's
|
||||
// the choice we made)
|
||||
|
||||
check_size_nonnegative(size);
|
||||
// TODO: check if strides are negative,
|
||||
// https://github.com/pytorch/pytorch/issues/53391
|
||||
// (bugged here to be consistent with CPU implementation)
|
||||
|
||||
auto tensor = detail::make_tensor<TensorImpl>(
|
||||
DispatchKeySet{DispatchKey::Meta},
|
||||
scalarTypeToTypeMeta(dtype_or_default(dtype)),
|
||||
device
|
||||
);
|
||||
|
||||
tensor.unsafeGetTensorImpl()->set_sizes_and_strides(size, stride);
|
||||
|
||||
tensor.unsafeGetTensorImpl()->set_storage_access_should_throw();
|
||||
|
||||
return tensor;
|
||||
auto t = at::native::empty_meta({0}, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
// Amazingly the CPU implementation will work for us, because most of resize
|
||||
// is generic except the memcpy, but the memcpy will be skipped if the source
|
||||
// storage is nullptr (which it always is, for meta tensors)
|
||||
at::native::resize_impl_cpu_(t.unsafeGetTensorImpl(), size, stride);
|
||||
return t;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
|
@ -274,6 +274,8 @@ dispatch:
|
||||
CompositeImplicitAutograd: func
|
||||
|
||||
# overload is ignored, but out functions get suffixed with _out in their name
|
||||
# (NB: no out functions in PyTorch today actually support autograd, but if they
|
||||
# did, you could call them here and autograd would be inferred)
|
||||
func: func.out_overload(...) -> ...
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: func_out
|
||||
|
@ -102,13 +102,12 @@ Tensor& resize_as_(
|
||||
Tensor& resize_(
|
||||
Tensor& self,
|
||||
IntArrayRef size,
|
||||
c10::optional<MemoryFormat> optional_memory_format,
|
||||
bool resize_storage) {
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
if (self.has_names()) {
|
||||
return resize_named_tensor_(self, size, optional_memory_format);
|
||||
}
|
||||
auto* self_ = self.unsafeGetTensorImpl();
|
||||
resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt, resize_storage);
|
||||
resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt);
|
||||
if (optional_memory_format.has_value()) {
|
||||
auto memory_format =
|
||||
optional_memory_format.value();
|
||||
@ -121,20 +120,5 @@ Tensor& resize_(
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor& resize_(
|
||||
Tensor& self,
|
||||
IntArrayRef size,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
return resize_(self, size, optional_memory_format, /*resize_storage=*/true);
|
||||
}
|
||||
|
||||
Tensor& resize_meta_(
|
||||
Tensor& self,
|
||||
IntArrayRef size,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
// meta tensors don't have storage, so don't resize them
|
||||
return resize_(self, size, optional_memory_format, /*resize_storage=*/false);
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -558,7 +558,7 @@
|
||||
- func: as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a)
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU, CUDA: as_strided_tensorimpl
|
||||
CPU, CUDA, Meta: as_strided_tensorimpl
|
||||
QuantizedCPU, QuantizedCUDA: as_strided_qtensorimpl
|
||||
device_guard: False
|
||||
|
||||
@ -1522,10 +1522,9 @@
|
||||
variants: method
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CPU: resize_
|
||||
CPU, Meta: resize_
|
||||
CUDA: resize_cuda_
|
||||
QuantizedCPU: quantized_resize_cpu_
|
||||
Meta: resize_meta_
|
||||
|
||||
- func: empty_quantized(int[] size, Tensor qtensor) -> Tensor
|
||||
variants: function
|
||||
@ -1660,11 +1659,13 @@
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU, CUDA, QuantizedCPU, QuantizedCUDA: fill_
|
||||
Meta: fill_meta_
|
||||
|
||||
- func: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!)
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU, CUDA, QuantizedCPU, QuantizedCUDA: fill_
|
||||
Meta: fill_meta_
|
||||
|
||||
- func: floor(Tensor self) -> Tensor
|
||||
variants: function, method
|
||||
@ -3998,6 +3999,7 @@
|
||||
variants: method, function
|
||||
dispatch:
|
||||
CPU, CUDA: zero_
|
||||
Meta: zero_meta_
|
||||
SparseCPU, SparseCUDA: zero_sparse_
|
||||
MkldnnCPU: mkldnn_zero_
|
||||
|
||||
@ -4699,7 +4701,7 @@
|
||||
variants: method
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CPU, CUDA, QuantizedCPU, QuantizedCUDA: view
|
||||
CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA: view
|
||||
MkldnnCPU: mkldnn_view
|
||||
|
||||
# Warning: If you want to change the name or overload name of this
|
||||
@ -5048,21 +5050,25 @@
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU, CUDA: random_
|
||||
Meta: random_meta_
|
||||
|
||||
- func: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU, CUDA: random_
|
||||
Meta: random_meta_
|
||||
|
||||
- func: random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU, CUDA: random_
|
||||
Meta: random_meta_
|
||||
|
||||
- func: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU, CUDA: uniform_
|
||||
Meta: uniform_meta_
|
||||
|
||||
- func: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)
|
||||
variants: method
|
||||
@ -6235,6 +6241,7 @@
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU, CUDA: normal_
|
||||
Meta: normal_meta_
|
||||
|
||||
- func: normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
|
||||
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
|
||||
@ -8451,7 +8458,7 @@
|
||||
python_module: special
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: special_entr
|
||||
CPU, CUDA: special_entr
|
||||
|
||||
- func: special_entr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: special
|
||||
|
@ -16,6 +16,7 @@ constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends |
|
||||
DispatchKey::PrivateUse2,
|
||||
DispatchKey::PrivateUse3,
|
||||
DispatchKey::MLC,
|
||||
DispatchKey::Meta,
|
||||
});
|
||||
|
||||
bool isBackendDispatchKey(DispatchKey t) {
|
||||
|
@ -7,6 +7,29 @@
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// A storage represents the underlying backing data buffer for a
|
||||
// tensor. This concept was inherited from the original Torch7
|
||||
// codebase; we'd kind of like to get rid of the concept
|
||||
// (see https://github.com/pytorch/pytorch/issues/14797) but
|
||||
// it's hard work and no one has gotten around to doing it.
|
||||
//
|
||||
// NB: storage is supposed to uniquely own a data pointer; e.g.,
|
||||
// two non-null data pointers alias if and only if they are from
|
||||
// the same storage. Technically you can violate this invariant
|
||||
// (e.g., you can create a non-owning StorageImpl with at::from_blob)
|
||||
// but a lot of things won't work correctly, including:
|
||||
//
|
||||
// - An ordinary deleter on such a storage is wrong, because normal deleters
|
||||
// assume unique ownership, but if you have two storages at the same data, that
|
||||
// implies there is some sort of shared ownership. So your deleter would have to
|
||||
// actually be internally doing some sort of refcount thing
|
||||
// - Deepcopy in Python side relies on storage equality and not data pointer
|
||||
// equality; so if there are two separate storages pointing to the same data,
|
||||
// the data will actually get duplicated in that case (one data ptr before, two
|
||||
// data ptrs after)
|
||||
// - Version counts won't work correctly, because we do all VC tracking at the
|
||||
// level of storages (unless you explicitly disconnect the VC with detach);
|
||||
// mutation because data pointers are the same are totally untracked
|
||||
struct C10_API StorageImpl final : public c10::intrusive_ptr_target {
|
||||
public:
|
||||
struct use_byte_size_t {};
|
||||
|
@ -701,6 +701,8 @@ inline DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) {
|
||||
return DeviceType::XLA;
|
||||
case DispatchKey::Vulkan:
|
||||
return DeviceType::Vulkan;
|
||||
case DispatchKey::Meta:
|
||||
return DeviceType::Meta;
|
||||
|
||||
// stuff that people are actively developing
|
||||
case DispatchKey::XPU:
|
||||
|
@ -1383,6 +1383,7 @@ class TestLinalg(TestCase):
|
||||
|
||||
# This test compares torch.linalg.norm and numpy.linalg.norm to ensure that
|
||||
# their matrix norm results match
|
||||
@skipMeta # https://github.com/pytorch/pytorch/issues/54082
|
||||
@skipCUDAIfNoMagma
|
||||
@dtypes(torch.float, torch.double)
|
||||
@precisionOverride({torch.float32: 2e-5})
|
||||
@ -1420,6 +1421,7 @@ class TestLinalg(TestCase):
|
||||
for ord in ord_settings:
|
||||
run_test_case(input, ord, dim, keepdim)
|
||||
|
||||
@skipMeta # https://github.com/pytorch/pytorch/issues/53739
|
||||
@skipCPUIfNoLapack
|
||||
@skipCUDAIfNoMagma
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@ -1474,6 +1476,7 @@ class TestLinalg(TestCase):
|
||||
actual = torch.linalg.cond(input, p)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
@skipMeta # https://github.com/pytorch/pytorch/issues/53739
|
||||
@skipCPUIfNoLapack
|
||||
@skipCUDAIfNoMagma
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@ -3369,6 +3372,7 @@ class TestLinalg(TestCase):
|
||||
a_inv = torch.linalg.tensorinv(a, ind=ind)
|
||||
self.assertEqual(a_inv.shape, a.shape[ind:] + a.shape[:ind])
|
||||
|
||||
@skipMeta # See https://github.com/pytorch/pytorch/issues/53739
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
|
@ -29,7 +29,7 @@ from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
skipCUDAIfNoMagma, skipCUDAVersionIn,
|
||||
onlyCUDA, onlyCPU,
|
||||
dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast,
|
||||
dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, skipMeta,
|
||||
PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyOnCPUAndCUDA,
|
||||
expectedAlertNondeterministic)
|
||||
from typing import Dict, List
|
||||
@ -5956,6 +5956,7 @@ class TestTorchDeviceType(TestCase):
|
||||
for x in xs:
|
||||
_test_helper(x, op, unary=True)
|
||||
|
||||
@skipMeta
|
||||
def test_dlpack_conversion(self, device):
|
||||
x = torch.randn(1, 2, 3, 4, device=device, dtype=torch.float)
|
||||
z = from_dlpack(to_dlpack(x))
|
||||
|
@ -248,11 +248,12 @@ if (C10_UNLIKELY(current_device.has_value())) {
|
||||
|
||||
if k is SchemaKind.functional:
|
||||
if self.dispatch_key == DispatchKey.Meta:
|
||||
# TODO: dedupe this with below
|
||||
return """
|
||||
if (strides.empty()) {
|
||||
outputs_[output_idx] = at::empty(sizes, options.device(at::kMeta));
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(0, "not implemented yet");
|
||||
outputs_[output_idx] = at::empty_strided(sizes, strides, options.device(at::kMeta));
|
||||
}
|
||||
"""
|
||||
else:
|
||||
|
@ -57,7 +57,11 @@ class Tensor(torch._C._TensorBase):
|
||||
if id(self) in memo:
|
||||
return memo[id(self)]
|
||||
with torch.no_grad():
|
||||
if self.is_sparse or self.device.type == 'xla' or self.device.type == 'mlc':
|
||||
# TODO: skipping storage copy is wrong for meta, as meta
|
||||
# does accurate alias tracking; however, the code below
|
||||
# doesn't work because of
|
||||
# https://github.com/pytorch/pytorch/issues/47442
|
||||
if self.is_sparse or self.device.type in ['xla', 'mlc', 'meta']:
|
||||
new_tensor = self.clone()
|
||||
else:
|
||||
new_storage = self.storage().__deepcopy__(memo)
|
||||
|
@ -60,6 +60,10 @@ at::DeprecatedTypeProperties* get_type(at::Backend backend, at::ScalarType scala
|
||||
PyTypeObject* getPyTypeObject(
|
||||
const at::Storage& storage,
|
||||
const caffe2::TypeMeta dtype) {
|
||||
// TODO: https://github.com/pytorch/pytorch/issues/47442
|
||||
if (storage.device_type() == at::DeviceType::Meta) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "python bindings for meta storage objects not supported");
|
||||
}
|
||||
at::ScalarType scalarType = at::typeMetaToScalarType(dtype);
|
||||
auto attype = &at::getDeprecatedTypeProperties(
|
||||
at::dispatchKeyToBackend(c10::computeDispatchKey(scalarType, c10::nullopt, storage.device_type())),
|
||||
|
@ -1072,7 +1072,7 @@ size_t Engine::ready_queue_size(const std::shared_ptr<GraphTask>& graph_task, at
|
||||
|
||||
// CPU ready queue is per GraphTask, but CUDA device ready queues are shared across all graph tasks
|
||||
auto Engine::ready_queue(std::shared_ptr<ReadyQueue> cpu_ready_queue, at::Device device) -> std::shared_ptr<ReadyQueue>{
|
||||
if (device.type() == at::kCPU) {
|
||||
if (device.type() == at::kCPU || device.type() == at::DeviceType::Meta) {
|
||||
// return the cpu ready queue passed in
|
||||
TORCH_INTERNAL_ASSERT(cpu_ready_queue);
|
||||
return cpu_ready_queue;
|
||||
|
@ -54,6 +54,9 @@ static void recursive_apply(IntArrayRef sizes, ScalarType scalarType, int64_t di
|
||||
}
|
||||
|
||||
Tensor & apply_(Tensor & self, PyObject* fn) {
|
||||
if (self.is_meta()) {
|
||||
return self; // Just skip
|
||||
}
|
||||
if (!self.device().is_cpu()) {
|
||||
throw TypeError("apply_ is only implemented on CPU tensors");
|
||||
}
|
||||
@ -63,13 +66,16 @@ Tensor & apply_(Tensor & self, PyObject* fn) {
|
||||
}
|
||||
|
||||
Tensor & map_(Tensor & self, const Tensor & other_, PyObject* fn) {
|
||||
if (!self.device().is_cpu()) {
|
||||
throw TypeError("map_ is only implemented on CPU tensors");
|
||||
}
|
||||
if (!other_.options().type_equal(self.options())) {
|
||||
throw TypeError("map_: expected %s for 'other' (got %s)",
|
||||
self.toString().c_str(), other_.toString().c_str());
|
||||
}
|
||||
if (self.is_meta()) {
|
||||
return self; // Just skip
|
||||
}
|
||||
if (!self.device().is_cpu()) {
|
||||
throw TypeError("map_ is only implemented on CPU tensors");
|
||||
}
|
||||
Tensor other;
|
||||
std::tie(other) = expand_inplace(self, other_, "map_");
|
||||
auto scalarType = self.scalar_type();
|
||||
@ -78,9 +84,6 @@ Tensor & map_(Tensor & self, const Tensor & other_, PyObject* fn) {
|
||||
}
|
||||
|
||||
Tensor & map2_(Tensor & self, const Tensor & x_, const Tensor & y_, PyObject* fn) {
|
||||
if (!self.device().is_cpu() || !x_.device().is_cpu() || !y_.device().is_cpu()) {
|
||||
throw TypeError("map2_ is only implemented on CPU tensors");
|
||||
}
|
||||
if (!x_.options().type_equal(self.options())) {
|
||||
throw TypeError("map2_: expected %s for argument 'x' (got %s)",
|
||||
self.toString().c_str(), x_.toString().c_str());
|
||||
@ -89,6 +92,12 @@ Tensor & map2_(Tensor & self, const Tensor & x_, const Tensor & y_, PyObject* fn
|
||||
throw TypeError("map2_: expected %s for argument 'y' (got %s)",
|
||||
self.toString().c_str(), y_.toString().c_str());
|
||||
}
|
||||
if (self.is_meta()) {
|
||||
return self; // Just skip
|
||||
}
|
||||
if (!self.device().is_cpu() || !x_.device().is_cpu() || !y_.device().is_cpu()) {
|
||||
throw TypeError("map2_ is only implemented on CPU tensors");
|
||||
}
|
||||
Tensor other1, other2;
|
||||
std::tie(other1, other2) = expand_inplace(self, x_, y_, "map2_");
|
||||
auto scalarType = self.scalar_type();
|
||||
|
Reference in New Issue
Block a user