mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add memory format argument to the clone
operator (#27106)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27106 Adds memory_format option to the `clone` operator. Introduce new `clone` behavior if used with `input_t.clone(memory_format=torch.preserve_format)`: 1) If tensor is non-overlapping and dense - output tensor will have the same strides as input tensor. 2) If not (1) and tensor is stored in the channels last format, output tensor going to have channels last format. 3) Output tensor is going to be contiguous in all other cases. --- Dense tensor is the tensor that store values in a contiguous block of memory. Non-overlapping tensor is the tensor in which elements occupy individual non-repetitive memory. Test Plan: Imported from OSS Differential Revision: D17699357 Pulled By: VitalyFedyunin fbshipit-source-id: 5ae1537c2aca1abf0bf1eec4416846129c156f66
This commit is contained in:
committed by
Facebook Github Bot
parent
111da77912
commit
7b2e8c323c
@ -419,7 +419,6 @@ bool aten_op_is_already_moved_to_c10(const c10::OperatorName& opName) {
|
||||
{"aten::frobenius_norm", "dim"},
|
||||
{"aten::nuclear_norm", ""},
|
||||
{"aten::nuclear_norm", "dim"},
|
||||
{"aten::clone", ""},
|
||||
{"aten::resize_as_", ""},
|
||||
{"aten::pow", "Tensor_Scalar"},
|
||||
{"aten::zero_", ""},
|
||||
@ -1201,6 +1200,7 @@ bool aten_op_is_not_moved_to_c10_yet(const c10::OperatorName& opName) {
|
||||
{"aten::frobenius_norm", "out"},
|
||||
{"aten::nuclear_norm", "out"},
|
||||
{"aten::nuclear_norm", "dim_out"},
|
||||
{"aten::clone", ""},
|
||||
{"aten::pow", "Tensor_Scalar_out"},
|
||||
{"aten::sub", "out"},
|
||||
{"aten::addmm", "out"},
|
||||
|
@ -186,10 +186,15 @@ class CAFFE2_API Tensor {
|
||||
int64_t ndimension() const {
|
||||
return dim();
|
||||
}
|
||||
|
||||
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
|
||||
return impl_->is_contiguous(memory_format);
|
||||
}
|
||||
|
||||
bool is_non_overlapping_and_dense() const {
|
||||
return impl_->is_non_overlapping_and_dense();
|
||||
}
|
||||
|
||||
at::MemoryFormat suggest_memory_format() const {
|
||||
if (impl_->is_strides_like_channels_last()) {
|
||||
return at::MemoryFormat::ChannelsLast;
|
||||
@ -734,7 +739,7 @@ class CAFFE2_API Tensor {
|
||||
#ifdef BUILD_NAMEDTENSOR
|
||||
Tensor norm(c10::optional<Scalar> p, DimnameList dim, bool keepdim=false) const;
|
||||
#endif
|
||||
Tensor clone() const;
|
||||
Tensor clone(c10::optional<MemoryFormat> memory_format=c10::nullopt) const;
|
||||
Tensor & resize_as_(const Tensor & the_template) const;
|
||||
Tensor pow(Scalar exponent) const;
|
||||
Tensor & zero_() const;
|
||||
|
@ -3044,26 +3044,25 @@ inline Tensor Tensor::norm(c10::optional<Scalar> p, DimnameList dim, bool keepdi
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
inline Tensor Tensor::clone() const {
|
||||
inline Tensor Tensor::clone(c10::optional<MemoryFormat> memory_format) const {
|
||||
#ifdef USE_STATIC_DISPATCH
|
||||
at::AutoNonVariableTypeMode _var_guard(true);
|
||||
switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) {
|
||||
case Backend::CPU:
|
||||
return CPUType::clone(const_cast<Tensor&>(*this));
|
||||
return CPUType::clone(const_cast<Tensor&>(*this), memory_format);
|
||||
break;
|
||||
case Backend::QuantizedCPU:
|
||||
return QuantizedCPUType::clone(const_cast<Tensor&>(*this));
|
||||
return QuantizedCPUType::clone(const_cast<Tensor&>(*this), memory_format);
|
||||
break;
|
||||
case Backend::SparseCPU:
|
||||
return SparseCPUType::clone(const_cast<Tensor&>(*this));
|
||||
return SparseCPUType::clone(const_cast<Tensor&>(*this), memory_format);
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("clone not implemented for ", at::toString(type_set()));
|
||||
}
|
||||
#else
|
||||
static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::clone", ""}).value();
|
||||
return c10::Dispatcher::singleton().callUnboxed<Tensor, const Tensor &>(
|
||||
op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast<Tensor&>(*this));
|
||||
static auto table = globalATenDispatch().getOpTable("aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor");
|
||||
return table->callUnboxed<Tensor, const Tensor &, c10::optional<MemoryFormat>>(const_cast<Tensor&>(*this), memory_format);
|
||||
#endif
|
||||
}
|
||||
inline Tensor & Tensor::resize_as_(const Tensor & the_template) const {
|
||||
|
@ -867,8 +867,20 @@ Tensor from_file(std::string filename, c10::optional<bool> shared, c10::optional
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ clone ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Tensor clone(const Tensor& src) {
|
||||
auto self = at::empty_like(src);
|
||||
Tensor clone(const Tensor& src, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
auto memory_format =
|
||||
optional_memory_format.value_or(MemoryFormat::Contiguous);
|
||||
if (memory_format == MemoryFormat::Preserve) {
|
||||
if (src.is_non_overlapping_and_dense()) {
|
||||
// Copy all strides
|
||||
auto self = at::empty_strided(src.sizes(), src.strides(), src.options());
|
||||
self.copy_(src);
|
||||
return self;
|
||||
} else {
|
||||
memory_format = src.suggest_memory_format();
|
||||
}
|
||||
}
|
||||
auto self = at::empty_like(src, src.options(), memory_format);
|
||||
self.copy_(src);
|
||||
return self;
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) {
|
||||
AT_ERROR("mkldnn_reshape: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
Tensor mkldnn_clone(const Tensor& self) {
|
||||
Tensor mkldnn_clone(const Tensor& self, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
AT_ERROR("mkldnn_clone: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
@ -54,7 +54,11 @@ Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) {
|
||||
return new_with_itensor_mkldnn(std::move(y), self.options());
|
||||
}
|
||||
|
||||
Tensor mkldnn_clone(const Tensor& self) {
|
||||
Tensor mkldnn_clone(const Tensor& self, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
TORCH_CHECK(
|
||||
!optional_memory_format.has_value(),
|
||||
"unsupported memory format option ",
|
||||
optional_memory_format.value());
|
||||
ideep::tensor& src = itensor_from_mkldnn(self);
|
||||
ideep::tensor dst;
|
||||
ideep::direct_copy::compute<AllocForMKLDNN>(src, dst);
|
||||
|
@ -3108,8 +3108,7 @@
|
||||
- func: nuclear_norm.dim_out(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
|
||||
- func: clone(Tensor self) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
- func: clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU: clone
|
||||
|
@ -155,13 +155,29 @@ Tensor& set_quantizer_(Tensor& self, ConstQuantizerPtr quantizer) {
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor quantized_clone(const Tensor& self) {
|
||||
Tensor quantized_clone(const Tensor& self, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
// TODO: add per channel support
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
self.qscheme() == at::kPerTensorAffine,
|
||||
"clone for quantized Tensor only works for PerTensorAffine scheme right now");
|
||||
|
||||
auto memory_format =
|
||||
optional_memory_format.value_or(MemoryFormat::Contiguous);
|
||||
|
||||
// TODO: To support all features of MemoryFormat::Preserve we need to add
|
||||
// _empty_affine_quantized_strided function and use it similarly to
|
||||
// Tensor clone(const Tensor& src, c10::optional<c10::MemoryFormat> optional_memory_format)
|
||||
// if (self.is_non_overlapping_and_dense()) -> _empty_affine_quantized_strided
|
||||
if (memory_format == MemoryFormat::Preserve) {
|
||||
memory_format = self.suggest_memory_format();
|
||||
}
|
||||
|
||||
Tensor dst = at::_empty_affine_quantized(
|
||||
self.sizes(), self.options(), self.q_scale(), self.q_zero_point());
|
||||
self.sizes(),
|
||||
self.options(),
|
||||
self.q_scale(),
|
||||
self.q_zero_point(),
|
||||
memory_format);
|
||||
|
||||
at::native::copy_(dst, self, false);
|
||||
|
||||
|
@ -255,7 +255,11 @@ Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values_, A
|
||||
|
||||
// NB: Deleted newWithSizeNd variants
|
||||
|
||||
SparseTensor clone_sparse(const SparseTensor& self) {
|
||||
SparseTensor clone_sparse(const SparseTensor& self, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
TORCH_CHECK(
|
||||
!optional_memory_format.has_value(),
|
||||
"unsupported memory format option ",
|
||||
optional_memory_format.value());
|
||||
SparseTensor other = new_with_dims_sparse(self.sparse_dim(), self.dense_dim(), self.sizes(), self.options());
|
||||
copy_into_sparse(other, self._indices(), self._values(), true);
|
||||
return other._coalesced_(self.is_coalesced());
|
||||
|
@ -186,10 +186,15 @@ class CAFFE2_API Tensor {
|
||||
int64_t ndimension() const {
|
||||
return dim();
|
||||
}
|
||||
|
||||
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
|
||||
return impl_->is_contiguous(memory_format);
|
||||
}
|
||||
|
||||
bool is_non_overlapping_and_dense() const {
|
||||
return impl_->is_non_overlapping_and_dense();
|
||||
}
|
||||
|
||||
at::MemoryFormat suggest_memory_format() const {
|
||||
if (impl_->is_strides_like_channels_last()) {
|
||||
return at::MemoryFormat::ChannelsLast;
|
||||
|
@ -130,6 +130,35 @@ bool TensorImpl::compute_strides_like_channels_last() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool TensorImpl::compute_non_overlapping_and_dense() const {
|
||||
if (dim() == 1) {
|
||||
return size(0) < 2 || stride(0) == 1;
|
||||
}
|
||||
SmallVector<int64_t,5> perm;
|
||||
perm.resize(dim());
|
||||
for (int64_t i = 0; i < dim(); i ++) {
|
||||
perm[i] = i;
|
||||
}
|
||||
// Sort by strides, leaving 0 and 1 sized dims at the end of the array
|
||||
std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {
|
||||
if (sizes_[a] < 2) {
|
||||
return false;
|
||||
}
|
||||
return strides_[a] < strides_[b];
|
||||
});
|
||||
auto require_stride = 1;
|
||||
for (int64_t i = 0; i < dim(); i ++) {
|
||||
if (sizes_[perm[i]] < 2) {
|
||||
return true;
|
||||
}
|
||||
if (strides_[perm[i]] != require_stride) {
|
||||
return false;
|
||||
}
|
||||
require_stride *= sizes_[perm[i]];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void TensorImpl::release_resources() {
|
||||
autograd_meta_.reset();
|
||||
if (storage_) {
|
||||
|
@ -1390,6 +1390,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
is_contiguous_ = false;
|
||||
is_channels_last_contiguous_ = false;
|
||||
is_channels_last_ = false;
|
||||
is_non_overlapping_and_dense_ = false;
|
||||
switch (memory_format) {
|
||||
case MemoryFormat::Contiguous: {
|
||||
strides_.resize(sizes_.size(), 0);
|
||||
@ -1401,6 +1402,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
}
|
||||
}
|
||||
is_contiguous_ = true;
|
||||
is_non_overlapping_and_dense_ = true;
|
||||
return;
|
||||
}
|
||||
case MemoryFormat::ChannelsLast: {
|
||||
@ -1410,6 +1412,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
set_sizes_and_strides(sizes(), get_channels_last_strides(sizes()));
|
||||
is_channels_last_contiguous_ = true;
|
||||
is_channels_last_ = true;
|
||||
is_non_overlapping_and_dense_ = true;
|
||||
return;
|
||||
}
|
||||
case MemoryFormat::Preserve:
|
||||
@ -1421,6 +1424,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return is_channels_last_;
|
||||
}
|
||||
|
||||
bool is_non_overlapping_and_dense() const {
|
||||
return is_non_overlapping_and_dense_;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
// The Caffe2 Resize() method supports being called both as Resize({2,2}) as
|
||||
@ -1501,6 +1508,8 @@ private:
|
||||
|
||||
bool compute_strides_like_channels_last() const;
|
||||
|
||||
bool compute_non_overlapping_and_dense() const;
|
||||
|
||||
protected:
|
||||
/**
|
||||
* Recompute the cached numel of a tensor. Call this if you modify sizes.
|
||||
@ -1517,6 +1526,7 @@ protected:
|
||||
is_contiguous_ = compute_contiguous();
|
||||
is_channels_last_contiguous_ = compute_channels_last_contiguous();
|
||||
is_channels_last_ = is_channels_last_contiguous_ || compute_strides_like_channels_last();
|
||||
is_non_overlapping_and_dense_ = is_contiguous_ || is_channels_last_contiguous_ || compute_non_overlapping_and_dense();
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1546,6 +1556,9 @@ protected:
|
||||
dest_impl->type_set_ = dest_impl->type_set_.remove(TensorTypeId::VariableTensorId);
|
||||
}
|
||||
dest_impl->is_contiguous_ = src_impl->is_contiguous_;
|
||||
dest_impl->is_channels_last_contiguous_ = src_impl->is_channels_last_contiguous_;
|
||||
dest_impl->is_channels_last_ = src_impl->is_channels_last_;
|
||||
dest_impl->is_non_overlapping_and_dense_ = src_impl->is_non_overlapping_and_dense_;
|
||||
dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_;
|
||||
dest_impl->reserved_ = src_impl->reserved_;
|
||||
dest_impl->set_version_counter(version_counter);
|
||||
@ -1641,6 +1654,11 @@ protected:
|
||||
// contiguous memory block.
|
||||
bool is_channels_last_contiguous_ = false;
|
||||
|
||||
// Dense tensor is the tensor that store values in a contiguous block of memory.
|
||||
// Non-overlapping tensor is the tensor in which elements occupy individual
|
||||
// non-repetitive memory.
|
||||
bool is_non_overlapping_and_dense_ = false;
|
||||
|
||||
bool is_wrapped_number_ = false;
|
||||
|
||||
// NOTE [ Metadata Change for a Detached Tensor ]
|
||||
|
@ -12402,6 +12402,36 @@ class TestTorchDeviceType(TestCase):
|
||||
y = nhwc.permute(0, 1, 3, 2).permute(0, 1, 3, 2)
|
||||
self.assertTrue(y.is_contiguous(memory_format=torch.channels_last))
|
||||
|
||||
def test_memory_format_clone(self, device):
|
||||
nhwc = torch.randn((10, 3, 32, 32), device=device).contiguous(memory_format=torch.channels_last)
|
||||
# nhwc is not memory dense, but looks like channels last
|
||||
nhwc = nhwc[:, :, ::2, ::2]
|
||||
clone = nhwc.clone(memory_format=torch.preserve_format)
|
||||
self.assertFalse(clone.is_contiguous())
|
||||
self.assertTrue(clone.is_contiguous(memory_format=torch.channels_last))
|
||||
self.assertFalse(nhwc.is_contiguous())
|
||||
self.assertFalse(nhwc.is_contiguous(memory_format=torch.channels_last))
|
||||
self.assertEqual(nhwc, clone)
|
||||
|
||||
nhwc = torch.randn((10, 3, 32, 32), device=device).contiguous(memory_format=torch.channels_last)
|
||||
clone = nhwc.clone(memory_format=torch.contiguous_format)
|
||||
self.assertTrue(clone.is_contiguous())
|
||||
self.assertFalse(clone.is_contiguous(memory_format=torch.channels_last))
|
||||
self.assertEqual(nhwc, clone)
|
||||
|
||||
nhwc = torch.randn((10, 3, 32, 32), device=device).contiguous(memory_format=torch.channels_last)
|
||||
clone = nhwc.clone()
|
||||
self.assertTrue(clone.is_contiguous())
|
||||
self.assertFalse(clone.is_contiguous(memory_format=torch.channels_last))
|
||||
self.assertEqual(nhwc, clone)
|
||||
|
||||
x = torch.randn((3, 4, 5, 6, 7, 8, 9), device=device)
|
||||
for _ in range(10):
|
||||
permutation = list(range(len(x.shape)))
|
||||
random.shuffle(permutation)
|
||||
x = x.permute(permutation)
|
||||
self.assertEqual(x.stride(), x.clone(memory_format=torch.preserve_format).stride())
|
||||
|
||||
def test_memory_format_empty_like(self, device):
|
||||
x = torch.randn(10, 3, 32, 32, device=device)
|
||||
nhwc = x.contiguous(memory_format=torch.channels_last)
|
||||
|
@ -233,7 +233,7 @@
|
||||
- name: clamp_max(Tensor self, Scalar max) -> Tensor
|
||||
self: grad * (self <= max).to(grad.dtype())
|
||||
|
||||
- name: clone(Tensor self) -> Tensor
|
||||
- name: clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
|
||||
self: grad
|
||||
|
||||
- name: coalesce(Tensor self) -> Tensor
|
||||
|
@ -791,7 +791,7 @@ class ShapePropagator {
|
||||
"aten::asin(Tensor self) -> Tensor",
|
||||
"aten::atan(Tensor self) -> Tensor",
|
||||
"aten::ceil(Tensor self) -> Tensor",
|
||||
"aten::clone(Tensor self) -> Tensor",
|
||||
"aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
|
||||
"aten::contiguous(Tensor self, *, MemoryFormat memory_format=contiguous_format) -> Tensor",
|
||||
"aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
|
||||
"aten::celu(Tensor self, Scalar alpha) -> Tensor",
|
||||
|
@ -972,7 +972,7 @@ def index_select(g, self, dim, index):
|
||||
index = g.op("Constant", value_t=torch.LongTensor([index_const]))
|
||||
elif index_dim is not None:
|
||||
if index_dim == 0:
|
||||
# Index is a scalar. Reshape it to a size 1 tensor.
|
||||
# Index is a scalar. Reshape it to a size 1 tensor.
|
||||
index = g.op("Reshape", index, g.op("Constant", value_t=torch.LongTensor([1])))
|
||||
return g.op("Gather", self, index, axis_i=dim)
|
||||
|
||||
@ -1042,7 +1042,7 @@ def cosine_similarity(g, x1, x2, dim, eps):
|
||||
|
||||
|
||||
# ignore clone operators that are inserted by PyTorch autograd
|
||||
def clone(g, input):
|
||||
def clone(g, input, unused_memory_format):
|
||||
return input
|
||||
|
||||
|
||||
@ -1950,10 +1950,10 @@ def multinomial(g, input, num_samples, replacement=False, generator=None):
|
||||
dtype_i=sym_help.cast_pytorch_to_onnx['Long'],
|
||||
sample_size_i=num_samples)
|
||||
|
||||
def baddbmm(g, self, batch1, batch2, beta, alpha):
|
||||
dtype = self.type().scalarType()
|
||||
batch_mul = matmul(g, batch1, batch2)
|
||||
mul_a = mul(g, batch_mul, g.op("Cast", alpha, to_i=sym_help.cast_pytorch_to_onnx[dtype]))
|
||||
def baddbmm(g, self, batch1, batch2, beta, alpha):
|
||||
dtype = self.type().scalarType()
|
||||
batch_mul = matmul(g, batch1, batch2)
|
||||
mul_a = mul(g, batch_mul, g.op("Cast", alpha, to_i=sym_help.cast_pytorch_to_onnx[dtype]))
|
||||
mul_b = mul(g, self, g.op("Cast", beta, to_i=sym_help.cast_pytorch_to_onnx[dtype]))
|
||||
return add(g, mul_a, mul_b)
|
||||
|
||||
|
Reference in New Issue
Block a user