mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE] Migrate dtype_abbrs into one location (#152229)
Namely `torch.utils._dtype_abbrs.dtype_abbrs` Before that it was defined in various forms of completeness inc02edba863/torch/fx/graph.py (L215)
,c02edba863/torch/testing/_internal/common_utils.py (L5226)
andc02edba863/torch/testing/_internal/logging_tensor.py (L17)
TODO: - Add linter that `torch.testing._internal` module is not referenced from any of the public facing APIs, as it can have extra dependencies such as `expect_test` Fixes https://github.com/pytorch/pytorch/issues/152225 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152229 Approved by: https://github.com/clee2000, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
899eec665c
commit
13966d0bf5
@ -30,7 +30,6 @@ from torch.testing._internal.common_device_type import (
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import op_db, skipOps
|
||||
from torch.testing._internal.common_utils import (
|
||||
dtype_abbrs,
|
||||
IS_MACOS,
|
||||
IS_X86,
|
||||
skipCUDAMemoryLeakCheckIf,
|
||||
@ -49,6 +48,7 @@ from torch.testing._internal.inductor_utils import (
|
||||
HAS_XPU,
|
||||
maybe_skip_size_asserts,
|
||||
)
|
||||
from torch.utils._dtype_abbrs import dtype_abbrs
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
@ -7,6 +7,7 @@ import os
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
from torch.overrides import resolve_name
|
||||
from torch.utils._dtype_abbrs import dtype_abbrs
|
||||
from torch.utils._pytree import tree_map, tree_map_only, tree_flatten, tree_unflatten
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq, is_sparse_any
|
||||
@ -22,7 +23,6 @@ from torch.testing._internal.common_utils import (
|
||||
suppress_warnings,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
run_tests,
|
||||
dtype_abbrs,
|
||||
parametrize,
|
||||
xfailIfTorchDynamo,
|
||||
)
|
||||
|
@ -780,7 +780,7 @@ def aot_graph_input_parser(
|
||||
forward(**kwargs)
|
||||
"""
|
||||
|
||||
from torch.fx.graph import dtype_abbrs
|
||||
from torch.utils._dtype_abbrs import dtype_abbrs
|
||||
|
||||
dtype_map = {value: key for key, value in dtype_abbrs.items()}
|
||||
dtype_pattern = "|".join(dtype_abbrs.values())
|
||||
|
@ -791,7 +791,7 @@ def _register_logging_hooks_on_whole_graph(
|
||||
|
||||
def fmt(t: Optional[torch.Tensor]) -> str:
|
||||
# Avoid circular import
|
||||
from torch.testing._internal.common_utils import dtype_abbrs
|
||||
from torch.utils._dtype_abbrs import dtype_abbrs
|
||||
|
||||
if t is None:
|
||||
return "None"
|
||||
|
@ -20,6 +20,7 @@ from typing import Any, Callable, Literal, NamedTuple, Optional, TYPE_CHECKING
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import _fx_map_arg as map_arg, _NodeIter
|
||||
from torch.utils._dtype_abbrs import dtype_abbrs
|
||||
|
||||
from . import _pytree as fx_pytree
|
||||
from ._compatibility import compatibility
|
||||
@ -212,34 +213,6 @@ class _Namespace:
|
||||
self._used_names.add(name)
|
||||
|
||||
|
||||
dtype_abbrs = {
|
||||
torch.bfloat16: "bf16",
|
||||
torch.float64: "f64",
|
||||
torch.float32: "f32",
|
||||
torch.float16: "f16",
|
||||
torch.float8_e4m3fn: "f8e4m3fn",
|
||||
torch.float8_e5m2: "f8e5m2",
|
||||
torch.float8_e4m3fnuz: "f8e4m3fnuz",
|
||||
torch.float8_e5m2fnuz: "f8e5m2fnuz",
|
||||
torch.float8_e8m0fnu: "f8e8m0fnu",
|
||||
torch.float4_e2m1fn_x2: "f4e2m1fnx2",
|
||||
torch.complex32: "c32",
|
||||
torch.complex64: "c64",
|
||||
torch.complex128: "c128",
|
||||
torch.int8: "i8",
|
||||
torch.int16: "i16",
|
||||
torch.int32: "i32",
|
||||
torch.int64: "i64",
|
||||
torch.bool: "b8",
|
||||
torch.uint8: "u8",
|
||||
torch.uint16: "u16",
|
||||
torch.uint32: "u32",
|
||||
torch.uint64: "u64",
|
||||
torch.bits16: "b16",
|
||||
torch.bits1x8: "b1x8",
|
||||
}
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
@dataclass
|
||||
class PythonCode:
|
||||
|
@ -5223,23 +5223,6 @@ def dtype_name(dtype):
|
||||
return str(dtype).split('.')[1]
|
||||
|
||||
|
||||
dtype_abbrs = {
|
||||
torch.bfloat16: 'bf16',
|
||||
torch.float64: 'f64',
|
||||
torch.float32: 'f32',
|
||||
torch.float16: 'f16',
|
||||
torch.complex32: 'c32',
|
||||
torch.complex64: 'c64',
|
||||
torch.complex128: 'c128',
|
||||
torch.int8: 'i8',
|
||||
torch.int16: 'i16',
|
||||
torch.int32: 'i32',
|
||||
torch.int64: 'i64',
|
||||
torch.bool: 'b8',
|
||||
torch.uint8: 'u8',
|
||||
}
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_cycles_per_ms() -> float:
|
||||
"""Measure and return approximate number of cycles per millisecond for torch.cuda._sleep
|
||||
|
@ -7,6 +7,7 @@ from collections.abc import Iterator
|
||||
import logging
|
||||
import contextlib
|
||||
import itertools
|
||||
from torch.utils._dtype_abbrs import dtype_abbrs as _dtype_abbrs
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils.weak import WeakTensorKeyDictionary
|
||||
import functools
|
||||
@ -14,26 +15,6 @@ from torch._C._profiler import gather_traceback, symbolize_tracebacks
|
||||
|
||||
logger = logging.getLogger("LoggingTensor")
|
||||
|
||||
_dtype_abbrs = {
|
||||
torch.bfloat16: "bf16",
|
||||
torch.float64: "f64",
|
||||
torch.float32: "f32",
|
||||
torch.float16: "f16",
|
||||
torch.complex32: "c32",
|
||||
torch.complex64: "c64",
|
||||
torch.complex128: "c128",
|
||||
torch.int8: "i8",
|
||||
torch.int16: "i16",
|
||||
torch.int32: "i32",
|
||||
torch.int64: "i64",
|
||||
torch.bool: "b8",
|
||||
torch.uint8: "u8",
|
||||
torch.float8_e4m3fn: "f8e4m3fn",
|
||||
torch.float8_e5m2: "f8e5m2",
|
||||
torch.float8_e4m3fnuz: "f8e4m3fnuz",
|
||||
torch.float8_e5m2fnuz: "f8e5m2fnuz",
|
||||
}
|
||||
|
||||
# How the chain of calls works for LoggingTensor:
|
||||
# 1. Call torch.sin
|
||||
# 2. Attempt __torch_function__. In LoggingTensor torch function is disabled so we bypass it entirely
|
||||
|
30
torch/utils/_dtype_abbrs.py
Normal file
30
torch/utils/_dtype_abbrs.py
Normal file
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
|
||||
|
||||
# Used for testing and logging
|
||||
dtype_abbrs = {
|
||||
torch.bfloat16: "bf16",
|
||||
torch.float64: "f64",
|
||||
torch.float32: "f32",
|
||||
torch.float16: "f16",
|
||||
torch.float8_e4m3fn: "f8e4m3fn",
|
||||
torch.float8_e5m2: "f8e5m2",
|
||||
torch.float8_e4m3fnuz: "f8e4m3fnuz",
|
||||
torch.float8_e5m2fnuz: "f8e5m2fnuz",
|
||||
torch.float8_e8m0fnu: "f8e8m0fnu",
|
||||
torch.float4_e2m1fn_x2: "f4e2m1fnx2",
|
||||
torch.complex32: "c32",
|
||||
torch.complex64: "c64",
|
||||
torch.complex128: "c128",
|
||||
torch.int8: "i8",
|
||||
torch.int16: "i16",
|
||||
torch.int32: "i32",
|
||||
torch.int64: "i64",
|
||||
torch.bool: "b8",
|
||||
torch.uint8: "u8",
|
||||
torch.uint16: "u16",
|
||||
torch.uint32: "u32",
|
||||
torch.uint64: "u64",
|
||||
torch.bits16: "b16",
|
||||
torch.bits1x8: "b1x8",
|
||||
}
|
Reference in New Issue
Block a user