This commit is contained in:
Patrick von Platen 2021-11-03 11:31:40 +01:00 committed by GitHub
parent bd21ed4099
commit 89766b3d44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -153,7 +153,7 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
if value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape:
if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
)
@ -163,14 +163,14 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if "bias" in name:
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")