mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[Longformer] Fix longformer documentation (#7016)
* fix longformer * allow position ids to not be initialized
This commit is contained in:
parent
5c4eb4b1ac
commit
120176ea29
@ -795,6 +795,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = LongformerConfig
|
||||
base_model_prefix = "longformer"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
@ -1019,11 +1020,13 @@ class LongformerModel(LongformerPreTrainedModel):
|
||||
|
||||
>>> # Attention mask values -- 0: no attention, 1: local attention, 2: global attention
|
||||
>>> attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention
|
||||
>>> attention_mask[:, [1, 4, 21,]] = 2 # Set global attention based on the task. For example,
|
||||
>>> global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to global attention to be deactivated for all tokens
|
||||
>>> global_attention_mask[:, [1, 4, 21,]] = 1 # Set global attention to random tokens for the sake of this example
|
||||
... # Usually, set global attention based on the task. For example,
|
||||
... # classification: the <s> token
|
||||
... # QA: question tokens
|
||||
... # LM: potentially on the beginning of sentences and paragraphs
|
||||
>>> outputs = model(input_ids, attention_mask=attention_mask)
|
||||
>>> outputs = model(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)
|
||||
>>> sequence_output = outputs.last_hidden_state
|
||||
>>> pooled_output = outputs.pooler_output
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user