mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Fixes issue https://github.com/pytorch/pytorch/issues/31759: - Changes is_valid_identifier check on named tensor dimensions to allow digits if they are not at the beginning of the name (this allows exactly the ASCII subset of [valid python identifiers](https://docs.python.org/3/reference/lexical_analysis.html#identifiers)). - Updates error message for illegal dimension names. - Updates and adds relevant tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/40871 Reviewed By: pbelevich Differential Revision: D22357314 Pulled By: zou3519 fbshipit-source-id: 9550a1136dd0673dd30a5cd5ade28069ba4c9086
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5db5a0f2bb
commit
6aabd12390
@ -16,12 +16,18 @@ std::ostream& operator<<(std::ostream& out, const Dimname& dimname) {
|
||||
}
|
||||
|
||||
bool Dimname::isValidName(const std::string& name) {
|
||||
// allow valid ASCII python identifiers: "uppercase and lowercase
|
||||
// letters A through Z, the underscore _ and, except for the first
|
||||
// character, the digits 0 through 9" (at least length 1)
|
||||
// https://docs.python.org/3/reference/lexical_analysis.html#identifiers
|
||||
if (name.length() == 0) {
|
||||
return false;
|
||||
}
|
||||
for (auto it = name.begin(); it != name.end(); ++it) {
|
||||
if (std::isalpha(*it) || *it == '_') {
|
||||
continue;
|
||||
} else if (it != name.begin() && std::isdigit(*it)) {
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -31,7 +37,8 @@ bool Dimname::isValidName(const std::string& name) {
|
||||
static void check_valid_identifier(const std::string& name) {
|
||||
TORCH_CHECK(
|
||||
Dimname::isValidName(name),
|
||||
"Invalid name: a valid identifier must contain alphabetical characters and/or underscore, got: '",
|
||||
"Invalid name: a valid identifier contains only digits, alphabetical "
|
||||
"characters, and/or underscore and starts with a non-digit. got: '",
|
||||
name, "'.");
|
||||
}
|
||||
|
||||
|
||||
@ -14,14 +14,19 @@ TEST(DimnameTest, isValidIdentifier) {
|
||||
ASSERT_TRUE(Dimname::isValidName("N"));
|
||||
ASSERT_TRUE(Dimname::isValidName("CHANNELS"));
|
||||
ASSERT_TRUE(Dimname::isValidName("foo_bar_baz"));
|
||||
ASSERT_TRUE(Dimname::isValidName("batch1"));
|
||||
ASSERT_TRUE(Dimname::isValidName("batch_9"));
|
||||
ASSERT_TRUE(Dimname::isValidName("_"));
|
||||
ASSERT_TRUE(Dimname::isValidName("_1"));
|
||||
|
||||
ASSERT_FALSE(Dimname::isValidName(""));
|
||||
ASSERT_FALSE(Dimname::isValidName(" "));
|
||||
ASSERT_FALSE(Dimname::isValidName(" a "));
|
||||
ASSERT_FALSE(Dimname::isValidName("batch1"));
|
||||
ASSERT_FALSE(Dimname::isValidName("foo_bar_1"));
|
||||
ASSERT_FALSE(Dimname::isValidName("1batch"));
|
||||
ASSERT_FALSE(Dimname::isValidName("?"));
|
||||
ASSERT_FALSE(Dimname::isValidName("-"));
|
||||
ASSERT_FALSE(Dimname::isValidName("1"));
|
||||
ASSERT_FALSE(Dimname::isValidName("01"));
|
||||
}
|
||||
|
||||
TEST(DimnameTest, wildcardName) {
|
||||
@ -36,7 +41,7 @@ TEST(DimnameTest, createNormalName) {
|
||||
ASSERT_EQ(dimname.type(), NameType::BASIC);
|
||||
ASSERT_EQ(dimname.symbol(), foo);
|
||||
ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("inva.lid")), c10::Error);
|
||||
ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("invalid1")), c10::Error);
|
||||
ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("1invalid")), c10::Error);
|
||||
}
|
||||
|
||||
static void check_unify_and_match(
|
||||
|
||||
@ -116,8 +116,15 @@ class TestNamedTensor(TestCase):
|
||||
x = factory(1, 2, 3, names=('N', None, 'D'), device=device)
|
||||
self.assertEqual(x.names, ('N', None, 'D'))
|
||||
|
||||
x = factory(1, 2, 3, names=('_1', 'batch9', 'BATCH_5'), device=device)
|
||||
self.assertEqual(x.names, ('_1', 'batch9', 'BATCH_5'))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
'must contain alphabetical characters and/or underscore'):
|
||||
'a valid identifier contains only'):
|
||||
x = factory(2, names=('1',), device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
'a valid identifier contains only'):
|
||||
x = factory(2, names=('?',), device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'Number of names'):
|
||||
|
||||
Reference in New Issue
Block a user