mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 06:24:59 +08:00
[jit] add a compiled script module (#5630)
Add script::Module C++ class to represent script modules switch AST -> IR conversion to work on Modules/Methods rather than raw graphs function-only AST -> IR conversion is just a simplified case where there is only one module with a single method and no parameters. introduce SugaredValue in compiler.h to represent values in scope in a script function that are not first-class and that get desugared. This is used to represent the module's self parameter, as well as python function calls, and method calls on tensor provide a Python ScriptModule that provides a nice API on top of script::Module allowing for the definition of script modules with methods, parameters, and submodules Not in this PR but intended for the future: ScriptModule actually subclasses nn.Module, with most methods implemented Unification of tracedmodule and script module functionality into one container class. Detailed changelog: * Switch compiler over to using Module, but don't use them yet. * Remove intermediate attribute encoding in compiler * Create SugaredValue object to handle resolution of compiled module. * switch to_ir to modules, implement Select * hacky python wrappers * Private ScriptModule * Add `define` to script module * Attributes use TK_LIST_LITERAL this anticipates adding a real list literal expression to the language. * Add a metaclass to make sure script stubs are registered * Add a test * Doc createResolutionCallback * Docs and minor editing * Address PR comments * Document * Fix unicode issue
This commit is contained in:
committed by
Edward Z. Yang
parent
dede63689f
commit
41285edbb6
40
torch/csrc/jit/script/module.cpp
Normal file
40
torch/csrc/jit/script/module.cpp
Normal file
@ -0,0 +1,40 @@
|
||||
#include "torch/csrc/jit/script/module.h"
|
||||
|
||||
namespace torch { namespace jit { namespace script {
|
||||
|
||||
static std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
|
||||
std::unordered_map<Value*, Value*> value_map;
|
||||
auto value_map_func = [&](Value* v) { return value_map.at(v); };
|
||||
JIT_ASSERT(callee.inputs().size() == inputs.size());
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
value_map[callee.inputs()[i]] = inputs[i];
|
||||
}
|
||||
for (auto* node : callee.nodes()) {
|
||||
auto* new_node =
|
||||
g.insertNode(g.createClone(node, value_map_func));
|
||||
for (size_t i = 0; i < node->outputs().size(); ++i) {
|
||||
value_map[node->outputs()[i]] = new_node->outputs()[i];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Value*> outputs;
|
||||
for (auto* output : callee.outputs()) {
|
||||
outputs.push_back(value_map_func(output));
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::vector<Value*> Method::emit_call_to(Method & callee, ArrayRef<Value*> inputs) {
|
||||
JIT_ASSERT(!executor);
|
||||
auto fn = callee.graph();
|
||||
JIT_ASSERT(inputs.size() == callee.num_inputs());
|
||||
std::vector<Value*> all_inputs = inputs;
|
||||
// parameters to callee method (which become parameters to _this_ method
|
||||
// if they were not already)
|
||||
for(at::Tensor* member : callee.member_inputs) {
|
||||
all_inputs.push_back(get_or_add_parameter(member));
|
||||
}
|
||||
return inlineCallTo(*graph(), *callee.graph(), all_inputs);
|
||||
}
|
||||
|
||||
}}}
|
||||
Reference in New Issue
Block a user