[Misc] rename torch_dtype to dtype (#26695)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-10-15 20:11:48 +08:00
committed by GitHub
parent f93e348010
commit 8f4b313c37
30 changed files with 52 additions and 55 deletions

View File

@ -631,7 +631,7 @@ def main(args: argparse.Namespace):
else:
ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
shard_intermediate_size = 2 * intermediate_size // args.tp_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
dtype = torch.float16 if current_platform.is_rocm() else config.dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
block_quant_shape = get_weight_block_size_safety(config)

View File

@ -344,7 +344,7 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
dtype = torch.float16 if current_platform.is_rocm() else config.dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
use_customized_permute = args.use_customized_permute

View File

@ -58,7 +58,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_round import AutoRound
model_name = "Qwen/Qwen3-0.6B"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
model = AutoModelForCausalLM.from_pretrained(model_name, dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
bits, group_size, sym = 4, 128, True

View File

@ -43,7 +43,7 @@ MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype="auto",
dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
```

View File

@ -41,7 +41,7 @@ MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype="auto",
dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
```

View File

@ -46,7 +46,7 @@ MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype="auto",
dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
```

View File

@ -82,7 +82,7 @@ Here's a complete example using `meta-llama/Llama-3.1-8B-Instruct` (most models
# Select model and load it
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", torch_dtype="auto")
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Select calibration dataset

View File

@ -50,7 +50,7 @@ to fetch model and tokenizer.
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype="auto",
dtype="auto",
)
model.eval()

View File

@ -27,7 +27,7 @@ You can quantize your own huggingface model with torchao, e.g. [transformers](ht
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
quantized_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
dtype="auto",
device_map="auto",
quantization_config=quantization_config
)

View File

@ -7,7 +7,7 @@ requests >= 2.26.0
tqdm
blake3
py-cpuinfo
transformers >= 4.55.2
transformers >= 4.56.0
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
protobuf # Required by LlamaTokenizer.
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.

View File

@ -334,7 +334,7 @@ class HfRunner:
trust_remote_code=trust_remote_code,
)
self.device = self.get_default_device()
self.dtype = torch_dtype = _get_and_verify_dtype(
self.dtype = dtype = _get_and_verify_dtype(
self.model_name,
self.config,
dtype=dtype,
@ -342,7 +342,7 @@ class HfRunner:
)
model_kwargs = model_kwargs if model_kwargs is not None else {}
model_kwargs.setdefault("torch_dtype", torch_dtype)
model_kwargs.setdefault("dtype", dtype)
if is_sentence_transformer:
# Lazy init required for AMD CI
@ -388,7 +388,7 @@ class HfRunner:
if not skip_tokenizer_init:
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
torch_dtype=torch_dtype,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -398,7 +398,7 @@ class HfRunner:
self.processor = AutoProcessor.from_pretrained(
model_name,
torch_dtype=torch_dtype,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if skip_tokenizer_init:

View File

@ -38,7 +38,7 @@ def run_intern_vit_test(
config.norm_type = "rms_norm"
hf_model = AutoModel.from_pretrained(
model, torch_dtype=torch_dtype, trust_remote_code=True
model, dtype=torch_dtype, trust_remote_code=True
).to("cuda")
hf_outputs_per_image = [
hf_model(pixel_value.to("cuda")).last_hidden_state

View File

@ -45,7 +45,7 @@ def run_radio_test(
hf_model = AutoModel.from_pretrained(
model_id,
config=config,
torch_dtype=torch_dtype,
dtype=torch_dtype,
trust_remote_code=True,
).to("cuda")
hf_model.eval()

View File

@ -251,7 +251,7 @@ def run_hf(
disable_detokenize: bool = False,
) -> float:
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
model, dtype=torch.float16, trust_remote_code=trust_remote_code
)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.

View File

@ -1837,18 +1837,18 @@ def _find_dtype(
*,
revision: str | None,
):
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
# NOTE: getattr(config, "dtype", torch.float32) is not correct
# because config.dtype can be None.
config_dtype = getattr(config, "dtype", None)
# Fallbacks for multi-modal models if the root config
# does not define torch_dtype
# does not define dtype
if config_dtype is None:
config_dtype = getattr(config.get_text_config(), "torch_dtype", None)
config_dtype = getattr(config.get_text_config(), "dtype", None)
if config_dtype is None and hasattr(config, "vision_config"):
config_dtype = getattr(config.vision_config, "torch_dtype", None)
config_dtype = getattr(config.vision_config, "dtype", None)
if config_dtype is None and hasattr(config, "encoder_config"):
config_dtype = getattr(config.encoder_config, "torch_dtype", None)
config_dtype = getattr(config.encoder_config, "dtype", None)
# Try to read the dtype of the weights if they are in safetensors format
if config_dtype is None:

View File

@ -117,9 +117,8 @@ class LLM:
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
the `torch_dtype` attribute specified in the model config file.
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
the `dtype` attribute of the Transformers model's config. However,
if the `dtype` in the config is `float32`, we will use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
we support "awq", "gptq", and "fp8" (experimental).
If None, we first check the `quantization_config` attribute in the

View File

@ -518,7 +518,7 @@ def init_tensorizer_model(
) -> nn.Module:
assert tensorizer_config.hf_config is not None
model_args = tensorizer_config.hf_config
model_args.torch_dtype = tensorizer_config.dtype
model_args.dtype = tensorizer_config.dtype
assert tensorizer_config.model_class is not None
# TODO: Do we need to consider old-style model class?
with meta_tensor_mode(), set_current_vllm_config(vllm_config, check_compile=True):

View File

@ -999,7 +999,7 @@ class ChameleonForConditionalGeneration(
return []
assert self.model.vqmodel is not None
image_tokens = self.model.get_image_tokens(
image_input["data"].to(self.config.torch_dtype)
image_input["data"].to(self.config.dtype)
)
vision_embeddings = self.model.get_input_embeddings(image_tokens)
return vision_embeddings

View File

@ -1089,7 +1089,7 @@ class Ernie4_5VLMultiModalProcessor(BaseMultiModalProcessor[Ernie4_5_VLProcessin
pixel_values = (
rescale_factor * pixel_values.to(torch.float32) - image_mean_tensor
) / image_std_tensor
pixel_values = pixel_values.to(hf_config.torch_dtype)
pixel_values = pixel_values.to(hf_config.dtype)
return pixel_values
def _call_hf_processor(

View File

@ -615,7 +615,7 @@ class GLM4VForCausalLM(
return None
def _process_image_input(self, image_input: GLMVImagePixelInputs) -> torch.Tensor:
pixel_values = image_input["data"].to(dtype=self.config.torch_dtype)
pixel_values = image_input["data"].to(dtype=self.config.dtype)
return self.transformer.vision(pixel_values)

View File

@ -114,7 +114,7 @@ class FlashConfig(PretrainedConfig):
attention_dropout=0.0,
mla_scale_q_lora=False,
mla_scale_kv_lora=False,
torch_dtype="bfloat16",
dtype="bfloat16",
params_dtype="bfloat16",
router_dtype="float32",
router_bias=False,
@ -130,7 +130,7 @@ class FlashConfig(PretrainedConfig):
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
torch_dtype=torch_dtype,
dtype=dtype,
params_dtype=params_dtype,
router_dtype=router_dtype,
topk_method=topk_method,

View File

@ -987,7 +987,7 @@ class NemotronH_Nano_VL_V2(
prefix=maybe_prefix(prefix, "language_model"),
)
self.vision_model = self.get_vit_model_from_radio_config(config).to(
self.language_model.config.torch_dtype
self.language_model.config.dtype
)
# Construct the vision projection.
@ -1008,7 +1008,7 @@ class NemotronH_Nano_VL_V2(
ReLUSquaredActivation(),
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
)
self.mlp1 = self.mlp1.to(self.language_model.config.torch_dtype)
self.mlp1 = self.mlp1.to(self.language_model.config.dtype)
self.config = config
self.model_config = vllm_config.model_config

View File

@ -338,7 +338,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
group_size=None,
norm_before_gate=True,
device=current_platform.current_device(),
dtype=config.torch_dtype,
dtype=config.dtype,
)
self.out_proj = RowParallelLinear(
@ -847,7 +847,7 @@ class Qwen3NextDecoderLayer(nn.Module):
1,
1,
config.hidden_size,
dtype=config.torch_dtype,
dtype=config.dtype,
),
)
self.ffn_layer_scale = torch.nn.Parameter(
@ -855,7 +855,7 @@ class Qwen3NextDecoderLayer(nn.Module):
1,
1,
config.hidden_size,
dtype=config.torch_dtype,
dtype=config.dtype,
),
)

View File

@ -530,7 +530,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
with init_on_device_without_buffers("meta"):
self.model: PreTrainedModel = AutoModel.from_config(
self.config,
torch_dtype=self.model_config.dtype,
dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)

View File

@ -157,7 +157,7 @@ class TransformersForSequenceClassification(TransformersPoolingBase):
with torch.device("meta"):
seq_cls_model = AutoModelForSequenceClassification.from_config(
self.config,
torch_dtype=self.model_config.dtype,
dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)

View File

@ -500,8 +500,8 @@ class CudaPlatformBase(Platform):
return supported
@classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
if torch_dtype == torch.bfloat16: # noqa: SIM102
def check_if_supports_dtype(cls, dtype: torch.dtype):
if dtype == torch.bfloat16: # noqa: SIM102
if not cls.has_device_capability(80):
capability = cls.get_device_capability()
gpu_name = cls.get_device_name()

View File

@ -563,7 +563,7 @@ class Platform:
return False
@classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
def check_if_supports_dtype(cls, dtype: torch.dtype):
"""
Check if the dtype is supported by the current platform.
"""

View File

@ -484,8 +484,8 @@ class RocmPlatform(Platform):
return True
@classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
if torch_dtype == torch.bfloat16: # noqa: SIM102
def check_if_supports_dtype(cls, dtype: torch.dtype):
if dtype == torch.bfloat16: # noqa: SIM102
if not cls.has_device_capability(80):
capability = cls.get_device_capability()
gpu_name = cls.get_device_name()

View File

@ -236,8 +236,8 @@ class XPUPlatform(Platform):
return torch.xpu.device_count()
@classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
if torch_dtype == torch.bfloat16: # noqa: SIM102
def check_if_supports_dtype(cls, dtype: torch.dtype):
if dtype == torch.bfloat16: # noqa: SIM102
device_name = cls.get_device_name().lower()
# client gpu a770
if device_name.count("a770") > 0:

View File

@ -806,7 +806,7 @@ def create_kv_caches_with_random_flash(
current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
assert cache_layout in ("NHD", "HND")
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4)
@ -819,7 +819,7 @@ def create_kv_caches_with_random_flash(
for _ in range(num_layers):
key_value_cache = torch.empty(
size=kv_cache_allocation_shape, dtype=torch_dtype, device=device
size=kv_cache_allocation_shape, dtype=dtype, device=device
).permute(*stride_order)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_value_cache.uniform_(-scale, scale)
@ -851,14 +851,14 @@ def create_kv_caches_with_random(
current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches: list[torch.Tensor] = []
for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device)
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_cache.uniform_(-scale, scale)
elif cache_dtype == "fp8":
@ -870,9 +870,7 @@ def create_kv_caches_with_random(
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches: list[torch.Tensor] = []
for _ in range(num_layers):
value_cache = torch.empty(
size=value_cache_shape, dtype=torch_dtype, device=device
)
value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
value_cache.uniform_(-scale, scale)
elif cache_dtype == "fp8":