mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Attention Quantization with FBGemm & TP (#37384)
* fix * keep fused * contiguous * rm print * update * update * rm print
This commit is contained in:
parent
c5c648dd74
commit
f834ca2c19
@ -50,7 +50,7 @@ class FbgemmFp8Linear(torch.nn.Linear):
|
||||
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
|
||||
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
|
||||
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
x.view(-1, x.shape[-1]), scale_ub=self.input_scale_ub
|
||||
x.view(-1, x.shape[-1]).contiguous(), scale_ub=self.input_scale_ub
|
||||
)
|
||||
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
|
||||
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
|
||||
@ -207,9 +207,6 @@ def _replace_with_fbgemm_fp8_linear(
|
||||
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
|
||||
):
|
||||
with init_empty_weights(include_buffers=True):
|
||||
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj_scale")] = tp_plan[
|
||||
re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj")
|
||||
]
|
||||
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None
|
||||
model._modules[name] = FbgemmFp8Llama4TextExperts(
|
||||
config.text_config,
|
||||
|
@ -219,7 +219,7 @@ class GatherParallel(TensorParallelLayer):
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
if isinstance(inputs[0], DTensor):
|
||||
if inputs and isinstance(inputs[0], DTensor):
|
||||
inputs = inputs[0].to_local()
|
||||
return inputs
|
||||
|
||||
|
@ -220,7 +220,6 @@ class HfQuantizer(ABC):
|
||||
"""
|
||||
model.is_quantized = True
|
||||
model.quantization_method = self.quantization_config.quant_method
|
||||
print("self.pre_quantized", self.pre_quantized)
|
||||
if self.pre_quantized:
|
||||
self._convert_model_for_quantization(model)
|
||||
return self._process_model_before_weight_loading(model, **kwargs)
|
||||
@ -345,6 +344,9 @@ class SequentialLlama4TextExperts(ModuleList):
|
||||
MODULES_TO_PATCH_FOR_QUANTIZATION = {
|
||||
"Llama4TextExperts": {
|
||||
"module_name": SequentialLlama4TextExperts,
|
||||
"quantization_methods": [QuantizationMethod.COMPRESSED_TENSORS, QuantizationMethod.BITS_AND_BYTES],
|
||||
"quantization_methods": [
|
||||
QuantizationMethod.COMPRESSED_TENSORS,
|
||||
QuantizationMethod.BITS_AND_BYTES,
|
||||
],
|
||||
}
|
||||
}
|
||||
|
@ -241,6 +241,42 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
|
||||
not_missing_keys.append(missing)
|
||||
return [k for k in missing_keys if k not in not_missing_keys]
|
||||
|
||||
def update_tp_plan(self, config):
|
||||
text_plan = {
|
||||
"layers.*.self_attn.q_proj.weight": "local_colwise",
|
||||
"layers.*.self_attn.q_proj.weight_scale": "local_colwise",
|
||||
"layers.*.self_attn.k_proj.weight": "local_colwise",
|
||||
"layers.*.self_attn.k_proj.weight_scale": "local_colwise",
|
||||
"layers.*.self_attn.v_proj.weight": "local_colwise",
|
||||
"layers.*.self_attn.v_proj.weight_scale": "local_colwise",
|
||||
"layers.*.self_attn.o_proj.weight": "local_rowwise",
|
||||
"layers.*.self_attn": "gather",
|
||||
"layers.*.input_layernorm.weight": "sequence_parallel",
|
||||
"layers.*.post_attention_layernorm.weight": "sequence_parallel",
|
||||
"norm.weight": "sequence_parallel",
|
||||
"layers.*.feed_forward.shared_expert.gate_proj.weight": "local_colwise",
|
||||
"layers.*.feed_forward.shared_expert.gate_proj.weight_scale": "local_colwise",
|
||||
"layers.*.feed_forward.shared_expert.up_proj.weight": "local_colwise",
|
||||
"layers.*.feed_forward.shared_expert.up_proj.weight_scale": "local_colwise",
|
||||
"layers.*.feed_forward.shared_expert.down_proj.weight": "local_rowwise",
|
||||
"layers.*.feed_forward.experts": "local",
|
||||
"layers.*.feed_forward": "gather",
|
||||
"layers.*.feed_forward.experts.*.gate_proj.weight": "local_colwise",
|
||||
"layers.*.feed_forward.experts.*.gate_proj.weight_scale": "local_colwise",
|
||||
"layers.*.feed_forward.experts.*.up_proj.weight": "local_colwise",
|
||||
"layers.*.feed_forward.experts.*.up_proj.weight_scale": "local_colwise",
|
||||
"layers.*.feed_forward.experts.*.down_proj.weight": "local_rowwise",
|
||||
# For Fused implementation
|
||||
"layers.*.feed_forward.experts.gate_up_proj": "local_packed_rowwise",
|
||||
"layers.*.feed_forward.experts.gate_up_proj_scale": "local_packed_rowwise",
|
||||
"layers.*.feed_forward.experts.down_proj": "local_colwise",
|
||||
}
|
||||
if config.get_text_config() is not None:
|
||||
config.get_text_config().base_model_tp_plan = text_plan
|
||||
else:
|
||||
config.base_model_tp_plan = text_plan
|
||||
return config
|
||||
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return True
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user