mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[JIT] Add SchemaInfo python bindings to init.cpp (#81518)
- Added python bindings for SchemaInfo class, SchemaArgument struct, and SchemaArgType enum. - Tested that argument values are added correctly to SchemaInfo binding in test_schema_check.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/81518 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
dadfe1c7bf
commit
8e454cc702
@ -1,6 +1,7 @@
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
#include <torch/csrc/utils/schema_info.h>
|
||||
|
||||
#include <ATen/core/operator_name.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
@ -103,8 +104,6 @@
|
||||
#include <c10/util/signal_handler.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
|
||||
#include <ATen/core/function_schema.h>
|
||||
|
||||
#include <pybind11/cast.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/iostream.h>
|
||||
@ -121,11 +120,14 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using ::c10::AliasInfo;
|
||||
using ::c10::Argument;
|
||||
using ::c10::FunctionSchema;
|
||||
using c10::AliasInfo;
|
||||
using c10::Argument;
|
||||
using c10::FunctionSchema;
|
||||
using c10::SchemaArgType;
|
||||
using c10::SchemaArgument;
|
||||
using caffe2::serialize::PyTorchStreamReader;
|
||||
using caffe2::serialize::PyTorchStreamWriter;
|
||||
using torch::utils::SchemaInfo;
|
||||
|
||||
static std::shared_ptr<c10::SymbolicIntNode> toSymIntNode(
|
||||
std::shared_ptr<c10::SymbolicIntNode> a,
|
||||
@ -1662,7 +1664,63 @@ void initJITBindings(PyObject* module) {
|
||||
}
|
||||
return type.value();
|
||||
});
|
||||
|
||||
py::enum_<SchemaArgType>(m, "_SchemaArgType")
|
||||
.value("input", SchemaArgType::input)
|
||||
.value("output", SchemaArgType::output);
|
||||
py::class_<SchemaArgument>(m, "_SchemaArgument")
|
||||
.def(py::init<SchemaArgType, size_t>())
|
||||
.def_readwrite("type", &SchemaArgument::type)
|
||||
.def_readwrite("index", &SchemaArgument::index);
|
||||
py::class_<SchemaInfo>(m, "_SchemaInfo")
|
||||
.def(py::init<FunctionSchema>())
|
||||
.def("is_mutable", [](SchemaInfo& self) { return self.is_mutable(); })
|
||||
.def(
|
||||
"is_mutable",
|
||||
[](SchemaInfo& self, size_t index) { return self.is_mutable(index); })
|
||||
.def(
|
||||
"is_mutable",
|
||||
[](SchemaInfo& self, c10::string_view name) {
|
||||
return self.is_mutable(name);
|
||||
})
|
||||
.def(
|
||||
"may_alias",
|
||||
[](SchemaInfo& self,
|
||||
const SchemaArgument& lhs,
|
||||
const SchemaArgument& rhs) { return self.may_alias(lhs, rhs); })
|
||||
.def(
|
||||
"may_contain_alias",
|
||||
[](SchemaInfo& self,
|
||||
const SchemaArgument& lhs,
|
||||
const SchemaArgument& rhs) {
|
||||
return self.may_contain_alias(lhs, rhs);
|
||||
})
|
||||
.def(
|
||||
"add_argument_value",
|
||||
[](SchemaInfo& self,
|
||||
const std::string& name,
|
||||
const py::object& value) {
|
||||
if (name == "input") {
|
||||
self.addArgumentValue("self", toTypeInferredIValue(value));
|
||||
} else {
|
||||
self.addArgumentValue(name, toTypeInferredIValue(value));
|
||||
}
|
||||
})
|
||||
.def("add_argument_values", [](SchemaInfo& self, const py::dict& values) {
|
||||
std::unordered_map<std::string, IValue> value_map;
|
||||
for (const auto& key_pair : values) {
|
||||
IValue key = toTypeInferredIValue(key_pair.first);
|
||||
IValue value = toTypeInferredIValue(key_pair.second);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
key.isString(),
|
||||
"Add argument value keys types should be strings.");
|
||||
if (key.toStringRef() == "input") {
|
||||
value_map["self"] = value;
|
||||
} else {
|
||||
value_map[key.toStringRef()] = value;
|
||||
}
|
||||
}
|
||||
self.addArgumentValues(value_map);
|
||||
});
|
||||
py::class_<FunctionSchema>(m, "FunctionSchema")
|
||||
.def_property_readonly(
|
||||
"name", [](FunctionSchema& self) { return self.name(); })
|
||||
|
Reference in New Issue
Block a user