mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
[Longformer] Major Refactor (#5219)
* refactor naming * add small slow test * refactor * refactor naming * rename selected to extra * big global attention refactor * make style * refactor naming * save intermed * refactor functions * finish function refactor * fix tests * fix longformer * fix longformer * fix longformer * fix all tests but one * finish longformer * address sams and izs comments * fix transpose
This commit is contained in:
parent
e0d58ddb65
commit
d697b6ca75
File diff suppressed because it is too large
Load Diff
@ -811,7 +811,7 @@ class ModelTesterMixin:
|
||||
# Wrap model in nn.DataParallel
|
||||
model = torch.nn.DataParallel(model)
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
@ -115,6 +115,18 @@ class LongformerModelTester:
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
|
||||
def create_and_check_attention_mask_determinism(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = LongformerModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||
output_with_mask = model(input_ids, attention_mask=attention_mask)[0]
|
||||
output_without_mask = model(input_ids)[0]
|
||||
self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4))
|
||||
|
||||
def create_and_check_longformer_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
@ -134,6 +146,36 @@ class LongformerModelTester:
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_longformer_model_with_global_attention_mask(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = LongformerModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
global_attention_mask = input_mask.clone()
|
||||
global_attention_mask[:, input_mask.shape[-1] // 2] = 0
|
||||
global_attention_mask = global_attention_mask.to(torch_device)
|
||||
|
||||
sequence_output, pooled_output = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
sequence_output, pooled_output = model(
|
||||
input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask
|
||||
)
|
||||
sequence_output, pooled_output = model(input_ids, global_attention_mask=global_attention_mask)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_longformer_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
@ -243,7 +285,13 @@ class LongformerModelTester:
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
global_attention_mask = torch.zeros_like(input_ids)
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"attention_mask": input_mask,
|
||||
"global_attention_mask": global_attention_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_question_answering(self):
|
||||
@ -277,11 +325,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
(
|
||||
LongformerModel,
|
||||
LongformerForMaskedLM,
|
||||
# TODO: make tests pass for those models
|
||||
# LongformerForSequenceClassification,
|
||||
# LongformerForQuestionAnswering,
|
||||
# LongformerForTokenClassification,
|
||||
# LongformerForMultipleChoice,
|
||||
LongformerForSequenceClassification,
|
||||
LongformerForQuestionAnswering,
|
||||
LongformerForTokenClassification,
|
||||
LongformerForMultipleChoice,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
@ -298,6 +345,14 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_model(*config_and_inputs)
|
||||
|
||||
def test_longformer_model_attention_mask_determinism(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)
|
||||
|
||||
def test_longformer_model_global_attention_mask(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs)
|
||||
|
||||
def test_longformer_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs)
|
||||
@ -325,15 +380,31 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
|
||||
model.to(torch_device)
|
||||
|
||||
# 'Hello world!'
|
||||
input_ids = torch.tensor([[0, 20920, 232, 328, 1437, 2]], dtype=torch.long, device=torch_device)
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||
output = model(input_ids, attention_mask=attention_mask)[0]
|
||||
output_without_mask = model(input_ids)[0]
|
||||
|
||||
expected_output_slice = torch.tensor([0.0549, 0.1087, -0.1119, -0.0368, 0.0250], device=torch_device)
|
||||
self.assertTrue(torch.allclose(output[0, 0, -5:], expected_output_slice, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(output_without_mask[0, 0, -5:], expected_output_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_no_head_long(self):
|
||||
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
|
||||
model.to(torch_device)
|
||||
|
||||
# 'Hello world! ' repeated 1000 times
|
||||
input_ids = torch.tensor(
|
||||
[[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=torch.long, device=torch_device
|
||||
) # long input
|
||||
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
|
||||
attention_mask[:, [1, 4, 21]] = 2 # Set global attention on a few random positions
|
||||
global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=input_ids.device)
|
||||
global_attention_mask[:, [1, 4, 21]] = 1 # Set global attention on a few random positions
|
||||
|
||||
output = model(input_ids, attention_mask=attention_mask)[0]
|
||||
output = model(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)[0]
|
||||
|
||||
expected_output_sum = torch.tensor(74585.8594, device=torch_device)
|
||||
expected_output_mean = torch.tensor(0.0243, device=torch_device)
|
||||
@ -341,7 +412,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
def test_inference_masked_lm_long(self):
|
||||
model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
|
||||
model.to(torch_device)
|
||||
|
||||
@ -352,9 +423,9 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
loss, prediction_scores = model(input_ids, labels=input_ids)
|
||||
|
||||
expected_loss = torch.tensor(0.0620, device=torch_device)
|
||||
expected_prediction_scores_sum = torch.tensor(-6.1599e08, device=torch_device)
|
||||
expected_prediction_scores_mean = torch.tensor(-3.0622, device=torch_device)
|
||||
expected_loss = torch.tensor(0.0074, device=torch_device)
|
||||
expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device)
|
||||
expected_prediction_scores_mean = torch.tensor(-3.0348, device=torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(loss, expected_loss, atol=1e-4))
|
||||
|
Loading…
Reference in New Issue
Block a user