mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Replace build() with build_in_name_scope() for some TF tests (#28046)
Replace build() with build_in_name_scope() for some tests
This commit is contained in:
parent
050e0b44f6
commit
3060899be5
@ -304,7 +304,7 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
|
|||||||
old_total_size = config.vocab_size
|
old_total_size = config.vocab_size
|
||||||
new_total_size = old_total_size + new_tokens_size
|
new_total_size = old_total_size + new_tokens_size
|
||||||
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
|
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
|
||||||
model.build()
|
model.build_in_name_scope()
|
||||||
model.resize_token_embeddings(new_total_size)
|
model.resize_token_embeddings(new_total_size)
|
||||||
|
|
||||||
# fetch the output for an input exclusively made of new members of the vocabulary
|
# fetch the output for an input exclusively made of new members of the vocabulary
|
||||||
|
@ -225,7 +225,7 @@ class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.build() # may be needed for the get_bias() call below
|
model.build_in_name_scope() # may be needed for the get_bias() call below
|
||||||
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||||
|
|
||||||
if model_class in list_lm_models:
|
if model_class in list_lm_models:
|
||||||
|
@ -316,7 +316,7 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
with tf.Graph().as_default() as g:
|
with tf.Graph().as_default() as g:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.build()
|
model.build_in_name_scope()
|
||||||
|
|
||||||
for op in g.get_operations():
|
for op in g.get_operations():
|
||||||
model_op_names.add(op.node_def.op)
|
model_op_names.add(op.node_def.op)
|
||||||
@ -346,7 +346,7 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
for model_class in self.all_model_classes[:2]:
|
for model_class in self.all_model_classes[:2]:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.build()
|
model.build_in_name_scope()
|
||||||
|
|
||||||
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
|
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
|
||||||
|
|
||||||
@ -1088,7 +1088,7 @@ class TFModelTesterMixin:
|
|||||||
def _get_word_embedding_weight(model, embedding_layer):
|
def _get_word_embedding_weight(model, embedding_layer):
|
||||||
if isinstance(embedding_layer, tf.keras.layers.Embedding):
|
if isinstance(embedding_layer, tf.keras.layers.Embedding):
|
||||||
# builds the embeddings layer
|
# builds the embeddings layer
|
||||||
model.build()
|
model.build_in_name_scope()
|
||||||
return embedding_layer.embeddings
|
return embedding_layer.embeddings
|
||||||
else:
|
else:
|
||||||
return model._get_word_embedding_weight(embedding_layer)
|
return model._get_word_embedding_weight(embedding_layer)
|
||||||
@ -1151,7 +1151,7 @@ class TFModelTesterMixin:
|
|||||||
old_total_size = config.vocab_size
|
old_total_size = config.vocab_size
|
||||||
new_total_size = old_total_size + new_tokens_size
|
new_total_size = old_total_size + new_tokens_size
|
||||||
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
|
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
|
||||||
model.build()
|
model.build_in_name_scope()
|
||||||
model.resize_token_embeddings(new_total_size)
|
model.resize_token_embeddings(new_total_size)
|
||||||
|
|
||||||
# fetch the output for an input exclusively made of new members of the vocabulary
|
# fetch the output for an input exclusively made of new members of the vocabulary
|
||||||
|
@ -402,8 +402,8 @@ class TFModelUtilsTest(unittest.TestCase):
|
|||||||
# Finally, check the model can be reloaded
|
# Finally, check the model can be reloaded
|
||||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
model.build()
|
model.build_in_name_scope()
|
||||||
new_model.build()
|
new_model.build_in_name_scope()
|
||||||
|
|
||||||
for p1, p2 in zip(model.weights, new_model.weights):
|
for p1, p2 in zip(model.weights, new_model.weights):
|
||||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
@ -632,7 +632,7 @@ class TFModelPushToHubTester(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
model = TFBertModel(config)
|
model = TFBertModel(config)
|
||||||
# Make sure model is properly initialized
|
# Make sure model is properly initialized
|
||||||
model.build()
|
model.build_in_name_scope()
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
logger = logging.get_logger("transformers.utils.hub")
|
logger = logging.get_logger("transformers.utils.hub")
|
||||||
@ -701,7 +701,7 @@ class TFModelPushToHubTester(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
model = TFBertModel(config)
|
model = TFBertModel(config)
|
||||||
# Make sure model is properly initialized
|
# Make sure model is properly initialized
|
||||||
model.build()
|
model.build_in_name_scope()
|
||||||
|
|
||||||
model.push_to_hub("valid_org/test-model-tf-org", token=self._token)
|
model.push_to_hub("valid_org/test-model-tf-org", token=self._token)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user