mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Replace assertions with valueError Exeptions (#14117)
* Replace assertions with valueError Exeptions * Reformatted
This commit is contained in:
parent
9f53f049c6
commit
fa4abdb3ea
@ -75,9 +75,10 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
||||
else:
|
||||
hf_shape = hf_pointer.shape
|
||||
|
||||
assert (
|
||||
hf_shape == value.shape
|
||||
), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
|
||||
if hf_shape != value.shape:
|
||||
raise ValueError(
|
||||
f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
|
||||
)
|
||||
|
||||
if weight_type == "weight":
|
||||
hf_pointer.weight.data = value
|
||||
@ -145,28 +146,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
|
||||
|
||||
if type_id == 0:
|
||||
if "bias" in name:
|
||||
assert (
|
||||
value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
|
||||
), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
|
||||
if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape:
|
||||
raise ValueError(
|
||||
f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
|
||||
)
|
||||
feature_extractor.conv_layers[layer_id].conv.bias.data = value
|
||||
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
|
||||
elif "weight" in name:
|
||||
assert (
|
||||
value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
|
||||
), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
|
||||
if value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape:
|
||||
raise ValueError(
|
||||
f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
|
||||
)
|
||||
feature_extractor.conv_layers[layer_id].conv.weight.data = value
|
||||
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
|
||||
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
|
||||
if "bias" in name:
|
||||
assert (
|
||||
value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
|
||||
), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
|
||||
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape:
|
||||
raise ValueError(
|
||||
f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
|
||||
)
|
||||
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
|
||||
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
|
||||
elif "weight" in name:
|
||||
assert (
|
||||
value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
|
||||
), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
|
||||
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape:
|
||||
raise ValueError(
|
||||
f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
|
||||
)
|
||||
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
|
||||
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user