mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
camembert
This commit is contained in:
parent
5a05d5a5bc
commit
d9f0a8a304
File diff suppressed because it is too large
Load Diff
589
src/transformers/models/camembert/modular_camembert.py
Normal file
589
src/transformers/models/camembert/modular_camembert.py
Normal file
@ -0,0 +1,589 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 Inria, Facebook AI Research and the HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch CamemBERT model."""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...modeling_outputs import (
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
MaskedLMOutput,
|
||||
MultipleChoiceModelOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...utils import auto_docstring
|
||||
from ..roberta.modeling_roberta import (
|
||||
RobertaForCausalLM,
|
||||
RobertaForMaskedLM,
|
||||
RobertaForMultipleChoice,
|
||||
RobertaForQuestionAnswering,
|
||||
RobertaForSequenceClassification,
|
||||
RobertaForTokenClassification,
|
||||
RobertaModel,
|
||||
RobertaPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
class CamembertPreTrainedModel(RobertaPreTrainedModel):
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
|
||||
class CamembertModel(RobertaModel):
|
||||
pass
|
||||
|
||||
|
||||
class CamembertForMaskedLM(RobertaForMaskedLM):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
del self.camembert
|
||||
|
||||
self.roberta = CamembertModel(config, add_pooling_layer=False)
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
|
||||
r"""
|
||||
token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
|
||||
|
||||
- 0 corresponds to a *sentence A* token,
|
||||
- 1 corresponds to a *sentence B* token.
|
||||
This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
|
||||
>= 2. All the value in this tensor should be always < type_vocab_size.
|
||||
|
||||
[What are token type IDs?](../glossary#token-type-ids)
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
||||
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
||||
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.roberta(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
|
||||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(prediction_scores.device)
|
||||
loss_fct = CrossEntropyLoss()
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||||
|
||||
return MaskedLMOutput(
|
||||
loss=masked_lm_loss,
|
||||
logits=prediction_scores,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class CamembertForSequenceClassification(RobertaForSequenceClassification):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
del self.camembert
|
||||
|
||||
self.roberta = CamembertModel(config, add_pooling_layer=False)
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
|
||||
r"""
|
||||
token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
|
||||
|
||||
- 0 corresponds to a *sentence A* token,
|
||||
- 1 corresponds to a *sentence B* token.
|
||||
This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
|
||||
>= 2. All the value in this tensor should be always < type_vocab_size.
|
||||
|
||||
[What are token type IDs?](../glossary#token-type-ids)
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.roberta(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class CamembertForMultipleChoice(RobertaForMultipleChoice):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
del self.camembert
|
||||
|
||||
self.roberta = CamembertModel(config, add_pooling_layer=False)
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
||||
r"""
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
|
||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
|
||||
|
||||
- 0 corresponds to a *sentence A* token,
|
||||
- 1 corresponds to a *sentence B* token.
|
||||
This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
|
||||
>= 2. All the value in this tensor should be always < type_vocab_size.
|
||||
|
||||
[What are token type IDs?](../glossary#token-type-ids)
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
||||
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
|
||||
`input_ids` above)
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.max_position_embeddings - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
||||
|
||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
||||
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||
flat_inputs_embeds = (
|
||||
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
||||
if inputs_embeds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
outputs = self.roberta(
|
||||
flat_input_ids,
|
||||
position_ids=flat_position_ids,
|
||||
token_type_ids=flat_token_type_ids,
|
||||
attention_mask=flat_attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=flat_inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = logits.view(-1, num_choices)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(reshaped_logits.device)
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(reshaped_logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (reshaped_logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return MultipleChoiceModelOutput(
|
||||
loss=loss,
|
||||
logits=reshaped_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class CamembertForTokenClassification(RobertaForTokenClassification):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
del self.camembert
|
||||
|
||||
self.roberta = CamembertModel(config, add_pooling_layer=False)
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
|
||||
r"""
|
||||
token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
|
||||
|
||||
- 0 corresponds to a *sentence A* token,
|
||||
- 1 corresponds to a *sentence B* token.
|
||||
This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
|
||||
>= 2. All the value in this tensor should be always < type_vocab_size.
|
||||
|
||||
[What are token type IDs?](../glossary#token-type-ids)
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.roberta(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class CamembertForQuestionAnswering(RobertaForQuestionAnswering):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
del self.camembert
|
||||
|
||||
self.roberta = CamembertModel(config, add_pooling_layer=False)
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
||||
r"""
|
||||
token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
|
||||
|
||||
- 0 corresponds to a *sentence A* token,
|
||||
- 1 corresponds to a *sentence B* token.
|
||||
This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
|
||||
>= 2. All the value in this tensor should be always < type_vocab_size.
|
||||
|
||||
[What are token type IDs?](../glossary#token-type-ids)
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.roberta(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1).contiguous()
|
||||
end_logits = end_logits.squeeze(-1).contiguous()
|
||||
|
||||
total_loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions = start_positions.clamp(0, ignored_index)
|
||||
end_positions = end_positions.clamp(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
return ((total_loss,) + output) if total_loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=total_loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class CamembertForCausalLM(RobertaForCausalLM):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
del self.camembert
|
||||
|
||||
self.roberta = CamembertModel(config, add_pooling_layer=False)
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
|
||||
|
||||
- 0 corresponds to a *sentence A* token,
|
||||
- 1 corresponds to a *sentence B* token.
|
||||
This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
|
||||
>= 2. All the value in this tensor should be always < type_vocab_size.
|
||||
|
||||
[What are token type IDs?](../glossary#token-type-ids)
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
||||
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
||||
ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, CamembertForCausalLM, AutoConfig
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("almanach/camembert-base")
|
||||
>>> config = AutoConfig.from_pretrained("almanach/camembert-base")
|
||||
>>> config.is_decoder = True
|
||||
>>> model = CamembertForCausalLM.from_pretrained("almanach/camembert-base", config=config)
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
|
||||
>>> prediction_logits = outputs.logits
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
if labels is not None:
|
||||
use_cache = False
|
||||
|
||||
outputs = self.roberta(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
|
||||
lm_loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(prediction_scores.device)
|
||||
lm_loss = self.loss_function(
|
||||
prediction_scores,
|
||||
labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=lm_loss,
|
||||
logits=prediction_scores,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CamembertForCausalLM",
|
||||
"CamembertForMaskedLM",
|
||||
"CamembertForMultipleChoice",
|
||||
"CamembertForQuestionAnswering",
|
||||
"CamembertForSequenceClassification",
|
||||
"CamembertForTokenClassification",
|
||||
"CamembertModel",
|
||||
"CamembertPreTrainedModel",
|
||||
]
|
Loading…
Reference in New Issue
Block a user