mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Adding memory_format to empty and empty_like operators (#20558)
Summary: Original RFC https://github.com/pytorch/pytorch/issues/19092 To ensure that we are not introducing BC breaking change, empty_like returns contiguous tensor by default. ```python nCwh = torch.randn(N, C, H, W) nhwC = nCwh.contiguous(memory_format=torch.channels_last) new_nCwh = torch.empty_like(nhwC) new_nCwh.is_contiguous(memory_format=torch.channels_last) == False ``` Now we need a way to preserve memory format in `empty_like` ```python nCwh = torch.randn(N, C, H, W) nhwC = nCwh.contiguous(memory_format=torch.channels_last) new_nhwC = torch.empty_like(nhwC, memory_format=torch.preserve_format) new_nhwC.is_contiguous(memory_format=torch.channels_last) == True like_nCwh = torch.empty_like(nCwh, memory_format=torch.preserve_format) like_nCwh.is_contiguous(memory_format=torch.channels_last) == False ``` Usage of `torch.preserve_format` allows us to avoid `if` constructs. We can also generate different memory format outputs ```python nCwh = torch.randn(N, C, H, W) nhwC = nCwh.contiguous(memory_format=torch.channels_last) new_nhwC = torch.empty_like(nCwh, memory_format=torch.channels_last) new_nhwC.is_contiguous(memory_format=torch.channels_last) == True new_nCwh = torch.empty_like(nhwC, memory_format=torch.contiguous_format) new_nCwh.is_contiguous(memory_format=torch.channels_last) == False ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/20558 Differential Revision: D15502474 Pulled By: VitalyFedyunin fbshipit-source-id: 2e120d57eefad6fb8e04b8322c79871392f64331
This commit is contained in:
committed by
Facebook Github Bot
parent
5bdc4db26e
commit
516c7e4456
@ -166,7 +166,7 @@ static PyObject * THPVariable_contiguous(PyObject* self, PyObject* args, PyObjec
|
||||
ParsedArgs<1> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
|
||||
auto memory_format = r.toMemoryFormat(0);
|
||||
auto memory_format = r.memoryformat(0);
|
||||
// avoids touching the GIL or current device if self is already contiguous
|
||||
if (self_.is_contiguous(memory_format)) {
|
||||
// NOTE: this logic is duplicated from VariableType.cpp. Since we need to
|
||||
@ -485,7 +485,7 @@ static PyObject * THPVariable_is_contiguous(PyObject* self_, PyObject* args, PyO
|
||||
});
|
||||
ParsedArgs<1> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
auto memory_format = r.toMemoryFormat(0);
|
||||
auto memory_format = r.memoryformat(0);
|
||||
auto& self = reinterpret_cast<THPVariable*>(self_)->cdata;
|
||||
return wrap(dispatch_is_contiguous(self, memory_format));
|
||||
END_HANDLE_TH_ERRORS
|
||||
|
Reference in New Issue
Block a user