mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add sv starts/ends_with (#139261)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/139261 Approved by: https://github.com/Skylion007 Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
2a309c0997
commit
8ace3e8023
@ -598,6 +598,20 @@ constexpr inline void swap(
|
||||
}
|
||||
using string_view = basic_string_view<char>;
|
||||
|
||||
// NOTE: In C++20, this function should be replaced by str.starts_with
|
||||
constexpr bool string_view_starts_with(
|
||||
std::string_view str,
|
||||
std::string_view prefix) noexcept {
|
||||
return str.size() >= prefix.size() && str.substr(0, prefix.size()) == prefix;
|
||||
}
|
||||
|
||||
// NOTE: In C++20, this function should be replaced by str.ends_with
|
||||
constexpr bool string_view_ends_with(
|
||||
std::string_view str,
|
||||
std::string_view suffix) noexcept {
|
||||
return str.size() >= suffix.size() &&
|
||||
str.substr(str.size() - suffix.size()) == suffix;
|
||||
}
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
@ -29,7 +29,7 @@
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
constexpr c10::string_view kDebugPklSuffix(".debug_pkl");
|
||||
constexpr std::string_view kDebugPklSuffix(".debug_pkl");
|
||||
|
||||
struct MzZipReaderIterWrapper {
|
||||
MzZipReaderIterWrapper(mz_zip_reader_extract_iter_state* iter) : impl(iter) {}
|
||||
@ -283,7 +283,7 @@ size_t getPadding(
|
||||
bool PyTorchStreamReader::hasRecord(const std::string& name) {
|
||||
std::lock_guard<std::mutex> guard(reader_lock_);
|
||||
|
||||
if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
|
||||
if ((!load_debug_symbol_) && c10::string_view_ends_with(std::string_view(name), kDebugPklSuffix)) {
|
||||
return false;
|
||||
}
|
||||
std::string ss = archive_name_plus_slash_ + name;
|
||||
@ -320,7 +320,7 @@ std::vector<std::string> PyTorchStreamReader::getAllRecords() {
|
||||
buf);
|
||||
}
|
||||
if ((load_debug_symbol_) ||
|
||||
(!c10::string_view(buf + archive_name_plus_slash_.size()).ends_with(kDebugPklSuffix))) {
|
||||
(!c10::string_view_ends_with(std::string_view(buf + archive_name_plus_slash_.size()),kDebugPklSuffix))) {
|
||||
// NOLINTNEXTLINE(modernize-use-emplace)
|
||||
out.push_back(buf + archive_name_plus_slash_.size());
|
||||
}
|
||||
@ -343,7 +343,7 @@ size_t PyTorchStreamReader::getRecordID(const std::string& name) {
|
||||
// return dataptr, size
|
||||
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) {
|
||||
std::lock_guard<std::mutex> guard(reader_lock_);
|
||||
if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
|
||||
if ((!load_debug_symbol_) && c10::string_view_ends_with(name, kDebugPklSuffix)) {
|
||||
at::DataPtr retval;
|
||||
return std::make_tuple(std::move(retval), 0);
|
||||
}
|
||||
@ -424,7 +424,7 @@ PyTorchStreamReader::getRecord(const std::string& name,
|
||||
return getRecord(name);
|
||||
}
|
||||
|
||||
if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
|
||||
if ((!load_debug_symbol_) && c10::string_view_ends_with(name, kDebugPklSuffix)) {
|
||||
at::DataPtr retval;
|
||||
return std::make_tuple(std::move(retval), 0);
|
||||
}
|
||||
@ -448,7 +448,7 @@ PyTorchStreamReader::getRecord(const std::string& name,
|
||||
size_t
|
||||
PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n) {
|
||||
std::lock_guard<std::mutex> guard(reader_lock_);
|
||||
if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
|
||||
if ((!load_debug_symbol_) && c10::string_view_ends_with(name, kDebugPklSuffix)) {
|
||||
return 0;
|
||||
}
|
||||
size_t key = getRecordID(name);
|
||||
@ -508,7 +508,7 @@ size_t PyTorchStreamReader::getRecord(
|
||||
void* buf,
|
||||
const std::function<void(void*, const void*, size_t)>& memcpy_func) {
|
||||
std::lock_guard<std::mutex> guard(reader_lock_);
|
||||
if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
|
||||
if ((!load_debug_symbol_) && c10::string_view_ends_with(name, kDebugPklSuffix)) {
|
||||
return 0;
|
||||
}
|
||||
if (chunk_size <= 0) {
|
||||
|
@ -115,9 +115,9 @@ MobileDebugTable::MobileDebugTable(
|
||||
const std::shared_ptr<CompilationUnit>& cu) {
|
||||
ska::flat_hash_map<int64_t, SourceRange> source_range_map;
|
||||
const std::vector<std::string>& record_names = reader->getAllRecords();
|
||||
const c10::string_view suffix(".debug_pkl");
|
||||
constexpr std::string_view suffix(".debug_pkl");
|
||||
for (const auto& record_name : record_names) {
|
||||
if (c10::string_view(record_name).ends_with(suffix)) {
|
||||
if (c10::string_view_ends_with(std::string_view(record_name), suffix)) {
|
||||
auto [debug_data, debug_size] = reader->getRecord(record_name);
|
||||
auto ivalueTuple = jit::unpickle(
|
||||
reinterpret_cast<const char*>(debug_data.get()),
|
||||
|
@ -70,10 +70,9 @@ static_assert(
|
||||
|
||||
namespace {
|
||||
|
||||
static constexpr c10::string_view kCustomClassPrefix =
|
||||
"__torch__.torch.classes";
|
||||
static constexpr c10::string_view kTorchPrefix = "__torch__";
|
||||
static constexpr c10::string_view kJitPrefix = "torch.jit";
|
||||
static constexpr auto kCustomClassPrefix = "__torch__.torch.classes";
|
||||
static constexpr auto kTorchPrefix = "__torch__";
|
||||
static constexpr auto kJitPrefix = "torch.jit";
|
||||
|
||||
class FlatbufferLoader final {
|
||||
public:
|
||||
@ -188,13 +187,14 @@ TypePtr resolveType(
|
||||
const std::string& type_string,
|
||||
const std::shared_ptr<CompilationUnit>& cu) {
|
||||
TypePtr type;
|
||||
c10::string_view type_str(type_string);
|
||||
if (type_str.starts_with(kCustomClassPrefix)) {
|
||||
std::string_view type_str(type_string);
|
||||
if (c10::string_view_starts_with(type_str, kCustomClassPrefix)) {
|
||||
type = getCustomClass(type_string);
|
||||
TORCH_CHECK(
|
||||
type, "The implementation of class ", type_string, " cannot be found.");
|
||||
} else if (
|
||||
type_str.starts_with(kTorchPrefix) || type_str.starts_with(kJitPrefix)) {
|
||||
c10::string_view_starts_with(type_str, kTorchPrefix) ||
|
||||
c10::string_view_starts_with(type_str, kJitPrefix)) {
|
||||
c10::QualifiedName qn(type_string);
|
||||
if (cu->get_class(qn) == nullptr) {
|
||||
auto classtype = ClassType::create(qn, cu, true);
|
||||
@ -607,9 +607,10 @@ ClassTypePtr FlatbufferLoader::getOrCreateClassTypeForObject(
|
||||
const mobile::serialization::ObjectType* obj_type =
|
||||
module_->object_types()->Get(object->type_index());
|
||||
if (cls == nullptr) {
|
||||
c10::string_view qn_str(
|
||||
std::string_view qn_str(
|
||||
obj_type->type_name()->c_str(), obj_type->type_name()->size());
|
||||
if (qn_str.starts_with(kTorchPrefix) || qn_str.starts_with(kJitPrefix)) {
|
||||
if (c10::string_view_starts_with(qn_str, kTorchPrefix) ||
|
||||
c10::string_view_starts_with(qn_str, kJitPrefix)) {
|
||||
c10::QualifiedName qn(obj_type->type_name()->str());
|
||||
cls = cu_->get_class(qn);
|
||||
if (cls == nullptr) {
|
||||
|
Reference in New Issue
Block a user