mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
Fix E266 flake8 warning (x90).
This commit is contained in:
parent
2ab78325f0
commit
fa2ccbc081
@ -487,7 +487,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_file", default=None, type=str, required=True, help="SWAG csv for training. E.g., train.csv"
|
"--train_file", default=None, type=str, required=True, help="SWAG csv for training. E.g., train.csv"
|
||||||
)
|
)
|
||||||
@ -520,7 +520,7 @@ def main():
|
|||||||
help="The output directory where the model checkpoints and predictions will be written.",
|
help="The output directory where the model checkpoints and predictions will be written.",
|
||||||
)
|
)
|
||||||
|
|
||||||
## Other parameters
|
# Other parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||||
)
|
)
|
||||||
|
@ -430,7 +430,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
|
"--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
|
||||||
)
|
)
|
||||||
@ -486,7 +486,7 @@ def main():
|
|||||||
"--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
|
"--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
|
||||||
)
|
)
|
||||||
|
|
||||||
## Other parameters
|
# Other parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||||
)
|
)
|
||||||
|
@ -43,7 +43,7 @@ if __name__ == "__main__":
|
|||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
compressed_sd = {}
|
compressed_sd = {}
|
||||||
|
|
||||||
### Embeddings ###
|
# Embeddings #
|
||||||
if args.model_type == "gpt2":
|
if args.model_type == "gpt2":
|
||||||
for param_name in ["wte.weight", "wpe.weight"]:
|
for param_name in ["wte.weight", "wpe.weight"]:
|
||||||
compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
|
compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
|
||||||
@ -55,7 +55,7 @@ if __name__ == "__main__":
|
|||||||
param_name = f"{prefix}.embeddings.LayerNorm.{w}"
|
param_name = f"{prefix}.embeddings.LayerNorm.{w}"
|
||||||
compressed_sd[param_name] = state_dict[param_name]
|
compressed_sd[param_name] = state_dict[param_name]
|
||||||
|
|
||||||
### Transformer Blocks ###
|
# Transformer Blocks #
|
||||||
std_idx = 0
|
std_idx = 0
|
||||||
for teacher_idx in [0, 2, 4, 7, 9, 11]:
|
for teacher_idx in [0, 2, 4, 7, 9, 11]:
|
||||||
if args.model_type == "gpt2":
|
if args.model_type == "gpt2":
|
||||||
@ -82,7 +82,7 @@ if __name__ == "__main__":
|
|||||||
]
|
]
|
||||||
std_idx += 1
|
std_idx += 1
|
||||||
|
|
||||||
### Language Modeling Head ###s
|
# Language Modeling Head ###s
|
||||||
if args.model_type == "roberta":
|
if args.model_type == "roberta":
|
||||||
for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
|
for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
|
||||||
compressed_sd[f"{layer}"] = state_dict[f"{layer}"]
|
compressed_sd[f"{layer}"] = state_dict[f"{layer}"]
|
||||||
|
@ -219,7 +219,7 @@ def main():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
sanity_checks(args)
|
sanity_checks(args)
|
||||||
|
|
||||||
## ARGS ##
|
# ARGS #
|
||||||
init_gpu_params(args)
|
init_gpu_params(args)
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
if args.is_master:
|
if args.is_master:
|
||||||
@ -236,7 +236,7 @@ def main():
|
|||||||
os.makedirs(args.dump_path)
|
os.makedirs(args.dump_path)
|
||||||
logger.info(f"Experiment will be dumped and logged in {args.dump_path}")
|
logger.info(f"Experiment will be dumped and logged in {args.dump_path}")
|
||||||
|
|
||||||
### SAVE PARAMS ###
|
# SAVE PARAMS #
|
||||||
logger.info(f"Param: {args}")
|
logger.info(f"Param: {args}")
|
||||||
with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
|
with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
|
||||||
json.dump(vars(args), f, indent=4)
|
json.dump(vars(args), f, indent=4)
|
||||||
@ -245,7 +245,7 @@ def main():
|
|||||||
student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
|
student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
|
||||||
teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]
|
teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]
|
||||||
|
|
||||||
### TOKENIZER ###
|
# TOKENIZER #
|
||||||
tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
|
tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
|
||||||
special_tok_ids = {}
|
special_tok_ids = {}
|
||||||
for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
|
for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
|
||||||
@ -255,7 +255,7 @@ def main():
|
|||||||
args.special_tok_ids = special_tok_ids
|
args.special_tok_ids = special_tok_ids
|
||||||
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
|
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
|
||||||
|
|
||||||
## DATA LOADER ##
|
# DATA LOADER #
|
||||||
logger.info(f"Loading data from {args.data_file}")
|
logger.info(f"Loading data from {args.data_file}")
|
||||||
with open(args.data_file, "rb") as fp:
|
with open(args.data_file, "rb") as fp:
|
||||||
data = pickle.load(fp)
|
data = pickle.load(fp)
|
||||||
@ -275,7 +275,7 @@ def main():
|
|||||||
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
|
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
|
||||||
logger.info(f"Data loader created.")
|
logger.info(f"Data loader created.")
|
||||||
|
|
||||||
## STUDENT ##
|
# STUDENT #
|
||||||
logger.info(f"Loading student config from {args.student_config}")
|
logger.info(f"Loading student config from {args.student_config}")
|
||||||
stu_architecture_config = student_config_class.from_pretrained(args.student_config)
|
stu_architecture_config = student_config_class.from_pretrained(args.student_config)
|
||||||
stu_architecture_config.output_hidden_states = True
|
stu_architecture_config.output_hidden_states = True
|
||||||
@ -290,26 +290,26 @@ def main():
|
|||||||
student.to(f"cuda:{args.local_rank}")
|
student.to(f"cuda:{args.local_rank}")
|
||||||
logger.info(f"Student loaded.")
|
logger.info(f"Student loaded.")
|
||||||
|
|
||||||
## TEACHER ##
|
# TEACHER #
|
||||||
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
|
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
|
||||||
if args.n_gpu > 0:
|
if args.n_gpu > 0:
|
||||||
teacher.to(f"cuda:{args.local_rank}")
|
teacher.to(f"cuda:{args.local_rank}")
|
||||||
logger.info(f"Teacher loaded from {args.teacher_name}.")
|
logger.info(f"Teacher loaded from {args.teacher_name}.")
|
||||||
|
|
||||||
## FREEZING ##
|
# FREEZING #
|
||||||
if args.freeze_pos_embs:
|
if args.freeze_pos_embs:
|
||||||
freeze_pos_embeddings(student, args)
|
freeze_pos_embeddings(student, args)
|
||||||
if args.freeze_token_type_embds:
|
if args.freeze_token_type_embds:
|
||||||
freeze_token_type_embeddings(student, args)
|
freeze_token_type_embeddings(student, args)
|
||||||
|
|
||||||
## SANITY CHECKS ##
|
# SANITY CHECKS #
|
||||||
assert student.config.vocab_size == teacher.config.vocab_size
|
assert student.config.vocab_size == teacher.config.vocab_size
|
||||||
assert student.config.hidden_size == teacher.config.hidden_size
|
assert student.config.hidden_size == teacher.config.hidden_size
|
||||||
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
|
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
|
||||||
if args.mlm:
|
if args.mlm:
|
||||||
assert token_probs.size(0) == stu_architecture_config.vocab_size
|
assert token_probs.size(0) == stu_architecture_config.vocab_size
|
||||||
|
|
||||||
## DISTILLER ##
|
# DISTILLER #
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
distiller = Distiller(
|
distiller = Distiller(
|
||||||
params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher
|
params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher
|
||||||
|
@ -344,7 +344,7 @@ def load_examples(args, tokenizer, evaluate=False):
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data_dir",
|
"--data_dir",
|
||||||
default=None,
|
default=None,
|
||||||
@ -374,7 +374,7 @@ def main():
|
|||||||
help="The output directory where the model predictions and checkpoints will be written.",
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
)
|
)
|
||||||
|
|
||||||
## Other parameters
|
# Other parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||||
)
|
)
|
||||||
|
@ -242,7 +242,7 @@ def prune_heads(args, model, eval_dataloader, head_mask):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data_dir",
|
"--data_dir",
|
||||||
default=None,
|
default=None,
|
||||||
@ -272,7 +272,7 @@ def main():
|
|||||||
help="The output directory where the model predictions and checkpoints will be written.",
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
)
|
)
|
||||||
|
|
||||||
## Other parameters
|
# Other parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config_name",
|
"--config_name",
|
||||||
default="",
|
default="",
|
||||||
|
@ -410,7 +410,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data_dir",
|
"--data_dir",
|
||||||
default=None,
|
default=None,
|
||||||
@ -447,7 +447,7 @@ def main():
|
|||||||
help="The output directory where the model predictions and checkpoints will be written.",
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
)
|
)
|
||||||
|
|
||||||
## Other parameters
|
# Other parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||||
)
|
)
|
||||||
|
@ -422,7 +422,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)."
|
"--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)."
|
||||||
)
|
)
|
||||||
@ -434,7 +434,7 @@ def main():
|
|||||||
help="The output directory where the model predictions and checkpoints will be written.",
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
)
|
)
|
||||||
|
|
||||||
## Other parameters
|
# Other parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--eval_data_file",
|
"--eval_data_file",
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -385,7 +385,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False):
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data_dir",
|
"--data_dir",
|
||||||
default=None,
|
default=None,
|
||||||
@ -422,7 +422,7 @@ def main():
|
|||||||
help="The output directory where the model predictions and checkpoints will be written.",
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
)
|
)
|
||||||
|
|
||||||
## Other parameters
|
# Other parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||||
)
|
)
|
||||||
|
@ -385,7 +385,7 @@ def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode):
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data_dir",
|
"--data_dir",
|
||||||
default=None,
|
default=None,
|
||||||
@ -415,7 +415,7 @@ def main():
|
|||||||
help="The output directory where the model predictions and checkpoints will be written.",
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
)
|
)
|
||||||
|
|
||||||
## Other parameters
|
# Other parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--labels",
|
"--labels",
|
||||||
default="",
|
default="",
|
||||||
|
@ -377,7 +377,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data_dir",
|
"--data_dir",
|
||||||
default=None,
|
default=None,
|
||||||
@ -417,7 +417,7 @@ def main():
|
|||||||
help="The output directory where the model predictions and checkpoints will be written.",
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
)
|
)
|
||||||
|
|
||||||
## Other parameters
|
# Other parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||||
)
|
)
|
||||||
|
@ -401,7 +401,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
|
"--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
|
||||||
)
|
)
|
||||||
@ -434,7 +434,7 @@ def main():
|
|||||||
help="The output directory where the model checkpoints and predictions will be written.",
|
help="The output directory where the model checkpoints and predictions will be written.",
|
||||||
)
|
)
|
||||||
|
|
||||||
## Other parameters
|
# Other parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||||
)
|
)
|
||||||
|
@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||||
)
|
)
|
||||||
|
@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pyt
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||||
)
|
)
|
||||||
|
@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||||
)
|
)
|
||||||
|
@ -51,7 +51,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
"--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||||
)
|
)
|
||||||
|
@ -51,7 +51,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--openai_checkpoint_folder_path",
|
"--openai_checkpoint_folder_path",
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -410,7 +410,7 @@ def convert_all_pt_checkpoints_to_tf(
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
|
"--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
|
||||||
)
|
)
|
||||||
|
@ -94,7 +94,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
|||||||
layer: BertLayer = model.roberta.encoder.layer[i]
|
layer: BertLayer = model.roberta.encoder.layer[i]
|
||||||
roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i]
|
roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i]
|
||||||
|
|
||||||
### self attention
|
# self attention
|
||||||
self_attn: BertSelfAttention = layer.attention.self
|
self_attn: BertSelfAttention = layer.attention.self
|
||||||
assert (
|
assert (
|
||||||
roberta_layer.self_attn.k_proj.weight.data.shape
|
roberta_layer.self_attn.k_proj.weight.data.shape
|
||||||
@ -110,7 +110,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
|||||||
self_attn.value.weight.data = roberta_layer.self_attn.v_proj.weight
|
self_attn.value.weight.data = roberta_layer.self_attn.v_proj.weight
|
||||||
self_attn.value.bias.data = roberta_layer.self_attn.v_proj.bias
|
self_attn.value.bias.data = roberta_layer.self_attn.v_proj.bias
|
||||||
|
|
||||||
### self-attention output
|
# self-attention output
|
||||||
self_output: BertSelfOutput = layer.attention.output
|
self_output: BertSelfOutput = layer.attention.output
|
||||||
assert self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
|
assert self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
|
||||||
self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
|
self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
|
||||||
@ -118,20 +118,20 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
|||||||
self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
|
self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
|
||||||
self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias
|
self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias
|
||||||
|
|
||||||
### intermediate
|
# intermediate
|
||||||
intermediate: BertIntermediate = layer.intermediate
|
intermediate: BertIntermediate = layer.intermediate
|
||||||
assert intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
|
assert intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
|
||||||
intermediate.dense.weight = roberta_layer.fc1.weight
|
intermediate.dense.weight = roberta_layer.fc1.weight
|
||||||
intermediate.dense.bias = roberta_layer.fc1.bias
|
intermediate.dense.bias = roberta_layer.fc1.bias
|
||||||
|
|
||||||
### output
|
# output
|
||||||
bert_output: BertOutput = layer.output
|
bert_output: BertOutput = layer.output
|
||||||
assert bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
|
assert bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
|
||||||
bert_output.dense.weight = roberta_layer.fc2.weight
|
bert_output.dense.weight = roberta_layer.fc2.weight
|
||||||
bert_output.dense.bias = roberta_layer.fc2.bias
|
bert_output.dense.bias = roberta_layer.fc2.bias
|
||||||
bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
|
bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
|
||||||
bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias
|
bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias
|
||||||
#### end of layer
|
# end of layer
|
||||||
|
|
||||||
if classification_head:
|
if classification_head:
|
||||||
model.classifier.dense.weight = roberta.model.classification_heads["mnli"].dense.weight
|
model.classifier.dense.weight = roberta.model.classification_heads["mnli"].dense.weight
|
||||||
@ -170,7 +170,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--roberta_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
|
"--roberta_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
|
||||||
)
|
)
|
||||||
|
@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||||
)
|
)
|
||||||
|
@ -70,7 +70,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--xlm_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
|
"--xlm_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
|
||||||
)
|
)
|
||||||
|
@ -82,7 +82,7 @@ def convert_xlnet_checkpoint_to_pytorch(
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
# Required parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||||
)
|
)
|
||||||
|
@ -47,7 +47,7 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ###
|
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||||
|
|
||||||
@ -327,7 +327,7 @@ class Transformer(nn.Module):
|
|||||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||||
|
|
||||||
|
|
||||||
### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ###
|
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
|
||||||
class DistilBertPreTrainedModel(PreTrainedModel):
|
class DistilBertPreTrainedModel(PreTrainedModel):
|
||||||
""" An abstract class to handle weights initialization and
|
""" An abstract class to handle weights initialization and
|
||||||
a simple interface for downloading and loading pretrained models.
|
a simple interface for downloading and loading pretrained models.
|
||||||
|
@ -42,7 +42,7 @@ TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ###
|
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
""" Gaussian Error Linear Unit.
|
""" Gaussian Error Linear Unit.
|
||||||
Original Implementation of the gelu activation function in Google Bert repo when initially created.
|
Original Implementation of the gelu activation function in Google Bert repo when initially created.
|
||||||
@ -463,7 +463,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
|||||||
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
|
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
|
||||||
|
|
||||||
|
|
||||||
### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ###
|
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
|
||||||
class TFDistilBertPreTrainedModel(TFPreTrainedModel):
|
class TFDistilBertPreTrainedModel(TFPreTrainedModel):
|
||||||
""" An abstract class to handle weights initialization and
|
""" An abstract class to handle weights initialization and
|
||||||
a simple interface for downloading and loading pretrained models.
|
a simple interface for downloading and loading pretrained models.
|
||||||
|
@ -67,7 +67,8 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
|
|||||||
|
|
||||||
|
|
||||||
#####################
|
#####################
|
||||||
### PyTorch => TF 2.0
|
# PyTorch => TF 2.0 #
|
||||||
|
#####################
|
||||||
|
|
||||||
|
|
||||||
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False):
|
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False):
|
||||||
@ -197,7 +198,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
|||||||
|
|
||||||
|
|
||||||
#####################
|
#####################
|
||||||
### TF 2.0 => PyTorch
|
# TF 2.0 => PyTorch #
|
||||||
|
#####################
|
||||||
|
|
||||||
|
|
||||||
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False):
|
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False):
|
||||||
|
@ -79,23 +79,23 @@ class TFPositionwiseFF(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
def call(self, inp, training=False):
|
def call(self, inp, training=False):
|
||||||
if self.pre_lnorm:
|
if self.pre_lnorm:
|
||||||
##### layer normalization + positionwise feed-forward
|
# layer normalization + positionwise feed-forward
|
||||||
core_out = self.layer_norm(inp)
|
core_out = self.layer_norm(inp)
|
||||||
core_out = self.layer_1(core_out)
|
core_out = self.layer_1(core_out)
|
||||||
core_out = self.drop_1(core_out, training=training)
|
core_out = self.drop_1(core_out, training=training)
|
||||||
core_out = self.layer_2(core_out)
|
core_out = self.layer_2(core_out)
|
||||||
core_out = self.drop_2(core_out, training=training)
|
core_out = self.drop_2(core_out, training=training)
|
||||||
|
|
||||||
##### residual connection
|
# residual connection
|
||||||
output = core_out + inp
|
output = core_out + inp
|
||||||
else:
|
else:
|
||||||
##### positionwise feed-forward
|
# positionwise feed-forward
|
||||||
core_out = self.layer_1(inp)
|
core_out = self.layer_1(inp)
|
||||||
core_out = self.drop_1(core_out, training=training)
|
core_out = self.drop_1(core_out, training=training)
|
||||||
core_out = self.layer_2(core_out)
|
core_out = self.layer_2(core_out)
|
||||||
core_out = self.drop_2(core_out, training=training)
|
core_out = self.drop_2(core_out, training=training)
|
||||||
|
|
||||||
##### residual connection + layer normalization
|
# residual connection + layer normalization
|
||||||
output = self.layer_norm(inp + core_out)
|
output = self.layer_norm(inp + core_out)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
@ -206,7 +206,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
r_head_k = tf.reshape(r_head_k, (rlen, self.n_head, self.d_head)) # qlen x n_head x d_head
|
r_head_k = tf.reshape(r_head_k, (rlen, self.n_head, self.d_head)) # qlen x n_head x d_head
|
||||||
|
|
||||||
#### compute attention score
|
# compute attention score
|
||||||
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
|
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
|
||||||
AC = tf.einsum("ibnd,jbnd->ijbn", rw_head_q, w_head_k) # qlen x klen x bsz x n_head
|
AC = tf.einsum("ibnd,jbnd->ijbn", rw_head_q, w_head_k) # qlen x klen x bsz x n_head
|
||||||
|
|
||||||
@ -218,7 +218,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
|||||||
attn_score = AC + BD
|
attn_score = AC + BD
|
||||||
attn_score = attn_score * self.scale
|
attn_score = attn_score * self.scale
|
||||||
|
|
||||||
#### compute attention probability
|
# compute attention probability
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
attn_mask_t = attn_mask[:, :, None, None]
|
attn_mask_t = attn_mask[:, :, None, None]
|
||||||
attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t
|
attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t
|
||||||
@ -231,22 +231,22 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
|||||||
if head_mask is not None:
|
if head_mask is not None:
|
||||||
attn_prob = attn_prob * head_mask
|
attn_prob = attn_prob * head_mask
|
||||||
|
|
||||||
#### compute attention vector
|
# compute attention vector
|
||||||
attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, w_head_v)
|
attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, w_head_v)
|
||||||
|
|
||||||
# [qlen x bsz x n_head x d_head]
|
# [qlen x bsz x n_head x d_head]
|
||||||
attn_vec_sizes = shape_list(attn_vec)
|
attn_vec_sizes = shape_list(attn_vec)
|
||||||
attn_vec = tf.reshape(attn_vec, (attn_vec_sizes[0], attn_vec_sizes[1], self.n_head * self.d_head))
|
attn_vec = tf.reshape(attn_vec, (attn_vec_sizes[0], attn_vec_sizes[1], self.n_head * self.d_head))
|
||||||
|
|
||||||
##### linear projection
|
# linear projection
|
||||||
attn_out = self.o_net(attn_vec)
|
attn_out = self.o_net(attn_vec)
|
||||||
attn_out = self.drop(attn_out, training=training)
|
attn_out = self.drop(attn_out, training=training)
|
||||||
|
|
||||||
if self.pre_lnorm:
|
if self.pre_lnorm:
|
||||||
##### residual connection
|
# residual connection
|
||||||
outputs = [w + attn_out]
|
outputs = [w + attn_out]
|
||||||
else:
|
else:
|
||||||
##### residual connection + layer normalization
|
# residual connection + layer normalization
|
||||||
outputs = [self.layer_norm(w + attn_out)]
|
outputs = [self.layer_norm(w + attn_out)]
|
||||||
|
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
|
@ -190,7 +190,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||||||
(h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask) = inputs
|
(h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask) = inputs
|
||||||
|
|
||||||
if g is not None:
|
if g is not None:
|
||||||
###### Two-stream attention with relative positional encoding.
|
# Two-stream attention with relative positional encoding.
|
||||||
# content based attention score
|
# content based attention score
|
||||||
if mems is not None and len(shape_list(mems)) > 1:
|
if mems is not None and len(shape_list(mems)) > 1:
|
||||||
cat = tf.concat([mems, h], axis=0)
|
cat = tf.concat([mems, h], axis=0)
|
||||||
@ -206,7 +206,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||||||
# position-based key head
|
# position-based key head
|
||||||
k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
|
k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
|
||||||
|
|
||||||
##### h-stream
|
# h-stream
|
||||||
# content-stream query head
|
# content-stream query head
|
||||||
q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
|
q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
|
||||||
|
|
||||||
@ -221,7 +221,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||||||
# post processing
|
# post processing
|
||||||
output_h = self.post_attention([h, attn_vec_h], training=training)
|
output_h = self.post_attention([h, attn_vec_h], training=training)
|
||||||
|
|
||||||
##### g-stream
|
# g-stream
|
||||||
# query-stream query head
|
# query-stream query head
|
||||||
q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.q)
|
q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.q)
|
||||||
|
|
||||||
@ -251,7 +251,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||||||
attn_prob = attn_prob_h, attn_prob_g
|
attn_prob = attn_prob_h, attn_prob_g
|
||||||
|
|
||||||
else:
|
else:
|
||||||
###### Multi-head attention with relative positional encoding
|
# Multi-head attention with relative positional encoding
|
||||||
if mems is not None and len(shape_list(mems)) > 1:
|
if mems is not None and len(shape_list(mems)) > 1:
|
||||||
cat = tf.concat([mems, h], axis=0)
|
cat = tf.concat([mems, h], axis=0)
|
||||||
else:
|
else:
|
||||||
@ -552,7 +552,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32
|
dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32
|
||||||
|
|
||||||
##### Attention mask
|
# Attention mask
|
||||||
# causal attention mask
|
# causal attention mask
|
||||||
if self.attn_type == "uni":
|
if self.attn_type == "uni":
|
||||||
attn_mask = self.create_mask(qlen, mlen)
|
attn_mask = self.create_mask(qlen, mlen)
|
||||||
@ -597,7 +597,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
non_tgt_mask = None
|
non_tgt_mask = None
|
||||||
|
|
||||||
##### Word embeddings and prepare h & g hidden states
|
# Word embeddings and prepare h & g hidden states
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
word_emb_k = inputs_embeds
|
word_emb_k = inputs_embeds
|
||||||
else:
|
else:
|
||||||
@ -612,7 +612,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
output_g = None
|
output_g = None
|
||||||
|
|
||||||
##### Segment embedding
|
# Segment embedding
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
# Convert `token_type_ids` to one-hot `seg_mat`
|
# Convert `token_type_ids` to one-hot `seg_mat`
|
||||||
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
|
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
|
||||||
@ -624,7 +624,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
seg_mat = None
|
seg_mat = None
|
||||||
|
|
||||||
##### Positional encoding
|
# Positional encoding
|
||||||
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float)
|
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float)
|
||||||
pos_emb = self.dropout(pos_emb, training=training)
|
pos_emb = self.dropout(pos_emb, training=training)
|
||||||
|
|
||||||
|
@ -213,16 +213,16 @@ class PositionwiseFF(nn.Module):
|
|||||||
|
|
||||||
def forward(self, inp):
|
def forward(self, inp):
|
||||||
if self.pre_lnorm:
|
if self.pre_lnorm:
|
||||||
##### layer normalization + positionwise feed-forward
|
# layer normalization + positionwise feed-forward
|
||||||
core_out = self.CoreNet(self.layer_norm(inp))
|
core_out = self.CoreNet(self.layer_norm(inp))
|
||||||
|
|
||||||
##### residual connection
|
# residual connection
|
||||||
output = core_out + inp
|
output = core_out + inp
|
||||||
else:
|
else:
|
||||||
##### positionwise feed-forward
|
# positionwise feed-forward
|
||||||
core_out = self.CoreNet(inp)
|
core_out = self.CoreNet(inp)
|
||||||
|
|
||||||
##### residual connection + layer normalization
|
# residual connection + layer normalization
|
||||||
output = self.layer_norm(inp + core_out)
|
output = self.layer_norm(inp + core_out)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
@ -316,7 +316,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
|
|||||||
|
|
||||||
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
|
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
|
||||||
|
|
||||||
#### compute attention score
|
# compute attention score
|
||||||
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
|
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
|
||||||
AC = torch.einsum("ibnd,jbnd->ijbn", (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
|
AC = torch.einsum("ibnd,jbnd->ijbn", (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
|
||||||
|
|
||||||
@ -328,7 +328,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
|
|||||||
attn_score = AC + BD
|
attn_score = AC + BD
|
||||||
attn_score.mul_(self.scale)
|
attn_score.mul_(self.scale)
|
||||||
|
|
||||||
#### compute attention probability
|
# compute attention probability
|
||||||
if attn_mask is not None and torch.sum(attn_mask).item():
|
if attn_mask is not None and torch.sum(attn_mask).item():
|
||||||
attn_mask = attn_mask == 1 # Switch to bool
|
attn_mask = attn_mask == 1 # Switch to bool
|
||||||
if attn_mask.dim() == 2:
|
if attn_mask.dim() == 2:
|
||||||
@ -352,21 +352,21 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
|
|||||||
if head_mask is not None:
|
if head_mask is not None:
|
||||||
attn_prob = attn_prob * head_mask
|
attn_prob = attn_prob * head_mask
|
||||||
|
|
||||||
#### compute attention vector
|
# compute attention vector
|
||||||
attn_vec = torch.einsum("ijbn,jbnd->ibnd", (attn_prob, w_head_v))
|
attn_vec = torch.einsum("ijbn,jbnd->ibnd", (attn_prob, w_head_v))
|
||||||
|
|
||||||
# [qlen x bsz x n_head x d_head]
|
# [qlen x bsz x n_head x d_head]
|
||||||
attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
|
attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
|
||||||
|
|
||||||
##### linear projection
|
# linear projection
|
||||||
attn_out = self.o_net(attn_vec)
|
attn_out = self.o_net(attn_vec)
|
||||||
attn_out = self.drop(attn_out)
|
attn_out = self.drop(attn_out)
|
||||||
|
|
||||||
if self.pre_lnorm:
|
if self.pre_lnorm:
|
||||||
##### residual connection
|
# residual connection
|
||||||
outputs = [w + attn_out]
|
outputs = [w + attn_out]
|
||||||
else:
|
else:
|
||||||
##### residual connection + layer normalization
|
# residual connection + layer normalization
|
||||||
outputs = [self.layer_norm(w + attn_out)]
|
outputs = [self.layer_norm(w + attn_out)]
|
||||||
|
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
|
@ -330,7 +330,7 @@ class XLNetRelativeAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(self, h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems=None, target_mapping=None, head_mask=None):
|
def forward(self, h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems=None, target_mapping=None, head_mask=None):
|
||||||
if g is not None:
|
if g is not None:
|
||||||
###### Two-stream attention with relative positional encoding.
|
# Two-stream attention with relative positional encoding.
|
||||||
# content based attention score
|
# content based attention score
|
||||||
if mems is not None and mems.dim() > 1:
|
if mems is not None and mems.dim() > 1:
|
||||||
cat = torch.cat([mems, h], dim=0)
|
cat = torch.cat([mems, h], dim=0)
|
||||||
@ -346,7 +346,7 @@ class XLNetRelativeAttention(nn.Module):
|
|||||||
# position-based key head
|
# position-based key head
|
||||||
k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r)
|
k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r)
|
||||||
|
|
||||||
##### h-stream
|
# h-stream
|
||||||
# content-stream query head
|
# content-stream query head
|
||||||
q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
|
q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
|
||||||
|
|
||||||
@ -361,7 +361,7 @@ class XLNetRelativeAttention(nn.Module):
|
|||||||
# post processing
|
# post processing
|
||||||
output_h = self.post_attention(h, attn_vec_h)
|
output_h = self.post_attention(h, attn_vec_h)
|
||||||
|
|
||||||
##### g-stream
|
# g-stream
|
||||||
# query-stream query head
|
# query-stream query head
|
||||||
q_head_g = torch.einsum("ibh,hnd->ibnd", g, self.q)
|
q_head_g = torch.einsum("ibh,hnd->ibnd", g, self.q)
|
||||||
|
|
||||||
@ -391,7 +391,7 @@ class XLNetRelativeAttention(nn.Module):
|
|||||||
attn_prob = attn_prob_h, attn_prob_g
|
attn_prob = attn_prob_h, attn_prob_g
|
||||||
|
|
||||||
else:
|
else:
|
||||||
###### Multi-head attention with relative positional encoding
|
# Multi-head attention with relative positional encoding
|
||||||
if mems is not None and mems.dim() > 1:
|
if mems is not None and mems.dim() > 1:
|
||||||
cat = torch.cat([mems, h], dim=0)
|
cat = torch.cat([mems, h], dim=0)
|
||||||
else:
|
else:
|
||||||
@ -804,7 +804,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
dtype_float = next(self.parameters()).dtype
|
dtype_float = next(self.parameters()).dtype
|
||||||
device = next(self.parameters()).device
|
device = next(self.parameters()).device
|
||||||
|
|
||||||
##### Attention mask
|
# Attention mask
|
||||||
# causal attention mask
|
# causal attention mask
|
||||||
if self.attn_type == "uni":
|
if self.attn_type == "uni":
|
||||||
attn_mask = self.create_mask(qlen, mlen)
|
attn_mask = self.create_mask(qlen, mlen)
|
||||||
@ -849,7 +849,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
non_tgt_mask = None
|
non_tgt_mask = None
|
||||||
|
|
||||||
##### Word embeddings and prepare h & g hidden states
|
# Word embeddings and prepare h & g hidden states
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
word_emb_k = inputs_embeds
|
word_emb_k = inputs_embeds
|
||||||
else:
|
else:
|
||||||
@ -864,7 +864,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
output_g = None
|
output_g = None
|
||||||
|
|
||||||
##### Segment embedding
|
# Segment embedding
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
# Convert `token_type_ids` to one-hot `seg_mat`
|
# Convert `token_type_ids` to one-hot `seg_mat`
|
||||||
if mlen > 0:
|
if mlen > 0:
|
||||||
@ -879,7 +879,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
seg_mat = None
|
seg_mat = None
|
||||||
|
|
||||||
##### Positional encoding
|
# Positional encoding
|
||||||
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
|
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
|
||||||
pos_emb = self.dropout(pos_emb)
|
pos_emb = self.dropout(pos_emb)
|
||||||
|
|
||||||
|
@ -178,7 +178,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
## Inspired from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
|
# Inspired from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
|
||||||
class GradientAccumulator(object):
|
class GradientAccumulator(object):
|
||||||
"""Distribution strategies-aware gradient accumulation utility."""
|
"""Distribution strategies-aware gradient accumulation utility."""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user