mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Set the dataset format used by test_trainer
to float32 (#28920)
Co-authored-by: unit_test <test@unit.com>
This commit is contained in:
parent
7252e8d937
commit
69ca640dd6
@ -176,8 +176,8 @@ class DynamicShapesDataset:
|
||||
np.random.seed(seed)
|
||||
sizes = np.random.randint(1, 20, (length // batch_size,))
|
||||
# For easy batching, we make every batch_size consecutive samples the same size.
|
||||
self.xs = [np.random.normal(size=(s,)) for s in sizes.repeat(batch_size)]
|
||||
self.ys = [np.random.normal(size=(s,)) for s in sizes.repeat(batch_size)]
|
||||
self.xs = [np.random.normal(size=(s,)).astype(np.float32) for s in sizes.repeat(batch_size)]
|
||||
self.ys = [np.random.normal(size=(s,)).astype(np.float32) for s in sizes.repeat(batch_size)]
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
@ -547,7 +547,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
np.random.seed(42)
|
||||
x = np.random.normal(size=(64,)).astype(np.float32)
|
||||
y = 2.0 * x + 3.0 + np.random.normal(scale=0.1, size=(64,))
|
||||
y = 2.0 * x + 3.0 + np.random.normal(scale=0.1, size=(64,)).astype(np.float32)
|
||||
train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y})
|
||||
|
||||
# Base training. Should have the same results as test_reproducible_training
|
||||
|
Loading…
Reference in New Issue
Block a user