add scripts

This commit is contained in:
Victor SANH 2020-05-27 18:24:39 -04:00
parent d489a6d3d5
commit fb8f4277b2
2 changed files with 224 additions and 0 deletions

View File

@ -0,0 +1,132 @@
# Copyright 2020-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Once a model has been fine-pruned, the weights that are masked during the forward pass can be pruned once for all.
For instance, once the a model from the :class:`~emmental.MaskedBertForSequenceClassification` is trained, it can be saved (and then loaded)
as a standard :class:`~transformers.BertForSequenceClassification`.
"""
import os
import shutil
import argparse
import torch
from emmental.modules import MagnitudeBinarizer, TopKBinarizer, ThresholdBinarizer
def main(args):
pruning_method = args.pruning_method
threshold = args.threshold
model_name_or_path = args.model_name_or_path.rstrip("/")
target_model_path = args.target_model_path
print(f"Load fine-pruned model from {model_name_or_path}")
model = torch.load(os.path.join(model_name_or_path, "pytorch_model.bin"))
pruned_model = {}
for name, tensor in model.items():
if "embeddings" in name or "LayerNorm" in name or "pooler" in name:
pruned_model[name] = tensor
print(f"Pruned layer {name}")
elif "classifier" in name or "qa_output" in name:
pruned_model[name] = tensor
print(f"Pruned layer {name}")
elif "bias" in name:
pruned_model[name] = tensor
print(f"Pruned layer {name}")
else:
if pruning_method == "magnitude":
mask = MagnitudeBinarizer.apply(inputs=tensor, threshold=threshold)
pruned_model[name] = tensor * mask
print(f"Pruned layer {name}")
elif pruning_method == "topK":
if "mask_scores" in name:
continue
prefix_ = name[:-6]
scores = model[f"{prefix_}mask_scores"]
mask = TopKBinarizer.apply(scores, threshold)
pruned_model[name] = tensor * mask
print(f"Pruned layer {name}")
elif pruning_method == "sigmoied_threshold":
if "mask_scores" in name:
continue
prefix_ = name[:-6]
scores = model[f"{prefix_}mask_scores"]
mask = ThresholdBinarizer.apply(scores, threshold, True)
pruned_model[name] = tensor * mask
print(f"Pruned layer {name}")
elif pruning_method == "l0":
if "mask_scores" in name:
continue
prefix_ = name[:-6]
scores = model[f"{prefix_}mask_scores"]
l, r = -0.1, 1.1
s = torch.sigmoid(scores)
s_bar = s * (r - l) + l
mask = s_bar.clamp(min=0.0, max=1.0)
pruned_model[name] = tensor * mask
print(f"Pruned layer {name}")
else:
raise ValueError("Unknown pruning method")
if target_model_path is None:
target_model_path = os.path.join(
os.path.dirname(model_name_or_path), f"bertarized_{os.path.basename(model_name_or_path)}"
)
if not os.path.isdir(target_model_path):
shutil.copytree(model_name_or_path, target_model_path)
print(f"\nCreated folder {target_model_path}")
torch.save(pruned_model, os.path.join(target_model_path, "pytorch_model.bin"))
print("\nPruned model saved! See you later!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pruning_method",
choices=["l0", "magnitude", "topK", "sigmoied_threshold", ],
type=str,
required=True,
help="Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning, sigmoied_threshold = Soft movement pruning)",
)
parser.add_argument(
"--threshold",
type=float,
required=False,
help="For `magnitude` and `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
"For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
"Not needed for `l0`",
)
parser.add_argument(
"--model_name_or_path",
type=str,
required=True,
help="Folder containing the model that was previously fine-pruned",
)
parser.add_argument(
"--target_model_path",
default=None,
type=str,
required=False,
help="Folder containing the model that was previously fine-pruned",
)
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,92 @@
# Copyright 2020-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Count remaining (non-zero) weights in the encoder (i.e. the transformer layers).
Sparsity and remaining weights levels are equivalent: sparsity % = 100 - remaining weights %.
"""
import os
import argparse
import torch
from emmental.modules import TopKBinarizer, ThresholdBinarizer
def main(args):
serialization_dir = args.serialization_dir
pruning_method = args.pruning_method
threshold = args.threshold
st = torch.load(os.path.join(serialization_dir, "pytorch_model.bin"), map_location="cpu")
remaining_count = 0 # Number of remaining (not pruned) params in the encoder
encoder_count = 0 # Number of params in the encoder
print("name".ljust(60, " "), "Remaining Weights %", "Remaning Weight")
for name, param in st.items():
if "encoder" not in name:
continue
if "mask_scores" in name:
if pruning_method == "topK":
mask_ones = TopKBinarizer.apply(param, threshold).sum().item()
elif pruning_method == "sigmoied_threshold":
mask_ones = ThresholdBinarizer.apply(param, threshold, True).sum().item()
elif pruning_method == "l0":
l, r = -0.1, 1.1
s = torch.sigmoid(param)
s_bar = s * (r - l) + l
mask = s_bar.clamp(min=0.0, max=1.0)
mask_ones = (mask > 0.0).sum().item()
else:
raise ValueError("Unknown pruning method")
remaining_count += mask_ones
print(name.ljust(60, " "), str(round(100 * mask_ones / param.numel(), 3)).ljust(20, " "), str(mask_ones))
else:
encoder_count += param.numel()
if "bias" in name or "LayerNorm" in name:
remaining_count += param.numel()
print("")
print("Remaining Weights (global) %: ", 100 * remaining_count / encoder_count)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pruning_method",
choices=["l0", "topK", "sigmoied_threshold", ],
type=str,
required=True,
help="Pruning Method (l0 = L0 regularization, topK = Movement pruning, sigmoied_threshold = Soft movement pruning)",
)
parser.add_argument(
"--threshold",
type=float,
required=False,
help="For `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
"For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
"Not needed for `l0`",
)
parser.add_argument(
"--serialization_dir",
type=str,
required=True,
help="Folder containing the model that was previously fine-pruned",
)
args = parser.parse_args()
main(args)