Fix the TF Trainer gradient accumulation and the TF NER example (#6713)

* Align TF NER example over the PT one

* Fix Dataset call

* Fix gradient accumulation training

* Apply style

* Address Sylvain's comments

* Address Sylvain's comments

* Apply style
This commit is contained in:
Julien Plu 2020-08-27 14:45:34 +02:00 committed by GitHub
parent 41aa2b4ef1
commit 6f289dc97a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 11 deletions

View File

@ -18,6 +18,7 @@
import logging
import os
from dataclasses import dataclass, field
from importlib import import_module
from typing import Dict, List, Optional, Tuple
import numpy as np
@ -32,7 +33,7 @@ from transformers import (
TFTrainer,
TFTrainingArguments,
)
from utils_ner import Split, TFNerDataset, get_labels
from utils_ner import Split, TFTokenClassificationDataset, TokenClassificationTask
logger = logging.getLogger(__name__)
@ -50,6 +51,9 @@ class ModelArguments:
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
task_type: Optional[str] = field(
default="NER", metadata={"help": "Task type to fine tune in training (e.g. NER, POS, etc)"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
@ -102,6 +106,17 @@ def main():
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
module = import_module("tasks")
try:
token_classification_task_clazz = getattr(module, model_args.task_type)
token_classification_task: TokenClassificationTask = token_classification_task_clazz()
except AttributeError:
raise ValueError(
f"Task {model_args.task_type} needs to be defined as a TokenClassificationTask subclass in {module}. "
f"Available tasks classes are: {TokenClassificationTask.__subclasses__()}"
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@ -117,7 +132,7 @@ def main():
logger.info("Training/evaluation parameters %s", training_args)
# Prepare Token Classification task
labels = get_labels(data_args.labels)
labels = token_classification_task.get_labels(data_args.labels)
label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
num_labels = len(labels)
@ -150,7 +165,8 @@ def main():
# Get datasets
train_dataset = (
TFNerDataset(
TFTokenClassificationDataset(
token_classification_task=token_classification_task,
data_dir=data_args.data_dir,
tokenizer=tokenizer,
labels=labels,
@ -163,7 +179,8 @@ def main():
else None
)
eval_dataset = (
TFNerDataset(
TFTokenClassificationDataset(
token_classification_task=token_classification_task,
data_dir=data_args.data_dir,
tokenizer=tokenizer,
labels=labels,
@ -233,7 +250,8 @@ def main():
# Predict
if training_args.do_predict:
test_dataset = TFNerDataset(
test_dataset = TFTokenClassificationDataset(
token_classification_task=token_classification_task,
data_dir=data_args.data_dir,
tokenizer=tokenizer,
labels=labels,

View File

@ -276,7 +276,7 @@ if is_torch_available():
if is_tf_available():
import tensorflow as tf
class TFNerDataset:
class TFTokenClassificationDataset:
"""
This will be superseded by a framework-agnostic approach
soon.

View File

@ -174,7 +174,7 @@ class TFTokenClassificationLoss:
)
# make sure only labels that are not equal to -100
# are taken into account as loss
if tf.math.reduce_any(labels == -1).numpy() is True:
if tf.math.reduce_any(labels == -1):
warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
active_loss = tf.reshape(labels, (-1,)) != -1
else:

View File

@ -620,13 +620,22 @@ class TFTrainer:
self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
else:
for _ in tf.range(self.args.gradient_accumulation_steps):
reduced_features = features[: self.args.train_batch_size / self.args.n_replicas]
reduced_labels = labels[: self.args.train_batch_size / self.args.n_replicas]
reduced_features = {
k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items()
}
reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas]
self.training_step(reduced_features, reduced_labels)
features = tf.concat(
[features[self.args.train_batch_size / self.args.n_replicas :], reduced_features], axis=0
features = {
k: tf.concat(
[ft[self.args.train_batch_size // self.args.n_replicas :], reduced_features[k]], axis=0,
)
for k, ft in features.items()
}
labels = tf.concat(
[labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0
)
gradients = self.gradient_accumulator.gradients