mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-29 19:24:55 +08:00
Add Symbool support in python to C++ translation (#98453)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98453 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
bc8cb62bcb
commit
39fd7f945f
@ -87,6 +87,9 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
||||
} else if (torch::is_symfloat(py::handle(obj))) {
|
||||
save_symint = true;
|
||||
scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
|
||||
} else if (torch::is_symbool(py::handle(obj))) {
|
||||
save_symint = true;
|
||||
scalar = at::Scalar(true);
|
||||
} else {
|
||||
throw py::cast_error(
|
||||
c10::str("Unable to cast ", py::str(obj), " to Tensor"));
|
||||
@ -171,6 +174,11 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
||||
return py::cast<c10::SymFloat>(obj);
|
||||
}
|
||||
return py::cast<double>(obj);
|
||||
case TypeKind::SymBoolType:
|
||||
if (torch::is_symbool(obj.ptr())) {
|
||||
return py::cast<c10::SymBool>(obj);
|
||||
}
|
||||
return py::cast<bool>(obj);
|
||||
case TypeKind::NoneType:
|
||||
if (!obj.is_none()) {
|
||||
throw py::cast_error(
|
||||
@ -285,6 +293,21 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
||||
return listToIValue<double>(obj);
|
||||
}
|
||||
}
|
||||
case TypeKind::SymBoolType: {
|
||||
bool is_symbolic = false;
|
||||
for (auto it = obj.begin(); it != obj.end(); it++) {
|
||||
auto elm = *it;
|
||||
if (torch::is_symbool(elm)) {
|
||||
is_symbolic = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_symbolic) {
|
||||
return listToIValue<c10::SymBool>(obj);
|
||||
} else {
|
||||
return listToIValue<bool>(obj);
|
||||
}
|
||||
}
|
||||
case TypeKind::FloatType:
|
||||
if (!N || !py::isinstance<py::float_>(obj)) {
|
||||
return IValue(py::cast<std::vector<double>>(obj));
|
||||
@ -451,6 +474,8 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
||||
return py::cast<c10::SymInt>(obj);
|
||||
} else if (torch::is_symfloat(obj)) {
|
||||
return py::cast<c10::SymFloat>(obj);
|
||||
} else if (torch::is_symbool(obj)) {
|
||||
return py::cast<c10::SymBool>(obj);
|
||||
} else {
|
||||
throw py::cast_error(
|
||||
c10::str("Cannot cast ", py::str(obj), " to ", type->repr_str()));
|
||||
@ -675,6 +700,8 @@ py::object toPyObject(IValue ivalue) {
|
||||
return py::cast(std::move(ivalue).toSymInt());
|
||||
} else if (ivalue.isSymFloat()) {
|
||||
return py::cast(std::move(ivalue).toSymFloat());
|
||||
} else if (ivalue.isSymBool()) {
|
||||
return py::cast(std::move(ivalue).toSymBool());
|
||||
} else {
|
||||
AT_ERROR(
|
||||
"Missing cases in 'toPyObject'! Can't convert ",
|
||||
|
||||
Reference in New Issue
Block a user