From d9a81fc0c5d8339357a42435009a5be3a190b305 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 18 Feb 2021 16:44:42 -0800 Subject: [PATCH] fix func signature (#10271) --- tests/test_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 30cd08d0095..dc8209ab64e 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -146,7 +146,7 @@ if is_torch_available(): self.double_output = double_output self.config = None - def forward(self, input_x=None, labels=None, **kwargs): + def forward(self, input_x, labels=None, **kwargs): y = input_x * self.a + self.b if labels is None: return (y, y) if self.double_output else (y,) @@ -160,7 +160,7 @@ if is_torch_available(): self.b = torch.nn.Parameter(torch.tensor(b).float()) self.config = None - def forward(self, input_x=None, labels=None, **kwargs): + def forward(self, input_x, labels=None, **kwargs): y = input_x * self.a + self.b result = {"output": y} if labels is not None: @@ -177,7 +177,7 @@ if is_torch_available(): self.b = torch.nn.Parameter(torch.tensor(config.b).float()) self.double_output = config.double_output - def forward(self, input_x=None, labels=None, **kwargs): + def forward(self, input_x, labels=None, **kwargs): y = input_x * self.a + self.b if labels is None: return (y, y) if self.double_output else (y,)