Files
pytorch/tools/test/test_max_tokens_pragma.py
Elton Leander Pinto 711ded688d Add a script to codemod max_tokens_total pragmas to C/C++ files (#61369)
Summary:
This PR adds a new script: `max_tokens_pragmas.py`

This is a utility script that can add/remove `max_tokens_total` pragmas from the codebase.

- [x] Implement script and test manually
- [x] Write test script

Examples:
First, change directories
```bash
cd tools/linter/clang_tidy
```

Then run the following:
```bash
cat << EOF > test/test1.cpp
// File without any prior pragmas

int main() {
    for (int i = 0; i < 10; i++);
    return 0;
}
EOF

cat << EOF > test/test2.cpp
// File with prior pragmas

#pragma clang max_tokens_total 1

int main() {
    for (int i = 0; i < 10; i++);
    return 0;
}
EOF

cat << EOF > test/test3.cpp
// File with multiple prior pragmas

#pragma clang max_tokens_total 1

// Different pragma; script should ignore this
#pragma clang max_tokens_here 20

int main() {
    for (int i = 0; i < 10; i++);
    return 0;
}

#pragma clang max_tokens_total 1
EOF

# Add pragmas to some files
python3 max_tokens_pragma.py --num-max-tokens 42 test/*.cpp
grep "#pragma clang max_tokens_total 42" test/*.cpp

# Remove pragmas from files
python3 max_tokens_pragma.py --strip test/*.cpp
grep "#pragma clang max_tokens_total 42" test/*.cpp # should fail

# Ignore files
python3 max_tokens_pragma.py --num-max-tokens 42 test/*.cpp --ignores test/test2.cpp
grep "#pragma clang max_tokens_total 42" test/*.cpp # should not list `test/test2.cpp`
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61369

Test Plan: `tools/linter/clang_tidy/test/test_max_tokens_pragma.py`

Reviewed By: malfet

Differential Revision: D29604291

Pulled By: 1ntEgr8

fbshipit-source-id: 3efe52573583769041a07e6776161d4d5bbf16a7
2021-07-09 13:30:52 -07:00

133 lines
3.1 KiB
Python

import unittest
from tools.linter.clang_tidy.max_tokens_pragma import (
add_max_tokens_pragma,
strip_max_tokens_pragmas,
)
def compare_code(a: str, b: str) -> bool:
a_lines = [line.strip() for line in a.splitlines()]
b_lines = [line.strip() for line in b.splitlines()]
return a_lines == b_lines
class TestMaxTokensPragma(unittest.TestCase):
def test_no_prior_pragmas(self) -> None:
input = """\
// File without any prior pragmas
int main() {
for (int i = 0; i < 10; i++);
return 0;
}
"""
expected = """\
#pragma clang max_tokens_total 42
// File without any prior pragmas
int main() {
for (int i = 0; i < 10; i++);
return 0;
}
"""
output = add_max_tokens_pragma(input, 42)
self.assertTrue(compare_code(output, expected))
output = strip_max_tokens_pragmas(output)
self.assertTrue(compare_code(output, input))
def test_single_prior_pragma(self) -> None:
input = """\
// File with prior pragmas
#pragma clang max_tokens_total 1
int main() {
for (int i = 0; i < 10; i++);
return 0;
}
"""
expected = """\
// File with prior pragmas
#pragma clang max_tokens_total 42
int main() {
for (int i = 0; i < 10; i++);
return 0;
}
"""
stripped = """\
// File with prior pragmas
int main() {
for (int i = 0; i < 10; i++);
return 0;
}
"""
output = add_max_tokens_pragma(input, 42)
self.assertTrue(compare_code(output, expected))
output = strip_max_tokens_pragmas(output)
self.assertTrue(compare_code(output, stripped))
def test_multiple_prior_pragmas(self) -> None:
input = """\
// File with multiple prior pragmas
#pragma clang max_tokens_total 1
// Different pragma; script should ignore this
#pragma clang max_tokens_here 20
int main() {
for (int i = 0; i < 10; i++);
return 0;
}
#pragma clang max_tokens_total 1
"""
expected = """\
// File with multiple prior pragmas
#pragma clang max_tokens_total 42
// Different pragma; script should ignore this
#pragma clang max_tokens_here 20
int main() {
for (int i = 0; i < 10; i++);
return 0;
}
#pragma clang max_tokens_total 42
"""
stripped = """\
// File with multiple prior pragmas
// Different pragma; script should ignore this
#pragma clang max_tokens_here 20
int main() {
for (int i = 0; i < 10; i++);
return 0;
}
"""
output = add_max_tokens_pragma(input, 42)
self.assertTrue(compare_code(output, expected))
output = strip_max_tokens_pragmas(output)
self.assertTrue(compare_code(output, stripped))
if __name__ == "__main__":
unittest.main()