Add memory format argument to the clone operator (#27106)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27106

Adds memory_format option to the `clone` operator.

Introduce new `clone` behavior if used with `input_t.clone(memory_format=torch.preserve_format)`:
1) If tensor is non-overlapping and dense - output tensor will have the same strides as input tensor.
2) If not (1) and tensor is stored in the channels last format, output tensor going to have channels last format.
3) Output tensor is going to be contiguous in all other cases.

 ---
Dense tensor is the tensor that store values in a contiguous block of memory.
Non-overlapping tensor is the tensor in which elements occupy individual non-repetitive memory.

Test Plan: Imported from OSS

Differential Revision: D17699357

Pulled By: VitalyFedyunin

fbshipit-source-id: 5ae1537c2aca1abf0bf1eec4416846129c156f66
This commit is contained in:
Vitaly Fedyunin
2019-10-03 12:04:42 -07:00
committed by Facebook Github Bot
parent 111da77912
commit 7b2e8c323c
15 changed files with 147 additions and 26 deletions

View File

@ -419,7 +419,6 @@ bool aten_op_is_already_moved_to_c10(const c10::OperatorName& opName) {
{"aten::frobenius_norm", "dim"},
{"aten::nuclear_norm", ""},
{"aten::nuclear_norm", "dim"},
{"aten::clone", ""},
{"aten::resize_as_", ""},
{"aten::pow", "Tensor_Scalar"},
{"aten::zero_", ""},
@ -1201,6 +1200,7 @@ bool aten_op_is_not_moved_to_c10_yet(const c10::OperatorName& opName) {
{"aten::frobenius_norm", "out"},
{"aten::nuclear_norm", "out"},
{"aten::nuclear_norm", "dim_out"},
{"aten::clone", ""},
{"aten::pow", "Tensor_Scalar_out"},
{"aten::sub", "out"},
{"aten::addmm", "out"},

View File

@ -186,10 +186,15 @@ class CAFFE2_API Tensor {
int64_t ndimension() const {
return dim();
}
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
return impl_->is_contiguous(memory_format);
}
bool is_non_overlapping_and_dense() const {
return impl_->is_non_overlapping_and_dense();
}
at::MemoryFormat suggest_memory_format() const {
if (impl_->is_strides_like_channels_last()) {
return at::MemoryFormat::ChannelsLast;
@ -734,7 +739,7 @@ class CAFFE2_API Tensor {
#ifdef BUILD_NAMEDTENSOR
Tensor norm(c10::optional<Scalar> p, DimnameList dim, bool keepdim=false) const;
#endif
Tensor clone() const;
Tensor clone(c10::optional<MemoryFormat> memory_format=c10::nullopt) const;
Tensor & resize_as_(const Tensor & the_template) const;
Tensor pow(Scalar exponent) const;
Tensor & zero_() const;

View File

@ -3044,26 +3044,25 @@ inline Tensor Tensor::norm(c10::optional<Scalar> p, DimnameList dim, bool keepdi
#endif
}
#endif
inline Tensor Tensor::clone() const {
inline Tensor Tensor::clone(c10::optional<MemoryFormat> memory_format) const {
#ifdef USE_STATIC_DISPATCH
at::AutoNonVariableTypeMode _var_guard(true);
switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) {
case Backend::CPU:
return CPUType::clone(const_cast<Tensor&>(*this));
return CPUType::clone(const_cast<Tensor&>(*this), memory_format);
break;
case Backend::QuantizedCPU:
return QuantizedCPUType::clone(const_cast<Tensor&>(*this));
return QuantizedCPUType::clone(const_cast<Tensor&>(*this), memory_format);
break;
case Backend::SparseCPU:
return SparseCPUType::clone(const_cast<Tensor&>(*this));
return SparseCPUType::clone(const_cast<Tensor&>(*this), memory_format);
break;
default:
AT_ERROR("clone not implemented for ", at::toString(type_set()));
}
#else
static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::clone", ""}).value();
return c10::Dispatcher::singleton().callUnboxed<Tensor, const Tensor &>(
op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast<Tensor&>(*this));
static auto table = globalATenDispatch().getOpTable("aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor");
return table->callUnboxed<Tensor, const Tensor &, c10::optional<MemoryFormat>>(const_cast<Tensor&>(*this), memory_format);
#endif
}
inline Tensor & Tensor::resize_as_(const Tensor & the_template) const {

View File

@ -867,8 +867,20 @@ Tensor from_file(std::string filename, c10::optional<bool> shared, c10::optional
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ clone ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Tensor clone(const Tensor& src) {
auto self = at::empty_like(src);
Tensor clone(const Tensor& src, c10::optional<c10::MemoryFormat> optional_memory_format) {
auto memory_format =
optional_memory_format.value_or(MemoryFormat::Contiguous);
if (memory_format == MemoryFormat::Preserve) {
if (src.is_non_overlapping_and_dense()) {
// Copy all strides
auto self = at::empty_strided(src.sizes(), src.strides(), src.options());
self.copy_(src);
return self;
} else {
memory_format = src.suggest_memory_format();
}
}
auto self = at::empty_like(src, src.options(), memory_format);
self.copy_(src);
return self;
}

View File

@ -16,7 +16,7 @@ Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) {
AT_ERROR("mkldnn_reshape: ATen not compiled with MKLDNN support");
}
Tensor mkldnn_clone(const Tensor& self) {
Tensor mkldnn_clone(const Tensor& self, c10::optional<c10::MemoryFormat> optional_memory_format) {
AT_ERROR("mkldnn_clone: ATen not compiled with MKLDNN support");
}
@ -54,7 +54,11 @@ Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) {
return new_with_itensor_mkldnn(std::move(y), self.options());
}
Tensor mkldnn_clone(const Tensor& self) {
Tensor mkldnn_clone(const Tensor& self, c10::optional<c10::MemoryFormat> optional_memory_format) {
TORCH_CHECK(
!optional_memory_format.has_value(),
"unsupported memory format option ",
optional_memory_format.value());
ideep::tensor& src = itensor_from_mkldnn(self);
ideep::tensor dst;
ideep::direct_copy::compute<AllocForMKLDNN>(src, dst);

View File

@ -3108,8 +3108,7 @@
- func: nuclear_norm.dim_out(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
variants: function
- func: clone(Tensor self) -> Tensor
use_c10_dispatcher: full
- func: clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
variants: function, method
dispatch:
CPU: clone

View File

@ -155,13 +155,29 @@ Tensor& set_quantizer_(Tensor& self, ConstQuantizerPtr quantizer) {
return self;
}
Tensor quantized_clone(const Tensor& self) {
Tensor quantized_clone(const Tensor& self, c10::optional<c10::MemoryFormat> optional_memory_format) {
// TODO: add per channel support
TORCH_INTERNAL_ASSERT(
self.qscheme() == at::kPerTensorAffine,
"clone for quantized Tensor only works for PerTensorAffine scheme right now");
auto memory_format =
optional_memory_format.value_or(MemoryFormat::Contiguous);
// TODO: To support all features of MemoryFormat::Preserve we need to add
// _empty_affine_quantized_strided function and use it similarly to
// Tensor clone(const Tensor& src, c10::optional<c10::MemoryFormat> optional_memory_format)
// if (self.is_non_overlapping_and_dense()) -> _empty_affine_quantized_strided
if (memory_format == MemoryFormat::Preserve) {
memory_format = self.suggest_memory_format();
}
Tensor dst = at::_empty_affine_quantized(
self.sizes(), self.options(), self.q_scale(), self.q_zero_point());
self.sizes(),
self.options(),
self.q_scale(),
self.q_zero_point(),
memory_format);
at::native::copy_(dst, self, false);

View File

@ -255,7 +255,11 @@ Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values_, A
// NB: Deleted newWithSizeNd variants
SparseTensor clone_sparse(const SparseTensor& self) {
SparseTensor clone_sparse(const SparseTensor& self, c10::optional<c10::MemoryFormat> optional_memory_format) {
TORCH_CHECK(
!optional_memory_format.has_value(),
"unsupported memory format option ",
optional_memory_format.value());
SparseTensor other = new_with_dims_sparse(self.sparse_dim(), self.dense_dim(), self.sizes(), self.options());
copy_into_sparse(other, self._indices(), self._values(), true);
return other._coalesced_(self.is_coalesced());

View File

@ -186,10 +186,15 @@ class CAFFE2_API Tensor {
int64_t ndimension() const {
return dim();
}
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
return impl_->is_contiguous(memory_format);
}
bool is_non_overlapping_and_dense() const {
return impl_->is_non_overlapping_and_dense();
}
at::MemoryFormat suggest_memory_format() const {
if (impl_->is_strides_like_channels_last()) {
return at::MemoryFormat::ChannelsLast;

View File

@ -130,6 +130,35 @@ bool TensorImpl::compute_strides_like_channels_last() const {
return false;
}
bool TensorImpl::compute_non_overlapping_and_dense() const {
if (dim() == 1) {
return size(0) < 2 || stride(0) == 1;
}
SmallVector<int64_t,5> perm;
perm.resize(dim());
for (int64_t i = 0; i < dim(); i ++) {
perm[i] = i;
}
// Sort by strides, leaving 0 and 1 sized dims at the end of the array
std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {
if (sizes_[a] < 2) {
return false;
}
return strides_[a] < strides_[b];
});
auto require_stride = 1;
for (int64_t i = 0; i < dim(); i ++) {
if (sizes_[perm[i]] < 2) {
return true;
}
if (strides_[perm[i]] != require_stride) {
return false;
}
require_stride *= sizes_[perm[i]];
}
return true;
}
void TensorImpl::release_resources() {
autograd_meta_.reset();
if (storage_) {

View File

@ -1390,6 +1390,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
is_contiguous_ = false;
is_channels_last_contiguous_ = false;
is_channels_last_ = false;
is_non_overlapping_and_dense_ = false;
switch (memory_format) {
case MemoryFormat::Contiguous: {
strides_.resize(sizes_.size(), 0);
@ -1401,6 +1402,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
}
}
is_contiguous_ = true;
is_non_overlapping_and_dense_ = true;
return;
}
case MemoryFormat::ChannelsLast: {
@ -1410,6 +1412,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
set_sizes_and_strides(sizes(), get_channels_last_strides(sizes()));
is_channels_last_contiguous_ = true;
is_channels_last_ = true;
is_non_overlapping_and_dense_ = true;
return;
}
case MemoryFormat::Preserve:
@ -1421,6 +1424,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return is_channels_last_;
}
bool is_non_overlapping_and_dense() const {
return is_non_overlapping_and_dense_;
}
private:
// The Caffe2 Resize() method supports being called both as Resize({2,2}) as
@ -1501,6 +1508,8 @@ private:
bool compute_strides_like_channels_last() const;
bool compute_non_overlapping_and_dense() const;
protected:
/**
* Recompute the cached numel of a tensor. Call this if you modify sizes.
@ -1517,6 +1526,7 @@ protected:
is_contiguous_ = compute_contiguous();
is_channels_last_contiguous_ = compute_channels_last_contiguous();
is_channels_last_ = is_channels_last_contiguous_ || compute_strides_like_channels_last();
is_non_overlapping_and_dense_ = is_contiguous_ || is_channels_last_contiguous_ || compute_non_overlapping_and_dense();
}
/**
@ -1546,6 +1556,9 @@ protected:
dest_impl->type_set_ = dest_impl->type_set_.remove(TensorTypeId::VariableTensorId);
}
dest_impl->is_contiguous_ = src_impl->is_contiguous_;
dest_impl->is_channels_last_contiguous_ = src_impl->is_channels_last_contiguous_;
dest_impl->is_channels_last_ = src_impl->is_channels_last_;
dest_impl->is_non_overlapping_and_dense_ = src_impl->is_non_overlapping_and_dense_;
dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_;
dest_impl->reserved_ = src_impl->reserved_;
dest_impl->set_version_counter(version_counter);
@ -1641,6 +1654,11 @@ protected:
// contiguous memory block.
bool is_channels_last_contiguous_ = false;
// Dense tensor is the tensor that store values in a contiguous block of memory.
// Non-overlapping tensor is the tensor in which elements occupy individual
// non-repetitive memory.
bool is_non_overlapping_and_dense_ = false;
bool is_wrapped_number_ = false;
// NOTE [ Metadata Change for a Detached Tensor ]

View File

@ -12402,6 +12402,36 @@ class TestTorchDeviceType(TestCase):
y = nhwc.permute(0, 1, 3, 2).permute(0, 1, 3, 2)
self.assertTrue(y.is_contiguous(memory_format=torch.channels_last))
def test_memory_format_clone(self, device):
nhwc = torch.randn((10, 3, 32, 32), device=device).contiguous(memory_format=torch.channels_last)
# nhwc is not memory dense, but looks like channels last
nhwc = nhwc[:, :, ::2, ::2]
clone = nhwc.clone(memory_format=torch.preserve_format)
self.assertFalse(clone.is_contiguous())
self.assertTrue(clone.is_contiguous(memory_format=torch.channels_last))
self.assertFalse(nhwc.is_contiguous())
self.assertFalse(nhwc.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(nhwc, clone)
nhwc = torch.randn((10, 3, 32, 32), device=device).contiguous(memory_format=torch.channels_last)
clone = nhwc.clone(memory_format=torch.contiguous_format)
self.assertTrue(clone.is_contiguous())
self.assertFalse(clone.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(nhwc, clone)
nhwc = torch.randn((10, 3, 32, 32), device=device).contiguous(memory_format=torch.channels_last)
clone = nhwc.clone()
self.assertTrue(clone.is_contiguous())
self.assertFalse(clone.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(nhwc, clone)
x = torch.randn((3, 4, 5, 6, 7, 8, 9), device=device)
for _ in range(10):
permutation = list(range(len(x.shape)))
random.shuffle(permutation)
x = x.permute(permutation)
self.assertEqual(x.stride(), x.clone(memory_format=torch.preserve_format).stride())
def test_memory_format_empty_like(self, device):
x = torch.randn(10, 3, 32, 32, device=device)
nhwc = x.contiguous(memory_format=torch.channels_last)

View File

@ -233,7 +233,7 @@
- name: clamp_max(Tensor self, Scalar max) -> Tensor
self: grad * (self <= max).to(grad.dtype())
- name: clone(Tensor self) -> Tensor
- name: clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
self: grad
- name: coalesce(Tensor self) -> Tensor

View File

@ -791,7 +791,7 @@ class ShapePropagator {
"aten::asin(Tensor self) -> Tensor",
"aten::atan(Tensor self) -> Tensor",
"aten::ceil(Tensor self) -> Tensor",
"aten::clone(Tensor self) -> Tensor",
"aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
"aten::contiguous(Tensor self, *, MemoryFormat memory_format=contiguous_format) -> Tensor",
"aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
"aten::celu(Tensor self, Scalar alpha) -> Tensor",

View File

@ -972,7 +972,7 @@ def index_select(g, self, dim, index):
index = g.op("Constant", value_t=torch.LongTensor([index_const]))
elif index_dim is not None:
if index_dim == 0:
# Index is a scalar. Reshape it to a size 1 tensor.
# Index is a scalar. Reshape it to a size 1 tensor.
index = g.op("Reshape", index, g.op("Constant", value_t=torch.LongTensor([1])))
return g.op("Gather", self, index, axis_i=dim)
@ -1042,7 +1042,7 @@ def cosine_similarity(g, x1, x2, dim, eps):
# ignore clone operators that are inserted by PyTorch autograd
def clone(g, input):
def clone(g, input, unused_memory_format):
return input
@ -1950,10 +1950,10 @@ def multinomial(g, input, num_samples, replacement=False, generator=None):
dtype_i=sym_help.cast_pytorch_to_onnx['Long'],
sample_size_i=num_samples)
def baddbmm(g, self, batch1, batch2, beta, alpha):
dtype = self.type().scalarType()
batch_mul = matmul(g, batch1, batch2)
mul_a = mul(g, batch_mul, g.op("Cast", alpha, to_i=sym_help.cast_pytorch_to_onnx[dtype]))
def baddbmm(g, self, batch1, batch2, beta, alpha):
dtype = self.type().scalarType()
batch_mul = matmul(g, batch1, batch2)
mul_a = mul(g, batch_mul, g.op("Cast", alpha, to_i=sym_help.cast_pytorch_to_onnx[dtype]))
mul_b = mul(g, self, g.op("Cast", beta, to_i=sym_help.cast_pytorch_to_onnx[dtype]))
return add(g, mul_a, mul_b)