Reland D33284352: [jit][edge] Do not reuse mobile type parser for all unpicklers. (#71048)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71048

reland D33284352 (0a921ba0d0)
ghstack-source-id: 146735646

Test Plan: All Github CI: ciflow rerun -l ciflow/all

Reviewed By: gmagogsfm

Differential Revision: D33489731

fbshipit-source-id: 3e160209a1abb193ad3eed3018054aa7d331025e
This commit is contained in:
Zhengxu Chen
2022-01-10 12:38:33 -08:00
committed by Facebook GitHub Bot
parent fb66f561b1
commit 30699cbfd5
21 changed files with 87 additions and 54 deletions

View File

@ -2,11 +2,7 @@
#include <test/cpp/jit/test_utils.h>
#include <ATen/core/jit_type.h>
namespace c10 {
TypePtr parseType(const std::string& pythonStr);
std::vector<TypePtr> parseType(std::vector<std::string>& pythonStr);
} // namespace c10
#include <torch/csrc/jit/mobile/type_parser.h>
namespace torch {
namespace jit {

View File

@ -5,13 +5,11 @@
* cd ~/pytorch && python torch/csrc/jit/mobile/upgrader_mobile.cpp
*/
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
#include <ATen/core/ivalue.h>
namespace c10 {
TypePtr parseType(const std::string& pythonStr);
} // namespace c10
#include <ATen/core/ivalue.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/mobile/type_parser.h>
namespace torch {
namespace jit {

View File

@ -101,13 +101,11 @@ UPGRADER_CPP_SRC = CodeTemplate("""/**
* cd ~/pytorch && python torch/csrc/jit/mobile/upgrader_mobile.cpp
*/
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
#include <ATen/core/ivalue.h>
namespace c10 {
TypePtr parseType(const std::string& pythonStr);
} // namespace c10
#include <ATen/core/ivalue.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/mobile/type_parser.h>
namespace torch {
namespace jit {

View File

@ -29,8 +29,6 @@ struct Tree;
using TreeRef = c10::intrusive_ptr<Tree>;
using TreeList = at::SmallVector<TreeRef, 4>;
static const TreeList empty_trees = {};
struct Tree : c10::intrusive_ptr_target {
Tree(int kind_) : kind_(kind_) {}
int kind() const {
@ -46,6 +44,7 @@ struct Tree : c10::intrusive_ptr_target {
throw std::runtime_error("stringValue can only be called on TK_STRING");
}
virtual const TreeList& trees() const {
static const TreeList empty_trees = {};
return empty_trees;
}
const TreeRef& tree(size_t i) const {
@ -149,11 +148,11 @@ struct Compound : public Tree {
return false;
}
TreeRef map(const std::function<TreeRef(TreeRef)>& fn) override {
TreeList trees_;
TreeList ret;
for (auto& t : trees()) {
trees_.push_back(fn(t));
ret.push_back(fn(t));
}
return Compound::create(kind(), range(), std::move(trees_));
return Compound::create(kind(), range(), std::move(ret));
}
const SourceRange& range() const override {

View File

@ -1,5 +1,6 @@
#include <torch/csrc/jit/frontend/source_range.h>
#include <torch/csrc/jit/mobile/debug_info.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
#include <torch/csrc/jit/serialization/source_range_serialization.h>
@ -122,10 +123,13 @@ MobileDebugTable::MobileDebugTable(
size_t debug_size{0};
std::tie(debug_data, debug_size) = reader->getRecord(record_name);
auto ivalues =
std::move(
*jit::unpickle(
reinterpret_cast<const char*>(debug_data.get()), debug_size)
.toTuple())
std::move(*jit::unpickle(
reinterpret_cast<const char*>(debug_data.get()),
debug_size,
nullptr,
{},
c10::parseType)
.toTuple())
.elements();
SourceRangeDeserializer deserializer;
for (auto& val : ivalues) {

View File

@ -12,6 +12,7 @@
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/mobile/observer.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/import_export_constants.h>
@ -78,11 +79,6 @@
// - Argument::{known_length_,kwarg_only_}
// - FunctionSchema::{overload_name_, is_vararg_, is_varret_}
namespace c10 {
// std::string serializeType(const Type &t);
TypePtr parseType(const std::string& pythonStr);
} // namespace c10
namespace torch {
namespace jit {
using caffe2::serialize::IStreamAdapter;
@ -502,7 +498,8 @@ c10::IValue BytecodeDeserializer::readArchive(
type_resolver,
obj_loader,
device_,
*reader_.get());
*reader_.get(),
nullptr);
return ivalues;
}

View File

@ -5,6 +5,7 @@
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/mobile/observer.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/unpickler.h>
#include <torch/custom_class.h>
@ -14,11 +15,6 @@
#include <string>
#include <vector>
namespace c10 {
// std::string serializeType(const Type &t);
TypePtr parseType(const std::string& pythonStr);
} // namespace c10
namespace torch {
namespace jit {
using caffe2::serialize::IStreamAdapter;
@ -151,7 +147,9 @@ c10::IValue BytecodeDeserializer::readArchive(
std::move(obj_loader),
std::move(read_record),
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(device));
std::move(device),
false,
nullptr);
return unpickler.parse_ivalue();
}

View File

@ -53,7 +53,8 @@ c10::IValue readArchive(
type_resolver,
obj_loader,
device,
stream_reader);
stream_reader,
nullptr);
return ivalues;
}

View File

@ -1,3 +1,5 @@
#include <torch/csrc/jit/mobile/type_parser.h>
#include <ATen/core/jit_type.h>
#include <c10/util/string_view.h>
#include <torch/csrc/jit/frontend/parser_constants.h>

View File

@ -1,3 +1,5 @@
#pragma once
#include <ATen/core/dynamic_type.h>
#include <ATen/core/jit_type.h>

View File

@ -384,8 +384,8 @@ void listAdd(Stack& stack) {
}
void listInplaceAdd(Stack& stack) {
c10::List<IValue> b = pop(stack).to<List<IValue>>();
c10::List<IValue> a = pop(stack).to<List<IValue>>();
c10::List<IValue> b = pop(stack).to<c10::List<IValue>>();
c10::List<IValue> a = pop(stack).to<c10::List<IValue>>();
a.append(std::move(b));
push(stack, std::move(a));
}

View File

@ -973,7 +973,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
TORCH_SELECTIVE_SCHEMA(
"aten::index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor"),
[](Stack& stack) {
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
auto indices = pop(stack).to<c10::List<c10::optional<at::Tensor>>>();
auto self = pop(stack).toTensor();
auto result = at::index(self, indices);
push(stack, std::move(result));
@ -986,7 +986,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
auto unsafe = pop(stack).toBool();
auto accumulate = pop(stack).toBool();
auto values = pop(stack).toTensor();
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
auto indices = pop(stack).to<c10::List<c10::optional<at::Tensor>>>();
auto self = pop(stack).toTensor();
auto result =
at::_index_put_impl_(self, indices, values, accumulate, unsafe);
@ -999,7 +999,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
[](Stack& stack) {
auto accumulate = pop(stack).toBool();
auto values = pop(stack).toTensor();
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
auto indices = pop(stack).to<c10::List<c10::optional<at::Tensor>>>();
auto self = pop(stack).toTensor();
auto result = at::index_put_(self, indices, values, accumulate);
push(stack, std::move(result));
@ -1011,7 +1011,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
[](Stack& stack) {
auto accumulate = pop(stack).toBool();
auto values = pop(stack).toTensor();
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
auto indices = pop(stack).to<c10::List<c10::optional<at::Tensor>>>();
auto self = pop(stack).toTensor();
auto result = at::index_put_(self, indices, values, accumulate);
push(stack, std::move(result));

View File

@ -1,4 +1,5 @@
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
#include <torch/csrc/jit/serialization/pickle.h>
@ -214,7 +215,12 @@ ska::flat_hash_map<int64_t, DebugInfoTuple> CallStackDebugInfoUnpickler::
size_t size,
const ska::flat_hash_map<int64_t, SourceRange>& source_range_map,
const std::shared_ptr<CompilationUnit>& cu) {
auto ival = jit::unpickle(reinterpret_cast<const char*>(data.get()), size);
auto ival = jit::unpickle(
reinterpret_cast<const char*>(data.get()),
size,
nullptr,
{},
c10::parseType);
ska::flat_hash_map<int64_t, DebugInfoTuple> callstack_ptrs;
auto ivalues = std::move(*std::move(ival).toTuple()).elements();
for (auto& val : ivalues) {

View File

@ -176,6 +176,7 @@ IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) {
obj_loader,
device_,
*reader_.get(),
nullptr,
storage_context_);
}

View File

@ -12,6 +12,7 @@ IValue readArchiveAndTensors(
c10::optional<ObjLoader> obj_loader,
c10::optional<at::Device> device,
caffe2::serialize::PyTorchStreamReader& stream_reader,
c10::TypePtr (*type_parser)(const std::string&),
std::shared_ptr<DeserializationStorageContext> storage_context) {
std::string picklename = pickle_prefix + archive_name + ".pkl";
at::DataPtr pickle_ptr;
@ -47,6 +48,7 @@ IValue readArchiveAndTensors(
std::move(read_record),
device,
false,
type_parser,
storage_context);
unpickler.set_version(stream_reader.version());
return unpickler.parse_ivalue();

View File

@ -20,6 +20,8 @@ TORCH_API IValue readArchiveAndTensors(
c10::optional<ObjLoader> obj_loader,
c10::optional<at::Device> device,
caffe2::serialize::PyTorchStreamReader& stream_reader,
c10::TypePtr (*type_parser)(const std::string&) =
Unpickler::defaultTypeParser,
std::shared_ptr<DeserializationStorageContext> storage_context = nullptr);
bool check_zip_file(

View File

@ -120,9 +120,10 @@ IValue pickle_load(const std::vector<char>& data) {
IValue unpickle(
std::function<size_t(char*, size_t)> reader,
TypeResolver type_resolver,
c10::ArrayRef<at::Tensor> tensor_table) {
c10::ArrayRef<at::Tensor> tensor_table,
c10::TypePtr (*type_parser)(const std::string&)) {
Unpickler unpickler(
std::move(reader), std::move(type_resolver), tensor_table);
std::move(reader), std::move(type_resolver), tensor_table, type_parser);
return unpickler.parse_ivalue();
}
@ -130,7 +131,8 @@ IValue unpickle(
const char* data,
size_t size,
TypeResolver type_resolver,
c10::ArrayRef<at::Tensor> tensor_table) {
c10::ArrayRef<at::Tensor> tensor_table,
c10::TypePtr (*type_parser)(const std::string&)) {
size_t bytes_read = 0;
return unpickle(
[&](char* buffer, size_t len) -> size_t {
@ -145,7 +147,8 @@ IValue unpickle(
return len;
},
std::move(type_resolver),
tensor_table);
tensor_table,
type_parser);
}
} // namespace jit

View File

@ -69,7 +69,9 @@ TORCH_API IValue pickle_load(const std::vector<char>& data);
TORCH_API IValue unpickle(
std::function<size_t(char*, size_t)> reader,
TypeResolver type_resolver,
c10::ArrayRef<at::Tensor> tensor_table);
c10::ArrayRef<at::Tensor> tensor_table,
c10::TypePtr (*type_parser)(const std::string&) =
Unpickler::defaultTypeParser);
/// Decode a chunk of memory containing pickled data into its `torch::IValue`s.
///
@ -81,7 +83,9 @@ TORCH_API IValue unpickle(
const char* data,
size_t size,
TypeResolver type_resolver = nullptr,
c10::ArrayRef<at::Tensor> tensor_table = {});
c10::ArrayRef<at::Tensor> tensor_table = {},
c10::TypePtr (*type_parser)(const std::string&) =
Unpickler::defaultTypeParser);
} // namespace jit
} // namespace torch

View File

@ -1,6 +1,7 @@
#include <torch/csrc/jit/serialization/source_range_serialization.h>
#include <torch/csrc/jit/serialization/source_range_serialization_impl.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/pickle.h>
namespace torch {
@ -111,8 +112,13 @@ void ConcreteSourceRangeUnpickler::unpickle() {
return;
}
auto ivaluesTuple =
jit::unpickle(reinterpret_cast<const char*>(data.get()), size).toTuple();
auto ivaluesTuple = jit::unpickle(
reinterpret_cast<const char*>(data.get()),
size,
nullptr,
{},
c10::parseType)
.toTuple();
const auto& ivalues = ivaluesTuple->elements();
unpickled_records = std::make_shared<SourceRangeRecords>();

View File

@ -565,7 +565,7 @@ void Unpickler::readGlobal(
if (type_resolver_ == nullptr) {
// If we haven't injected a custom way of retrieving types from
// names, use a barebones type parser.
type = c10::parseType(type_str);
type = type_parser_(type_str);
} else {
type = type_resolver_(type_str).type_;
}

View File

@ -4,6 +4,7 @@
#include <c10/util/ArrayRef.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/jit/serialization/pickler.h>
namespace torch {
@ -24,6 +25,8 @@ class DeserializationStorageContext;
class TORCH_API Unpickler {
TH_DISALLOW_COPY_AND_ASSIGN(Unpickler);
using TypeParserT = c10::TypePtr (*)(const std::string&);
public:
// tensors inside the pickle are references to the tensor_table.
// class_resolver is to resolve strong class type, type_resolver_ is
@ -34,11 +37,13 @@ class TORCH_API Unpickler {
Unpickler(
std::function<size_t(char*, size_t)> reader,
TypeResolver type_resolver,
c10::ArrayRef<at::Tensor> tensor_table)
c10::ArrayRef<at::Tensor> tensor_table,
TypeParserT type_parser = defaultTypeParser)
: reader_(std::move(reader)),
tensor_table_(tensor_table),
type_resolver_(std::move(type_resolver)),
use_storage_device_(false),
type_parser_(type_parser),
version_(caffe2::serialize::kProducedFileFormatVersion) {}
// tensors inside the pickle contain meta-data, the raw tensor
@ -51,6 +56,7 @@ class TORCH_API Unpickler {
std::function<at::DataPtr(const std::string&)> read_record,
c10::optional<at::Device> device,
bool use_storage_device = false,
TypeParserT type_parser = defaultTypeParser,
std::shared_ptr<DeserializationStorageContext> storage_context = nullptr)
: reader_(std::move(reader)),
tensor_table_(),
@ -60,6 +66,7 @@ class TORCH_API Unpickler {
// NOLINTNEXTLINE(performance-move-const-arg)
device_(std::move(device)),
use_storage_device_(use_storage_device),
type_parser_(type_parser),
storage_context_(std::move(storage_context)),
version_(caffe2::serialize::kProducedFileFormatVersion) {}
@ -83,6 +90,11 @@ class TORCH_API Unpickler {
version_ = version_number;
}
static c10::TypePtr defaultTypeParser(const std::string& str) {
ScriptTypeParser parser;
return parser.parseType(str);
}
private:
// No arguments ensures that a template argument must be specified
// so that the number of bytes read / type read is explicit
@ -156,6 +168,8 @@ class TORCH_API Unpickler {
// value of this flag is false.
const bool use_storage_device_;
TypeParserT type_parser_{defaultTypeParser};
// Used for torch.package to enable sharing of storages across
// ScriptModules and eager modules
std::shared_ptr<DeserializationStorageContext> storage_context_;