mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 06:11:27 +08:00
Add memory_format support to
and type
operators (#27107)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27107 Adds memory_format keyword argument (positional for cpp). 'Preserve' behavior now follows next rules: 1) If tensor is non-overlapping and dense - output tensor will have the same strides as input tensor. 2) If not (1) and tensor is stored in the channels last format, output tensor going to have channels last format. 3) Output tensor is going to be contiguous in all other cases. --- Dense tensor is the tensor that store values in a contiguous block of memory. Non-overlapping tensor is the tensor in which elements occupy individual non-repetitive memory. Test Plan: Imported from OSS Differential Revision: D17931062 Pulled By: VitalyFedyunin fbshipit-source-id: 2c5dd3dd05bf58a9a29f25562cd45190b009c3f9
This commit is contained in:
committed by
Facebook Github Bot
parent
cbe5ab1109
commit
d39ab0312a
@ -336,31 +336,32 @@ static PyObject * THPVariable_invert(PyObject* self, PyObject* args) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static Tensor dispatch_to(const Tensor & self, Device device, bool non_blocking, bool copy) {
|
||||
static Tensor dispatch_to(const Tensor & self, Device device, bool non_blocking, bool copy, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
AutoNoGIL no_gil;
|
||||
// NOTE: this is where we record aten::to in the graph during tracing. However, the behavior of aten::to
|
||||
// is different with respect to TensorOptions fields that are not present: aten::to inherits fields that
|
||||
// are missing from the self argument while the tracer assumes that they should be populated with the
|
||||
// default values (eg. float for scalar type). By explicitly copying over the tensor options here we fully
|
||||
// specify all tensor options and thus record the proper trace
|
||||
return self.to(self.options().device(device), non_blocking, copy);
|
||||
return self.to(self.options().device(device), non_blocking, copy, optional_memory_format);
|
||||
}
|
||||
|
||||
static Tensor dispatch_to(const Tensor & self, ScalarType dtype, bool non_blocking, bool copy) {
|
||||
static Tensor dispatch_to(const Tensor & self, ScalarType dtype, bool non_blocking, bool copy, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
AutoNoGIL no_gil;
|
||||
return self.to(dtype, non_blocking, copy);
|
||||
return self.to(dtype, non_blocking, copy, optional_memory_format);
|
||||
}
|
||||
|
||||
static Tensor dispatch_to(const Tensor & self, Device device, ScalarType dtype, bool non_blocking, bool copy) {
|
||||
static Tensor dispatch_to(const Tensor & self, Device device, ScalarType dtype, bool non_blocking, bool copy, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
AutoNoGIL no_gil;
|
||||
return self.to(device, dtype, non_blocking, copy);
|
||||
return self.to(device, dtype, non_blocking, copy, optional_memory_format);
|
||||
}
|
||||
|
||||
static PyObject * THPVariable_cpu(PyObject* self, PyObject* args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
|
||||
return THPVariable_Wrap(dispatch_to(self_, at::Device(at::DeviceType::CPU), false, false));
|
||||
// Setting to MemoryFormat::Contiguous now, will change to accept memory_format in next PR
|
||||
return THPVariable_Wrap(dispatch_to(self_, at::Device(at::DeviceType::CPU), false, false, MemoryFormat::Contiguous));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -407,14 +408,16 @@ static PyObject * THPVariable_cuda(PyObject* self, PyObject* args, PyObject* kwa
|
||||
auto device = r.isNone(0) ? at::Device(at::DeviceType::CUDA) : r.device(0);
|
||||
TORCH_CHECK(device.is_cuda(), "Invalid device, must be cuda device");
|
||||
torch::utils::cuda_lazy_init();
|
||||
return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false));
|
||||
// Setting to MemoryFormat::Contiguous now, will change to accept memory_format in next PR
|
||||
return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false, MemoryFormat::Contiguous));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject * THPVariable_to_type(PyObject* self, ScalarType scalarType) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
|
||||
return THPVariable_Wrap(dispatch_to(self_, scalarType, false, false));
|
||||
// Setting to MemoryFormat::Contiguous now, will change to accept memory_format in next PR
|
||||
return THPVariable_Wrap(dispatch_to(self_, scalarType, false, false, MemoryFormat::Contiguous));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
static PyObject * THPVariable_byte(PyObject* self, PyObject* args) {
|
||||
@ -639,6 +642,7 @@ static PyObject * THPVariable_to(PyObject* self, PyObject* args, PyObject* kwarg
|
||||
auto& scalarType = std::get<1>(parsed);
|
||||
auto non_blocking = std::get<2>(parsed);
|
||||
auto copy = std::get<3>(parsed);
|
||||
auto opt_memory_format = std::get<4>(parsed);
|
||||
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
|
||||
if (device && device->is_cuda()) {
|
||||
torch::utils::cuda_lazy_init();
|
||||
@ -647,11 +651,11 @@ static PyObject * THPVariable_to(PyObject* self, PyObject* args, PyObject* kwarg
|
||||
Py_INCREF(self);
|
||||
return self;
|
||||
} else if (!device) {
|
||||
return THPVariable_Wrap(dispatch_to(self_, *scalarType, non_blocking, copy));
|
||||
return THPVariable_Wrap(dispatch_to(self_, *scalarType, non_blocking, copy, opt_memory_format));
|
||||
} else if (!scalarType) {
|
||||
return THPVariable_Wrap(dispatch_to(self_, *device, non_blocking, copy));
|
||||
return THPVariable_Wrap(dispatch_to(self_, *device, non_blocking, copy, opt_memory_format));
|
||||
} else {
|
||||
return THPVariable_Wrap(dispatch_to(self_, *device, *scalarType, non_blocking, copy));
|
||||
return THPVariable_Wrap(dispatch_to(self_, *device, *scalarType, non_blocking, copy, opt_memory_format));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
@ -670,16 +674,17 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser({
|
||||
"type(PyObject* dtype=None, bool non_blocking=False)",
|
||||
"type(PyObject* dtype=None, bool async=False)|deprecated"
|
||||
"type(PyObject* dtype=None, bool non_blocking=False, *, MemoryFormat? memory_format=None)",
|
||||
"type(PyObject* dtype=None, bool async=False, *, MemoryFormat? memory_format=None)|deprecated"
|
||||
});
|
||||
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
|
||||
ParsedArgs<2> parsed_args;
|
||||
ParsedArgs<3> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
if (r.isNone(0)) {
|
||||
return THPUtils_packString(torch::utils::type_to_string(self_.type()));
|
||||
}
|
||||
auto obj = r.pyobject(0);
|
||||
auto opt_memory_format = r.memoryformatOptional(2);
|
||||
std::string type_name;
|
||||
bool is_dtype = false;
|
||||
if (PyType_Check(obj)) {
|
||||
@ -710,7 +715,7 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa
|
||||
if (device.is_cuda()) {
|
||||
torch::utils::cuda_lazy_init();
|
||||
}
|
||||
return THPVariable_Wrap(dispatch_to(self_, device, scalar_type, /*non_blocking=*/ r.toBool(1), /*copy=*/ false));
|
||||
return THPVariable_Wrap(dispatch_to(self_, device, scalar_type, /*non_blocking=*/ r.toBool(1), /*copy=*/ false, opt_memory_format));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user