mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Changes by apply order: 1. Replace all `".."` and `os.pardir` usage with `os.path.dirname(...)`. 2. Replace nested `os.path.dirname(os.path.dirname(...))` call with `str(Path(...).parent.parent)`. 3. Reorder `.absolute()` ~/ `.resolve()`~ and `.parent`: always resolve the path first. `.parent{...}.absolute()` -> `.absolute().parent{...}` 4. Replace chained `.parent x N` with `.parents[${N - 1}]`: the code is easier to read (see 5.) `.parent.parent.parent.parent` -> `.parents[3]` 5. ~Replace `.parents[${N - 1}]` with `.parents[${N} - 1]`: the code is easier to read and does not introduce any runtime overhead.~ ~`.parents[3]` -> `.parents[4 - 1]`~ 6. ~Replace `.parents[2 - 1]` with `.parent.parent`: because the code is shorter and easier to read.~ Pull Request resolved: https://github.com/pytorch/pytorch/pull/129374 Approved by: https://github.com/justinchuby, https://github.com/malfet
65 lines
1.9 KiB
Python
65 lines
1.9 KiB
Python
"""
|
|
This script will generate default values of quantization configs.
|
|
These are for use in the documentation.
|
|
"""
|
|
|
|
import os.path
|
|
|
|
import torch
|
|
from torch.ao.quantization.backend_config import get_native_backend_config_dict
|
|
from torch.ao.quantization.backend_config.utils import (
|
|
entry_to_pretty_str,
|
|
remove_boolean_dispatch_from_name,
|
|
)
|
|
|
|
|
|
# Create a directory for the images, if it doesn't exist
|
|
QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH = os.path.join(
|
|
os.path.realpath(os.path.dirname(__file__)), "quantization_backend_configs"
|
|
)
|
|
|
|
if not os.path.exists(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH):
|
|
os.mkdir(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH)
|
|
|
|
output_path = os.path.join(
|
|
QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH, "default_backend_config.txt"
|
|
)
|
|
|
|
with open(output_path, "w") as f:
|
|
native_backend_config_dict = get_native_backend_config_dict()
|
|
|
|
configs = native_backend_config_dict["configs"]
|
|
|
|
def _sort_key_func(entry):
|
|
pattern = entry["pattern"]
|
|
while isinstance(pattern, tuple):
|
|
pattern = pattern[-1]
|
|
|
|
pattern = remove_boolean_dispatch_from_name(pattern)
|
|
if not isinstance(pattern, str):
|
|
# methods are already strings
|
|
pattern = torch.typename(pattern)
|
|
|
|
# we want
|
|
#
|
|
# torch.nn.modules.pooling.AdaptiveAvgPool1d
|
|
#
|
|
# and
|
|
#
|
|
# torch._VariableFunctionsClass.adaptive_avg_pool1d
|
|
#
|
|
# to be next to each other, so convert to all lower case
|
|
# and remove the underscores, and compare the last part
|
|
# of the string
|
|
pattern_str_normalized = pattern.lower().replace("_", "")
|
|
key = pattern_str_normalized.split(".")[-1]
|
|
return key
|
|
|
|
configs.sort(key=_sort_key_func)
|
|
|
|
entries = []
|
|
for entry in configs:
|
|
entries.append(entry_to_pretty_str(entry))
|
|
entries = ",\n".join(entries)
|
|
f.write(entries)
|