mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/142442 Approved by: https://github.com/albanD
		
			
				
	
	
		
			167 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			167 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| #include <torch/csrc/utils/pybind.h>
 | |
| #include <torch/csrc/utils/python_arg_parser.h>
 | |
| #include <torch/csrc/utils/python_symnode.h>
 | |
| 
 | |
| namespace pybind11::detail {
 | |
| 
 | |
| bool type_caster<c10::SymInt>::load(py::handle src, bool) {
 | |
|   if (torch::is_symint(src)) {
 | |
|     auto node = src.attr("node");
 | |
|     if (py::isinstance<c10::SymNodeImpl>(node)) {
 | |
|       value = c10::SymInt(py::cast<c10::SymNode>(node));
 | |
|       return true;
 | |
|     }
 | |
| 
 | |
|     value = c10::SymInt(static_cast<c10::SymNode>(
 | |
|         c10::make_intrusive<torch::impl::PythonSymNodeImpl>(node)));
 | |
|     return true;
 | |
|   }
 | |
| 
 | |
|   auto raw_obj = src.ptr();
 | |
| 
 | |
|   if (THPVariable_Check(raw_obj)) {
 | |
|     auto& var = THPVariable_Unpack(raw_obj);
 | |
|     if (var.numel() == 1 &&
 | |
|         at::isIntegralType(var.dtype().toScalarType(), /*include_bool*/ true)) {
 | |
|       auto scalar = var.item();
 | |
|       TORCH_INTERNAL_ASSERT(scalar.isIntegral(/*include bool*/ false));
 | |
|       value = scalar.toSymInt();
 | |
|       return true;
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   if (THPUtils_checkIndex(raw_obj)) {
 | |
|     value = c10::SymInt{THPUtils_unpackIndex(raw_obj)};
 | |
|     return true;
 | |
|   }
 | |
|   return false;
 | |
| }
 | |
| 
 | |
| py::handle type_caster<c10::SymInt>::cast(
 | |
|     const c10::SymInt& si,
 | |
|     return_value_policy /* policy */,
 | |
|     handle /* parent */) {
 | |
|   if (si.is_symbolic()) {
 | |
|     auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
 | |
|         si.toSymNodeImplUnowned());
 | |
|     if (py_node) {
 | |
|       // Return the Python directly (unwrap)
 | |
|       return torch::get_symint_class()(py_node->getPyObj()).release();
 | |
|     } else {
 | |
|       // Wrap the C++ into Python
 | |
|       auto inner = py::cast(si.toSymNode());
 | |
|       if (!inner) {
 | |
|         throw python_error();
 | |
|       }
 | |
|       return torch::get_symint_class()(inner).release();
 | |
|     }
 | |
|   } else {
 | |
|     auto m = si.maybe_as_int();
 | |
|     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
 | |
|     return py::cast(m.value()).release();
 | |
|   }
 | |
| }
 | |
| 
 | |
| bool type_caster<c10::SymFloat>::load(py::handle src, bool) {
 | |
|   if (torch::is_symfloat(src)) {
 | |
|     value = c10::SymFloat(static_cast<c10::SymNode>(
 | |
|         c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
 | |
|     return true;
 | |
|   }
 | |
| 
 | |
|   auto raw_obj = src.ptr();
 | |
|   if (THPUtils_checkDouble(raw_obj)) {
 | |
|     value = c10::SymFloat{THPUtils_unpackDouble(raw_obj)};
 | |
|     return true;
 | |
|   }
 | |
|   return false;
 | |
| }
 | |
| 
 | |
| py::handle type_caster<c10::SymFloat>::cast(
 | |
|     const c10::SymFloat& si,
 | |
|     return_value_policy /* policy */,
 | |
|     handle /* parent */) {
 | |
|   if (si.is_symbolic()) {
 | |
|     // TODO: generalize this to work with C++ backed class
 | |
|     auto* py_node =
 | |
|         dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
 | |
|     TORCH_INTERNAL_ASSERT(py_node);
 | |
|     return torch::get_symfloat_class()(py_node->getPyObj()).release();
 | |
|   } else {
 | |
|     return py::cast(si.as_float_unchecked()).release();
 | |
|   }
 | |
| }
 | |
| 
 | |
| bool type_caster<c10::SymBool>::load(py::handle src, bool) {
 | |
|   if (torch::is_symbool(src)) {
 | |
|     value = c10::SymBool(static_cast<c10::SymNode>(
 | |
|         c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
 | |
|     return true;
 | |
|   }
 | |
| 
 | |
|   auto raw_obj = src.ptr();
 | |
|   if (THPUtils_checkBool(raw_obj)) {
 | |
|     value = c10::SymBool{THPUtils_unpackBool(raw_obj)};
 | |
|     return true;
 | |
|   }
 | |
|   return false;
 | |
| }
 | |
| 
 | |
| py::handle type_caster<c10::SymBool>::cast(
 | |
|     const c10::SymBool& si,
 | |
|     return_value_policy /* policy */,
 | |
|     handle /* parent */) {
 | |
|   if (auto m = si.maybe_as_bool()) {
 | |
|     return py::cast(*m).release();
 | |
|   } else {
 | |
|     // TODO: generalize this to work with C++ backed class
 | |
|     auto* py_node =
 | |
|         dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
 | |
|     TORCH_INTERNAL_ASSERT(py_node);
 | |
|     return torch::get_symbool_class()(py_node->getPyObj()).release();
 | |
|   }
 | |
| }
 | |
| 
 | |
| bool type_caster<c10::Scalar>::load(py::handle src, bool) {
 | |
|   TORCH_INTERNAL_ASSERT(
 | |
|       0, "pybind11 loading for c10::Scalar NYI (file a bug if you need it)");
 | |
| }
 | |
| 
 | |
| py::handle type_caster<c10::Scalar>::cast(
 | |
|     const c10::Scalar& scalar,
 | |
|     return_value_policy /* policy */,
 | |
|     handle /* parent */) {
 | |
|   if (scalar.isIntegral(/*includeBool*/ false)) {
 | |
|     // We have to be careful here; we cannot unconditionally route through
 | |
|     // SymInt because integer data from Tensors can easily be MIN_INT or
 | |
|     // very negative, which conflicts with the allocated range.
 | |
|     if (scalar.isSymbolic()) {
 | |
|       return py::cast(scalar.toSymInt()).release();
 | |
|     } else {
 | |
|       if (scalar.type() == at::ScalarType::UInt64) {
 | |
|         return py::cast(scalar.toUInt64()).release();
 | |
|       } else {
 | |
|         return py::cast(scalar.toLong()).release();
 | |
|       }
 | |
|     }
 | |
|   } else if (scalar.isFloatingPoint()) {
 | |
|     // This isn't strictly necessary but we add it for symmetry
 | |
|     if (scalar.isSymbolic()) {
 | |
|       return py::cast(scalar.toSymFloat()).release();
 | |
|     } else {
 | |
|       return py::cast(scalar.toDouble()).release();
 | |
|     }
 | |
|   } else if (scalar.isBoolean()) {
 | |
|     if (scalar.isSymbolic()) {
 | |
|       return py::cast(scalar.toSymBool()).release();
 | |
|     }
 | |
|     return py::cast(scalar.toBool()).release();
 | |
|   } else if (scalar.isComplex()) {
 | |
|     return py::cast(scalar.toComplexDouble()).release();
 | |
|   } else {
 | |
|     TORCH_INTERNAL_ASSERT(0, "unrecognized scalar type ", scalar.type());
 | |
|   }
 | |
| }
 | |
| 
 | |
| } // namespace pybind11::detail
 |