Copied from for test files (#26713)

* copied statement for test files

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-10-11 14:12:09 +02:00 committed by GitHub
parent 9f40639292
commit 5334796d20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 127 additions and 45 deletions

View File

@ -386,7 +386,7 @@ class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
model = BioGptModel.from_pretrained(model_name)
self.assertIsNotNone(model)
# Copied from tests.models.opt.test_modeling_opt.OPTModelTest with OPT->BioGpt, prepare_config_and_inputs-> prepare_config_and_inputs_for_common
# Copied from tests.models.opt.test_modeling_opt.OPTModelTest.test_opt_sequence_classification_model with OPT->BioGpt,opt->biogpt,prepare_config_and_inputs->prepare_config_and_inputs_for_common
def test_biogpt_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
@ -399,7 +399,7 @@ class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.opt.test_modeling_opt.OPTModelTest with OPT->BioGpt, prepare_config_and_inputs-> prepare_config_and_inputs_for_common
# Copied from tests.models.opt.test_modeling_opt.OPTModelTest.test_opt_sequence_classification_model_for_multi_label with OPT->BioGpt,opt->biogpt,prepare_config_and_inputs->prepare_config_and_inputs_for_common
def test_biogpt_sequence_classification_model_for_multi_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3

View File

@ -19,6 +19,7 @@ import random
import unittest
import numpy as np
from datasets import load_dataset
from transformers import ClapFeatureExtractor
from transformers.testing_utils import require_torch, require_torchaudio
@ -110,10 +111,10 @@ class ClapFeatureExtractionTester(unittest.TestCase):
@require_torch
@require_torchaudio
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest with Whisper->Clap
class ClapFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = ClapFeatureExtractor
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.setUp with Whisper->Clap
def setUp(self):
self.feat_extract_tester = ClapFeatureExtractionTester(self)
@ -147,6 +148,7 @@ class ClapFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.test_double_precision_pad
def test_double_precision_pad(self):
import torch
@ -160,9 +162,8 @@ class ClapFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
self.assertTrue(pt_processed.input_features.dtype == torch.float32)
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest._load_datasamples
def _load_datasamples(self, num_samples):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]

View File

@ -341,7 +341,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@unittest.skip("LLaMA buffers include complex numbers, which breaks this test")
@unittest.skip("Llama buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self):
pass

View File

@ -27,7 +27,6 @@ from transformers.testing_utils import require_tokenizers, slow
from ...test_tokenization_common import TokenizerTesterMixin
# Copied from transformers.tests.roberta.test_modeling_roberta.py with Roberta->Longformer
@require_tokenizers
class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = LongformerTokenizer
@ -72,19 +71,23 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_tokenizer
def get_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_rust_tokenizer
def get_rust_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_input_output_texts
def get_input_output_texts(self, tokenizer):
input_text = "lower newer"
output_text = "lower newer"
return input_text, output_text
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_full_tokenizer
def test_full_tokenizer(self):
tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer"
@ -96,6 +99,7 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.roberta_dict_integration_testing with roberta->longformer
def longformer_dict_integration_testing(self):
tokenizer = self.get_tokenizer()
@ -106,6 +110,7 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
)
@slow
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_sequence_builders with roberta-base->allenai/longformer-base-4096
def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("allenai/longformer-base-4096")
@ -125,6 +130,7 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
assert encoded_sentence == encoded_text_from_decode
assert encoded_pair == encoded_pair_from_decode
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_space_encoding
def test_space_encoding(self):
tokenizer = self.get_tokenizer()
@ -165,9 +171,11 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
first_char = tokenizer.convert_ids_to_tokens(encoded[mask_loc + 1])[0]
self.assertNotEqual(first_char, space_encoding)
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_pretokenized_inputs
def test_pretokenized_inputs(self):
pass
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_embeded_special_tokens
def test_embeded_special_tokens(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
@ -200,6 +208,7 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokens_r_str, ["<s>", "A", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"]
)
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_change_add_prefix_space_and_trim_offsets_args
def test_change_add_prefix_space_and_trim_offsets_args(self):
for trim_offsets, add_prefix_space in itertools.product([True, False], repeat=2):
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
@ -214,6 +223,7 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(post_processor_state["add_prefix_space"], add_prefix_space)
self.assertEqual(post_processor_state["trim_offsets"], trim_offsets)
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_offsets_mapping_with_different_add_prefix_space_and_trim_space_arguments
def test_offsets_mapping_with_different_add_prefix_space_and_trim_space_arguments(self):
# Test which aims to verify that the offsets are well adapted to the argument `add_prefix_space` and
# `trim_offsets`

View File

@ -39,7 +39,6 @@ if is_torch_available():
)
# Copied from transformers.tests.mistral.test_modelling_mistral.MistralModelTest with Llama->Mistral
class MistralModelTester:
def __init__(
self,
@ -93,6 +92,7 @@ class MistralModelTester:
self.pad_token_id = pad_token_id
self.scope = scope
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@ -134,6 +134,7 @@ class MistralModelTester:
pad_token_id=self.pad_token_id,
)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Mistral
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
@ -144,6 +145,7 @@ class MistralModelTester:
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Mistral
def create_and_check_model_as_decoder(
self,
config,
@ -174,6 +176,7 @@ class MistralModelTester:
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Mistral
def create_and_check_for_causal_lm(
self,
config,
@ -192,6 +195,7 @@ class MistralModelTester:
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->Mistral
def create_and_check_decoder_model_past_large_inputs(
self,
config,
@ -254,6 +258,7 @@ class MistralModelTester:
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(

View File

@ -32,7 +32,6 @@ from transformers.testing_utils import require_tokenizers, slow
from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
# Copied from transformers.tests.models.bert.test_modeling_bert.py with Bert->MobileBert and pathfix
@require_tokenizers
class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = MobileBertTokenizer
@ -71,11 +70,13 @@ class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
for tokenizer_def in self.tokenizers_list
]
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.get_input_output_texts
def get_input_output_texts(self, tokenizer):
input_text = "UNwant\u00E9d,running"
output_text = "unwanted, running"
return input_text, output_text
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_full_tokenizer
def test_full_tokenizer(self):
tokenizer = self.tokenizer_class(self.vocab_file)
@ -83,6 +84,7 @@ class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [9, 6, 7, 12, 10, 11])
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_rust_and_python_full_tokenizers
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
return
@ -124,11 +126,13 @@ class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_chinese
def test_chinese(self):
tokenizer = BasicTokenizer()
self.assertListEqual(tokenizer.tokenize("ah\u535A\u63A8zz"), ["ah", "\u535A", "\u63A8", "zz"])
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower
def test_basic_tokenizer_lower(self):
tokenizer = BasicTokenizer(do_lower_case=True)
@ -137,6 +141,7 @@ class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_false
def test_basic_tokenizer_lower_strip_accents_false(self):
tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=False)
@ -145,6 +150,7 @@ class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["h\u00E9llo"])
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_true
def test_basic_tokenizer_lower_strip_accents_true(self):
tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=True)
@ -153,6 +159,7 @@ class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_default
def test_basic_tokenizer_lower_strip_accents_default(self):
tokenizer = BasicTokenizer(do_lower_case=True)
@ -161,6 +168,7 @@ class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower
def test_basic_tokenizer_no_lower(self):
tokenizer = BasicTokenizer(do_lower_case=False)
@ -168,6 +176,7 @@ class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
)
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower_strip_accents_false
def test_basic_tokenizer_no_lower_strip_accents_false(self):
tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=False)
@ -175,6 +184,7 @@ class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HäLLo", "!", "how", "Are", "yoU", "?"]
)
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower_strip_accents_true
def test_basic_tokenizer_no_lower_strip_accents_true(self):
tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=True)
@ -182,6 +192,7 @@ class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HaLLo", "!", "how", "Are", "yoU", "?"]
)
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_respects_never_split_tokens
def test_basic_tokenizer_respects_never_split_tokens(self):
tokenizer = BasicTokenizer(do_lower_case=False, never_split=["[UNK]"])
@ -189,6 +200,7 @@ class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"]
)
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_wordpiece_tokenizer
def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
@ -203,6 +215,7 @@ class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_is_whitespace
def test_is_whitespace(self):
self.assertTrue(_is_whitespace(" "))
self.assertTrue(_is_whitespace("\t"))
@ -213,6 +226,7 @@ class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertFalse(_is_whitespace("A"))
self.assertFalse(_is_whitespace("-"))
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_is_control
def test_is_control(self):
self.assertTrue(_is_control("\u0005"))
@ -221,6 +235,7 @@ class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertFalse(_is_control("\t"))
self.assertFalse(_is_control("\r"))
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_is_punctuation
def test_is_punctuation(self):
self.assertTrue(_is_punctuation("-"))
self.assertTrue(_is_punctuation("$"))
@ -230,6 +245,7 @@ class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertFalse(_is_punctuation("A"))
self.assertFalse(_is_punctuation(" "))
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_clean_text
def test_clean_text(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
@ -242,6 +258,7 @@ class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
)
@slow
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_sequence_builders with bert-base-uncased->google/mobilebert-uncased
def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("google/mobilebert-uncased")
@ -254,6 +271,7 @@ class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
assert encoded_sentence == [101] + text + [102]
assert encoded_pair == [101] + text + [102] + text_2 + [102]
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_offsets_with_special_characters
def test_offsets_with_special_characters(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
@ -306,6 +324,7 @@ class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
)
self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_change_tokenize_chinese_chars
def test_change_tokenize_chinese_chars(self):
list_of_commun_chinese_char = ["", "", ""]
text_with_chinese_char = "".join(list_of_commun_chinese_char)

View File

@ -39,7 +39,7 @@ if is_torch_available():
)
# Copied from transformers.tests.llama.test_modelling_llama.LlamaModelTest with Llama->Persimmon
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Persimmon
class PersimmonModelTester:
def __init__(
self,
@ -266,7 +266,6 @@ class PersimmonModelTester:
return config, inputs_dict
# Copied from transformers.tests.llama.test_modelling_llama.LlamaModelTest with Llama->Persimmon
@require_torch
class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
@ -288,23 +287,28 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
test_headmasking = False
test_pruning = False
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.setUp with Llama->Persimmon
def setUp(self):
self.model_tester = PersimmonModelTester(self)
self.config_tester = ConfigTester(self, config_class=PersimmonConfig, hidden_size=37)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config
def test_config(self):
self.config_tester.run_common_tests()
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_various_embeddings
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with Llama->Persimmon,llama->persimmon
def test_persimmon_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
@ -317,6 +321,7 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with Llama->Persimmon,llama->persimmon
def test_persimmon_sequence_classification_model_for_single_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
@ -330,6 +335,7 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with Llama->Persimmon,llama->persimmon
def test_persimmon_sequence_classification_model_for_multi_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
@ -346,10 +352,12 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@unittest.skip("Persimmon buffers include complex numbers, which breaks this test")
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_save_load_fast_init_from_base
def test_save_load_fast_init_from_base(self):
pass
@parameterized.expand([("linear",), ("dynamic",)])
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling with Llama->Persimmon
def test_model_rope_scaling(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)

View File

@ -76,8 +76,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def get_rust_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return RobertaTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
return RobertaTokenizerFast(self.vocab_file, self.merges_file, **kwargs)
return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self, tokenizer):
input_text = "lower newer"

View File

@ -36,7 +36,7 @@ if is_flax_available():
)
# Copied from tests.models.roberta.test_modelling_flax_roberta.FlaxRobertaModelTester with Roberta->RobertaPreLayerNorm
# Copied from tests.models.roberta.test_modeling_flax_roberta.FlaxRobertaModelTester with Roberta->RobertaPreLayerNorm
class FlaxRobertaPreLayerNormModelTester(unittest.TestCase):
def __init__(
self,
@ -134,7 +134,7 @@ class FlaxRobertaPreLayerNormModelTester(unittest.TestCase):
@require_flax
# Copied from tests.models.roberta.test_modelling_flax_roberta.FlaxRobertaPreLayerNormModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta-base->andreasmadsen/efficient_mlm_m0.40
# Copied from tests.models.roberta.test_modeling_flax_roberta.FlaxRobertaModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta-base->andreasmadsen/efficient_mlm_m0.40
class FlaxRobertaPreLayerNormModelTest(FlaxModelTesterMixin, unittest.TestCase):
test_head_masking = True

View File

@ -44,7 +44,7 @@ if is_torch_available():
)
# Copied from tests.models.roberta.test_modelling_roberta.RobertaModelTester with Roberta->RobertaPreLayerNorm
# Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTester with Roberta->RobertaPreLayerNorm
class RobertaPreLayerNormModelTester:
def __init__(
self,
@ -365,7 +365,6 @@ class RobertaPreLayerNormModelTester:
@require_torch
# Copied from tests.models.roberta.test_modelling_roberta.RobertaPreLayerNormModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm
class RobertaPreLayerNormModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
@ -397,27 +396,33 @@ class RobertaPreLayerNormModelTest(ModelTesterMixin, GenerationTesterMixin, Pipe
fx_compatible = False
model_split_percents = [0.5, 0.8, 0.9]
# Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.setUp with Roberta->RobertaPreLayerNorm
def setUp(self):
self.model_tester = RobertaPreLayerNormModelTester(self)
self.config_tester = ConfigTester(self, config_class=RobertaPreLayerNormConfig, hidden_size=37)
# Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_config
def test_config(self):
self.config_tester.run_common_tests()
# Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_model
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
# Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_model_various_embeddings
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
# Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_model_as_decoder
def test_model_as_decoder(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
# Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_model_as_decoder_with_default_input_mask
def test_model_as_decoder_with_default_input_mask(self):
# This regression test was failing with PyTorch < 1.3
(
@ -446,42 +451,50 @@ class RobertaPreLayerNormModelTest(ModelTesterMixin, GenerationTesterMixin, Pipe
encoder_attention_mask,
)
# Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_for_causal_lm
def test_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
# Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_decoder_model_past_with_large_inputs
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
# Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_for_masked_lm
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
# Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_for_token_classification
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
# Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_for_multiple_choice
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
# Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_for_question_answering
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
@slow
# Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_model_from_pretrained with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm
def test_model_from_pretrained(self):
for model_name in ROBERTA_PRELAYERNORM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = RobertaPreLayerNormModel.from_pretrained(model_name)
self.assertIsNotNone(model)
# Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_create_position_ids_respects_padding_index with Roberta->RobertaPreLayerNorm
def test_create_position_ids_respects_padding_index(self):
"""Ensure that the default position ids only assign a sequential . This is a regression
test for https://github.com/huggingface/transformers/issues/1761
The position ids should be masked with the embedding object's padding index. Therefore, the
first available non-padding position index is RobertaPreLayerNormEmbeddings.padding_idx + 1
The position ids should be masked with the embedding object's padding index. Therefore, the first available
non-padding position index is RobertaPreLayerNormEmbeddings.padding_idx + 1
"""
config = self.model_tester.prepare_config_and_inputs()[0]
model = RobertaPreLayerNormEmbeddings(config=config)
@ -495,12 +508,13 @@ class RobertaPreLayerNormModelTest(ModelTesterMixin, GenerationTesterMixin, Pipe
self.assertEqual(position_ids.shape, expected_positions.shape)
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
# Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_create_position_ids_from_inputs_embeds with Roberta->RobertaPreLayerNorm
def test_create_position_ids_from_inputs_embeds(self):
"""Ensure that the default position ids only assign a sequential . This is a regression
test for https://github.com/huggingface/transformers/issues/1761
The position ids should be masked with the embedding object's padding index. Therefore, the
first available non-padding position index is RobertaPreLayerNormEmbeddings.padding_idx + 1
The position ids should be masked with the embedding object's padding index. Therefore, the first available
non-padding position index is RobertaPreLayerNormEmbeddings.padding_idx + 1
"""
config = self.model_tester.prepare_config_and_inputs()[0]
embeddings = RobertaPreLayerNormEmbeddings(config=config)

View File

@ -42,7 +42,7 @@ if is_tf_available():
)
# Copied from tests.models.roberta.test_modelling_tf_roberta.TFRobertaModelTester with Roberta->RobertaPreLayerNorm
# Copied from tests.models.roberta.test_modeling_tf_roberta.TFRobertaModelTester with Roberta->RobertaPreLayerNorm
class TFRobertaPreLayerNormModelTester:
def __init__(
self,
@ -551,7 +551,7 @@ class TFRobertaPreLayerNormModelTester:
@require_tf
# Copied from tests.models.roberta.test_modelling_tf_roberta.TFRobertaPreLayerNormModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm
# Copied from tests.models.roberta.test_modeling_tf_roberta.TFRobertaModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm
class TFRobertaPreLayerNormModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(

View File

@ -68,13 +68,13 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokenizer.convert_tokens_to_shape_ids(tokens), [5, 6, 2, 5, 7, 8])
self.assertListEqual(tokenizer.convert_tokens_to_pronunciation_ids(tokens), [5, 6, 2, 5, 7, 8])
# Copied from tests.models.bert.test_tokenization_bert.test_chinese with BasicTokenizer->RoCBertBertBasicTokenizer
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_chinese with BasicTokenizer->RoCBertBasicTokenizer
def test_chinese(self):
tokenizer = RoCBertBasicTokenizer()
self.assertListEqual(tokenizer.tokenize("ah\u535A\u63A8zz"), ["ah", "\u535A", "\u63A8", "zz"])
# Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_lower with BasicTokenizer->RoCBertBertBasicTokenizer
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_lower(self):
tokenizer = RoCBertBasicTokenizer(do_lower_case=True)
@ -83,7 +83,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
# Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_lower_strip_accents_false with BasicTokenizer->RoCBertBertBasicTokenizer
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_false with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_lower_strip_accents_false(self):
tokenizer = RoCBertBasicTokenizer(do_lower_case=True, strip_accents=False)
@ -92,7 +92,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["h\u00E9llo"])
# Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_lower_strip_accents_true with BertBasicTokenizer->RoCBertBertBasicTokenizer
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_true with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_lower_strip_accents_true(self):
tokenizer = RoCBertBasicTokenizer(do_lower_case=True, strip_accents=True)
@ -101,7 +101,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
# Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_lower_strip_accents_default with BasicTokenizer->RoCBertBertBasicTokenizer
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_default with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_lower_strip_accents_default(self):
tokenizer = RoCBertBasicTokenizer(do_lower_case=True)
@ -110,7 +110,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
# Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_no_lower with BasicTokenizer->RoCBertBertBasicTokenizer
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_no_lower(self):
tokenizer = RoCBertBasicTokenizer(do_lower_case=False)
@ -118,7 +118,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
)
# Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_no_lower_strip_accents_false with BertBasicTokenizer->RoCBertBertBasicTokenizer
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower_strip_accents_false with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_no_lower_strip_accents_false(self):
tokenizer = RoCBertBasicTokenizer(do_lower_case=False, strip_accents=False)
@ -126,7 +126,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HäLLo", "!", "how", "Are", "yoU", "?"]
)
# Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_no_lower_strip_accents_true with BasicTokenizer->RoCBertBertBasicTokenizer
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower_strip_accents_true with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_no_lower_strip_accents_true(self):
tokenizer = RoCBertBasicTokenizer(do_lower_case=False, strip_accents=True)
@ -134,7 +134,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HaLLo", "!", "how", "Are", "yoU", "?"]
)
# Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_respects_never_split_tokens with BasicTokenizer->RoCBertBertBasicTokenizer
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_respects_never_split_tokens with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_respects_never_split_tokens(self):
tokenizer = RoCBertBasicTokenizer(do_lower_case=False, never_split=["[UNK]"])
@ -142,7 +142,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"]
)
# Copied from tests.models.bert.test_tokenization_bert.test_wordpiece_tokenizer with WordpieceTokenizer->RoCBertWordpieceTokenizer
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_wordpiece_tokenizer with WordpieceTokenizer->RoCBertWordpieceTokenizer
def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
@ -157,7 +157,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
# Copied from tests.models.bert.test_tokenization_bert.test_is_whitespace
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_is_whitespace
def test_is_whitespace(self):
self.assertTrue(_is_whitespace(" "))
self.assertTrue(_is_whitespace("\t"))
@ -168,7 +168,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertFalse(_is_whitespace("A"))
self.assertFalse(_is_whitespace("-"))
# Copied from tests.models.bert.test_tokenization_bert.test_is_control
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_is_control
def test_is_control(self):
self.assertTrue(_is_control("\u0005"))
@ -177,7 +177,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertFalse(_is_control("\t"))
self.assertFalse(_is_control("\r"))
# Copied from tests.models.bert.test_tokenization_bert.test_is_punctuation
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_is_punctuation
def test_is_punctuation(self):
self.assertTrue(_is_punctuation("-"))
self.assertTrue(_is_punctuation("$"))
@ -199,7 +199,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
[rust_tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]]
)
# Copied from tests.models.bert.test_tokenization_bert. test_offsets_with_special_characters
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_offsets_with_special_characters
def test_offsets_with_special_characters(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
@ -252,7 +252,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
)
self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
# Copied from tests.models.bert.test_tokenization_bert. test_change_tokenize_chinese_chars
# Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_change_tokenize_chinese_chars
def test_change_tokenize_chinese_chars(self):
list_of_commun_chinese_char = ["", "", ""]
text_with_chinese_char = "".join(list_of_commun_chinese_char)

View File

@ -376,7 +376,7 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
def test_vocab_size(self):
self.assertEqual(self.tokenizer.vocab_size, 50257)
# Copied from transformers.tests.speech_to_test.test_tokenization_speech_to_text.py
# Copied from tests.models.speech_to_text.test_tokenization_speech_to_text.SpeechToTextTokenizerMultilinguialTest.test_tokenizer_decode_ignores_language_codes
def test_tokenizer_decode_ignores_language_codes(self):
self.assertIn(ES_CODE, self.tokenizer.all_special_ids)
generated_ids = [ES_CODE, 4, 1601, 47, 7647, 2]

View File

@ -51,6 +51,7 @@ from transformers.utils import direct_transformers_import
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_copies.py
TRANSFORMERS_PATH = "src/transformers"
MODEL_TEST_PATH = "tests/models"
PATH_TO_DOCS = "docs/source/en"
REPO_PATH = "."
@ -132,12 +133,15 @@ def _should_continue(line: str, indent: str) -> bool:
return line.startswith(indent) or len(line.strip()) == 0 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
def find_code_in_transformers(object_name: str) -> str:
def find_code_in_transformers(object_name: str, base_path: str = None) -> str:
"""
Find and return the source code of an object.
Args:
object_name (`str`): The name of the object we want the source code of.
object_name (`str`):
The name of the object we want the source code of.
base_path (`str`, *optional*):
The path to the base folder where files are checked. If not set, it will be set to `TRANSFORMERS_PATH`.
Returns:
`str`: The source code of the object.
@ -145,9 +149,21 @@ def find_code_in_transformers(object_name: str) -> str:
parts = object_name.split(".")
i = 0
# We can't set this as the default value in the argument, otherwise `CopyCheckTester` will fail, as it uses a
# patched temp directory.
if base_path is None:
base_path = TRANSFORMERS_PATH
# Detail: the `Copied from` statement is originally designed to work with the last part of `TRANSFORMERS_PATH`,
# (which is `transformers`). The same should be applied for `MODEL_TEST_PATH`. However, its last part is `models`
# (to only check and search in it) which is a bit confusing. So we keep the copied statement staring with
# `tests.models.` and change it to `tests` here.
if base_path == MODEL_TEST_PATH:
base_path = "tests"
# First let's find the module where our object lives.
module = parts[i]
while i < len(parts) and not os.path.isfile(os.path.join(TRANSFORMERS_PATH, f"{module}.py")):
while i < len(parts) and not os.path.isfile(os.path.join(base_path, f"{module}.py")):
i += 1
if i < len(parts):
module = os.path.join(module, parts[i])
@ -156,7 +172,7 @@ def find_code_in_transformers(object_name: str) -> str:
f"`object_name` should begin with the name of a module of transformers but got {object_name}."
)
with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f:
with open(os.path.join(base_path, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Now let's find the class / func in the code!
@ -186,6 +202,7 @@ def find_code_in_transformers(object_name: str) -> str:
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
_re_copy_warning_for_test_file = re.compile(r"^(\s*)#\s*Copied from\s+tests\.(\S+\.\S+)\s*($|\S.*$)")
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
_re_fill_pattern = re.compile(r"<FILL\s+[^>]*>")
@ -284,14 +301,20 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[
line_index = 0
# Not a for loop cause `lines` is going to change (if `overwrite=True`).
while line_index < len(lines):
search = _re_copy_warning.search(lines[line_index])
search_re = _re_copy_warning
if filename.startswith("tests"):
search_re = _re_copy_warning_for_test_file
search = search_re.search(lines[line_index])
if search is None:
line_index += 1
continue
# There is some copied code here, let's retrieve the original.
indent, object_name, replace_pattern = search.groups()
theoretical_code = find_code_in_transformers(object_name)
base_path = TRANSFORMERS_PATH if not filename.startswith("tests") else MODEL_TEST_PATH
theoretical_code = find_code_in_transformers(object_name, base_path=base_path)
theoretical_indent = get_indent(theoretical_code)
start_index = line_index + 1 if indent == theoretical_indent else line_index
@ -357,6 +380,9 @@ def check_copies(overwrite: bool = False):
Whether or not to overwrite the copies when they don't match.
"""
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
all_test_files = glob.glob(os.path.join(MODEL_TEST_PATH, "**/*.py"), recursive=True)
all_files = list(all_files) + list(all_test_files)
diffs = []
for filename in all_files:
new_diffs = is_copy_consistent(filename, overwrite)