This commit is contained in:
Sylvain Gugger 2021-03-08 16:04:46 -05:00
parent b35e7b68ca
commit a8ec52efc2
2 changed files with 2 additions and 2 deletions

View File

@ -992,7 +992,7 @@ class TrainerIntegrationTest(unittest.TestCase):
# should be about half of fp16_init
# perfect world: fp32_init/2 == fp16_eval
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
def test_no_wd_param_group(self):
model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)]))
trainer = Trainer(model=model)

View File

@ -30,7 +30,7 @@ if is_torch_available():
DistributedTensorGatherer,
LabelSmoother,
LengthGroupedSampler,
get_parameter_names
get_parameter_names,
)
class TstLayer(torch.nn.Module):