mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Memory format support for contiguous and is_contiguous (#20455)
Summary: #19975 was separated by 2 PRs. This one: Introduce MemoryFormat argument to the `x.is_contiguous(memory_format=torch.channels_last)` and to the `y = x.contiguous(memory_format=torch.channels_last)` functions. At this moment both functions just operate with strides and doesn't store any tensor state. (Original RFC #19092) ----- Expands functionality of two tensor functions `.is_contiguous` and `.contiguous` (both python and c++ api). Note: We had several complaints about `.to(memory_format)` function, and decided not to support it. 1. `.contiguous` now support optional keyword-only argument - `memory_format`, which can be either `torch.contiguous_format` or `torch.channels_last`. - Using `torch.contiguous_format` will preserve existing `.contiguous()` behavior. - Calling `x.contiguous(memory_format=torch.channels_last)` returns new tensor which maintain same semantical layout (NCHW), but have different memory allocation pattern. `x.contiguous(memory_format=torch.channels_last)` expects input tensor to be 3d, 4d or 5d; and fails otherwise. 2. `.is_contiguous` now support optional keyword-only argument - `memory_format`, which can be either `torch.contiguous_format` or `torch.channels_last`. - `x.is_contiguous(memory_format=torch.contiguous_format)` preserves same functionality as `x.is_contiguous()` and remains unchanged. - `x.is_contiguous(memory_format=torch.channels_last)` returns true if A) input tensor is contiguous in memory AND B) allocated in the memory in NWHC (or similar for 3d,5d) format. Note: By the end of the phase one `x.is_contiguous(memory_format=torch.channels_last)` will calculate state of the Tensor on every call. This functionality going to be updated later. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20455 Differential Revision: D15341577 Pulled By: VitalyFedyunin fbshipit-source-id: bbb6b4159a8a49149110ad321109a3742383185d
This commit is contained in:
committed by
Facebook Github Bot
parent
09f22d10a6
commit
5b78a5eadb
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/TensorImpl.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace at {
|
||||
@ -36,7 +37,7 @@ struct CAFFE2_API OpaqueTensorImpl : public TensorImpl {
|
||||
AT_ERROR("opaque tensors do not have strides");
|
||||
}
|
||||
|
||||
bool is_contiguous() const override {
|
||||
bool is_contiguous(c10::MemoryFormat memory_format=c10::MemoryFormat::Any) const override {
|
||||
AT_ERROR("opaque tensors do not have is_contiguous");
|
||||
}
|
||||
|
||||
|
@ -51,7 +51,7 @@ SparseTensorImpl::SparseTensorImpl(at::TensorTypeId type_id, const caffe2::TypeM
|
||||
IntArrayRef SparseTensorImpl::strides() const {
|
||||
AT_ERROR("sparse tensors do not have strides");
|
||||
}
|
||||
bool SparseTensorImpl::is_contiguous() const {
|
||||
bool SparseTensorImpl::is_contiguous(at::MemoryFormat memory_format) const {
|
||||
AT_ERROR("sparse tensors do not have is_contiguous");
|
||||
}
|
||||
int64_t SparseTensorImpl::stride(int64_t d) const {
|
||||
|
@ -41,7 +41,7 @@ public:
|
||||
Tensor values() const { return values_; }
|
||||
|
||||
IntArrayRef strides() const override;
|
||||
bool is_contiguous() const override;
|
||||
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Any) const override;
|
||||
int64_t stride(int64_t d) const override;
|
||||
void resize_dim(int64_t ndim) override;
|
||||
void set_size(int64_t dim, int64_t new_size) override;
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <ATen/core/Type.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/Layout.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <ATen/core/SparseTensorRef.h>
|
||||
@ -165,8 +166,8 @@ class CAFFE2_API Tensor {
|
||||
int64_t ndimension() const {
|
||||
return dim();
|
||||
}
|
||||
bool is_contiguous() const {
|
||||
return impl_->is_contiguous();
|
||||
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Any) const {
|
||||
return impl_->is_contiguous(memory_format);
|
||||
}
|
||||
|
||||
// Total bytes consumed by the "view" of elements of the array. Does not
|
||||
@ -373,7 +374,7 @@ class CAFFE2_API Tensor {
|
||||
Tensor & clamp_max_(Scalar max);
|
||||
Tensor clamp_min(Scalar min) const;
|
||||
Tensor & clamp_min_(Scalar min);
|
||||
Tensor contiguous() const;
|
||||
Tensor contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const;
|
||||
Tensor & copy_(const Tensor & src, bool non_blocking=false);
|
||||
Tensor cos() const;
|
||||
Tensor & cos_();
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <ATen/core/SparseTensorRef.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
@ -176,8 +177,8 @@ inline Tensor Tensor::clamp_min(Scalar min) const {
|
||||
inline Tensor & Tensor::clamp_min_(Scalar min) {
|
||||
return dispatch_type().clamp_min_(*this, min);
|
||||
}
|
||||
inline Tensor Tensor::contiguous() const {
|
||||
return dispatch_type().contiguous(*this);
|
||||
inline Tensor Tensor::contiguous(MemoryFormat memory_format) const {
|
||||
return dispatch_type().contiguous(*this, memory_format);
|
||||
}
|
||||
inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) {
|
||||
return dispatch_type().copy_(*this, src, non_blocking);
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <c10/util/Deprecated.h>
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <c10/core/Layout.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <ATen/core/SparseTensorRef.h>
|
||||
@ -182,7 +183,7 @@ struct CAFFE2_API Type {
|
||||
virtual Tensor & clamp_max_(Tensor & self, Scalar max) const = 0;
|
||||
virtual Tensor clamp_min(const Tensor & self, Scalar min) const = 0;
|
||||
virtual Tensor & clamp_min_(Tensor & self, Scalar min) const = 0;
|
||||
virtual Tensor contiguous(const Tensor & self) const = 0;
|
||||
virtual Tensor contiguous(const Tensor & self, MemoryFormat memory_format) const = 0;
|
||||
virtual Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) const = 0;
|
||||
virtual Tensor cos(const Tensor & self) const = 0;
|
||||
virtual Tensor & cos_(Tensor & self) const = 0;
|
||||
|
@ -319,6 +319,12 @@ struct CAFFE2_API IValue final {
|
||||
return static_cast<at::Layout>(toInt());
|
||||
}
|
||||
|
||||
// MemoryFormat
|
||||
at::MemoryFormat toMemoryFormat() const {
|
||||
return static_cast<at::MemoryFormat>(toInt());
|
||||
}
|
||||
|
||||
|
||||
// for debugging
|
||||
std::string tagKind() const {
|
||||
switch(tag) {
|
||||
|
@ -480,6 +480,7 @@ DEFINE_TO(IValue, toIValue)
|
||||
DEFINE_TO(c10::Device, toDevice)
|
||||
DEFINE_TO(at::ScalarType, toScalarType)
|
||||
DEFINE_TO(at::Layout, toLayout)
|
||||
DEFINE_TO(at::MemoryFormat, toMemoryFormat)
|
||||
|
||||
template <typename T>
|
||||
struct _fake_type {};
|
||||
|
@ -53,11 +53,37 @@ Tensor & detach_(Tensor & self) {
|
||||
}
|
||||
|
||||
Tensor contiguous(const Tensor & self) {
|
||||
if (self.is_contiguous()) {
|
||||
return self;
|
||||
}
|
||||
return self.clone();
|
||||
return contiguous(self, MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
Tensor contiguous(const Tensor& self, MemoryFormat memory_format) {
|
||||
if (self.is_contiguous(memory_format)) {
|
||||
return self;
|
||||
}
|
||||
auto result = at::empty_like(self);
|
||||
switch (memory_format) {
|
||||
case MemoryFormat::Any: // Back compatibility with old defaults
|
||||
case MemoryFormat::Contiguous: {
|
||||
break;
|
||||
}
|
||||
case MemoryFormat::ChannelsLast: {
|
||||
AT_CHECK(
|
||||
result.dim() == 4,
|
||||
" required rank 4 tensor to use channels_last format");
|
||||
std::vector<int64_t> newStrides(self.dim());
|
||||
auto sizes = result.sizes();
|
||||
newStrides[1] = 1;
|
||||
newStrides[3] = sizes[1];
|
||||
newStrides[2] = newStrides[3] * sizes[3];
|
||||
newStrides[0] = newStrides[2] * sizes[2];
|
||||
result = result.as_strided(sizes, newStrides);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
AT_CHECK(false, " unsupported memory format");
|
||||
}
|
||||
}
|
||||
return result.copy_(self);
|
||||
}
|
||||
} // namespace native
|
||||
}
|
||||
|
@ -424,7 +424,7 @@
|
||||
- func: constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: contiguous(Tensor self) -> Tensor
|
||||
- func: contiguous(Tensor self, *, MemoryFormat memory_format=contiguous_format) -> Tensor
|
||||
variants: method
|
||||
|
||||
- func: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor
|
||||
|
@ -127,6 +127,8 @@ def type_argument_translations(arg):
|
||||
# we change this at either a JIT schema or C++ level.
|
||||
elif default == 'Mean':
|
||||
default = 'Reduction::Mean'
|
||||
elif default == 'contiguous_format':
|
||||
default = 'MemoryFormat::Contiguous'
|
||||
else:
|
||||
try:
|
||||
default = int(default)
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <ATen/core/Type.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/Layout.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <ATen/core/SparseTensorRef.h>
|
||||
@ -165,8 +166,8 @@ class CAFFE2_API Tensor {
|
||||
int64_t ndimension() const {
|
||||
return dim();
|
||||
}
|
||||
bool is_contiguous() const {
|
||||
return impl_->is_contiguous();
|
||||
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Any) const {
|
||||
return impl_->is_contiguous(memory_format);
|
||||
}
|
||||
|
||||
// Total bytes consumed by the "view" of elements of the array. Does not
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <ATen/core/SparseTensorRef.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <c10/util/Deprecated.h>
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <c10/core/Layout.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <ATen/core/SparseTensorRef.h>
|
||||
|
49
c10/core/MemoryFormat.h
Normal file
49
c10/core/MemoryFormat.h
Normal file
@ -0,0 +1,49 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Backend.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
// Memory format is not the property of a Tensor. It is the way to tell an
|
||||
// operator how the result should be organized in memory and nothing more. That
|
||||
// means memory format should never be used as return value for any tensor state
|
||||
// interrogation functions (internally and externally).
|
||||
//
|
||||
// Possible options are:
|
||||
// Any:
|
||||
// An operator can return Tensor with any memory format. This describes the
|
||||
// current behavior of operators.
|
||||
//
|
||||
// Preserve:
|
||||
// If any of the input tensors is in channels_last format, operator output
|
||||
// should be in channels_last format
|
||||
//
|
||||
// Contiguous:
|
||||
// Regardless of input tensors format, the output should be contiguous Tensor.
|
||||
//
|
||||
// ChannelsLast:
|
||||
// Regardless of input tensors format, the output should be in channels_last format.
|
||||
|
||||
|
||||
namespace c10 {
|
||||
enum class MemoryFormat : int8_t { Any, Preserve, Contiguous, ChannelsLast };
|
||||
|
||||
inline std::ostream& operator<<(
|
||||
std::ostream& stream,
|
||||
at::MemoryFormat memory_format) {
|
||||
switch (memory_format) {
|
||||
case MemoryFormat::Any:
|
||||
return stream << "Any";
|
||||
case MemoryFormat::Preserve:
|
||||
return stream << "Preserve";
|
||||
case MemoryFormat::Contiguous:
|
||||
return stream << "Contiguous";
|
||||
case MemoryFormat::ChannelsLast:
|
||||
return stream << "ChannelsLast";
|
||||
default:
|
||||
AT_ERROR("Unknown memory format");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace c10
|
@ -113,6 +113,26 @@ bool TensorImpl::has_storage() const {
|
||||
return storage_;
|
||||
}
|
||||
|
||||
bool TensorImpl::is_contiguous(at::MemoryFormat memory_format) const {
|
||||
#ifdef DEBUG
|
||||
AT_ASSERT(compute_contiguous() == is_contiguous_);
|
||||
#endif
|
||||
if (memory_format == at::MemoryFormat::ChannelsLast) {
|
||||
if (dim() == 4) {
|
||||
auto strides_1 = 1;
|
||||
auto strides_3 = sizes_[1];
|
||||
auto strides_2 = strides_3 * sizes_[3];
|
||||
auto strides_0 = strides_2 * sizes_[2];
|
||||
if (strides_0 == strides_[0] && strides_1 == strides_[1] &&
|
||||
strides_2 == strides_[2] && strides_3 == strides_[3]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return is_contiguous_;
|
||||
}
|
||||
|
||||
const Storage& TensorImpl::storage() const {
|
||||
return storage_;
|
||||
}
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <numeric>
|
||||
|
||||
#include <c10/core/Backend.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/core/Storage.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <c10/core/TensorTypeId.h>
|
||||
@ -387,12 +388,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
* compute_contiguous() for the exact definition of whether or not
|
||||
* a tensor is contiguous or not.
|
||||
*/
|
||||
virtual bool is_contiguous() const {
|
||||
#ifdef DEBUG
|
||||
AT_ASSERT(compute_contiguous() == is_contiguous_);
|
||||
#endif
|
||||
return is_contiguous_;
|
||||
}
|
||||
virtual bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Any) const;
|
||||
|
||||
bool is_sparse() const {
|
||||
// NB: This method is not virtual and avoid dispatches for performance reasons.
|
||||
|
@ -519,8 +519,8 @@ class CAFFE2_API Tensor final {
|
||||
return impl_.get()->strides();
|
||||
}
|
||||
|
||||
inline bool is_contiguous() const {
|
||||
return impl_.get()->is_contiguous();
|
||||
inline bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Any) const {
|
||||
return impl_.get()->is_contiguous(memory_format);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -9370,7 +9370,7 @@ class _TestTorchMixin(object):
|
||||
i, j = 41, 43
|
||||
with BytesIOContext() as f:
|
||||
pickle.dump(i, f)
|
||||
torch.save(a, f)
|
||||
torch.save(a, f)
|
||||
pickle.dump(j, f)
|
||||
torch.save(b, f)
|
||||
f.seek(0)
|
||||
@ -11404,6 +11404,40 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||
weight), torch.tensor(bias), 1, epsilon, True)
|
||||
torch.testing.assert_allclose(expected_norm, actual_norm)
|
||||
|
||||
def test_memory_format(self):
|
||||
x = torch.randn(10, 3, 32, 32)
|
||||
nhwc = x.contiguous(memory_format=torch.channels_last)
|
||||
self.assertFalse(nhwc.is_contiguous())
|
||||
self.assertTrue(nhwc.is_contiguous(memory_format=torch.channels_last))
|
||||
self.assertEqual(nhwc, x)
|
||||
|
||||
## These sections are intentionally commented, as they suppose to pass when we
|
||||
## switch from the computational check of the tensor memory format to the
|
||||
## actual layout support
|
||||
#
|
||||
#
|
||||
# def fake_nhwc(N, C, H, W):
|
||||
# alloc = torch.randn(N, H, W, C)
|
||||
# return alloc.permute(0, 3, 1, 2)
|
||||
#
|
||||
# fake = fake_nhwc(10, 3, 32, 32)
|
||||
# self.assertFalse(
|
||||
# fake.is_contiguous(memory_format=torch.channels_last),
|
||||
# " must be tagged to be identified as channels_last")
|
||||
#
|
||||
# def test_memory_format_permute(self):
|
||||
# x = torch.randn(10, 3, 32, 32)
|
||||
# nhwc = x.contiguous(memory_format=torch.channels_last)
|
||||
# y = nhwc.permute(0, 1, 3, 2).permute(0, 1, 3, 2)
|
||||
# self.assertFalse(y.is_contiguous(memory_format=torch.channels_last))
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
|
||||
def test_memory_format_permute_cuda(self):
|
||||
x = torch.randn(10, 3, 32, 32)
|
||||
nhwc = x.contiguous(memory_format=torch.channels_last).cuda()
|
||||
y = nhwc.permute(0, 1, 3, 2).permute(0, 1, 3, 2)
|
||||
self.assertFalse(y.is_contiguous(memory_format=torch.channels_last))
|
||||
|
||||
def test_subclass_tensors(self):
|
||||
# raise an error when trying to subclass FloatTensor
|
||||
with self.assertRaisesRegex(TypeError, "type 'torch.FloatTensor' is not an acceptable base type"):
|
||||
@ -11423,7 +11457,6 @@ INPLACE_METHOD = 2
|
||||
FUNCTIONAL = 4
|
||||
DIM_ARG = None
|
||||
|
||||
|
||||
def make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim=0):
|
||||
def neg_dim_test(self):
|
||||
if isinstance(tensor_arg, list):
|
||||
|
@ -21,6 +21,7 @@ using at::Context;
|
||||
using at::Device;
|
||||
using at::Generator;
|
||||
using at::IntArrayRef;
|
||||
using at::MemoryFormat;
|
||||
using at::Scalar;
|
||||
using at::ScalarType;
|
||||
using at::SparseTensorRef;
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
// ${generated_comment}
|
||||
|
||||
|
||||
#include "torch/csrc/Device.h"
|
||||
#include "torch/csrc/DynamicTypes.h"
|
||||
#include "torch/csrc/Exceptions.h"
|
||||
@ -15,6 +16,7 @@
|
||||
|
||||
using at::Tensor;
|
||||
using at::Scalar;
|
||||
using at::MemoryFormat;
|
||||
using namespace torch::autograd::utils;
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
@ -143,17 +143,24 @@ static PyObject * THPVariable_dim(PyObject* self, PyObject* args)
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static Tensor dispatch_contiguous(const Tensor & self) {
|
||||
static Tensor dispatch_contiguous(const Tensor & self, at::MemoryFormat memory_format) {
|
||||
AutoNoGIL no_gil;
|
||||
OptionalDeviceGuard device_guard(device_of(self));
|
||||
return self.contiguous();
|
||||
return self.contiguous(memory_format);
|
||||
}
|
||||
static PyObject * THPVariable_contiguous(PyObject* self, PyObject* args)
|
||||
|
||||
static PyObject * THPVariable_contiguous(PyObject* self, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser({
|
||||
"contiguous(*, MemoryFormat memory_format=contiguous_format)",
|
||||
});
|
||||
ParsedArgs<1> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
|
||||
auto memory_format = r.toMemoryFormat(0);
|
||||
// avoids touching the GIL or current device if self is already contiguous
|
||||
if (self_.is_contiguous()) {
|
||||
if (self_.is_contiguous(memory_format)) {
|
||||
// NOTE: this logic is duplicated from VariableType.cpp. Since we need to
|
||||
// record this call to contiguous() in the trace regardless of whether
|
||||
// we actually call contiguous here, we need to record this information
|
||||
@ -163,13 +170,14 @@ static Tensor dispatch_contiguous(const Tensor & self) {
|
||||
auto node = tracer_state->graph->create(jit::aten::contiguous, /*num_outputs=*/0);
|
||||
jit::tracer::recordSourceLocation(node);
|
||||
jit::tracer::addInputs(node, "self", self_);
|
||||
jit::tracer::addInputs(node, "memory_format", memory_format);
|
||||
tracer_state->graph->insertNode(node);
|
||||
jit::tracer::addOutput(node, self_);
|
||||
}
|
||||
Py_INCREF(self);
|
||||
return self;
|
||||
}
|
||||
return THPVariable_Wrap(dispatch_contiguous(self_));
|
||||
return THPVariable_Wrap(dispatch_contiguous(self_, memory_format));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -432,15 +440,21 @@ static PyObject * THPVariable_requires_grad_(PyObject* self, PyObject* args, PyO
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
inline bool dispatch_is_contiguous(Tensor & self) {
|
||||
return self.is_contiguous();
|
||||
inline bool dispatch_is_contiguous(Tensor & self, MemoryFormat memory_format) {
|
||||
return self.is_contiguous(memory_format);
|
||||
}
|
||||
|
||||
static PyObject * THPVariable_is_contiguous(PyObject* self_, PyObject* args)
|
||||
static PyObject * THPVariable_is_contiguous(PyObject* self_, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser({
|
||||
"is_contiguous(*, MemoryFormat memory_format=contiguous_format)",
|
||||
});
|
||||
ParsedArgs<1> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
auto memory_format = r.toMemoryFormat(0);
|
||||
auto& self = reinterpret_cast<THPVariable*>(self_)->cdata;
|
||||
return wrap(dispatch_is_contiguous(self));
|
||||
return wrap(dispatch_is_contiguous(self, memory_format));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -686,7 +700,7 @@ PyMethodDef variable_methods[] = {
|
||||
{"apply_", (PyCFunction)THPVariable_apply_, METH_O, NULL},
|
||||
{"byte", (PyCFunction)THPVariable_byte, METH_NOARGS, NULL},
|
||||
{"char", (PyCFunction)THPVariable_char, METH_NOARGS, NULL},
|
||||
{"contiguous", (PyCFunction)THPVariable_contiguous, METH_NOARGS, NULL},
|
||||
{"contiguous", (PyCFunction)THPVariable_contiguous, METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"copy_", (PyCFunction)THPVariable_copy_, METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"cpu", (PyCFunction)THPVariable_cpu, METH_NOARGS, NULL},
|
||||
{"cuda", (PyCFunction)THPVariable_cuda, METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
@ -698,7 +712,7 @@ PyMethodDef variable_methods[] = {
|
||||
{"bool", (PyCFunction)THPVariable_bool, METH_NOARGS, NULL},
|
||||
{"half", (PyCFunction)THPVariable_half, METH_NOARGS, NULL},
|
||||
{"int", (PyCFunction)THPVariable_int, METH_NOARGS, NULL},
|
||||
{"is_contiguous", (PyCFunction)THPVariable_is_contiguous, METH_NOARGS, NULL},
|
||||
{"is_contiguous", (PyCFunction)THPVariable_is_contiguous, METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
{"item", (PyCFunction)THPVariable_item, METH_NOARGS, NULL},
|
||||
{"long", (PyCFunction)THPVariable_long, METH_NOARGS, NULL},
|
||||
{"map_", (PyCFunction)THPVariable_map_, METH_VARARGS | METH_KEYWORDS, NULL},
|
||||
|
@ -158,6 +158,7 @@ def add_torch_libs():
|
||||
"torch/csrc/DynamicTypes.cpp",
|
||||
"torch/csrc/Generator.cpp",
|
||||
"torch/csrc/Layout.cpp",
|
||||
"torch/csrc/MemoryFormat.cpp",
|
||||
"torch/csrc/Module.cpp",
|
||||
"torch/csrc/PtrWrapper.cpp",
|
||||
"torch/csrc/Size.cpp",
|
||||
@ -233,6 +234,7 @@ def add_torch_libs():
|
||||
"torch/csrc/utils/tensor_apply.cpp",
|
||||
"torch/csrc/utils/tensor_dtypes.cpp",
|
||||
"torch/csrc/utils/tensor_layouts.cpp",
|
||||
"torch/csrc/utils/tensor_memoryformats.cpp",
|
||||
"torch/csrc/utils/tensor_list.cpp",
|
||||
"torch/csrc/utils/tensor_new.cpp",
|
||||
"torch/csrc/utils/tensor_numpy.cpp",
|
||||
|
@ -40,6 +40,7 @@ TYPE_MAP = {
|
||||
'std::array<bool,4>': 'bool[4]',
|
||||
'std::string': 'str',
|
||||
'Scalar': 'Scalar',
|
||||
'MemoryFormat': 'MemoryFormat',
|
||||
'Scalar?': 'Scalar?',
|
||||
'Tensor': 'Tensor',
|
||||
'Tensor?': 'Tensor?',
|
||||
@ -96,6 +97,7 @@ FROM_IVALUE = {
|
||||
'IntArrayRef': '{}.toIntList()->elements()',
|
||||
'Layout': '{}.toLayout()',
|
||||
'Layout?': '{}.toOptional<c10::Layout>()',
|
||||
'MemoryFormat': '{}.toMemoryFormat()',
|
||||
'Scalar': '{}.toScalar()',
|
||||
'Scalar?': '{}.toOptional<Scalar>()',
|
||||
'ScalarType': '{}.toScalarType()',
|
||||
@ -483,6 +485,7 @@ def signature(decl, should_match_schema=True):
|
||||
.replace('true', 'True') \
|
||||
.replace('false', 'False') \
|
||||
.replace('Reduction::Mean', 'Mean') \
|
||||
.replace('MemoryFormat::Contiguous', 'contiguous_format') \
|
||||
.replace('{}', 'None' if is_tensor_arg(arg) else '[]') \
|
||||
.replace('{', '[') \
|
||||
.replace('}', ']')
|
||||
|
@ -46,6 +46,7 @@ set(TORCH_PYTHON_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/TypeInfo.cpp
|
||||
${TORCH_SRC_DIR}/csrc/Generator.cpp
|
||||
${TORCH_SRC_DIR}/csrc/Layout.cpp
|
||||
${TORCH_SRC_DIR}/csrc/MemoryFormat.cpp
|
||||
${TORCH_SRC_DIR}/csrc/Module.cpp
|
||||
${TORCH_SRC_DIR}/csrc/PtrWrapper.cpp
|
||||
${TORCH_SRC_DIR}/csrc/Size.cpp
|
||||
@ -89,6 +90,7 @@ set(TORCH_PYTHON_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/utils/tensor_dtypes.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/tensor_layouts.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/tensor_list.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/tensor_memoryformats.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/tensor_new.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/tensor_numpy.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/tensor_types.cpp
|
||||
|
80
torch/csrc/MemoryFormat.cpp
Normal file
80
torch/csrc/MemoryFormat.cpp
Normal file
@ -0,0 +1,80 @@
|
||||
#include <torch/csrc/MemoryFormat.h>
|
||||
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
#include <torch/csrc/utils/object_ptr.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
|
||||
#include <structmember.h>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
|
||||
PyObject *THPMemoryFormat_New(at::MemoryFormat memory_format, const std::string& name)
|
||||
{
|
||||
auto type = (PyTypeObject*)&THPMemoryFormatType;
|
||||
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
|
||||
if (!self) throw python_error();
|
||||
auto self_ = reinterpret_cast<THPMemoryFormat*>(self.get());
|
||||
self_->memory_format = memory_format;
|
||||
std::strncpy (self_->name, name.c_str(), MEMORY_FORMAT_NAME_LEN);
|
||||
self_->name[MEMORY_FORMAT_NAME_LEN] = '\0';
|
||||
return self.release();
|
||||
}
|
||||
|
||||
PyObject *THPMemoryFormat_repr(THPMemoryFormat *self)
|
||||
{
|
||||
return THPUtils_packString(self->name);
|
||||
}
|
||||
|
||||
PyTypeObject THPMemoryFormatType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||
"torch.memory_format", /* tp_name */
|
||||
sizeof(THPMemoryFormat), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
nullptr, /* tp_dealloc */
|
||||
nullptr, /* tp_print */
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
(reprfunc)THPMemoryFormat_repr, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
nullptr, /* tp_hash */
|
||||
nullptr, /* tp_call */
|
||||
nullptr, /* tp_str */
|
||||
nullptr, /* tp_getattro */
|
||||
nullptr, /* tp_setattro */
|
||||
nullptr, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT, /* tp_flags */
|
||||
nullptr, /* tp_doc */
|
||||
nullptr, /* tp_traverse */
|
||||
nullptr, /* tp_clear */
|
||||
nullptr, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
nullptr, /* tp_iternext */
|
||||
nullptr, /* tp_methods */
|
||||
nullptr, /* tp_members */
|
||||
nullptr, /* tp_getset */
|
||||
nullptr, /* tp_base */
|
||||
nullptr, /* tp_dict */
|
||||
nullptr, /* tp_descr_get */
|
||||
nullptr, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
nullptr, /* tp_init */
|
||||
nullptr, /* tp_alloc */
|
||||
nullptr, /* tp_new */
|
||||
};
|
||||
|
||||
void THPMemoryFormat_init(PyObject *module)
|
||||
{
|
||||
if (PyType_Ready(&THPMemoryFormatType) < 0) {
|
||||
throw python_error();
|
||||
}
|
||||
Py_INCREF(&THPMemoryFormatType);
|
||||
if (PyModule_AddObject(module, "memory_format", (PyObject *)&THPMemoryFormatType) != 0) {
|
||||
throw python_error();
|
||||
}
|
||||
}
|
25
torch/csrc/MemoryFormat.h
Normal file
25
torch/csrc/MemoryFormat.h
Normal file
@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
const int MEMORY_FORMAT_NAME_LEN = 64;
|
||||
|
||||
struct THPMemoryFormat {
|
||||
PyObject_HEAD
|
||||
at::MemoryFormat memory_format;
|
||||
char name[MEMORY_FORMAT_NAME_LEN + 1];
|
||||
};
|
||||
|
||||
extern PyTypeObject THPMemoryFormatType;
|
||||
|
||||
inline bool THPMemoryFormat_Check(PyObject *obj) {
|
||||
return Py_TYPE(obj) == &THPMemoryFormatType;
|
||||
}
|
||||
|
||||
PyObject * THPMemoryFormat_New(at::MemoryFormat memory_format, const std::string& name);
|
||||
|
||||
void THPMemoryFormat_init(PyObject *module);
|
@ -25,6 +25,7 @@
|
||||
#include <torch/csrc/DataLoader.h>
|
||||
#include <torch/csrc/Generator.h>
|
||||
#include <torch/csrc/Layout.h>
|
||||
#include <torch/csrc/MemoryFormat.h>
|
||||
#include <torch/csrc/TypeInfo.h>
|
||||
#include <torch/csrc/autograd/generated/python_nn_functions.h>
|
||||
#include <torch/csrc/autograd/python_legacy_variable.h>
|
||||
@ -34,6 +35,7 @@
|
||||
#include <torch/csrc/utils/tensor_dtypes.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
#include <torch/csrc/utils/tensor_layouts.h>
|
||||
#include <torch/csrc/utils/tensor_memoryformats.h>
|
||||
#include <torch/csrc/utils/tensor_numpy.h>
|
||||
#include <torch/csrc/jit/python_tracer.h>
|
||||
#include <torch/csrc/jit/init.h>
|
||||
@ -97,6 +99,7 @@ static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manag
|
||||
return nullptr;
|
||||
}
|
||||
torch::utils::initializeLayouts();
|
||||
torch::utils::initializeMemoryFormats();
|
||||
torch::utils::initializeDtypes();
|
||||
torch::tensors::initialize_python_bindings();
|
||||
std::string path = THPUtils_unpackString(shm_manager_path);
|
||||
@ -589,6 +592,7 @@ PyObject* initModule() {
|
||||
THPDtype_init(module);
|
||||
THPDTypeInfo_init(module);
|
||||
THPLayout_init(module);
|
||||
THPMemoryFormat_init(module);
|
||||
THPDevice_init(module);
|
||||
ASSERT_TRUE(THPVariable_initModule(module));
|
||||
ASSERT_TRUE(THPFunction_initModule(module));
|
||||
|
@ -54,8 +54,8 @@ IntArrayRef Variable::Impl::strides() const {
|
||||
return data_.strides();
|
||||
}
|
||||
|
||||
bool Variable::Impl::is_contiguous() const {
|
||||
return data_.is_contiguous();
|
||||
bool Variable::Impl::is_contiguous(MemoryFormat memory_format) const {
|
||||
return data_.is_contiguous(memory_format);
|
||||
}
|
||||
|
||||
int64_t Variable::Impl::dim() const {
|
||||
|
@ -409,7 +409,7 @@ struct TORCH_API Variable::Impl : public at::TensorImpl {
|
||||
int64_t numel() const override;
|
||||
at::IntArrayRef sizes() const override;
|
||||
at::IntArrayRef strides() const override;
|
||||
bool is_contiguous() const override;
|
||||
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Any) const override;
|
||||
int64_t size(int64_t d) const override;
|
||||
int64_t stride(int64_t d) const override;
|
||||
void resize_dim(int64_t ndim) override;
|
||||
|
@ -701,7 +701,7 @@ class ShapePropagator {
|
||||
"aten::atan(Tensor self) -> Tensor",
|
||||
"aten::ceil(Tensor self) -> Tensor",
|
||||
"aten::clone(Tensor self) -> Tensor",
|
||||
"aten::contiguous(Tensor self) -> Tensor",
|
||||
"aten::contiguous(Tensor self, *, MemoryFormat memory_format=contiguous_format) -> Tensor",
|
||||
"aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
|
||||
"aten::celu(Tensor self, Scalar alpha) -> Tensor",
|
||||
"aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
|
||||
|
@ -158,6 +158,8 @@ struct SchemaParser {
|
||||
return static_cast<int64_t>(at::kStrided);
|
||||
} else if ("Mean" == text) {
|
||||
return static_cast<int64_t>(Reduction::Mean);
|
||||
} else if ("contiguous_format" == text) {
|
||||
return static_cast<int64_t>(c10::MemoryFormat::Contiguous);
|
||||
} else {
|
||||
throw ErrorReport(L.cur().range) << "invalid numeric default value";
|
||||
}
|
||||
|
@ -36,6 +36,7 @@ TypeAndAlias SchemaTypeParser::parseBaseType() {
|
||||
{"Generator", GeneratorType::get()},
|
||||
{"ScalarType", IntType::get()},
|
||||
{"Layout", IntType::get()},
|
||||
{"MemoryFormat", IntType::get()},
|
||||
{"Device", DeviceObjType::get()},
|
||||
{"Scalar", NumberType::get()},
|
||||
{"str", StringType::get()},
|
||||
|
@ -402,11 +402,11 @@ const std::vector<std::string> functions = {
|
||||
|
||||
return torch._dim_arange(like, dim), backward
|
||||
|
||||
def contiguous(self):
|
||||
def contiguous(self, *, memory_format: int=0):
|
||||
def backward(grad_output):
|
||||
return grad_output
|
||||
return grad_output, None
|
||||
|
||||
return self.contiguous(), backward
|
||||
return self.contiguous(memory_format=memory_format), backward
|
||||
|
||||
def dot(self, tensor):
|
||||
def backward(grad_output):
|
||||
|
@ -451,6 +451,9 @@ void addInputs(Node* n, const char* name, at::Layout value) {
|
||||
void addInputs(Node* n, const char* name, at::ScalarType value) {
|
||||
detail::genericAddInput(n, static_cast<int64_t>(value));
|
||||
}
|
||||
void addInputs(Node* n, const char* name, at::MemoryFormat value) {
|
||||
detail::genericAddInput(n, static_cast<int64_t>(value));
|
||||
}
|
||||
void addInputs(
|
||||
Node* n,
|
||||
const char* name,
|
||||
|
@ -148,6 +148,7 @@ TORCH_API void addInputs(
|
||||
Node* n,
|
||||
const char* name,
|
||||
const c10::optional<at::ScalarType>& value);
|
||||
TORCH_API void addInputs(Node* n, const char* name, at::MemoryFormat value);
|
||||
TORCH_API void addInputs(Node* n, const char* name, at::Generator* value);
|
||||
|
||||
template<typename T>
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
#include <torch/csrc/Layout.h>
|
||||
#include <torch/csrc/MemoryFormat.h>
|
||||
#include <torch/csrc/utils/invalid_arguments.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
|
||||
@ -28,6 +29,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
|
||||
{"PyObject*", ParameterType::PYOBJECT},
|
||||
{"ScalarType", ParameterType::SCALARTYPE},
|
||||
{"Layout", ParameterType::LAYOUT},
|
||||
{"MemoryFormat", ParameterType::MEMORY_FORMAT},
|
||||
{"Device", ParameterType::DEVICE},
|
||||
{"std::string", ParameterType::STRING},
|
||||
};
|
||||
@ -169,6 +171,7 @@ bool FunctionParameter::check(PyObject* obj) {
|
||||
case ParameterType::PYOBJECT: return true;
|
||||
case ParameterType::SCALARTYPE: return THPDtype_Check(obj);
|
||||
case ParameterType::LAYOUT: return THPLayout_Check(obj);
|
||||
case ParameterType::MEMORY_FORMAT: return THPMemoryFormat_Check(obj);
|
||||
case ParameterType::DEVICE:
|
||||
return THPUtils_checkLong(obj) || THPUtils_checkString(obj) || THPDevice_Check(obj);
|
||||
case ParameterType::STRING: return THPUtils_checkString(obj);
|
||||
@ -190,6 +193,7 @@ std::string FunctionParameter::type_name() const {
|
||||
case ParameterType::PYOBJECT: return "object";
|
||||
case ParameterType::SCALARTYPE: return "torch.dtype";
|
||||
case ParameterType::LAYOUT: return "torch.layout";
|
||||
case ParameterType::MEMORY_FORMAT: return "torch.memory_format";
|
||||
case ParameterType::DEVICE: return "torch.device";
|
||||
case ParameterType::STRING: return "str";
|
||||
default: throw std::runtime_error("unknown parameter type");
|
||||
|
@ -47,6 +47,7 @@
|
||||
#include <torch/csrc/DynamicTypes.h>
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
#include <torch/csrc/Generator.h>
|
||||
#include <torch/csrc/MemoryFormat.h>
|
||||
#include <torch/csrc/autograd/python_variable.h>
|
||||
#include <torch/csrc/jit/tracer.h>
|
||||
#include <torch/csrc/tensor/python_tensor.h>
|
||||
@ -70,7 +71,7 @@ namespace torch {
|
||||
|
||||
enum class ParameterType {
|
||||
TENSOR, SCALAR, INT64, DOUBLE, TENSOR_LIST, INT_LIST, GENERATOR,
|
||||
BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, DEVICE, STRING
|
||||
BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STRING
|
||||
};
|
||||
|
||||
struct FunctionParameter;
|
||||
@ -134,6 +135,7 @@ struct PythonArgs {
|
||||
inline at::Device device(int i);
|
||||
inline at::Device deviceWithDefault(int i, const at::Device& default_device);
|
||||
inline c10::optional<at::Device> deviceOptional(int i);
|
||||
inline at::MemoryFormat toMemoryFormat(int i);
|
||||
inline std::string string(int i);
|
||||
inline PyObject* pyobject(int i);
|
||||
inline int64_t toInt64(int i);
|
||||
@ -388,6 +390,13 @@ inline c10::optional<at::Device> PythonArgs::deviceOptional(int i) {
|
||||
return device(i);
|
||||
}
|
||||
|
||||
inline at::MemoryFormat PythonArgs::toMemoryFormat(int i) {
|
||||
if (!args[i]) return at::MemoryFormat::Any;
|
||||
AT_CHECK(THPMemoryFormat_Check(args[i]), "memory_format arg must be an instance of the torch.memory_format");
|
||||
const auto memory_format = reinterpret_cast<THPMemoryFormat*>(args[i]);
|
||||
return memory_format->memory_format;
|
||||
}
|
||||
|
||||
inline std::string PythonArgs::string(int i) {
|
||||
if (!args[i]) return "";
|
||||
return THPUtils_unpackString(args[i]);
|
||||
|
39
torch/csrc/utils/tensor_memoryformats.cpp
Normal file
39
torch/csrc/utils/tensor_memoryformats.cpp
Normal file
@ -0,0 +1,39 @@
|
||||
#include <torch/csrc/utils/tensor_memoryformats.h>
|
||||
|
||||
#include <torch/csrc/DynamicTypes.h>
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
#include <torch/csrc/MemoryFormat.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/utils/object_ptr.h>
|
||||
|
||||
|
||||
namespace torch {
|
||||
namespace utils {
|
||||
|
||||
#define _ADD_MEMORY_FORMAT(format, name) \
|
||||
{ \
|
||||
std::string module_name = "torch."; \
|
||||
PyObject* memory_format = THPMemoryFormat_New(format, module_name + name); \
|
||||
Py_INCREF(memory_format); \
|
||||
if (PyModule_AddObject(torch_module, name, memory_format) != 0) { \
|
||||
throw python_error(); \
|
||||
} \
|
||||
}
|
||||
|
||||
void initializeMemoryFormats() {
|
||||
auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
|
||||
if (!torch_module) {
|
||||
throw python_error();
|
||||
}
|
||||
|
||||
_ADD_MEMORY_FORMAT(at::MemoryFormat::Any, "any_format");
|
||||
_ADD_MEMORY_FORMAT(at::MemoryFormat::Preserve, "preserve_format");
|
||||
_ADD_MEMORY_FORMAT(at::MemoryFormat::Contiguous, "contiguous_format");
|
||||
_ADD_MEMORY_FORMAT(at::MemoryFormat::ChannelsLast, "channels_last");
|
||||
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace torch
|
7
torch/csrc/utils/tensor_memoryformats.h
Normal file
7
torch/csrc/utils/tensor_memoryformats.h
Normal file
@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
namespace torch { namespace utils {
|
||||
|
||||
void initializeMemoryFormats();
|
||||
|
||||
}} // namespace torch::utils
|
@ -22,7 +22,7 @@ import warnings
|
||||
|
||||
# This file exports ONNX ops for opset 9
|
||||
# Opset 9 is supported by ONNX release 1.4.1
|
||||
# release on 01/23/19
|
||||
# release on 01/23/19
|
||||
|
||||
|
||||
# Note [Pointwise by scalar]
|
||||
@ -1392,7 +1392,10 @@ def detach(g, input):
|
||||
return input
|
||||
|
||||
|
||||
def contiguous(g, input):
|
||||
@parse_args('v', 'i')
|
||||
def contiguous(g, input, memory_format):
|
||||
if memory_format > 2: # allower values are any, preserve and contiguous_format
|
||||
raise RuntimeError("onnx memory_format support is not implemented")
|
||||
return input
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user