mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-16 19:18:24 +06:00
Fix tests imports dpr (#5576)
* fix test imports * fix max_length * style * fix tests
This commit is contained in:
parent
d4886173b2
commit
4fedc1256c
@ -157,13 +157,13 @@ CUSTOM_DPR_READER_DOCSTRING = r"""
|
|||||||
The passages titles to be encoded. This can be a string, a list of strings if there are several passages.
|
The passages titles to be encoded. This can be a string, a list of strings if there are several passages.
|
||||||
texts (:obj:`str`, :obj:`List[str]`):
|
texts (:obj:`str`, :obj:`List[str]`):
|
||||||
The passages texts to be encoded. This can be a string, a list of strings if there are several passages.
|
The passages texts to be encoded. This can be a string, a list of strings if there are several passages.
|
||||||
padding (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`True`):
|
padding (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`False`):
|
||||||
Activate and control padding. Accepts the following values:
|
Activate and control padding. Accepts the following values:
|
||||||
|
|
||||||
* `True` or `'longest'`: pad to the longest sequence in the batch (or no padding if only a single sequence if provided),
|
* `True` or `'longest'`: pad to the longest sequence in the batch (or no padding if only a single sequence if provided),
|
||||||
* `'max_length'`: pad to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`)
|
* `'max_length'`: pad to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`)
|
||||||
* `False` or `'do_not_pad'` (default): No padding (i.e. can output batch with sequences of uneven lengths)
|
* `False` or `'do_not_pad'` (default): No padding (i.e. can output batch with sequences of uneven lengths)
|
||||||
truncation (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`True`):
|
truncation (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`False`):
|
||||||
Activate and control truncation. Accepts the following values:
|
Activate and control truncation. Accepts the following values:
|
||||||
|
|
||||||
* `True` or `'only_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`).
|
* `True` or `'only_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`).
|
||||||
@ -203,15 +203,37 @@ class CustomDPRReaderTokenizerMixin:
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
questions,
|
questions,
|
||||||
titles,
|
titles: Optional[str] = None,
|
||||||
texts,
|
texts: Optional[str] = None,
|
||||||
padding: Union[bool, str] = True,
|
padding: Union[bool, str] = False,
|
||||||
truncation: Union[bool, str] = True,
|
truncation: Union[bool, str] = False,
|
||||||
max_length: Optional[int] = 512,
|
max_length: Optional[int] = None,
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
return_attention_mask: Optional[bool] = None,
|
return_attention_mask: Optional[bool] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
|
if titles is None and texts is None:
|
||||||
|
return super().__call__(
|
||||||
|
questions,
|
||||||
|
padding=padding,
|
||||||
|
truncation=truncation,
|
||||||
|
max_length=max_length,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
return_attention_mask=return_attention_mask,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif titles is None or texts is None:
|
||||||
|
text_pair = titles if texts is None else texts
|
||||||
|
return super().__call__(
|
||||||
|
questions,
|
||||||
|
text_pair,
|
||||||
|
padding=padding,
|
||||||
|
truncation=truncation,
|
||||||
|
max_length=max_length,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
return_attention_mask=return_attention_mask,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
titles = titles if not isinstance(titles, str) else [titles]
|
titles = titles if not isinstance(titles, str) else [titles]
|
||||||
texts = texts if not isinstance(texts, str) else [texts]
|
texts = texts if not isinstance(texts, str) else [texts]
|
||||||
n_passages = len(titles)
|
n_passages = len(titles)
|
||||||
|
@ -17,10 +17,10 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
from .utils import require_torch, slow, torch_device
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from transformers.testing_utils import slow
|
||||||
from transformers.tokenization_dpr import (
|
from transformers.tokenization_dpr import (
|
||||||
DPRContextEncoderTokenizer,
|
DPRContextEncoderTokenizer,
|
||||||
DPRContextEncoderTokenizerFast,
|
DPRContextEncoderTokenizerFast,
|
||||||
@ -26,7 +27,6 @@ from transformers.tokenization_dpr import (
|
|||||||
from transformers.tokenization_utils_base import BatchEncoding
|
from transformers.tokenization_utils_base import BatchEncoding
|
||||||
|
|
||||||
from .test_tokenization_bert import BertTokenizationTest
|
from .test_tokenization_bert import BertTokenizationTest
|
||||||
from .utils import slow
|
|
||||||
|
|
||||||
|
|
||||||
class DPRContextEncoderTokenizationTest(BertTokenizationTest):
|
class DPRContextEncoderTokenizationTest(BertTokenizationTest):
|
||||||
|
Loading…
Reference in New Issue
Block a user