mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[DocTests Speech] Add doc tests for all speech models (#15031)
* fix_torch_device_generate_test * remove @ * doc tests * up * up * fix doctests * adapt files * finish refactor * up * save intermediate * add more logic * new change * improve * next try * next try * next try * next try * fix final spaces * fix final spaces * improve * renaming * correct more bugs * finish wavlm * add comment * run on test runner * finish all speech models * adapt * finish
This commit is contained in:
parent
4df69506a8
commit
9f831bdeaf
14
.github/workflows/doctests.yml
vendored
14
.github/workflows/doctests.yml
vendored
@ -19,7 +19,7 @@ env:
|
||||
|
||||
jobs:
|
||||
run_doctests:
|
||||
runs-on: [self-hosted, docker-gpu, single-gpu]
|
||||
runs-on: [self-hosted, docker-gpu-test, single-gpu]
|
||||
container:
|
||||
image: pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
@ -35,8 +35,16 @@ jobs:
|
||||
run: |
|
||||
apt -y update && apt install -y libsndfile1-dev
|
||||
pip install --upgrade pip
|
||||
pip install .[dev]
|
||||
pip install .[testing,torch-speech]
|
||||
|
||||
- name: Prepare files for doctests
|
||||
run: |
|
||||
python utils/prepare_for_doc_test.py src docs
|
||||
|
||||
- name: Run doctests
|
||||
run: |
|
||||
pytest --doctest-modules $(cat utils/documentation_tests.txt) -sv --doctest-continue-on-failure
|
||||
pytest --doctest-modules $(cat utils/documentation_tests.txt) -sv --doctest-continue-on-failure --doctest-glob="*.mdx"
|
||||
|
||||
- name: Clean files after doctests
|
||||
run: |
|
||||
python utils/prepare_for_doc_test.py src docs --remove_new_line
|
||||
|
@ -1127,9 +1127,11 @@ PT_SPEECH_BASE_MODEL_SAMPLE = r"""
|
||||
|
||||
```python
|
||||
>>> from transformers import {processor_class}, {model_class}
|
||||
>>> import torch
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
||||
>>> dataset = dataset.sort("id")
|
||||
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
||||
|
||||
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
|
||||
@ -1137,9 +1139,12 @@ PT_SPEECH_BASE_MODEL_SAMPLE = r"""
|
||||
|
||||
>>> # audio file is decoded on the fly
|
||||
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
>>> list(last_hidden_states.shape)
|
||||
{expected_output}
|
||||
```
|
||||
"""
|
||||
|
||||
@ -1152,6 +1157,7 @@ PT_SPEECH_CTC_SAMPLE = r"""
|
||||
>>> import torch
|
||||
|
||||
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
||||
>>> dataset = dataset.sort("id")
|
||||
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
||||
|
||||
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
|
||||
@ -1159,17 +1165,24 @@ PT_SPEECH_CTC_SAMPLE = r"""
|
||||
|
||||
>>> # audio file is decoded on the fly
|
||||
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
|
||||
>>> logits = model(**inputs).logits
|
||||
>>> with torch.no_grad():
|
||||
... logits = model(**inputs).logits
|
||||
>>> predicted_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
>>> # transcribe speech
|
||||
>>> transcription = processor.batch_decode(predicted_ids)
|
||||
>>> transcription[0]
|
||||
{expected_output}
|
||||
```
|
||||
|
||||
>>> # compute loss
|
||||
```python
|
||||
>>> with processor.as_target_processor():
|
||||
... inputs["labels"] = processor(dataset[0]["text"], return_tensors="pt").input_ids
|
||||
|
||||
>>> # compute loss
|
||||
>>> loss = model(**inputs).loss
|
||||
>>> round(loss.item(), 2)
|
||||
{expected_loss}
|
||||
```
|
||||
"""
|
||||
|
||||
@ -1182,21 +1195,31 @@ PT_SPEECH_SEQ_CLASS_SAMPLE = r"""
|
||||
>>> import torch
|
||||
|
||||
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
||||
>>> dataset = dataset.sort("id")
|
||||
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
||||
|
||||
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
|
||||
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
||||
|
||||
>>> # audio file is decoded on the fly
|
||||
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt")
|
||||
>>> logits = model(**inputs).logits
|
||||
>>> predicted_class_ids = torch.argmax(logits, dim=-1)
|
||||
>>> predicted_label = model.config.id2label[predicted_class_ids]
|
||||
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... logits = model(**inputs).logits
|
||||
|
||||
>>> predicted_class_ids = torch.argmax(logits, dim=-1).item()
|
||||
>>> predicted_label = model.config.id2label[predicted_class_ids]
|
||||
>>> predicted_label
|
||||
{expected_output}
|
||||
```
|
||||
|
||||
```python
|
||||
>>> # compute loss - target_label is e.g. "down"
|
||||
>>> target_label = model.config.id2label[0]
|
||||
>>> inputs["labels"] = torch.tensor([model.config.label2id[target_label]])
|
||||
>>> loss = model(**inputs).loss
|
||||
>>> round(loss.item(), 2)
|
||||
{expected_loss}
|
||||
```
|
||||
"""
|
||||
|
||||
@ -1210,17 +1233,22 @@ PT_SPEECH_FRAME_CLASS_SAMPLE = r"""
|
||||
>>> import torch
|
||||
|
||||
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
||||
>>> dataset = dataset.sort("id")
|
||||
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
||||
|
||||
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
|
||||
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
||||
|
||||
>>> # audio file is decoded on the fly
|
||||
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt")
|
||||
>>> logits = model(**inputs).logits
|
||||
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate)
|
||||
>>> with torch.no_grad():
|
||||
... logits = model(**inputs).logits
|
||||
|
||||
>>> probabilities = torch.sigmoid(logits[0])
|
||||
>>> # labels is a one-hot array of shape (num_frames, num_speakers)
|
||||
>>> labels = (probabilities > 0.5).long()
|
||||
>>> labels[0].tolist()
|
||||
{expected_output}
|
||||
```
|
||||
"""
|
||||
|
||||
@ -1234,14 +1262,19 @@ PT_SPEECH_XVECTOR_SAMPLE = r"""
|
||||
>>> import torch
|
||||
|
||||
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
||||
>>> dataset = dataset.sort("id")
|
||||
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
||||
|
||||
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
|
||||
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
||||
|
||||
>>> # audio file is decoded on the fly
|
||||
>>> inputs = feature_extractor(dataset[:2]["audio"]["array"], return_tensors="pt")
|
||||
>>> embeddings = model(**inputs).embeddings
|
||||
>>> inputs = feature_extractor(
|
||||
... [d["array"] for d in dataset[:2]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True
|
||||
... )
|
||||
>>> with torch.no_grad():
|
||||
... embeddings = model(**inputs).embeddings
|
||||
|
||||
>>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
|
||||
|
||||
>>> # the resulting embeddings can be used for cosine similarity-based retrieval
|
||||
@ -1250,6 +1283,8 @@ PT_SPEECH_XVECTOR_SAMPLE = r"""
|
||||
>>> threshold = 0.7 # the optimal threshold is dataset-dependent
|
||||
>>> if similarity < threshold:
|
||||
... print("Speakers are not the same!")
|
||||
>>> round(similarity.item(), 2)
|
||||
{expected_output}
|
||||
```
|
||||
"""
|
||||
|
||||
@ -1553,9 +1588,11 @@ def add_code_sample_docstrings(
|
||||
checkpoint=None,
|
||||
output_type=None,
|
||||
config_class=None,
|
||||
mask=None,
|
||||
mask="[MASK]",
|
||||
model_cls=None,
|
||||
modality=None
|
||||
modality=None,
|
||||
expected_output="",
|
||||
expected_loss="",
|
||||
):
|
||||
def docstring_decorator(fn):
|
||||
# model_class defaults to function's class if not specified otherwise
|
||||
@ -1568,7 +1605,17 @@ def add_code_sample_docstrings(
|
||||
else:
|
||||
sample_docstrings = PT_SAMPLE_DOCSTRINGS
|
||||
|
||||
doc_kwargs = dict(model_class=model_class, processor_class=processor_class, checkpoint=checkpoint)
|
||||
# putting all kwargs for docstrings in a dict to be used
|
||||
# with the `.format(**doc_kwargs)`. Note that string might
|
||||
# be formatted with non-existing keys, which is fine.
|
||||
doc_kwargs = dict(
|
||||
model_class=model_class,
|
||||
processor_class=processor_class,
|
||||
checkpoint=checkpoint,
|
||||
mask=mask,
|
||||
expected_output=expected_output,
|
||||
expected_loss=expected_loss,
|
||||
)
|
||||
|
||||
if "SequenceClassification" in model_class and modality == "audio":
|
||||
code_sample = sample_docstrings["AudioClassification"]
|
||||
@ -1581,7 +1628,6 @@ def add_code_sample_docstrings(
|
||||
elif "MultipleChoice" in model_class:
|
||||
code_sample = sample_docstrings["MultipleChoice"]
|
||||
elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
|
||||
doc_kwargs["mask"] = "[MASK]" if mask is None else mask
|
||||
code_sample = sample_docstrings["MaskedLM"]
|
||||
elif "LMHead" in model_class or "CausalLM" in model_class:
|
||||
code_sample = sample_docstrings["LMHead"]
|
||||
|
@ -40,15 +40,29 @@ from .configuration_hubert import HubertConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "HubertConfig"
|
||||
_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 1
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "HubertConfig"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
|
||||
# Base docstring
|
||||
_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
|
||||
|
||||
# CTC docstring
|
||||
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
|
||||
_CTC_EXPECTED_LOSS = 22.68
|
||||
|
||||
# Audio class docstring
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
|
||||
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
|
||||
_SEQ_CLASS_EXPECTED_LOSS = 8.53
|
||||
|
||||
|
||||
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/hubert-base-ls960",
|
||||
@ -1098,6 +1112,8 @@ class HubertForCTC(HubertPreTrainedModel):
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=CausalLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_CTC_EXPECTED_OUTPUT,
|
||||
expected_loss=_CTC_EXPECTED_LOSS,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1228,6 +1244,8 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
||||
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
|
@ -36,16 +36,33 @@ from .configuration_sew import SEWConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "SEWConfig"
|
||||
_CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = "asapp/sew-tiny-100k"
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 1
|
||||
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "SEWConfig"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
|
||||
# Base docstring
|
||||
_CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k-ft-ls100h"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 292, 512]
|
||||
|
||||
# CTC docstring
|
||||
_CTC_EXPECTED_OUTPUT = (
|
||||
"'MISTER QUILTER IS THE APPOSTILE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPOLLE'"
|
||||
)
|
||||
_CTC_EXPECTED_LOSS = 0.42
|
||||
|
||||
# Audio class docstring
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
_SEQ_CLASS_CHECKPOINT = "anton-l/sew-mid-100k-ft-keyword-spotting"
|
||||
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
|
||||
_SEQ_CLASS_EXPECTED_LOSS = 9.52
|
||||
|
||||
SEW_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"asapp/sew-tiny-100k",
|
||||
"asapp/sew-small-100k",
|
||||
@ -879,6 +896,7 @@ class SEWModel(SEWPreTrainedModel):
|
||||
output_type=BaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -978,6 +996,8 @@ class SEWForCTC(SEWPreTrainedModel):
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=CausalLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_CTC_EXPECTED_OUTPUT,
|
||||
expected_loss=_CTC_EXPECTED_LOSS,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1108,6 +1128,8 @@ class SEWForSequenceClassification(SEWPreTrainedModel):
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
||||
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
|
@ -37,15 +37,28 @@ from .configuration_sew_d import SEWDConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "SEWDConfig"
|
||||
_CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = "asapp/sew-d-tiny-100k"
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 1
|
||||
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "SEWDConfig"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
|
||||
# Base docstring
|
||||
_CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k-ft-ls100h"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 292, 384]
|
||||
|
||||
# CTC docstring
|
||||
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTIL OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
|
||||
_CTC_EXPECTED_LOSS = 0.21
|
||||
|
||||
# Audio class docstring
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
_SEQ_CLASS_CHECKPOINT = "anton-l/sew-d-mid-400k-ft-keyword-spotting"
|
||||
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
|
||||
_SEQ_CLASS_EXPECTED_LOSS = 3.16
|
||||
|
||||
SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"asapp/sew-d-tiny-100k",
|
||||
"asapp/sew-d-small-100k",
|
||||
@ -1415,6 +1428,7 @@ class SEWDModel(SEWDPreTrainedModel):
|
||||
output_type=BaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1514,6 +1528,8 @@ class SEWDForCTC(SEWDPreTrainedModel):
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=CausalLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_CTC_EXPECTED_OUTPUT,
|
||||
expected_loss=_CTC_EXPECTED_LOSS,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1644,6 +1660,8 @@ class SEWDForSequenceClassification(SEWDPreTrainedModel):
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
||||
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
|
@ -42,15 +42,27 @@ from .configuration_unispeech import UniSpeechConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "UniSpeechConfig"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-large-1500h-cv"
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = "microsoft/unispeech-large-1500h-cv"
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 2
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "UniSpeechConfig"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
|
||||
# Base docstring
|
||||
_CHECKPOINT_FOR_DOC = "patrickvonplaten/unispeech-large-1500h-cv-timit"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
|
||||
|
||||
# CTC docstring
|
||||
_CTC_EXPECTED_OUTPUT = "'mister quilter is the apposl of the midle classes and weare glad to welcom his gosepl'"
|
||||
_CTC_EXPECTED_LOSS = 17.17
|
||||
|
||||
# Audio class docstring
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
_SEQ_CLASS_CHECKPOINT = "hf-internal-testing/tiny-random-unispeech"
|
||||
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" # TODO(anton) - could you quickly fine-tune a KS WavLM Model
|
||||
_SEQ_CLASS_EXPECTED_LOSS = 0.66 # TODO(anton) - could you quickly fine-tune a KS WavLM Model
|
||||
|
||||
UNISPEECH_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"microsoft/unispeech-large-1500h-cv",
|
||||
"microsoft/unispeech-large-multi-lingual-1500h-cv",
|
||||
@ -1129,6 +1141,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
|
||||
output_type=UniSpeechBaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1266,44 +1279,14 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForPreTraining
|
||||
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
|
||||
>>> from datasets import load_dataset
|
||||
>>> import soundfile as sf
|
||||
>>> from transformers import Wav2Vec2FeatureExtractor, UniSpeechForPreTraining
|
||||
>>> from transformers.models.unispeech.modeling_unispeech import _compute_mask_indices
|
||||
|
||||
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base")
|
||||
>>> model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
|
||||
|
||||
|
||||
>>> def map_to_array(batch):
|
||||
... speech, _ = sf.read(batch["file"])
|
||||
... batch["speech"] = speech
|
||||
... return batch
|
||||
|
||||
|
||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.map(map_to_array)
|
||||
|
||||
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
||||
|
||||
>>> # compute masked indices
|
||||
>>> batch_size, raw_sequence_length = input_values.shape
|
||||
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
|
||||
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
|
||||
>>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long)
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(input_values, mask_time_indices=mask_time_indices)
|
||||
|
||||
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
|
||||
>>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
|
||||
|
||||
>>> # show that cosine similarity is much higher than random
|
||||
>>> assert cosine_sim[mask_time_indices].mean() > 0.5
|
||||
|
||||
>>> # for contrastive loss training model should be put into train mode
|
||||
>>> model.train()
|
||||
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
|
||||
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
... "hf-internal-testing/tiny-random-unispeech-sat"
|
||||
... )
|
||||
>>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv")
|
||||
>>> # TODO: Add full pretraining example
|
||||
```"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
@ -1406,6 +1389,8 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel):
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=CausalLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_CTC_EXPECTED_OUTPUT,
|
||||
expected_loss=_CTC_EXPECTED_LOSS,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1536,6 +1521,8 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
||||
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
|
@ -43,16 +43,33 @@ from .configuration_unispeech_sat import UniSpeechSatConfig
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 2
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "UniSpeechSatConfig"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-plus"
|
||||
|
||||
# Base docstring
|
||||
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-100h-libri-ft"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
|
||||
|
||||
# CTC docstring
|
||||
_CTC_EXPECTED_OUTPUT = "'MISTER QUILDER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
|
||||
_CTC_EXPECTED_LOSS = 39.88
|
||||
|
||||
# Audio class docstring
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
_SEQ_CLASS_CHECKPOINT = "hf-internal-testing/tiny-random-unispeech-sat"
|
||||
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'" # TODO(anton) - could you quickly fine-tune a KS WavLM Model
|
||||
_SEQ_CLASS_EXPECTED_LOSS = 0.71 # TODO(anton) - could you quickly fine-tune a KS WavLM Model
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus"
|
||||
# Frame class docstring
|
||||
_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd"
|
||||
_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv"
|
||||
_FRAME_EXPECTED_OUTPUT = [0, 0]
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 2
|
||||
# Speaker Verification docstring
|
||||
_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv"
|
||||
_XVECTOR_EXPECTED_OUTPUT = 0.97
|
||||
|
||||
UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
# See all UniSpeechSat models at https://huggingface.co/models?filter=unispeech_sat
|
||||
@ -1163,6 +1180,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
|
||||
output_type=UniSpeechSatBaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1300,42 +1318,10 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
|
||||
>>> import torch
|
||||
>>> from transformers import Wav2Vec2FeatureExtractor, UniSpeechSatForPreTraining
|
||||
>>> from transformers.models.unispeech_sat.modeling_unispeech_sat import _compute_mask_indices
|
||||
>>> from datasets import load_dataset
|
||||
>>> import soundfile as sf
|
||||
|
||||
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/unispeech-sat-base")
|
||||
>>> model = UniSpeechSatForPreTraining.from_pretrained("microsoft/unispeech-sat-base")
|
||||
|
||||
|
||||
>>> def map_to_array(batch):
|
||||
... speech, _ = sf.read(batch["file"])
|
||||
... batch["speech"] = speech
|
||||
... return batch
|
||||
|
||||
|
||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.map(map_to_array)
|
||||
|
||||
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
||||
|
||||
>>> # compute masked indices
|
||||
>>> batch_size, raw_sequence_length = input_values.shape
|
||||
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
|
||||
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
|
||||
>>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long)
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(input_values, mask_time_indices=mask_time_indices)
|
||||
|
||||
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
|
||||
>>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
|
||||
|
||||
>>> # show that cosine similarity is much higher than random
|
||||
>>> assert cosine_sim[mask_time_indices].mean() > 0.5
|
||||
|
||||
>>> # for contrastive loss training model should be put into train mode
|
||||
>>> model.train()
|
||||
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
|
||||
>>> # TODO: Add full pretraining example
|
||||
```"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
@ -1431,6 +1417,8 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=CausalLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_CTC_EXPECTED_OUTPUT,
|
||||
expected_loss=_CTC_EXPECTED_LOSS,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1561,6 +1549,8 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
||||
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1677,6 +1667,7 @@ class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel):
|
||||
output_type=TokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_FRAME_EXPECTED_OUTPUT,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1853,6 +1844,7 @@ class UniSpeechSatForXVector(UniSpeechSatPreTrainedModel):
|
||||
output_type=XVectorOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_XVECTOR_EXPECTED_OUTPUT,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
|
@ -48,17 +48,35 @@ from .configuration_wav2vec2 import Wav2Vec2Config
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "Wav2Vec2Config"
|
||||
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-ks"
|
||||
_FRAME_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-sd"
|
||||
_XVECTOR_CHECKPOINT = "superb/wav2vec2-base-superb-sv"
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 2
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "Wav2Vec2Config"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
|
||||
# Base docstring
|
||||
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
|
||||
|
||||
# CTC docstring
|
||||
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
|
||||
_CTC_EXPECTED_LOSS = 53.48
|
||||
|
||||
# Audio class docstring
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
_SEQ_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-ks"
|
||||
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
|
||||
_SEQ_CLASS_EXPECTED_LOSS = 6.54
|
||||
|
||||
# Frame class docstring
|
||||
_FRAME_CLASS_CHECKPOINT = "anton-l/wav2vec2-base-superb-sd"
|
||||
_FRAME_EXPECTED_OUTPUT = [0, 0]
|
||||
|
||||
# Speaker Verification docstring
|
||||
_XVECTOR_CHECKPOINT = "anton-l/wav2vec2-base-superb-sv"
|
||||
_XVECTOR_EXPECTED_OUTPUT = 0.98
|
||||
|
||||
|
||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/wav2vec2-base-960h",
|
||||
@ -1294,6 +1312,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
output_type=Wav2Vec2BaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1469,10 +1488,11 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
||||
>>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
|
||||
|
||||
>>> # show that cosine similarity is much higher than random
|
||||
>>> assert cosine_sim[mask_time_indices].mean() > 0.5
|
||||
>>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
|
||||
tensor(True)
|
||||
|
||||
>>> # for contrastive loss training model should be put into train mode
|
||||
>>> model.train()
|
||||
>>> model = model.train()
|
||||
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
|
||||
```"""
|
||||
|
||||
@ -1697,6 +1717,8 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=CausalLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_CTC_EXPECTED_OUTPUT,
|
||||
expected_loss=_CTC_EXPECTED_LOSS,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1826,6 +1848,8 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
||||
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1941,6 +1965,7 @@ class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
|
||||
output_type=TokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_FRAME_EXPECTED_OUTPUT,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -2114,6 +2139,7 @@ class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel):
|
||||
output_type=XVectorOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_XVECTOR_EXPECTED_OUTPUT,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
|
@ -42,19 +42,35 @@ from .configuration_wavlm import WavLMConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "WavLMConfig"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
_CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus"
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = "microsoft/wavlm-base"
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus"
|
||||
_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd"
|
||||
_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv"
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 2
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "WavLMConfig"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
|
||||
# Base docstring
|
||||
_CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
|
||||
|
||||
# CTC docstring
|
||||
_CTC_EXPECTED_OUTPUT = "'mister quilter is the aposle of the middle classes and we are glad to welcome his gospel'"
|
||||
_CTC_EXPECTED_LOSS = 12.51
|
||||
|
||||
# Audio class docstring
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
_SEQ_CLASS_CHECKPOINT = "hf-internal-testing/tiny-random-wavlm"
|
||||
_SEQ_CLASS_EXPECTED_OUTPUT = "'no'" # TODO(anton) - could you quickly fine-tune a KS WavLM Model
|
||||
_SEQ_CLASS_EXPECTED_LOSS = 0.7 # TODO(anton) - could you quickly fine-tune a KS WavLM Model
|
||||
|
||||
# Frame class docstring
|
||||
_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd"
|
||||
_FRAME_EXPECTED_OUTPUT = [0, 0]
|
||||
|
||||
# Speaker Verification docstring
|
||||
_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv"
|
||||
_XVECTOR_EXPECTED_OUTPUT = 0.97
|
||||
|
||||
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"microsoft/wavlm-base",
|
||||
"microsoft/wavlm-base-plus",
|
||||
@ -1247,6 +1263,7 @@ class WavLMModel(WavLMPreTrainedModel):
|
||||
output_type=WavLMBaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1350,6 +1367,8 @@ class WavLMForCTC(WavLMPreTrainedModel):
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=CausalLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_CTC_EXPECTED_OUTPUT,
|
||||
expected_loss=_CTC_EXPECTED_LOSS,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1480,6 +1499,8 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
||||
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1596,6 +1617,7 @@ class WavLMForAudioFrameClassification(WavLMPreTrainedModel):
|
||||
output_type=TokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_FRAME_EXPECTED_OUTPUT,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1772,6 +1794,7 @@ class WavLMForXVector(WavLMPreTrainedModel):
|
||||
output_type=XVectorOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
expected_output=_XVECTOR_EXPECTED_OUTPUT,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
|
@ -1,2 +1,7 @@
|
||||
docs/source/quicktour.rst
|
||||
docs/source/task_summary.rst
|
||||
src/transformers/models/wav2vec2/modeling_wav2vec2.py
|
||||
src/transformers/models/hubert/modeling_hubert.py
|
||||
src/transformers/models/wavlm/modeling_wavlm.py
|
||||
src/transformers/models/unispeech/modeling_unispeech.py
|
||||
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
|
||||
src/transformers/models/sew/modeling_sew.py
|
||||
src/transformers/models/sew_d/modeling_sew_d.py
|
||||
|
145
utils/prepare_for_doc_test.py
Normal file
145
utils/prepare_for_doc_test.py
Normal file
@ -0,0 +1,145 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Style utils to preprocess files for doc tests.
|
||||
|
||||
The doc precossing function can be run on a list of files and/org
|
||||
directories of files. It will recursively check if the files have
|
||||
a python code snippet by looking for a ```python or ```py syntax.
|
||||
In the default mode - `remove_new_line==False` the script will
|
||||
add a new line before every python code ending ``` line to make
|
||||
the docstrings ready for pytest doctests.
|
||||
However, we don't want to have empty lines displayed in the
|
||||
official documentation which is why the new line command can be
|
||||
reversed by adding the flag `--remove_new_line` which sets
|
||||
`remove_new_line==True`.
|
||||
|
||||
When debugging the doc tests locally, please make sure to
|
||||
always run:
|
||||
|
||||
```python utils/prepare_for_doc_test.py src doc```
|
||||
|
||||
before running the doc tests:
|
||||
|
||||
```pytest --doctest-modules $(cat utils/documentation_tests.txt) -sv --doctest-continue-on-failure --doctest-glob="*.mdx"```
|
||||
|
||||
Afterwards you should revert the changes by running
|
||||
|
||||
```python utils/prepare_for_doc_test.py src doc --remove_new_line```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
def process_code_block(code, add_new_line=True):
|
||||
if add_new_line:
|
||||
return maybe_append_new_line(code)
|
||||
else:
|
||||
return maybe_remove_new_line(code)
|
||||
|
||||
|
||||
def maybe_append_new_line(code):
|
||||
"""
|
||||
Append new line if code snippet is a
|
||||
Python code snippet
|
||||
"""
|
||||
lines = code.split("\n")
|
||||
|
||||
if lines[0] in ["py", "python"]:
|
||||
# add new line before last line being ```
|
||||
last_line = lines[-1]
|
||||
lines.pop()
|
||||
lines.append("\n" + last_line)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def maybe_remove_new_line(code):
|
||||
"""
|
||||
Remove new line if code snippet is a
|
||||
Python code snippet
|
||||
"""
|
||||
lines = code.split("\n")
|
||||
|
||||
if lines[0] in ["py", "python"]:
|
||||
# add new line before last line being ```
|
||||
lines = lines[:-2] + lines[-1:]
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def process_doc_file(code_file, add_new_line=True):
|
||||
"""
|
||||
Process given file.
|
||||
|
||||
Args:
|
||||
code_file (`str` or `os.PathLike`): The file in which we want to style the docstring.
|
||||
"""
|
||||
with open(code_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
code = f.read()
|
||||
|
||||
# fmt: off
|
||||
splits = code.split("```")
|
||||
splits = [s if i % 2 == 0 else process_code_block(s, add_new_line=add_new_line) for i, s in enumerate(splits)]
|
||||
clean_code = "```".join(splits)
|
||||
# fmt: on
|
||||
|
||||
diff = clean_code != code
|
||||
if diff:
|
||||
print(f"Overwriting content of {code_file}.")
|
||||
with open(code_file, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(clean_code)
|
||||
|
||||
|
||||
def process_doc_files(*files, add_new_line=True):
|
||||
"""
|
||||
Applies doc styling or checks everything is correct in a list of files.
|
||||
|
||||
Args:
|
||||
files (several `str` or `os.PathLike`): The files to treat.
|
||||
Whether to restyle file or just check if they should be restyled.
|
||||
|
||||
Returns:
|
||||
List[`str`]: The list of files changed or that should be restyled.
|
||||
"""
|
||||
for file in files:
|
||||
# Treat folders
|
||||
if os.path.isdir(file):
|
||||
files = [os.path.join(file, f) for f in os.listdir(file)]
|
||||
files = [f for f in files if os.path.isdir(f) or f.endswith(".mdx") or f.endswith(".py")]
|
||||
process_doc_files(*files, add_new_line=add_new_line)
|
||||
else:
|
||||
try:
|
||||
process_doc_file(file, add_new_line=add_new_line)
|
||||
except Exception:
|
||||
print(f"There is a problem in {file}.")
|
||||
raise
|
||||
|
||||
|
||||
def main(*files, add_new_line=True):
|
||||
process_doc_files(*files, add_new_line=add_new_line)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("files", nargs="+", help="The file(s) or folder(s) to restyle.")
|
||||
parser.add_argument(
|
||||
"--remove_new_line",
|
||||
action="store_true",
|
||||
help="Whether to remove new line after each python code block instead of adding one.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(*args.files, add_new_line=not args.remove_new_line)
|
Loading…
Reference in New Issue
Block a user