Fix label datatype in TF Trainer (#9616)

* Fix label datatype

* Apply style
This commit is contained in:
Julien Plu 2021-01-20 12:08:00 +01:00 committed by GitHub
parent 76f36e183a
commit 12f0d7e8e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -638,7 +638,15 @@ class TFTrainer:
reduced_features = {
k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items()
}
reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas]
if tf.is_tensor(labels):
reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas]
elif isinstance(labels, dict):
reduced_labels = {
k: lbl[: self.args.train_batch_size // self.args.n_replicas] for k, lbl in labels.items()
}
else:
raise ValueError("The labels must be either a tf.Tensor or a dict.")
self.training_step(reduced_features, reduced_labels, nb_instances_in_global_batch)
@ -650,9 +658,20 @@ class TFTrainer:
for k, ft in features.items()
}
labels = tf.concat(
[labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0
)
if tf.is_tensor(labels):
labels = tf.concat(
[labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0
)
elif isinstance(labels, dict):
labels = {
k: tf.concat(
[lbl[self.args.train_batch_size // self.args.n_replicas :], reduced_labels[k]],
axis=0,
)
for k, lbl in labels.items()
}
else:
raise ValueError("The labels must be either a tf.Tensor or a dict.")
gradients = self.gradient_accumulator.gradients
gradients = [