mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Replace all direct cdata access with THPVariable_Unpack (#55799)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55799 I'm going to change the implementation of cdata soon so I need to abstract over cdata access with a function. Additionally, many users are casting manually casting to THPVariable to access the member so I can remove these unsafe casts in the client code (the implementation, of course, is still doing an unsafe cast.) Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D27712130 Pulled By: ezyang fbshipit-source-id: 95fcc013bf3913d67f2c634068eb5b3aab144cb3
This commit is contained in:
committed by
Facebook GitHub Bot
parent
61418aa069
commit
6ec71ed4f9
@ -39,12 +39,6 @@ py::object cast_handle_sequence(std::vector<py::handle> objs) {
|
||||
}
|
||||
|
||||
void flatten_rec(PyObject* obj, ParsedArgs& args) {
|
||||
auto as_variable = [](at::Tensor& tensor) // Wrap tensor as Variable
|
||||
{
|
||||
PyObject* wappred_obj = THPVariable_Wrap(tensor);
|
||||
return reinterpret_cast<THPVariable*>(wappred_obj)->cdata;
|
||||
};
|
||||
|
||||
auto& structure = args.desc.structure;
|
||||
if (six::isTuple(obj)) {
|
||||
structure.push_back(D::TupleOpen);
|
||||
@ -68,28 +62,25 @@ void flatten_rec(PyObject* obj, ParsedArgs& args) {
|
||||
args.desc.strings.emplace_back(str);
|
||||
args.desc.structure.push_back(D::String);
|
||||
} else if (THPVariable_Check(obj)) {
|
||||
auto& var = reinterpret_cast<THPVariable*>(obj)->cdata;
|
||||
auto& var = THPVariable_Unpack(obj);
|
||||
args.vars.push_back(var);
|
||||
args.desc.metadata.emplace_back(var);
|
||||
args.desc.structure.push_back(D::Variable);
|
||||
} else if (strcmp(THPUtils_typename(obj), "NoneType") == 0) {
|
||||
args.desc.structure.push_back(D::NoneType);
|
||||
} else if (PyBool_Check(obj)) { // Wrap integers in bool tensors
|
||||
at::Tensor tensor = scalar_to_tensor(at::Scalar(THPUtils_unpackBool(obj)));
|
||||
auto var = as_variable(tensor);
|
||||
at::Tensor var = scalar_to_tensor(at::Scalar(THPUtils_unpackBool(obj)));
|
||||
args.vars.push_back(var);
|
||||
args.desc.metadata.emplace_back(var);
|
||||
args.desc.structure.push_back(D::Bool);
|
||||
} else if (PyLong_Check(obj)) { // Wrap integers in long tensors
|
||||
at::Tensor tensor = scalar_to_tensor(
|
||||
at::Tensor var = scalar_to_tensor(
|
||||
at::Scalar(static_cast<int64_t>(THPUtils_unpackLong(obj))));
|
||||
auto var = as_variable(tensor);
|
||||
args.vars.push_back(var);
|
||||
args.desc.metadata.emplace_back(var);
|
||||
args.desc.structure.push_back(D::Long);
|
||||
} else if (PyFloat_Check(obj)) { // Wrap floating points in double tensors
|
||||
at::Tensor tensor = scalar_to_tensor(THPUtils_unpackDouble(obj));
|
||||
auto var = as_variable(tensor);
|
||||
at::Tensor var = scalar_to_tensor(THPUtils_unpackDouble(obj));
|
||||
args.vars.push_back(var);
|
||||
args.desc.metadata.emplace_back(var);
|
||||
args.desc.structure.push_back(D::Double);
|
||||
|
Reference in New Issue
Block a user