mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix label datatype in TF Trainer (#9616)
* Fix label datatype * Apply style
This commit is contained in:
parent
76f36e183a
commit
12f0d7e8e0
@ -638,7 +638,15 @@ class TFTrainer:
|
||||
reduced_features = {
|
||||
k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items()
|
||||
}
|
||||
reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas]
|
||||
|
||||
if tf.is_tensor(labels):
|
||||
reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas]
|
||||
elif isinstance(labels, dict):
|
||||
reduced_labels = {
|
||||
k: lbl[: self.args.train_batch_size // self.args.n_replicas] for k, lbl in labels.items()
|
||||
}
|
||||
else:
|
||||
raise ValueError("The labels must be either a tf.Tensor or a dict.")
|
||||
|
||||
self.training_step(reduced_features, reduced_labels, nb_instances_in_global_batch)
|
||||
|
||||
@ -650,9 +658,20 @@ class TFTrainer:
|
||||
for k, ft in features.items()
|
||||
}
|
||||
|
||||
labels = tf.concat(
|
||||
[labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0
|
||||
)
|
||||
if tf.is_tensor(labels):
|
||||
labels = tf.concat(
|
||||
[labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0
|
||||
)
|
||||
elif isinstance(labels, dict):
|
||||
labels = {
|
||||
k: tf.concat(
|
||||
[lbl[self.args.train_batch_size // self.args.n_replicas :], reduced_labels[k]],
|
||||
axis=0,
|
||||
)
|
||||
for k, lbl in labels.items()
|
||||
}
|
||||
else:
|
||||
raise ValueError("The labels must be either a tf.Tensor or a dict.")
|
||||
|
||||
gradients = self.gradient_accumulator.gradients
|
||||
gradients = [
|
||||
|
Loading…
Reference in New Issue
Block a user