Files
pytorch/docs/source/scripts/build_opsets.py
Xuehai Pan b6bdb67f82 [BE][Easy] use pathlib.Path instead of dirname / ".." / pardir (#129374)
Changes by apply order:

1. Replace all `".."` and `os.pardir` usage with `os.path.dirname(...)`.
2. Replace nested `os.path.dirname(os.path.dirname(...))` call with `str(Path(...).parent.parent)`.
3. Reorder `.absolute()` ~/ `.resolve()`~ and `.parent`: always resolve the path first.

    `.parent{...}.absolute()` -> `.absolute().parent{...}`

4. Replace chained `.parent x N` with `.parents[${N - 1}]`: the code is easier to read (see 5.)

    `.parent.parent.parent.parent` -> `.parents[3]`

5. ~Replace `.parents[${N - 1}]` with `.parents[${N} - 1]`: the code is easier to read and does not introduce any runtime overhead.~

    ~`.parents[3]` -> `.parents[4 - 1]`~

6. ~Replace `.parents[2 - 1]` with `.parent.parent`: because the code is shorter and easier to read.~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129374
Approved by: https://github.com/justinchuby, https://github.com/malfet
2024-12-29 17:23:13 +00:00

76 lines
2.0 KiB
Python

import os
from collections import OrderedDict
from pathlib import Path
import torch
import torch._prims as prims
from torchgen.gen import parse_native_yaml
ROOT = Path(__file__).absolute().parents[3]
NATIVE_FUNCTION_YAML_PATH = ROOT / "aten/src/ATen/native/native_functions.yaml"
TAGS_YAML_PATH = ROOT / "aten/src/ATen/native/tags.yaml"
BUILD_DIR = "build/ir"
ATEN_OPS_CSV_FILE = "aten_ops.csv"
PRIMS_OPS_CSV_FILE = "prims_ops.csv"
def get_aten():
parsed_yaml = parse_native_yaml(NATIVE_FUNCTION_YAML_PATH, TAGS_YAML_PATH)
native_functions = parsed_yaml.native_functions
aten_ops = OrderedDict()
for function in native_functions:
if "core" in function.tags:
op_name = str(function.func.name)
aten_ops[op_name] = function
op_schema_pairs = []
for key, op in sorted(aten_ops.items()):
op_name = f"aten.{key}"
schema = str(op.func).replace("*", r"\*")
op_schema_pairs.append((op_name, schema))
return op_schema_pairs
def get_prims():
op_schema_pairs = []
for op_name in prims.__all__:
op_overload = getattr(prims, op_name, None)
if not isinstance(op_overload, torch._ops.OpOverload):
continue
op_overloadpacket = op_overload.overloadpacket
op_name = str(op_overload).replace(".default", "")
schema = op_overloadpacket.schema.replace("*", r"\*")
op_schema_pairs.append((op_name, schema))
return op_schema_pairs
def main():
aten_ops_list = get_aten()
prims_ops_list = get_prims()
os.makedirs(BUILD_DIR, exist_ok=True)
with open(os.path.join(BUILD_DIR, ATEN_OPS_CSV_FILE), "w") as f:
f.write("Operator,Schema\n")
for name, schema in aten_ops_list:
f.write(f'"``{name}``","{schema}"\n')
with open(os.path.join(BUILD_DIR, PRIMS_OPS_CSV_FILE), "w") as f:
f.write("Operator,Schema\n")
for name, schema in prims_ops_list:
f.write(f'"``{name}``","{schema}"\n')
if __name__ == "__main__":
main()