Add Symbool support in python to C++ translation (#98453)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98453
Approved by: https://github.com/ezyang
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2023-04-11 10:50:29 -07:00
committed by PyTorch MergeBot
parent bc8cb62bcb
commit 39fd7f945f
22 changed files with 287 additions and 15 deletions

View File

@ -87,6 +87,9 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
} else if (torch::is_symfloat(py::handle(obj))) {
save_symint = true;
scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
} else if (torch::is_symbool(py::handle(obj))) {
save_symint = true;
scalar = at::Scalar(true);
} else {
throw py::cast_error(
c10::str("Unable to cast ", py::str(obj), " to Tensor"));
@ -171,6 +174,11 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
return py::cast<c10::SymFloat>(obj);
}
return py::cast<double>(obj);
case TypeKind::SymBoolType:
if (torch::is_symbool(obj.ptr())) {
return py::cast<c10::SymBool>(obj);
}
return py::cast<bool>(obj);
case TypeKind::NoneType:
if (!obj.is_none()) {
throw py::cast_error(
@ -285,6 +293,21 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
return listToIValue<double>(obj);
}
}
case TypeKind::SymBoolType: {
bool is_symbolic = false;
for (auto it = obj.begin(); it != obj.end(); it++) {
auto elm = *it;
if (torch::is_symbool(elm)) {
is_symbolic = true;
break;
}
}
if (is_symbolic) {
return listToIValue<c10::SymBool>(obj);
} else {
return listToIValue<bool>(obj);
}
}
case TypeKind::FloatType:
if (!N || !py::isinstance<py::float_>(obj)) {
return IValue(py::cast<std::vector<double>>(obj));
@ -451,6 +474,8 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
return py::cast<c10::SymInt>(obj);
} else if (torch::is_symfloat(obj)) {
return py::cast<c10::SymFloat>(obj);
} else if (torch::is_symbool(obj)) {
return py::cast<c10::SymBool>(obj);
} else {
throw py::cast_error(
c10::str("Cannot cast ", py::str(obj), " to ", type->repr_str()));
@ -675,6 +700,8 @@ py::object toPyObject(IValue ivalue) {
return py::cast(std::move(ivalue).toSymInt());
} else if (ivalue.isSymFloat()) {
return py::cast(std::move(ivalue).toSymFloat());
} else if (ivalue.isSymBool()) {
return py::cast(std::move(ivalue).toSymBool());
} else {
AT_ERROR(
"Missing cases in 'toPyObject'! Can't convert ",