Fix QA sample (#16648)

* fix QA sample

* For TF_QUESTION_ANSWERING_SAMPLE

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-04-08 15:31:43 +02:00 committed by GitHub
parent 9a24b97b7f
commit ab229663b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -207,7 +207,8 @@ PT_QUESTION_ANSWERING_SAMPLE = r"""
```python
>>> # target is "nice puppet"
>>> target_start_index, target_end_index = torch.tensor([14]), torch.tensor([15])
>>> target_start_index = torch.tensor([{qa_target_start_index}])
>>> target_end_index = torch.tensor([{qa_target_end_index}])
>>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
>>> loss = outputs.loss
@ -667,7 +668,8 @@ TF_QUESTION_ANSWERING_SAMPLE = r"""
```python
>>> # target is "nice puppet"
>>> target_start_index, target_end_index = tf.constant([14]), tf.constant([15])
>>> target_start_index = tf.constant([{qa_target_start_index}])
>>> target_end_index = tf.constant([{qa_target_end_index}])
>>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
>>> loss = tf.math.reduce_mean(outputs.loss)
@ -1054,6 +1056,8 @@ def add_code_sample_docstrings(
output_type=None,
config_class=None,
mask="[MASK]",
qa_target_start_index=14,
qa_target_end_index=15,
model_cls=None,
modality=None,
expected_output="",
@ -1078,6 +1082,8 @@ def add_code_sample_docstrings(
processor_class=processor_class,
checkpoint=checkpoint,
mask=mask,
qa_target_start_index=qa_target_start_index,
qa_target_end_index=qa_target_end_index,
expected_output=expected_output,
expected_loss=expected_loss,
)