[s2s] distill: --normalize_hidden --supervise_forward (#6834)

This commit is contained in:
Sam Shleifer 2020-09-04 14:05:56 -04:00 committed by GitHub
parent c5d43a872f
commit 6078b12098
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 18 deletions

View File

@ -5,10 +5,9 @@ export WANDB_PROJECT=dmar
python distillation.py \
--learning_rate=3e-4 \
--do_train \
--do_predict \
--fp16 \
--val_check_interval 0.25 \
--teacher Helsinki-NLP/opus-mt-en-ro --data_dir $ENRO_DIR \
--teacher Helsinki-NLP/opus-mt-en-ro \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--student_decoder_layers 3 --student_encoder_layers 6 \
--freeze_encoder --freeze_embeds \
@ -16,6 +15,6 @@ python distillation.py \
--alpha_hid=3. \
--train_batch_size=$BS --eval_batch_size=$BS \
--tokenizer_name Helsinki-NLP/opus-mt-en-ro \
--warmup_steps 500 --sortish_sampler --logger_name wandb \
--gpus 1 --fp16_opt_level O1 --task translation \
--warmup_steps 500 --logger_name wandb \
--fp16_opt_level O1 --task translation --normalize_hidden \
"$@"

View File

@ -87,10 +87,19 @@ class BartSummarizationDistiller(SummarizationModule):
}
if hparams.length_penalty != -1:
student_updates["length_penalty"] = hparams.length_penalty
d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
hparams.d_layer_to_copy = d_layers_to_copy
hparams.e_layer_to_copy = e_layers_to_copy
d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
if hparams.supervise_forward:
hparams.d_matches = get_layers_to_supervise(
student_updates["decoder_layers"], teacher.config.decoder_layers
)
else:
hparams.d_matches = d_layers_to_copy
hparams.d_layer_to_copy = d_layers_to_copy
kw = teacher.config.to_diff_dict()
kw.update(student_updates)
# Copy weights
@ -221,7 +230,7 @@ class BartSummarizationDistiller(SummarizationModule):
dec_mask = decoder_input_ids.ne(pad_token_id)
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
if self.alpha_hid > 0:
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_matches)
blended_loss = (
self.alpha_ce * loss_ce
@ -237,12 +246,14 @@ class BartSummarizationDistiller(SummarizationModule):
assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
mask = attention_mask.to(hidden_states[0])
valid_count = mask.sum() * hidden_states[0].size(-1)
hidden_losses = [
(F.mse_loss(hidden_states[i], hidden_states_T[j], reduction="none") * mask.unsqueeze(-1)).sum()
/ valid_count
for i, j in enumerate(matches)
]
return sum(hidden_losses)
student_states = torch.stack([hidden_states[i] for i in range(len(matches))])
teacher_states = torch.stack([hidden_states_T[j] for j in matches])
if self.hparams.normalize_hidden:
student_states = F.layer_norm(student_states, student_states.shape[1:])
teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:])
mse = F.mse_loss(student_states, teacher_states, reduction="none")
masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count
return masked_mse
def add_distill_args(parser):
@ -255,6 +266,8 @@ def add_distill_args(parser):
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
parser.add_argument("--no_teacher", action="store_true", default=False)
parser.add_argument("--length_penalty", type=float, default=-1)
parser.add_argument("--supervise_forward", action="store_true", default=False)
parser.add_argument("--normalize_hidden", action="store_true", default=False)
class BartTranslationDistiller(BartSummarizationDistiller):
@ -389,7 +402,7 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
if self.alpha_hid > 0:
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_matches)
blended_loss = (
self.alpha_ce * loss_ce
@ -463,15 +476,28 @@ LAYERS_TO_COPY = {
},
6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))},
}
LAYERS_TO_SUPERVISE = {
12: {1: [11], 2: [5, 11], 3: [3, 7, 11], 6: [1, 3, 5, 8, 10, 11]},
16: {1: [15], 4: [4, 9, 12, 15], 8: [1, 3, 5, 7, 9, 11, 13, 15]},
6: {1: [5], 2: [3, 5], 3: [1, 4, 5], 4: [1, 2, 4, 5]},
2: {1: [1], 2: [0, 1]},
}
def get_layers_to_supervise(n_student, n_teacher):
return LAYERS_TO_SUPERVISE[n_teacher][n_student]
def get_layers_to_copy(n_student, n_teacher):
try:
return LAYERS_TO_COPY[n_teacher][n_student]
val = LAYERS_TO_COPY[n_teacher][n_student]
assert len(LAYERS_TO_SUPERVISE[n_teacher][n_student]) == len(val) == n_student
return val
except KeyError:
warnings.warn(
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
)
if n_student != n_teacher:
warnings.warn(
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
)
return list(range(n_student))

View File

@ -31,6 +31,8 @@ logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = {
"supervise_forward": True,
"normalize_hidden": True,
"label_smoothing": 0.2,
"eval_beams": 1,
"val_metric": "loss",