mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54864 Support primitive type attributes. Needed for Silero model. Test Plan: Imported from OSS Reviewed By: nikithamalgifb Differential Revision: D27408982 Pulled By: SplitInfinity fbshipit-source-id: 16b291eedbe9f9bb31d7664a29a484555df53755
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ce48b14060
commit
cd9dd653e9
@ -20,6 +20,9 @@ static constexpr char ListClose = ']';
|
||||
static constexpr char TupleOpen = '(';
|
||||
static constexpr char TupleClose = ')';
|
||||
static constexpr char Variable = 'v';
|
||||
static constexpr char Bool = 'b';
|
||||
static constexpr char Long = 'l';
|
||||
static constexpr char Double = 'd';
|
||||
static constexpr char String = 's';
|
||||
static constexpr char NoneType = 'n';
|
||||
} // namespace D
|
||||
@ -36,6 +39,12 @@ py::object cast_handle_sequence(std::vector<py::handle> objs) {
|
||||
}
|
||||
|
||||
void flatten_rec(PyObject* obj, ParsedArgs& args) {
|
||||
auto as_variable = [](at::Tensor& tensor) // Wrap tensor as Variable
|
||||
{
|
||||
PyObject* wappred_obj = THPVariable_Wrap(tensor);
|
||||
return reinterpret_cast<THPVariable*>(wappred_obj)->cdata;
|
||||
};
|
||||
|
||||
auto& structure = args.desc.structure;
|
||||
if (six::isTuple(obj)) {
|
||||
structure.push_back(D::TupleOpen);
|
||||
@ -65,6 +74,25 @@ void flatten_rec(PyObject* obj, ParsedArgs& args) {
|
||||
args.desc.structure.push_back(D::Variable);
|
||||
} else if (strcmp(THPUtils_typename(obj), "NoneType") == 0) {
|
||||
args.desc.structure.push_back(D::NoneType);
|
||||
} else if (PyBool_Check(obj)) { // Wrap integers in bool tensors
|
||||
at::Tensor tensor = scalar_to_tensor(at::Scalar(THPUtils_unpackBool(obj)));
|
||||
auto var = as_variable(tensor);
|
||||
args.vars.push_back(var);
|
||||
args.desc.metadata.emplace_back(var);
|
||||
args.desc.structure.push_back(D::Bool);
|
||||
} else if (PyLong_Check(obj)) { // Wrap integers in long tensors
|
||||
at::Tensor tensor = scalar_to_tensor(
|
||||
at::Scalar(static_cast<int64_t>(THPUtils_unpackLong(obj))));
|
||||
auto var = as_variable(tensor);
|
||||
args.vars.push_back(var);
|
||||
args.desc.metadata.emplace_back(var);
|
||||
args.desc.structure.push_back(D::Long);
|
||||
} else if (PyFloat_Check(obj)) { // Wrap floating points in double tensors
|
||||
at::Tensor tensor = scalar_to_tensor(THPUtils_unpackDouble(obj));
|
||||
auto var = as_variable(tensor);
|
||||
args.vars.push_back(var);
|
||||
args.desc.metadata.emplace_back(var);
|
||||
args.desc.structure.push_back(D::Double);
|
||||
} else {
|
||||
std::string msg =
|
||||
"Only tuples, lists and Variables are supported as JIT inputs/outputs. "
|
||||
@ -142,6 +170,9 @@ py::object unflatten_rec(
|
||||
} else if (type == D::NoneType) {
|
||||
return py::reinterpret_borrow<py::object>(py::none());
|
||||
} else {
|
||||
// if (type == D::Long || type == D::Double || type == D::Bool ||
|
||||
// D::Variable) unwrap variables (D::Variable), or unwrap primitive types
|
||||
// (Long, Double, Bool) as variables for tracer.
|
||||
if (var_it == var_it_end)
|
||||
throw std::runtime_error("Not enough Variables given to unflatten");
|
||||
auto var = *var_it++;
|
||||
|
Reference in New Issue
Block a user