mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 09:17:11 +08:00
feature: adding the ability to restore shapes after loading a traced model (#90744)
Adds the ability to store inputs used in tracing models when calling torch.jit.save and restore the input shapes using torch.jit.load if the appropriate variables are set. Fixes [89185](https://github.com/pytorch/pytorch/issues/89185) Pull Request resolved: https://github.com/pytorch/pytorch/pull/90744 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
c7c7238976
commit
0d0ebcdfe5
@ -33,6 +33,7 @@
|
||||
#include <torch/csrc/jit/frontend/parser.h>
|
||||
#include <torch/csrc/jit/frontend/tracer.h>
|
||||
#include <torch/csrc/jit/ir/constants.h>
|
||||
#include <torch/csrc/jit/ir/graph_utils.h>
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||
@ -436,91 +437,6 @@ struct VISIBILITY_HIDDEN ModuleSelf : public Self {
|
||||
std::shared_ptr<ConcreteModuleType> concreteType_;
|
||||
};
|
||||
|
||||
static TypePtr getTensorType(const at::Tensor& t, bool complete) {
|
||||
auto r = TensorType::create(t);
|
||||
if (!complete) {
|
||||
r = r->dimensionedOnly();
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
static TypePtr inferShapeAndTypeForInput(
|
||||
TypePtr input_type,
|
||||
Stack::const_iterator& s_iter,
|
||||
const Stack::const_iterator& s_iter_end,
|
||||
bool complete) {
|
||||
if (auto tuple_type = input_type->cast<TupleType>()) {
|
||||
std::vector<TypePtr> types;
|
||||
for (const auto& sub_type : tuple_type->containedTypes()) {
|
||||
TORCH_INTERNAL_ASSERT(s_iter != s_iter_end);
|
||||
types.emplace_back(
|
||||
inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete));
|
||||
}
|
||||
return TupleType::create(types);
|
||||
} else if (auto list_type = input_type->cast<ListType>()) {
|
||||
const TypePtr& sub_type = list_type->getElementType();
|
||||
auto elem_type =
|
||||
inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete);
|
||||
return ListType::create(elem_type);
|
||||
} else if (auto tensor_type = input_type->cast<TensorType>()) {
|
||||
auto type = getTensorType(s_iter->toTensor(), complete);
|
||||
s_iter++;
|
||||
return type;
|
||||
} else if (auto optional_type = input_type->cast<OptionalType>()) {
|
||||
const TypePtr& sub_type = optional_type->getElementType();
|
||||
auto elem_type =
|
||||
inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete);
|
||||
return OptionalType::create(elem_type);
|
||||
} else {
|
||||
// Primitive type, keep as is.
|
||||
s_iter++;
|
||||
return input_type;
|
||||
}
|
||||
}
|
||||
|
||||
static void setInputTensorTypes(
|
||||
Graph& g,
|
||||
const Stack& stack,
|
||||
bool complete,
|
||||
const std::vector<int>& param_count_list = {}) {
|
||||
at::ArrayRef<Value*> input_values = g.inputs();
|
||||
auto s_iter = stack.begin();
|
||||
size_t list_idx = 0;
|
||||
if (!param_count_list.empty()) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
input_values.size() == param_count_list.size(),
|
||||
" input_values:",
|
||||
input_values.size(),
|
||||
" vs param_count_list:",
|
||||
param_count_list.size());
|
||||
}
|
||||
for (auto v : input_values) {
|
||||
// Leave packed param types alone. This is needed for downstream passes
|
||||
// (like alias analysis) to work properly. This will be unpacked later
|
||||
// in unpackQuantizedWeights.
|
||||
if (auto named_type = v->type()->cast<c10::NamedType>()) {
|
||||
if (auto qualname = named_type->name()) {
|
||||
if (getCustomClass(qualname->qualifiedName())) {
|
||||
if (param_count_list.empty()) {
|
||||
AT_ASSERT(s_iter != stack.end());
|
||||
s_iter++;
|
||||
} else {
|
||||
if (param_count_list[list_idx] > 0) {
|
||||
AT_ASSERT(s_iter != stack.end());
|
||||
}
|
||||
s_iter += param_count_list[list_idx];
|
||||
}
|
||||
list_idx++;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
v->setType(
|
||||
inferShapeAndTypeForInput(v->type(), s_iter, stack.end(), complete));
|
||||
list_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
static std::shared_ptr<Graph> _propagate_shapes(
|
||||
Graph& graph,
|
||||
std::vector<at::Tensor> inputs,
|
||||
@ -1190,7 +1106,8 @@ void initJitScriptBindings(PyObject* module) {
|
||||
const py::function& var_name_lookup_fn,
|
||||
bool strict,
|
||||
bool force_outplace,
|
||||
const std::vector<std::string>& argument_names) {
|
||||
const std::vector<std::string>& argument_names,
|
||||
bool store_inputs) {
|
||||
// prereq: Module's buffers and parameters are unique
|
||||
// this was ensured in python before calling this function
|
||||
auto typed_inputs = toTraceableStack(input_tuple);
|
||||
@ -1208,6 +1125,9 @@ void initJitScriptBindings(PyObject* module) {
|
||||
auto fn = self._ivalue()->compilation_unit()->create_function(
|
||||
method_name, graph);
|
||||
self.type()->addMethod(fn);
|
||||
if (store_inputs) {
|
||||
self.store_traced_inputs(name, typed_inputs);
|
||||
}
|
||||
didFinishEmitModule(self);
|
||||
},
|
||||
py::arg("name"),
|
||||
@ -1216,7 +1136,8 @@ void initJitScriptBindings(PyObject* module) {
|
||||
py::arg("var_name_lookup_fn"),
|
||||
py::arg("strict"),
|
||||
py::arg("force_outplace"),
|
||||
py::arg("argument_names") = std::vector<std::string>())
|
||||
py::arg("argument_names") = std::vector<std::string>(),
|
||||
py::arg("store_inputs"))
|
||||
.def(
|
||||
"_create_method_from_trace_with_dict",
|
||||
[](Module& self,
|
||||
@ -1226,7 +1147,8 @@ void initJitScriptBindings(PyObject* module) {
|
||||
const py::function& var_name_lookup_fn,
|
||||
bool strict,
|
||||
bool force_outplace,
|
||||
const std::vector<std::string>& argument_names) {
|
||||
const std::vector<std::string>& argument_names,
|
||||
bool store_inputs) {
|
||||
// prereq: Module's buffers and parameters are unique
|
||||
// this was ensured in python before calling this function
|
||||
auto typed_inputs = toTraceableStack(input_dict);
|
||||
@ -1244,6 +1166,9 @@ void initJitScriptBindings(PyObject* module) {
|
||||
const auto method_name = QualifiedName(*self.type()->name(), name);
|
||||
auto fn = self._ivalue()->compilation_unit()->create_function(
|
||||
method_name, graph);
|
||||
if (store_inputs) {
|
||||
self.store_traced_inputs(name, typed_inputs);
|
||||
}
|
||||
self.type()->addMethod(fn);
|
||||
didFinishEmitModule(self);
|
||||
},
|
||||
@ -1253,7 +1178,8 @@ void initJitScriptBindings(PyObject* module) {
|
||||
py::arg("var_name_lookup_fn"),
|
||||
py::arg("strict"),
|
||||
py::arg("force_outplace"),
|
||||
py::arg("argument_names") = std::vector<std::string>())
|
||||
py::arg("argument_names") = std::vector<std::string>(),
|
||||
py::arg("store_inputs"))
|
||||
.def(
|
||||
"_get_forward_hooks",
|
||||
[](const Module& m) {
|
||||
@ -1272,6 +1198,11 @@ void initJitScriptBindings(PyObject* module) {
|
||||
}
|
||||
return funcs;
|
||||
})
|
||||
.def(
|
||||
"_retrieve_traced_inputs",
|
||||
[](const Module& m) {
|
||||
return ScriptDict(m.retrieve_traced_inputs());
|
||||
})
|
||||
.def_property_readonly(
|
||||
"code",
|
||||
[](Module& self) {
|
||||
@ -1864,7 +1795,8 @@ void initJitScriptBindings(PyObject* module) {
|
||||
[](std::shared_ptr<CompilationUnit> cu,
|
||||
const std::string& filename,
|
||||
py::object map_location,
|
||||
const py::dict& extra_files) {
|
||||
const py::dict& extra_files,
|
||||
bool restore_shapes = false) {
|
||||
c10::optional<at::Device> optional_device;
|
||||
if (!map_location.is_none()) {
|
||||
AT_ASSERT(THPDevice_Check(map_location.ptr()));
|
||||
@ -1873,7 +1805,12 @@ void initJitScriptBindings(PyObject* module) {
|
||||
}
|
||||
ExtraFilesMap extra_files_map = extra_files_from_python(extra_files);
|
||||
auto ret = import_ir_module(
|
||||
std::move(cu), filename, optional_device, extra_files_map);
|
||||
std::move(cu),
|
||||
filename,
|
||||
optional_device,
|
||||
extra_files_map,
|
||||
/*load_debug_files*/ true,
|
||||
restore_shapes);
|
||||
extra_files_to_python(extra_files_map, extra_files);
|
||||
return ret;
|
||||
});
|
||||
@ -1903,7 +1840,8 @@ void initJitScriptBindings(PyObject* module) {
|
||||
[](std::shared_ptr<CompilationUnit> cu,
|
||||
const std::string& buffer,
|
||||
py::object map_location,
|
||||
const py::dict& extra_files) {
|
||||
const py::dict& extra_files,
|
||||
bool restore_shapes = false) {
|
||||
std::istringstream in(buffer);
|
||||
c10::optional<at::Device> optional_device;
|
||||
if (!map_location.is_none()) {
|
||||
@ -1913,7 +1851,12 @@ void initJitScriptBindings(PyObject* module) {
|
||||
}
|
||||
ExtraFilesMap extra_files_map = extra_files_from_python(extra_files);
|
||||
auto ret = import_ir_module(
|
||||
std::move(cu), in, optional_device, extra_files_map);
|
||||
std::move(cu),
|
||||
in,
|
||||
optional_device,
|
||||
extra_files_map,
|
||||
/*load_debug_files*/ true,
|
||||
restore_shapes);
|
||||
extra_files_to_python(extra_files_map, extra_files);
|
||||
return ret;
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user