Restore copy_ overload with async arg (#19641)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19641
ghimport-source-id: 7099221334505bacdc209cff8bf29e3004c30379

Differential Revision: D15056755

Pulled By: li-roy

fbshipit-source-id: e9063b606e72a70fc1270fbcdcf1c0b23d876dd3
This commit is contained in:
Roy Li
2019-04-24 17:48:18 -07:00
committed by Facebook Github Bot
parent c08f3d06c3
commit a6811e17c0
2 changed files with 22 additions and 1 deletions

View File

@ -172,6 +172,26 @@ static Tensor dispatch_contiguous(const Tensor & self) {
END_HANDLE_TH_ERRORS
}
static Tensor dispatch_copy_(Tensor & self, const Tensor & other, bool non_blocking) {
AutoNoGIL no_gil;
OptionalDeviceGuard device_guard(device_of(self));
return self.copy_(other, non_blocking);
}
static PyObject * THPVariable_copy_(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"copy_(Tensor other, bool non_blocking=False)",
"copy_(Tensor other, bool async=False)|deprecated"
});
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
return THPVariable_Wrap(dispatch_copy_(self_, r.tensor(0), r.toBool(1)));
END_HANDLE_TH_ERRORS
}
static double dispatch_to_CDouble(const Tensor & self) {
AutoNoGIL no_gil;
OptionalDeviceGuard device_guard(device_of(self));
@ -667,6 +687,7 @@ PyMethodDef variable_methods[] = {
{"byte", (PyCFunction)THPVariable_byte, METH_NOARGS, NULL},
{"char", (PyCFunction)THPVariable_char, METH_NOARGS, NULL},
{"contiguous", (PyCFunction)THPVariable_contiguous, METH_NOARGS, NULL},
{"copy_", (PyCFunction)THPVariable_copy_, METH_VARARGS | METH_KEYWORDS, NULL},
{"cpu", (PyCFunction)THPVariable_cpu, METH_NOARGS, NULL},
{"cuda", (PyCFunction)THPVariable_cuda, METH_VARARGS | METH_KEYWORDS, NULL},
{"dim", (PyCFunction)THPVariable_dim, METH_NOARGS, NULL},