Files
pytorch/tools/linter/adapters/import_linter.py
2025-09-23 23:22:53 +00:00

139 lines
3.7 KiB
Python

"""
Checks files to make sure there are no imports from disallowed third party
libraries.
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import token
from enum import Enum
from pathlib import Path
from typing import NamedTuple, TYPE_CHECKING
_PARENT = Path(__file__).parent.absolute()
_PATH = [Path(p).absolute() for p in sys.path]
if TYPE_CHECKING or _PARENT not in _PATH:
from . import _linter
else:
import _linter
class LintSeverity(str, Enum):
ERROR = "error"
WARNING = "warning"
ADVICE = "advice"
DISABLED = "disabled"
class LintMessage(NamedTuple):
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: str | None
replacement: str | None
description: str | None
LINTER_CODE = "IMPORT_LINTER"
CURRENT_FILE_NAME = os.path.basename(__file__)
_MODULE_NAME_ALLOW_LIST: set[str] = set()
# Add builtin modules.
_MODULE_NAME_ALLOW_LIST.update(sys.stdlib_module_names)
# Add the allowed third party libraries. Please avoid updating this unless you
# understand the risks -- see `_ERROR_MESSAGE` for why.
_MODULE_NAME_ALLOW_LIST.update(
[
"sympy",
"einops",
"libfb",
"torch",
"tvm",
"_pytest",
"tabulate",
"optree",
"typing_extensions",
"triton",
"functorch",
"torchrec",
"numpy",
"torch_xla",
]
)
_ERROR_MESSAGE = """
Please do not import third-party modules in PyTorch unless they're explicit
requirements of PyTorch. Imports of a third-party library may have side effects
and other unintentional behavior. If you're just checking if a module exists,
use sys.modules.get("torchrec") or the like.
"""
def check_file(filepath: str) -> list[LintMessage]:
path = Path(filepath)
file = _linter.PythonFile("import_linter", path)
lint_messages = []
for line_number, line_of_tokens in enumerate(file.token_lines):
# Skip indents
idx = 0
for tok in line_of_tokens:
if tok.type == token.INDENT:
idx += 1
else:
break
# Look for either "import foo..." or "from foo..."
if idx + 1 < len(line_of_tokens):
tok0 = line_of_tokens[idx]
tok1 = line_of_tokens[idx + 1]
if tok0.type == token.NAME and tok0.string in {"import", "from"}:
if tok1.type == token.NAME:
module_name = tok1.string
if module_name not in _MODULE_NAME_ALLOW_LIST:
msg = LintMessage(
path=filepath,
line=line_number,
char=None,
code="IMPORT",
severity=LintSeverity.ERROR,
name="Disallowed import",
original=None,
replacement=None,
description=_ERROR_MESSAGE,
)
lint_messages.append(msg)
return lint_messages
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="native functions linter",
fromfile_prefix_chars="@",
)
parser.add_argument(
"filepaths",
nargs="+",
help="paths of files to lint",
)
args = parser.parse_args()
# Check all files.
all_lint_messages = []
for filepath in args.filepaths:
lint_messages = check_file(filepath)
all_lint_messages.extend(lint_messages)
# Print out lint messages.
for lint_message in all_lint_messages:
print(json.dumps(lint_message._asdict()), flush=True)