mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix the TF Trainer gradient accumulation and the TF NER example (#6713)
* Align TF NER example over the PT one * Fix Dataset call * Fix gradient accumulation training * Apply style * Address Sylvain's comments * Address Sylvain's comments * Apply style
This commit is contained in:
parent
41aa2b4ef1
commit
6f289dc97a
@ -18,6 +18,7 @@
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from importlib import import_module
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -32,7 +33,7 @@ from transformers import (
|
||||
TFTrainer,
|
||||
TFTrainingArguments,
|
||||
)
|
||||
from utils_ner import Split, TFNerDataset, get_labels
|
||||
from utils_ner import Split, TFTokenClassificationDataset, TokenClassificationTask
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -50,6 +51,9 @@ class ModelArguments:
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
task_type: Optional[str] = field(
|
||||
default="NER", metadata={"help": "Task type to fine tune in training (e.g. NER, POS, etc)"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
@ -102,6 +106,17 @@ def main():
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
module = import_module("tasks")
|
||||
|
||||
try:
|
||||
token_classification_task_clazz = getattr(module, model_args.task_type)
|
||||
token_classification_task: TokenClassificationTask = token_classification_task_clazz()
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
f"Task {model_args.task_type} needs to be defined as a TokenClassificationTask subclass in {module}. "
|
||||
f"Available tasks classes are: {TokenClassificationTask.__subclasses__()}"
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@ -117,7 +132,7 @@ def main():
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
|
||||
# Prepare Token Classification task
|
||||
labels = get_labels(data_args.labels)
|
||||
labels = token_classification_task.get_labels(data_args.labels)
|
||||
label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
|
||||
num_labels = len(labels)
|
||||
|
||||
@ -150,7 +165,8 @@ def main():
|
||||
|
||||
# Get datasets
|
||||
train_dataset = (
|
||||
TFNerDataset(
|
||||
TFTokenClassificationDataset(
|
||||
token_classification_task=token_classification_task,
|
||||
data_dir=data_args.data_dir,
|
||||
tokenizer=tokenizer,
|
||||
labels=labels,
|
||||
@ -163,7 +179,8 @@ def main():
|
||||
else None
|
||||
)
|
||||
eval_dataset = (
|
||||
TFNerDataset(
|
||||
TFTokenClassificationDataset(
|
||||
token_classification_task=token_classification_task,
|
||||
data_dir=data_args.data_dir,
|
||||
tokenizer=tokenizer,
|
||||
labels=labels,
|
||||
@ -233,7 +250,8 @@ def main():
|
||||
|
||||
# Predict
|
||||
if training_args.do_predict:
|
||||
test_dataset = TFNerDataset(
|
||||
test_dataset = TFTokenClassificationDataset(
|
||||
token_classification_task=token_classification_task,
|
||||
data_dir=data_args.data_dir,
|
||||
tokenizer=tokenizer,
|
||||
labels=labels,
|
||||
|
@ -276,7 +276,7 @@ if is_torch_available():
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
class TFNerDataset:
|
||||
class TFTokenClassificationDataset:
|
||||
"""
|
||||
This will be superseded by a framework-agnostic approach
|
||||
soon.
|
||||
|
@ -174,7 +174,7 @@ class TFTokenClassificationLoss:
|
||||
)
|
||||
# make sure only labels that are not equal to -100
|
||||
# are taken into account as loss
|
||||
if tf.math.reduce_any(labels == -1).numpy() is True:
|
||||
if tf.math.reduce_any(labels == -1):
|
||||
warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
|
||||
active_loss = tf.reshape(labels, (-1,)) != -1
|
||||
else:
|
||||
|
@ -620,13 +620,22 @@ class TFTrainer:
|
||||
self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
|
||||
else:
|
||||
for _ in tf.range(self.args.gradient_accumulation_steps):
|
||||
reduced_features = features[: self.args.train_batch_size / self.args.n_replicas]
|
||||
reduced_labels = labels[: self.args.train_batch_size / self.args.n_replicas]
|
||||
reduced_features = {
|
||||
k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items()
|
||||
}
|
||||
reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas]
|
||||
|
||||
self.training_step(reduced_features, reduced_labels)
|
||||
|
||||
features = tf.concat(
|
||||
[features[self.args.train_batch_size / self.args.n_replicas :], reduced_features], axis=0
|
||||
features = {
|
||||
k: tf.concat(
|
||||
[ft[self.args.train_batch_size // self.args.n_replicas :], reduced_features[k]], axis=0,
|
||||
)
|
||||
for k, ft in features.items()
|
||||
}
|
||||
|
||||
labels = tf.concat(
|
||||
[labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0
|
||||
)
|
||||
|
||||
gradients = self.gradient_accumulator.gradients
|
||||
|
Loading…
Reference in New Issue
Block a user