mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136949 Approved by: https://github.com/albanD, https://github.com/eqy ghstack dependencies: #136945
165 lines
4.9 KiB
C++
165 lines
4.9 KiB
C++
#include <torch/csrc/autograd/python_legacy_variable.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
|
|
#include <torch/csrc/Exceptions.h>
|
|
#include <torch/csrc/autograd/python_function.h>
|
|
#include <torch/csrc/autograd/python_variable.h>
|
|
#include <torch/csrc/jit/frontend/tracer.h>
|
|
#include <torch/csrc/tensor/python_tensor.h>
|
|
|
|
using namespace at;
|
|
|
|
namespace torch::autograd {
|
|
|
|
static PyObject* THPVariable_pynew(
|
|
PyTypeObject* type,
|
|
PyObject* args,
|
|
PyObject* kwds) {
|
|
HANDLE_TH_ERRORS
|
|
THPObjectPtr _data;
|
|
PyObject* data = nullptr;
|
|
PyObject* grad_fn = nullptr;
|
|
char is_volatile = 0;
|
|
char requires_grad = 0;
|
|
const char* name = nullptr;
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
|
constexpr const char* accepted_args[] = {
|
|
"data", "requires_grad", "volatile", "_grad_fn", "name", nullptr};
|
|
if (!PyArg_ParseTupleAndKeywords(
|
|
args,
|
|
kwds,
|
|
"|ObbOz",
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
|
const_cast<char**>(accepted_args),
|
|
&data,
|
|
&requires_grad,
|
|
&is_volatile,
|
|
&grad_fn,
|
|
&name))
|
|
return nullptr;
|
|
|
|
if (grad_fn == Py_None)
|
|
grad_fn = nullptr;
|
|
|
|
if (is_volatile) {
|
|
auto r = PyErr_WarnEx(
|
|
PyExc_UserWarning,
|
|
"volatile was removed and now has no effect. Use `with torch.no_grad():` "
|
|
"instead.",
|
|
1);
|
|
if (r != 0)
|
|
throw python_error();
|
|
}
|
|
|
|
TORCH_CHECK_VALUE(
|
|
!is_volatile || !requires_grad,
|
|
"Variable can't be volatile and require_grad at the same time!");
|
|
if (grad_fn && !THPFunction_Check(grad_fn)) {
|
|
throw TypeError(
|
|
"_grad_fn has to be a Function object or None, but got %s",
|
|
Py_TYPE(grad_fn)->tp_name);
|
|
}
|
|
Variable var;
|
|
if (!data || data == Py_None) {
|
|
// For legacy serialization code, create an empty tensor. This is also used
|
|
// by nn.Parameter() with no arguments.
|
|
auto dispatch_key = torch::tensors::get_default_dispatch_key();
|
|
auto scalar_type = torch::tensors::get_default_scalar_type();
|
|
auto options = TensorOptions(scalar_type)
|
|
.device(dispatchKeyToDeviceType(dispatch_key))
|
|
.layout(dispatchKeyToLayout(dispatch_key));
|
|
var = at::empty({0}, options);
|
|
} else if (THPVariable_Check(data)) {
|
|
var = THPVariable_Unpack(data).detach();
|
|
} else {
|
|
throw torch::TypeError(
|
|
"Variable data has to be a tensor, but got %s", Py_TYPE(data)->tp_name);
|
|
}
|
|
// We set `tensor`'s `allow_tensor_metadata_change` to true here, because we
|
|
// want to allow the following use case for backward compatibility:
|
|
//
|
|
// ```python
|
|
// var = Variable(torch.randn(2, 3))
|
|
// var.resize_(4, 5)
|
|
// ```
|
|
var.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
|
|
|
TORCH_CHECK(
|
|
!grad_fn,
|
|
"_grad_fn argument to legacy Variable constructor is no longer supported. "
|
|
"Instead, please invoke your _grad_fn to produce a variable with it as the "
|
|
"_grad_fn.");
|
|
var.set_requires_grad(requires_grad);
|
|
|
|
if (name) {
|
|
impl::set_name(var, name);
|
|
}
|
|
|
|
if (jit::tracer::isTracing() && data && data != Py_None &&
|
|
THPVariable_Check(data)) {
|
|
if (auto* v = jit::tracer::getValueTrace(THPVariable_Unpack(data))) {
|
|
jit::tracer::setValueTrace(var, v);
|
|
}
|
|
}
|
|
|
|
return THPVariable_Wrap(std::move(var));
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyTypeObject THPLegacyVariableType = {
|
|
PyVarObject_HEAD_INIT(nullptr, 0)
|
|
"torch._C._LegacyVariableBase", /* tp_name */
|
|
0, /* tp_basicsize */
|
|
0, /* tp_itemsize */
|
|
nullptr, /* tp_dealloc */
|
|
0, /* tp_vectorcall_offset */
|
|
nullptr, /* tp_getattr */
|
|
nullptr, /* tp_setattr */
|
|
nullptr, /* tp_reserved */
|
|
nullptr, /* tp_repr */
|
|
nullptr, /* tp_as_number */
|
|
nullptr, /* tp_as_sequence */
|
|
nullptr, /* tp_as_mapping */
|
|
nullptr, /* tp_hash */
|
|
nullptr, /* tp_call */
|
|
nullptr, /* tp_str */
|
|
nullptr, /* tp_getattro */
|
|
nullptr, /* tp_setattro */
|
|
nullptr, /* tp_as_buffer */
|
|
// NOLINTNEXTLINE(misc-redundant-expression)
|
|
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
|
|
nullptr, /* tp_doc */
|
|
nullptr, /* tp_traverse */
|
|
nullptr, /* tp_clear */
|
|
nullptr, /* tp_richcompare */
|
|
0, /* tp_weaklistoffset */
|
|
nullptr, /* tp_iter */
|
|
nullptr, /* tp_iternext */
|
|
nullptr, /* tp_methods */
|
|
nullptr, /* tp_members */
|
|
nullptr, /* tp_getset */
|
|
nullptr, /* tp_base */
|
|
nullptr, /* tp_dict */
|
|
nullptr, /* tp_descr_get */
|
|
nullptr, /* tp_descr_set */
|
|
0, /* tp_dictoffset */
|
|
nullptr, /* tp_init */
|
|
nullptr, /* tp_alloc */
|
|
THPVariable_pynew /* tp_new */
|
|
};
|
|
|
|
void init_legacy_variable(PyObject* module) {
|
|
if (PyType_Ready(&THPLegacyVariableType) < 0) {
|
|
throw python_error();
|
|
}
|
|
auto obj = (PyObject*)&THPLegacyVariableType;
|
|
Py_INCREF(obj);
|
|
if (PyModule_AddObject(module, "_LegacyVariableBase", obj) < 0) {
|
|
throw python_error();
|
|
}
|
|
}
|
|
|
|
} // namespace torch::autograd
|