TF: add test for PushToHubCallback (#20231)

* test hub tf callback

* create repo before cloning it
This commit is contained in:
Joao Gante 2022-11-17 12:33:44 +00:00 committed by GitHub
parent 3a780cc57a
commit 2062c28552
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 63 additions and 17 deletions

View File

@ -9,7 +9,7 @@ import tensorflow as tf
from packaging.version import parse
from tensorflow.keras.callbacks import Callback
from huggingface_hub import Repository
from huggingface_hub import Repository, create_repo
from . import IntervalStrategy, PreTrainedTokenizerBase
from .modelcard import TrainingSummary
@ -339,11 +339,13 @@ class PushToHubCallback(Callback):
self.output_dir = output_dir
self.hub_model_id = hub_model_id
create_repo(self.hub_model_id, exist_ok=True)
self.repo = Repository(
str(self.output_dir),
clone_from=self.hub_model_id,
use_auth_token=hub_token if hub_token else True,
)
self.tokenizer = tokenizer
self.last_job = None
self.checkpoint = checkpoint
@ -394,17 +396,22 @@ class PushToHubCallback(Callback):
)
def on_train_end(self, logs=None):
# Makes sure the latest version of the model is uploaded
if self.last_job is not None and not self.last_job.is_done:
self.last_job._process.terminate() # Gotta go fast
logging.info("Pushing the last epoch to the Hub, this may take a while...")
while not self.last_job.is_done:
sleep(1)
self.model.save_pretrained(self.output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(self.output_dir)
train_summary = TrainingSummary.from_keras(
model=self.model, model_name=self.hub_model_id, keras_history=self.training_history, **self.model_card_args
)
model_card = train_summary.to_model_card()
with (self.output_dir / "README.md").open("w") as f:
f.write(model_card)
self.repo.push_to_hub(commit_message="End of training", blocking=True)
else:
self.model.save_pretrained(self.output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(self.output_dir)
train_summary = TrainingSummary.from_keras(
model=self.model,
model_name=self.hub_model_id,
keras_history=self.training_history,
**self.model_card_args,
)
model_card = train_summary.to_model_card()
with (self.output_dir / "README.md").open("w") as f:
f.write(model_card)
self.repo.push_to_hub(commit_message="End of training", blocking=True)

View File

@ -399,6 +399,7 @@ class BlenderbotTokenizer(PreTrainedTokenizer):
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A Blenderbot sequence has the following format:
- single sequence: ` X </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added

View File

@ -284,6 +284,7 @@ class BlenderbotTokenizerFast(PreTrainedTokenizerFast):
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A Blenderbot sequence has the following format:
- single sequence: ` X </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added

View File

@ -428,6 +428,7 @@ class MarkupLMTokenizer(PreTrainedTokenizer):
adding special tokens. A RoBERTa sequence has the following format:
- single sequence: `<s> X </s>`
- pair of sequences: `<s> A </s></s> B </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.

View File

@ -883,6 +883,7 @@ class MarkupLMTokenizerFast(PreTrainedTokenizerFast):
adding special tokens. A RoBERTa sequence has the following format:
- single sequence: `<s> X </s>`
- pair of sequences: `<s> A </s></s> B </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.

View File

@ -342,6 +342,7 @@ class TapexTokenizer(PreTrainedTokenizer):
adding special tokens. A TAPEX sequence has the following format:
- single sequence: `<s> X </s>`
- pair of sequences: `<s> A </s></s> B </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.

View File

@ -78,9 +78,11 @@ if is_tf_available():
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig,
PushToHubCallback,
RagRetriever,
TFAutoModel,
TFAutoModelForSequenceClassification,
TFBertForMaskedLM,
TFBertModel,
TFRagModel,
TFSharedEmbeddings,
@ -2359,6 +2361,11 @@ class TFModelPushToHubTester(unittest.TestCase):
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="test-model-tf-callback")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="valid_org/test-model-tf-org")
except HTTPError:
@ -2378,13 +2385,14 @@ class TFModelPushToHubTester(unittest.TestCase):
model.push_to_hub("test-model-tf", use_auth_token=self._token)
logging.set_verbosity_warning()
# Check the model card was created and uploaded.
self.assertIn("Uploading README.md to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)
self.assertIn("Uploading the following files to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
# Reset repo
@ -2397,8 +2405,32 @@ class TFModelPushToHubTester(unittest.TestCase):
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
def test_push_to_hub_callback(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertForMaskedLM(config)
model.compile()
with tempfile.TemporaryDirectory() as tmp_dir:
push_to_hub_callback = PushToHubCallback(
output_dir=tmp_dir,
hub_model_id="test-model-tf-callback",
hub_token=self._token,
)
model.fit(model.dummy_inputs, model.dummy_inputs, epochs=1, callbacks=[push_to_hub_callback])
new_model = TFBertForMaskedLM.from_pretrained(f"{USER}/test-model-tf-callback")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
def test_push_to_hub_in_organization(self):
@ -2414,8 +2446,9 @@ class TFModelPushToHubTester(unittest.TestCase):
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
# Reset repo
@ -2430,6 +2463,7 @@ class TFModelPushToHubTester(unittest.TestCase):
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)