mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 09:44:02 +08:00
Compare commits
8 Commits
v4.56.2
...
merging_to
Author | SHA1 | Date | |
---|---|---|---|
7960f1d9d7 | |||
2024e3b5e6 | |||
091657e23f | |||
634c9c6661 | |||
4500320ab4 | |||
b3e08ec108 | |||
c6f442aed6 | |||
c7d1e731ec |
@ -50,7 +50,7 @@ class FbgemmFp8Linear(torch.nn.Linear):
|
||||
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
|
||||
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
|
||||
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
x.view(-1, x.shape[-1]), scale_ub=self.input_scale_ub
|
||||
x.view(-1, x.shape[-1]).contiguous(), scale_ub=self.input_scale_ub
|
||||
)
|
||||
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
|
||||
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
|
||||
@ -207,9 +207,6 @@ def _replace_with_fbgemm_fp8_linear(
|
||||
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
|
||||
):
|
||||
with init_empty_weights(include_buffers=True):
|
||||
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj_scale")] = tp_plan[
|
||||
re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj")
|
||||
]
|
||||
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None
|
||||
model._modules[name] = FbgemmFp8Llama4TextExperts(
|
||||
config.text_config,
|
||||
|
@ -70,7 +70,7 @@ class WrappedFlexAttention:
|
||||
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
|
||||
)
|
||||
else:
|
||||
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
|
||||
self._compiled_flex_attention = torch.compile(flex_attention, fullgraph=True)
|
||||
self._is_flex_compiled = True
|
||||
|
||||
def __call__(self):
|
||||
|
@ -219,7 +219,7 @@ class GatherParallel(TensorParallelLayer):
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
if isinstance(inputs[0], DTensor):
|
||||
if inputs and isinstance(inputs[0], DTensor):
|
||||
inputs = inputs[0].to_local()
|
||||
return inputs
|
||||
|
||||
|
@ -738,18 +738,8 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
else:
|
||||
full_cache_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length
|
||||
|
||||
# to avoid graph break, we introduce this hack
|
||||
cond1 = first_cache_position >= attention_chunk_size
|
||||
cond2 = (first_cache_position < attention_chunk_size) & (
|
||||
first_cache_position + sequence_length > attention_chunk_size
|
||||
)
|
||||
|
||||
key_length = (
|
||||
torch.where(
|
||||
cond1,
|
||||
attention_chunk_size + sequence_length - 1,
|
||||
torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size),
|
||||
)
|
||||
sequence_length if sequence_length > attention_chunk_size else attention_chunk_size
|
||||
if use_cache
|
||||
else full_cache_length
|
||||
)
|
||||
|
@ -342,6 +342,9 @@ class SequentialLlama4TextExperts(ModuleList):
|
||||
MODULES_TO_PATCH_FOR_QUANTIZATION = {
|
||||
"Llama4TextExperts": {
|
||||
"module_name": SequentialLlama4TextExperts,
|
||||
"quantization_methods": [QuantizationMethod.COMPRESSED_TENSORS, QuantizationMethod.BITS_AND_BYTES],
|
||||
"quantization_methods": [
|
||||
QuantizationMethod.COMPRESSED_TENSORS,
|
||||
QuantizationMethod.BITS_AND_BYTES,
|
||||
],
|
||||
}
|
||||
}
|
||||
|
@ -241,6 +241,41 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
|
||||
not_missing_keys.append(missing)
|
||||
return [k for k in missing_keys if k not in not_missing_keys]
|
||||
|
||||
def update_tp_plan(self, config):
|
||||
additional_text_plan = {
|
||||
"layers.*.self_attn.q_proj.weight": "local_colwise",
|
||||
"layers.*.self_attn.q_proj.weight_scale": "local_colwise",
|
||||
"layers.*.self_attn.k_proj.weight": "local_colwise",
|
||||
"layers.*.self_attn.k_proj.weight_scale": "local_colwise",
|
||||
"layers.*.self_attn.v_proj.weight": "local_colwise",
|
||||
"layers.*.self_attn.v_proj.weight_scale": "local_colwise",
|
||||
"layers.*.self_attn.o_proj.weight": "local_rowwise",
|
||||
"layers.*.self_attn": "gather",
|
||||
"layers.*.input_layernorm.weight": "sequence_parallel",
|
||||
"layers.*.post_attention_layernorm.weight": "sequence_parallel",
|
||||
"norm.weight": "sequence_parallel",
|
||||
"layers.*.feed_forward.shared_expert.gate_proj.weight": "local_colwise",
|
||||
"layers.*.feed_forward.shared_expert.gate_proj.weight_scale": "local_colwise",
|
||||
"layers.*.feed_forward.shared_expert.up_proj.weight": "local_colwise",
|
||||
"layers.*.feed_forward.shared_expert.up_proj.weight_scale": "local_colwise",
|
||||
"layers.*.feed_forward.shared_expert.down_proj.weight": "local_rowwise",
|
||||
"layers.*.feed_forward.experts": "local",
|
||||
"layers.*.feed_forward": "gather",
|
||||
"layers.*.feed_forward.experts.*.gate_proj.weight": "local_colwise",
|
||||
"layers.*.feed_forward.experts.*.gate_proj.weight_scale": "local_colwise",
|
||||
"layers.*.feed_forward.experts.*.up_proj.weight": "local_colwise",
|
||||
"layers.*.feed_forward.experts.*.up_proj.weight_scale": "local_colwise",
|
||||
"layers.*.feed_forward.experts.*.down_proj.weight": "local_rowwise",
|
||||
# For Fused implementation
|
||||
"layers.*.feed_forward.experts.gate_up_proj": "local_packed_rowwise",
|
||||
"layers.*.feed_forward.experts.gate_up_proj_scale": "local_packed_rowwise",
|
||||
"layers.*.feed_forward.experts.down_proj": "local_colwise",
|
||||
}
|
||||
if config.get_text_config() is not None and config.get_text_config().base_model_tp_plan is not None:
|
||||
config.get_text_config().base_model_tp_plan = additional_text_plan
|
||||
|
||||
return config
|
||||
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return True
|
||||
|
||||
|
Reference in New Issue
Block a user