Adding memory_format to empty and empty_like operators (#20558)

Summary:
Original RFC https://github.com/pytorch/pytorch/issues/19092

To ensure that we are not introducing BC breaking change, empty_like returns contiguous tensor by default.

```python
nCwh = torch.randn(N, C, H, W)
nhwC = nCwh.contiguous(memory_format=torch.channels_last)

new_nCwh = torch.empty_like(nhwC)
new_nCwh.is_contiguous(memory_format=torch.channels_last) == False
```

Now we need a way to preserve memory format in `empty_like`

```python
nCwh = torch.randn(N, C, H, W)
nhwC = nCwh.contiguous(memory_format=torch.channels_last)

new_nhwC = torch.empty_like(nhwC, memory_format=torch.preserve_format)
new_nhwC.is_contiguous(memory_format=torch.channels_last) == True

like_nCwh = torch.empty_like(nCwh, memory_format=torch.preserve_format)
like_nCwh.is_contiguous(memory_format=torch.channels_last) == False
```

Usage of `torch.preserve_format` allows us to avoid `if` constructs.

We can also generate different memory format outputs

```python
nCwh = torch.randn(N, C, H, W)
nhwC = nCwh.contiguous(memory_format=torch.channels_last)

new_nhwC = torch.empty_like(nCwh, memory_format=torch.channels_last)
new_nhwC.is_contiguous(memory_format=torch.channels_last) == True

new_nCwh = torch.empty_like(nhwC, memory_format=torch.contiguous_format)
new_nCwh.is_contiguous(memory_format=torch.channels_last) == False
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20558

Differential Revision: D15502474

Pulled By: VitalyFedyunin

fbshipit-source-id: 2e120d57eefad6fb8e04b8322c79871392f64331
This commit is contained in:
Vitaly Fedyunin
2019-06-26 11:40:31 -07:00
committed by Facebook Github Bot
parent 5bdc4db26e
commit 516c7e4456
31 changed files with 228 additions and 99 deletions

View File

@ -166,7 +166,7 @@ static PyObject * THPVariable_contiguous(PyObject* self, PyObject* args, PyObjec
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
auto memory_format = r.toMemoryFormat(0);
auto memory_format = r.memoryformat(0);
// avoids touching the GIL or current device if self is already contiguous
if (self_.is_contiguous(memory_format)) {
// NOTE: this logic is duplicated from VariableType.cpp. Since we need to
@ -485,7 +485,7 @@ static PyObject * THPVariable_is_contiguous(PyObject* self_, PyObject* args, PyO
});
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto memory_format = r.toMemoryFormat(0);
auto memory_format = r.memoryformat(0);
auto& self = reinterpret_cast<THPVariable*>(self_)->cdata;
return wrap(dispatch_is_contiguous(self, memory_format));
END_HANDLE_TH_ERRORS