mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Python function to extract information on mobile::Module from flatbuffer (#77328)"
This reverts commit 69fa49f1230f80d1a0667e0a6ac8aca2746431b6. Reverted https://github.com/pytorch/pytorch/pull/77328 on behalf of https://github.com/atalman
This commit is contained in:
@ -918,34 +918,6 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
||||
output = m_loaded()
|
||||
self.assertEqual(output, None)
|
||||
|
||||
def test_module_info_flatbuffer(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Foo, self).__init__()
|
||||
self.foo = torch.nn.Linear(2, 2)
|
||||
self.bar = torch.nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.foo(x)
|
||||
x = self.bar(x)
|
||||
return x
|
||||
|
||||
first_script_module = torch.jit.script(Foo())
|
||||
first_saved_module = io.BytesIO()
|
||||
torch.jit.save_jit_module_to_flatbuffer(
|
||||
first_script_module, first_saved_module)
|
||||
first_saved_module.seek(0)
|
||||
expected = {
|
||||
'bytecode_version': 4,
|
||||
'operator_version': 4,
|
||||
'function_names': {'__torch__.___torch_mangle_0.Foo.forward'},
|
||||
'type_names': set(),
|
||||
'opname_to_num_args': {'aten::linear': 3}}
|
||||
self.assertEqual(
|
||||
torch.jit._serialization.get_flatbuffer_module_info(first_saved_module),
|
||||
expected)
|
||||
|
||||
|
||||
def test_save_load_params_buffers_submodules(self):
|
||||
"""
|
||||
Check that parameters, buffers, and submodules are the same after loading.
|
||||
|
@ -99,18 +99,5 @@ extern "C"
|
||||
reinterpret_cast<char*>(detached_buffer.data()),
|
||||
detached_buffer.size());
|
||||
});
|
||||
pym.def("_get_module_info_from_flatbuffer", [](std::string flatbuffer_content) {
|
||||
py::gil_scoped_acquire acquire;
|
||||
py::dict result;
|
||||
mobile::ModuleInfo minfo = torch::jit::get_module_info_from_flatbuffer(
|
||||
&flatbuffer_content[0]);
|
||||
result["bytecode_version"] = minfo.bytecode_version;
|
||||
result["operator_version"] = minfo.operator_version;
|
||||
result["function_names"] = minfo.function_names;
|
||||
result["type_names"] = minfo.type_names;
|
||||
result["opname_to_num_args"] = minfo.opname_to_num_args;
|
||||
return result;
|
||||
});
|
||||
|
||||
return module;
|
||||
}
|
||||
|
@ -30,8 +30,6 @@ struct Code {
|
||||
// be done in parseMethods().
|
||||
std::vector<mobile::Function*> functions_;
|
||||
size_t register_size_ = 0; // Aggregated output size.
|
||||
// initialized means operators_ array is filled with operators
|
||||
bool initialized = false;
|
||||
};
|
||||
|
||||
} // namespace mobile
|
||||
|
@ -240,6 +240,8 @@ std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
|
||||
function->append_constant(getIValue(i));
|
||||
}
|
||||
|
||||
std::unordered_set<std::string> unsupported_op_names;
|
||||
|
||||
appendUpgraderFunctions(function.get());
|
||||
// 2. Decides if upgrader is needed
|
||||
const uint32_t operator_version = module_->operator_version();
|
||||
@ -252,13 +254,19 @@ std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
|
||||
num_args = op->num_args_serialized();
|
||||
}
|
||||
|
||||
function->append_operator(
|
||||
auto op_found = function->append_operator(
|
||||
op->name()->str(), op->overload_name()->str(), num_args);
|
||||
|
||||
if (!op_found) {
|
||||
unsupported_op_names.emplace(
|
||||
op->name()->str() + "/" + op->overload_name()->str());
|
||||
}
|
||||
}
|
||||
|
||||
if (should_load_operators_) {
|
||||
function->initialize_operators(true);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
unsupported_op_names.empty(),
|
||||
"Unsupported ops: ",
|
||||
c10::Join(", ", unsupported_op_names));
|
||||
|
||||
for (const auto i : *method->type_annotations()) {
|
||||
function->append_type(getOrCreateTypeAnnotations(i));
|
||||
@ -717,13 +725,5 @@ uint64_t get_bytecode_version(const std::string& filename) {
|
||||
return flatbuffer_module->bytecode_version();
|
||||
}
|
||||
|
||||
mobile::ModuleInfo get_module_info_from_flatbuffer(char* flatbuffer_content) {
|
||||
auto* ff_module = mobile::serialization::GetMutableModule(flatbuffer_content);
|
||||
FlatbufferLoader loader;
|
||||
loader.setShouldLoadOperators(false);
|
||||
mobile::Module m = loader.parseModule(ff_module);
|
||||
return mobile::get_module_info(m);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -30,11 +30,6 @@ using ExtraFilesMap = std::unordered_map<std::string, std::string>;
|
||||
// Parse a mobile::Module from flatbuffer's in-memory Module representation.
|
||||
// The caller is assumed to manage the lifetimes of Module.
|
||||
// This function does step 3 described above.
|
||||
// If should_copy_tensor_memory is true, then the returned module will NOT
|
||||
// have refences to flatbuffer_module, so it can be discarded.
|
||||
// If should_copy_tensor_memory is false, then returned module will have
|
||||
// tensors that points inside of flatbuffer_module; the caller need to make
|
||||
// sure that flatbuffer_module outlives returned Module.
|
||||
TORCH_API mobile::Module initialize_mobile_module(
|
||||
mobile::serialization::Module* flatbuffer_module,
|
||||
c10::optional<at::Device> device = c10::nullopt,
|
||||
@ -71,9 +66,6 @@ TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
|
||||
TORCH_API uint64_t get_bytecode_version(std::istream& in);
|
||||
TORCH_API uint64_t get_bytecode_version(const std::string& filename);
|
||||
|
||||
TORCH_API mobile::ModuleInfo get_module_info_from_flatbuffer(
|
||||
char* flatbuffer_content);
|
||||
|
||||
class TORCH_API FlatbufferLoader {
|
||||
public:
|
||||
FlatbufferLoader();
|
||||
@ -126,14 +118,6 @@ class TORCH_API FlatbufferLoader {
|
||||
should_copy_tensor_memory_ = should_copy_tensor_memory;
|
||||
}
|
||||
|
||||
// Whether or not should load operators in functions.
|
||||
// Not loading operators is useful because if an operator is not found
|
||||
// then we throw exceptions, and sometimes we want to print out
|
||||
// what operators are included before that to debug.
|
||||
void setShouldLoadOperators(bool should_load_operators) {
|
||||
should_load_operators_ = should_load_operators;
|
||||
}
|
||||
|
||||
std::shared_ptr<mobile::CompilationUnit> mcu_;
|
||||
std::shared_ptr<CompilationUnit> cu_;
|
||||
|
||||
@ -157,7 +141,6 @@ class TORCH_API FlatbufferLoader {
|
||||
mobile::serialization::Module* module_ = nullptr;
|
||||
bool module_parsed_ = false;
|
||||
bool should_copy_tensor_memory_ = false;
|
||||
bool should_load_operators_ = true;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
|
@ -4,7 +4,6 @@
|
||||
#include <torch/csrc/jit/mobile/parse_bytecode.h>
|
||||
#include <torch/csrc/jit/mobile/parse_operators.h>
|
||||
#include <torch/csrc/jit/mobile/prim_ops_registery.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
#include <torch/csrc/jit/runtime/instruction.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
|
||||
@ -50,48 +49,14 @@ bool Function::append_operator(
|
||||
const c10::optional<int>& num_specified_args) {
|
||||
// Keep the original opname in code_
|
||||
code_.op_names_.emplace_back(name, overload_name);
|
||||
const auto& opname = code_.op_names_.back();
|
||||
code_.operator_input_sizes_.emplace_back(num_specified_args.value_or(-1));
|
||||
return true;
|
||||
}
|
||||
|
||||
void print_unsupported_ops_and_throw(
|
||||
const std::unordered_set<std::string>& unsupported_ops) {}
|
||||
|
||||
std::string operator_str(const c10::OperatorName& opname) {
|
||||
std::string result = opname.name;
|
||||
if (!opname.overload_name.empty()) {
|
||||
result += "." + opname.overload_name;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool Function::initialize_operators(bool should_check_operators) {
|
||||
if (code_.initialized) {
|
||||
return true;
|
||||
}
|
||||
std::unordered_set<std::string> unsupported_op_names;
|
||||
code_.operators_.clear();
|
||||
bool all_ops_supported = true;
|
||||
for (int i = 0; i < code_.op_names_.size(); i++) {
|
||||
const auto& opname = code_.op_names_[i];
|
||||
int num_args = code_.operator_input_sizes_[i];
|
||||
c10::optional<int> num_specified_args =
|
||||
num_args < 0 ? c10::nullopt : c10::optional<int>(num_args);
|
||||
auto func = makeOperatorFunction(opname, num_specified_args);
|
||||
if (!func.has_value()) {
|
||||
unsupported_op_names.insert(operator_str(opname));
|
||||
all_ops_supported = false;
|
||||
return false;
|
||||
}
|
||||
code_.operators_.emplace_back(*func);
|
||||
}
|
||||
if (should_check_operators) {
|
||||
TORCH_CHECK(
|
||||
unsupported_op_names.empty(),
|
||||
"Following ops cannot be found. Please check if the operator library is included in the build. If built with selected ops, check if these ops are in the list. If you are a Meta employee, please see fburl.com/missing_ops for a fix. Or post it in https://discuss.pytorch.org/",
|
||||
c10::Join(", ", unsupported_op_names));
|
||||
}
|
||||
code_.initialized = all_ops_supported;
|
||||
return all_ops_supported;
|
||||
return true;
|
||||
}
|
||||
|
||||
void Function::append_constant(const c10::IValue& constant) {
|
||||
@ -131,7 +96,6 @@ const c10::FunctionSchema& Function::getSchema() const {
|
||||
}
|
||||
|
||||
void Function::run(Stack& stack) {
|
||||
initialize_operators(/* should_check_operators */ true);
|
||||
if (hasSchema()) { // if we have a schema then resolve optional args if any
|
||||
getSchema().checkAndNormalizeInputs<c10::DynamicType>(
|
||||
stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);
|
||||
|
@ -63,12 +63,6 @@ class TORCH_API Function : public torch::jit::Function {
|
||||
const std::vector<c10::TypePtr>& types,
|
||||
const size_t register_size);
|
||||
|
||||
// if not initialize, initialize by loading operators.
|
||||
// return true of all op loaded, return false if some op is not found
|
||||
// in the current runtime. Then, the ops that did not found will be filled
|
||||
// in unsupported_op_names
|
||||
bool initialize_operators(bool should_check_operators);
|
||||
|
||||
private:
|
||||
c10::QualifiedName name_;
|
||||
Code code_;
|
||||
@ -79,8 +73,6 @@ c10::optional<std::function<void(Stack&)>> makeOperatorFunction(
|
||||
c10::OperatorName opname,
|
||||
c10::optional<int> num_specified_args);
|
||||
|
||||
TORCH_API std::string operator_str(const c10::OperatorName& opname);
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -3,7 +3,6 @@
|
||||
#include <torch/csrc/jit/backends/backend_exception.h>
|
||||
#include <torch/csrc/jit/mobile/interpreter.h>
|
||||
#include <torch/csrc/jit/mobile/observer.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
#include <torch/csrc/jit/runtime/jit_exception.h>
|
||||
#include <exception>
|
||||
|
||||
@ -264,40 +263,6 @@ c10::IValue Method::operator()(std::vector<c10::IValue> stack) const {
|
||||
return stack.front();
|
||||
}
|
||||
|
||||
c10::optional<std::string> print_type(const c10::Type& t) {
|
||||
auto namedType = t.cast<c10::NamedType>();
|
||||
if (namedType && namedType->name()) {
|
||||
return namedType->name().value().qualifiedName();
|
||||
}
|
||||
if (auto dyn = t.castRaw<c10::DynamicType>()) {
|
||||
return dyn->fallback()->annotation_str();
|
||||
}
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
TORCH_API ModuleInfo get_module_info(const mobile::Module& module) {
|
||||
ModuleInfo minfo;
|
||||
minfo.operator_version = module.min_operator_version();
|
||||
minfo.bytecode_version = module.bytecode_version();
|
||||
std::vector<std::string> type_name_list;
|
||||
for (const auto& func_ptr : module.compilation_unit().methods()) {
|
||||
const auto& function = *func_ptr;
|
||||
for (int i = 0; i < function.get_code().op_names_.size(); i++) {
|
||||
const auto& op = function.get_code().op_names_[i];
|
||||
minfo.opname_to_num_args[mobile::operator_str(op)] =
|
||||
function.get_code().operator_input_sizes_[i];
|
||||
}
|
||||
for (const c10::TypePtr& tp : function.get_code().types_) {
|
||||
type_name_list.push_back(tp->annotation_str(print_type));
|
||||
}
|
||||
minfo.function_names.insert(function.qualname().qualifiedName());
|
||||
}
|
||||
c10::TypeParser parser(type_name_list);
|
||||
parser.parseList();
|
||||
minfo.type_names = parser.getContainedTypes();
|
||||
return minfo;
|
||||
}
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -163,16 +163,6 @@ class TORCH_API Module {
|
||||
// Extra handle for the module to delete when itself is deleted
|
||||
std::shared_ptr<char> mem_to_delete_;
|
||||
};
|
||||
|
||||
struct TORCH_API ModuleInfo {
|
||||
uint64_t bytecode_version;
|
||||
uint64_t operator_version;
|
||||
std::unordered_map<std::string, int> opname_to_num_args;
|
||||
std::unordered_set<std::string> function_names;
|
||||
std::unordered_set<std::string> type_names;
|
||||
};
|
||||
TORCH_API ModuleInfo get_module_info(const mobile::Module& module);
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -5,10 +5,27 @@ namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
|
||||
void parseOperators(
|
||||
std::string operator_str(
|
||||
const std::string& name,
|
||||
const std::string& overloadname) {
|
||||
std::string result = name;
|
||||
if (!overloadname.empty()) {
|
||||
result += "." + overloadname;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads operators by looking them up in the Dispatcher and returns
|
||||
* the set of operator names (with overload) that are not supported
|
||||
* by the current runtime.
|
||||
*/
|
||||
std::unordered_set<std::string> load_and_find_unsupported_operator_names(
|
||||
c10::ivalue::TupleElements&& ops_list,
|
||||
const uint64_t& module_load_options,
|
||||
mobile::Function* function) {
|
||||
std::unordered_set<std::string> unsupported_op_names;
|
||||
// ops_list is the list of operator names that were read in from
|
||||
// bytecode.plk for the method that is currently being processed.
|
||||
for (auto& op : std::move(ops_list)) {
|
||||
auto op_item = std::move(*std::move(op).toTuple()).elements();
|
||||
TORCH_CHECK(
|
||||
@ -20,13 +37,41 @@ void parseOperators(
|
||||
if (op_item.size() > 2) {
|
||||
num_args = op_item[2].toInt();
|
||||
}
|
||||
function->append_operator(
|
||||
auto op_found = function->append_operator(
|
||||
op_item[0].toString()->string(),
|
||||
op_item[1].toString()->string(),
|
||||
num_args);
|
||||
if (!op_found) {
|
||||
unsupported_op_names.emplace(operator_str(
|
||||
op_item[0].toString()->string(), op_item[1].toString()->string()));
|
||||
}
|
||||
}
|
||||
return unsupported_op_names;
|
||||
}
|
||||
|
||||
void print_unsupported_ops_and_throw(
|
||||
const std::unordered_set<std::string>& unsupported_ops) {
|
||||
std::string error_message("{");
|
||||
for (const auto& op_name : unsupported_ops) {
|
||||
error_message += op_name + ", ";
|
||||
}
|
||||
error_message += "}";
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Following ops cannot be found. Please check if the operator library is included in the build. If built with selected ops, check if these ops are in the list. If you are a Meta employee, please see fburl.com/missing_ops for a fix. Or post it in https://discuss.pytorch.org/",
|
||||
error_message);
|
||||
}
|
||||
|
||||
void parseOperators(
|
||||
c10::ivalue::TupleElements&& ops_list,
|
||||
const uint64_t& module_load_options,
|
||||
mobile::Function* function) {
|
||||
std::unordered_set<std::string> unsupported_op_names =
|
||||
load_and_find_unsupported_operator_names(std::move(ops_list), function);
|
||||
if ((module_load_options & MobileModuleLoadOptions::OPERATOR_CHECK) &&
|
||||
!unsupported_op_names.empty()) {
|
||||
print_unsupported_ops_and_throw(unsupported_op_names);
|
||||
}
|
||||
function->initialize_operators(
|
||||
(module_load_options & MobileModuleLoadOptions::OPERATOR_CHECK));
|
||||
}
|
||||
|
||||
} // namespace mobile
|
||||
|
@ -535,7 +535,6 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
|
||||
op.overload_name,
|
||||
op.num_specified_args);
|
||||
}
|
||||
upgrader_function.function.initialize_operators(true);
|
||||
}
|
||||
return upgrader_function_list;
|
||||
};
|
||||
|
@ -49,12 +49,7 @@ from torch.jit._trace import (
|
||||
)
|
||||
from torch.jit._async import fork, wait
|
||||
from torch.jit._decomposition_utils import _register_decomposition
|
||||
from torch.jit._serialization import (
|
||||
save,
|
||||
load,
|
||||
jit_module_from_flatbuffer,
|
||||
save_jit_module_to_flatbuffer,
|
||||
)
|
||||
from torch.jit._serialization import save, load, jit_module_from_flatbuffer, save_jit_module_to_flatbuffer
|
||||
from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph, set_fusion_strategy
|
||||
from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations
|
||||
from torch.jit._ir_utils import _InsertPoint
|
||||
|
@ -184,17 +184,13 @@ def validate_map_location(map_location=None):
|
||||
return map_location
|
||||
|
||||
|
||||
def get_ff_module():
|
||||
def jit_module_from_flatbuffer(f):
|
||||
try:
|
||||
import torch._C_flatbuffer as ff
|
||||
return ff
|
||||
except ImportError:
|
||||
print("Please include //caffe2:_C_flatbuffer as dependency.")
|
||||
raise
|
||||
|
||||
|
||||
def jit_module_from_flatbuffer(f):
|
||||
ff = get_ff_module()
|
||||
if isinstance(f, string_classes):
|
||||
if not os.path.exists(f): # type: ignore[type-var]
|
||||
raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe]
|
||||
@ -246,40 +242,14 @@ def save_jit_module_to_flatbuffer(m, f):
|
||||
# Save to file
|
||||
torch.jit.save_jit_module_to_flatbuffer(m, 'scriptmodule.ff')
|
||||
"""
|
||||
ff = get_ff_module()
|
||||
try:
|
||||
import torch._C_flatbuffer as ff
|
||||
except ImportError:
|
||||
print("Please include //caffe2:_C_flatbuffer as dependency.")
|
||||
raise
|
||||
if isinstance(f, str) or isinstance(f, pathlib.Path):
|
||||
f = str(f)
|
||||
ff._save_jit_module(m._c, f)
|
||||
else:
|
||||
s = ff._save_jit_module_to_bytes(m._c)
|
||||
f.write(s)
|
||||
|
||||
|
||||
def get_flatbuffer_module_info(path_or_file):
|
||||
r"""Get some information regarding a model file in flatbuffer format.
|
||||
|
||||
|
||||
Args:
|
||||
path_or_file: Either str, Path or file like object (BytesIO OK).
|
||||
If it's str or Path, we will read the file referenced by that
|
||||
path as Bytes.
|
||||
|
||||
Returns:
|
||||
A dict with metadata on what that file contains, currently looks like
|
||||
this:
|
||||
{
|
||||
'bytecode_version': 4, # int
|
||||
'operator_version': 4, # int
|
||||
'function_names': {
|
||||
'__torch__.___torch_mangle_0.Foo.forward'}, # set
|
||||
'type_names': set(), # set
|
||||
'opname_to_num_args': {'aten::linear': 3} # Dict[str, int]
|
||||
}
|
||||
"""
|
||||
ff = get_ff_module()
|
||||
if isinstance(path_or_file, str) or isinstance(path_or_file, pathlib.Path):
|
||||
with open(path_or_file, 'rb') as f:
|
||||
all_bytes = f.read()
|
||||
else:
|
||||
all_bytes = path_or_file.read()
|
||||
return ff._get_module_info_from_flatbuffer(all_bytes)
|
||||
|
Reference in New Issue
Block a user