mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
Docstring check (#26052)
* Fix number of minimal calls to the Hub with peft integration * Alternate design * And this way? * Revert * Nits to fix * Add util * Print when changes are made * Add list to ignore * Add more rules * Manual fixes * deal with kwargs * deal with enum defaults * avoid many digits for floats * Manual fixes * Fix regex * Fix regex * Auto fix * Style * Apply script * Add ignored list * Add check that templates are filled * Adding to CI checks * Add back semi-fix * Ignore more objects * More auto-fixes * Ignore missing objects * Remove temp semi-fix * Fixes * Update src/transformers/models/pvt/configuration_pvt.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update utils/check_docstrings.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/utils/quantization_config.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Deal with float defaults * Fix small defaults * Address review comment * Treat * Post-rebase cleanup * Address review comment * Update src/transformers/models/deprecated/mctct/configuration_mctct.py Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr> * Address review comment --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
parent
122b2657f8
commit
03af4c42a6
@ -209,6 +209,7 @@ jobs:
|
||||
- run: make deps_table_check_updated
|
||||
- run: python utils/update_metadata.py --check-only
|
||||
- run: python utils/check_task_guides.py
|
||||
- run: python utils/check_docstrings.py
|
||||
|
||||
workflows:
|
||||
version: 2
|
||||
|
2
Makefile
2
Makefile
@ -43,6 +43,7 @@ repo-consistency:
|
||||
python utils/check_doctest_list.py
|
||||
python utils/update_metadata.py --check-only
|
||||
python utils/check_task_guides.py
|
||||
python utils/check_docstrings.py
|
||||
|
||||
# this target runs checks on all files
|
||||
|
||||
@ -82,6 +83,7 @@ fix-copies:
|
||||
python utils/check_dummies.py --fix_and_overwrite
|
||||
python utils/check_doctest_list.py --fix_and_overwrite
|
||||
python utils/check_task_guides.py --fix_and_overwrite
|
||||
python utils/check_docstrings.py --fix_and_overwrite
|
||||
|
||||
# Run tests for the library
|
||||
|
||||
|
@ -124,6 +124,7 @@ This checks that:
|
||||
- The translations of the READMEs and the index of the doc have the same model list as the main README (performed by `utils/check_copies.py`)
|
||||
- The auto-generated tables in the documentation are up to date (performed by `utils/check_table.py`)
|
||||
- The library has all objects available even if not all optional dependencies are installed (performed by `utils/check_dummies.py`)
|
||||
- All docstrings properly document the arguments in the signature of the object (performed by `utils/check_docstrings.py`)
|
||||
|
||||
Should this check fail, the first two items require manual fixing, the last four can be fixed automatically for you by running the command
|
||||
|
||||
|
@ -47,6 +47,7 @@ _re_configuration_file = re.compile(r"config\.(.*)\.json")
|
||||
|
||||
|
||||
class PretrainedConfig(PushToHubMixin):
|
||||
# no-format
|
||||
r"""
|
||||
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
|
||||
methods for loading/downloading/saving configurations.
|
||||
|
@ -90,7 +90,7 @@ class DefaultDataCollator(DataCollatorMixin):
|
||||
helpful if you need to set a return_tensors value at initialization.
|
||||
|
||||
Args:
|
||||
return_tensors (`str`):
|
||||
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||
"""
|
||||
|
||||
@ -235,7 +235,7 @@ class DataCollatorWithPadding:
|
||||
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||
7.5 (Volta).
|
||||
return_tensors (`str`):
|
||||
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||
"""
|
||||
|
||||
@ -288,7 +288,7 @@ class DataCollatorForTokenClassification(DataCollatorMixin):
|
||||
7.5 (Volta).
|
||||
label_pad_token_id (`int`, *optional*, defaults to -100):
|
||||
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
|
||||
return_tensors (`str`):
|
||||
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||
"""
|
||||
|
||||
@ -521,7 +521,7 @@ class DataCollatorForSeq2Seq:
|
||||
Args:
|
||||
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
||||
The tokenizer used for encoding the data.
|
||||
model ([`PreTrainedModel`]):
|
||||
model ([`PreTrainedModel`], *optional*):
|
||||
The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
|
||||
prepare the *decoder_input_ids*
|
||||
|
||||
@ -544,7 +544,7 @@ class DataCollatorForSeq2Seq:
|
||||
7.5 (Volta).
|
||||
label_pad_token_id (`int`, *optional*, defaults to -100):
|
||||
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||
return_tensors (`str`):
|
||||
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||
"""
|
||||
|
||||
|
@ -65,7 +65,7 @@ class BatchFeature(UserDict):
|
||||
This class is derived from a python dictionary and can be used as a dictionary.
|
||||
|
||||
Args:
|
||||
data (`dict`):
|
||||
data (`dict`, *optional*):
|
||||
Dictionary of lists/arrays/tensors returned by the __call__/pad methods ('input_values', 'attention_mask',
|
||||
etc.).
|
||||
tensor_type (`Union[None, str, TensorType]`, *optional*):
|
||||
|
@ -263,8 +263,9 @@ class DisjunctiveConstraint(Constraint):
|
||||
A special [`Constraint`] that is fulfilled by fulfilling just one of several constraints.
|
||||
|
||||
Args:
|
||||
nested_token_ids (`List[List[int]]`): a list of words, where each word is a list of ids. This constraint
|
||||
is fulfilled by generating just one from the list of words.
|
||||
nested_token_ids (`List[List[int]]`):
|
||||
A list of words, where each word is a list of ids. This constraint is fulfilled by generating just one from
|
||||
the list of words.
|
||||
"""
|
||||
|
||||
def __init__(self, nested_token_ids: List[List[int]]):
|
||||
|
@ -152,7 +152,7 @@ class BeamSearchScorer(BeamScorer):
|
||||
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
|
||||
The number of beam hypotheses that shall be returned upon calling
|
||||
[`~transformer.BeamSearchScorer.finalize`].
|
||||
num_beam_groups (`int`):
|
||||
num_beam_groups (`int`, *optional*, defaults to 1):
|
||||
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
|
||||
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
|
||||
max_length (`int`, *optional*):
|
||||
@ -437,7 +437,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
|
||||
The number of beam hypotheses that shall be returned upon calling
|
||||
[`~transformer.BeamSearchScorer.finalize`].
|
||||
num_beam_groups (`int`):
|
||||
num_beam_groups (`int`, *optional*, defaults to 1):
|
||||
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
|
||||
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
|
||||
max_length (`int`, *optional*):
|
||||
|
@ -38,6 +38,7 @@ METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash"
|
||||
|
||||
|
||||
class GenerationConfig(PushToHubMixin):
|
||||
# no-format
|
||||
r"""
|
||||
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
|
||||
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
|
||||
|
@ -120,7 +120,7 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
|
||||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
filter_value (`float`, *optional*, defaults to -inf):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
@ -163,7 +163,7 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper):
|
||||
Args:
|
||||
top_k (`int`):
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
filter_value (`float`, *optional*, defaults to -inf):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
|
@ -357,7 +357,7 @@ class TopPLogitsWarper(LogitsWarper):
|
||||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
filter_value (`float`, *optional*, defaults to -inf):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
@ -419,7 +419,7 @@ class TopKLogitsWarper(LogitsWarper):
|
||||
Args:
|
||||
top_k (`int`):
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
filter_value (`float`, *optional*, defaults to -inf):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
@ -447,9 +447,9 @@ class TypicalLogitsWarper(LogitsWarper):
|
||||
Generation](https://arxiv.org/abs/2202.00666) for more information.
|
||||
|
||||
Args:
|
||||
mass (`float`):
|
||||
mass (`float`, *optional*, defaults to 0.9):
|
||||
Value of typical_p between 0 and 1 inclusive, defaults to 0.9.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
filter_value (`float`, *optional*, defaults to -inf):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
@ -499,7 +499,7 @@ class EpsilonLogitsWarper(LogitsWarper):
|
||||
Args:
|
||||
epsilon (`float`):
|
||||
If set to > 0, only the most tokens with probabilities `epsilon` or higher are kept for generation.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
filter_value (`float`, *optional*, defaults to -inf):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
@ -572,7 +572,7 @@ class EtaLogitsWarper(LogitsWarper):
|
||||
epsilon (`float`):
|
||||
A float value in the range (0, 1). Hyperparameter used to calculate the dynamic cutoff value, `eta`. The
|
||||
suggested values from the paper ranges from 3e-4 to 4e-3 depending on the size of the model.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
filter_value (`float`, *optional*, defaults to -inf):
|
||||
All values that are found to be below the dynamic cutoff value, `eta`, are set to this float value. This
|
||||
parameter is useful when logits need to be modified for very low probability tokens that should be excluded
|
||||
from generation entirely.
|
||||
@ -1600,18 +1600,15 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
|
||||
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
||||
prompt, usually at the expense of poorer quality. A value smaller than 1 has the opposite effect, while
|
||||
making the negative prompt provided with negative_prompt_ids (if any) act as a positive prompt.
|
||||
unconditional_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary for the unconditional branch. If unset, will default to
|
||||
the last token of the prompt.
|
||||
unconditional_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, **optional**):
|
||||
Attention mask for unconditional_ids.
|
||||
model (`PreTrainedModel`):
|
||||
The model computing the unconditional scores. Supposedly the same as the one computing the conditional
|
||||
scores. Both models must use the same tokenizer.
|
||||
smooth_factor (`float`, **optional**):
|
||||
The interpolation weight for CFG Rescale. 1 means no rescaling, 0 reduces to the conditional scores without
|
||||
CFG. Turn it lower if the output degenerates.
|
||||
use_cache (`bool`, **optional**):
|
||||
unconditional_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary for the unconditional branch. If unset, will default to
|
||||
the last token of the prompt.
|
||||
unconditional_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Attention mask for unconditional_ids.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether to cache key/values during the negative prompt forward pass.
|
||||
|
||||
|
||||
|
@ -49,7 +49,7 @@ class MaxLengthCriteria(StoppingCriteria):
|
||||
Args:
|
||||
max_length (`int`):
|
||||
The maximum length that the output sequence can have in number of tokens.
|
||||
max_position_embeddings (`int`, `optional`):
|
||||
max_position_embeddings (`int`, *optional*):
|
||||
The maximum model length, as defined by the model's `config.max_position_embeddings` attribute.
|
||||
"""
|
||||
|
||||
|
@ -122,7 +122,7 @@ class TFTopKLogitsWarper(TFLogitsWarper):
|
||||
Args:
|
||||
top_k (`int`):
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
filter_value (`float`, *optional*, defaults to -inf):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
@ -151,7 +151,7 @@ class TFTopPLogitsWarper(TFLogitsWarper):
|
||||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation.
|
||||
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||
filter_value (`float`, *optional*, defaults to -inf):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
|
@ -71,6 +71,8 @@ class AlignTextConfig(PretrainedConfig):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
Padding token id.
|
||||
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
||||
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
||||
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
||||
@ -80,8 +82,6 @@ class AlignTextConfig(PretrainedConfig):
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*, defaults to 0)
|
||||
Padding token id.
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -259,7 +259,7 @@ class AltCLIPConfig(PretrainedConfig):
|
||||
Dictionary of configuration options used to initialize [`AltCLIPTextConfig`].
|
||||
vision_config (`dict`, *optional*):
|
||||
Dictionary of configuration options used to initialize [`AltCLIPVisionConfig`].
|
||||
projection_dim (`int`, *optional*, defaults to 512):
|
||||
projection_dim (`int`, *optional*, defaults to 768):
|
||||
Dimentionality of text and vision projection layers.
|
||||
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
|
||||
The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation.
|
||||
|
@ -30,9 +30,9 @@ class AltCLIPProcessor(ProcessorMixin):
|
||||
the [`~AltCLIPProcessor.__call__`] and [`~AltCLIPProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`CLIPImageProcessor`]):
|
||||
image_processor ([`CLIPImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`XLMRobertaTokenizerFast`]):
|
||||
tokenizer ([`XLMRobertaTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
"""
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
|
@ -51,15 +51,15 @@ class ASTConfig(PretrainedConfig):
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
patch_size (`int`, *optional*, defaults to `16`):
|
||||
patch_size (`int`, *optional*, defaults to 16):
|
||||
The size (resolution) of each patch.
|
||||
qkv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to add a bias to the queries, keys and values.
|
||||
|
@ -38,7 +38,7 @@ class BarkProcessor(ProcessorMixin):
|
||||
Args:
|
||||
tokenizer ([`PreTrainedTokenizer`]):
|
||||
An instance of [`PreTrainedTokenizer`].
|
||||
speaker_embeddings (`Dict[Dict[str]]`, *optional*, defaults to `None`):
|
||||
speaker_embeddings (`Dict[Dict[str]]`, *optional*):
|
||||
Optional nested speaker embeddings dictionary. The first level contains voice preset names (e.g
|
||||
`"en_speaker_4"`). The second level contains `"semantic_prompt"`, `"coarse_prompt"` and `"fine_prompt"`
|
||||
embeddings. The values correspond to the path of the corresponding `np.ndarray`. See
|
||||
|
@ -97,8 +97,6 @@ class BarthezTokenizer(PreTrainedTokenizer):
|
||||
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
||||
The token used for masking values. This is the token used when training this model with masked language
|
||||
modeling. This is the token which the model will try to predict.
|
||||
additional_special_tokens (`List[str]`, *optional*, defaults to `["<s>NOTUSED", "</s>NOTUSED"]`):
|
||||
Additional special tokens used by the tokenizer.
|
||||
sp_model_kwargs (`dict`, *optional*):
|
||||
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
||||
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
||||
|
@ -92,8 +92,6 @@ class BartphoTokenizer(PreTrainedTokenizer):
|
||||
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
||||
The token used for masking values. This is the token used when training this model with masked language
|
||||
modeling. This is the token which the model will try to predict.
|
||||
additional_special_tokens (`List[str]`, *optional*, defaults to `["<s>NOTUSED", "</s>NOTUSED"]`):
|
||||
Additional special tokens used by the tokenizer.
|
||||
sp_model_kwargs (`dict`, *optional*):
|
||||
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
||||
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
||||
|
@ -41,7 +41,7 @@ class BeitConfig(PretrainedConfig):
|
||||
[microsoft/beit-base-patch16-224-pt22k](https://huggingface.co/microsoft/beit-base-patch16-224-pt22k) architecture.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 8092):
|
||||
vocab_size (`int`, *optional*, defaults to 8192):
|
||||
Vocabulary size of the BEiT model. Defines the number of different image tokens that can be used during
|
||||
pre-training.
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
|
@ -57,7 +57,7 @@ class BeitImageProcessor(BaseImageProcessor):
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
|
||||
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
||||
method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
||||
`preprocess` method.
|
||||
do_center_crop (`bool`, *optional*, defaults to `True`):
|
||||
@ -67,12 +67,12 @@ class BeitImageProcessor(BaseImageProcessor):
|
||||
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
||||
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
|
||||
Can be overridden by the `crop_size` parameter in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
||||
parameter in the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
||||
`preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
||||
parameter in the `preprocess` method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||
method.
|
||||
|
@ -77,7 +77,7 @@ class BertweetTokenizer(PreTrainedTokenizer):
|
||||
Path to the vocabulary file.
|
||||
merges_file (`str`):
|
||||
Path to the merges file.
|
||||
normalization (`bool`, *optional*, defaults to `False`)
|
||||
normalization (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to apply a normalization preprocess.
|
||||
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
||||
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
||||
|
@ -60,25 +60,25 @@ class BigBirdTokenizer(PreTrainedTokenizer):
|
||||
vocab_file (`str`):
|
||||
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
|
||||
contains the vocabulary necessary to instantiate a tokenizer.
|
||||
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
||||
The end of sequence token.
|
||||
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
||||
The begin of sequence token.
|
||||
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
||||
The begin of sequence token.
|
||||
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
||||
The end of sequence token.
|
||||
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
sep_token (`str`, *optional*, defaults to `"[SEP]"`):
|
||||
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
||||
sequence classification or for a text and a question for question answering. It is also used as the last
|
||||
token of a sequence built with special tokens.
|
||||
cls_token (`str`, *optional*, defaults to `"[CLS]"`):
|
||||
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
||||
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
||||
mask_token (`str`, *optional*, defaults to `"[MASK]"`):
|
||||
The token used for masking values. This is the token used when training this model with masked language
|
||||
modeling. This is the token which the model will try to predict.
|
||||
cls_token (`str`, *optional*, defaults to `"[CLS]"`):
|
||||
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
||||
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
||||
sp_model_kwargs (`dict`, *optional*):
|
||||
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
||||
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
||||
|
@ -72,12 +72,13 @@ class BioGptConfig(PretrainedConfig):
|
||||
Please refer to the paper about LayerDrop: https://arxiv.org/abs/1909.11556 for further details
|
||||
activation_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
pad_token_id (`int`, *optional*, defaults to 1)
|
||||
pad_token_id (`int`, *optional*, defaults to 1):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 0)
|
||||
bos_token_id (`int`, *optional*, defaults to 0):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 2)
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
End of stream token id.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
|
@ -52,7 +52,7 @@ class BitConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
are supported.
|
||||
global_padding (`str`, *optional*):
|
||||
Padding strategy to use for the convolutional layers. Can be either `"valid"`, `"same"`, or `None`.
|
||||
num_groups (`int`, *optional*, defaults to `32`):
|
||||
num_groups (`int`, *optional*, defaults to 32):
|
||||
Number of groups used for the `BitGroupNormActivation` layers.
|
||||
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
||||
The drop path rate for the stochastic depth.
|
||||
|
@ -85,9 +85,9 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer):
|
||||
unk_token (`str`, *optional*, defaults to `"__unk__"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
pad_token (`str`, *optional*, defaults to `"__pad__"`):
|
||||
pad_token (`str`, *optional*, defaults to `"__null__"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
**kwargs
|
||||
kwargs (*optional*):
|
||||
Additional keyword arguments passed along to [`PreTrainedTokenizer`]
|
||||
"""
|
||||
|
||||
|
@ -295,7 +295,7 @@ class BlipConfig(PretrainedConfig):
|
||||
Dimentionality of text and vision projection layers.
|
||||
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
|
||||
The inital value of the *logit_scale* paramter. Default is used as per the original BLIP implementation.
|
||||
image_text_hidden_size (`int`, *optional*, defaults to 768):
|
||||
image_text_hidden_size (`int`, *optional*, defaults to 256):
|
||||
Dimentionality of the hidden state of the image-text fusion layer.
|
||||
kwargs (*optional*):
|
||||
Dictionary of keyword arguments.
|
||||
|
@ -53,7 +53,7 @@ class BlipImageProcessor(BaseImageProcessor):
|
||||
size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
|
||||
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
||||
method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
|
||||
overridden by the `resample` parameter in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
|
@ -128,14 +128,14 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
|
||||
`do_resize` parameter in the `preprocess` method.
|
||||
size (`Dict[str, int]` *optional*, defaults to `288`):
|
||||
size (`Dict[str, int]` *optional*, defaults to 288):
|
||||
Resize the shorter side of the input to `size["shortest_edge"]`. The longer side will be limited to under
|
||||
`int((1333 / 800) * size["shortest_edge"])` while preserving the aspect ratio. Only has an effect if
|
||||
`do_resize` is set to `True`. Can be overridden by the `size` parameter in the `preprocess` method.
|
||||
size_divisor (`int`, *optional*, defaults to 32):
|
||||
The size by which to make sure both the height and width can be divided. Only has an effect if `do_resize`
|
||||
is set to `True`. Can be overridden by the `size_divisor` parameter in the `preprocess` method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
|
||||
overridden by the `resample` parameter in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
|
@ -31,7 +31,7 @@ class BrosProcessor(ProcessorMixin):
|
||||
[`~BrosProcessor.__call__`] and [`~BrosProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
tokenizer (`BertTokenizerFast`):
|
||||
tokenizer (`BertTokenizerFast`, *optional*):
|
||||
An instance of ['BertTokenizerFast`]. The tokenizer is a required input.
|
||||
"""
|
||||
attributes = ["tokenizer"]
|
||||
|
@ -48,7 +48,7 @@ class ByT5Tokenizer(PreTrainedTokenizer):
|
||||
token instead.
|
||||
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
extra_ids (`int`, *optional*, defaults to 100):
|
||||
extra_ids (`int`, *optional*, defaults to 125):
|
||||
Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are
|
||||
accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are
|
||||
indexed from the end of the vocabulary up to beginning ("<extra_id_0>" is the last token in the vocabulary
|
||||
|
@ -89,7 +89,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
|
||||
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
||||
The token used for masking values. This is the token used when training this model with masked language
|
||||
modeling. This is the token which the model will try to predict.
|
||||
additional_special_tokens (`List[str]`, *optional*, defaults to `["<s>NOTUSED", "</s>NOTUSED"]`):
|
||||
additional_special_tokens (`List[str]`, *optional*, defaults to `['<s>NOTUSED', '</s>NOTUSED']`):
|
||||
Additional special tokens used by the tokenizer.
|
||||
sp_model_kwargs (`dict`, *optional*):
|
||||
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
||||
|
@ -31,9 +31,9 @@ class ChineseCLIPProcessor(ProcessorMixin):
|
||||
See the [`~ChineseCLIPProcessor.__call__`] and [`~ChineseCLIPProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`ChineseCLIPImageProcessor`]):
|
||||
image_processor ([`ChineseCLIPImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`BertTokenizerFast`]):
|
||||
tokenizer ([`BertTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
"""
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
|
@ -227,7 +227,7 @@ class ClapAudioConfig(PretrainedConfig):
|
||||
projection_hidden_act (`str`, *optional*, defaults to `"relu"`):
|
||||
The non-linear activation function (function or string) in the projection layer. If string, `"gelu"`,
|
||||
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||
layer_norm_eps (`[type]`, *optional*, defaults to `1e-5`):
|
||||
layer_norm_eps (`[type]`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the layer normalization layers.
|
||||
initializer_factor (`float`, *optional*, defaults to 1.0):
|
||||
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
||||
@ -345,10 +345,10 @@ class ClapConfig(PretrainedConfig):
|
||||
Dictionary of configuration options used to initialize [`ClapTextConfig`].
|
||||
audio_config (`dict`, *optional*):
|
||||
Dictionary of configuration options used to initialize [`ClapAudioConfig`].
|
||||
logit_scale_init_value (`float`, *optional*, defaults to 14.29):
|
||||
The inital value of the *logit_scale* paramter. Default is used as per the original CLAP implementation.
|
||||
projection_dim (`int`, *optional*, defaults to 512):
|
||||
Dimentionality of text and audio projection layers.
|
||||
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
|
||||
The inital value of the *logit_scale* paramter. Default is used as per the original CLAP implementation.
|
||||
projection_hidden_act (`str`, *optional*, defaults to `"relu"`):
|
||||
Activation function for the projection layers.
|
||||
initializer_factor (`float`, *optional*, defaults to 1.0):
|
||||
|
@ -41,32 +41,32 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
|
||||
Fourier Transform* (STFT) which should match pytorch's `torch.stft` equivalent.
|
||||
|
||||
Args:
|
||||
feature_size (`int`, defaults to 64):
|
||||
feature_size (`int`, *optional*, defaults to 64):
|
||||
The feature dimension of the extracted Mel spectrograms. This corresponds to the number of mel filters
|
||||
(`n_mels`).
|
||||
sampling_rate (`int`, defaults to 48_000):
|
||||
sampling_rate (`int`, *optional*, defaults to 48000):
|
||||
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). This only serves
|
||||
to warn users if the audio fed to the feature extractor does not have the same sampling rate.
|
||||
hop_length (`int`, defaults to 480):
|
||||
hop_length (`int`,*optional*, defaults to 480):
|
||||
Length of the overlaping windows for the STFT used to obtain the Mel Spectrogram. The audio will be split
|
||||
in smaller `frames` with a step of `hop_length` between each frame.
|
||||
max_length_s (`int`, defaults to 10):
|
||||
max_length_s (`int`, *optional*, defaults to 10):
|
||||
The maximum input length of the model in seconds. This is used to pad the audio.
|
||||
fft_window_size (`int`, defaults to 1024):
|
||||
fft_window_size (`int`, *optional*, defaults to 1024):
|
||||
Size of the window (in samples) on which the Fourier transform is applied. This controls the frequency
|
||||
resolution of the spectrogram. 400 means that the fourrier transform is computed on windows of 400 samples.
|
||||
padding_value (`float`, *optional*, defaults to 0.0):
|
||||
Padding value used to pad the audio. Should correspond to silences.
|
||||
return_attention_mask (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the model should return the attention masks coresponding to the input.
|
||||
frequency_min (`float`, *optional*, default to 0):
|
||||
frequency_min (`float`, *optional*, defaults to 0):
|
||||
The lowest frequency of interest. The STFT will not be computed for values below this.
|
||||
frequency_max (`float`, *optional*, default to 14_000):
|
||||
frequency_max (`float`, *optional*, defaults to 14000):
|
||||
The highest frequency of interest. The STFT will not be computed for values above this.
|
||||
top_db (`float`, *optional*):
|
||||
The highest decibel value used to convert the mel spectrogram to the log scale. For more details see the
|
||||
`audio_utils.power_to_db` function
|
||||
truncation (`str`, *optional*, default to `"fusions"`):
|
||||
truncation (`str`, *optional*, defaults to `"fusion"`):
|
||||
Truncation pattern for long audio inputs. Two patterns are available:
|
||||
- `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and a
|
||||
downsampled version of the entire mel spectrogram.
|
||||
|
@ -30,9 +30,9 @@ class CLIPProcessor(ProcessorMixin):
|
||||
[`~CLIPProcessor.__call__`] and [`~CLIPProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`CLIPImageProcessor`]):
|
||||
image_processor ([`CLIPImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`CLIPTokenizerFast`]):
|
||||
tokenizer ([`CLIPTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
"""
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
|
@ -255,7 +255,7 @@ class CLIPSegConfig(PretrainedConfig):
|
||||
Dimensionality of text and vision projection layers.
|
||||
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
|
||||
The inital value of the *logit_scale* paramter. Default is used as per the original CLIPSeg implementation.
|
||||
extract_layers (`List[int]`, *optional*, defaults to [3, 6, 9]):
|
||||
extract_layers (`List[int]`, *optional*, defaults to `[3, 6, 9]`):
|
||||
Layers to extract when forwarding the query image through the frozen visual backbone of CLIP.
|
||||
reduce_dim (`int`, *optional*, defaults to 64):
|
||||
Dimensionality to reduce the CLIP vision embedding.
|
||||
|
@ -30,9 +30,9 @@ class CLIPSegProcessor(ProcessorMixin):
|
||||
[`~CLIPSegProcessor.__call__`] and [`~CLIPSegProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`ViTImageProcessor`]):
|
||||
image_processor ([`ViTImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`CLIPTokenizerFast`]):
|
||||
tokenizer ([`CLIPTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
"""
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
|
@ -64,7 +64,7 @@ class ConvNextImageProcessor(BaseImageProcessor):
|
||||
crop_pct (`float` *optional*, defaults to 224 / 256):
|
||||
Percentage of the image to crop. Only has an effect if `do_resize` is `True` and size < 384. Can be
|
||||
overriden by `crop_pct` in the `preprocess` method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Resampling filter to use if resizing the image. Can be overriden by `resample` in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overriden by `do_rescale` in
|
||||
|
@ -50,15 +50,17 @@ class CpmAntConfig(PretrainedConfig):
|
||||
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 48):
|
||||
Number of layers of the Transformer encoder.
|
||||
dropout_p (`float`, *optional*, defaults to 0.1):
|
||||
dropout_p (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder.
|
||||
position_bias_num_buckets (`int`, *optional*, defaults to 512):
|
||||
The number of position_bias buckets.
|
||||
position_bias_max_distance (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
eps (`float`, *optional*, defaults to 1e-6):
|
||||
eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
init_std (`float`, *optional*, defaults to 1.0):
|
||||
Initialize parameters with std = init_std.
|
||||
prompt_types (`int`, *optional*, defaults to 32):
|
||||
The type of prompt.
|
||||
prompt_length (`int`, *optional*, defaults to 32):
|
||||
@ -67,8 +69,6 @@ class CpmAntConfig(PretrainedConfig):
|
||||
The type of segment.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use cache.
|
||||
init_std (`float`, *optional*, defaults to 1.0):
|
||||
Initialize parameters with std = init_std.
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -54,7 +54,7 @@ class CTRLConfig(PretrainedConfig):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
embd_pdrop (`int`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for the embeddings.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-6):
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon to use in the layer normalization layers
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
|
@ -99,9 +99,9 @@ class DebertaTokenizerFast(PreTrainedTokenizerFast):
|
||||
refer to this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
vocab_file (`str`, *optional*):
|
||||
Path to the vocabulary file.
|
||||
merges_file (`str`):
|
||||
merges_file (`str`, *optional*):
|
||||
Path to the merges file.
|
||||
tokenizer_file (`str`, *optional*):
|
||||
The path to a tokenizer file to use instead of the vocab file.
|
||||
|
@ -58,23 +58,23 @@ class DeiTConfig(PretrainedConfig):
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
image_size (`int`, *optional*, defaults to `224`):
|
||||
image_size (`int`, *optional*, defaults to 224):
|
||||
The size (resolution) of each image.
|
||||
patch_size (`int`, *optional*, defaults to `16`):
|
||||
patch_size (`int`, *optional*, defaults to 16):
|
||||
The size (resolution) of each patch.
|
||||
num_channels (`int`, *optional*, defaults to `3`):
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
qkv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to add a bias to the queries, keys and values.
|
||||
encoder_stride (`int`, `optional`, defaults to 16):
|
||||
encoder_stride (`int`, *optional*, defaults to 16):
|
||||
Factor to increase the spatial resolution by in the decoder head for masked image modeling.
|
||||
|
||||
Example:
|
||||
|
@ -52,19 +52,19 @@ class DeiTImageProcessor(BaseImageProcessor):
|
||||
`do_resize` in `preprocess`.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
|
||||
Size of the image after `resize`. Can be overridden by `size` in `preprocess`.
|
||||
resample (`PILImageResampling` filter, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
resample (`PILImageResampling` filter, *optional*, defaults to `Resampling.BICUBIC`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
|
||||
do_center_crop (`bool`, *optional*, defaults to `True`):
|
||||
Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
|
||||
is padded with 0's and then center cropped. Can be overridden by `do_center_crop` in `preprocess`.
|
||||
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
||||
Desired output size when applying center-cropping. Can be overridden by `crop_size` in `preprocess`.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
||||
parameter in the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
||||
`preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
||||
parameter in the `preprocess` method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||
method.
|
||||
|
@ -53,7 +53,7 @@ class MCTCTConfig(PretrainedConfig):
|
||||
Dimensions of each attention head for each attention layer in the Transformer encoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 920):
|
||||
The maximum sequence length that this model might ever be used with (after log-mel spectrogram extraction).
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the layer normalization layers.
|
||||
layerdrop (`float`, *optional*, defaults to 0.3):
|
||||
The probability of dropping an encoder layer during training. The default 0.3 value is used in the original
|
||||
@ -63,9 +63,9 @@ class MCTCTConfig(PretrainedConfig):
|
||||
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.3):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.3):
|
||||
The dropout ratio for the attention probabilities.
|
||||
pad_token_id (`int`, *optional*, defaults to 1):
|
||||
The tokenizer index of the pad token.
|
||||
@ -80,17 +80,17 @@ class MCTCTConfig(PretrainedConfig):
|
||||
The probability of randomly dropping the `Conv1dSubsampler` layer during training.
|
||||
num_conv_layers (`int`, *optional*, defaults to 1):
|
||||
Number of convolution layers before applying transformer encoder layers.
|
||||
conv_kernel (`List[int]`, *optional*, defaults to `[7]`):
|
||||
conv_kernel (`Sequence[int]`, *optional*, defaults to `(7,)`):
|
||||
The kernel size of the 1D convolution applied before transformer layers. `len(conv_kernel)` must be equal
|
||||
to `num_conv_layers`.
|
||||
conv_stride (`List[int]`, *optional*, defaults to `[3]`):
|
||||
conv_stride (`Sequence[int]`, *optional*, defaults to `(3,)`):
|
||||
The stride length of the 1D convolution applied before transformer layers. `len(conv_stride)` must be equal
|
||||
to `num_conv_layers`.
|
||||
input_feat_per_channel (`int`, *optional*, defaults to 80):
|
||||
Feature dimensions of the channels of the input to the Conv1D layer.
|
||||
input_channels (`int`, *optional*, defaults to 1):
|
||||
Number of input channels of the input to the Conv1D layer.
|
||||
conv_channels (`List[int]`, *optional*, defaults to None):
|
||||
conv_channels (`List[int]`, *optional*):
|
||||
Channel sizes of intermediate Conv1D layers.
|
||||
ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
|
||||
Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
|
||||
|
@ -57,9 +57,9 @@ class VanConfig(PretrainedConfig):
|
||||
`"selu"` and `"gelu_new"` are supported.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
layer_scale_init_value (`float`, *optional*, defaults to 1e-2):
|
||||
layer_scale_init_value (`float`, *optional*, defaults to 0.01):
|
||||
The initial value for layer scaling.
|
||||
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for stochastic depth.
|
||||
|
@ -44,9 +44,9 @@ class DinatConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
The number of input channels.
|
||||
embed_dim (`int`, *optional*, defaults to 64):
|
||||
Dimensionality of patch embedding.
|
||||
depths (`List[int]`, *optional*, defaults to `[2, 2, 6, 2]`):
|
||||
depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 5]`):
|
||||
Number of layers in each level of the encoder.
|
||||
num_heads (`List[int]`, *optional*, defaults to `[3, 6, 12, 24]`):
|
||||
num_heads (`List[int]`, *optional*, defaults to `[2, 4, 8, 16]`):
|
||||
Number of attention heads in each layer of the Transformer encoder.
|
||||
kernel_size (`int`, *optional*, defaults to 7):
|
||||
Neighborhood Attention kernel size.
|
||||
@ -67,7 +67,7 @@ class DinatConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
`"selu"` and `"gelu_new"` are supported.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the layer normalization layers.
|
||||
layer_scale_init_value (`float`, *optional*, defaults to 0.0):
|
||||
The initial value for the layer scale. Disabled if <=0.
|
||||
|
@ -60,7 +60,7 @@ class Dinov2Config(BackboneConfigMixin, PretrainedConfig):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
image_size (`int`, *optional*, defaults to 224):
|
||||
The size (resolution) of each image.
|
||||
|
@ -45,15 +45,15 @@ class DonutSwinConfig(PretrainedConfig):
|
||||
The number of input channels.
|
||||
embed_dim (`int`, *optional*, defaults to 96):
|
||||
Dimensionality of patch embedding.
|
||||
depths (`list(int)`, *optional*, defaults to [2, 2, 6, 2]):
|
||||
depths (`list(int)`, *optional*, defaults to `[2, 2, 6, 2]`):
|
||||
Depth of each layer in the Transformer encoder.
|
||||
num_heads (`list(int)`, *optional*, defaults to [3, 6, 12, 24]):
|
||||
num_heads (`list(int)`, *optional*, defaults to `[3, 6, 12, 24]`):
|
||||
Number of attention heads in each layer of the Transformer encoder.
|
||||
window_size (`int`, *optional*, defaults to 7):
|
||||
Size of windows.
|
||||
mlp_ratio (`float`, *optional*, defaults to 4.0):
|
||||
Ratio of MLP hidden dimensionality to embedding dimensionality.
|
||||
qkv_bias (`bool`, *optional*, defaults to True):
|
||||
qkv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not a learnable bias should be added to the queries, keys and values.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings and encoder.
|
||||
@ -64,11 +64,11 @@ class DonutSwinConfig(PretrainedConfig):
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
|
||||
`"selu"` and `"gelu_new"` are supported.
|
||||
use_absolute_embeddings (`bool`, *optional*, defaults to False):
|
||||
use_absolute_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to add absolute position embeddings to the patch embeddings.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the layer normalization layers.
|
||||
|
||||
Example:
|
||||
|
@ -32,9 +32,9 @@ class DonutProcessor(ProcessorMixin):
|
||||
[`~DonutProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`DonutImageProcessor`]):
|
||||
image_processor ([`DonutImageProcessor`], *optional*):
|
||||
An instance of [`DonutImageProcessor`]. The image processor is a required input.
|
||||
tokenizer ([`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]):
|
||||
tokenizer ([`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`], *optional*):
|
||||
An instance of [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]. The tokenizer is a required input.
|
||||
"""
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
|
@ -52,9 +52,9 @@ class DPTConfig(PretrainedConfig):
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
@ -66,6 +66,8 @@ class DPTConfig(PretrainedConfig):
|
||||
The size (resolution) of each patch.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
is_hybrid (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a hybrid backbone. Useful in the context of loading DPT-Hybrid models.
|
||||
qkv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to add a bias to the queries, keys and values.
|
||||
backbone_out_indices (`List[int]`, *optional*, defaults to `[2, 5, 8, 11]`):
|
||||
@ -79,11 +81,9 @@ class DPTConfig(PretrainedConfig):
|
||||
- "project" passes information to the other tokens by concatenating the readout to all other tokens before
|
||||
projecting the
|
||||
representation to the original feature dimension D using a linear layer followed by a GELU non-linearity.
|
||||
is_hybrid (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a hybrid backbone. Useful in the context of loading DPT-Hybrid models.
|
||||
reassemble_factors (`List[int]`, *optional*, defaults to `[4, 2, 1, 0.5]`):
|
||||
The up/downsampling factors of the reassemble layers.
|
||||
neck_hidden_sizes (`List[str]`, *optional*, defaults to [96, 192, 384, 768]):
|
||||
neck_hidden_sizes (`List[str]`, *optional*, defaults to `[96, 192, 384, 768]`):
|
||||
The hidden sizes to project to for the feature maps of the backbone.
|
||||
fusion_hidden_size (`int`, *optional*, defaults to 256):
|
||||
The number of channels before fusion.
|
||||
|
@ -100,14 +100,14 @@ class DPTImageProcessor(BaseImageProcessor):
|
||||
Whether to resize the image's (height, width) dimensions. Can be overidden by `do_resize` in `preprocess`.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"height": 384, "width": 384}`):
|
||||
Size of the image after resizing. Can be overidden by `size` in `preprocess`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`.
|
||||
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can
|
||||
be overidden by `keep_aspect_ratio` in `preprocess`.
|
||||
ensure_multiple_of (`int`, *optional*, defaults to 1):
|
||||
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overidden
|
||||
by `ensure_multiple_of` in `preprocess`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||
Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overidden by `do_rescale` in
|
||||
`preprocess`.
|
||||
|
@ -52,22 +52,22 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
||||
`do_resize` in `preprocess`.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"height": 346, "width": 346}`):
|
||||
Size of the image after `resize`. Can be overridden by `size` in `preprocess`.
|
||||
resample (`PILImageResampling` filter, *optional*, defaults to `PILImageResampling.NEAREST`):
|
||||
resample (`PILImageResampling` filter, *optional*, defaults to 0):
|
||||
Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
|
||||
do_center_crop (`bool`, *optional*, defaults to `False`):
|
||||
Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
|
||||
is padded with 0's and then center cropped. Can be overridden by `do_center_crop` in `preprocess`.
|
||||
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 289, "width": 289}`):
|
||||
Desired output size when applying center-cropping. Can be overridden by `crop_size` in `preprocess`.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
||||
parameter in the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
||||
`preprocess` method.
|
||||
rescale_offset (`bool`, *optional*, defaults to `False`):
|
||||
Whether to rescale the image between [-scale_range, scale_range] instead of [0, scale_range]. Can be
|
||||
overridden by the `rescale_factor` parameter in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
||||
parameter in the `preprocess` method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||
method.
|
||||
|
@ -46,13 +46,13 @@ class FalconConfig(PretrainedConfig):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 71):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the layer normalization layers.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether the model should return the last key/values attentions (not used by all models). Only relevant if
|
||||
`config.is_decoder=True`.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
||||
The epsilon used by the layer normalization layers.
|
||||
hidden_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for MLP layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
|
@ -207,7 +207,7 @@ class FlaubertTokenizer(PreTrainedTokenizer):
|
||||
mask_token (`str`, *optional*, defaults to `"<special1>"`):
|
||||
The token used for masking values. This is the token used when training this model with masked language
|
||||
modeling. This is the token which the model will try to predict.
|
||||
additional_special_tokens (`List[str]`, *optional*, defaults to `["<special0>","<special1>","<special2>","<special3>","<special4>","<special5>","<special6>","<special7>","<special8>","<special9>"]`):
|
||||
additional_special_tokens (`List[str]`, *optional*, defaults to `['<special0>', '<special1>', '<special2>', '<special3>', '<special4>', '<special5>', '<special6>', '<special7>', '<special8>', '<special9>']`):
|
||||
List of additional special tokens.
|
||||
lang2id (`Dict[str, int]`, *optional*):
|
||||
Dictionary mapping languages string identifiers to their IDs.
|
||||
|
@ -52,9 +52,9 @@ class FlavaImageConfig(PretrainedConfig):
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
@ -291,7 +291,7 @@ class FlavaMultimodalConfig(PretrainedConfig):
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
num_hidden_layers (`int`, *optional*, defaults to 6):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
@ -300,9 +300,9 @@ class FlavaMultimodalConfig(PretrainedConfig):
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
|
@ -33,8 +33,8 @@ class FlavaProcessor(ProcessorMixin):
|
||||
[`~FlavaProcessor.__call__`] and [`~FlavaProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`FlavaImageProcessor`]): The image processor is a required input.
|
||||
tokenizer ([`BertTokenizerFast`]): The tokenizer is a required input.
|
||||
image_processor ([`FlavaImageProcessor`], *optional*): The image processor is a required input.
|
||||
tokenizer ([`BertTokenizerFast`], *optional*): The tokenizer is a required input.
|
||||
"""
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "FlavaImageProcessor"
|
||||
|
@ -67,7 +67,7 @@ class FocalNetConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
Stochastic depth rate.
|
||||
use_layerscale (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use layer scale in the encoder.
|
||||
layerscale_value (`float`, *optional*, defaults to 1e-4):
|
||||
layerscale_value (`float`, *optional*, defaults to 0.0001):
|
||||
The initial value of the layer scale.
|
||||
use_post_layernorm (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use post layer normalization in the encoder.
|
||||
@ -77,9 +77,9 @@ class FocalNetConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
Whether to normalize the modulator.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the layer normalization layers.
|
||||
encoder_stride (`int`, `optional`, defaults to 32):
|
||||
encoder_stride (`int`, *optional*, defaults to 32):
|
||||
Factor to increase the spatial resolution by in the decoder head for masked image modeling.
|
||||
out_features (`List[str]`, *optional*):
|
||||
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
||||
|
@ -146,13 +146,13 @@ class FSMTTokenizer(PreTrainedTokenizer):
|
||||
this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
langs (`List[str]`):
|
||||
langs (`List[str]`, *optional*):
|
||||
A list of two languages to translate from and to, for instance `["en", "ru"]`.
|
||||
src_vocab_file (`str`):
|
||||
src_vocab_file (`str`, *optional*):
|
||||
File containing the vocabulary for the source language.
|
||||
tgt_vocab_file (`st`):
|
||||
tgt_vocab_file (`st`, *optional*):
|
||||
File containing the vocabulary for the target language.
|
||||
merges_file (`str`):
|
||||
merges_file (`str`, *optional*):
|
||||
File containing the merges.
|
||||
do_lower_case (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to lowercase the input when tokenizing.
|
||||
|
@ -81,7 +81,7 @@ class FunnelConfig(PretrainedConfig):
|
||||
The standard deviation of the *normal initializer* for initializing the embedding matrix and the weight of
|
||||
linear layers. Will default to 1 for the embedding matrix and the value given by Xavier initialization for
|
||||
linear layers.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-9):
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-09):
|
||||
The epsilon used by the layer normalization layers.
|
||||
pooling_type (`str`, *optional*, defaults to `"mean"`):
|
||||
Possible values are `"mean"` or `"max"`. The way pooling is performed at the beginning of each block.
|
||||
@ -90,10 +90,10 @@ class FunnelConfig(PretrainedConfig):
|
||||
is faster on TPU.
|
||||
separate_cls (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to separate the cls token when applying pooling.
|
||||
truncate_seq (`bool`, *optional*, defaults to `False`):
|
||||
truncate_seq (`bool`, *optional*, defaults to `True`):
|
||||
When using `separate_cls`, whether or not to truncate the last token when pooling, to avoid getting a
|
||||
sequence length that is not a multiple of 2.
|
||||
pool_q_only (`bool`, *optional*, defaults to `False`):
|
||||
pool_q_only (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to apply the pooling only to the query or to query, key and values for the attention layers.
|
||||
"""
|
||||
model_type = "funnel"
|
||||
|
@ -120,9 +120,9 @@ class FunnelTokenizer(PreTrainedTokenizer):
|
||||
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
||||
The token used for masking values. This is the token used when training this model with masked language
|
||||
modeling. This is the token which the model will try to predict.
|
||||
bos_token (`str`, `optional`, defaults to `"<s>"`):
|
||||
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
||||
The beginning of sentence token.
|
||||
eos_token (`str`, `optional`, defaults to `"</s>"`):
|
||||
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
||||
The end of sentence token.
|
||||
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to tokenize Chinese characters.
|
||||
|
@ -51,7 +51,7 @@ class GLPNConfig(PretrainedConfig):
|
||||
Patch size before each encoder block.
|
||||
strides (`List[int]`, *optional*, defaults to `[4, 2, 2, 2]`):
|
||||
Stride before each encoder block.
|
||||
num_attention_heads (`List[int]`, *optional*, defaults to `[1, 2, 4, 8]`):
|
||||
num_attention_heads (`List[int]`, *optional*, defaults to `[1, 2, 5, 8]`):
|
||||
Number of attention heads for each attention layer in each block of the Transformer encoder.
|
||||
mlp_ratios (`List[int]`, *optional*, defaults to `[4, 4, 4, 4]`):
|
||||
Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the
|
||||
@ -67,9 +67,9 @@ class GLPNConfig(PretrainedConfig):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
drop_path_rate (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probability for stochastic depth, used in the blocks of the Transformer encoder.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
decoder_hidden_size (`int`, *optional*, defaults to 32):
|
||||
decoder_hidden_size (`int`, *optional*, defaults to 64):
|
||||
The dimension of the decoder.
|
||||
max_depth (`int`, *optional*, defaults to 10):
|
||||
The maximum depth of the decoder.
|
||||
|
@ -48,7 +48,7 @@ class GLPNImageProcessor(BaseImageProcessor):
|
||||
size_divisor (`int`, *optional*, defaults to 32):
|
||||
When `do_resize` is `True`, images are resized so their height and width are rounded down to the closest
|
||||
multiple of `size_divisor`. Can be overridden by `size_divisor` in `preprocess`.
|
||||
resample (`PIL.Image` resampling filter, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||
resample (`PIL.Image` resampling filter, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Can be
|
||||
|
@ -54,7 +54,7 @@ class GPTNeoConfig(PretrainedConfig):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_layers (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
attention_types (`List`, *optional*, defaults to `[[["global", "local"], 12]]`):
|
||||
attention_types (`List`, *optional*, defaults to `[[['global', 'local'], 12]]`):
|
||||
The type of attention for each layer in a `List` of the following format `[[["attention_type"],
|
||||
num_layerss]]` e.g. for a 24 layer model `[[["global"], 24]]` or `[[["global", "local"], 12]]` Choose the
|
||||
value of `attention_type` from `["global", "local"]`
|
||||
@ -76,7 +76,7 @@ class GPTNeoConfig(PretrainedConfig):
|
||||
classifier_dropout (`float`, *optional*, defaults to 0.1):
|
||||
Argument used when doing token classification, used in the model [`GPTNeoForTokenClassification`]. The
|
||||
dropout ratio for the hidden layer.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the layer normalization layers.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
|
@ -64,17 +64,17 @@ class GPTSw3Tokenizer(PreTrainedTokenizer):
|
||||
Whether or not to strip the text when tokenizing (removing excess spaces before and after the string).
|
||||
keep_accents (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to keep accents when tokenizing.
|
||||
bos_token (`str`, *optional*):
|
||||
The beginning of sequence token that can be used for downstream task, was not seen during pretraining. If
|
||||
not provided, will default to '<s>' or '<|endoftext|>', depending on model size.
|
||||
eos_token (`str`, *optional*):
|
||||
The end of sequence token seen during pretraining. If not provided, will default to '<|endoftext|>'
|
||||
unk_token (`str`, *optional*):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead. If not provided, will default to '<unk>'.
|
||||
pad_token (`str`, *optional*):
|
||||
The token used for padding, for example when batching sequences of different lengths. If not provided, will
|
||||
default to '<pad>' or '<unk>' depending on model size.
|
||||
unk_token (`str`, *optional*):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead. If not provided, will default to '<unk>'.
|
||||
eos_token (`str`, *optional*):
|
||||
The end of sequence token seen during pretraining. If not provided, will default to '<|endoftext|>'
|
||||
bos_token (`str`, *optional*):
|
||||
The beginning of sequence token that can be used for downstream task, was not seen during pretraining. If
|
||||
not provided, will default to '<s>' or '<|endoftext|>', depending on model size.
|
||||
sp_model_kwargs (`dict`, *optional*):
|
||||
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
||||
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
||||
|
@ -139,7 +139,7 @@ class GPTSanJapaneseTokenizer(PreTrainedTokenizer):
|
||||
The token used for unknown charactor
|
||||
pad_token (`str`, *optional*, defaults to `"<|separator|>"`):
|
||||
The token used for padding
|
||||
bos_token (`str`, *optional*, defaults to `"<|startoftext|>""`):
|
||||
bos_token (`str`, *optional*, defaults to `"<|startoftext|>"`):
|
||||
The beginning of sequence token.
|
||||
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The end of sequence token.
|
||||
|
@ -53,10 +53,8 @@ class IdeficsImageProcessor(BaseImageProcessor):
|
||||
Constructs a Idefics image processor.
|
||||
|
||||
Args:
|
||||
image_size (`int`, *optional*, defaults to `224`):
|
||||
image_size (`int`, *optional*, defaults to 224):
|
||||
Resize to image size
|
||||
image_num_channels (`int`, *optional*, defaults to `3`):
|
||||
Number of image channels.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
|
||||
@ -65,6 +63,8 @@ class IdeficsImageProcessor(BaseImageProcessor):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
image_num_channels (`int`, *optional*, defaults to 3):
|
||||
Number of image channels.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
@ -70,7 +70,7 @@ class ImageGPTImageProcessor(BaseImageProcessor):
|
||||
`do_resize` in `preprocess`.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
|
||||
Size of the image after resizing. Can be overridden by `size` in `preprocess`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image pixel value to between [-1, 1]. Can be overridden by `do_normalize` in
|
||||
|
@ -57,7 +57,7 @@ class InstructBlipVisionConfig(PretrainedConfig):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. to 1e-5): The epsilon used by the layer
|
||||
normalization layers.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
|
@ -83,8 +83,6 @@ class LayoutLMConfig(PretrainedConfig):
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
classifier_dropout (`float`, *optional*):
|
||||
The dropout ratio for the classification head.
|
||||
max_2d_position_embeddings (`int`, *optional*, defaults to 1024):
|
||||
The maximum value that the 2D position embedding might ever used. Typically set this to something large
|
||||
just in case (e.g., 1024).
|
||||
|
@ -100,7 +100,7 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor):
|
||||
overridden by `do_resize` in `preprocess`.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
|
||||
Size of the image after resizing. Can be overridden by `size` in `preprocess`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
||||
`preprocess` method.
|
||||
apply_ocr (`bool`, *optional*, defaults to `True`):
|
||||
@ -109,7 +109,7 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor):
|
||||
ocr_lang (`str`, *optional*):
|
||||
The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
|
||||
used. Can be overridden by `ocr_lang` in `preprocess`.
|
||||
tesseract_config (`str`, *optional*):
|
||||
tesseract_config (`str`, *optional*, defaults to `""`):
|
||||
Any additional custom configuration flags that are forwarded to the `config` parameter when calling
|
||||
Tesseract. For example: '--psm 6'. Can be overridden by `tesseract_config` in `preprocess`.
|
||||
"""
|
||||
|
@ -38,9 +38,9 @@ class LayoutLMv2Processor(ProcessorMixin):
|
||||
into token-level `labels` for token classification tasks (such as FUNSD, CORD).
|
||||
|
||||
Args:
|
||||
image_processor (`LayoutLMv2ImageProcessor`):
|
||||
image_processor (`LayoutLMv2ImageProcessor`, *optional*):
|
||||
An instance of [`LayoutLMv2ImageProcessor`]. The image processor is a required input.
|
||||
tokenizer (`LayoutLMv2Tokenizer` or `LayoutLMv2TokenizerFast`):
|
||||
tokenizer (`LayoutLMv2Tokenizer` or `LayoutLMv2TokenizerFast`, *optional*):
|
||||
An instance of [`LayoutLMv2Tokenizer`] or [`LayoutLMv2TokenizerFast`]. The tokenizer is a required input.
|
||||
"""
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
|
@ -38,9 +38,9 @@ class LayoutLMv3Processor(ProcessorMixin):
|
||||
into token-level `labels` for token classification tasks (such as FUNSD, CORD).
|
||||
|
||||
Args:
|
||||
image_processor (`LayoutLMv3ImageProcessor`):
|
||||
image_processor (`LayoutLMv3ImageProcessor`, *optional*):
|
||||
An instance of [`LayoutLMv3ImageProcessor`]. The image processor is a required input.
|
||||
tokenizer (`LayoutLMv3Tokenizer` or `LayoutLMv3TokenizerFast`):
|
||||
tokenizer (`LayoutLMv3Tokenizer` or `LayoutLMv3TokenizerFast`, *optional*):
|
||||
An instance of [`LayoutLMv3Tokenizer`] or [`LayoutLMv3TokenizerFast`]. The tokenizer is a required input.
|
||||
"""
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
|
@ -253,7 +253,7 @@ class LayoutLMv3Tokenizer(PreTrainedTokenizer):
|
||||
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
||||
The token used for masking values. This is the token used when training this model with masked language
|
||||
modeling. This is the token which the model will try to predict.
|
||||
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
||||
add_prefix_space (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
||||
other word. (RoBERTa tokenizer detect beginning of words by the preceding space).
|
||||
cls_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
|
||||
|
@ -37,9 +37,9 @@ class LayoutXLMProcessor(ProcessorMixin):
|
||||
into token-level `labels` for token classification tasks (such as FUNSD, CORD).
|
||||
|
||||
Args:
|
||||
image_processor (`LayoutLMv2ImageProcessor`):
|
||||
image_processor (`LayoutLMv2ImageProcessor`, *optional*):
|
||||
An instance of [`LayoutLMv2ImageProcessor`]. The image processor is a required input.
|
||||
tokenizer (`LayoutXLMTokenizer` or `LayoutXLMTokenizerFast`):
|
||||
tokenizer (`LayoutXLMTokenizer` or `LayoutXLMTokenizerFast`, *optional*):
|
||||
An instance of [`LayoutXLMTokenizer`] or [`LayoutXLMTokenizerFast`]. The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
|
@ -203,8 +203,6 @@ class LayoutXLMTokenizer(PreTrainedTokenizer):
|
||||
CrossEntropyLoss.
|
||||
only_label_first_subword (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to only label the first subword, in case word labels are provided.
|
||||
additional_special_tokens (`List[str]`, *optional*, defaults to `["<s>NOTUSED", "</s>NOTUSED"]`):
|
||||
Additional special tokens used by the tokenizer.
|
||||
sp_model_kwargs (`dict`, *optional*):
|
||||
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
||||
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
||||
|
@ -56,7 +56,7 @@ class LevitImageProcessor(BaseImageProcessor):
|
||||
edge value `c` is rescaled to `int(c * (256/224))`. The smaller edge of the image will be matched to this
|
||||
value i.e, if height > width, then image will be rescaled to `(size["shortest_egde"] * height / width,
|
||||
size["shortest_egde"])`. Can be overridden by the `size` parameter in the `preprocess` method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
||||
`preprocess` method.
|
||||
do_center_crop (`bool`, *optional*, defaults to `True`):
|
||||
@ -74,10 +74,10 @@ class LevitImageProcessor(BaseImageProcessor):
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
|
||||
`preprocess` method.
|
||||
image_mean (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
|
||||
image_mean (`List[int]`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
|
||||
image_std (`List[int]`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
"""
|
||||
|
@ -43,14 +43,18 @@ class LxmertConfig(PretrainedConfig):
|
||||
`inputs_ids` passed when calling [`LxmertModel`] or [`TFLxmertModel`].
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
r_layers (`int`, *optional*, defaults to 5):
|
||||
Number of hidden layers in the Transformer visual encoder.
|
||||
l_layers (`int`, *optional*, defaults to 9):
|
||||
Number of hidden layers in the Transformer language encoder.
|
||||
x_layers (`int`, *optional*, defaults to 5):
|
||||
Number of hidden layers in the Transformer cross modality encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 5):
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_qa_labels (`int`, *optional*, defaults to 9500):
|
||||
This represents the total number of different question answering (QA) labels there are. If using more than
|
||||
one dataset with QA, the user will need to account for the total number of labels that all of the datasets
|
||||
have in total.
|
||||
num_object_labels (`int`, *optional*, defaults to 1600):
|
||||
This represents the total number of semantically unique objects that lxmert will be able to classify a
|
||||
pooled-object feature as belonging too.
|
||||
num_attr_labels (`int`, *optional*, defaults to 400):
|
||||
This represents the total number of semantically unique attributes that lxmert will be able to classify a
|
||||
pooled-object feature as possessing.
|
||||
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
|
||||
@ -69,25 +73,21 @@ class LxmertConfig(PretrainedConfig):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
l_layers (`int`, *optional*, defaults to 9):
|
||||
Number of hidden layers in the Transformer language encoder.
|
||||
x_layers (`int`, *optional*, defaults to 5):
|
||||
Number of hidden layers in the Transformer cross modality encoder.
|
||||
r_layers (`int`, *optional*, defaults to 5):
|
||||
Number of hidden layers in the Transformer visual encoder.
|
||||
visual_feat_dim (`int`, *optional*, defaults to 2048):
|
||||
This represents the last dimension of the pooled-object features used as input for the model, representing
|
||||
the size of each object feature itself.
|
||||
visual_pos_dim (`int`, *optional*, defaults to 4):
|
||||
This represents the number of spacial features that are mixed into the visual features. The default is set
|
||||
to 4 because most commonly this will represent the location of a bounding box. i.e., (x, y, width, height)
|
||||
visual_loss_normalizer (`float`, *optional*, defaults to 1/15):
|
||||
visual_loss_normalizer (`float`, *optional*, defaults to 6.67):
|
||||
This represents the scaling factor in which each visual loss is multiplied by if during pretraining, one
|
||||
decided to train with multiple vision-based loss objectives.
|
||||
num_qa_labels (`int`, *optional*, defaults to 9500):
|
||||
This represents the total number of different question answering (QA) labels there are. If using more than
|
||||
one dataset with QA, the user will need to account for the total number of labels that all of the datasets
|
||||
have in total.
|
||||
num_object_labels (`int`, *optional*, defaults to 1600):
|
||||
This represents the total number of semantically unique objects that lxmert will be able to classify a
|
||||
pooled-object feature as belonging too.
|
||||
num_attr_labels (`int`, *optional*, defaults to 400):
|
||||
This represents the total number of semantically unique attributes that lxmert will be able to classify a
|
||||
pooled-object feature as possessing.
|
||||
task_matched (`bool`, *optional*, defaults to `True`):
|
||||
This task is used for sentence-image matching. If the sentence correctly describes the image the label will
|
||||
be 1. If the sentence does not correctly describe the image, the label will be 0.
|
||||
@ -104,12 +104,6 @@ class LxmertConfig(PretrainedConfig):
|
||||
Whether or not to calculate the attribute-prediction loss objective
|
||||
visual_feat_loss (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to calculate the feature-regression loss objective
|
||||
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the model should return the attentions from the vision, language, and cross-modality layers
|
||||
should be returned.
|
||||
output_hidden_states (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the model should return the hidden states from the vision, language, and cross-modality
|
||||
layers should be returned.
|
||||
"""
|
||||
|
||||
model_type = "lxmert"
|
||||
|
@ -356,20 +356,17 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of
|
||||
the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size *
|
||||
height / width, size)`.
|
||||
max_size (`int`, *optional*, defaults to 1333):
|
||||
The largest size an image dimension can have (otherwise it's capped). Only has an effect if `do_resize` is
|
||||
set to `True`.
|
||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
|
||||
size_divisor (`int`, *optional*, defaults to 32):
|
||||
Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in
|
||||
Swin Transformer.
|
||||
resample (`int`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
||||
to `True`.
|
||||
size_divisor (`int`, *optional*, defaults to 32):
|
||||
Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in
|
||||
Swin Transformer.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the input to a certain `scale`.
|
||||
rescale_factor (`float`, *optional*, defaults to 1/ 255):
|
||||
rescale_factor (`float`, *optional*, defaults to `1/ 255`):
|
||||
Rescale the input by the given factor. Only has an effect if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to normalize the input with mean and standard deviation.
|
||||
|
@ -358,20 +358,17 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of
|
||||
the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size *
|
||||
height / width, size)`.
|
||||
max_size (`int`, *optional*, defaults to 1333):
|
||||
The largest size an image dimension can have (otherwise it's capped). Only has an effect if `do_resize` is
|
||||
set to `True`.
|
||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
|
||||
size_divisor (`int`, *optional*, defaults to 32):
|
||||
Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in
|
||||
Swin Transformer.
|
||||
resample (`int`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
||||
to `True`.
|
||||
size_divisor (`int`, *optional*, defaults to 32):
|
||||
Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in
|
||||
Swin Transformer.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the input to a certain `scale`.
|
||||
rescale_factor (`float`, *optional*, defaults to 1/ 255):
|
||||
rescale_factor (`float`, *optional*, defaults to `1/ 255`):
|
||||
Rescale the input by the given factor. Only has an effect if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to normalize the input with mean and standard deviation.
|
||||
|
@ -62,7 +62,7 @@ class MgpstrConfig(PretrainedConfig):
|
||||
Whether to add a bias to the queries, keys and values.
|
||||
distilled (`bool`, *optional*, defaults to `False`):
|
||||
Model includes a distillation token and head as in DeiT models.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the layer normalization layers.
|
||||
drop_rate (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder.
|
||||
|
@ -44,9 +44,9 @@ class MgpstrProcessor(ProcessorMixin):
|
||||
[`~MgpstrProcessor.__call__`] and [`~MgpstrProcessor.batch_decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor (`ViTImageProcessor`):
|
||||
image_processor (`ViTImageProcessor`, *optional*):
|
||||
An instance of `ViTImageProcessor`. The image processor is a required input.
|
||||
tokenizer ([`MgpstrTokenizer`]):
|
||||
tokenizer ([`MgpstrTokenizer`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
"""
|
||||
attributes = ["image_processor", "char_tokenizer"]
|
||||
|
@ -52,7 +52,7 @@ class MgpstrTokenizer(PreTrainedTokenizer):
|
||||
The beginning of sequence token.
|
||||
eos_token (`str`, *optional*, defaults to `"[s]"`):
|
||||
The end of sequence token.
|
||||
pad_token (`str` or `tokenizers.AddedToken`, *optional*, , defaults to `"[GO]"`):
|
||||
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"[GO]"`):
|
||||
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
|
||||
attention mechanisms or loss computation.
|
||||
"""
|
||||
|
@ -55,7 +55,7 @@ class MobileNetV1Config(PretrainedConfig):
|
||||
All layers will have at least this many channels.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"relu6"`):
|
||||
The non-linear activation function (function or string) in the Transformer encoder and convolution layers.
|
||||
tf_padding (`bool`, `optional`, defaults to `True`):
|
||||
tf_padding (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use TensorFlow padding rules on the convolution layers.
|
||||
classifier_dropout_prob (`float`, *optional*, defaults to 0.999):
|
||||
The dropout ratio for attached classifiers.
|
||||
|
@ -64,16 +64,16 @@ class MobileNetV2Config(PretrainedConfig):
|
||||
the input dimensions by a factor of 32. If `output_stride` is 8 or 16, the model uses dilated convolutions
|
||||
on the depthwise layers instead of regular convolutions, so that the feature maps never become more than 8x
|
||||
or 16x smaller than the input image.
|
||||
first_layer_is_expansion (`bool`, `optional`, defaults to `True`):
|
||||
first_layer_is_expansion (`bool`, *optional*, defaults to `True`):
|
||||
True if the very first convolution layer is also the expansion layer for the first expansion block.
|
||||
finegrained_output (`bool`, `optional`, defaults to `True`):
|
||||
finegrained_output (`bool`, *optional*, defaults to `True`):
|
||||
If true, the number of output channels in the final convolution layer will stay large (1280) even if
|
||||
`depth_multiplier` is less than 1.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"relu6"`):
|
||||
The non-linear activation function (function or string) in the Transformer encoder and convolution layers.
|
||||
tf_padding (`bool`, `optional`, defaults to `True`):
|
||||
tf_padding (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use TensorFlow padding rules on the convolution layers.
|
||||
classifier_dropout_prob (`float`, *optional*, defaults to 0.999):
|
||||
classifier_dropout_prob (`float`, *optional*, defaults to 0.8):
|
||||
The dropout ratio for attached classifiers.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
@ -105,7 +105,7 @@ class MobileNetV2Config(PretrainedConfig):
|
||||
depth_multiplier=1.0,
|
||||
depth_divisible_by=8,
|
||||
min_depth=8,
|
||||
expand_ratio=6,
|
||||
expand_ratio=6.0,
|
||||
output_stride=32,
|
||||
first_layer_is_expansion=True,
|
||||
finegrained_output=True,
|
||||
|
@ -74,7 +74,7 @@ class MobileViTConfig(PretrainedConfig):
|
||||
The non-linear activation function (function or string) in the Transformer encoder and convolution layers.
|
||||
conv_kernel_size (`int`, *optional*, defaults to 3):
|
||||
The size of the convolutional kernel in the MobileViT layer.
|
||||
output_stride (`int`, `optional`, defaults to 32):
|
||||
output_stride (`int`, *optional*, defaults to 32):
|
||||
The ratio of the spatial resolution of the output to the resolution of the input image.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the Transformer encoder.
|
||||
@ -84,11 +84,11 @@ class MobileViTConfig(PretrainedConfig):
|
||||
The dropout ratio for attached classifiers.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the layer normalization layers.
|
||||
qkv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to add a bias to the queries, keys and values.
|
||||
aspp_out_channels (`int`, `optional`, defaults to 256):
|
||||
aspp_out_channels (`int`, *optional*, defaults to 256):
|
||||
Number of output channels used in the ASPP layer for semantic segmentation.
|
||||
atrous_rates (`List[int]`, *optional*, defaults to `[6, 12, 18]`):
|
||||
Dilation (atrous) factors used in the ASPP layer for semantic segmentation.
|
||||
|
@ -59,7 +59,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
|
||||
Controls the size of the output image after resizing. Can be overridden by the `size` parameter in the
|
||||
`preprocess` method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Defines the resampling filter to use if resizing the image. Can be overridden by the `resample` parameter
|
||||
in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
|
@ -54,15 +54,15 @@ class MobileViTV2Config(PretrainedConfig):
|
||||
The non-linear activation function (function or string) in the Transformer encoder and convolution layers.
|
||||
conv_kernel_size (`int`, *optional*, defaults to 3):
|
||||
The size of the convolutional kernel in the MobileViTV2 layer.
|
||||
output_stride (`int`, `optional`, defaults to 32):
|
||||
output_stride (`int`, *optional*, defaults to 32):
|
||||
The ratio of the spatial resolution of the output to the resolution of the input image.
|
||||
classifier_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for attached classifiers.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the layer normalization layers.
|
||||
aspp_out_channels (`int`, `optional`, defaults to 512):
|
||||
aspp_out_channels (`int`, *optional*, defaults to 512):
|
||||
Number of output channels used in the ASPP layer for semantic segmentation.
|
||||
atrous_rates (`List[int]`, *optional*, defaults to `[6, 12, 18]`):
|
||||
Dilation (atrous) factors used in the ASPP layer for semantic segmentation.
|
||||
@ -74,13 +74,13 @@ class MobileViTV2Config(PretrainedConfig):
|
||||
The number of attention blocks in each MobileViTV2Layer
|
||||
base_attn_unit_dims (`List[int]`, *optional*, defaults to `[128, 192, 256]`):
|
||||
The base multiplier for dimensions of attention blocks in each MobileViTV2Layer
|
||||
width_multiplier (`float`, *optional*, defaults to 1.0)
|
||||
width_multiplier (`float`, *optional*, defaults to 1.0):
|
||||
The width multiplier for MobileViTV2.
|
||||
ffn_multiplier (`int`, *optional*, defaults to 2)
|
||||
ffn_multiplier (`int`, *optional*, defaults to 2):
|
||||
The FFN multiplier for MobileViTV2.
|
||||
attn_dropout (`float`, *optional*, defaults to 0.0)
|
||||
attn_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout in the attention layer.
|
||||
ffn_dropout (`float`, *optional*, defaults to 0.0)
|
||||
ffn_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout between FFN layers.
|
||||
|
||||
Example:
|
||||
|
@ -145,17 +145,17 @@ class MptConfig(PretrainedConfig):
|
||||
the `inputs_ids` passed when calling [`MptModel`]. Check [this
|
||||
discussion](https://huggingface.co/bigscience/mpt/discussions/120#633d28389addb8530b406c2a) on how the
|
||||
`vocab_size` has been defined.
|
||||
resid_pdrop (`float`, *optional*, defaults to 0.1):
|
||||
resid_pdrop (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability applied to the attention output before combining with residual.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon to use in the layer normalization layers.
|
||||
emb_pdrop (`float`, *optional*, defaults to 0.1):
|
||||
emb_pdrop (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for the embedding layer.
|
||||
learned_pos_emb (`bool`, *optional*, defaults to `False`):
|
||||
learned_pos_emb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learned positional embeddings.
|
||||
attn_config (`dict`, *optional*):
|
||||
A dictionary used to configure the model's attention module.
|
||||
init_device (`str`, *optional*):
|
||||
init_device (`str`, *optional*, defaults to `"cpu"`):
|
||||
The device to use for parameter initialization. Defined for backward compatibility
|
||||
logit_scale (`float`, *optional*):
|
||||
If not None, scale the logits by this value.
|
||||
@ -169,7 +169,7 @@ class MptConfig(PretrainedConfig):
|
||||
norm_type (`str`, *optional*, defaults to `"low_precision_layernorm"`):
|
||||
Type of layer norm to use. All MPT models uses the same layer norm implementation. Defined for backward
|
||||
compatibility.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
use_cache (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
|
@ -44,9 +44,9 @@ class NatConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
The number of input channels.
|
||||
embed_dim (`int`, *optional*, defaults to 64):
|
||||
Dimensionality of patch embedding.
|
||||
depths (`List[int]`, *optional*, defaults to `[2, 2, 6, 2]`):
|
||||
depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 5]`):
|
||||
Number of layers in each level of the encoder.
|
||||
num_heads (`List[int]`, *optional*, defaults to `[3, 6, 12, 24]`):
|
||||
num_heads (`List[int]`, *optional*, defaults to `[2, 4, 8, 16]`):
|
||||
Number of attention heads in each layer of the Transformer encoder.
|
||||
kernel_size (`int`, *optional*, defaults to 7):
|
||||
Neighborhood Attention kernel size.
|
||||
@ -65,7 +65,7 @@ class NatConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
`"selu"` and `"gelu_new"` are supported.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the layer normalization layers.
|
||||
layer_scale_init_value (`float`, *optional*, defaults to 0.0):
|
||||
The initial value for the layer scale. Disabled if <=0.
|
||||
|
@ -66,7 +66,7 @@ class NougatImageProcessor(BaseImageProcessor):
|
||||
`do_resize` in the `preprocess` method.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"height": 896, "width": 672}`):
|
||||
Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||
do_thumbnail (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image using thumbnail method.
|
||||
|
@ -383,10 +383,10 @@ class NougatTokenizerFast(PreTrainedTokenizerFast):
|
||||
methods for postprocessing the generated text.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
vocab_file (`str`, *optional*):
|
||||
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
|
||||
contains the vocabulary necessary to instantiate a tokenizer.
|
||||
tokenizer_file (`str`):
|
||||
tokenizer_file (`str`, *optional*):
|
||||
[tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
|
||||
contains everything needed to load the tokenizer.
|
||||
|
||||
@ -394,16 +394,16 @@ class NougatTokenizerFast(PreTrainedTokenizerFast):
|
||||
Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
|
||||
spaces.
|
||||
|
||||
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
|
||||
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
||||
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
||||
|
||||
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
||||
The end of sequence token.
|
||||
|
||||
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
|
||||
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
"""
|
||||
|
@ -42,87 +42,87 @@ class OneFormerConfig(PretrainedConfig):
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
backbone_config (`PretrainedConfig`, *optional*, defaults to `SwinConfig`)
|
||||
backbone_config (`PretrainedConfig`, *optional*, defaults to `SwinConfig`):
|
||||
The configuration of the backbone model.
|
||||
ignore_value (`int`, *optional*, defaults to 255)
|
||||
ignore_value (`int`, *optional*, defaults to 255):
|
||||
Values to be ignored in GT label while calculating loss.
|
||||
num_queries (`int`, *optional*, defaults to 150)
|
||||
num_queries (`int`, *optional*, defaults to 150):
|
||||
Number of object queries.
|
||||
no_object_weight (`float`, *optional*, defaults to 0.1)
|
||||
no_object_weight (`float`, *optional*, defaults to 0.1):
|
||||
Weight for no-object class predictions.
|
||||
class_weight (`float`, *optional*, defaults to 2.0)
|
||||
class_weight (`float`, *optional*, defaults to 2.0):
|
||||
Weight for Classification CE loss.
|
||||
mask_weight (`float`, *optional*, defaults to 5.0)
|
||||
mask_weight (`float`, *optional*, defaults to 5.0):
|
||||
Weight for binary CE loss.
|
||||
dice_weight (`float`, *optional*, defaults to 5.0)
|
||||
dice_weight (`float`, *optional*, defaults to 5.0):
|
||||
Weight for dice loss.
|
||||
contrastive_weight (`float`, *optional*, defaults to 0.5)
|
||||
contrastive_weight (`float`, *optional*, defaults to 0.5):
|
||||
Weight for contrastive loss.
|
||||
contrastive_temperature (`float`, *optional*, defaults to 0.07)
|
||||
contrastive_temperature (`float`, *optional*, defaults to 0.07):
|
||||
Initial value for scaling the contrastive logits.
|
||||
train_num_points (`int`, *optional*, defaults to 12544)
|
||||
train_num_points (`int`, *optional*, defaults to 12544):
|
||||
Number of points to sample while calculating losses on mask predictions.
|
||||
oversample_ratio (`float`, *optional*, defaults to 3.0)
|
||||
oversample_ratio (`float`, *optional*, defaults to 3.0):
|
||||
Ratio to decide how many points to oversample.
|
||||
importance_sample_ratio (`float`, *optional*, defaults to 0.75)
|
||||
importance_sample_ratio (`float`, *optional*, defaults to 0.75):
|
||||
Ratio of points that are sampled via importance sampling.
|
||||
init_std (`float`, *optional*, defaults to 0.02)
|
||||
init_std (`float`, *optional*, defaults to 0.02):
|
||||
Standard deviation for normal intialization.
|
||||
init_xavier_std (`float`, *optional*, defaults to 0.02)
|
||||
init_xavier_std (`float`, *optional*, defaults to 1.0):
|
||||
Standard deviation for xavier uniform initialization.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-05)
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
Epsilon for layer normalization.
|
||||
is_training (`bool`, *optional*, defaults to False)
|
||||
is_training (`bool`, *optional*, defaults to `False`):
|
||||
Whether to run in training or inference mode.
|
||||
use_auxiliary_loss (`bool`, *optional*, defaults to True)
|
||||
use_auxiliary_loss (`bool`, *optional*, defaults to `True`):
|
||||
Whether to calculate loss using intermediate predictions from transformer decoder.
|
||||
output_auxiliary_logits (`bool`, *optional*, defaults to True)
|
||||
output_auxiliary_logits (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return intermediate predictions from transformer decoder.
|
||||
strides (`list`, *optional*, defaults to [4, 8, 16, 32])
|
||||
strides (`list`, *optional*, defaults to `[4, 8, 16, 32]`):
|
||||
List containing the strides for feature maps in the encoder.
|
||||
task_seq_len (`int`, *optional*, defaults to 77)
|
||||
task_seq_len (`int`, *optional*, defaults to 77):
|
||||
Sequence length for tokenizing text list input.
|
||||
text_encoder_width (`int`, *optional*, defaults to 256)
|
||||
text_encoder_width (`int`, *optional*, defaults to 256):
|
||||
Hidden size for text encoder.
|
||||
text_encoder_context_length (`int`, *optional*, defaults to 77):
|
||||
Input sequence length for text encoder.
|
||||
text_encoder_num_layers (`int`, *optional*, defaults to 6)
|
||||
text_encoder_num_layers (`int`, *optional*, defaults to 6):
|
||||
Number of layers for transformer in text encoder.
|
||||
text_encoder_vocab_size (`int`, *optional*, defaults to 49408)
|
||||
text_encoder_vocab_size (`int`, *optional*, defaults to 49408):
|
||||
Vocabulary size for tokenizer.
|
||||
text_encoder_proj_layers (`int`, *optional*, defaults to 2)
|
||||
text_encoder_proj_layers (`int`, *optional*, defaults to 2):
|
||||
Number of layers in MLP for project text queries.
|
||||
text_encoder_n_ctx (`int`, *optional*, defaults to 16)
|
||||
text_encoder_n_ctx (`int`, *optional*, defaults to 16):
|
||||
Number of learnable text context queries.
|
||||
conv_dim (`int`, *optional*, defaults to 256)
|
||||
conv_dim (`int`, *optional*, defaults to 256):
|
||||
Feature map dimension to map outputs from the backbone.
|
||||
mask_dim (`int`, *optional*, defaults to 256)
|
||||
mask_dim (`int`, *optional*, defaults to 256):
|
||||
Dimension for feature maps in pixel decoder.
|
||||
hidden_dim (`int`, *optional*, defaults to 256)
|
||||
hidden_dim (`int`, *optional*, defaults to 256):
|
||||
Dimension for hidden states in transformer decoder.
|
||||
encoder_feedforward_dim (`int`, *optional*, defaults to 1024)
|
||||
encoder_feedforward_dim (`int`, *optional*, defaults to 1024):
|
||||
Dimension for FFN layer in pixel decoder.
|
||||
norm (`str`, *optional*, defaults to `GN`)
|
||||
norm (`str`, *optional*, defaults to `"GN"`):
|
||||
Type of normalization.
|
||||
encoder_layers (`int`, *optional*, defaults to 6)
|
||||
encoder_layers (`int`, *optional*, defaults to 6):
|
||||
Number of layers in pixel decoder.
|
||||
decoder_layers (`int`, *optional*, defaults to 10)
|
||||
decoder_layers (`int`, *optional*, defaults to 10):
|
||||
Number of layers in transformer decoder.
|
||||
use_task_norm (`bool`, *optional*, defaults to `True`)
|
||||
use_task_norm (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the task token.
|
||||
num_attention_heads (`int`, *optional*, defaults to 8)
|
||||
num_attention_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads in transformer layers in the pixel and transformer decoders.
|
||||
dropout (`float`, *optional*, defaults to 0.1)
|
||||
dropout (`float`, *optional*, defaults to 0.1):
|
||||
Dropout probability for pixel and transformer decoders.
|
||||
dim_feedforward (`int`, *optional*, defaults to 2048)
|
||||
dim_feedforward (`int`, *optional*, defaults to 2048):
|
||||
Dimension for FFN layer in transformer decoder.
|
||||
pre_norm (`bool`, *optional*, defaults to `False`)
|
||||
pre_norm (`bool`, *optional*, defaults to `False`):
|
||||
Whether to normalize hidden states before attention layers in transformer decoder.
|
||||
enforce_input_proj (`bool`, *optional*, defaults to `False`)
|
||||
enforce_input_proj (`bool`, *optional*, defaults to `False`):
|
||||
Whether to project hidden states in transformer decoder.
|
||||
query_dec_layers (`int`, *optional*, defaults to 2)
|
||||
query_dec_layers (`int`, *optional*, defaults to 2):
|
||||
Number of layers in query transformer.
|
||||
common_stride (`int`, *optional*, defaults to 4)
|
||||
common_stride (`int`, *optional*, defaults to 4):
|
||||
Common stride used for features in pixel decoder.
|
||||
|
||||
Examples:
|
||||
|
@ -361,17 +361,14 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of
|
||||
the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size *
|
||||
height / width, size)`.
|
||||
max_size (`int`, *optional*, defaults to 1333):
|
||||
The largest size an image dimension can have (otherwise it's capped). Only has an effect if `do_resize` is
|
||||
set to `True`.
|
||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
|
||||
resample (`int`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
||||
to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the input to a certain `scale`.
|
||||
rescale_factor (`float`, *optional*, defaults to 1/ 255):
|
||||
rescale_factor (`float`, *optional*, defaults to `1/ 255`):
|
||||
Rescale the input by the given factor. Only has an effect if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to normalize the input with mean and standard deviation.
|
||||
@ -387,9 +384,9 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
Whether or not to decrement all label values of segmentation maps by 1. Usually used for datasets where 0
|
||||
is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k).
|
||||
The background label will be replaced by `ignore_index`.
|
||||
repo_path (`str`, defaults to `shi-labs/oneformer_demo`):
|
||||
repo_path (`str`, defaults to `shi-labs/oneformer_demo`, *optional*, defaults to `"shi-labs/oneformer_demo"`):
|
||||
Dataset repository on huggingface hub containing the JSON file with class information for the dataset.
|
||||
class_info_file (`str`):
|
||||
class_info_file (`str`, *optional*):
|
||||
JSON file containing class information for the dataset. It is stored inside on the `repo_path` dataset
|
||||
repository.
|
||||
num_text (`int`, *optional*):
|
||||
|
@ -56,7 +56,7 @@ class OpenAIGPTConfig(PretrainedConfig):
|
||||
The dropout ratio for the embeddings.
|
||||
attn_pdrop (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for the attention.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon to use in the layer normalization layers
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
@ -91,8 +91,6 @@ class OpenAIGPTConfig(PretrainedConfig):
|
||||
[`OpenAIGPTDoubleHeadsModel`].
|
||||
|
||||
The dropout ratio to be used after the projection and activation.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
|
||||
|
||||
Examples:
|
||||
|
@ -171,13 +171,13 @@ class OwlViTVisionConfig(PretrainedConfig):
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the layer normalization layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
initializer_factor (`float``, *optional*, defaults to 1):
|
||||
initializer_factor (`float``, *optional*, defaults to 1.0):
|
||||
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
||||
testing).
|
||||
|
||||
|
@ -102,7 +102,7 @@ class OwlViTImageProcessor(BaseImageProcessor):
|
||||
The size to use for resizing the image. Only has an effect if `do_resize` is set to `True`. If `size` is a
|
||||
sequence like (h, w), output size will be matched to this. If `size` is an int, then image will be resized
|
||||
to (size, size).
|
||||
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BICUBIC`):
|
||||
resample (`int`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
|
||||
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
|
||||
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
|
||||
|
@ -33,9 +33,9 @@ class OwlViTProcessor(ProcessorMixin):
|
||||
[`~OwlViTProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`OwlViTImageProcessor`]):
|
||||
image_processor ([`OwlViTImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`]):
|
||||
tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
"""
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
|
@ -65,7 +65,7 @@ class PerceiverConfig(PretrainedConfig):
|
||||
v_channels (`int`, *optional*):
|
||||
Dimension to project the values before applying attention in the cross-attention and self-attention layers
|
||||
of the encoder. Will default to preserving the dimension of the queries if not specified.
|
||||
cross_attention_shape_for_attention (`str`, *optional*, defaults to `'kv'`):
|
||||
cross_attention_shape_for_attention (`str`, *optional*, defaults to `"kv"`):
|
||||
Dimension to use when downsampling the queries and keys in the cross-attention layer of the encoder.
|
||||
self_attention_widening_factor (`int`, *optional*, defaults to 1):
|
||||
Dimension of the feed-forward layer in the cross-attention layer of the Transformer encoder.
|
||||
@ -89,7 +89,7 @@ class PerceiverConfig(PretrainedConfig):
|
||||
this to something large just in case (e.g., 512 or 1024 or 2048).
|
||||
image_size (`int`, *optional*, defaults to 56):
|
||||
Size of the images after preprocessing, for [`PerceiverForImageClassificationLearned`].
|
||||
train_size (`List[int]`, *optional*, defaults to [368, 496]):
|
||||
train_size (`List[int]`, *optional*, defaults to `[368, 496]`):
|
||||
Training size of the images for the optical flow model.
|
||||
num_frames (`int`, *optional*, defaults to 16):
|
||||
Number of video frames used for the multimodal autoencoding model.
|
||||
@ -97,11 +97,11 @@ class PerceiverConfig(PretrainedConfig):
|
||||
Number of audio samples per frame for the multimodal autoencoding model.
|
||||
samples_per_patch (`int`, *optional*, defaults to 16):
|
||||
Number of audio samples per patch when preprocessing the audio for the multimodal autoencoding model.
|
||||
output_num_channels (`int`, *optional*, defaults to 512):
|
||||
Number of output channels for each modalitiy decoder.
|
||||
output_shape (`List[int]`, *optional*, defaults to `[1, 16, 224, 224]`):
|
||||
Shape of the output (batch_size, num_frames, height, width) for the video decoder queries of the multimodal
|
||||
autoencoding model. This excludes the channel dimension.
|
||||
output_num_channels (`int`, *optional*, defaults to 512):
|
||||
Number of output channels for each modalitiy decoder.
|
||||
|
||||
Example:
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user