mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
Add non_blocking to Tensor/Module.to (#7312)
* Add non_blocking to Tensor/Module.to * flake8 * Add argparse tests * cpp parse * Use C++ parser * use a commong parse function with Tensor.to * fix test_jit * use THPObjectPtr * increase refcount for None, True, and False * address comments * address comments
This commit is contained in:
@ -8,6 +8,7 @@
|
||||
#include "torch/csrc/autograd/python_variable.h"
|
||||
#include "torch/csrc/autograd/utils/python_error_messages.h"
|
||||
#include "torch/csrc/autograd/utils/wrap_outputs.h"
|
||||
#include "torch/csrc/autograd/utils/python_arg_parsing.h"
|
||||
#include "torch/csrc/jit/tracer.h"
|
||||
#ifdef WITH_CUDA
|
||||
#include "torch/csrc/cuda/Stream.h"
|
||||
@ -558,31 +559,22 @@ static PyObject * THPVariable_storage_type(PyObject* self, PyObject* arg)
|
||||
static PyObject * THPVariable_to(PyObject* self, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser({
|
||||
"to(Device device, ScalarType dtype=None)",
|
||||
"to(ScalarType dtype)",
|
||||
"to(Tensor other)",
|
||||
});
|
||||
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
|
||||
ParsedArgs<2> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
if (r.idx == 0) {
|
||||
auto device = r.device(0);
|
||||
auto deviceAutoGPU = device.deviceInt64();
|
||||
auto scalarType = r.scalartypeWithDefault(1, self_.type().scalarType());
|
||||
auto& layout = *torch::getLayout(self_.type().backend());
|
||||
auto& type = torch::getType(scalarType, layout, device.type);
|
||||
return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type, deviceAutoGPU, false));
|
||||
} else if (r.idx == 1) {
|
||||
auto scalarType = r.scalartype(0);
|
||||
auto& type = self_.type().toScalarType(scalarType);
|
||||
auto parsed = parse_to_conversion(args, kwargs);
|
||||
auto& device = std::get<0>(parsed);
|
||||
auto& scalarType = std::get<1>(parsed);
|
||||
auto non_blocking = std::get<2>(parsed);
|
||||
if (!device) {
|
||||
// device not given
|
||||
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
|
||||
auto& type = self_.type().toScalarType(scalarType.value_or(self_.type().scalarType()));
|
||||
return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type));
|
||||
} else if (r.idx == 2) {
|
||||
auto other = r.tensor(0);
|
||||
auto& type = other.type();
|
||||
auto deviceType = torch::getDeviceType(type);
|
||||
auto deviceAutoGPU = (deviceType == DeviceType::CPU) ? -1 : other.get_device();
|
||||
return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type, deviceAutoGPU, false));
|
||||
} else {
|
||||
// device and maybe dtype are given
|
||||
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
|
||||
auto deviceAutoGPU = device->deviceInt64();
|
||||
auto& layout = *torch::getLayout(self_.type().backend());
|
||||
auto& type = torch::getType(scalarType.value_or(self_.type().scalarType()), layout, device->type);
|
||||
return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type, deviceAutoGPU, non_blocking));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
|
||||
Reference in New Issue
Block a user