mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 23:15:01 +08:00
refactor self to be a class again (#22207)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22207 ghimport-source-id: 36ee8bd17411a2e220665ad2a27364653061070e Test Plan: Imported from OSS Differential Revision: D15998758 Pulled By: suo fbshipit-source-id: 14bad87bb6e44bf1a43ae86339d8cc7b311c76dd
This commit is contained in:
committed by
Facebook Github Bot
parent
c0674cebf1
commit
ee9c8a75f4
@ -213,12 +213,22 @@ FunctionSchema getSchemaWithNameAndDefaults(
|
||||
schema.is_varret());
|
||||
}
|
||||
|
||||
static Self moduleSelf(const Module& m, const py::object& py_m) {
|
||||
return [m, py_m](Value* v) {
|
||||
v->setType(m.module_object()->type());
|
||||
return std::make_shared<ModuleValue>(v, m, py_m);
|
||||
};
|
||||
}
|
||||
struct ModuleSelf : public Self {
|
||||
ModuleSelf(const Module& m, py::object& py_m)
|
||||
: Self(), module_(m), pyModule_(py_m) {}
|
||||
|
||||
std::shared_ptr<SugaredValue> makeSugared(Value* v) const override {
|
||||
v->setType(module_.type());
|
||||
return std::make_shared<ModuleValue>(v, module_, pyModule_);
|
||||
}
|
||||
ClassTypePtr getClassType() const override {
|
||||
return module_.type();
|
||||
}
|
||||
|
||||
private:
|
||||
const Module& module_;
|
||||
const py::object& pyModule_;
|
||||
};
|
||||
|
||||
static TypePtr getTensorType(
|
||||
const at::Tensor& t,
|
||||
@ -361,9 +371,9 @@ void initJitScriptBindings(PyObject* module) {
|
||||
py::object py_m,
|
||||
const std::string& script,
|
||||
ResolutionCallback rcb) {
|
||||
c10::optional<Self> self;
|
||||
const auto self = ModuleSelf(m, py_m);
|
||||
m.class_compilation_unit()->define(
|
||||
m.name(), script, pythonResolver(rcb), moduleSelf(m, py_m));
|
||||
m.name(), script, pythonResolver(rcb), &self);
|
||||
didFinishEmitModule(m);
|
||||
})
|
||||
.def(
|
||||
@ -380,8 +390,8 @@ void initJitScriptBindings(PyObject* module) {
|
||||
resolvers.push_back(pythonResolver(callback));
|
||||
}
|
||||
const auto prefix = QualifiedName(m.name());
|
||||
m.class_compilation_unit()->define(
|
||||
prefix, defs, resolvers, moduleSelf(m, py_m));
|
||||
const auto self = ModuleSelf(m, py_m);
|
||||
m.class_compilation_unit()->define(prefix, defs, resolvers, &self);
|
||||
// Stitch in default arguments for each Def if provided
|
||||
auto defaults_it = defaults.begin();
|
||||
auto defs_it = defs.begin();
|
||||
@ -737,8 +747,8 @@ void initJitScriptBindings(PyObject* module) {
|
||||
rcbs.push_back(
|
||||
pythonResolver(rcb, classDef.name().name(), classType));
|
||||
}
|
||||
cu->define(
|
||||
QualifiedName(classname), methodDefs, rcbs, simpleSelf(classType));
|
||||
const auto self = SimpleSelf(classType);
|
||||
cu->define(classname, methodDefs, rcbs, &self);
|
||||
});
|
||||
|
||||
m.def("parse_type_comment", [](const std::string& comment) {
|
||||
@ -781,15 +791,14 @@ void initJitScriptBindings(PyObject* module) {
|
||||
"_jit_import_functions",
|
||||
[](CompilationUnit& cu,
|
||||
const std::string& src,
|
||||
const std::vector<at::Tensor>& constant_table,
|
||||
const Self& self) {
|
||||
const std::vector<at::Tensor>& constant_table) {
|
||||
import_functions(
|
||||
c10::nullopt,
|
||||
*CompilationUnit::_get_python_cu_const(),
|
||||
cu,
|
||||
std::make_shared<Source>(src),
|
||||
constant_table,
|
||||
self,
|
||||
nullptr,
|
||||
nullptr);
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user