From fb8f4277b236d67bcdcc8cfb9dba8b17abc6d4ab Mon Sep 17 00:00:00 2001 From: Victor SANH Date: Wed, 27 May 2020 18:24:39 -0400 Subject: [PATCH] add scripts --- examples/movement-pruning/bertarize.py | 132 ++++++++++++++++++ .../movement-pruning/counts_parameters.py | 92 ++++++++++++ 2 files changed, 224 insertions(+) create mode 100644 examples/movement-pruning/bertarize.py create mode 100644 examples/movement-pruning/counts_parameters.py diff --git a/examples/movement-pruning/bertarize.py b/examples/movement-pruning/bertarize.py new file mode 100644 index 00000000000..80643687b7c --- /dev/null +++ b/examples/movement-pruning/bertarize.py @@ -0,0 +1,132 @@ +# Copyright 2020-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Once a model has been fine-pruned, the weights that are masked during the forward pass can be pruned once for all. +For instance, once the a model from the :class:`~emmental.MaskedBertForSequenceClassification` is trained, it can be saved (and then loaded) +as a standard :class:`~transformers.BertForSequenceClassification`. +""" + +import os +import shutil +import argparse + +import torch + +from emmental.modules import MagnitudeBinarizer, TopKBinarizer, ThresholdBinarizer + + +def main(args): + pruning_method = args.pruning_method + threshold = args.threshold + + model_name_or_path = args.model_name_or_path.rstrip("/") + target_model_path = args.target_model_path + + print(f"Load fine-pruned model from {model_name_or_path}") + model = torch.load(os.path.join(model_name_or_path, "pytorch_model.bin")) + pruned_model = {} + + for name, tensor in model.items(): + if "embeddings" in name or "LayerNorm" in name or "pooler" in name: + pruned_model[name] = tensor + print(f"Pruned layer {name}") + elif "classifier" in name or "qa_output" in name: + pruned_model[name] = tensor + print(f"Pruned layer {name}") + elif "bias" in name: + pruned_model[name] = tensor + print(f"Pruned layer {name}") + else: + if pruning_method == "magnitude": + mask = MagnitudeBinarizer.apply(inputs=tensor, threshold=threshold) + pruned_model[name] = tensor * mask + print(f"Pruned layer {name}") + elif pruning_method == "topK": + if "mask_scores" in name: + continue + prefix_ = name[:-6] + scores = model[f"{prefix_}mask_scores"] + mask = TopKBinarizer.apply(scores, threshold) + pruned_model[name] = tensor * mask + print(f"Pruned layer {name}") + elif pruning_method == "sigmoied_threshold": + if "mask_scores" in name: + continue + prefix_ = name[:-6] + scores = model[f"{prefix_}mask_scores"] + mask = ThresholdBinarizer.apply(scores, threshold, True) + pruned_model[name] = tensor * mask + print(f"Pruned layer {name}") + elif pruning_method == "l0": + if "mask_scores" in name: + continue + prefix_ = name[:-6] + scores = model[f"{prefix_}mask_scores"] + l, r = -0.1, 1.1 + s = torch.sigmoid(scores) + s_bar = s * (r - l) + l + mask = s_bar.clamp(min=0.0, max=1.0) + pruned_model[name] = tensor * mask + print(f"Pruned layer {name}") + else: + raise ValueError("Unknown pruning method") + + if target_model_path is None: + target_model_path = os.path.join( + os.path.dirname(model_name_or_path), f"bertarized_{os.path.basename(model_name_or_path)}" + ) + + if not os.path.isdir(target_model_path): + shutil.copytree(model_name_or_path, target_model_path) + print(f"\nCreated folder {target_model_path}") + + torch.save(pruned_model, os.path.join(target_model_path, "pytorch_model.bin")) + print("\nPruned model saved! See you later!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--pruning_method", + choices=["l0", "magnitude", "topK", "sigmoied_threshold", ], + type=str, + required=True, + help="Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning, sigmoied_threshold = Soft movement pruning)", + ) + parser.add_argument( + "--threshold", + type=float, + required=False, + help="For `magnitude` and `topK`, it is the level of remaining weights (in %) in the fine-pruned model." + "For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared." + "Not needed for `l0`", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + required=True, + help="Folder containing the model that was previously fine-pruned", + ) + parser.add_argument( + "--target_model_path", + default=None, + type=str, + required=False, + help="Folder containing the model that was previously fine-pruned", + ) + + args = parser.parse_args() + + main(args) diff --git a/examples/movement-pruning/counts_parameters.py b/examples/movement-pruning/counts_parameters.py new file mode 100644 index 00000000000..5b5dac95061 --- /dev/null +++ b/examples/movement-pruning/counts_parameters.py @@ -0,0 +1,92 @@ +# Copyright 2020-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Count remaining (non-zero) weights in the encoder (i.e. the transformer layers). +Sparsity and remaining weights levels are equivalent: sparsity % = 100 - remaining weights %. +""" +import os +import argparse + +import torch + +from emmental.modules import TopKBinarizer, ThresholdBinarizer + + +def main(args): + serialization_dir = args.serialization_dir + pruning_method = args.pruning_method + threshold = args.threshold + + st = torch.load(os.path.join(serialization_dir, "pytorch_model.bin"), map_location="cpu") + + remaining_count = 0 # Number of remaining (not pruned) params in the encoder + encoder_count = 0 # Number of params in the encoder + + print("name".ljust(60, " "), "Remaining Weights %", "Remaning Weight") + for name, param in st.items(): + if "encoder" not in name: + continue + + if "mask_scores" in name: + if pruning_method == "topK": + mask_ones = TopKBinarizer.apply(param, threshold).sum().item() + elif pruning_method == "sigmoied_threshold": + mask_ones = ThresholdBinarizer.apply(param, threshold, True).sum().item() + elif pruning_method == "l0": + l, r = -0.1, 1.1 + s = torch.sigmoid(param) + s_bar = s * (r - l) + l + mask = s_bar.clamp(min=0.0, max=1.0) + mask_ones = (mask > 0.0).sum().item() + else: + raise ValueError("Unknown pruning method") + remaining_count += mask_ones + print(name.ljust(60, " "), str(round(100 * mask_ones / param.numel(), 3)).ljust(20, " "), str(mask_ones)) + else: + encoder_count += param.numel() + if "bias" in name or "LayerNorm" in name: + remaining_count += param.numel() + + print("") + print("Remaining Weights (global) %: ", 100 * remaining_count / encoder_count) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--pruning_method", + choices=["l0", "topK", "sigmoied_threshold", ], + type=str, + required=True, + help="Pruning Method (l0 = L0 regularization, topK = Movement pruning, sigmoied_threshold = Soft movement pruning)", + ) + parser.add_argument( + "--threshold", + type=float, + required=False, + help="For `topK`, it is the level of remaining weights (in %) in the fine-pruned model." + "For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared." + "Not needed for `l0`", + ) + parser.add_argument( + "--serialization_dir", + type=str, + required=True, + help="Folder containing the model that was previously fine-pruned", + ) + + args = parser.parse_args() + + main(args)