mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Reland D33284352: [jit][edge] Do not reuse mobile type parser for all unpicklers. (#71048)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71048
reland D33284352 (0a921ba0d0)
ghstack-source-id: 146735646
Test Plan: All Github CI: ciflow rerun -l ciflow/all
Reviewed By: gmagogsfm
Differential Revision: D33489731
fbshipit-source-id: 3e160209a1abb193ad3eed3018054aa7d331025e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
fb66f561b1
commit
30699cbfd5
@ -2,11 +2,7 @@
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
|
||||
#include <ATen/core/jit_type.h>
|
||||
|
||||
namespace c10 {
|
||||
TypePtr parseType(const std::string& pythonStr);
|
||||
std::vector<TypePtr> parseType(std::vector<std::string>& pythonStr);
|
||||
} // namespace c10
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
@ -5,13 +5,11 @@
|
||||
* cd ~/pytorch && python torch/csrc/jit/mobile/upgrader_mobile.cpp
|
||||
*/
|
||||
|
||||
#include <caffe2/serialize/versions.h>
|
||||
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
|
||||
namespace c10 {
|
||||
TypePtr parseType(const std::string& pythonStr);
|
||||
} // namespace c10
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <caffe2/serialize/versions.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
@ -101,13 +101,11 @@ UPGRADER_CPP_SRC = CodeTemplate("""/**
|
||||
* cd ~/pytorch && python torch/csrc/jit/mobile/upgrader_mobile.cpp
|
||||
*/
|
||||
|
||||
#include <caffe2/serialize/versions.h>
|
||||
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
|
||||
namespace c10 {
|
||||
TypePtr parseType(const std::string& pythonStr);
|
||||
} // namespace c10
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <caffe2/serialize/versions.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
@ -29,8 +29,6 @@ struct Tree;
|
||||
using TreeRef = c10::intrusive_ptr<Tree>;
|
||||
using TreeList = at::SmallVector<TreeRef, 4>;
|
||||
|
||||
static const TreeList empty_trees = {};
|
||||
|
||||
struct Tree : c10::intrusive_ptr_target {
|
||||
Tree(int kind_) : kind_(kind_) {}
|
||||
int kind() const {
|
||||
@ -46,6 +44,7 @@ struct Tree : c10::intrusive_ptr_target {
|
||||
throw std::runtime_error("stringValue can only be called on TK_STRING");
|
||||
}
|
||||
virtual const TreeList& trees() const {
|
||||
static const TreeList empty_trees = {};
|
||||
return empty_trees;
|
||||
}
|
||||
const TreeRef& tree(size_t i) const {
|
||||
@ -149,11 +148,11 @@ struct Compound : public Tree {
|
||||
return false;
|
||||
}
|
||||
TreeRef map(const std::function<TreeRef(TreeRef)>& fn) override {
|
||||
TreeList trees_;
|
||||
TreeList ret;
|
||||
for (auto& t : trees()) {
|
||||
trees_.push_back(fn(t));
|
||||
ret.push_back(fn(t));
|
||||
}
|
||||
return Compound::create(kind(), range(), std::move(trees_));
|
||||
return Compound::create(kind(), range(), std::move(ret));
|
||||
}
|
||||
|
||||
const SourceRange& range() const override {
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#include <torch/csrc/jit/frontend/source_range.h>
|
||||
#include <torch/csrc/jit/mobile/debug_info.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
|
||||
#include <torch/csrc/jit/serialization/source_range_serialization.h>
|
||||
|
||||
@ -122,10 +123,13 @@ MobileDebugTable::MobileDebugTable(
|
||||
size_t debug_size{0};
|
||||
std::tie(debug_data, debug_size) = reader->getRecord(record_name);
|
||||
auto ivalues =
|
||||
std::move(
|
||||
*jit::unpickle(
|
||||
reinterpret_cast<const char*>(debug_data.get()), debug_size)
|
||||
.toTuple())
|
||||
std::move(*jit::unpickle(
|
||||
reinterpret_cast<const char*>(debug_data.get()),
|
||||
debug_size,
|
||||
nullptr,
|
||||
{},
|
||||
c10::parseType)
|
||||
.toTuple())
|
||||
.elements();
|
||||
SourceRangeDeserializer deserializer;
|
||||
for (auto& val : ivalues) {
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/mobile/interpreter.h>
|
||||
#include <torch/csrc/jit/mobile/observer.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
|
||||
#include <torch/csrc/jit/runtime/instruction.h>
|
||||
#include <torch/csrc/jit/serialization/import_export_constants.h>
|
||||
@ -78,11 +79,6 @@
|
||||
// - Argument::{known_length_,kwarg_only_}
|
||||
// - FunctionSchema::{overload_name_, is_vararg_, is_varret_}
|
||||
|
||||
namespace c10 {
|
||||
// std::string serializeType(const Type &t);
|
||||
TypePtr parseType(const std::string& pythonStr);
|
||||
} // namespace c10
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
using caffe2::serialize::IStreamAdapter;
|
||||
@ -502,7 +498,8 @@ c10::IValue BytecodeDeserializer::readArchive(
|
||||
type_resolver,
|
||||
obj_loader,
|
||||
device_,
|
||||
*reader_.get());
|
||||
*reader_.get(),
|
||||
nullptr);
|
||||
return ivalues;
|
||||
}
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/mobile/observer.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
#include <torch/csrc/jit/runtime/instruction.h>
|
||||
#include <torch/csrc/jit/serialization/unpickler.h>
|
||||
#include <torch/custom_class.h>
|
||||
@ -14,11 +15,6 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace c10 {
|
||||
// std::string serializeType(const Type &t);
|
||||
TypePtr parseType(const std::string& pythonStr);
|
||||
} // namespace c10
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
using caffe2::serialize::IStreamAdapter;
|
||||
@ -151,7 +147,9 @@ c10::IValue BytecodeDeserializer::readArchive(
|
||||
std::move(obj_loader),
|
||||
std::move(read_record),
|
||||
// NOLINTNEXTLINE(performance-move-const-arg)
|
||||
std::move(device));
|
||||
std::move(device),
|
||||
false,
|
||||
nullptr);
|
||||
return unpickler.parse_ivalue();
|
||||
}
|
||||
|
||||
|
||||
@ -53,7 +53,8 @@ c10::IValue readArchive(
|
||||
type_resolver,
|
||||
obj_loader,
|
||||
device,
|
||||
stream_reader);
|
||||
stream_reader,
|
||||
nullptr);
|
||||
return ivalues;
|
||||
}
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <c10/util/string_view.h>
|
||||
#include <torch/csrc/jit/frontend/parser_constants.h>
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/dynamic_type.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
|
||||
|
||||
@ -384,8 +384,8 @@ void listAdd(Stack& stack) {
|
||||
}
|
||||
|
||||
void listInplaceAdd(Stack& stack) {
|
||||
c10::List<IValue> b = pop(stack).to<List<IValue>>();
|
||||
c10::List<IValue> a = pop(stack).to<List<IValue>>();
|
||||
c10::List<IValue> b = pop(stack).to<c10::List<IValue>>();
|
||||
c10::List<IValue> a = pop(stack).to<c10::List<IValue>>();
|
||||
a.append(std::move(b));
|
||||
push(stack, std::move(a));
|
||||
}
|
||||
|
||||
@ -973,7 +973,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor"),
|
||||
[](Stack& stack) {
|
||||
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
|
||||
auto indices = pop(stack).to<c10::List<c10::optional<at::Tensor>>>();
|
||||
auto self = pop(stack).toTensor();
|
||||
auto result = at::index(self, indices);
|
||||
push(stack, std::move(result));
|
||||
@ -986,7 +986,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
auto unsafe = pop(stack).toBool();
|
||||
auto accumulate = pop(stack).toBool();
|
||||
auto values = pop(stack).toTensor();
|
||||
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
|
||||
auto indices = pop(stack).to<c10::List<c10::optional<at::Tensor>>>();
|
||||
auto self = pop(stack).toTensor();
|
||||
auto result =
|
||||
at::_index_put_impl_(self, indices, values, accumulate, unsafe);
|
||||
@ -999,7 +999,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
[](Stack& stack) {
|
||||
auto accumulate = pop(stack).toBool();
|
||||
auto values = pop(stack).toTensor();
|
||||
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
|
||||
auto indices = pop(stack).to<c10::List<c10::optional<at::Tensor>>>();
|
||||
auto self = pop(stack).toTensor();
|
||||
auto result = at::index_put_(self, indices, values, accumulate);
|
||||
push(stack, std::move(result));
|
||||
@ -1011,7 +1011,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
[](Stack& stack) {
|
||||
auto accumulate = pop(stack).toBool();
|
||||
auto values = pop(stack).toTensor();
|
||||
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
|
||||
auto indices = pop(stack).to<c10::List<c10::optional<at::Tensor>>>();
|
||||
auto self = pop(stack).toTensor();
|
||||
auto result = at::index_put_(self, indices, values, accumulate);
|
||||
push(stack, std::move(result));
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
|
||||
#include <torch/csrc/jit/serialization/pickle.h>
|
||||
|
||||
@ -214,7 +215,12 @@ ska::flat_hash_map<int64_t, DebugInfoTuple> CallStackDebugInfoUnpickler::
|
||||
size_t size,
|
||||
const ska::flat_hash_map<int64_t, SourceRange>& source_range_map,
|
||||
const std::shared_ptr<CompilationUnit>& cu) {
|
||||
auto ival = jit::unpickle(reinterpret_cast<const char*>(data.get()), size);
|
||||
auto ival = jit::unpickle(
|
||||
reinterpret_cast<const char*>(data.get()),
|
||||
size,
|
||||
nullptr,
|
||||
{},
|
||||
c10::parseType);
|
||||
ska::flat_hash_map<int64_t, DebugInfoTuple> callstack_ptrs;
|
||||
auto ivalues = std::move(*std::move(ival).toTuple()).elements();
|
||||
for (auto& val : ivalues) {
|
||||
|
||||
@ -176,6 +176,7 @@ IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) {
|
||||
obj_loader,
|
||||
device_,
|
||||
*reader_.get(),
|
||||
nullptr,
|
||||
storage_context_);
|
||||
}
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ IValue readArchiveAndTensors(
|
||||
c10::optional<ObjLoader> obj_loader,
|
||||
c10::optional<at::Device> device,
|
||||
caffe2::serialize::PyTorchStreamReader& stream_reader,
|
||||
c10::TypePtr (*type_parser)(const std::string&),
|
||||
std::shared_ptr<DeserializationStorageContext> storage_context) {
|
||||
std::string picklename = pickle_prefix + archive_name + ".pkl";
|
||||
at::DataPtr pickle_ptr;
|
||||
@ -47,6 +48,7 @@ IValue readArchiveAndTensors(
|
||||
std::move(read_record),
|
||||
device,
|
||||
false,
|
||||
type_parser,
|
||||
storage_context);
|
||||
unpickler.set_version(stream_reader.version());
|
||||
return unpickler.parse_ivalue();
|
||||
|
||||
@ -20,6 +20,8 @@ TORCH_API IValue readArchiveAndTensors(
|
||||
c10::optional<ObjLoader> obj_loader,
|
||||
c10::optional<at::Device> device,
|
||||
caffe2::serialize::PyTorchStreamReader& stream_reader,
|
||||
c10::TypePtr (*type_parser)(const std::string&) =
|
||||
Unpickler::defaultTypeParser,
|
||||
std::shared_ptr<DeserializationStorageContext> storage_context = nullptr);
|
||||
|
||||
bool check_zip_file(
|
||||
|
||||
@ -120,9 +120,10 @@ IValue pickle_load(const std::vector<char>& data) {
|
||||
IValue unpickle(
|
||||
std::function<size_t(char*, size_t)> reader,
|
||||
TypeResolver type_resolver,
|
||||
c10::ArrayRef<at::Tensor> tensor_table) {
|
||||
c10::ArrayRef<at::Tensor> tensor_table,
|
||||
c10::TypePtr (*type_parser)(const std::string&)) {
|
||||
Unpickler unpickler(
|
||||
std::move(reader), std::move(type_resolver), tensor_table);
|
||||
std::move(reader), std::move(type_resolver), tensor_table, type_parser);
|
||||
return unpickler.parse_ivalue();
|
||||
}
|
||||
|
||||
@ -130,7 +131,8 @@ IValue unpickle(
|
||||
const char* data,
|
||||
size_t size,
|
||||
TypeResolver type_resolver,
|
||||
c10::ArrayRef<at::Tensor> tensor_table) {
|
||||
c10::ArrayRef<at::Tensor> tensor_table,
|
||||
c10::TypePtr (*type_parser)(const std::string&)) {
|
||||
size_t bytes_read = 0;
|
||||
return unpickle(
|
||||
[&](char* buffer, size_t len) -> size_t {
|
||||
@ -145,7 +147,8 @@ IValue unpickle(
|
||||
return len;
|
||||
},
|
||||
std::move(type_resolver),
|
||||
tensor_table);
|
||||
tensor_table,
|
||||
type_parser);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
||||
@ -69,7 +69,9 @@ TORCH_API IValue pickle_load(const std::vector<char>& data);
|
||||
TORCH_API IValue unpickle(
|
||||
std::function<size_t(char*, size_t)> reader,
|
||||
TypeResolver type_resolver,
|
||||
c10::ArrayRef<at::Tensor> tensor_table);
|
||||
c10::ArrayRef<at::Tensor> tensor_table,
|
||||
c10::TypePtr (*type_parser)(const std::string&) =
|
||||
Unpickler::defaultTypeParser);
|
||||
|
||||
/// Decode a chunk of memory containing pickled data into its `torch::IValue`s.
|
||||
///
|
||||
@ -81,7 +83,9 @@ TORCH_API IValue unpickle(
|
||||
const char* data,
|
||||
size_t size,
|
||||
TypeResolver type_resolver = nullptr,
|
||||
c10::ArrayRef<at::Tensor> tensor_table = {});
|
||||
c10::ArrayRef<at::Tensor> tensor_table = {},
|
||||
c10::TypePtr (*type_parser)(const std::string&) =
|
||||
Unpickler::defaultTypeParser);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#include <torch/csrc/jit/serialization/source_range_serialization.h>
|
||||
#include <torch/csrc/jit/serialization/source_range_serialization_impl.h>
|
||||
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
#include <torch/csrc/jit/serialization/pickle.h>
|
||||
|
||||
namespace torch {
|
||||
@ -111,8 +112,13 @@ void ConcreteSourceRangeUnpickler::unpickle() {
|
||||
return;
|
||||
}
|
||||
|
||||
auto ivaluesTuple =
|
||||
jit::unpickle(reinterpret_cast<const char*>(data.get()), size).toTuple();
|
||||
auto ivaluesTuple = jit::unpickle(
|
||||
reinterpret_cast<const char*>(data.get()),
|
||||
size,
|
||||
nullptr,
|
||||
{},
|
||||
c10::parseType)
|
||||
.toTuple();
|
||||
const auto& ivalues = ivaluesTuple->elements();
|
||||
|
||||
unpickled_records = std::make_shared<SourceRangeRecords>();
|
||||
|
||||
@ -565,7 +565,7 @@ void Unpickler::readGlobal(
|
||||
if (type_resolver_ == nullptr) {
|
||||
// If we haven't injected a custom way of retrieving types from
|
||||
// names, use a barebones type parser.
|
||||
type = c10::parseType(type_str);
|
||||
type = type_parser_(type_str);
|
||||
} else {
|
||||
type = type_resolver_(type_str).type_;
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/jit/frontend/script_type_parser.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
|
||||
namespace torch {
|
||||
@ -24,6 +25,8 @@ class DeserializationStorageContext;
|
||||
class TORCH_API Unpickler {
|
||||
TH_DISALLOW_COPY_AND_ASSIGN(Unpickler);
|
||||
|
||||
using TypeParserT = c10::TypePtr (*)(const std::string&);
|
||||
|
||||
public:
|
||||
// tensors inside the pickle are references to the tensor_table.
|
||||
// class_resolver is to resolve strong class type, type_resolver_ is
|
||||
@ -34,11 +37,13 @@ class TORCH_API Unpickler {
|
||||
Unpickler(
|
||||
std::function<size_t(char*, size_t)> reader,
|
||||
TypeResolver type_resolver,
|
||||
c10::ArrayRef<at::Tensor> tensor_table)
|
||||
c10::ArrayRef<at::Tensor> tensor_table,
|
||||
TypeParserT type_parser = defaultTypeParser)
|
||||
: reader_(std::move(reader)),
|
||||
tensor_table_(tensor_table),
|
||||
type_resolver_(std::move(type_resolver)),
|
||||
use_storage_device_(false),
|
||||
type_parser_(type_parser),
|
||||
version_(caffe2::serialize::kProducedFileFormatVersion) {}
|
||||
|
||||
// tensors inside the pickle contain meta-data, the raw tensor
|
||||
@ -51,6 +56,7 @@ class TORCH_API Unpickler {
|
||||
std::function<at::DataPtr(const std::string&)> read_record,
|
||||
c10::optional<at::Device> device,
|
||||
bool use_storage_device = false,
|
||||
TypeParserT type_parser = defaultTypeParser,
|
||||
std::shared_ptr<DeserializationStorageContext> storage_context = nullptr)
|
||||
: reader_(std::move(reader)),
|
||||
tensor_table_(),
|
||||
@ -60,6 +66,7 @@ class TORCH_API Unpickler {
|
||||
// NOLINTNEXTLINE(performance-move-const-arg)
|
||||
device_(std::move(device)),
|
||||
use_storage_device_(use_storage_device),
|
||||
type_parser_(type_parser),
|
||||
storage_context_(std::move(storage_context)),
|
||||
version_(caffe2::serialize::kProducedFileFormatVersion) {}
|
||||
|
||||
@ -83,6 +90,11 @@ class TORCH_API Unpickler {
|
||||
version_ = version_number;
|
||||
}
|
||||
|
||||
static c10::TypePtr defaultTypeParser(const std::string& str) {
|
||||
ScriptTypeParser parser;
|
||||
return parser.parseType(str);
|
||||
}
|
||||
|
||||
private:
|
||||
// No arguments ensures that a template argument must be specified
|
||||
// so that the number of bytes read / type read is explicit
|
||||
@ -156,6 +168,8 @@ class TORCH_API Unpickler {
|
||||
// value of this flag is false.
|
||||
const bool use_storage_device_;
|
||||
|
||||
TypeParserT type_parser_{defaultTypeParser};
|
||||
|
||||
// Used for torch.package to enable sharing of storages across
|
||||
// ScriptModules and eager modules
|
||||
std::shared_ptr<DeserializationStorageContext> storage_context_;
|
||||
|
||||
Reference in New Issue
Block a user