Files
pytorch/docs/source/scripts/build_quantization_configs.py
Xuehai Pan b6bdb67f82 [BE][Easy] use pathlib.Path instead of dirname / ".." / pardir (#129374)
Changes by apply order:

1. Replace all `".."` and `os.pardir` usage with `os.path.dirname(...)`.
2. Replace nested `os.path.dirname(os.path.dirname(...))` call with `str(Path(...).parent.parent)`.
3. Reorder `.absolute()` ~/ `.resolve()`~ and `.parent`: always resolve the path first.

    `.parent{...}.absolute()` -> `.absolute().parent{...}`

4. Replace chained `.parent x N` with `.parents[${N - 1}]`: the code is easier to read (see 5.)

    `.parent.parent.parent.parent` -> `.parents[3]`

5. ~Replace `.parents[${N - 1}]` with `.parents[${N} - 1]`: the code is easier to read and does not introduce any runtime overhead.~

    ~`.parents[3]` -> `.parents[4 - 1]`~

6. ~Replace `.parents[2 - 1]` with `.parent.parent`: because the code is shorter and easier to read.~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129374
Approved by: https://github.com/justinchuby, https://github.com/malfet
2024-12-29 17:23:13 +00:00

65 lines
1.9 KiB
Python

"""
This script will generate default values of quantization configs.
These are for use in the documentation.
"""
import os.path
import torch
from torch.ao.quantization.backend_config import get_native_backend_config_dict
from torch.ao.quantization.backend_config.utils import (
entry_to_pretty_str,
remove_boolean_dispatch_from_name,
)
# Create a directory for the images, if it doesn't exist
QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH = os.path.join(
os.path.realpath(os.path.dirname(__file__)), "quantization_backend_configs"
)
if not os.path.exists(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH):
os.mkdir(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH)
output_path = os.path.join(
QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH, "default_backend_config.txt"
)
with open(output_path, "w") as f:
native_backend_config_dict = get_native_backend_config_dict()
configs = native_backend_config_dict["configs"]
def _sort_key_func(entry):
pattern = entry["pattern"]
while isinstance(pattern, tuple):
pattern = pattern[-1]
pattern = remove_boolean_dispatch_from_name(pattern)
if not isinstance(pattern, str):
# methods are already strings
pattern = torch.typename(pattern)
# we want
#
# torch.nn.modules.pooling.AdaptiveAvgPool1d
#
# and
#
# torch._VariableFunctionsClass.adaptive_avg_pool1d
#
# to be next to each other, so convert to all lower case
# and remove the underscores, and compare the last part
# of the string
pattern_str_normalized = pattern.lower().replace("_", "")
key = pattern_str_normalized.split(".")[-1]
return key
configs.sort(key=_sort_key_func)
entries = []
for entry in configs:
entries.append(entry_to_pretty_str(entry))
entries = ",\n".join(entries)
f.write(entries)