Fix E266 flake8 warning (x90).

This commit is contained in:
Aymeric Augustin 2019-12-21 21:22:55 +01:00
parent 2ab78325f0
commit fa2ccbc081
30 changed files with 92 additions and 90 deletions

View File

@ -487,7 +487,7 @@ def evaluate(args, model, tokenizer, prefix=""):
def main():
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--train_file", default=None, type=str, required=True, help="SWAG csv for training. E.g., train.csv"
)
@ -520,7 +520,7 @@ def main():
help="The output directory where the model checkpoints and predictions will be written.",
)
## Other parameters
# Other parameters
parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
)

View File

@ -430,7 +430,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
def main():
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
)
@ -486,7 +486,7 @@ def main():
"--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
)
## Other parameters
# Other parameters
parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
)

View File

@ -43,7 +43,7 @@ if __name__ == "__main__":
state_dict = model.state_dict()
compressed_sd = {}
### Embeddings ###
# Embeddings #
if args.model_type == "gpt2":
for param_name in ["wte.weight", "wpe.weight"]:
compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
@ -55,7 +55,7 @@ if __name__ == "__main__":
param_name = f"{prefix}.embeddings.LayerNorm.{w}"
compressed_sd[param_name] = state_dict[param_name]
### Transformer Blocks ###
# Transformer Blocks #
std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]:
if args.model_type == "gpt2":
@ -82,7 +82,7 @@ if __name__ == "__main__":
]
std_idx += 1
### Language Modeling Head ###s
# Language Modeling Head ###s
if args.model_type == "roberta":
for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
compressed_sd[f"{layer}"] = state_dict[f"{layer}"]

View File

@ -219,7 +219,7 @@ def main():
args = parser.parse_args()
sanity_checks(args)
## ARGS ##
# ARGS #
init_gpu_params(args)
set_seed(args)
if args.is_master:
@ -236,7 +236,7 @@ def main():
os.makedirs(args.dump_path)
logger.info(f"Experiment will be dumped and logged in {args.dump_path}")
### SAVE PARAMS ###
# SAVE PARAMS #
logger.info(f"Param: {args}")
with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
json.dump(vars(args), f, indent=4)
@ -245,7 +245,7 @@ def main():
student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]
### TOKENIZER ###
# TOKENIZER #
tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
special_tok_ids = {}
for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
@ -255,7 +255,7 @@ def main():
args.special_tok_ids = special_tok_ids
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
## DATA LOADER ##
# DATA LOADER #
logger.info(f"Loading data from {args.data_file}")
with open(args.data_file, "rb") as fp:
data = pickle.load(fp)
@ -275,7 +275,7 @@ def main():
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
logger.info(f"Data loader created.")
## STUDENT ##
# STUDENT #
logger.info(f"Loading student config from {args.student_config}")
stu_architecture_config = student_config_class.from_pretrained(args.student_config)
stu_architecture_config.output_hidden_states = True
@ -290,26 +290,26 @@ def main():
student.to(f"cuda:{args.local_rank}")
logger.info(f"Student loaded.")
## TEACHER ##
# TEACHER #
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
if args.n_gpu > 0:
teacher.to(f"cuda:{args.local_rank}")
logger.info(f"Teacher loaded from {args.teacher_name}.")
## FREEZING ##
# FREEZING #
if args.freeze_pos_embs:
freeze_pos_embeddings(student, args)
if args.freeze_token_type_embds:
freeze_token_type_embeddings(student, args)
## SANITY CHECKS ##
# SANITY CHECKS #
assert student.config.vocab_size == teacher.config.vocab_size
assert student.config.hidden_size == teacher.config.hidden_size
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
if args.mlm:
assert token_probs.size(0) == stu_architecture_config.vocab_size
## DISTILLER ##
# DISTILLER #
torch.cuda.empty_cache()
distiller = Distiller(
params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher

View File

@ -344,7 +344,7 @@ def load_examples(args, tokenizer, evaluate=False):
def main():
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--data_dir",
default=None,
@ -374,7 +374,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.",
)
## Other parameters
# Other parameters
parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
)

View File

@ -242,7 +242,7 @@ def prune_heads(args, model, eval_dataloader, head_mask):
def main():
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--data_dir",
default=None,
@ -272,7 +272,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.",
)
## Other parameters
# Other parameters
parser.add_argument(
"--config_name",
default="",

View File

@ -410,7 +410,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
def main():
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--data_dir",
default=None,
@ -447,7 +447,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.",
)
## Other parameters
# Other parameters
parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
)

View File

@ -422,7 +422,7 @@ def evaluate(args, model, tokenizer, prefix=""):
def main():
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)."
)
@ -434,7 +434,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.",
)
## Other parameters
# Other parameters
parser.add_argument(
"--eval_data_file",
default=None,

View File

@ -385,7 +385,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False):
def main():
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--data_dir",
default=None,
@ -422,7 +422,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.",
)
## Other parameters
# Other parameters
parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
)

View File

@ -385,7 +385,7 @@ def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode):
def main():
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--data_dir",
default=None,
@ -415,7 +415,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.",
)
## Other parameters
# Other parameters
parser.add_argument(
"--labels",
default="",

View File

@ -377,7 +377,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
def main():
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--data_dir",
default=None,
@ -417,7 +417,7 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.",
)
## Other parameters
# Other parameters
parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
)

View File

@ -401,7 +401,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
def main():
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
)
@ -434,7 +434,7 @@ def main():
help="The output directory where the model checkpoints and predictions will be written.",
)
## Other parameters
# Other parameters
parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
)

View File

@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)

View File

@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pyt
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)

View File

@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)

View File

@ -51,7 +51,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--gpt2_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)

View File

@ -51,7 +51,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--openai_checkpoint_folder_path",
default=None,

View File

@ -410,7 +410,7 @@ def convert_all_pt_checkpoints_to_tf(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
)

View File

@ -94,7 +94,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
layer: BertLayer = model.roberta.encoder.layer[i]
roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i]
### self attention
# self attention
self_attn: BertSelfAttention = layer.attention.self
assert (
roberta_layer.self_attn.k_proj.weight.data.shape
@ -110,7 +110,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
self_attn.value.weight.data = roberta_layer.self_attn.v_proj.weight
self_attn.value.bias.data = roberta_layer.self_attn.v_proj.bias
### self-attention output
# self-attention output
self_output: BertSelfOutput = layer.attention.output
assert self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
@ -118,20 +118,20 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias
### intermediate
# intermediate
intermediate: BertIntermediate = layer.intermediate
assert intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
intermediate.dense.weight = roberta_layer.fc1.weight
intermediate.dense.bias = roberta_layer.fc1.bias
### output
# output
bert_output: BertOutput = layer.output
assert bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
bert_output.dense.weight = roberta_layer.fc2.weight
bert_output.dense.bias = roberta_layer.fc2.bias
bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias
#### end of layer
# end of layer
if classification_head:
model.classifier.dense.weight = roberta.model.classification_heads["mnli"].dense.weight
@ -170,7 +170,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--roberta_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
)

View File

@ -43,7 +43,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)

View File

@ -70,7 +70,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--xlm_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
)

View File

@ -82,7 +82,7 @@ def convert_xlnet_checkpoint_to_pytorch(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)

View File

@ -47,7 +47,7 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
}
### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ###
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
def gelu(x):
return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))
@ -327,7 +327,7 @@ class Transformer(nn.Module):
return outputs # last-layer hidden state, (all hidden states), (all attentions)
### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ###
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
class DistilBertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.

View File

@ -42,7 +42,7 @@ TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
}
### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ###
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
def gelu(x):
""" Gaussian Error Linear Unit.
Original Implementation of the gelu activation function in Google Bert repo when initially created.
@ -463,7 +463,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ###
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
class TFDistilBertPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.

View File

@ -67,7 +67,8 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
#####################
### PyTorch => TF 2.0
# PyTorch => TF 2.0 #
#####################
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False):
@ -197,7 +198,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
#####################
### TF 2.0 => PyTorch
# TF 2.0 => PyTorch #
#####################
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False):

View File

@ -79,23 +79,23 @@ class TFPositionwiseFF(tf.keras.layers.Layer):
def call(self, inp, training=False):
if self.pre_lnorm:
##### layer normalization + positionwise feed-forward
# layer normalization + positionwise feed-forward
core_out = self.layer_norm(inp)
core_out = self.layer_1(core_out)
core_out = self.drop_1(core_out, training=training)
core_out = self.layer_2(core_out)
core_out = self.drop_2(core_out, training=training)
##### residual connection
# residual connection
output = core_out + inp
else:
##### positionwise feed-forward
# positionwise feed-forward
core_out = self.layer_1(inp)
core_out = self.drop_1(core_out, training=training)
core_out = self.layer_2(core_out)
core_out = self.drop_2(core_out, training=training)
##### residual connection + layer normalization
# residual connection + layer normalization
output = self.layer_norm(inp + core_out)
return output
@ -206,7 +206,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
r_head_k = tf.reshape(r_head_k, (rlen, self.n_head, self.d_head)) # qlen x n_head x d_head
#### compute attention score
# compute attention score
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
AC = tf.einsum("ibnd,jbnd->ijbn", rw_head_q, w_head_k) # qlen x klen x bsz x n_head
@ -218,7 +218,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
attn_score = AC + BD
attn_score = attn_score * self.scale
#### compute attention probability
# compute attention probability
if attn_mask is not None:
attn_mask_t = attn_mask[:, :, None, None]
attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t
@ -231,22 +231,22 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
if head_mask is not None:
attn_prob = attn_prob * head_mask
#### compute attention vector
# compute attention vector
attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, w_head_v)
# [qlen x bsz x n_head x d_head]
attn_vec_sizes = shape_list(attn_vec)
attn_vec = tf.reshape(attn_vec, (attn_vec_sizes[0], attn_vec_sizes[1], self.n_head * self.d_head))
##### linear projection
# linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out, training=training)
if self.pre_lnorm:
##### residual connection
# residual connection
outputs = [w + attn_out]
else:
##### residual connection + layer normalization
# residual connection + layer normalization
outputs = [self.layer_norm(w + attn_out)]
if self.output_attentions:

View File

@ -190,7 +190,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
(h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask) = inputs
if g is not None:
###### Two-stream attention with relative positional encoding.
# Two-stream attention with relative positional encoding.
# content based attention score
if mems is not None and len(shape_list(mems)) > 1:
cat = tf.concat([mems, h], axis=0)
@ -206,7 +206,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# position-based key head
k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
##### h-stream
# h-stream
# content-stream query head
q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
@ -221,7 +221,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
# post processing
output_h = self.post_attention([h, attn_vec_h], training=training)
##### g-stream
# g-stream
# query-stream query head
q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.q)
@ -251,7 +251,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
attn_prob = attn_prob_h, attn_prob_g
else:
###### Multi-head attention with relative positional encoding
# Multi-head attention with relative positional encoding
if mems is not None and len(shape_list(mems)) > 1:
cat = tf.concat([mems, h], axis=0)
else:
@ -552,7 +552,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32
##### Attention mask
# Attention mask
# causal attention mask
if self.attn_type == "uni":
attn_mask = self.create_mask(qlen, mlen)
@ -597,7 +597,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else:
non_tgt_mask = None
##### Word embeddings and prepare h & g hidden states
# Word embeddings and prepare h & g hidden states
if inputs_embeds is not None:
word_emb_k = inputs_embeds
else:
@ -612,7 +612,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else:
output_g = None
##### Segment embedding
# Segment embedding
if token_type_ids is not None:
# Convert `token_type_ids` to one-hot `seg_mat`
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
@ -624,7 +624,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else:
seg_mat = None
##### Positional encoding
# Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float)
pos_emb = self.dropout(pos_emb, training=training)

View File

@ -213,16 +213,16 @@ class PositionwiseFF(nn.Module):
def forward(self, inp):
if self.pre_lnorm:
##### layer normalization + positionwise feed-forward
# layer normalization + positionwise feed-forward
core_out = self.CoreNet(self.layer_norm(inp))
##### residual connection
# residual connection
output = core_out + inp
else:
##### positionwise feed-forward
# positionwise feed-forward
core_out = self.CoreNet(inp)
##### residual connection + layer normalization
# residual connection + layer normalization
output = self.layer_norm(inp + core_out)
return output
@ -316,7 +316,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
#### compute attention score
# compute attention score
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
AC = torch.einsum("ibnd,jbnd->ijbn", (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
@ -328,7 +328,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
attn_score = AC + BD
attn_score.mul_(self.scale)
#### compute attention probability
# compute attention probability
if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = attn_mask == 1 # Switch to bool
if attn_mask.dim() == 2:
@ -352,21 +352,21 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
if head_mask is not None:
attn_prob = attn_prob * head_mask
#### compute attention vector
# compute attention vector
attn_vec = torch.einsum("ijbn,jbnd->ibnd", (attn_prob, w_head_v))
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
##### linear projection
# linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
# residual connection
outputs = [w + attn_out]
else:
##### residual connection + layer normalization
# residual connection + layer normalization
outputs = [self.layer_norm(w + attn_out)]
if self.output_attentions:

View File

@ -330,7 +330,7 @@ class XLNetRelativeAttention(nn.Module):
def forward(self, h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems=None, target_mapping=None, head_mask=None):
if g is not None:
###### Two-stream attention with relative positional encoding.
# Two-stream attention with relative positional encoding.
# content based attention score
if mems is not None and mems.dim() > 1:
cat = torch.cat([mems, h], dim=0)
@ -346,7 +346,7 @@ class XLNetRelativeAttention(nn.Module):
# position-based key head
k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r)
##### h-stream
# h-stream
# content-stream query head
q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
@ -361,7 +361,7 @@ class XLNetRelativeAttention(nn.Module):
# post processing
output_h = self.post_attention(h, attn_vec_h)
##### g-stream
# g-stream
# query-stream query head
q_head_g = torch.einsum("ibh,hnd->ibnd", g, self.q)
@ -391,7 +391,7 @@ class XLNetRelativeAttention(nn.Module):
attn_prob = attn_prob_h, attn_prob_g
else:
###### Multi-head attention with relative positional encoding
# Multi-head attention with relative positional encoding
if mems is not None and mems.dim() > 1:
cat = torch.cat([mems, h], dim=0)
else:
@ -804,7 +804,7 @@ class XLNetModel(XLNetPreTrainedModel):
dtype_float = next(self.parameters()).dtype
device = next(self.parameters()).device
##### Attention mask
# Attention mask
# causal attention mask
if self.attn_type == "uni":
attn_mask = self.create_mask(qlen, mlen)
@ -849,7 +849,7 @@ class XLNetModel(XLNetPreTrainedModel):
else:
non_tgt_mask = None
##### Word embeddings and prepare h & g hidden states
# Word embeddings and prepare h & g hidden states
if inputs_embeds is not None:
word_emb_k = inputs_embeds
else:
@ -864,7 +864,7 @@ class XLNetModel(XLNetPreTrainedModel):
else:
output_g = None
##### Segment embedding
# Segment embedding
if token_type_ids is not None:
# Convert `token_type_ids` to one-hot `seg_mat`
if mlen > 0:
@ -879,7 +879,7 @@ class XLNetModel(XLNetPreTrainedModel):
else:
seg_mat = None
##### Positional encoding
# Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
pos_emb = self.dropout(pos_emb)

View File

@ -178,7 +178,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
return True
## Inspired from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
# Inspired from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
class GradientAccumulator(object):
"""Distribution strategies-aware gradient accumulation utility."""