fix issue #31759 (allow valid ASCII python identifiers as dimnames) (#40871)

Summary:
Fixes issue https://github.com/pytorch/pytorch/issues/31759:
- Changes is_valid_identifier check on named tensor dimensions to allow digits if they are not at the beginning of the name (this allows exactly the ASCII subset of [valid python identifiers](https://docs.python.org/3/reference/lexical_analysis.html#identifiers)).
- Updates error message for illegal dimension names.
- Updates and adds relevant tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40871

Reviewed By: pbelevich

Differential Revision: D22357314

Pulled By: zou3519

fbshipit-source-id: 9550a1136dd0673dd30a5cd5ade28069ba4c9086
This commit is contained in:
Adam Teichert
2020-07-02 11:32:35 -07:00
committed by Facebook GitHub Bot
parent 5db5a0f2bb
commit 6aabd12390
3 changed files with 24 additions and 5 deletions

View File

@ -16,12 +16,18 @@ std::ostream& operator<<(std::ostream& out, const Dimname& dimname) {
}
bool Dimname::isValidName(const std::string& name) {
// allow valid ASCII python identifiers: "uppercase and lowercase
// letters A through Z, the underscore _ and, except for the first
// character, the digits 0 through 9" (at least length 1)
// https://docs.python.org/3/reference/lexical_analysis.html#identifiers
if (name.length() == 0) {
return false;
}
for (auto it = name.begin(); it != name.end(); ++it) {
if (std::isalpha(*it) || *it == '_') {
continue;
} else if (it != name.begin() && std::isdigit(*it)) {
continue;
}
return false;
}
@ -31,7 +37,8 @@ bool Dimname::isValidName(const std::string& name) {
static void check_valid_identifier(const std::string& name) {
TORCH_CHECK(
Dimname::isValidName(name),
"Invalid name: a valid identifier must contain alphabetical characters and/or underscore, got: '",
"Invalid name: a valid identifier contains only digits, alphabetical "
"characters, and/or underscore and starts with a non-digit. got: '",
name, "'.");
}

View File

@ -14,14 +14,19 @@ TEST(DimnameTest, isValidIdentifier) {
ASSERT_TRUE(Dimname::isValidName("N"));
ASSERT_TRUE(Dimname::isValidName("CHANNELS"));
ASSERT_TRUE(Dimname::isValidName("foo_bar_baz"));
ASSERT_TRUE(Dimname::isValidName("batch1"));
ASSERT_TRUE(Dimname::isValidName("batch_9"));
ASSERT_TRUE(Dimname::isValidName("_"));
ASSERT_TRUE(Dimname::isValidName("_1"));
ASSERT_FALSE(Dimname::isValidName(""));
ASSERT_FALSE(Dimname::isValidName(" "));
ASSERT_FALSE(Dimname::isValidName(" a "));
ASSERT_FALSE(Dimname::isValidName("batch1"));
ASSERT_FALSE(Dimname::isValidName("foo_bar_1"));
ASSERT_FALSE(Dimname::isValidName("1batch"));
ASSERT_FALSE(Dimname::isValidName("?"));
ASSERT_FALSE(Dimname::isValidName("-"));
ASSERT_FALSE(Dimname::isValidName("1"));
ASSERT_FALSE(Dimname::isValidName("01"));
}
TEST(DimnameTest, wildcardName) {
@ -36,7 +41,7 @@ TEST(DimnameTest, createNormalName) {
ASSERT_EQ(dimname.type(), NameType::BASIC);
ASSERT_EQ(dimname.symbol(), foo);
ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("inva.lid")), c10::Error);
ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("invalid1")), c10::Error);
ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("1invalid")), c10::Error);
}
static void check_unify_and_match(

View File

@ -116,8 +116,15 @@ class TestNamedTensor(TestCase):
x = factory(1, 2, 3, names=('N', None, 'D'), device=device)
self.assertEqual(x.names, ('N', None, 'D'))
x = factory(1, 2, 3, names=('_1', 'batch9', 'BATCH_5'), device=device)
self.assertEqual(x.names, ('_1', 'batch9', 'BATCH_5'))
with self.assertRaisesRegex(RuntimeError,
'must contain alphabetical characters and/or underscore'):
'a valid identifier contains only'):
x = factory(2, names=('1',), device=device)
with self.assertRaisesRegex(RuntimeError,
'a valid identifier contains only'):
x = factory(2, names=('?',), device=device)
with self.assertRaisesRegex(RuntimeError, 'Number of names'):