#include #include #include namespace pybind11::detail { bool type_caster::load(py::handle src, bool /*unused*/) { if (torch::is_symint(src)) { auto node = src.attr("node"); if (py::isinstance(node)) { value = c10::SymInt(py::cast(node)); return true; } value = c10::SymInt(static_cast( c10::make_intrusive(node))); return true; } auto raw_obj = src.ptr(); if (THPVariable_Check(raw_obj)) { auto& var = THPVariable_Unpack(raw_obj); if (var.numel() == 1 && at::isIntegralType(var.dtype().toScalarType(), /*include_bool*/ true)) { auto scalar = var.item(); TORCH_INTERNAL_ASSERT(scalar.isIntegral(/*include bool*/ false)); value = scalar.toSymInt(); return true; } } if (THPUtils_checkIndex(raw_obj)) { value = c10::SymInt{THPUtils_unpackIndex(raw_obj)}; return true; } return false; } py::handle type_caster::cast( const c10::SymInt& si, return_value_policy /* policy */, handle /* parent */) { if (si.is_symbolic()) { auto* py_node = dynamic_cast( si.toSymNodeImplUnowned()); if (py_node) { // Return the Python directly (unwrap) return torch::get_symint_class()(py_node->getPyObj()).release(); } else { // Wrap the C++ into Python auto inner = py::cast(si.toSymNode()); if (!inner) { throw python_error(); } return torch::get_symint_class()(inner).release(); } } else { auto m = si.maybe_as_int(); // NOLINTNEXTLINE(bugprone-unchecked-optional-access) return py::cast(m.value()).release(); } } bool type_caster::load(py::handle src, bool /*unused*/) { if (torch::is_symfloat(src)) { value = c10::SymFloat(static_cast( c10::make_intrusive(src.attr("node")))); return true; } auto raw_obj = src.ptr(); if (THPUtils_checkDouble(raw_obj)) { value = c10::SymFloat{THPUtils_unpackDouble(raw_obj)}; return true; } return false; } py::handle type_caster::cast( const c10::SymFloat& si, return_value_policy /* policy */, handle /* parent */) { if (si.is_symbolic()) { // TODO: generalize this to work with C++ backed class auto* py_node = dynamic_cast(si.toSymNodeImpl().get()); TORCH_INTERNAL_ASSERT(py_node); return torch::get_symfloat_class()(py_node->getPyObj()).release(); } else { return py::cast(si.as_float_unchecked()).release(); } } bool type_caster::load(py::handle src, bool /*unused*/) { if (torch::is_symbool(src)) { value = c10::SymBool(static_cast( c10::make_intrusive(src.attr("node")))); return true; } auto raw_obj = src.ptr(); if (THPUtils_checkBool(raw_obj)) { value = c10::SymBool{THPUtils_unpackBool(raw_obj)}; return true; } return false; } py::handle type_caster::cast( const c10::SymBool& si, return_value_policy /* policy */, handle /* parent */) { if (auto m = si.maybe_as_bool()) { return py::cast(*m).release(); } else { // TODO: generalize this to work with C++ backed class auto* py_node = dynamic_cast(si.toSymNodeImpl().get()); TORCH_INTERNAL_ASSERT(py_node); return torch::get_symbool_class()(py_node->getPyObj()).release(); } } bool type_caster::load(py::handle src, bool /*unused*/) { TORCH_INTERNAL_ASSERT( 0, "pybind11 loading for c10::Scalar NYI (file a bug if you need it)"); } py::handle type_caster::cast( const c10::Scalar& scalar, return_value_policy /* policy */, handle /* parent */) { if (scalar.isIntegral(/*includeBool*/ false)) { // We have to be careful here; we cannot unconditionally route through // SymInt because integer data from Tensors can easily be MIN_INT or // very negative, which conflicts with the allocated range. if (scalar.isSymbolic()) { return py::cast(scalar.toSymInt()).release(); } else { if (scalar.type() == at::ScalarType::UInt64) { return py::cast(scalar.toUInt64()).release(); } else { return py::cast(scalar.toLong()).release(); } } } else if (scalar.isFloatingPoint()) { // This isn't strictly necessary but we add it for symmetry if (scalar.isSymbolic()) { return py::cast(scalar.toSymFloat()).release(); } else { return py::cast(scalar.toDouble()).release(); } } else if (scalar.isBoolean()) { if (scalar.isSymbolic()) { return py::cast(scalar.toSymBool()).release(); } return py::cast(scalar.toBool()).release(); } else if (scalar.isComplex()) { return py::cast(scalar.toComplexDouble()).release(); } else { TORCH_INTERNAL_ASSERT(0, "unrecognized scalar type ", scalar.type()); } } } // namespace pybind11::detail