[BE] Migrate dtype_abbrs into one location (#152229)

Namely `torch.utils._dtype_abbrs.dtype_abbrs`

Before that it was defined in various forms of completeness in
c02edba863/torch/fx/graph.py (L215),
c02edba863/torch/testing/_internal/common_utils.py (L5226)
 and c02edba863/torch/testing/_internal/logging_tensor.py (L17)

TODO:
 - Add linter that `torch.testing._internal` module is not referenced from any of the public facing APIs, as it can have extra dependencies such as `expect_test`

Fixes https://github.com/pytorch/pytorch/issues/152225

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152229
Approved by: https://github.com/clee2000, https://github.com/Skylion007
This commit is contained in:
Nikita Shulga
2025-04-28 03:52:47 +00:00
committed by PyTorch MergeBot
parent 899eec665c
commit 13966d0bf5
8 changed files with 36 additions and 69 deletions

View File

@ -30,7 +30,6 @@ from torch.testing._internal.common_device_type import (
)
from torch.testing._internal.common_methods_invocations import op_db, skipOps
from torch.testing._internal.common_utils import (
dtype_abbrs,
IS_MACOS,
IS_X86,
skipCUDAMemoryLeakCheckIf,
@ -49,6 +48,7 @@ from torch.testing._internal.inductor_utils import (
HAS_XPU,
maybe_skip_size_asserts,
)
from torch.utils._dtype_abbrs import dtype_abbrs
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map

View File

@ -7,6 +7,7 @@ import os
import numpy as np
from enum import Enum
from torch.overrides import resolve_name
from torch.utils._dtype_abbrs import dtype_abbrs
from torch.utils._pytree import tree_map, tree_map_only, tree_flatten, tree_unflatten
from torch.utils import _pytree as pytree
from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq, is_sparse_any
@ -22,7 +23,6 @@ from torch.testing._internal.common_utils import (
suppress_warnings,
TEST_WITH_TORCHDYNAMO,
run_tests,
dtype_abbrs,
parametrize,
xfailIfTorchDynamo,
)

View File

@ -780,7 +780,7 @@ def aot_graph_input_parser(
forward(**kwargs)
"""
from torch.fx.graph import dtype_abbrs
from torch.utils._dtype_abbrs import dtype_abbrs
dtype_map = {value: key for key, value in dtype_abbrs.items()}
dtype_pattern = "|".join(dtype_abbrs.values())

View File

@ -791,7 +791,7 @@ def _register_logging_hooks_on_whole_graph(
def fmt(t: Optional[torch.Tensor]) -> str:
# Avoid circular import
from torch.testing._internal.common_utils import dtype_abbrs
from torch.utils._dtype_abbrs import dtype_abbrs
if t is None:
return "None"

View File

@ -20,6 +20,7 @@ from typing import Any, Callable, Literal, NamedTuple, Optional, TYPE_CHECKING
import torch
import torch.utils._pytree as pytree
from torch._C import _fx_map_arg as map_arg, _NodeIter
from torch.utils._dtype_abbrs import dtype_abbrs
from . import _pytree as fx_pytree
from ._compatibility import compatibility
@ -212,34 +213,6 @@ class _Namespace:
self._used_names.add(name)
dtype_abbrs = {
torch.bfloat16: "bf16",
torch.float64: "f64",
torch.float32: "f32",
torch.float16: "f16",
torch.float8_e4m3fn: "f8e4m3fn",
torch.float8_e5m2: "f8e5m2",
torch.float8_e4m3fnuz: "f8e4m3fnuz",
torch.float8_e5m2fnuz: "f8e5m2fnuz",
torch.float8_e8m0fnu: "f8e8m0fnu",
torch.float4_e2m1fn_x2: "f4e2m1fnx2",
torch.complex32: "c32",
torch.complex64: "c64",
torch.complex128: "c128",
torch.int8: "i8",
torch.int16: "i16",
torch.int32: "i32",
torch.int64: "i64",
torch.bool: "b8",
torch.uint8: "u8",
torch.uint16: "u16",
torch.uint32: "u32",
torch.uint64: "u64",
torch.bits16: "b16",
torch.bits1x8: "b1x8",
}
@compatibility(is_backward_compatible=True)
@dataclass
class PythonCode:

View File

@ -5223,23 +5223,6 @@ def dtype_name(dtype):
return str(dtype).split('.')[1]
dtype_abbrs = {
torch.bfloat16: 'bf16',
torch.float64: 'f64',
torch.float32: 'f32',
torch.float16: 'f16',
torch.complex32: 'c32',
torch.complex64: 'c64',
torch.complex128: 'c128',
torch.int8: 'i8',
torch.int16: 'i16',
torch.int32: 'i32',
torch.int64: 'i64',
torch.bool: 'b8',
torch.uint8: 'u8',
}
@functools.lru_cache
def get_cycles_per_ms() -> float:
"""Measure and return approximate number of cycles per millisecond for torch.cuda._sleep

View File

@ -7,6 +7,7 @@ from collections.abc import Iterator
import logging
import contextlib
import itertools
from torch.utils._dtype_abbrs import dtype_abbrs as _dtype_abbrs
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils.weak import WeakTensorKeyDictionary
import functools
@ -14,26 +15,6 @@ from torch._C._profiler import gather_traceback, symbolize_tracebacks
logger = logging.getLogger("LoggingTensor")
_dtype_abbrs = {
torch.bfloat16: "bf16",
torch.float64: "f64",
torch.float32: "f32",
torch.float16: "f16",
torch.complex32: "c32",
torch.complex64: "c64",
torch.complex128: "c128",
torch.int8: "i8",
torch.int16: "i16",
torch.int32: "i32",
torch.int64: "i64",
torch.bool: "b8",
torch.uint8: "u8",
torch.float8_e4m3fn: "f8e4m3fn",
torch.float8_e5m2: "f8e5m2",
torch.float8_e4m3fnuz: "f8e4m3fnuz",
torch.float8_e5m2fnuz: "f8e5m2fnuz",
}
# How the chain of calls works for LoggingTensor:
# 1. Call torch.sin
# 2. Attempt __torch_function__. In LoggingTensor torch function is disabled so we bypass it entirely

View File

@ -0,0 +1,30 @@
import torch
# Used for testing and logging
dtype_abbrs = {
torch.bfloat16: "bf16",
torch.float64: "f64",
torch.float32: "f32",
torch.float16: "f16",
torch.float8_e4m3fn: "f8e4m3fn",
torch.float8_e5m2: "f8e5m2",
torch.float8_e4m3fnuz: "f8e4m3fnuz",
torch.float8_e5m2fnuz: "f8e5m2fnuz",
torch.float8_e8m0fnu: "f8e8m0fnu",
torch.float4_e2m1fn_x2: "f4e2m1fnx2",
torch.complex32: "c32",
torch.complex64: "c64",
torch.complex128: "c128",
torch.int8: "i8",
torch.int16: "i16",
torch.int32: "i32",
torch.int64: "i64",
torch.bool: "b8",
torch.uint8: "u8",
torch.uint16: "u16",
torch.uint32: "u32",
torch.uint64: "u64",
torch.bits16: "b16",
torch.bits1x8: "b1x8",
}