Doc fixes in preparation for the docstyle PR (#8061)

* Fixes in preparation for doc styling

* More fixes

* Better syntax

* Fixes

* Style

* More fixes

* More fixes
This commit is contained in:
Sylvain Gugger 2020-10-26 15:01:09 -04:00 committed by GitHub
parent 8bbb74f211
commit 04a17f8550
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 179 additions and 237 deletions

View File

@ -112,7 +112,7 @@ Example usage
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Here is an example using the processors as well as the conversion method using data files:
Example::
.. code-block::
# Loading a V2 processor
processor = SquadV2Processor()
@ -133,7 +133,7 @@ Example::
Using `tensorflow_datasets` is as easy as using a data file:
Example::
.. code-block::
# tensorflow_datasets only handle Squad V1.
tfds_examples = tfds.load("squad")

View File

@ -47,7 +47,7 @@ Usage:
- Pretrained :class:`~transformers.EncoderDecoderModel` are also directly available in the model hub, e.g.,
:: code-block
.. code-block::
# instantiate sentence fusion model
sentence_fuser = EncoderDecoderModel.from_pretrained("google/roberta2roberta_L-24_discofuse")

View File

@ -28,7 +28,9 @@ Implementation Notes
Usage
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Model Usage:
Here is an example of model usage:
.. code-block::
>>> from transformers import BlenderbotSmallTokenizer, BlenderbotForConditionalGeneration
>>> mname = 'facebook/blenderbot-90M'
@ -40,7 +42,10 @@ Model Usage:
>>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in reply_ids])
See Config Values:
Here is how you can check out config values:
.. code-block::
>>> from transformers import BlenderbotConfig
>>> config_90 = BlenderbotConfig.from_pretrained("facebook/blenderbot-90M")

View File

@ -45,6 +45,8 @@ Note:
If you want to reproduce the original tokenization process of the `OpenAI GPT` paper, you will need to install
``ftfy`` and ``SpaCy``::
.. code-block:: bash
pip install spacy ftfy==4.4.3
python -m spacy download en

View File

@ -1,7 +1,7 @@
# This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
# Copyright by the AllenNLP authors.
"""
Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""
import copy

View File

@ -8,7 +8,8 @@ from ..utils import logging
def convert_command_factory(args: Namespace):
"""
Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint.
:return: ServeCommand
Returns: ServeCommand
"""
return ConvertCommand(
args.model_type, args.tf_checkpoint, args.pytorch_dump_output, args.config, args.finetuning_task_name
@ -26,8 +27,9 @@ class ConvertCommand(BaseTransformersCLICommand):
def register_subcommand(parser: ArgumentParser):
"""
Register this command to argparse so it's available for the transformer-cli
:param parser: Root parser to register command-specific arguments
:return:
Args:
parser: Root parser to register command-specific arguments
"""
train_parser = parser.add_parser(
"convert",

View File

@ -31,7 +31,8 @@ logger = logging.get_logger("transformers-cli/serving")
def serve_command_factory(args: Namespace):
"""
Factory function used to instantiate serving server from provided command line arguments.
:return: ServeCommand
Returns: ServeCommand
"""
nlp = pipeline(
task=args.task,
@ -81,8 +82,9 @@ class ServeCommand(BaseTransformersCLICommand):
def register_subcommand(parser: ArgumentParser):
"""
Register this command to argparse so it's available for the transformer-cli
:param parser: Root parser to register command-specific arguments
:return:
Args:
parser: Root parser to register command-specific arguments
"""
serve_parser = parser.add_parser(
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."

View File

@ -19,7 +19,8 @@ USE_AMP = False
def train_command_factory(args: Namespace):
"""
Factory function used to instantiate training command from provided command line arguments.
:return: TrainCommand
Returns: TrainCommand
"""
return TrainCommand(args)
@ -29,8 +30,9 @@ class TrainCommand(BaseTransformersCLICommand):
def register_subcommand(parser: ArgumentParser):
"""
Register this command to argparse so it's available for the transformer-cli
:param parser: Root parser to register command-specific arguments
:return:
Args:
parser: Root parser to register command-specific arguments
"""
train_parser = parser.add_parser("train", help="CLI tool to train a model on a task.")

View File

@ -70,7 +70,7 @@ class BaseUserCommand:
class LoginCommand(BaseUserCommand):
def run(self):
print(
print( # docstyle-ignore
"""
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
@ -127,8 +127,9 @@ class ListObjsCommand(BaseUserCommand):
def tabulate(self, rows: List[List[Union[str, int]]], headers: List[str]) -> str:
"""
Inspired by:
stackoverflow.com/a/8356620/593036
stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
- stackoverflow.com/a/8356620/593036
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
"""
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)

View File

@ -28,19 +28,19 @@ from transformers import BertModel
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
"""
:param model:BertModel Pytorch model instance to be converted
:param ckpt_dir: Tensorflow model directory
:param model_name: model name
:return:
Args
model: BertModel Pytorch model instance to be converted
ckpt_dir: Tensorflow model directory
model_name: model name
Currently supported HF models:
Y BertModel
N BertForMaskedLM
N BertForPreTraining
N BertForMultipleChoice
N BertForNextSentencePrediction
N BertForSequenceClassification
N BertForQuestionAnswering
- Y BertModel
- N BertForMaskedLM
- N BertForPreTraining
- N BertForMultipleChoice
- N BertForNextSentencePrediction
- N BertForSequenceClassification
- N BertForQuestionAnswering
"""
tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")

View File

@ -28,11 +28,13 @@ LANG_CODE_PATH = "lang_code_data/language-codes-3b2.csv"
class TatoebaConverter:
"""Convert Tatoeba-Challenge models to huggingface format.
Steps:
(1) convert numpy state dict to hf format (same code as OPUS-MT-Train conversion).
(2) rename opus model to huggingface format. This means replace each alpha3 code with an alpha2 code if a unique one existes.
e.g. aav-eng -> aav-en, heb-eng -> he-en
(3) write a model card containing the original Tatoeba-Challenge/README.md and extra info about alpha3 group members.
1. convert numpy state dict to hf format (same code as OPUS-MT-Train conversion).
2. rename opus model to huggingface format. This means replace each alpha3 code with an alpha2 code if a unique one existes.
e.g. aav-eng -> aav-en, heb-eng -> he-en
3. write a model card containing the original Tatoeba-Challenge/README.md and extra info about alpha3 group members.
"""
def __init__(self, save_dir="marian_converted"):

View File

@ -19,14 +19,12 @@ DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str
def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
"""
Very simple data collator that:
- simply collates batches of dict-like objects
- Performs special handling for potential keys named:
Very simple data collator that simply collates batches of dict-like objects and erforms special handling for potential keys named:
- ``label``: handles a single value (int or float) per object
- ``label_ids``: handles a list of values per object
- does not do any additional preprocessing
i.e., Property names of the input object will be used as corresponding inputs to the model.
Des not do any additional preprocessing: property names of the input object will be used as corresponding inputs to the model.
See glue and ner for example of how it's useful.
"""
@ -425,6 +423,7 @@ class DataCollatorForPermutationLanguageModeling:
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
0. Start from the beginning of the sequence by setting ``cur_len = 0`` (number of tokens processed so far).
1. Sample a ``span_length`` from the interval ``[1, max_span_length]`` (length of span of tokens to be masked)
2. Reserve a context of length ``context_length = span_length / plm_probability`` to surround span to be masked

View File

@ -289,6 +289,7 @@ def torch_only_method(fn):
return wrapper
# docstyle-ignore
DATASETS_IMPORT_ERROR = """
{0} requires the 🤗 Datasets library but it was not found in your enviromnent. You can install it with:
```
@ -306,6 +307,7 @@ that python file if that's the case.
"""
# docstyle-ignore
TOKENIZERS_IMPORT_ERROR = """
{0} requires the 🤗 Tokenizers library but it was not found in your enviromnent. You can install it with:
```
@ -318,6 +320,7 @@ In a notebook or a colab, you can install it by executing a cell with
"""
# docstyle-ignore
SENTENCEPIECE_IMPORT_ERROR = """
{0} requires the SentencePiece library but it was not found in your enviromnent. Checkout the instructions on the
installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones
@ -325,6 +328,7 @@ that match your enviromnent.
"""
# docstyle-ignore
FAISS_IMPORT_ERROR = """
{0} requires the faiss library but it was not found in your enviromnent. Checkout the instructions on the
installation page of its repo: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md and follow the ones
@ -332,12 +336,14 @@ that match your enviromnent.
"""
# docstyle-ignore
PYTORCH_IMPORT_ERROR = """
{0} requires the PyTorch library but it was not found in your enviromnent. Checkout the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your enviromnent.
"""
# docstyle-ignore
SKLEARN_IMPORT_ERROR = """
{0} requires the scikit-learn library but it was not found in your enviromnent. You can install it with:
```
@ -350,12 +356,14 @@ In a notebook or a colab, you can install it by executing a cell with
"""
# docstyle-ignore
TENSORFLOW_IMPORT_ERROR = """
{0} requires the TensorFlow library but it was not found in your enviromnent. Checkout the instructions on the
installation page: https://www.tensorflow.org/install and follow the ones that match your enviromnent.
"""
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your enviromnent. Checkout the instructions on the
installation page: https://github.com/google/flax and follow the ones that match your enviromnent.

View File

@ -917,7 +917,7 @@ def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
# Copied from fairseq for no_repeat_ngram in beam_search"""
# Copied from fairseq for no_repeat_ngram in beam_search
if cur_len + 1 < no_repeat_ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)]

View File

@ -857,16 +857,16 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
**kwargs,
):
r"""
mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input):
Index of the classification token in each input sequence.
Selected in the range ``[0, input_ids.size(-1) - 1[``.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`)
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`)
mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)

View File

@ -105,8 +105,10 @@ def create_position_ids_from_input_ids(input_ids, padding_idx):
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
:param torch.Tensor x:
:return torch.Tensor:
Args:
x: torch.Tensor x:
Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
@ -176,8 +178,10 @@ class LongformerEmbeddings(nn.Module):
"""We are provided embeddings directly. We cannot infer which are padded so just generate
sequential position ids.
:param torch.Tensor inputs_embeds:
:return torch.Tensor:
Args:
inputs_embeds: torch.Tensor inputs_embeds:
Returns: torch.Tensor
"""
input_shape = inputs_embeds.size()[:-1]
sequence_length = input_shape[1]

View File

@ -647,16 +647,16 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
**kwargs
):
r"""
mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input):
Index of the classification token in each input sequence.
Selected in the range ``[0, input_ids.size(-1) - 1]``.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`)
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`)
mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)

View File

@ -127,8 +127,10 @@ class RobertaEmbeddings(nn.Module):
"""We are provided embeddings directly. We cannot infer which are padded so just generate
sequential position ids.
:param torch.Tensor inputs_embeds:
:return torch.Tensor:
Args:
inputs_embeds: torch.Tensor
Returns: torch.Tensor
"""
input_shape = inputs_embeds.size()[:-1]
sequence_length = input_shape[1]
@ -1326,8 +1328,10 @@ def create_position_ids_from_input_ids(input_ids, padding_idx):
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
:param torch.Tensor x:
:return torch.Tensor:
Args:
x: torch.Tensor x:
Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()

View File

@ -704,7 +704,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
training=False,
):
r"""
mc_token_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
mc_token_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input):
Index of the classification token in each input sequence.
Selected in the range ``[0, input_ids.size(-1) - 1[``.

View File

@ -166,8 +166,11 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
"""Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
:param tf.Tensor x:
:return tf.Tensor:
Args:
x: tf.Tensor
Returns: tf.Tensor
"""
mask = tf.cast(tf.math.not_equal(x, self.padding_idx), dtype=tf.int32)
incremental_indicies = tf.math.cumsum(mask, axis=1) * mask
@ -177,8 +180,11 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
"""We are provided embeddings directly. We cannot infer which are padded so just generate
sequential position ids.
:param tf.Tensor inputs_embeds:
:return tf.Tensor:
Args:
inputs_embeds: tf.Tensor
Returns: tf.Tensor
"""
seq_length = shape_list(inputs_embeds)[1]
position_ids = tf.range(self.padding_idx + 1, seq_length + self.padding_idx + 1, dtype=tf.int32)[tf.newaxis, :]

View File

@ -625,7 +625,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
training=False,
):
r"""
mc_token_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
mc_token_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input):
Index of the classification token in each input sequence.
Selected in the range ``[0, input_ids.size(-1) - 1]``.

View File

@ -111,8 +111,11 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
"""Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
:param tf.Tensor x:
:return tf.Tensor:
Args:
x: tf.Tensor
Returns: tf.Tensor
"""
mask = tf.cast(tf.math.not_equal(x, self.padding_idx), dtype=tf.int32)
incremental_indicies = tf.math.cumsum(mask, axis=1) * mask
@ -122,8 +125,11 @@ class TFRobertaEmbeddings(tf.keras.layers.Layer):
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
"""We are provided embeddings directly. We cannot infer which are padded so just generate
sequential position ids.
:param tf.Tensor inputs_embeds:
:return tf.Tensor:
Args:
inputs_embeds: tf.Tensor
Returns: tf.Tensor
"""
seq_length = shape_list(inputs_embeds)[1]
position_ids = tf.range(self.padding_idx + 1, seq_length + self.padding_idx + 1, dtype=tf.int32)[tf.newaxis, :]

View File

@ -1718,120 +1718,3 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
# the hidden-states output to compute `span start logits` and `span end logits`). """,
# XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
# class TFXLNetForQuestionAnswering(TFXLNetPreTrainedModel):
# r"""
# Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
# **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
# ``tf.Tensor`` of shape ``(batch_size, config.start_n_top)``
# Log probabilities for the top config.start_n_top start token possibilities (beam-search).
# **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
# ``tf.Tensor`` of shape ``(batch_size, config.start_n_top)``
# Indices for the top config.start_n_top start token possibilities (beam-search).
# **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
# ``tf.Tensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
# Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
# **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
# ``tf.Tensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
# Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
# **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
# ``tf.Tensor`` of shape ``(batch_size,)``
# Log probabilities for the ``is_impossible`` label of the answers.
# **mems**:
# list of ``tf.Tensor`` (one for each layer):
# that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
# if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
# See details in the docstring of the `mems` input above.
# **hidden_states**: (`optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``)
# list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
# of shape ``(batch_size, sequence_length, hidden_size)``:
# Hidden-states of the model at the output of each layer plus the initial embedding outputs.
# **attentions**: (`optional`, returned when ``output_attentions=True``)
# list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
# Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
# Examples::
# # For example purposes. Not runnable.
# tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
# model = XLMForQuestionAnswering.from_pretrained('xlnet-large-cased')
# input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
# start_positions = tf.constant([1])
# end_positions = tf.constant([3])
# outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
# loss, start_scores, end_scores = outputs[:2]
# """
# def __init__(self, config, *inputs, **kwargs):
# super().__init__(config, *inputs, **kwargs)
# self.start_n_top = config.start_n_top
# self.end_n_top = config.end_n_top
# self.transformer = TFXLNetMainLayer(config, name='transformer')
# self.start_logits = TFPoolerStartLogits(config, name='start_logits')
# self.end_logits = TFPoolerEndLogits(config, name='end_logits')
# self.answer_class = TFPoolerAnswerClass(config, name='answer_class')
# def call(self, inputs, training=False):
# transformer_outputs = self.transformer(inputs, training=training)
# hidden_states = transformer_outputs[0]
# start_logits = self.start_logits(hidden_states, p_mask=p_mask)
# outputs = transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
# if start_positions is not None and end_positions is not None:
# # If we are on multi-GPU, let's remove the dimension added by batch splitting
# for x in (start_positions, end_positions, cls_index, is_impossible):
# if x is not None and x.dim() > 1:
# x.squeeze_(-1)
# # during training, compute the end logits based on the ground truth of the start position
# end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
# loss_fct = CrossEntropyLoss()
# start_loss = loss_fct(start_logits, start_positions)
# end_loss = loss_fct(end_logits, end_positions)
# total_loss = (start_loss + end_loss) / 2
# if cls_index is not None and is_impossible is not None:
# # Predict answerability from the representation of CLS and START
# cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
# loss_fct_cls = nn.BCEWithLogitsLoss()
# cls_loss = loss_fct_cls(cls_logits, is_impossible)
# # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
# total_loss += cls_loss * 0.5
# outputs = (total_loss,) + outputs
# else:
# # during inference, compute the end logits based on beam search
# bsz, slen, hsz = hidden_states.size()
# start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
# start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
# start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
# start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
# start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
# hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
# p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
# end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
# end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
# end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top)
# end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
# end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
# start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) # get the representation of START as weighted sum of hidden states
# cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) # Shape (batch size,): one single `cls_logits` for each sample
# outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
# # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
# # or (if labels are provided) (total_loss,)
# return outputs

View File

@ -1487,7 +1487,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`)
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),

View File

@ -352,22 +352,22 @@ class CaptureStd:
- out - capture stdout: True/False, default True
- err - capture stdout: True/False, default True
Examples:
Examples::
with CaptureStdout() as cs:
print("Secret message")
print(f"captured: {cs.out}")
with CaptureStdout() as cs:
print("Secret message")
print(f"captured: {cs.out}")
import sys
with CaptureStderr() as cs:
print("Warning: ", file=sys.stderr)
print(f"captured: {cs.err}")
import sys
with CaptureStderr() as cs:
print("Warning: ", file=sys.stderr)
print(f"captured: {cs.err}")
# to capture just one of the streams, but not the other
with CaptureStd(err=False) as cs:
print("Secret message")
print(f"captured: {cs.out}")
# but best use the stream-specific subclasses
# to capture just one of the streams, but not the other
with CaptureStd(err=False) as cs:
print("Secret message")
print(f"captured: {cs.out}")
# but best use the stream-specific subclasses
"""
@ -444,17 +444,17 @@ class CaptureLogger:
Results:
The captured output is available via `self.out`
Example:
Example::
>>> from transformers import logging
>>> from transformers.testing_utils import CaptureLogger
>>> from transformers import logging
>>> from transformers.testing_utils import CaptureLogger
>>> msg = "Testing 1, 2, 3"
>>> logging.set_verbosity_info()
>>> logger = logging.get_logger("transformers.tokenization_bart")
>>> with CaptureLogger(logger) as cl:
... logger.info(msg)
>>> assert cl.out, msg+"\n"
>>> msg = "Testing 1, 2, 3"
>>> logging.set_verbosity_info()
>>> logger = logging.get_logger("transformers.tokenization_bart")
>>> with CaptureLogger(logger) as cl:
... logger.info(msg)
>>> assert cl.out, msg+"\n"
"""
def __init__(self, logger):
@ -485,24 +485,36 @@ class TestCasePlus(unittest.TestCase):
of test, unless `after=False`.
# 1. create a unique temp dir, `tmp_dir` will contain the path to the created temp dir
def test_whatever(self):
tmp_dir = self.get_auto_remove_tmp_dir()
::
def test_whatever(self):
tmp_dir = self.get_auto_remove_tmp_dir()
# 2. create a temp dir of my choice and delete it at the end - useful for debug when you want to
# monitor a specific directory
def test_whatever(self):
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test")
::
def test_whatever(self):
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test")
# 3. create a temp dir of my choice and do not delete it at the end - useful for when you want
# to look at the temp results
def test_whatever(self):
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test", after=False)
::
def test_whatever(self):
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test", after=False)
# 4. create a temp dir of my choice and ensure to delete it right away - useful for when you
# disabled deletion in the previous test run and want to make sure the that tmp dir is empty
# before the new test is run
def test_whatever(self):
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test", before=True)
::
def test_whatever(self):
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test", before=True)
Note 1: In order to run the equivalent of `rm -r` safely, only subdirs of the
project repository checkout are allowed if an explicit `tmp_dir` is used, so

View File

@ -488,6 +488,7 @@ domains and tasks. The basic logic is this:
# This particular element is used in a couple ways, so we define it
# with a name:
# docstyle-ignore
EMOTICONS = r"""
(?:
[<>]?
@ -505,7 +506,7 @@ EMOTICONS = r"""
# URL pattern due to John Gruber, modified by Tom Winzig. See
# https://gist.github.com/winzig/8894715
# docstyle-ignore
URLS = r""" # Capture 1: entire matched URL
(?:
https?: # URL protocol and colon
@ -549,6 +550,7 @@ URLS = r""" # Capture 1: entire matched URL
)
"""
# docstyle-ignore
# The components of the tokenizer:
REGEXPS = (
URLS,
@ -628,18 +630,16 @@ def _replace_html_entities(text, keep=(), remove_illegal=True, encoding="utf-8")
Remove entities from text by converting them to their
corresponding unicode character.
:param text: a unicode string or a byte string encoded in the given
`encoding` (which defaults to 'utf-8').
Args:
text:
A unicode string or a byte string encoded in the given `encoding` (which defaults to 'utf-8').
keep (list):
List of entity names which should not be replaced. This supports both numeric entities (``&#nnnn;`` and ``&#hhhh;``)
and named entities (such as ``&nbsp;`` or ``&gt;``).
remove_illegal (bool):
If `True`, entities that can't be converted are removed. Otherwise, entities that can't be converted are kept "as is".
:param list keep: list of entity names which should not be replaced.\
This supports both numeric entities (``&#nnnn;`` and ``&#hhhh;``)
and named entities (such as ``&nbsp;`` or ``&gt;``).
:param bool remove_illegal: If `True`, entities that can't be converted are\
removed. Otherwise, entities that can't be converted are kept "as
is".
:returns: A unicode string with the entities removed.
Returns: A unicode string with the entities removed.
See https://github.com/scrapy/w3lib/blob/master/w3lib/html.py
@ -688,16 +688,16 @@ def _replace_html_entities(text, keep=(), remove_illegal=True, encoding="utf-8")
class TweetTokenizer:
r"""
Tokenizer for tweets.
Examples::
>>> # Tokenizer for tweets.
>>> from nltk.tokenize import TweetTokenizer
>>> tknzr = TweetTokenizer()
>>> s0 = "This is a cooool #dummysmiley: :-) :-P <3 and some arrows < > -> <--"
>>> tknzr.tokenize(s0)
['This', 'is', 'a', 'cooool', '#dummysmiley', ':', ':-)', ':-P', '<3', 'and', 'some', 'arrows', '<', '>', '->', '<--']
Examples using `strip_handles` and `reduce_len parameters`:
>>> # Examples using `strip_handles` and `reduce_len parameters`:
>>> tknzr = TweetTokenizer(strip_handles=True, reduce_len=True)
>>> s1 = '@remy: This is waaaaayyyy too much for you!!!!!!'
>>> tknzr.tokenize(s1)
@ -711,10 +711,11 @@ class TweetTokenizer:
def tokenize(self, text):
"""
:param text: str
:rtype: list(str)
:return: a tokenized list of strings; concatenating this list returns\
the original string if `preserve_case=False`
Args:
text: str
Returns: list(str)
A tokenized list of strings; concatenating this list returns the original string if `preserve_case=False`
"""
# Fix HTML character entities:
text = _replace_html_entities(text)

View File

@ -628,13 +628,16 @@ class DebertaTokenizer(PreTrainedTokenizer):
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
A BERT sequence pair mask has the following format:
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
| first sequence | second sequence
Create a mask from the two sequences passed to be used in a sequence-pair classification task.
A DeBERTa sequence pair mask has the following format:
::
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
if token_ids_1 is None, only returns the first portion of the mask (0's).
~
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.