mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
add FlashAttentionKwargs and seq_idx to flat collator (#36456)
* add flash attn kwargs to flattening collator * add return_seq_idx option * doc string edits * cleaner max len updates * various fixes * temp testing code * return int32 seq_idx and FlashAttnKwargs * DataCollatorIntegrationTest impl * fix batch dims and dtypes * fill out remaining collator tests * test name change and fmt * rm unused var * fmt * minor change * fmt * add missing pos_ids check * consistent {np,pt,tf} tests * split pt tests into 3, like np/tf tests * mv comment, rename fa test * remove batch dim comment * simply wrapping * compute cu_seq_len/max_length once * fmt * remove tf code * rm warning * move separator_id back to 2nd pos * use cleaner lists in tests * ret -> batch * fmt * attr ordering * use py ints for max_length_{k,q}
This commit is contained in:
parent
9ddcf5fce5
commit
503541d7ef
@ -1974,9 +1974,11 @@ class DataCollatorWithFlattening(DefaultDataCollator):
|
||||
"""
|
||||
Data collator used for padding free approach. Does the following:
|
||||
|
||||
- concatate the entire mini batch into single long sequence [1, total_tokens]
|
||||
- concatenates the entire mini batch into single long sequence of shape [1, total_tokens]
|
||||
- uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100
|
||||
- no padding will be added, returns `input_ids`, `labels` and `position_ids`
|
||||
- no padding will be added, returns `input_ids`, `labels` and `position_ids` by default
|
||||
- optionally returns the kwargs contained in FlashAttentionKwargs
|
||||
- optionally returns seq_idx indicating which sequence each token belongs to
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
@ -1986,10 +1988,23 @@ class DataCollatorWithFlattening(DefaultDataCollator):
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
def __init__(self, *args, return_position_ids=True, separator_id=-100, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
return_position_ids=True,
|
||||
separator_id=-100,
|
||||
return_flash_attn_kwargs=False,
|
||||
return_seq_idx=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.return_position_ids = return_position_ids
|
||||
self.separator_id = separator_id
|
||||
self.return_flash_attn_kwargs = return_flash_attn_kwargs
|
||||
self.return_seq_idx = return_seq_idx
|
||||
self._int_64_keys = {"labels", "position_ids", "input_ids"}
|
||||
self._batch_dim_keys = {"labels", "position_ids", "input_ids", "seq_idx"}
|
||||
self._py_int_keys = {"max_length_q", "max_length_k"}
|
||||
|
||||
def __call__(self, features, return_tensors=None, separator_id=None):
|
||||
if return_tensors is None:
|
||||
@ -1997,15 +2012,52 @@ class DataCollatorWithFlattening(DefaultDataCollator):
|
||||
if separator_id is None:
|
||||
separator_id = self.separator_id
|
||||
is_labels_provided = "labels" in features[0]
|
||||
ret = {"input_ids": [], "labels": []}
|
||||
batch = {"input_ids": [], "labels": []}
|
||||
if self.return_position_ids:
|
||||
ret.update({"position_ids": []})
|
||||
for idx in range(0, len(features)):
|
||||
ret["input_ids"] += features[idx]["input_ids"]
|
||||
batch.update({"position_ids": []})
|
||||
if self.return_seq_idx:
|
||||
batch.update({"seq_idx": []})
|
||||
if self.return_flash_attn_kwargs:
|
||||
cu_seq_lens = [0]
|
||||
max_length = 0
|
||||
for seq_idx, sample in enumerate(features):
|
||||
input_ids = sample["input_ids"]
|
||||
batch["input_ids"] += input_ids
|
||||
if is_labels_provided:
|
||||
ret["labels"] += [separator_id] + features[idx]["labels"][1:]
|
||||
batch["labels"] += [separator_id] + sample["labels"][1:]
|
||||
else:
|
||||
ret["labels"] += [separator_id] + features[idx]["input_ids"][1:]
|
||||
batch["labels"] += [separator_id] + input_ids[1:]
|
||||
if self.return_position_ids:
|
||||
ret["position_ids"] += list(range(len(features[idx]["input_ids"])))
|
||||
return default_data_collator([ret], return_tensors)
|
||||
batch["position_ids"] += list(range(len(input_ids)))
|
||||
if self.return_seq_idx:
|
||||
batch["seq_idx"] += [seq_idx for _ in range(len(input_ids))]
|
||||
if self.return_flash_attn_kwargs:
|
||||
cu_seq_lens.append(cu_seq_lens[-1] + len(input_ids))
|
||||
max_length = max(max_length, len(input_ids))
|
||||
|
||||
if self.return_flash_attn_kwargs:
|
||||
batch["cu_seq_lens_q"] = batch["cu_seq_lens_k"] = cu_seq_lens
|
||||
batch["max_length_q"] = batch["max_length_k"] = max_length
|
||||
|
||||
# FlashAttentionKwargs and seq_idx are expected to be int32s.
|
||||
if return_tensors == "pt":
|
||||
import torch
|
||||
|
||||
data_cls = torch.tensor
|
||||
dtype_64 = torch.int64
|
||||
dtype_32 = torch.int32
|
||||
elif return_tensors == "np":
|
||||
data_cls = np.array
|
||||
dtype_64 = np.int64
|
||||
dtype_32 = np.int32
|
||||
else:
|
||||
raise ValueError(f'return_tensors must be one of ("pt", "np"), {return_tensors=} not suported')
|
||||
|
||||
for k, v in batch.items():
|
||||
if k in self._batch_dim_keys:
|
||||
v = [v]
|
||||
# Flash attention max_len_{q,k} are python ints
|
||||
if k not in self._py_int_keys:
|
||||
batch[k] = data_cls(v, dtype=dtype_64 if k in self._int_64_keys else dtype_32)
|
||||
|
||||
return batch
|
||||
|
@ -34,6 +34,7 @@ from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSequenceClassification,
|
||||
DataCollatorWithFlattening,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
is_torch_available,
|
||||
@ -4170,6 +4171,78 @@ class ModelTesterMixin:
|
||||
tol = torch.finfo(torch.float16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
max_new_tokens = 30
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
|
||||
self.skipTest("Model dummy inputs should contain padding in their attention mask")
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
dummy_input = dummy_input.to(torch.float16)
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
if "position_ids" not in inspect.signature(model.forward).parameters:
|
||||
self.skipTest("Model does not support position_ids")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
# ensure left padding, to adapt for some models
|
||||
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||
|
||||
model = (
|
||||
model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
# flatten
|
||||
features = [
|
||||
{"input_ids": i[a.bool()].tolist()}
|
||||
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
|
||||
]
|
||||
|
||||
# add position_ids + fa_kwargs
|
||||
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
|
||||
batch = data_collator(features)
|
||||
batch_cuda = {k: t.cuda() if torch.is_tensor(t) else t for k, t in batch.items()}
|
||||
|
||||
res_padded = model(**inputs_dict)
|
||||
res_padfree = model(**batch_cuda)
|
||||
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
|
||||
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||
# acceptable numerical instability
|
||||
tol = torch.finfo(torch.float16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
|
@ -126,6 +126,104 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
|
||||
|
||||
def test_data_collator_with_flattening(self):
|
||||
features = [
|
||||
{"input_ids": [10, 11, 12]},
|
||||
{"input_ids": [20, 21, 22, 23, 24, 25]},
|
||||
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
|
||||
]
|
||||
|
||||
data_collator = DataCollatorWithFlattening(return_tensors="pt")
|
||||
batch = data_collator(features)
|
||||
|
||||
for unexpected_key in [
|
||||
"attention_mask",
|
||||
"cu_seq_lens_k",
|
||||
"cu_seq_lens_q",
|
||||
"max_length_k",
|
||||
"max_length_q",
|
||||
"seq_idx",
|
||||
]:
|
||||
self.assertNotIn(unexpected_key, batch)
|
||||
self.assertIn("position_ids", batch)
|
||||
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
|
||||
self.assertEqual(
|
||||
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||
)
|
||||
self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16]))
|
||||
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||
|
||||
def test_data_collator_with_flattening_flash_attn_kwargs(self):
|
||||
features = [
|
||||
{"input_ids": [10, 11, 12]},
|
||||
{"input_ids": [20, 21, 22, 23, 24, 25]},
|
||||
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
|
||||
]
|
||||
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
|
||||
batch = data_collator(features)
|
||||
|
||||
for unexpected_key in [
|
||||
"attention_mask",
|
||||
"seq_idx",
|
||||
]:
|
||||
self.assertNotIn(unexpected_key, batch)
|
||||
for expected_key in [
|
||||
"position_ids",
|
||||
"cu_seq_lens_k",
|
||||
"cu_seq_lens_q",
|
||||
"max_length_k",
|
||||
"max_length_q",
|
||||
]:
|
||||
self.assertIn(expected_key, batch)
|
||||
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
|
||||
self.assertEqual(
|
||||
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||
)
|
||||
self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16]))
|
||||
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||
|
||||
self.assertEqual(batch["cu_seq_lens_k"].shape, torch.Size([4]))
|
||||
self.assertEqual(batch["cu_seq_lens_k"].tolist(), [0, 3, 9, 16])
|
||||
self.assertEqual(batch["cu_seq_lens_q"].shape, torch.Size([4]))
|
||||
self.assertEqual(batch["cu_seq_lens_q"].tolist(), [0, 3, 9, 16])
|
||||
# The flash attn max_length_{k,q} are simple python ints
|
||||
self.assertEqual(batch["max_length_k"], 7)
|
||||
self.assertEqual(batch["max_length_q"], 7)
|
||||
|
||||
def test_data_collator_with_flattening_seq_idx(self):
|
||||
features = [
|
||||
{"input_ids": [10, 11, 12]},
|
||||
{"input_ids": [20, 21, 22, 23, 24, 25]},
|
||||
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
|
||||
]
|
||||
data_collator = DataCollatorWithFlattening(return_tensors="pt", return_seq_idx=True)
|
||||
batch = data_collator(features)
|
||||
|
||||
for unexpected_key in [
|
||||
"attention_mask",
|
||||
"cu_seq_lens_k",
|
||||
"cu_seq_lens_q",
|
||||
"max_length_k",
|
||||
"max_length_q",
|
||||
]:
|
||||
self.assertNotIn(unexpected_key, batch)
|
||||
for expected_key in [
|
||||
"position_ids",
|
||||
"seq_idx",
|
||||
]:
|
||||
self.assertIn(expected_key, batch)
|
||||
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
|
||||
self.assertEqual(
|
||||
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||
)
|
||||
self.assertEqual(batch["position_ids"].shape, torch.Size([1, 16]))
|
||||
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||
self.assertEqual(batch["seq_idx"].shape, batch["input_ids"].shape)
|
||||
self.assertEqual(batch["seq_idx"][0].tolist(), [0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2])
|
||||
|
||||
def test_data_collator_for_token_classification(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [
|
||||
@ -1803,15 +1901,97 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
|
||||
|
||||
data_collator = DataCollatorWithFlattening(return_tensors="np")
|
||||
batch = data_collator(features)
|
||||
|
||||
for unexpected_key in [
|
||||
"attention_mask",
|
||||
"cu_seq_lens_k",
|
||||
"cu_seq_lens_q",
|
||||
"max_length_k",
|
||||
"max_length_q",
|
||||
"seq_idx",
|
||||
]:
|
||||
self.assertNotIn(unexpected_key, batch)
|
||||
self.assertIn("position_ids", batch)
|
||||
|
||||
self.assertEqual(batch["input_ids"].shape, (1, 16))
|
||||
self.assertEqual(
|
||||
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||
)
|
||||
self.assertNotIn("attention_mask", batch)
|
||||
self.assertIn("position_ids", batch)
|
||||
self.assertEqual(batch["position_ids"].shape, (1, 16))
|
||||
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||
|
||||
def test_data_collator_with_flattening_flash_attn_kwargs(self):
|
||||
features = [
|
||||
{"input_ids": [10, 11, 12]},
|
||||
{"input_ids": [20, 21, 22, 23, 24, 25]},
|
||||
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
|
||||
]
|
||||
|
||||
data_collator = DataCollatorWithFlattening(return_tensors="np", return_flash_attn_kwargs=True)
|
||||
batch = data_collator(features)
|
||||
|
||||
for unexpected_key in [
|
||||
"attention_mask",
|
||||
"seq_idx",
|
||||
]:
|
||||
self.assertNotIn(unexpected_key, batch)
|
||||
for expected_key in [
|
||||
"position_ids",
|
||||
"cu_seq_lens_k",
|
||||
"cu_seq_lens_q",
|
||||
"max_length_k",
|
||||
"max_length_q",
|
||||
]:
|
||||
self.assertIn(expected_key, batch)
|
||||
|
||||
self.assertEqual(batch["input_ids"].shape, (1, 16))
|
||||
self.assertEqual(
|
||||
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||
)
|
||||
self.assertEqual(batch["position_ids"].shape, (1, 16))
|
||||
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||
|
||||
self.assertEqual(batch["cu_seq_lens_k"].shape, (4,))
|
||||
self.assertEqual(batch["cu_seq_lens_k"].tolist(), [0, 3, 9, 16])
|
||||
self.assertEqual(batch["cu_seq_lens_q"].shape, (4,))
|
||||
self.assertEqual(batch["cu_seq_lens_q"].tolist(), [0, 3, 9, 16])
|
||||
# The flash attn max_length_{k,q} are simple python ints
|
||||
self.assertEqual(batch["max_length_k"], 7)
|
||||
self.assertEqual(batch["max_length_q"], 7)
|
||||
|
||||
def test_data_collator_with_flattening_seq_idx(self):
|
||||
features = [
|
||||
{"input_ids": [10, 11, 12]},
|
||||
{"input_ids": [20, 21, 22, 23, 24, 25]},
|
||||
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
|
||||
]
|
||||
|
||||
data_collator = DataCollatorWithFlattening(return_tensors="np", return_seq_idx=True)
|
||||
batch = data_collator(features)
|
||||
|
||||
for unexpected_key in [
|
||||
"attention_mask",
|
||||
"cu_seq_lens_k",
|
||||
"cu_seq_lens_q",
|
||||
"max_length_k",
|
||||
"max_length_q",
|
||||
]:
|
||||
self.assertNotIn(unexpected_key, batch)
|
||||
for expected_key in [
|
||||
"position_ids",
|
||||
"seq_idx",
|
||||
]:
|
||||
self.assertIn(expected_key, batch)
|
||||
|
||||
self.assertEqual(batch["input_ids"].shape, (1, 16))
|
||||
self.assertEqual(
|
||||
batch["input_ids"][0].tolist(), [10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36]
|
||||
)
|
||||
self.assertEqual(batch["position_ids"].shape, (1, 16))
|
||||
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
||||
self.assertEqual(batch["seq_idx"].shape, batch["input_ids"].shape)
|
||||
self.assertEqual(batch["seq_idx"][0].tolist(), [0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2])
|
||||
|
||||
def test_data_collator_for_token_classification(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [
|
||||
|
Loading…
Reference in New Issue
Block a user