Add sv starts/ends_with (#139261)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139261
Approved by: https://github.com/Skylion007

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
This commit is contained in:
cyyever
2024-11-01 01:17:42 +00:00
committed by PyTorch MergeBot
parent 2a309c0997
commit 8ace3e8023
4 changed files with 33 additions and 18 deletions

View File

@ -598,6 +598,20 @@ constexpr inline void swap(
}
using string_view = basic_string_view<char>;
// NOTE: In C++20, this function should be replaced by str.starts_with
constexpr bool string_view_starts_with(
std::string_view str,
std::string_view prefix) noexcept {
return str.size() >= prefix.size() && str.substr(0, prefix.size()) == prefix;
}
// NOTE: In C++20, this function should be replaced by str.ends_with
constexpr bool string_view_ends_with(
std::string_view str,
std::string_view suffix) noexcept {
return str.size() >= suffix.size() &&
str.substr(str.size() - suffix.size()) == suffix;
}
} // namespace c10
namespace std {

View File

@ -29,7 +29,7 @@
namespace caffe2 {
namespace serialize {
constexpr c10::string_view kDebugPklSuffix(".debug_pkl");
constexpr std::string_view kDebugPklSuffix(".debug_pkl");
struct MzZipReaderIterWrapper {
MzZipReaderIterWrapper(mz_zip_reader_extract_iter_state* iter) : impl(iter) {}
@ -283,7 +283,7 @@ size_t getPadding(
bool PyTorchStreamReader::hasRecord(const std::string& name) {
std::lock_guard<std::mutex> guard(reader_lock_);
if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
if ((!load_debug_symbol_) && c10::string_view_ends_with(std::string_view(name), kDebugPklSuffix)) {
return false;
}
std::string ss = archive_name_plus_slash_ + name;
@ -320,7 +320,7 @@ std::vector<std::string> PyTorchStreamReader::getAllRecords() {
buf);
}
if ((load_debug_symbol_) ||
(!c10::string_view(buf + archive_name_plus_slash_.size()).ends_with(kDebugPklSuffix))) {
(!c10::string_view_ends_with(std::string_view(buf + archive_name_plus_slash_.size()),kDebugPklSuffix))) {
// NOLINTNEXTLINE(modernize-use-emplace)
out.push_back(buf + archive_name_plus_slash_.size());
}
@ -343,7 +343,7 @@ size_t PyTorchStreamReader::getRecordID(const std::string& name) {
// return dataptr, size
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) {
std::lock_guard<std::mutex> guard(reader_lock_);
if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
if ((!load_debug_symbol_) && c10::string_view_ends_with(name, kDebugPklSuffix)) {
at::DataPtr retval;
return std::make_tuple(std::move(retval), 0);
}
@ -424,7 +424,7 @@ PyTorchStreamReader::getRecord(const std::string& name,
return getRecord(name);
}
if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
if ((!load_debug_symbol_) && c10::string_view_ends_with(name, kDebugPklSuffix)) {
at::DataPtr retval;
return std::make_tuple(std::move(retval), 0);
}
@ -448,7 +448,7 @@ PyTorchStreamReader::getRecord(const std::string& name,
size_t
PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n) {
std::lock_guard<std::mutex> guard(reader_lock_);
if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
if ((!load_debug_symbol_) && c10::string_view_ends_with(name, kDebugPklSuffix)) {
return 0;
}
size_t key = getRecordID(name);
@ -508,7 +508,7 @@ size_t PyTorchStreamReader::getRecord(
void* buf,
const std::function<void(void*, const void*, size_t)>& memcpy_func) {
std::lock_guard<std::mutex> guard(reader_lock_);
if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
if ((!load_debug_symbol_) && c10::string_view_ends_with(name, kDebugPklSuffix)) {
return 0;
}
if (chunk_size <= 0) {

View File

@ -115,9 +115,9 @@ MobileDebugTable::MobileDebugTable(
const std::shared_ptr<CompilationUnit>& cu) {
ska::flat_hash_map<int64_t, SourceRange> source_range_map;
const std::vector<std::string>& record_names = reader->getAllRecords();
const c10::string_view suffix(".debug_pkl");
constexpr std::string_view suffix(".debug_pkl");
for (const auto& record_name : record_names) {
if (c10::string_view(record_name).ends_with(suffix)) {
if (c10::string_view_ends_with(std::string_view(record_name), suffix)) {
auto [debug_data, debug_size] = reader->getRecord(record_name);
auto ivalueTuple = jit::unpickle(
reinterpret_cast<const char*>(debug_data.get()),

View File

@ -70,10 +70,9 @@ static_assert(
namespace {
static constexpr c10::string_view kCustomClassPrefix =
"__torch__.torch.classes";
static constexpr c10::string_view kTorchPrefix = "__torch__";
static constexpr c10::string_view kJitPrefix = "torch.jit";
static constexpr auto kCustomClassPrefix = "__torch__.torch.classes";
static constexpr auto kTorchPrefix = "__torch__";
static constexpr auto kJitPrefix = "torch.jit";
class FlatbufferLoader final {
public:
@ -188,13 +187,14 @@ TypePtr resolveType(
const std::string& type_string,
const std::shared_ptr<CompilationUnit>& cu) {
TypePtr type;
c10::string_view type_str(type_string);
if (type_str.starts_with(kCustomClassPrefix)) {
std::string_view type_str(type_string);
if (c10::string_view_starts_with(type_str, kCustomClassPrefix)) {
type = getCustomClass(type_string);
TORCH_CHECK(
type, "The implementation of class ", type_string, " cannot be found.");
} else if (
type_str.starts_with(kTorchPrefix) || type_str.starts_with(kJitPrefix)) {
c10::string_view_starts_with(type_str, kTorchPrefix) ||
c10::string_view_starts_with(type_str, kJitPrefix)) {
c10::QualifiedName qn(type_string);
if (cu->get_class(qn) == nullptr) {
auto classtype = ClassType::create(qn, cu, true);
@ -607,9 +607,10 @@ ClassTypePtr FlatbufferLoader::getOrCreateClassTypeForObject(
const mobile::serialization::ObjectType* obj_type =
module_->object_types()->Get(object->type_index());
if (cls == nullptr) {
c10::string_view qn_str(
std::string_view qn_str(
obj_type->type_name()->c_str(), obj_type->type_name()->size());
if (qn_str.starts_with(kTorchPrefix) || qn_str.starts_with(kJitPrefix)) {
if (c10::string_view_starts_with(qn_str, kTorchPrefix) ||
c10::string_view_starts_with(qn_str, kJitPrefix)) {
c10::QualifiedName qn(obj_type->type_name()->str());
cls = cu_->get_class(qn);
if (cls == nullptr) {