mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #69221. Builds on top of #107000, fixing the buck build issue linked [here](https://github.com/pytorch/pytorch/pull/107000#issuecomment-1708857375). Pull Request resolved: https://github.com/pytorch/pytorch/pull/108832 Approved by: https://github.com/zou3519
114 lines
3.4 KiB
C++
114 lines
3.4 KiB
C++
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
|
// ${generated_comment}
|
|
|
|
#include "torch/csrc/Device.h"
|
|
#include "torch/csrc/DynamicTypes.h"
|
|
#include "torch/csrc/Exceptions.h"
|
|
#include "torch/csrc/autograd/python_nn_functions.h"
|
|
#include "torch/csrc/autograd/generated/python_return_types.h"
|
|
#include "torch/csrc/autograd/python_variable.h"
|
|
#include "torch/csrc/autograd/utils/wrap_outputs.h"
|
|
#include "torch/csrc/autograd/utils/python_arg_parsing.h"
|
|
#include "torch/csrc/utils/pycfunction_helpers.h"
|
|
#include "torch/csrc/utils/python_arg_parser.h"
|
|
#include "torch/csrc/utils/structseq.h"
|
|
#include "torch/csrc/utils/tensor_memoryformats.h"
|
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS
|
|
#include <ATen/Functions.h>
|
|
#else
|
|
$ops_headers
|
|
#endif
|
|
|
|
using at::Tensor;
|
|
using at::Scalar;
|
|
using at::MemoryFormat;
|
|
using at::Generator;
|
|
using at::IntArrayRef;
|
|
using at::ArrayRef;
|
|
|
|
using namespace torch::autograd::utils;
|
|
|
|
namespace torch { namespace autograd {
|
|
|
|
static PyObject* THPNNVariableFunctionsModule = NULL;
|
|
|
|
static PyObject * THPVariable__parse_to(PyObject* module, PyObject* args, PyObject* kwargs)
|
|
{
|
|
HANDLE_TH_ERRORS
|
|
static PythonArgParser parser({
|
|
"to(Device device=None, ScalarType dtype=None, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)",
|
|
"to(ScalarType dtype, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)",
|
|
"to(Tensor tensor, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)",
|
|
});
|
|
ParsedArgs<5> parsed_args;
|
|
auto r = parser.parse(args, kwargs, parsed_args);
|
|
if (r.has_torch_function()) {
|
|
return handle_torch_function(r, args, kwargs, THPNNVariableFunctionsModule, "torch.nn", "_parse_to");
|
|
}
|
|
auto parsed = parse_to_conversion(r, /*allow_copy*/ false); // we don't want copy for nn.Module.to
|
|
auto& device = std::get<0>(parsed);
|
|
auto& scalarType = std::get<1>(parsed);
|
|
auto non_blocking = std::get<2>(parsed);
|
|
auto opt_memory_format = std::get<4>(parsed);
|
|
auto tuple = THPObjectPtr{PyTuple_New(4)};
|
|
if (!tuple) throw python_error();
|
|
if (device) {
|
|
PyTuple_SET_ITEM(tuple.get(), 0, THPDevice_New(*device));
|
|
} else {
|
|
Py_INCREF(Py_None);
|
|
PyTuple_SET_ITEM(tuple.get(), 0, Py_None);
|
|
}
|
|
if (scalarType) {
|
|
PyTuple_SET_ITEM(tuple.get(), 1, torch::autograd::utils::wrap(torch::getTHPDtype(*scalarType)));
|
|
} else {
|
|
Py_INCREF(Py_None);
|
|
PyTuple_SET_ITEM(tuple.get(), 1, Py_None);
|
|
}
|
|
PyTuple_SET_ITEM(tuple.get(), 2, torch::autograd::utils::wrap(non_blocking));
|
|
if (opt_memory_format.has_value()) {
|
|
PyTuple_SET_ITEM(tuple.get(), 3, torch::utils::getTHPMemoryFormat(opt_memory_format.value()));
|
|
} else {
|
|
Py_INCREF(Py_None);
|
|
PyTuple_SET_ITEM(tuple.get(), 3, Py_None);
|
|
}
|
|
return tuple.release();
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
// generated forward declarations start here
|
|
|
|
${py_forwards}
|
|
|
|
static PyMethodDef nn_functions[] = {
|
|
{"_parse_to", castPyCFunctionWithKeywords(THPVariable__parse_to),
|
|
METH_VARARGS | METH_KEYWORDS, nullptr},
|
|
${py_method_defs}
|
|
{NULL}
|
|
};
|
|
|
|
void initNNFunctions(PyObject* module) {
|
|
static struct PyModuleDef def = {
|
|
PyModuleDef_HEAD_INIT,
|
|
"torch._C._nn",
|
|
NULL,
|
|
-1,
|
|
nn_functions
|
|
};
|
|
PyObject* nn = PyModule_Create(&def);
|
|
THPNNVariableFunctionsModule = nn;
|
|
if (!nn) {
|
|
throw python_error();
|
|
}
|
|
// steals a reference to nn
|
|
if (PyModule_AddObject(module, "_nn", nn) != 0) {
|
|
throw python_error();
|
|
}
|
|
}
|
|
|
|
// generated methods start here
|
|
|
|
${py_methods}
|
|
|
|
}} // namespace torch::autograd
|