Compare commits

...

8 Commits

Author SHA1 Message Date
7960f1d9d7 Merge branch 'fix_fbgemm_tp' into merging_to_test 2025-04-09 10:58:37 +00:00
2024e3b5e6 contiguous 2025-04-09 10:58:08 +00:00
091657e23f Merge branch 'fix_fbgemm_tp' into merging_to_test 2025-04-09 10:30:08 +00:00
634c9c6661 keep fused 2025-04-09 10:28:30 +00:00
4500320ab4 Merge branch 'flex-fix-regression' into merging_to_test 2025-04-09 09:47:38 +00:00
b3e08ec108 fix 2025-04-09 06:53:08 +00:00
c6f442aed6 First fix flex 2025-04-08 18:40:16 +02:00
c7d1e731ec the fix that did not get in 2025-04-08 15:22:10 +00:00
6 changed files with 43 additions and 18 deletions

View File

@ -50,7 +50,7 @@ class FbgemmFp8Linear(torch.nn.Linear):
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x.view(-1, x.shape[-1]), scale_ub=self.input_scale_ub
x.view(-1, x.shape[-1]).contiguous(), scale_ub=self.input_scale_ub
)
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
@ -207,9 +207,6 @@ def _replace_with_fbgemm_fp8_linear(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
with init_empty_weights(include_buffers=True):
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj_scale")] = tp_plan[
re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj")
]
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None
model._modules[name] = FbgemmFp8Llama4TextExperts(
config.text_config,

View File

@ -70,7 +70,7 @@ class WrappedFlexAttention:
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
)
else:
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
self._compiled_flex_attention = torch.compile(flex_attention, fullgraph=True)
self._is_flex_compiled = True
def __call__(self):

View File

@ -219,7 +219,7 @@ class GatherParallel(TensorParallelLayer):
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
if isinstance(inputs[0], DTensor):
if inputs and isinstance(inputs[0], DTensor):
inputs = inputs[0].to_local()
return inputs

View File

@ -738,18 +738,8 @@ class Llama4TextModel(Llama4PreTrainedModel):
else:
full_cache_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length
# to avoid graph break, we introduce this hack
cond1 = first_cache_position >= attention_chunk_size
cond2 = (first_cache_position < attention_chunk_size) & (
first_cache_position + sequence_length > attention_chunk_size
)
key_length = (
torch.where(
cond1,
attention_chunk_size + sequence_length - 1,
torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size),
)
sequence_length if sequence_length > attention_chunk_size else attention_chunk_size
if use_cache
else full_cache_length
)

View File

@ -342,6 +342,9 @@ class SequentialLlama4TextExperts(ModuleList):
MODULES_TO_PATCH_FOR_QUANTIZATION = {
"Llama4TextExperts": {
"module_name": SequentialLlama4TextExperts,
"quantization_methods": [QuantizationMethod.COMPRESSED_TENSORS, QuantizationMethod.BITS_AND_BYTES],
"quantization_methods": [
QuantizationMethod.COMPRESSED_TENSORS,
QuantizationMethod.BITS_AND_BYTES,
],
}
}

View File

@ -241,6 +241,41 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
not_missing_keys.append(missing)
return [k for k in missing_keys if k not in not_missing_keys]
def update_tp_plan(self, config):
additional_text_plan = {
"layers.*.self_attn.q_proj.weight": "local_colwise",
"layers.*.self_attn.q_proj.weight_scale": "local_colwise",
"layers.*.self_attn.k_proj.weight": "local_colwise",
"layers.*.self_attn.k_proj.weight_scale": "local_colwise",
"layers.*.self_attn.v_proj.weight": "local_colwise",
"layers.*.self_attn.v_proj.weight_scale": "local_colwise",
"layers.*.self_attn.o_proj.weight": "local_rowwise",
"layers.*.self_attn": "gather",
"layers.*.input_layernorm.weight": "sequence_parallel",
"layers.*.post_attention_layernorm.weight": "sequence_parallel",
"norm.weight": "sequence_parallel",
"layers.*.feed_forward.shared_expert.gate_proj.weight": "local_colwise",
"layers.*.feed_forward.shared_expert.gate_proj.weight_scale": "local_colwise",
"layers.*.feed_forward.shared_expert.up_proj.weight": "local_colwise",
"layers.*.feed_forward.shared_expert.up_proj.weight_scale": "local_colwise",
"layers.*.feed_forward.shared_expert.down_proj.weight": "local_rowwise",
"layers.*.feed_forward.experts": "local",
"layers.*.feed_forward": "gather",
"layers.*.feed_forward.experts.*.gate_proj.weight": "local_colwise",
"layers.*.feed_forward.experts.*.gate_proj.weight_scale": "local_colwise",
"layers.*.feed_forward.experts.*.up_proj.weight": "local_colwise",
"layers.*.feed_forward.experts.*.up_proj.weight_scale": "local_colwise",
"layers.*.feed_forward.experts.*.down_proj.weight": "local_rowwise",
# For Fused implementation
"layers.*.feed_forward.experts.gate_up_proj": "local_packed_rowwise",
"layers.*.feed_forward.experts.gate_up_proj_scale": "local_packed_rowwise",
"layers.*.feed_forward.experts.down_proj": "local_colwise",
}
if config.get_text_config() is not None and config.get_text_config().base_model_tp_plan is not None:
config.get_text_config().base_model_tp_plan = additional_text_plan
return config
def is_serializable(self, safe_serialization=None):
return True