add FlashAttentionKwargs and seq_idx to flat collator (#36456)

* add flash attn kwargs to flattening collator

* add return_seq_idx option

* doc string edits

* cleaner max len updates

* various fixes

* temp testing code

* return int32 seq_idx and FlashAttnKwargs

* DataCollatorIntegrationTest impl

* fix batch dims and dtypes

* fill out remaining collator tests

* test name change and fmt

* rm unused var

* fmt

* minor change

* fmt

* add missing pos_ids check

* consistent {np,pt,tf} tests

* split pt tests into 3, like np/tf tests

* mv comment, rename fa test

* remove batch dim comment

* simply wrapping

* compute cu_seq_len/max_length once

* fmt

* remove tf code

* rm warning

* move separator_id back to 2nd pos

* use cleaner lists in tests

* ret -> batch

* fmt

* attr ordering

* use py ints for max_length_{k,q}
This commit is contained in:
Garrett Goon 2025-04-16 09:45:03 -04:00 committed by GitHub
parent 9ddcf5fce5
commit 503541d7ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 318 additions and 13 deletions

View File

@ -1974,9 +1974,11 @@ class DataCollatorWithFlattening(DefaultDataCollator):
"""
Data collator used for padding free approach. Does the following:
- concatate the entire mini batch into single long sequence [1, total_tokens]
- concatenates the entire mini batch into single long sequence of shape [1, total_tokens]
- uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100
- no padding will be added, returns `input_ids`, `labels` and `position_ids`
- no padding will be added, returns `input_ids`, `labels` and `position_ids` by default
- optionally returns the kwargs contained in FlashAttentionKwargs
- optionally returns seq_idx indicating which sequence each token belongs to
<Tip warning={true}>
@ -1986,10 +1988,23 @@ class DataCollatorWithFlattening(DefaultDataCollator):
</Tip>
"""
def __init__(self, *args, return_position_ids=True, separator_id=-100, **kwargs):
def __init__(
self,
*args,
return_position_ids=True,
separator_id=-100,
return_flash_attn_kwargs=False,
return_seq_idx=False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.return_position_ids = return_position_ids
self.separator_id = separator_id
self.return_flash_attn_kwargs = return_flash_attn_kwargs
self.return_seq_idx = return_seq_idx
self._int_64_keys = {"labels", "position_ids", "input_ids"}
self._batch_dim_keys = {"labels", "position_ids", "input_ids", "seq_idx"}
self._py_int_keys = {"max_length_q", "max_length_k"}
def __call__(self, features, return_tensors=None, separator_id=None):
if return_tensors is None:
@ -1997,15 +2012,52 @@ class DataCollatorWithFlattening(DefaultDataCollator):
if separator_id is None:
separator_id = self.separator_id
is_labels_provided = "labels" in features[0]
ret = {"input_ids": [], "labels": []}
batch = {"input_ids": [], "labels": []}
if self.return_position_ids:
ret.update({"position_ids": []})
for idx in range(0, len(features)):
ret["input_ids"] += features[idx]["input_ids"]
batch.update({"position_ids": []})
if self.return_seq_idx:
batch.update({"seq_idx": []})
if self.return_flash_attn_kwargs:
cu_seq_lens = [0]
max_length = 0
for seq_idx, sample in enumerate(features):
input_ids = sample["input_ids"]
batch["input_ids"] += input_ids
if is_labels_provided:
ret["labels"] += [separator_id] + features[idx]["labels"][1:]
batch["labels"] += [separator_id] + sample["labels"][1:]
else:
ret["labels"] += [separator_id] + features[idx]["input_ids"][1:]
batch["labels"] += [separator_id] + input_ids[1:]
if self.return_position_ids:
ret["position_ids"] += list(range(len(features[idx]["input_ids"])))
return default_data_collator([ret], return_tensors)
batch["position_ids"] += list(range(len(input_ids)))
if self.return_seq_idx:
batch["seq_idx"] += [seq_idx for _ in range(len(input_ids))]
if self.return_flash_attn_kwargs:
cu_seq_lens.append(cu_seq_lens[-1] + len(input_ids))
max_length = max(max_length, len(input_ids))
if self.return_flash_attn_kwargs:
batch["cu_seq_lens_q"] = batch["cu_seq_lens_k"] = cu_seq_lens
batch["max_length_q"] = batch["max_length_k"] = max_length
# FlashAttentionKwargs and seq_idx are expected to be int32s.
if return_tensors == "pt":
import torch
data_cls = torch.tensor
dtype_64 = torch.int64
dtype_32 = torch.int32
elif return_tensors == "np":
data_cls = np.array
dtype_64 = np.int64
dtype_32 = np.int32
else:
raise ValueError(f'return_tensors must be one of ("pt", "np"), {return_tensors=} not suported')
for k, v in batch.items():
if k in self._batch_dim_keys:
v = [v]
# Flash attention max_len_{q,k} are python ints
if k not in self._py_int_keys:
batch[k] = data_cls(v, dtype=dtype_64 if k in self._int_64_keys else dtype_32)
return batch

View File

@ -34,6 +34,7 @@ from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
DataCollatorWithFlattening,
PretrainedConfig,
PreTrainedModel,
is_torch_available,
@ -4170,6 +4171,78 @@ class ModelTesterMixin:
tol = torch.finfo(torch.float16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
self.skipTest("Model dummy inputs should contain padding in their attention mask")
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
if "position_ids" not in inspect.signature(model.forward).parameters:
self.skipTest("Model does not support position_ids")
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# ensure left padding, to adapt for some models
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
model = (
model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
)
.to(torch_device)
.eval()
)
# flatten
features = [
{"input_ids": i[a.bool()].tolist()}
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
]
# add position_ids + fa_kwargs
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
batch = data_collator(features)
batch_cuda = {k: t.cuda() if torch.is_tensor(t) else t for k, t in batch.items()}
res_padded = model(**inputs_dict)
res_padfree = model(**batch_cuda)
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padfree = res_padfree.logits[0]
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
# acceptable numerical instability
tol = torch.finfo(torch.float16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test

View File

@ -126,6 +126,104 @@ class DataCollatorIntegrationTest(unittest.TestCase):
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
def test_data_collator_with_flattening(self):
features = [
{"input_ids": [10, 11, 12]},
{"input_ids": [20, 21, 22, 23, 24, 25]},
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
]
data_collator = DataCollatorWithFlattening(return_tensors="pt")
batch = data_collator(features)
for unexpected_key in [
"attention_mask",
"cu_seq_lens_k",
"cu_seq_lens_q",
"max_length_k",
"max_length_q",
"seq_idx",
]:
self.assertNotIn(unexpected_key, batch)
self.assertIn("position_ids", batch)
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
self.assertEqual(
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
)
self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16]))
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
def test_data_collator_with_flattening_flash_attn_kwargs(self):
features = [
{"input_ids": [10, 11, 12]},
{"input_ids": [20, 21, 22, 23, 24, 25]},
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
]
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
batch = data_collator(features)
for unexpected_key in [
"attention_mask",
"seq_idx",
]:
self.assertNotIn(unexpected_key, batch)
for expected_key in [
"position_ids",
"cu_seq_lens_k",
"cu_seq_lens_q",
"max_length_k",
"max_length_q",
]:
self.assertIn(expected_key, batch)
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
self.assertEqual(
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
)
self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16]))
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
self.assertEqual(batch["cu_seq_lens_k"].shape, torch.Size([4]))
self.assertEqual(batch["cu_seq_lens_k"].tolist(), [0, 3, 9, 16])
self.assertEqual(batch["cu_seq_lens_q"].shape, torch.Size([4]))
self.assertEqual(batch["cu_seq_lens_q"].tolist(), [0, 3, 9, 16])
# The flash attn max_length_{k,q} are simple python ints
self.assertEqual(batch["max_length_k"], 7)
self.assertEqual(batch["max_length_q"], 7)
def test_data_collator_with_flattening_seq_idx(self):
features = [
{"input_ids": [10, 11, 12]},
{"input_ids": [20, 21, 22, 23, 24, 25]},
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
]
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_seq_idx=True)
batch = data_collator(features)
for unexpected_key in [
"attention_mask",
"cu_seq_lens_k",
"cu_seq_lens_q",
"max_length_k",
"max_length_q",
]:
self.assertNotIn(unexpected_key, batch)
for expected_key in [
"position_ids",
"seq_idx",
]:
self.assertIn(expected_key, batch)
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
self.assertEqual(
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
)
self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16]))
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
self.assertEqual(batch["seq_idx"].shape, batch["input_ids"].shape)
self.assertEqual(batch["seq_idx"][0].tolist(), [0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2])
def test_data_collator_for_token_classification(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [
@ -1803,15 +1901,97 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
data_collator = DataCollatorWithFlattening(return_tensors="np")
batch = data_collator(features)
for unexpected_key in [
"attention_mask",
"cu_seq_lens_k",
"cu_seq_lens_q",
"max_length_k",
"max_length_q",
"seq_idx",
]:
self.assertNotIn(unexpected_key, batch)
self.assertIn("position_ids", batch)
self.assertEqual(batch["input_ids"].shape, (1, 16))
self.assertEqual(
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
)
self.assertNotIn("attention_mask", batch)
self.assertIn("position_ids", batch)
self.assertEqual(batch["position_ids"].shape, (1, 16))
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
def test_data_collator_with_flattening_flash_attn_kwargs(self):
features = [
{"input_ids": [10, 11, 12]},
{"input_ids": [20, 21, 22, 23, 24, 25]},
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
]
data_collator = DataCollatorWithFlattening(return_tensors="np", return_flash_attn_kwargs=True)
batch = data_collator(features)
for unexpected_key in [
"attention_mask",
"seq_idx",
]:
self.assertNotIn(unexpected_key, batch)
for expected_key in [
"position_ids",
"cu_seq_lens_k",
"cu_seq_lens_q",
"max_length_k",
"max_length_q",
]:
self.assertIn(expected_key, batch)
self.assertEqual(batch["input_ids"].shape, (1, 16))
self.assertEqual(
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
)
self.assertEqual(batch["position_ids"].shape, (1, 16))
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
self.assertEqual(batch["cu_seq_lens_k"].shape, (4,))
self.assertEqual(batch["cu_seq_lens_k"].tolist(), [0, 3, 9, 16])
self.assertEqual(batch["cu_seq_lens_q"].shape, (4,))
self.assertEqual(batch["cu_seq_lens_q"].tolist(), [0, 3, 9, 16])
# The flash attn max_length_{k,q} are simple python ints
self.assertEqual(batch["max_length_k"], 7)
self.assertEqual(batch["max_length_q"], 7)
def test_data_collator_with_flattening_seq_idx(self):
features = [
{"input_ids": [10, 11, 12]},
{"input_ids": [20, 21, 22, 23, 24, 25]},
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
]
data_collator = DataCollatorWithFlattening(return_tensors="np", return_seq_idx=True)
batch = data_collator(features)
for unexpected_key in [
"attention_mask",
"cu_seq_lens_k",
"cu_seq_lens_q",
"max_length_k",
"max_length_q",
]:
self.assertNotIn(unexpected_key, batch)
for expected_key in [
"position_ids",
"seq_idx",
]:
self.assertIn(expected_key, batch)
self.assertEqual(batch["input_ids"].shape, (1, 16))
self.assertEqual(
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
)
self.assertEqual(batch["position_ids"].shape, (1, 16))
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
self.assertEqual(batch["seq_idx"].shape, batch["input_ids"].shape)
self.assertEqual(batch["seq_idx"][0].tolist(), [0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2])
def test_data_collator_for_token_classification(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [