Fix StringCoordView::substr after D73379178 / #151810 (#152304)

Received complaint that we broke something. After a bunch of debugging, landed on this test + fix.

Differential Revision: [D73754877](https://our.internmc.facebook.com/intern/diff/D73754877/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D73754877/)!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152304
Approved by: https://github.com/Skylion007
This commit is contained in:
Scott Wolchok
2025-04-27 22:37:38 -07:00
committed by PyTorch MergeBot
parent ad11d6378c
commit 520366e102
3 changed files with 19 additions and 2 deletions

View File

@ -1,4 +1,6 @@
#include <gtest/gtest.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/frontend/source_range.h>
using namespace ::testing;
@ -33,6 +35,12 @@ TEST(SourceRangeTest, test_substr) {
auto x = view.substr(4, 10).str();
EXPECT_EQ(x, view.str().substr(4, 10));
EXPECT_EQ(view.substr(0, view.size()).str(), view.str());
for (const auto start : c10::irange(view.size())) {
for (const auto size : c10::irange(view.size())) {
EXPECT_EQ(
view.substr(start, size).str(), view.str().substr(start, size));
}
}
}
}

View File

@ -80,8 +80,8 @@ StringCordView StringCordView::substr(size_t start, size_t size) const {
if (start + size >= this->size()) {
size = this->size() - start;
}
IteratorImpl begin = IteratorImpl(this) + start;
IteratorImpl end = begin + size;
IteratorImpl begin = iter_impl_for_pos(start);
IteratorImpl end = iter_impl_for_pos(start + size);
if (begin.line_ == end.line_) {
// same line
@ -136,6 +136,14 @@ StringCordView::Iterator StringCordView::iter_for_pos(size_t pos) const {
return begin() + pos;
}
StringCordView::IteratorImpl StringCordView::iter_impl_for_pos(
size_t pos) const {
if (pos >= size()) {
return end_impl();
}
return begin_impl() + pos;
}
StringCordView::IteratorImpl& StringCordView::IteratorImpl::operator+=(
size_t num) {
if (!has_next()) {

View File

@ -313,6 +313,7 @@ struct TORCH_API StringCordView {
IteratorImpl end_impl() const {
return IteratorImpl(this, pieces_.size(), 0, 0);
}
IteratorImpl iter_impl_for_pos(size_t pos) const;
std::vector<std::string_view> pieces_;
std::vector<size_t> accumulated_sizes_;
std::vector<std::shared_ptr<std::string>> owned_strings_;