Revert "Python function to extract information on mobile::Module from flatbuffer (#77328)"

This reverts commit 69fa49f1230f80d1a0667e0a6ac8aca2746431b6.

Reverted https://github.com/pytorch/pytorch/pull/77328 on behalf of https://github.com/atalman
This commit is contained in:
PyTorch MergeBot
2022-05-17 01:35:05 +00:00
parent efcbbb177e
commit 5e3e5a5403
13 changed files with 75 additions and 215 deletions

View File

@ -918,34 +918,6 @@ class TestSaveLoadFlatbuffer(JitTestCase):
output = m_loaded()
self.assertEqual(output, None)
def test_module_info_flatbuffer(self):
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.foo = torch.nn.Linear(2, 2)
self.bar = torch.nn.Linear(2, 2)
def forward(self, x):
x = self.foo(x)
x = self.bar(x)
return x
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save_jit_module_to_flatbuffer(
first_script_module, first_saved_module)
first_saved_module.seek(0)
expected = {
'bytecode_version': 4,
'operator_version': 4,
'function_names': {'__torch__.___torch_mangle_0.Foo.forward'},
'type_names': set(),
'opname_to_num_args': {'aten::linear': 3}}
self.assertEqual(
torch.jit._serialization.get_flatbuffer_module_info(first_saved_module),
expected)
def test_save_load_params_buffers_submodules(self):
"""
Check that parameters, buffers, and submodules are the same after loading.

View File

@ -99,18 +99,5 @@ extern "C"
reinterpret_cast<char*>(detached_buffer.data()),
detached_buffer.size());
});
pym.def("_get_module_info_from_flatbuffer", [](std::string flatbuffer_content) {
py::gil_scoped_acquire acquire;
py::dict result;
mobile::ModuleInfo minfo = torch::jit::get_module_info_from_flatbuffer(
&flatbuffer_content[0]);
result["bytecode_version"] = minfo.bytecode_version;
result["operator_version"] = minfo.operator_version;
result["function_names"] = minfo.function_names;
result["type_names"] = minfo.type_names;
result["opname_to_num_args"] = minfo.opname_to_num_args;
return result;
});
return module;
}

View File

@ -30,8 +30,6 @@ struct Code {
// be done in parseMethods().
std::vector<mobile::Function*> functions_;
size_t register_size_ = 0; // Aggregated output size.
// initialized means operators_ array is filled with operators
bool initialized = false;
};
} // namespace mobile

View File

@ -240,6 +240,8 @@ std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
function->append_constant(getIValue(i));
}
std::unordered_set<std::string> unsupported_op_names;
appendUpgraderFunctions(function.get());
// 2. Decides if upgrader is needed
const uint32_t operator_version = module_->operator_version();
@ -252,13 +254,19 @@ std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
num_args = op->num_args_serialized();
}
function->append_operator(
auto op_found = function->append_operator(
op->name()->str(), op->overload_name()->str(), num_args);
if (!op_found) {
unsupported_op_names.emplace(
op->name()->str() + "/" + op->overload_name()->str());
}
}
if (should_load_operators_) {
function->initialize_operators(true);
}
TORCH_CHECK(
unsupported_op_names.empty(),
"Unsupported ops: ",
c10::Join(", ", unsupported_op_names));
for (const auto i : *method->type_annotations()) {
function->append_type(getOrCreateTypeAnnotations(i));
@ -717,13 +725,5 @@ uint64_t get_bytecode_version(const std::string& filename) {
return flatbuffer_module->bytecode_version();
}
mobile::ModuleInfo get_module_info_from_flatbuffer(char* flatbuffer_content) {
auto* ff_module = mobile::serialization::GetMutableModule(flatbuffer_content);
FlatbufferLoader loader;
loader.setShouldLoadOperators(false);
mobile::Module m = loader.parseModule(ff_module);
return mobile::get_module_info(m);
}
} // namespace jit
} // namespace torch

View File

@ -30,11 +30,6 @@ using ExtraFilesMap = std::unordered_map<std::string, std::string>;
// Parse a mobile::Module from flatbuffer's in-memory Module representation.
// The caller is assumed to manage the lifetimes of Module.
// This function does step 3 described above.
// If should_copy_tensor_memory is true, then the returned module will NOT
// have refences to flatbuffer_module, so it can be discarded.
// If should_copy_tensor_memory is false, then returned module will have
// tensors that points inside of flatbuffer_module; the caller need to make
// sure that flatbuffer_module outlives returned Module.
TORCH_API mobile::Module initialize_mobile_module(
mobile::serialization::Module* flatbuffer_module,
c10::optional<at::Device> device = c10::nullopt,
@ -71,9 +66,6 @@ TORCH_API std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
TORCH_API uint64_t get_bytecode_version(std::istream& in);
TORCH_API uint64_t get_bytecode_version(const std::string& filename);
TORCH_API mobile::ModuleInfo get_module_info_from_flatbuffer(
char* flatbuffer_content);
class TORCH_API FlatbufferLoader {
public:
FlatbufferLoader();
@ -126,14 +118,6 @@ class TORCH_API FlatbufferLoader {
should_copy_tensor_memory_ = should_copy_tensor_memory;
}
// Whether or not should load operators in functions.
// Not loading operators is useful because if an operator is not found
// then we throw exceptions, and sometimes we want to print out
// what operators are included before that to debug.
void setShouldLoadOperators(bool should_load_operators) {
should_load_operators_ = should_load_operators;
}
std::shared_ptr<mobile::CompilationUnit> mcu_;
std::shared_ptr<CompilationUnit> cu_;
@ -157,7 +141,6 @@ class TORCH_API FlatbufferLoader {
mobile::serialization::Module* module_ = nullptr;
bool module_parsed_ = false;
bool should_copy_tensor_memory_ = false;
bool should_load_operators_ = true;
};
} // namespace jit

View File

@ -4,7 +4,6 @@
#include <torch/csrc/jit/mobile/parse_bytecode.h>
#include <torch/csrc/jit/mobile/parse_operators.h>
#include <torch/csrc/jit/mobile/prim_ops_registery.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/runtime/operator.h>
@ -50,48 +49,14 @@ bool Function::append_operator(
const c10::optional<int>& num_specified_args) {
// Keep the original opname in code_
code_.op_names_.emplace_back(name, overload_name);
const auto& opname = code_.op_names_.back();
code_.operator_input_sizes_.emplace_back(num_specified_args.value_or(-1));
return true;
}
void print_unsupported_ops_and_throw(
const std::unordered_set<std::string>& unsupported_ops) {}
std::string operator_str(const c10::OperatorName& opname) {
std::string result = opname.name;
if (!opname.overload_name.empty()) {
result += "." + opname.overload_name;
}
return result;
}
bool Function::initialize_operators(bool should_check_operators) {
if (code_.initialized) {
return true;
}
std::unordered_set<std::string> unsupported_op_names;
code_.operators_.clear();
bool all_ops_supported = true;
for (int i = 0; i < code_.op_names_.size(); i++) {
const auto& opname = code_.op_names_[i];
int num_args = code_.operator_input_sizes_[i];
c10::optional<int> num_specified_args =
num_args < 0 ? c10::nullopt : c10::optional<int>(num_args);
auto func = makeOperatorFunction(opname, num_specified_args);
if (!func.has_value()) {
unsupported_op_names.insert(operator_str(opname));
all_ops_supported = false;
return false;
}
code_.operators_.emplace_back(*func);
}
if (should_check_operators) {
TORCH_CHECK(
unsupported_op_names.empty(),
"Following ops cannot be found. Please check if the operator library is included in the build. If built with selected ops, check if these ops are in the list. If you are a Meta employee, please see fburl.com/missing_ops for a fix. Or post it in https://discuss.pytorch.org/",
c10::Join(", ", unsupported_op_names));
}
code_.initialized = all_ops_supported;
return all_ops_supported;
return true;
}
void Function::append_constant(const c10::IValue& constant) {
@ -131,7 +96,6 @@ const c10::FunctionSchema& Function::getSchema() const {
}
void Function::run(Stack& stack) {
initialize_operators(/* should_check_operators */ true);
if (hasSchema()) { // if we have a schema then resolve optional args if any
getSchema().checkAndNormalizeInputs<c10::DynamicType>(
stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);

View File

@ -63,12 +63,6 @@ class TORCH_API Function : public torch::jit::Function {
const std::vector<c10::TypePtr>& types,
const size_t register_size);
// if not initialize, initialize by loading operators.
// return true of all op loaded, return false if some op is not found
// in the current runtime. Then, the ops that did not found will be filled
// in unsupported_op_names
bool initialize_operators(bool should_check_operators);
private:
c10::QualifiedName name_;
Code code_;
@ -79,8 +73,6 @@ c10::optional<std::function<void(Stack&)>> makeOperatorFunction(
c10::OperatorName opname,
c10::optional<int> num_specified_args);
TORCH_API std::string operator_str(const c10::OperatorName& opname);
} // namespace mobile
} // namespace jit
} // namespace torch

View File

@ -3,7 +3,6 @@
#include <torch/csrc/jit/backends/backend_exception.h>
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/mobile/observer.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <exception>
@ -264,40 +263,6 @@ c10::IValue Method::operator()(std::vector<c10::IValue> stack) const {
return stack.front();
}
c10::optional<std::string> print_type(const c10::Type& t) {
auto namedType = t.cast<c10::NamedType>();
if (namedType && namedType->name()) {
return namedType->name().value().qualifiedName();
}
if (auto dyn = t.castRaw<c10::DynamicType>()) {
return dyn->fallback()->annotation_str();
}
return c10::nullopt;
}
TORCH_API ModuleInfo get_module_info(const mobile::Module& module) {
ModuleInfo minfo;
minfo.operator_version = module.min_operator_version();
minfo.bytecode_version = module.bytecode_version();
std::vector<std::string> type_name_list;
for (const auto& func_ptr : module.compilation_unit().methods()) {
const auto& function = *func_ptr;
for (int i = 0; i < function.get_code().op_names_.size(); i++) {
const auto& op = function.get_code().op_names_[i];
minfo.opname_to_num_args[mobile::operator_str(op)] =
function.get_code().operator_input_sizes_[i];
}
for (const c10::TypePtr& tp : function.get_code().types_) {
type_name_list.push_back(tp->annotation_str(print_type));
}
minfo.function_names.insert(function.qualname().qualifiedName());
}
c10::TypeParser parser(type_name_list);
parser.parseList();
minfo.type_names = parser.getContainedTypes();
return minfo;
}
} // namespace mobile
} // namespace jit
} // namespace torch

View File

@ -163,16 +163,6 @@ class TORCH_API Module {
// Extra handle for the module to delete when itself is deleted
std::shared_ptr<char> mem_to_delete_;
};
struct TORCH_API ModuleInfo {
uint64_t bytecode_version;
uint64_t operator_version;
std::unordered_map<std::string, int> opname_to_num_args;
std::unordered_set<std::string> function_names;
std::unordered_set<std::string> type_names;
};
TORCH_API ModuleInfo get_module_info(const mobile::Module& module);
} // namespace mobile
} // namespace jit
} // namespace torch

View File

@ -5,10 +5,27 @@ namespace torch {
namespace jit {
namespace mobile {
void parseOperators(
std::string operator_str(
const std::string& name,
const std::string& overloadname) {
std::string result = name;
if (!overloadname.empty()) {
result += "." + overloadname;
}
return result;
}
/**
* Loads operators by looking them up in the Dispatcher and returns
* the set of operator names (with overload) that are not supported
* by the current runtime.
*/
std::unordered_set<std::string> load_and_find_unsupported_operator_names(
c10::ivalue::TupleElements&& ops_list,
const uint64_t& module_load_options,
mobile::Function* function) {
std::unordered_set<std::string> unsupported_op_names;
// ops_list is the list of operator names that were read in from
// bytecode.plk for the method that is currently being processed.
for (auto& op : std::move(ops_list)) {
auto op_item = std::move(*std::move(op).toTuple()).elements();
TORCH_CHECK(
@ -20,13 +37,41 @@ void parseOperators(
if (op_item.size() > 2) {
num_args = op_item[2].toInt();
}
function->append_operator(
auto op_found = function->append_operator(
op_item[0].toString()->string(),
op_item[1].toString()->string(),
num_args);
if (!op_found) {
unsupported_op_names.emplace(operator_str(
op_item[0].toString()->string(), op_item[1].toString()->string()));
}
}
return unsupported_op_names;
}
void print_unsupported_ops_and_throw(
const std::unordered_set<std::string>& unsupported_ops) {
std::string error_message("{");
for (const auto& op_name : unsupported_ops) {
error_message += op_name + ", ";
}
error_message += "}";
TORCH_CHECK(
false,
"Following ops cannot be found. Please check if the operator library is included in the build. If built with selected ops, check if these ops are in the list. If you are a Meta employee, please see fburl.com/missing_ops for a fix. Or post it in https://discuss.pytorch.org/",
error_message);
}
void parseOperators(
c10::ivalue::TupleElements&& ops_list,
const uint64_t& module_load_options,
mobile::Function* function) {
std::unordered_set<std::string> unsupported_op_names =
load_and_find_unsupported_operator_names(std::move(ops_list), function);
if ((module_load_options & MobileModuleLoadOptions::OPERATOR_CHECK) &&
!unsupported_op_names.empty()) {
print_unsupported_ops_and_throw(unsupported_op_names);
}
function->initialize_operators(
(module_load_options & MobileModuleLoadOptions::OPERATOR_CHECK));
}
} // namespace mobile

View File

@ -535,7 +535,6 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
op.overload_name,
op.num_specified_args);
}
upgrader_function.function.initialize_operators(true);
}
return upgrader_function_list;
};

View File

@ -49,12 +49,7 @@ from torch.jit._trace import (
)
from torch.jit._async import fork, wait
from torch.jit._decomposition_utils import _register_decomposition
from torch.jit._serialization import (
save,
load,
jit_module_from_flatbuffer,
save_jit_module_to_flatbuffer,
)
from torch.jit._serialization import save, load, jit_module_from_flatbuffer, save_jit_module_to_flatbuffer
from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph, set_fusion_strategy
from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations
from torch.jit._ir_utils import _InsertPoint

View File

@ -184,17 +184,13 @@ def validate_map_location(map_location=None):
return map_location
def get_ff_module():
def jit_module_from_flatbuffer(f):
try:
import torch._C_flatbuffer as ff
return ff
except ImportError:
print("Please include //caffe2:_C_flatbuffer as dependency.")
raise
def jit_module_from_flatbuffer(f):
ff = get_ff_module()
if isinstance(f, string_classes):
if not os.path.exists(f): # type: ignore[type-var]
raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe]
@ -246,40 +242,14 @@ def save_jit_module_to_flatbuffer(m, f):
# Save to file
torch.jit.save_jit_module_to_flatbuffer(m, 'scriptmodule.ff')
"""
ff = get_ff_module()
try:
import torch._C_flatbuffer as ff
except ImportError:
print("Please include //caffe2:_C_flatbuffer as dependency.")
raise
if isinstance(f, str) or isinstance(f, pathlib.Path):
f = str(f)
ff._save_jit_module(m._c, f)
else:
s = ff._save_jit_module_to_bytes(m._c)
f.write(s)
def get_flatbuffer_module_info(path_or_file):
r"""Get some information regarding a model file in flatbuffer format.
Args:
path_or_file: Either str, Path or file like object (BytesIO OK).
If it's str or Path, we will read the file referenced by that
path as Bytes.
Returns:
A dict with metadata on what that file contains, currently looks like
this:
{
'bytecode_version': 4, # int
'operator_version': 4, # int
'function_names': {
'__torch__.___torch_mangle_0.Foo.forward'}, # set
'type_names': set(), # set
'opname_to_num_args': {'aten::linear': 3} # Dict[str, int]
}
"""
ff = get_ff_module()
if isinstance(path_or_file, str) or isinstance(path_or_file, pathlib.Path):
with open(path_or_file, 'rb') as f:
all_bytes = f.read()
else:
all_bytes = path_or_file.read()
return ff._get_module_info_from_flatbuffer(all_bytes)