mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 22:38:58 +06:00
Fix E266 flake8 warning (x90).
This commit is contained in:
parent
2ab78325f0
commit
fa2ccbc081
@ -487,7 +487,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--train_file", default=None, type=str, required=True, help="SWAG csv for training. E.g., train.csv"
|
||||
)
|
||||
@ -520,7 +520,7 @@ def main():
|
||||
help="The output directory where the model checkpoints and predictions will be written.",
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
|
@ -430,7 +430,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
|
||||
)
|
||||
@ -486,7 +486,7 @@ def main():
|
||||
"--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
|
@ -43,7 +43,7 @@ if __name__ == "__main__":
|
||||
state_dict = model.state_dict()
|
||||
compressed_sd = {}
|
||||
|
||||
### Embeddings ###
|
||||
# Embeddings #
|
||||
if args.model_type == "gpt2":
|
||||
for param_name in ["wte.weight", "wpe.weight"]:
|
||||
compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
|
||||
@ -55,7 +55,7 @@ if __name__ == "__main__":
|
||||
param_name = f"{prefix}.embeddings.LayerNorm.{w}"
|
||||
compressed_sd[param_name] = state_dict[param_name]
|
||||
|
||||
### Transformer Blocks ###
|
||||
# Transformer Blocks #
|
||||
std_idx = 0
|
||||
for teacher_idx in [0, 2, 4, 7, 9, 11]:
|
||||
if args.model_type == "gpt2":
|
||||
@ -82,7 +82,7 @@ if __name__ == "__main__":
|
||||
]
|
||||
std_idx += 1
|
||||
|
||||
### Language Modeling Head ###s
|
||||
# Language Modeling Head ###s
|
||||
if args.model_type == "roberta":
|
||||
for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
|
||||
compressed_sd[f"{layer}"] = state_dict[f"{layer}"]
|
||||
|
@ -219,7 +219,7 @@ def main():
|
||||
args = parser.parse_args()
|
||||
sanity_checks(args)
|
||||
|
||||
## ARGS ##
|
||||
# ARGS #
|
||||
init_gpu_params(args)
|
||||
set_seed(args)
|
||||
if args.is_master:
|
||||
@ -236,7 +236,7 @@ def main():
|
||||
os.makedirs(args.dump_path)
|
||||
logger.info(f"Experiment will be dumped and logged in {args.dump_path}")
|
||||
|
||||
### SAVE PARAMS ###
|
||||
# SAVE PARAMS #
|
||||
logger.info(f"Param: {args}")
|
||||
with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
@ -245,7 +245,7 @@ def main():
|
||||
student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
|
||||
teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]
|
||||
|
||||
### TOKENIZER ###
|
||||
# TOKENIZER #
|
||||
tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
|
||||
special_tok_ids = {}
|
||||
for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
|
||||
@ -255,7 +255,7 @@ def main():
|
||||
args.special_tok_ids = special_tok_ids
|
||||
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
|
||||
|
||||
## DATA LOADER ##
|
||||
# DATA LOADER #
|
||||
logger.info(f"Loading data from {args.data_file}")
|
||||
with open(args.data_file, "rb") as fp:
|
||||
data = pickle.load(fp)
|
||||
@ -275,7 +275,7 @@ def main():
|
||||
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
|
||||
logger.info(f"Data loader created.")
|
||||
|
||||
## STUDENT ##
|
||||
# STUDENT #
|
||||
logger.info(f"Loading student config from {args.student_config}")
|
||||
stu_architecture_config = student_config_class.from_pretrained(args.student_config)
|
||||
stu_architecture_config.output_hidden_states = True
|
||||
@ -290,26 +290,26 @@ def main():
|
||||
student.to(f"cuda:{args.local_rank}")
|
||||
logger.info(f"Student loaded.")
|
||||
|
||||
## TEACHER ##
|
||||
# TEACHER #
|
||||
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
|
||||
if args.n_gpu > 0:
|
||||
teacher.to(f"cuda:{args.local_rank}")
|
||||
logger.info(f"Teacher loaded from {args.teacher_name}.")
|
||||
|
||||
## FREEZING ##
|
||||
# FREEZING #
|
||||
if args.freeze_pos_embs:
|
||||
freeze_pos_embeddings(student, args)
|
||||
if args.freeze_token_type_embds:
|
||||
freeze_token_type_embeddings(student, args)
|
||||
|
||||
## SANITY CHECKS ##
|
||||
# SANITY CHECKS #
|
||||
assert student.config.vocab_size == teacher.config.vocab_size
|
||||
assert student.config.hidden_size == teacher.config.hidden_size
|
||||
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
|
||||
if args.mlm:
|
||||
assert token_probs.size(0) == stu_architecture_config.vocab_size
|
||||
|
||||
## DISTILLER ##
|
||||
# DISTILLER #
|
||||
torch.cuda.empty_cache()
|
||||
distiller = Distiller(
|
||||
params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher
|
||||
|
@ -344,7 +344,7 @@ def load_examples(args, tokenizer, evaluate=False):
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
@ -374,7 +374,7 @@ def main():
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
|
@ -242,7 +242,7 @@ def prune_heads(args, model, eval_dataloader, head_mask):
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
@ -272,7 +272,7 @@ def main():
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
default="",
|
||||
|
@ -410,7 +410,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
@ -447,7 +447,7 @@ def main():
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
|
@ -422,7 +422,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)."
|
||||
)
|
||||
@ -434,7 +434,7 @@ def main():
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--eval_data_file",
|
||||
default=None,
|
||||
|
@ -385,7 +385,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False):
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
@ -422,7 +422,7 @@ def main():
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
|
@ -385,7 +385,7 @@ def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode):
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
@ -415,7 +415,7 @@ def main():
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--labels",
|
||||
default="",
|
||||
|
@ -377,7 +377,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
@ -417,7 +417,7 @@ def main():
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
|
@ -401,7 +401,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
|
||||
)
|
||||
@ -434,7 +434,7 @@ def main():
|
||||
help="The output directory where the model checkpoints and predictions will be written.",
|
||||
)
|
||||
|
||||
## Other parameters
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
|
@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
|
@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pyt
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
|
@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
|
@ -51,7 +51,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
|
@ -51,7 +51,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--openai_checkpoint_folder_path",
|
||||
default=None,
|
||||
|
@ -410,7 +410,7 @@ def convert_all_pt_checkpoints_to_tf(
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
|
||||
)
|
||||
|
@ -94,7 +94,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
layer: BertLayer = model.roberta.encoder.layer[i]
|
||||
roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i]
|
||||
|
||||
### self attention
|
||||
# self attention
|
||||
self_attn: BertSelfAttention = layer.attention.self
|
||||
assert (
|
||||
roberta_layer.self_attn.k_proj.weight.data.shape
|
||||
@ -110,7 +110,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
self_attn.value.weight.data = roberta_layer.self_attn.v_proj.weight
|
||||
self_attn.value.bias.data = roberta_layer.self_attn.v_proj.bias
|
||||
|
||||
### self-attention output
|
||||
# self-attention output
|
||||
self_output: BertSelfOutput = layer.attention.output
|
||||
assert self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
|
||||
self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
|
||||
@ -118,20 +118,20 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
|
||||
self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias
|
||||
|
||||
### intermediate
|
||||
# intermediate
|
||||
intermediate: BertIntermediate = layer.intermediate
|
||||
assert intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
|
||||
intermediate.dense.weight = roberta_layer.fc1.weight
|
||||
intermediate.dense.bias = roberta_layer.fc1.bias
|
||||
|
||||
### output
|
||||
# output
|
||||
bert_output: BertOutput = layer.output
|
||||
assert bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
|
||||
bert_output.dense.weight = roberta_layer.fc2.weight
|
||||
bert_output.dense.bias = roberta_layer.fc2.bias
|
||||
bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
|
||||
bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias
|
||||
#### end of layer
|
||||
# end of layer
|
||||
|
||||
if classification_head:
|
||||
model.classifier.dense.weight = roberta.model.classification_heads["mnli"].dense.weight
|
||||
@ -170,7 +170,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--roberta_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
|
||||
)
|
||||
|
@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
|
@ -70,7 +70,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--xlm_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
|
||||
)
|
||||
|
@ -82,7 +82,7 @@ def convert_xlnet_checkpoint_to_pytorch(
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
|
@ -47,7 +47,7 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ###
|
||||
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
|
||||
def gelu(x):
|
||||
return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
@ -327,7 +327,7 @@ class Transformer(nn.Module):
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||
|
||||
|
||||
### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ###
|
||||
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
|
||||
class DistilBertPreTrainedModel(PreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for downloading and loading pretrained models.
|
||||
|
@ -42,7 +42,7 @@ TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ###
|
||||
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
|
||||
def gelu(x):
|
||||
""" Gaussian Error Linear Unit.
|
||||
Original Implementation of the gelu activation function in Google Bert repo when initially created.
|
||||
@ -463,7 +463,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
||||
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
|
||||
|
||||
|
||||
### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ###
|
||||
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
|
||||
class TFDistilBertPreTrainedModel(TFPreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for downloading and loading pretrained models.
|
||||
|
@ -67,7 +67,8 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
|
||||
|
||||
|
||||
#####################
|
||||
### PyTorch => TF 2.0
|
||||
# PyTorch => TF 2.0 #
|
||||
#####################
|
||||
|
||||
|
||||
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False):
|
||||
@ -197,7 +198,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
||||
|
||||
|
||||
#####################
|
||||
### TF 2.0 => PyTorch
|
||||
# TF 2.0 => PyTorch #
|
||||
#####################
|
||||
|
||||
|
||||
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False):
|
||||
|
@ -79,23 +79,23 @@ class TFPositionwiseFF(tf.keras.layers.Layer):
|
||||
|
||||
def call(self, inp, training=False):
|
||||
if self.pre_lnorm:
|
||||
##### layer normalization + positionwise feed-forward
|
||||
# layer normalization + positionwise feed-forward
|
||||
core_out = self.layer_norm(inp)
|
||||
core_out = self.layer_1(core_out)
|
||||
core_out = self.drop_1(core_out, training=training)
|
||||
core_out = self.layer_2(core_out)
|
||||
core_out = self.drop_2(core_out, training=training)
|
||||
|
||||
##### residual connection
|
||||
# residual connection
|
||||
output = core_out + inp
|
||||
else:
|
||||
##### positionwise feed-forward
|
||||
# positionwise feed-forward
|
||||
core_out = self.layer_1(inp)
|
||||
core_out = self.drop_1(core_out, training=training)
|
||||
core_out = self.layer_2(core_out)
|
||||
core_out = self.drop_2(core_out, training=training)
|
||||
|
||||
##### residual connection + layer normalization
|
||||
# residual connection + layer normalization
|
||||
output = self.layer_norm(inp + core_out)
|
||||
|
||||
return output
|
||||
@ -206,7 +206,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
||||
|
||||
r_head_k = tf.reshape(r_head_k, (rlen, self.n_head, self.d_head)) # qlen x n_head x d_head
|
||||
|
||||
#### compute attention score
|
||||
# compute attention score
|
||||
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
|
||||
AC = tf.einsum("ibnd,jbnd->ijbn", rw_head_q, w_head_k) # qlen x klen x bsz x n_head
|
||||
|
||||
@ -218,7 +218,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
||||
attn_score = AC + BD
|
||||
attn_score = attn_score * self.scale
|
||||
|
||||
#### compute attention probability
|
||||
# compute attention probability
|
||||
if attn_mask is not None:
|
||||
attn_mask_t = attn_mask[:, :, None, None]
|
||||
attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t
|
||||
@ -231,22 +231,22 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
||||
if head_mask is not None:
|
||||
attn_prob = attn_prob * head_mask
|
||||
|
||||
#### compute attention vector
|
||||
# compute attention vector
|
||||
attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, w_head_v)
|
||||
|
||||
# [qlen x bsz x n_head x d_head]
|
||||
attn_vec_sizes = shape_list(attn_vec)
|
||||
attn_vec = tf.reshape(attn_vec, (attn_vec_sizes[0], attn_vec_sizes[1], self.n_head * self.d_head))
|
||||
|
||||
##### linear projection
|
||||
# linear projection
|
||||
attn_out = self.o_net(attn_vec)
|
||||
attn_out = self.drop(attn_out, training=training)
|
||||
|
||||
if self.pre_lnorm:
|
||||
##### residual connection
|
||||
# residual connection
|
||||
outputs = [w + attn_out]
|
||||
else:
|
||||
##### residual connection + layer normalization
|
||||
# residual connection + layer normalization
|
||||
outputs = [self.layer_norm(w + attn_out)]
|
||||
|
||||
if self.output_attentions:
|
||||
|
@ -190,7 +190,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
||||
(h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask) = inputs
|
||||
|
||||
if g is not None:
|
||||
###### Two-stream attention with relative positional encoding.
|
||||
# Two-stream attention with relative positional encoding.
|
||||
# content based attention score
|
||||
if mems is not None and len(shape_list(mems)) > 1:
|
||||
cat = tf.concat([mems, h], axis=0)
|
||||
@ -206,7 +206,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
||||
# position-based key head
|
||||
k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
|
||||
|
||||
##### h-stream
|
||||
# h-stream
|
||||
# content-stream query head
|
||||
q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
|
||||
|
||||
@ -221,7 +221,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
||||
# post processing
|
||||
output_h = self.post_attention([h, attn_vec_h], training=training)
|
||||
|
||||
##### g-stream
|
||||
# g-stream
|
||||
# query-stream query head
|
||||
q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.q)
|
||||
|
||||
@ -251,7 +251,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
||||
attn_prob = attn_prob_h, attn_prob_g
|
||||
|
||||
else:
|
||||
###### Multi-head attention with relative positional encoding
|
||||
# Multi-head attention with relative positional encoding
|
||||
if mems is not None and len(shape_list(mems)) > 1:
|
||||
cat = tf.concat([mems, h], axis=0)
|
||||
else:
|
||||
@ -552,7 +552,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32
|
||||
|
||||
##### Attention mask
|
||||
# Attention mask
|
||||
# causal attention mask
|
||||
if self.attn_type == "uni":
|
||||
attn_mask = self.create_mask(qlen, mlen)
|
||||
@ -597,7 +597,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
else:
|
||||
non_tgt_mask = None
|
||||
|
||||
##### Word embeddings and prepare h & g hidden states
|
||||
# Word embeddings and prepare h & g hidden states
|
||||
if inputs_embeds is not None:
|
||||
word_emb_k = inputs_embeds
|
||||
else:
|
||||
@ -612,7 +612,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
else:
|
||||
output_g = None
|
||||
|
||||
##### Segment embedding
|
||||
# Segment embedding
|
||||
if token_type_ids is not None:
|
||||
# Convert `token_type_ids` to one-hot `seg_mat`
|
||||
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
|
||||
@ -624,7 +624,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
else:
|
||||
seg_mat = None
|
||||
|
||||
##### Positional encoding
|
||||
# Positional encoding
|
||||
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float)
|
||||
pos_emb = self.dropout(pos_emb, training=training)
|
||||
|
||||
|
@ -213,16 +213,16 @@ class PositionwiseFF(nn.Module):
|
||||
|
||||
def forward(self, inp):
|
||||
if self.pre_lnorm:
|
||||
##### layer normalization + positionwise feed-forward
|
||||
# layer normalization + positionwise feed-forward
|
||||
core_out = self.CoreNet(self.layer_norm(inp))
|
||||
|
||||
##### residual connection
|
||||
# residual connection
|
||||
output = core_out + inp
|
||||
else:
|
||||
##### positionwise feed-forward
|
||||
# positionwise feed-forward
|
||||
core_out = self.CoreNet(inp)
|
||||
|
||||
##### residual connection + layer normalization
|
||||
# residual connection + layer normalization
|
||||
output = self.layer_norm(inp + core_out)
|
||||
|
||||
return output
|
||||
@ -316,7 +316,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
|
||||
|
||||
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
|
||||
|
||||
#### compute attention score
|
||||
# compute attention score
|
||||
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
|
||||
AC = torch.einsum("ibnd,jbnd->ijbn", (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
|
||||
|
||||
@ -328,7 +328,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
|
||||
attn_score = AC + BD
|
||||
attn_score.mul_(self.scale)
|
||||
|
||||
#### compute attention probability
|
||||
# compute attention probability
|
||||
if attn_mask is not None and torch.sum(attn_mask).item():
|
||||
attn_mask = attn_mask == 1 # Switch to bool
|
||||
if attn_mask.dim() == 2:
|
||||
@ -352,21 +352,21 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
|
||||
if head_mask is not None:
|
||||
attn_prob = attn_prob * head_mask
|
||||
|
||||
#### compute attention vector
|
||||
# compute attention vector
|
||||
attn_vec = torch.einsum("ijbn,jbnd->ibnd", (attn_prob, w_head_v))
|
||||
|
||||
# [qlen x bsz x n_head x d_head]
|
||||
attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
|
||||
|
||||
##### linear projection
|
||||
# linear projection
|
||||
attn_out = self.o_net(attn_vec)
|
||||
attn_out = self.drop(attn_out)
|
||||
|
||||
if self.pre_lnorm:
|
||||
##### residual connection
|
||||
# residual connection
|
||||
outputs = [w + attn_out]
|
||||
else:
|
||||
##### residual connection + layer normalization
|
||||
# residual connection + layer normalization
|
||||
outputs = [self.layer_norm(w + attn_out)]
|
||||
|
||||
if self.output_attentions:
|
||||
|
@ -330,7 +330,7 @@ class XLNetRelativeAttention(nn.Module):
|
||||
|
||||
def forward(self, h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems=None, target_mapping=None, head_mask=None):
|
||||
if g is not None:
|
||||
###### Two-stream attention with relative positional encoding.
|
||||
# Two-stream attention with relative positional encoding.
|
||||
# content based attention score
|
||||
if mems is not None and mems.dim() > 1:
|
||||
cat = torch.cat([mems, h], dim=0)
|
||||
@ -346,7 +346,7 @@ class XLNetRelativeAttention(nn.Module):
|
||||
# position-based key head
|
||||
k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r)
|
||||
|
||||
##### h-stream
|
||||
# h-stream
|
||||
# content-stream query head
|
||||
q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
|
||||
|
||||
@ -361,7 +361,7 @@ class XLNetRelativeAttention(nn.Module):
|
||||
# post processing
|
||||
output_h = self.post_attention(h, attn_vec_h)
|
||||
|
||||
##### g-stream
|
||||
# g-stream
|
||||
# query-stream query head
|
||||
q_head_g = torch.einsum("ibh,hnd->ibnd", g, self.q)
|
||||
|
||||
@ -391,7 +391,7 @@ class XLNetRelativeAttention(nn.Module):
|
||||
attn_prob = attn_prob_h, attn_prob_g
|
||||
|
||||
else:
|
||||
###### Multi-head attention with relative positional encoding
|
||||
# Multi-head attention with relative positional encoding
|
||||
if mems is not None and mems.dim() > 1:
|
||||
cat = torch.cat([mems, h], dim=0)
|
||||
else:
|
||||
@ -804,7 +804,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
dtype_float = next(self.parameters()).dtype
|
||||
device = next(self.parameters()).device
|
||||
|
||||
##### Attention mask
|
||||
# Attention mask
|
||||
# causal attention mask
|
||||
if self.attn_type == "uni":
|
||||
attn_mask = self.create_mask(qlen, mlen)
|
||||
@ -849,7 +849,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
else:
|
||||
non_tgt_mask = None
|
||||
|
||||
##### Word embeddings and prepare h & g hidden states
|
||||
# Word embeddings and prepare h & g hidden states
|
||||
if inputs_embeds is not None:
|
||||
word_emb_k = inputs_embeds
|
||||
else:
|
||||
@ -864,7 +864,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
else:
|
||||
output_g = None
|
||||
|
||||
##### Segment embedding
|
||||
# Segment embedding
|
||||
if token_type_ids is not None:
|
||||
# Convert `token_type_ids` to one-hot `seg_mat`
|
||||
if mlen > 0:
|
||||
@ -879,7 +879,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
else:
|
||||
seg_mat = None
|
||||
|
||||
##### Positional encoding
|
||||
# Positional encoding
|
||||
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
|
||||
pos_emb = self.dropout(pos_emb)
|
||||
|
||||
|
@ -178,7 +178,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
|
||||
return True
|
||||
|
||||
|
||||
## Inspired from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
|
||||
# Inspired from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
|
||||
class GradientAccumulator(object):
|
||||
"""Distribution strategies-aware gradient accumulation utility."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user