fix func signature (#10271)

This commit is contained in:
Stas Bekman 2021-02-18 16:44:42 -08:00 committed by GitHub
parent c6fe17557e
commit d9a81fc0c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -146,7 +146,7 @@ if is_torch_available():
self.double_output = double_output
self.config = None
def forward(self, input_x=None, labels=None, **kwargs):
def forward(self, input_x, labels=None, **kwargs):
y = input_x * self.a + self.b
if labels is None:
return (y, y) if self.double_output else (y,)
@ -160,7 +160,7 @@ if is_torch_available():
self.b = torch.nn.Parameter(torch.tensor(b).float())
self.config = None
def forward(self, input_x=None, labels=None, **kwargs):
def forward(self, input_x, labels=None, **kwargs):
y = input_x * self.a + self.b
result = {"output": y}
if labels is not None:
@ -177,7 +177,7 @@ if is_torch_available():
self.b = torch.nn.Parameter(torch.tensor(config.b).float())
self.double_output = config.double_output
def forward(self, input_x=None, labels=None, **kwargs):
def forward(self, input_x, labels=None, **kwargs):
y = input_x * self.a + self.b
if labels is None:
return (y, y) if self.double_output else (y,)