mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
TF: add test for PushToHubCallback
(#20231)
* test hub tf callback * create repo before cloning it
This commit is contained in:
parent
3a780cc57a
commit
2062c28552
@ -9,7 +9,7 @@ import tensorflow as tf
|
||||
from packaging.version import parse
|
||||
from tensorflow.keras.callbacks import Callback
|
||||
|
||||
from huggingface_hub import Repository
|
||||
from huggingface_hub import Repository, create_repo
|
||||
|
||||
from . import IntervalStrategy, PreTrainedTokenizerBase
|
||||
from .modelcard import TrainingSummary
|
||||
@ -339,11 +339,13 @@ class PushToHubCallback(Callback):
|
||||
|
||||
self.output_dir = output_dir
|
||||
self.hub_model_id = hub_model_id
|
||||
create_repo(self.hub_model_id, exist_ok=True)
|
||||
self.repo = Repository(
|
||||
str(self.output_dir),
|
||||
clone_from=self.hub_model_id,
|
||||
use_auth_token=hub_token if hub_token else True,
|
||||
)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.last_job = None
|
||||
self.checkpoint = checkpoint
|
||||
@ -394,17 +396,22 @@ class PushToHubCallback(Callback):
|
||||
)
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
# Makes sure the latest version of the model is uploaded
|
||||
if self.last_job is not None and not self.last_job.is_done:
|
||||
self.last_job._process.terminate() # Gotta go fast
|
||||
logging.info("Pushing the last epoch to the Hub, this may take a while...")
|
||||
while not self.last_job.is_done:
|
||||
sleep(1)
|
||||
self.model.save_pretrained(self.output_dir)
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(self.output_dir)
|
||||
train_summary = TrainingSummary.from_keras(
|
||||
model=self.model, model_name=self.hub_model_id, keras_history=self.training_history, **self.model_card_args
|
||||
)
|
||||
model_card = train_summary.to_model_card()
|
||||
with (self.output_dir / "README.md").open("w") as f:
|
||||
f.write(model_card)
|
||||
self.repo.push_to_hub(commit_message="End of training", blocking=True)
|
||||
else:
|
||||
self.model.save_pretrained(self.output_dir)
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(self.output_dir)
|
||||
train_summary = TrainingSummary.from_keras(
|
||||
model=self.model,
|
||||
model_name=self.hub_model_id,
|
||||
keras_history=self.training_history,
|
||||
**self.model_card_args,
|
||||
)
|
||||
model_card = train_summary.to_model_card()
|
||||
with (self.output_dir / "README.md").open("w") as f:
|
||||
f.write(model_card)
|
||||
self.repo.push_to_hub(commit_message="End of training", blocking=True)
|
||||
|
@ -399,6 +399,7 @@ class BlenderbotTokenizer(PreTrainedTokenizer):
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. A Blenderbot sequence has the following format:
|
||||
- single sequence: ` X </s>`
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs to which the special tokens will be added
|
||||
|
@ -284,6 +284,7 @@ class BlenderbotTokenizerFast(PreTrainedTokenizerFast):
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. A Blenderbot sequence has the following format:
|
||||
- single sequence: ` X </s>`
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs to which the special tokens will be added
|
||||
|
@ -428,6 +428,7 @@ class MarkupLMTokenizer(PreTrainedTokenizer):
|
||||
adding special tokens. A RoBERTa sequence has the following format:
|
||||
- single sequence: `<s> X </s>`
|
||||
- pair of sequences: `<s> A </s></s> B </s>`
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
|
@ -883,6 +883,7 @@ class MarkupLMTokenizerFast(PreTrainedTokenizerFast):
|
||||
adding special tokens. A RoBERTa sequence has the following format:
|
||||
- single sequence: `<s> X </s>`
|
||||
- pair of sequences: `<s> A </s></s> B </s>`
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
|
@ -342,6 +342,7 @@ class TapexTokenizer(PreTrainedTokenizer):
|
||||
adding special tokens. A TAPEX sequence has the following format:
|
||||
- single sequence: `<s> X </s>`
|
||||
- pair of sequences: `<s> A </s></s> B </s>`
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
|
@ -78,9 +78,11 @@ if is_tf_available():
|
||||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
BertConfig,
|
||||
PushToHubCallback,
|
||||
RagRetriever,
|
||||
TFAutoModel,
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFBertForMaskedLM,
|
||||
TFBertModel,
|
||||
TFRagModel,
|
||||
TFSharedEmbeddings,
|
||||
@ -2359,6 +2361,11 @@ class TFModelPushToHubTester(unittest.TestCase):
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="test-model-tf-callback")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="valid_org/test-model-tf-org")
|
||||
except HTTPError:
|
||||
@ -2378,13 +2385,14 @@ class TFModelPushToHubTester(unittest.TestCase):
|
||||
model.push_to_hub("test-model-tf", use_auth_token=self._token)
|
||||
logging.set_verbosity_warning()
|
||||
# Check the model card was created and uploaded.
|
||||
self.assertIn("Uploading README.md to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)
|
||||
self.assertIn("Uploading the following files to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)
|
||||
|
||||
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
|
||||
models_equal = True
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
if not tf.math.reduce_all(p1 == p2):
|
||||
models_equal = False
|
||||
break
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
# Reset repo
|
||||
@ -2397,8 +2405,32 @@ class TFModelPushToHubTester(unittest.TestCase):
|
||||
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
|
||||
models_equal = True
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
if not tf.math.reduce_all(p1 == p2):
|
||||
models_equal = False
|
||||
break
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_push_to_hub_callback(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
model = TFBertForMaskedLM(config)
|
||||
model.compile()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
push_to_hub_callback = PushToHubCallback(
|
||||
output_dir=tmp_dir,
|
||||
hub_model_id="test-model-tf-callback",
|
||||
hub_token=self._token,
|
||||
)
|
||||
model.fit(model.dummy_inputs, model.dummy_inputs, epochs=1, callbacks=[push_to_hub_callback])
|
||||
|
||||
new_model = TFBertForMaskedLM.from_pretrained(f"{USER}/test-model-tf-callback")
|
||||
models_equal = True
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
if not tf.math.reduce_all(p1 == p2):
|
||||
models_equal = False
|
||||
break
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
@ -2414,8 +2446,9 @@ class TFModelPushToHubTester(unittest.TestCase):
|
||||
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
|
||||
models_equal = True
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
if not tf.math.reduce_all(p1 == p2):
|
||||
models_equal = False
|
||||
break
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
# Reset repo
|
||||
@ -2430,6 +2463,7 @@ class TFModelPushToHubTester(unittest.TestCase):
|
||||
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
|
||||
models_equal = True
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||
if not tf.math.reduce_all(p1 == p2):
|
||||
models_equal = False
|
||||
break
|
||||
self.assertTrue(models_equal)
|
||||
|
Loading…
Reference in New Issue
Block a user