Replace assertions with valueError Exeptions (#14117)

* Replace assertions with valueError Exeptions

* Reformatted
This commit is contained in:
Jayesh Dewangan 2021-10-22 17:15:32 +05:30 committed by GitHub
parent 9f53f049c6
commit fa4abdb3ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -75,9 +75,10 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
else:
hf_shape = hf_pointer.shape
assert (
hf_shape == value.shape
), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
if hf_shape != value.shape:
raise ValueError(
f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
)
if weight_type == "weight":
hf_pointer.weight.data = value
@ -145,28 +146,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if type_id == 0:
if "bias" in name:
assert (
value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
assert (
value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
if value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
assert (
value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
assert (
value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else: