Fix multi_gpu_data_parallel_forward for MusicgenTest (#29632)

update

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2024-03-13 19:12:20 +01:00 committed by GitHub
parent 5ac264d8a8
commit fe085560d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -103,7 +103,7 @@ class MusicgenDecoderTester:
def __init__(
self,
parent,
batch_size=3, # need batch_size != num_hidden_layers
batch_size=4, # need batch_size != num_hidden_layers
seq_length=7,
is_training=False,
use_labels=False,
@ -441,7 +441,7 @@ class MusicgenTester:
def __init__(
self,
parent,
batch_size=3, # need batch_size != num_hidden_layers
batch_size=4, # need batch_size != num_hidden_layers
seq_length=7,
is_training=False,
use_labels=False,