Attention Quantization with FBGemm & TP (#37384)

* fix

* keep fused

* contiguous

* rm print

* update

* update

* rm print
This commit is contained in:
Mohamed Mekkouri 2025-04-09 18:45:42 +02:00 committed by GitHub
parent c5c648dd74
commit f834ca2c19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 42 additions and 7 deletions

View File

@ -50,7 +50,7 @@ class FbgemmFp8Linear(torch.nn.Linear):
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x.view(-1, x.shape[-1]), scale_ub=self.input_scale_ub
x.view(-1, x.shape[-1]).contiguous(), scale_ub=self.input_scale_ub
)
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
@ -207,9 +207,6 @@ def _replace_with_fbgemm_fp8_linear(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
with init_empty_weights(include_buffers=True):
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj_scale")] = tp_plan[
re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj")
]
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None
model._modules[name] = FbgemmFp8Llama4TextExperts(
config.text_config,

View File

@ -219,7 +219,7 @@ class GatherParallel(TensorParallelLayer):
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
if isinstance(inputs[0], DTensor):
if inputs and isinstance(inputs[0], DTensor):
inputs = inputs[0].to_local()
return inputs

View File

@ -220,7 +220,6 @@ class HfQuantizer(ABC):
"""
model.is_quantized = True
model.quantization_method = self.quantization_config.quant_method
print("self.pre_quantized", self.pre_quantized)
if self.pre_quantized:
self._convert_model_for_quantization(model)
return self._process_model_before_weight_loading(model, **kwargs)
@ -345,6 +344,9 @@ class SequentialLlama4TextExperts(ModuleList):
MODULES_TO_PATCH_FOR_QUANTIZATION = {
"Llama4TextExperts": {
"module_name": SequentialLlama4TextExperts,
"quantization_methods": [QuantizationMethod.COMPRESSED_TENSORS, QuantizationMethod.BITS_AND_BYTES],
"quantization_methods": [
QuantizationMethod.COMPRESSED_TENSORS,
QuantizationMethod.BITS_AND_BYTES,
],
}
}

View File

@ -241,6 +241,42 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
not_missing_keys.append(missing)
return [k for k in missing_keys if k not in not_missing_keys]
def update_tp_plan(self, config):
text_plan = {
"layers.*.self_attn.q_proj.weight": "local_colwise",
"layers.*.self_attn.q_proj.weight_scale": "local_colwise",
"layers.*.self_attn.k_proj.weight": "local_colwise",
"layers.*.self_attn.k_proj.weight_scale": "local_colwise",
"layers.*.self_attn.v_proj.weight": "local_colwise",
"layers.*.self_attn.v_proj.weight_scale": "local_colwise",
"layers.*.self_attn.o_proj.weight": "local_rowwise",
"layers.*.self_attn": "gather",
"layers.*.input_layernorm.weight": "sequence_parallel",
"layers.*.post_attention_layernorm.weight": "sequence_parallel",
"norm.weight": "sequence_parallel",
"layers.*.feed_forward.shared_expert.gate_proj.weight": "local_colwise",
"layers.*.feed_forward.shared_expert.gate_proj.weight_scale": "local_colwise",
"layers.*.feed_forward.shared_expert.up_proj.weight": "local_colwise",
"layers.*.feed_forward.shared_expert.up_proj.weight_scale": "local_colwise",
"layers.*.feed_forward.shared_expert.down_proj.weight": "local_rowwise",
"layers.*.feed_forward.experts": "local",
"layers.*.feed_forward": "gather",
"layers.*.feed_forward.experts.*.gate_proj.weight": "local_colwise",
"layers.*.feed_forward.experts.*.gate_proj.weight_scale": "local_colwise",
"layers.*.feed_forward.experts.*.up_proj.weight": "local_colwise",
"layers.*.feed_forward.experts.*.up_proj.weight_scale": "local_colwise",
"layers.*.feed_forward.experts.*.down_proj.weight": "local_rowwise",
# For Fused implementation
"layers.*.feed_forward.experts.gate_up_proj": "local_packed_rowwise",
"layers.*.feed_forward.experts.gate_up_proj_scale": "local_packed_rowwise",
"layers.*.feed_forward.experts.down_proj": "local_colwise",
}
if config.get_text_config() is not None:
config.get_text_config().base_model_tp_plan = text_plan
else:
config.base_model_tp_plan = text_plan
return config
def is_serializable(self, safe_serialization=None):
return True