mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix typo
This commit is contained in:
parent
2d07f945ad
commit
ee0308f79d
@ -238,7 +238,7 @@ def bertForSequenceClassification(*args, **kwargs):
|
||||
seq_classif_logits = model(tokens_tensor, segments_tensors)
|
||||
# Or get the sequence classification loss
|
||||
>>> labels = torch.tensor([1])
|
||||
>>> seq_classif_loss = model(tokens_tensor, segments_tensors, labels=labels)
|
||||
>>> seq_classif_loss = model(tokens_tensor, segments_tensors, labels=labels) # set model.train() before if training this loss
|
||||
"""
|
||||
model = BertForSequenceClassification.from_pretrained(*args, **kwargs)
|
||||
return model
|
||||
@ -272,7 +272,7 @@ def bertForMultipleChoice(*args, **kwargs):
|
||||
multiple_choice_logits = model(tokens_tensor, segments_tensors)
|
||||
# Or get the multiple choice loss
|
||||
>>> labels = torch.tensor([1])
|
||||
>>> multiple_choice_loss = model(tokens_tensor, segments_tensors, labels=labels)
|
||||
>>> multiple_choice_loss = model(tokens_tensor, segments_tensors, labels=labels) # set model.train() before if training this loss
|
||||
"""
|
||||
model = BertForMultipleChoice.from_pretrained(*args, **kwargs)
|
||||
return model
|
||||
@ -304,6 +304,7 @@ def bertForQuestionAnswering(*args, **kwargs):
|
||||
start_logits, end_logits = model(tokens_tensor, segments_tensors)
|
||||
# Or get the total loss which is the sum of the CrossEntropy loss for the start and end token positions
|
||||
>>> start_positions, end_positions = torch.tensor([12]), torch.tensor([14])
|
||||
# set model.train() before if training this loss
|
||||
>>> multiple_choice_loss = model(tokens_tensor, segments_tensors, start_positions=start_positions, end_positions=end_positions)
|
||||
"""
|
||||
model = BertForQuestionAnswering.from_pretrained(*args, **kwargs)
|
||||
@ -341,7 +342,7 @@ def bertForTokenClassification(*args, **kwargs):
|
||||
classif_logits = model(tokens_tensor, segments_tensors)
|
||||
# Or get the token classification loss
|
||||
>>> labels = torch.tensor([[0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0]])
|
||||
>>> classif_loss = model(tokens_tensor, segments_tensors, labels=labels)
|
||||
>>> classif_loss = model(tokens_tensor, segments_tensors, labels=labels) # set model.train() before if training this loss
|
||||
"""
|
||||
model = BertForTokenClassification.from_pretrained(*args, **kwargs)
|
||||
return model
|
||||
|
Loading…
Reference in New Issue
Block a user