Fix E266 flake8 warning (x90).

This commit is contained in:
Aymeric Augustin 2019-12-21 21:22:55 +01:00
parent 2ab78325f0
commit fa2ccbc081
30 changed files with 92 additions and 90 deletions

View File

@ -487,7 +487,7 @@ def evaluate(args, model, tokenizer, prefix=""):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--train_file", default=None, type=str, required=True, help="SWAG csv for training. E.g., train.csv" "--train_file", default=None, type=str, required=True, help="SWAG csv for training. E.g., train.csv"
) )
@ -520,7 +520,7 @@ def main():
help="The output directory where the model checkpoints and predictions will be written.", help="The output directory where the model checkpoints and predictions will be written.",
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
) )

View File

@ -430,7 +430,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json" "--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
) )
@ -486,7 +486,7 @@ def main():
"--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation." "--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
) )

View File

@ -43,7 +43,7 @@ if __name__ == "__main__":
state_dict = model.state_dict() state_dict = model.state_dict()
compressed_sd = {} compressed_sd = {}
### Embeddings ### # Embeddings #
if args.model_type == "gpt2": if args.model_type == "gpt2":
for param_name in ["wte.weight", "wpe.weight"]: for param_name in ["wte.weight", "wpe.weight"]:
compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"] compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
@ -55,7 +55,7 @@ if __name__ == "__main__":
param_name = f"{prefix}.embeddings.LayerNorm.{w}" param_name = f"{prefix}.embeddings.LayerNorm.{w}"
compressed_sd[param_name] = state_dict[param_name] compressed_sd[param_name] = state_dict[param_name]
### Transformer Blocks ### # Transformer Blocks #
std_idx = 0 std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]: for teacher_idx in [0, 2, 4, 7, 9, 11]:
if args.model_type == "gpt2": if args.model_type == "gpt2":
@ -82,7 +82,7 @@ if __name__ == "__main__":
] ]
std_idx += 1 std_idx += 1
### Language Modeling Head ###s # Language Modeling Head ###s
if args.model_type == "roberta": if args.model_type == "roberta":
for layer in ["lm_head.decoder.weight", "lm_head.bias"]: for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
compressed_sd[f"{layer}"] = state_dict[f"{layer}"] compressed_sd[f"{layer}"] = state_dict[f"{layer}"]

View File

@ -219,7 +219,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
sanity_checks(args) sanity_checks(args)
## ARGS ## # ARGS #
init_gpu_params(args) init_gpu_params(args)
set_seed(args) set_seed(args)
if args.is_master: if args.is_master:
@ -236,7 +236,7 @@ def main():
os.makedirs(args.dump_path) os.makedirs(args.dump_path)
logger.info(f"Experiment will be dumped and logged in {args.dump_path}") logger.info(f"Experiment will be dumped and logged in {args.dump_path}")
### SAVE PARAMS ### # SAVE PARAMS #
logger.info(f"Param: {args}") logger.info(f"Param: {args}")
with open(os.path.join(args.dump_path, "parameters.json"), "w") as f: with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
json.dump(vars(args), f, indent=4) json.dump(vars(args), f, indent=4)
@ -245,7 +245,7 @@ def main():
student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type] student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type] teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]
### TOKENIZER ### # TOKENIZER #
tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name) tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
special_tok_ids = {} special_tok_ids = {}
for tok_name, tok_symbol in tokenizer.special_tokens_map.items(): for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
@ -255,7 +255,7 @@ def main():
args.special_tok_ids = special_tok_ids args.special_tok_ids = special_tok_ids
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name] args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
## DATA LOADER ## # DATA LOADER #
logger.info(f"Loading data from {args.data_file}") logger.info(f"Loading data from {args.data_file}")
with open(args.data_file, "rb") as fp: with open(args.data_file, "rb") as fp:
data = pickle.load(fp) data = pickle.load(fp)
@ -275,7 +275,7 @@ def main():
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data) train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
logger.info(f"Data loader created.") logger.info(f"Data loader created.")
## STUDENT ## # STUDENT #
logger.info(f"Loading student config from {args.student_config}") logger.info(f"Loading student config from {args.student_config}")
stu_architecture_config = student_config_class.from_pretrained(args.student_config) stu_architecture_config = student_config_class.from_pretrained(args.student_config)
stu_architecture_config.output_hidden_states = True stu_architecture_config.output_hidden_states = True
@ -290,26 +290,26 @@ def main():
student.to(f"cuda:{args.local_rank}") student.to(f"cuda:{args.local_rank}")
logger.info(f"Student loaded.") logger.info(f"Student loaded.")
## TEACHER ## # TEACHER #
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True) teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
if args.n_gpu > 0: if args.n_gpu > 0:
teacher.to(f"cuda:{args.local_rank}") teacher.to(f"cuda:{args.local_rank}")
logger.info(f"Teacher loaded from {args.teacher_name}.") logger.info(f"Teacher loaded from {args.teacher_name}.")
## FREEZING ## # FREEZING #
if args.freeze_pos_embs: if args.freeze_pos_embs:
freeze_pos_embeddings(student, args) freeze_pos_embeddings(student, args)
if args.freeze_token_type_embds: if args.freeze_token_type_embds:
freeze_token_type_embeddings(student, args) freeze_token_type_embeddings(student, args)
## SANITY CHECKS ## # SANITY CHECKS #
assert student.config.vocab_size == teacher.config.vocab_size assert student.config.vocab_size == teacher.config.vocab_size
assert student.config.hidden_size == teacher.config.hidden_size assert student.config.hidden_size == teacher.config.hidden_size
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
if args.mlm: if args.mlm:
assert token_probs.size(0) == stu_architecture_config.vocab_size assert token_probs.size(0) == stu_architecture_config.vocab_size
## DISTILLER ## # DISTILLER #
torch.cuda.empty_cache() torch.cuda.empty_cache()
distiller = Distiller( distiller = Distiller(
params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher

View File

@ -344,7 +344,7 @@ def load_examples(args, tokenizer, evaluate=False):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--data_dir", "--data_dir",
default=None, default=None,
@ -374,7 +374,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
) )

View File

@ -242,7 +242,7 @@ def prune_heads(args, model, eval_dataloader, head_mask):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--data_dir", "--data_dir",
default=None, default=None,
@ -272,7 +272,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--config_name", "--config_name",
default="", default="",

View File

@ -410,7 +410,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--data_dir", "--data_dir",
default=None, default=None,
@ -447,7 +447,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
) )

View File

@ -422,7 +422,7 @@ def evaluate(args, model, tokenizer, prefix=""):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)." "--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)."
) )
@ -434,7 +434,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--eval_data_file", "--eval_data_file",
default=None, default=None,

View File

@ -385,7 +385,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--data_dir", "--data_dir",
default=None, default=None,
@ -422,7 +422,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
) )

View File

@ -385,7 +385,7 @@ def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--data_dir", "--data_dir",
default=None, default=None,
@ -415,7 +415,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--labels", "--labels",
default="", default="",

View File

@ -377,7 +377,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--data_dir", "--data_dir",
default=None, default=None,
@ -417,7 +417,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
) )

View File

@ -401,7 +401,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json" "--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
) )
@ -434,7 +434,7 @@ def main():
help="The output directory where the model checkpoints and predictions will be written.", help="The output directory where the model checkpoints and predictions will be written.",
) )
## Other parameters # Other parameters
parser.add_argument( parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
) )

View File

@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
) )

View File

@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pyt
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
) )

View File

@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
) )

View File

@ -51,7 +51,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." "--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
) )

View File

@ -51,7 +51,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--openai_checkpoint_folder_path", "--openai_checkpoint_folder_path",
default=None, default=None,

View File

@ -410,7 +410,7 @@ def convert_all_pt_checkpoints_to_tf(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file." "--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
) )

View File

@ -94,7 +94,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
layer: BertLayer = model.roberta.encoder.layer[i] layer: BertLayer = model.roberta.encoder.layer[i]
roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i] roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i]
### self attention # self attention
self_attn: BertSelfAttention = layer.attention.self self_attn: BertSelfAttention = layer.attention.self
assert ( assert (
roberta_layer.self_attn.k_proj.weight.data.shape roberta_layer.self_attn.k_proj.weight.data.shape
@ -110,7 +110,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
self_attn.value.weight.data = roberta_layer.self_attn.v_proj.weight self_attn.value.weight.data = roberta_layer.self_attn.v_proj.weight
self_attn.value.bias.data = roberta_layer.self_attn.v_proj.bias self_attn.value.bias.data = roberta_layer.self_attn.v_proj.bias
### self-attention output # self-attention output
self_output: BertSelfOutput = layer.attention.output self_output: BertSelfOutput = layer.attention.output
assert self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape assert self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
self_output.dense.weight = roberta_layer.self_attn.out_proj.weight self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
@ -118,20 +118,20 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias
### intermediate # intermediate
intermediate: BertIntermediate = layer.intermediate intermediate: BertIntermediate = layer.intermediate
assert intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape assert intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
intermediate.dense.weight = roberta_layer.fc1.weight intermediate.dense.weight = roberta_layer.fc1.weight
intermediate.dense.bias = roberta_layer.fc1.bias intermediate.dense.bias = roberta_layer.fc1.bias
### output # output
bert_output: BertOutput = layer.output bert_output: BertOutput = layer.output
assert bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape assert bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
bert_output.dense.weight = roberta_layer.fc2.weight bert_output.dense.weight = roberta_layer.fc2.weight
bert_output.dense.bias = roberta_layer.fc2.bias bert_output.dense.bias = roberta_layer.fc2.bias
bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias
#### end of layer # end of layer
if classification_head: if classification_head:
model.classifier.dense.weight = roberta.model.classification_heads["mnli"].dense.weight model.classifier.dense.weight = roberta.model.classification_heads["mnli"].dense.weight
@ -170,7 +170,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--roberta_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump." "--roberta_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
) )

View File

@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
) )

View File

@ -70,7 +70,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--xlm_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump." "--xlm_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
) )

View File

@ -82,7 +82,7 @@ def convert_xlnet_checkpoint_to_pytorch(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument( parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
) )

View File

@ -47,7 +47,7 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
} }
### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ### # UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
def gelu(x): def gelu(x):
return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0))) return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))
@ -327,7 +327,7 @@ class Transformer(nn.Module):
return outputs # last-layer hidden state, (all hidden states), (all attentions) return outputs # last-layer hidden state, (all hidden states), (all attentions)
### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ### # INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
class DistilBertPreTrainedModel(PreTrainedModel): class DistilBertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.

View File

@ -42,7 +42,7 @@ TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
} }
### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ### # UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
def gelu(x): def gelu(x):
""" Gaussian Error Linear Unit. """ Gaussian Error Linear Unit.
Original Implementation of the gelu activation function in Google Bert repo when initially created. Original Implementation of the gelu activation function in Google Bert repo when initially created.
@ -463,7 +463,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions) return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ### # INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
class TFDistilBertPreTrainedModel(TFPreTrainedModel): class TFDistilBertPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.

View File

@ -67,7 +67,8 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
##################### #####################
### PyTorch => TF 2.0 # PyTorch => TF 2.0 #
#####################
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False): def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False):
@ -197,7 +198,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
##################### #####################
### TF 2.0 => PyTorch # TF 2.0 => PyTorch #
#####################
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False): def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False):

View File

@ -79,23 +79,23 @@ class TFPositionwiseFF(tf.keras.layers.Layer):
def call(self, inp, training=False): def call(self, inp, training=False):
if self.pre_lnorm: if self.pre_lnorm:
##### layer normalization + positionwise feed-forward # layer normalization + positionwise feed-forward
core_out = self.layer_norm(inp) core_out = self.layer_norm(inp)
core_out = self.layer_1(core_out) core_out = self.layer_1(core_out)
core_out = self.drop_1(core_out, training=training) core_out = self.drop_1(core_out, training=training)
core_out = self.layer_2(core_out) core_out = self.layer_2(core_out)
core_out = self.drop_2(core_out, training=training) core_out = self.drop_2(core_out, training=training)
##### residual connection # residual connection
output = core_out + inp output = core_out + inp
else: else:
##### positionwise feed-forward # positionwise feed-forward
core_out = self.layer_1(inp) core_out = self.layer_1(inp)
core_out = self.drop_1(core_out, training=training) core_out = self.drop_1(core_out, training=training)
core_out = self.layer_2(core_out) core_out = self.layer_2(core_out)
core_out = self.drop_2(core_out, training=training) core_out = self.drop_2(core_out, training=training)
##### residual connection + layer normalization # residual connection + layer normalization
output = self.layer_norm(inp + core_out) output = self.layer_norm(inp + core_out)
return output return output
@ -206,7 +206,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
r_head_k = tf.reshape(r_head_k, (rlen, self.n_head, self.d_head)) # qlen x n_head x d_head r_head_k = tf.reshape(r_head_k, (rlen, self.n_head, self.d_head)) # qlen x n_head x d_head
#### compute attention score # compute attention score
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
AC = tf.einsum("ibnd,jbnd->ijbn", rw_head_q, w_head_k) # qlen x klen x bsz x n_head AC = tf.einsum("ibnd,jbnd->ijbn", rw_head_q, w_head_k) # qlen x klen x bsz x n_head
@ -218,7 +218,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
attn_score = AC + BD attn_score = AC + BD
attn_score = attn_score * self.scale attn_score = attn_score * self.scale
#### compute attention probability # compute attention probability
if attn_mask is not None: if attn_mask is not None:
attn_mask_t = attn_mask[:, :, None, None] attn_mask_t = attn_mask[:, :, None, None]
attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t
@ -231,22 +231,22 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
if head_mask is not None: if head_mask is not None:
attn_prob = attn_prob * head_mask attn_prob = attn_prob * head_mask
#### compute attention vector # compute attention vector
attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, w_head_v) attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, w_head_v)
# [qlen x bsz x n_head x d_head] # [qlen x bsz x n_head x d_head]
attn_vec_sizes = shape_list(attn_vec) attn_vec_sizes = shape_list(attn_vec)
attn_vec = tf.reshape(attn_vec, (attn_vec_sizes[0], attn_vec_sizes[1], self.n_head * self.d_head)) attn_vec = tf.reshape(attn_vec, (attn_vec_sizes[0], attn_vec_sizes[1], self.n_head * self.d_head))
##### linear projection # linear projection
attn_out = self.o_net(attn_vec) attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out, training=training) attn_out = self.drop(attn_out, training=training)
if self.pre_lnorm: if self.pre_lnorm:
##### residual connection # residual connection
outputs = [w + attn_out] outputs = [w + attn_out]
else: else:
##### residual connection + layer normalization # residual connection + layer normalization
outputs = [self.layer_norm(w + attn_out)] outputs = [self.layer_norm(w + attn_out)]
if self.output_attentions: if self.output_attentions:

View File

@ -190,7 +190,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
(h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask) = inputs (h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask) = inputs
if g is not None: if g is not None:
###### Two-stream attention with relative positional encoding. # Two-stream attention with relative positional encoding.
# content based attention score # content based attention score
if mems is not None and len(shape_list(mems)) > 1: if mems is not None and len(shape_list(mems)) > 1:
cat = tf.concat([mems, h], axis=0) cat = tf.concat([mems, h], axis=0)
@ -206,7 +206,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# position-based key head # position-based key head
k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r) k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
##### h-stream # h-stream
# content-stream query head # content-stream query head
q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q) q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
@ -221,7 +221,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# post processing # post processing
output_h = self.post_attention([h, attn_vec_h], training=training) output_h = self.post_attention([h, attn_vec_h], training=training)
##### g-stream # g-stream
# query-stream query head # query-stream query head
q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.q) q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.q)
@ -251,7 +251,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
attn_prob = attn_prob_h, attn_prob_g attn_prob = attn_prob_h, attn_prob_g
else: else:
###### Multi-head attention with relative positional encoding # Multi-head attention with relative positional encoding
if mems is not None and len(shape_list(mems)) > 1: if mems is not None and len(shape_list(mems)) > 1:
cat = tf.concat([mems, h], axis=0) cat = tf.concat([mems, h], axis=0)
else: else:
@ -552,7 +552,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32 dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32
##### Attention mask # Attention mask
# causal attention mask # causal attention mask
if self.attn_type == "uni": if self.attn_type == "uni":
attn_mask = self.create_mask(qlen, mlen) attn_mask = self.create_mask(qlen, mlen)
@ -597,7 +597,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else: else:
non_tgt_mask = None non_tgt_mask = None
##### Word embeddings and prepare h & g hidden states # Word embeddings and prepare h & g hidden states
if inputs_embeds is not None: if inputs_embeds is not None:
word_emb_k = inputs_embeds word_emb_k = inputs_embeds
else: else:
@ -612,7 +612,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else: else:
output_g = None output_g = None
##### Segment embedding # Segment embedding
if token_type_ids is not None: if token_type_ids is not None:
# Convert `token_type_ids` to one-hot `seg_mat` # Convert `token_type_ids` to one-hot `seg_mat`
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32) mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
@ -624,7 +624,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else: else:
seg_mat = None seg_mat = None
##### Positional encoding # Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float) pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float)
pos_emb = self.dropout(pos_emb, training=training) pos_emb = self.dropout(pos_emb, training=training)

View File

@ -213,16 +213,16 @@ class PositionwiseFF(nn.Module):
def forward(self, inp): def forward(self, inp):
if self.pre_lnorm: if self.pre_lnorm:
##### layer normalization + positionwise feed-forward # layer normalization + positionwise feed-forward
core_out = self.CoreNet(self.layer_norm(inp)) core_out = self.CoreNet(self.layer_norm(inp))
##### residual connection # residual connection
output = core_out + inp output = core_out + inp
else: else:
##### positionwise feed-forward # positionwise feed-forward
core_out = self.CoreNet(inp) core_out = self.CoreNet(inp)
##### residual connection + layer normalization # residual connection + layer normalization
output = self.layer_norm(inp + core_out) output = self.layer_norm(inp + core_out)
return output return output
@ -316,7 +316,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
#### compute attention score # compute attention score
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
AC = torch.einsum("ibnd,jbnd->ijbn", (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head AC = torch.einsum("ibnd,jbnd->ijbn", (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
@ -328,7 +328,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
attn_score = AC + BD attn_score = AC + BD
attn_score.mul_(self.scale) attn_score.mul_(self.scale)
#### compute attention probability # compute attention probability
if attn_mask is not None and torch.sum(attn_mask).item(): if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = attn_mask == 1 # Switch to bool attn_mask = attn_mask == 1 # Switch to bool
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
@ -352,21 +352,21 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
if head_mask is not None: if head_mask is not None:
attn_prob = attn_prob * head_mask attn_prob = attn_prob * head_mask
#### compute attention vector # compute attention vector
attn_vec = torch.einsum("ijbn,jbnd->ibnd", (attn_prob, w_head_v)) attn_vec = torch.einsum("ijbn,jbnd->ibnd", (attn_prob, w_head_v))
# [qlen x bsz x n_head x d_head] # [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
##### linear projection # linear projection
attn_out = self.o_net(attn_vec) attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out) attn_out = self.drop(attn_out)
if self.pre_lnorm: if self.pre_lnorm:
##### residual connection # residual connection
outputs = [w + attn_out] outputs = [w + attn_out]
else: else:
##### residual connection + layer normalization # residual connection + layer normalization
outputs = [self.layer_norm(w + attn_out)] outputs = [self.layer_norm(w + attn_out)]
if self.output_attentions: if self.output_attentions:

View File

@ -330,7 +330,7 @@ class XLNetRelativeAttention(nn.Module):
def forward(self, h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems=None, target_mapping=None, head_mask=None): def forward(self, h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems=None, target_mapping=None, head_mask=None):
if g is not None: if g is not None:
###### Two-stream attention with relative positional encoding. # Two-stream attention with relative positional encoding.
# content based attention score # content based attention score
if mems is not None and mems.dim() > 1: if mems is not None and mems.dim() > 1:
cat = torch.cat([mems, h], dim=0) cat = torch.cat([mems, h], dim=0)
@ -346,7 +346,7 @@ class XLNetRelativeAttention(nn.Module):
# position-based key head # position-based key head
k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r) k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r)
##### h-stream # h-stream
# content-stream query head # content-stream query head
q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q) q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
@ -361,7 +361,7 @@ class XLNetRelativeAttention(nn.Module):
# post processing # post processing
output_h = self.post_attention(h, attn_vec_h) output_h = self.post_attention(h, attn_vec_h)
##### g-stream # g-stream
# query-stream query head # query-stream query head
q_head_g = torch.einsum("ibh,hnd->ibnd", g, self.q) q_head_g = torch.einsum("ibh,hnd->ibnd", g, self.q)
@ -391,7 +391,7 @@ class XLNetRelativeAttention(nn.Module):
attn_prob = attn_prob_h, attn_prob_g attn_prob = attn_prob_h, attn_prob_g
else: else:
###### Multi-head attention with relative positional encoding # Multi-head attention with relative positional encoding
if mems is not None and mems.dim() > 1: if mems is not None and mems.dim() > 1:
cat = torch.cat([mems, h], dim=0) cat = torch.cat([mems, h], dim=0)
else: else:
@ -804,7 +804,7 @@ class XLNetModel(XLNetPreTrainedModel):
dtype_float = next(self.parameters()).dtype dtype_float = next(self.parameters()).dtype
device = next(self.parameters()).device device = next(self.parameters()).device
##### Attention mask # Attention mask
# causal attention mask # causal attention mask
if self.attn_type == "uni": if self.attn_type == "uni":
attn_mask = self.create_mask(qlen, mlen) attn_mask = self.create_mask(qlen, mlen)
@ -849,7 +849,7 @@ class XLNetModel(XLNetPreTrainedModel):
else: else:
non_tgt_mask = None non_tgt_mask = None
##### Word embeddings and prepare h & g hidden states # Word embeddings and prepare h & g hidden states
if inputs_embeds is not None: if inputs_embeds is not None:
word_emb_k = inputs_embeds word_emb_k = inputs_embeds
else: else:
@ -864,7 +864,7 @@ class XLNetModel(XLNetPreTrainedModel):
else: else:
output_g = None output_g = None
##### Segment embedding # Segment embedding
if token_type_ids is not None: if token_type_ids is not None:
# Convert `token_type_ids` to one-hot `seg_mat` # Convert `token_type_ids` to one-hot `seg_mat`
if mlen > 0: if mlen > 0:
@ -879,7 +879,7 @@ class XLNetModel(XLNetPreTrainedModel):
else: else:
seg_mat = None seg_mat = None
##### Positional encoding # Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz) pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
pos_emb = self.dropout(pos_emb) pos_emb = self.dropout(pos_emb)

View File

@ -178,7 +178,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
return True return True
## Inspired from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py # Inspired from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
class GradientAccumulator(object): class GradientAccumulator(object):
"""Distribution strategies-aware gradient accumulation utility.""" """Distribution strategies-aware gradient accumulation utility."""