diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index 69474e06d18e..cdc805d5a4b5 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -555,8 +555,7 @@ def gen_autograd_functions_lib( fname, lambda: { "generated_comment": "@" - + f"generated from {fm.template_dir_for_comments()}/" - + fname, + + f"generated from {fm.template_dir_for_comments()}/{fname}", "autograd_function_declarations": declarations, "autograd_function_definitions": definitions, }, diff --git a/tools/autograd/gen_view_funcs.py b/tools/autograd/gen_view_funcs.py index e6600106dca9..8cc8a2ffcecc 100644 --- a/tools/autograd/gen_view_funcs.py +++ b/tools/autograd/gen_view_funcs.py @@ -331,8 +331,7 @@ def gen_view_funcs( fname, lambda: { "generated_comment": "@" - + f"generated from {fm.template_dir_for_comments()}/" - + fname, + + f"generated from {fm.template_dir_for_comments()}/{fname}", "view_func_declarations": declarations, "view_func_definitions": definitions, "ops_headers": ops_headers, diff --git a/torchgen/api/cpp.py b/torchgen/api/cpp.py index 554bfa4a5c79..c619ec45d2f8 100644 --- a/torchgen/api/cpp.py +++ b/torchgen/api/cpp.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing_extensions import assert_never from torchgen import local from torchgen.api.types import ( @@ -48,7 +49,6 @@ from torchgen.model import ( TensorOptionsArguments, Type, ) -from torchgen.utils import assert_never if TYPE_CHECKING: diff --git a/torchgen/api/dispatcher.py b/torchgen/api/dispatcher.py index 4cc6186d7e0e..fcca7a60fec1 100644 --- a/torchgen/api/dispatcher.py +++ b/torchgen/api/dispatcher.py @@ -2,6 +2,7 @@ from __future__ import annotations import itertools from typing import TYPE_CHECKING +from typing_extensions import assert_never from torchgen.api import cpp from torchgen.api.types import ArgName, Binding, CType, NamedCType @@ -13,7 +14,7 @@ from torchgen.model import ( TensorOptionsArguments, Type, ) -from torchgen.utils import assert_never, concatMap +from torchgen.utils import concatMap if TYPE_CHECKING: diff --git a/torchgen/api/native.py b/torchgen/api/native.py index 82bc051a6832..632216704d2d 100644 --- a/torchgen/api/native.py +++ b/torchgen/api/native.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing_extensions import assert_never from torchgen import local from torchgen.api import cpp @@ -29,7 +30,6 @@ from torchgen.model import ( TensorOptionsArguments, Type, ) -from torchgen.utils import assert_never if TYPE_CHECKING: diff --git a/torchgen/api/structured.py b/torchgen/api/structured.py index 93a72eb2b4a5..a0e14e5b69e6 100644 --- a/torchgen/api/structured.py +++ b/torchgen/api/structured.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing_extensions import assert_never + from torchgen.api import cpp from torchgen.api.types import ( ArgName, @@ -30,7 +32,6 @@ from torchgen.model import ( TensorOptionsArguments, Type, ) -from torchgen.utils import assert_never # This file describes the translation of JIT schema to the structured functions API. diff --git a/torchgen/dest/register_dispatch_key.py b/torchgen/dest/register_dispatch_key.py index 5b7feef83237..af3d4c9ca740 100644 --- a/torchgen/dest/register_dispatch_key.py +++ b/torchgen/dest/register_dispatch_key.py @@ -4,6 +4,7 @@ import itertools import textwrap from dataclasses import dataclass from typing import Literal, TYPE_CHECKING +from typing_extensions import assert_never import torchgen.api.cpp as cpp import torchgen.api.meta as meta @@ -36,7 +37,7 @@ from torchgen.model import ( SchemaKind, TensorOptionsArguments, ) -from torchgen.utils import assert_never, mapMaybe, Target +from torchgen.utils import mapMaybe, Target if TYPE_CHECKING: diff --git a/torchgen/executorch/api/et_cpp.py b/torchgen/executorch/api/et_cpp.py index 72b0551d029d..081a3d4ece1c 100644 --- a/torchgen/executorch/api/et_cpp.py +++ b/torchgen/executorch/api/et_cpp.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing_extensions import assert_never from torchgen import local from torchgen.api.types import ( @@ -37,7 +38,6 @@ from torchgen.model import ( TensorOptionsArguments, Type, ) -from torchgen.utils import assert_never if TYPE_CHECKING: diff --git a/torchgen/executorch/model.py b/torchgen/executorch/model.py index 6be7501ebead..310c5968ec0d 100644 --- a/torchgen/executorch/model.py +++ b/torchgen/executorch/model.py @@ -7,6 +7,7 @@ import itertools from collections import defaultdict, namedtuple from dataclasses import dataclass from enum import IntEnum +from typing_extensions import assert_never from torchgen.model import ( BackendIndex, @@ -16,7 +17,6 @@ from torchgen.model import ( NativeFunctionsGroup, OperatorName, ) -from torchgen.utils import assert_never KERNEL_KEY_VERSION = 1 diff --git a/torchgen/gen.py b/torchgen/gen.py index 609d338887e6..1c27d32be727 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -9,6 +9,7 @@ from collections import defaultdict, namedtuple, OrderedDict from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable, Literal, TYPE_CHECKING, TypeVar +from typing_extensions import assert_never import yaml @@ -84,7 +85,6 @@ from torchgen.native_function_generation import ( ) from torchgen.selective_build.selector import SelectiveBuilder from torchgen.utils import ( - assert_never, concatMap, context, FileManager, diff --git a/torchgen/model.py b/torchgen/model.py index b6509b57748b..4715759d322e 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -6,8 +6,9 @@ import re from dataclasses import dataclass from enum import auto, Enum from typing import Callable, Optional, TYPE_CHECKING +from typing_extensions import assert_never -from torchgen.utils import assert_never, NamespaceHelper, OrderedSet +from torchgen.utils import NamespaceHelper, OrderedSet if TYPE_CHECKING: diff --git a/torchgen/utils.py b/torchgen/utils.py index 2d760a51145b..905d6fd0c0b6 100644 --- a/torchgen/utils.py +++ b/torchgen/utils.py @@ -11,7 +11,7 @@ from dataclasses import fields, is_dataclass from enum import auto, Enum from pathlib import Path from typing import Any, Callable, Generic, Literal, NoReturn, TYPE_CHECKING, TypeVar -from typing_extensions import Self +from typing_extensions import assert_never, deprecated, Self from torchgen.code_template import CodeTemplate @@ -21,7 +21,8 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Sequence -REPO_ROOT = Path(__file__).absolute().parent.parent +TORCHGEN_ROOT = Path(__file__).absolute().parent +REPO_ROOT = TORCHGEN_ROOT.parent # Many of these functions share logic for defining both the definition @@ -96,11 +97,13 @@ def context(msg_fn: Callable[[], str]) -> Iterator[None]: raise -# A little trick from https://github.com/python/mypy/issues/6366 -# for getting mypy to do exhaustiveness checking -# TODO: put this somewhere else, maybe -def assert_never(x: NoReturn) -> NoReturn: - raise AssertionError(f"Unhandled type: {type(x).__name__}") +if TYPE_CHECKING: + # A little trick from https://github.com/python/mypy/issues/6366 + # for getting mypy to do exhaustiveness checking + # TODO: put this somewhere else, maybe + @deprecated("Use typing_extensions.assert_never instead") + def assert_never(x: NoReturn) -> NoReturn: # type: ignore[misc] # noqa: F811 + raise AssertionError(f"Unhandled type: {type(x).__name__}") @functools.cache @@ -118,39 +121,47 @@ def string_stable_hash(s: str) -> int: # of what files have been written (so you can write out a list of output # files) class FileManager: - install_dir: str - template_dir: str - dry_run: bool - filenames: set[str] - - def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None: - self.install_dir = install_dir - self.template_dir = template_dir - self.filenames = set() + def __init__( + self, + install_dir: str | Path, + template_dir: str | Path, + dry_run: bool, + ) -> None: + self.install_dir = Path(install_dir) + self.template_dir = Path(template_dir) + self.files: set[Path] = set() self.dry_run = dry_run - def _write_if_changed(self, filename: str, contents: str) -> None: - old_contents: str | None + @property + def filenames(self) -> frozenset[str]: + return frozenset({file.as_posix() for file in self.files}) + + def _write_if_changed(self, filename: str | Path, contents: str) -> None: + file = Path(filename) + old_contents: str | None = None try: - with open(filename) as f: - old_contents = f.read() + old_contents = file.read_text(encoding="utf-8") except OSError: - old_contents = None + pass if contents != old_contents: # Create output directory if it doesn't exist - os.makedirs(os.path.dirname(filename), exist_ok=True) - with open(filename, "w") as f: - f.write(contents) + file.parent.mkdir(parents=True, exist_ok=True) + file.write_text(contents, encoding="utf-8") # Read from template file and replace pattern with callable (type could be dict or str). def substitute_with_template( - self, template_fn: str, env_callable: Callable[[], str | dict[str, Any]] + self, + template_fn: str | Path, + env_callable: Callable[[], str | dict[str, Any]], ) -> str: - template_path = os.path.join(self.template_dir, template_fn) + assert not Path(template_fn).is_absolute(), ( + f"template_fn must be relative: {template_fn}" + ) + template_path = self.template_dir / template_fn env = env_callable() if isinstance(env, dict): if "generated_comment" not in env: - generator_default = REPO_ROOT / "torchgen" / "gen.py" + generator_default = TORCHGEN_ROOT / "gen.py" try: generator = Path( sys.modules["__main__"].__file__ or generator_default @@ -170,38 +181,56 @@ class FileManager: ), } template = _read_template(template_path) - return template.substitute(env) - elif isinstance(env, str): + substitute_out = template.substitute(env) + # Ensure an extra blank line between the class/function definition + # and the docstring of the previous class/function definition. + # NB: It is generally not recommended to have docstrings in pyi stub + # files. But if there are any, we need to ensure that the file + # is properly formatted. + return re.sub( + r''' + (""")\n+ # match triple quotes + ( + (\s*@.+\n)* # match decorators if any + \s*(class|def) # match class/function definition + ) + ''', + r"\g<1>\n\n\g<2>", + substitute_out, + flags=re.VERBOSE, + ) + if isinstance(env, str): return env - else: - assert_never(env) + assert_never(env) def write_with_template( self, - filename: str, - template_fn: str, + filename: str | Path, + template_fn: str | Path, env_callable: Callable[[], str | dict[str, Any]], ) -> None: - filename = f"{self.install_dir}/{filename}" - assert filename not in self.filenames, "duplicate file write {filename}" - self.filenames.add(filename) + filename = Path(filename) + assert not filename.is_absolute(), f"filename must be relative: {filename}" + file = self.install_dir / filename + assert file not in self.files, f"duplicate file write {file}" + self.files.add(file) if not self.dry_run: substitute_out = self.substitute_with_template( template_fn=template_fn, env_callable=env_callable, ) - self._write_if_changed(filename=filename, contents=substitute_out) + self._write_if_changed(filename=file, contents=substitute_out) def write( self, - filename: str, + filename: str | Path, env_callable: Callable[[], str | dict[str, Any]], ) -> None: self.write_with_template(filename, filename, env_callable) def write_sharded( self, - filename: str, + filename: str | Path, items: Iterable[T], *, key_fn: Callable[[T], str], @@ -223,8 +252,8 @@ class FileManager: def write_sharded_with_template( self, - filename: str, - template_fn: str, + filename: str | Path, + template_fn: str | Path, items: Iterable[T], *, key_fn: Callable[[T], str], @@ -233,6 +262,8 @@ class FileManager: base_env: dict[str, Any] | None = None, sharded_keys: set[str], ) -> None: + file = Path(filename) + assert not file.is_absolute(), f"filename must be relative: {filename}" everything: dict[str, Any] = {"shard_id": "Everything"} shards: list[dict[str, Any]] = [ {"shard_id": f"_{i}"} for i in range(num_shards) @@ -270,31 +301,27 @@ class FileManager: merge_env(shards[sid], env) merge_env(everything, env) - dot_pos = filename.rfind(".") - if dot_pos == -1: - dot_pos = len(filename) - base_filename = filename[:dot_pos] - extension = filename[dot_pos:] - for shard in all_shards: shard_id = shard["shard_id"] self.write_with_template( - f"{base_filename}{shard_id}{extension}", + file.with_stem(f"{file.stem}{shard_id}"), template_fn, lambda: shard, ) # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled - self.filenames.discard( - f"{self.install_dir}/{base_filename}Everything{extension}" - ) + self.files.discard(self.install_dir / file.with_stem(f"{file.stem}Everything")) - def write_outputs(self, variable_name: str, filename: str) -> None: - """Write a file containing the list of all outputs which are - generated by this script.""" - content = "set({}\n {})".format( - variable_name, - "\n ".join('"' + name + '"' for name in sorted(self.filenames)), + def write_outputs(self, variable_name: str, filename: str | Path) -> None: + """Write a file containing the list of all outputs which are generated by this script.""" + content = "\n".join( + ( + "set(", + variable_name, + # Use POSIX paths to avoid invalid escape sequences on Windows + *(f' "{file.as_posix()}"' for file in sorted(self.files)), + ")", + ) ) self._write_if_changed(filename, content) @@ -309,12 +336,15 @@ class FileManager: # Helper function to generate file manager def make_file_manager( - options: Namespace, install_dir: str | None = None + options: Namespace, + install_dir: str | Path | None = None, ) -> FileManager: template_dir = os.path.join(options.source_path, "templates") install_dir = install_dir if install_dir else options.install_dir return FileManager( - install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run + install_dir=install_dir, + template_dir=template_dir, + dry_run=options.dry_run, ) @@ -437,7 +467,10 @@ class NamespaceHelper: """ def __init__( - self, namespace_str: str, entity_name: str = "", max_level: int = 2 + self, + namespace_str: str, + entity_name: str = "", + max_level: int = 2, ) -> None: # cpp_namespace can be a colon joined string such as torch::lazy cpp_namespaces = namespace_str.split("::") @@ -454,7 +487,8 @@ class NamespaceHelper: @staticmethod def from_namespaced_entity( - namespaced_entity: str, max_level: int = 2 + namespaced_entity: str, + max_level: int = 2, ) -> NamespaceHelper: """ Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"