This will reduce "Already borrowed error": (#12550)

* This will reduce "Already borrowed error":

Original issue https://github.com/huggingface/tokenizers/issues/537

The original issue is caused by transformers calling many times
mutable functions on the rust tokenizers.
Rust needs to guarantee that only 1 agent has a mutable reference
to memory at a given time (for many reasons which don't need explaining
here). Usually, the rust compiler can guarantee that this property is
true at compile time.

Unfortunately, this is impossible for Python to do that, so PyO3, the
bridge between rust and python used by `tokenizers`, will change the
compile guarantee for a dynamic guarantee, so if multiple agents try
to have multiple mutable borrows at the same time, then the runtime will
yell with "Already borrowed".

The proposed fix here in transformers, is simply to reduce the actual
number of calls that really need mutable borrows. By reducing them,
we reduce the risk of running into "Already borrowed" error.
The caveat is now we add a call to read the current configuration of the
`_tokenizer`, so worst case we have 2 calls instead of 1, and best case
we simply have 1 + a Python comparison of a dict (should be negligible).

* Adding a test.

* trivial error :(.

* Update tests/test_tokenization_fast.py

Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com>

* Adding reference to original issues in the tests.

* Update the tests with fast tokenizer.

Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com>
This commit is contained in:
Nicolas Patry 2021-07-09 09:36:05 +02:00 committed by GitHub
parent 8fe836af5a
commit cc12e1dbf6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 13 deletions

View File

@ -338,23 +338,32 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
"""
_truncation = self._tokenizer.truncation
_padding = self._tokenizer.padding
# Set truncation and padding on the backend tokenizer
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
self._tokenizer.enable_truncation(max_length, stride=stride, strategy=truncation_strategy.value)
if truncation_strategy == TruncationStrategy.DO_NOT_TRUNCATE:
if _truncation is not None:
self._tokenizer.no_truncation()
else:
self._tokenizer.no_truncation()
target = {"max_length": max_length, "stride": stride, "strategy": truncation_strategy.value}
if _truncation != target:
self._tokenizer.enable_truncation(**target)
if padding_strategy != PaddingStrategy.DO_NOT_PAD:
self._tokenizer.enable_padding(
length=max_length if padding_strategy == PaddingStrategy.MAX_LENGTH else None,
direction=self.padding_side,
pad_id=self.pad_token_id,
pad_type_id=self.pad_token_type_id,
pad_token=self.pad_token,
pad_to_multiple_of=pad_to_multiple_of,
)
if padding_strategy == PaddingStrategy.DO_NOT_PAD:
if _padding is not None:
self._tokenizer.no_padding()
else:
self._tokenizer.no_padding()
length = max_length if padding_strategy == PaddingStrategy.MAX_LENGTH else None
target = {
"length": length,
"direction": self.padding_side,
"pad_id": self.pad_token_id,
"pad_token": self.pad_token,
"pad_type_id": self.pad_token_type_id,
"pad_to_multiple_of": pad_to_multiple_of,
}
if _padding != target:
self._tokenizer.enable_padding(**target)
def _batch_encode_plus(
self,

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent.futures
import shutil
import tempfile
import unittest
@ -90,3 +91,20 @@ class PreTrainedTokenizationFastTest(TokenizerTesterMixin, unittest.TestCase):
# is restored
shutil.rmtree(self.tmpdirname)
self.tmpdirname = tmpdirname_orig
@require_tokenizers
class ReduceMutableBorrowTests(unittest.TestCase):
def test_async_share_tokenizer(self):
# See https://github.com/huggingface/transformers/pull/12550
# and https://github.com/huggingface/tokenizers/issues/537
tokenizer = PreTrainedTokenizerFast.from_pretrained("robot-test/dummy-tokenizer-wordlevel")
text = "The Matrix is a 1999 science fiction action film."
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(self.fetch, tokenizer, text) for i in range(10)]
return_value = [future.result() for future in futures]
self.assertEqual(return_value, [[1, 10, 0, 8, 0, 18, 0, 0, 0, 2] for i in range(10)])
def fetch(self, tokenizer, text):
return tokenizer.encode(text, truncation="longest_first", padding="longest")