mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Changes by apply order: 1. Replace all `".."` and `os.pardir` usage with `os.path.dirname(...)`. 2. Replace nested `os.path.dirname(os.path.dirname(...))` call with `str(Path(...).parent.parent)`. 3. Reorder `.absolute()` ~/ `.resolve()`~ and `.parent`: always resolve the path first. `.parent{...}.absolute()` -> `.absolute().parent{...}` 4. Replace chained `.parent x N` with `.parents[${N - 1}]`: the code is easier to read (see 5.) `.parent.parent.parent.parent` -> `.parents[3]` 5. ~Replace `.parents[${N - 1}]` with `.parents[${N} - 1]`: the code is easier to read and does not introduce any runtime overhead.~ ~`.parents[3]` -> `.parents[4 - 1]`~ 6. ~Replace `.parents[2 - 1]` with `.parent.parent`: because the code is shorter and easier to read.~ Pull Request resolved: https://github.com/pytorch/pytorch/pull/129374 Approved by: https://github.com/justinchuby, https://github.com/malfet
76 lines
2.0 KiB
Python
76 lines
2.0 KiB
Python
import os
|
|
from collections import OrderedDict
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch._prims as prims
|
|
from torchgen.gen import parse_native_yaml
|
|
|
|
|
|
ROOT = Path(__file__).absolute().parents[3]
|
|
NATIVE_FUNCTION_YAML_PATH = ROOT / "aten/src/ATen/native/native_functions.yaml"
|
|
TAGS_YAML_PATH = ROOT / "aten/src/ATen/native/tags.yaml"
|
|
|
|
BUILD_DIR = "build/ir"
|
|
ATEN_OPS_CSV_FILE = "aten_ops.csv"
|
|
PRIMS_OPS_CSV_FILE = "prims_ops.csv"
|
|
|
|
|
|
def get_aten():
|
|
parsed_yaml = parse_native_yaml(NATIVE_FUNCTION_YAML_PATH, TAGS_YAML_PATH)
|
|
native_functions = parsed_yaml.native_functions
|
|
|
|
aten_ops = OrderedDict()
|
|
for function in native_functions:
|
|
if "core" in function.tags:
|
|
op_name = str(function.func.name)
|
|
aten_ops[op_name] = function
|
|
|
|
op_schema_pairs = []
|
|
for key, op in sorted(aten_ops.items()):
|
|
op_name = f"aten.{key}"
|
|
schema = str(op.func).replace("*", r"\*")
|
|
|
|
op_schema_pairs.append((op_name, schema))
|
|
|
|
return op_schema_pairs
|
|
|
|
|
|
def get_prims():
|
|
op_schema_pairs = []
|
|
for op_name in prims.__all__:
|
|
op_overload = getattr(prims, op_name, None)
|
|
|
|
if not isinstance(op_overload, torch._ops.OpOverload):
|
|
continue
|
|
|
|
op_overloadpacket = op_overload.overloadpacket
|
|
|
|
op_name = str(op_overload).replace(".default", "")
|
|
schema = op_overloadpacket.schema.replace("*", r"\*")
|
|
|
|
op_schema_pairs.append((op_name, schema))
|
|
|
|
return op_schema_pairs
|
|
|
|
|
|
def main():
|
|
aten_ops_list = get_aten()
|
|
prims_ops_list = get_prims()
|
|
|
|
os.makedirs(BUILD_DIR, exist_ok=True)
|
|
|
|
with open(os.path.join(BUILD_DIR, ATEN_OPS_CSV_FILE), "w") as f:
|
|
f.write("Operator,Schema\n")
|
|
for name, schema in aten_ops_list:
|
|
f.write(f'"``{name}``","{schema}"\n')
|
|
|
|
with open(os.path.join(BUILD_DIR, PRIMS_OPS_CSV_FILE), "w") as f:
|
|
f.write("Operator,Schema\n")
|
|
for name, schema in prims_ops_list:
|
|
f.write(f'"``{name}``","{schema}"\n')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|