Replace all direct cdata access with THPVariable_Unpack (#55799)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55799

I'm going to change the implementation of cdata soon so I need to
abstract over cdata access with a function.  Additionally, many
users are casting manually casting to THPVariable to access
the member so I can remove these unsafe casts in the client code
(the implementation, of course, is still doing an unsafe cast.)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D27712130

Pulled By: ezyang

fbshipit-source-id: 95fcc013bf3913d67f2c634068eb5b3aab144cb3
This commit is contained in:
Edward Yang
2021-04-15 08:48:00 -07:00
committed by Facebook GitHub Bot
parent 61418aa069
commit 6ec71ed4f9
24 changed files with 183 additions and 181 deletions

View File

@ -39,12 +39,6 @@ py::object cast_handle_sequence(std::vector<py::handle> objs) {
}
void flatten_rec(PyObject* obj, ParsedArgs& args) {
auto as_variable = [](at::Tensor& tensor) // Wrap tensor as Variable
{
PyObject* wappred_obj = THPVariable_Wrap(tensor);
return reinterpret_cast<THPVariable*>(wappred_obj)->cdata;
};
auto& structure = args.desc.structure;
if (six::isTuple(obj)) {
structure.push_back(D::TupleOpen);
@ -68,28 +62,25 @@ void flatten_rec(PyObject* obj, ParsedArgs& args) {
args.desc.strings.emplace_back(str);
args.desc.structure.push_back(D::String);
} else if (THPVariable_Check(obj)) {
auto& var = reinterpret_cast<THPVariable*>(obj)->cdata;
auto& var = THPVariable_Unpack(obj);
args.vars.push_back(var);
args.desc.metadata.emplace_back(var);
args.desc.structure.push_back(D::Variable);
} else if (strcmp(THPUtils_typename(obj), "NoneType") == 0) {
args.desc.structure.push_back(D::NoneType);
} else if (PyBool_Check(obj)) { // Wrap integers in bool tensors
at::Tensor tensor = scalar_to_tensor(at::Scalar(THPUtils_unpackBool(obj)));
auto var = as_variable(tensor);
at::Tensor var = scalar_to_tensor(at::Scalar(THPUtils_unpackBool(obj)));
args.vars.push_back(var);
args.desc.metadata.emplace_back(var);
args.desc.structure.push_back(D::Bool);
} else if (PyLong_Check(obj)) { // Wrap integers in long tensors
at::Tensor tensor = scalar_to_tensor(
at::Tensor var = scalar_to_tensor(
at::Scalar(static_cast<int64_t>(THPUtils_unpackLong(obj))));
auto var = as_variable(tensor);
args.vars.push_back(var);
args.desc.metadata.emplace_back(var);
args.desc.structure.push_back(D::Long);
} else if (PyFloat_Check(obj)) { // Wrap floating points in double tensors
at::Tensor tensor = scalar_to_tensor(THPUtils_unpackDouble(obj));
auto var = as_variable(tensor);
at::Tensor var = scalar_to_tensor(THPUtils_unpackDouble(obj));
args.vars.push_back(var);
args.desc.metadata.emplace_back(var);
args.desc.structure.push_back(D::Double);