mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit bd72e28314d8d63bb347becb8309f5ac7761c6b5.
Reverted https://github.com/pytorch/pytorch/pull/128301 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it fails XLA build bd72e28314
. Please rebase your PR before relanding because I think the failure is hidden by an unrelated broken trunk XLA failure from your current base commit ([comment](https://github.com/pytorch/pytorch/pull/128301#issuecomment-2169035822))
260 lines
8.5 KiB
C++
260 lines
8.5 KiB
C++
#include <torch/csrc/jit/serialization/source_range_serialization.h>
|
|
#include <torch/csrc/jit/serialization/source_range_serialization_impl.h>
|
|
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/Flags.h>
|
|
#include <torch/csrc/jit/mobile/type_parser.h>
|
|
#include <torch/csrc/jit/serialization/pickle.h>
|
|
#include <algorithm>
|
|
|
|
namespace torch::jit {
|
|
|
|
// "Whether to emit compact debug_pkl when saving a model to .pt file."
|
|
// "Compact file is smaller but cannot be loaded by old torch binaries."
|
|
// TODO(qihan) remove when all binaries are using string table.
|
|
thread_local bool should_use_format_with_string_table_ = true;
|
|
|
|
class SourceRangeSerializer {
|
|
public:
|
|
// Serialize SourceRange as Tuple[SourceType, int, int]
|
|
// where SourceType = Tuple[int, int, int, List[int]],
|
|
// The first 2 ints are positions into the vector returned by textSaved
|
|
// after all the Ranges are processed. textSaved() returns a vector of str
|
|
// the serialized form of Source
|
|
c10::IValue serialize(const SourceRange& sr);
|
|
|
|
const std::vector<c10::IValue>& texts_saved() {
|
|
return texts_;
|
|
}
|
|
|
|
SourceRangeSerializer() {
|
|
texts_.emplace_back("");
|
|
text_to_idx_[texts_.back().toStringRef()] = 0;
|
|
}
|
|
|
|
private:
|
|
// Serialize Source as Tuple[str, Optional[str], int, List[int]]
|
|
// This caches serialized sources, since many SourceRanges can
|
|
// refer to the same one.
|
|
c10::IValue serialize_source(const std::shared_ptr<Source>& s);
|
|
std::unordered_map<std::shared_ptr<Source>, c10::IValue> serialized_sources;
|
|
|
|
int64_t store_text_and_get_index(const std::string& text_view);
|
|
|
|
std::vector<c10::IValue> texts_;
|
|
std::unordered_map<c10::string_view, int64_t> text_to_idx_;
|
|
};
|
|
|
|
SourceRange SourceRangeDeserializer::deserialize(const c10::IValue& iv) {
|
|
const auto& tup_elems = iv.toTupleRef().elements();
|
|
TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
|
|
std::shared_ptr<Source> source_ = deserialize_source(tup_elems[0]);
|
|
int64_t start_ = tup_elems[1].toInt();
|
|
int64_t end_ = tup_elems[2].toInt();
|
|
return SourceRange(source_, start_, end_);
|
|
}
|
|
|
|
std::shared_ptr<Source> SourceRangeDeserializer::deserialize_source(
|
|
const c10::IValue& iv) {
|
|
auto tup = iv.toTuple();
|
|
auto it = cached_sources.find(tup);
|
|
if (it != cached_sources.end()) {
|
|
return it->second;
|
|
}
|
|
std::shared_ptr<Source> source;
|
|
const auto& tup_elems = tup->elements();
|
|
TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
|
|
if (!text_table_.empty()) {
|
|
const auto& textIndex = tup_elems[0].toIntList();
|
|
int64_t fnameIndex = tup_elems[1].toInt();
|
|
int64_t starting_line_no_ = tup_elems[2].toInt();
|
|
std::optional<std::string> filename = c10::nullopt;
|
|
|
|
TORCH_CHECK(
|
|
(uint64_t)fnameIndex < text_table_.size(),
|
|
"Text table index is out of range")
|
|
filename = *text_table_[fnameIndex];
|
|
|
|
std::vector<c10::string_view> pieces;
|
|
std::vector<std::shared_ptr<std::string>> strs;
|
|
|
|
for (int64_t i : textIndex) {
|
|
pieces.emplace_back(*text_table_[i]);
|
|
strs.emplace_back(text_table_[i]);
|
|
}
|
|
|
|
StringCordView str_cord(std::move(pieces), std::move(strs));
|
|
|
|
source = std::make_shared<Source>(str_cord, filename, starting_line_no_);
|
|
} else {
|
|
std::string text_ = tup_elems[0].toStringRef();
|
|
std::optional<std::string> filename_ =
|
|
tup_elems[1].toOptional<std::string>();
|
|
int64_t starting_line_no_ = tup_elems[2].toInt();
|
|
source = std::make_shared<Source>(
|
|
std::move(text_), std::move(filename_), starting_line_no_);
|
|
}
|
|
cached_sources[tup] = source;
|
|
return source;
|
|
}
|
|
|
|
c10::IValue SourceRangeSerializer::serialize(const SourceRange& sr) {
|
|
return c10::ivalue::Tuple::create(
|
|
serialize_source(sr.source()), (int64_t)sr.start(), (int64_t)sr.end());
|
|
}
|
|
|
|
int64_t SourceRangeSerializer::store_text_and_get_index(
|
|
const std::string& text_view) {
|
|
auto text_iter = text_to_idx_.find(text_view);
|
|
if (text_iter == text_to_idx_.end()) {
|
|
int64_t text_pos = static_cast<int64_t>(texts_.size());
|
|
texts_.emplace_back(text_view);
|
|
text_to_idx_[texts_.back().toStringView()] = text_pos;
|
|
return text_pos;
|
|
} else {
|
|
return text_iter->second;
|
|
}
|
|
}
|
|
|
|
c10::IValue SourceRangeSerializer::serialize_source(
|
|
const std::shared_ptr<Source>& s) {
|
|
if (serialized_sources.count(s)) {
|
|
return serialized_sources.at(s);
|
|
}
|
|
c10::intrusive_ptr<c10::ivalue::Tuple> serialized;
|
|
c10::List<int64_t> lines;
|
|
if (should_use_format_with_string_table_) {
|
|
if (s == nullptr) {
|
|
serialized = c10::ivalue::Tuple::create({lines, 0, 0});
|
|
} else {
|
|
for (size_t lineno = 0; lineno < s->num_lines(); lineno++) {
|
|
std::string line_content = s->get_line(lineno).str();
|
|
int64_t text_pos = store_text_and_get_index(line_content);
|
|
lines.push_back(text_pos);
|
|
}
|
|
|
|
int64_t fname_pos = 0;
|
|
if (s->filename().has_value()) {
|
|
fname_pos = store_text_and_get_index(*s->filename());
|
|
}
|
|
serialized = c10::ivalue::Tuple::create(
|
|
{lines, fname_pos, (int64_t)s->starting_line_no()});
|
|
}
|
|
} else {
|
|
if (s == nullptr) {
|
|
serialized = c10::ivalue::Tuple::create({"", "", 0});
|
|
} else {
|
|
serialized = c10::ivalue::Tuple::create(
|
|
{s->text_str().str(), s->filename(), (int64_t)s->starting_line_no()});
|
|
}
|
|
}
|
|
serialized_sources[s] = serialized;
|
|
return serialized;
|
|
}
|
|
|
|
SourceRangePickler::SourceRangePickler() : srs(new SourceRangeSerializer()) {}
|
|
|
|
std::vector<char> SourceRangePickler::pickle(
|
|
const SourceRangeRecords& ranges,
|
|
const SourceRangeTagMap& source_range_tags) {
|
|
std::vector<c10::IValue> ivalues;
|
|
for (const auto& range : ranges) {
|
|
int64_t source_range_tag{-1};
|
|
const auto& it = source_range_tags.find(range.range);
|
|
if (it != source_range_tags.end()) {
|
|
source_range_tag = it->second;
|
|
}
|
|
|
|
ivalues.emplace_back(c10::ivalue::Tuple::create(
|
|
{(int64_t)range.bytes,
|
|
srs->serialize(range.range),
|
|
static_cast<int64_t>(source_range_tag)}));
|
|
}
|
|
|
|
std::vector<at::Tensor> table;
|
|
auto textTable = c10::ivalue::Tuple::create(srs->texts_saved());
|
|
auto ivalue = c10::ivalue::Tuple::create(std::move(ivalues));
|
|
std::vector<char> result;
|
|
if (should_use_format_with_string_table_) {
|
|
result = jit::pickle(
|
|
c10::ivalue::Tuple::create({kFormatWithStringTable, textTable, ivalue}),
|
|
&table);
|
|
} else {
|
|
result = jit::pickle(ivalue, &table);
|
|
}
|
|
TORCH_CHECK(table.empty(), "Expected 0 tensors to be written");
|
|
return result;
|
|
}
|
|
|
|
ConcreteSourceRangeUnpickler::ConcreteSourceRangeUnpickler(
|
|
at::DataPtr&& data,
|
|
size_t size)
|
|
: data(std::move(data)),
|
|
size(size),
|
|
deserializer(nullptr),
|
|
unpickled_records(nullptr) {}
|
|
|
|
void ConcreteSourceRangeUnpickler::unpickle() {
|
|
std::lock_guard<std::mutex> guard(mutex);
|
|
if (unpickled_records) {
|
|
return;
|
|
}
|
|
|
|
auto ivaluesTuple = jit::unpickle(
|
|
reinterpret_cast<const char*>(data.get()),
|
|
size,
|
|
nullptr,
|
|
{},
|
|
c10::parseType)
|
|
.toTuple();
|
|
|
|
const auto& ivalues = ivaluesTuple->elements();
|
|
TORCH_CHECK(
|
|
ivalues.size(), "Invalid unpickle operation: empty ivalues tuple");
|
|
unpickled_records = std::make_shared<SourceRangeRecords>();
|
|
IValue lines;
|
|
if (ivalues[0].isString() &&
|
|
kFormatWithStringTable == ivalues[0].toStringRef()) {
|
|
deserializer.reset(new SourceRangeDeserializer(ivalues[1]));
|
|
lines = ivalues[2];
|
|
} else {
|
|
deserializer.reset(new SourceRangeDeserializer());
|
|
lines = ivaluesTuple;
|
|
}
|
|
for (auto& val : lines.toTuple()->elements()) {
|
|
const auto& tup_elems = val.toTupleRef().elements();
|
|
int64_t offset = tup_elems[kByteOffsetIndex].toInt();
|
|
auto source_range = deserializer->deserialize(tup_elems[kSourceRangeIndex]);
|
|
unpickled_records->emplace_back(offset, std::move(source_range));
|
|
}
|
|
}
|
|
|
|
std::optional<SourceRange> ConcreteSourceRangeUnpickler::
|
|
findSourceRangeThatGenerated(const SourceRange& range) {
|
|
unpickle();
|
|
|
|
auto query = TaggedRange(range.start(), SourceRange{});
|
|
auto entry = std::upper_bound(
|
|
unpickled_records->begin(),
|
|
unpickled_records->end(),
|
|
query,
|
|
[](const TaggedRange& a, const TaggedRange& b) -> bool {
|
|
return a.bytes < b.bytes;
|
|
});
|
|
|
|
// NB: must decrement iterator since upper_bound finds the element
|
|
// *greater than* the query.
|
|
if (entry != unpickled_records->begin()) {
|
|
return (entry - 1)->range;
|
|
}
|
|
|
|
return c10::nullopt;
|
|
}
|
|
|
|
TORCH_API void setShouldUseFormatWithStringTable(
|
|
bool should_use_format_with_string_table) {
|
|
should_use_format_with_string_table_ = should_use_format_with_string_table;
|
|
}
|
|
|
|
} // namespace torch::jit
|