Add non_blocking to Tensor/Module.to (#7312)

* Add non_blocking to Tensor/Module.to

* flake8

* Add argparse tests

* cpp parse

* Use C++ parser

* use a commong parse function with Tensor.to

* fix test_jit

* use THPObjectPtr

* increase refcount for None, True, and False

* address comments

* address comments
This commit is contained in:
Tongzhou Wang
2018-06-04 18:46:52 -04:00
committed by GitHub
parent ec4a0f332e
commit c0a419e6ba
14 changed files with 177 additions and 111 deletions

View File

@ -8,6 +8,7 @@
#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/autograd/utils/python_error_messages.h"
#include "torch/csrc/autograd/utils/wrap_outputs.h"
#include "torch/csrc/autograd/utils/python_arg_parsing.h"
#include "torch/csrc/jit/tracer.h"
#ifdef WITH_CUDA
#include "torch/csrc/cuda/Stream.h"
@ -558,31 +559,22 @@ static PyObject * THPVariable_storage_type(PyObject* self, PyObject* arg)
static PyObject * THPVariable_to(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"to(Device device, ScalarType dtype=None)",
"to(ScalarType dtype)",
"to(Tensor other)",
});
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
auto device = r.device(0);
auto deviceAutoGPU = device.deviceInt64();
auto scalarType = r.scalartypeWithDefault(1, self_.type().scalarType());
auto& layout = *torch::getLayout(self_.type().backend());
auto& type = torch::getType(scalarType, layout, device.type);
return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type, deviceAutoGPU, false));
} else if (r.idx == 1) {
auto scalarType = r.scalartype(0);
auto& type = self_.type().toScalarType(scalarType);
auto parsed = parse_to_conversion(args, kwargs);
auto& device = std::get<0>(parsed);
auto& scalarType = std::get<1>(parsed);
auto non_blocking = std::get<2>(parsed);
if (!device) {
// device not given
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
auto& type = self_.type().toScalarType(scalarType.value_or(self_.type().scalarType()));
return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type));
} else if (r.idx == 2) {
auto other = r.tensor(0);
auto& type = other.type();
auto deviceType = torch::getDeviceType(type);
auto deviceAutoGPU = (deviceType == DeviceType::CPU) ? -1 : other.get_device();
return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type, deviceAutoGPU, false));
} else {
// device and maybe dtype are given
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
auto deviceAutoGPU = device->deviceInt64();
auto& layout = *torch::getLayout(self_.type().backend());
auto& type = torch::getType(scalarType.value_or(self_.type().scalarType()), layout, device->type);
return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type, deviceAutoGPU, non_blocking));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS