mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit 2293fe1024812d6349f6e2b3b7de82c6b73f11e4. Reverted https://github.com/pytorch/pytorch/pull/129374 on behalf of https://github.com/malfet due to failing internal ROCM builds with error: ModuleNotFoundError: No module named hipify ([comment](https://github.com/pytorch/pytorch/pull/129374#issuecomment-2562973920))
65 lines
1.9 KiB
Python
65 lines
1.9 KiB
Python
"""
|
|
This script will generate default values of quantization configs.
|
|
These are for use in the documentation.
|
|
"""
|
|
|
|
import os.path
|
|
|
|
import torch
|
|
from torch.ao.quantization.backend_config import get_native_backend_config_dict
|
|
from torch.ao.quantization.backend_config.utils import (
|
|
entry_to_pretty_str,
|
|
remove_boolean_dispatch_from_name,
|
|
)
|
|
|
|
|
|
# Create a directory for the images, if it doesn't exist
|
|
QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH = os.path.join(
|
|
os.path.realpath(os.path.join(__file__, "..")), "quantization_backend_configs"
|
|
)
|
|
|
|
if not os.path.exists(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH):
|
|
os.mkdir(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH)
|
|
|
|
output_path = os.path.join(
|
|
QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH, "default_backend_config.txt"
|
|
)
|
|
|
|
with open(output_path, "w") as f:
|
|
native_backend_config_dict = get_native_backend_config_dict()
|
|
|
|
configs = native_backend_config_dict["configs"]
|
|
|
|
def _sort_key_func(entry):
|
|
pattern = entry["pattern"]
|
|
while isinstance(pattern, tuple):
|
|
pattern = pattern[-1]
|
|
|
|
pattern = remove_boolean_dispatch_from_name(pattern)
|
|
if not isinstance(pattern, str):
|
|
# methods are already strings
|
|
pattern = torch.typename(pattern)
|
|
|
|
# we want
|
|
#
|
|
# torch.nn.modules.pooling.AdaptiveAvgPool1d
|
|
#
|
|
# and
|
|
#
|
|
# torch._VariableFunctionsClass.adaptive_avg_pool1d
|
|
#
|
|
# to be next to each other, so convert to all lower case
|
|
# and remove the underscores, and compare the last part
|
|
# of the string
|
|
pattern_str_normalized = pattern.lower().replace("_", "")
|
|
key = pattern_str_normalized.split(".")[-1]
|
|
return key
|
|
|
|
configs.sort(key=_sort_key_func)
|
|
|
|
entries = []
|
|
for entry in configs:
|
|
entries.append(entry_to_pretty_str(entry))
|
|
entries = ",\n".join(entries)
|
|
f.write(entries)
|