mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
fix func signature (#10271)
This commit is contained in:
parent
c6fe17557e
commit
d9a81fc0c5
@ -146,7 +146,7 @@ if is_torch_available():
|
||||
self.double_output = double_output
|
||||
self.config = None
|
||||
|
||||
def forward(self, input_x=None, labels=None, **kwargs):
|
||||
def forward(self, input_x, labels=None, **kwargs):
|
||||
y = input_x * self.a + self.b
|
||||
if labels is None:
|
||||
return (y, y) if self.double_output else (y,)
|
||||
@ -160,7 +160,7 @@ if is_torch_available():
|
||||
self.b = torch.nn.Parameter(torch.tensor(b).float())
|
||||
self.config = None
|
||||
|
||||
def forward(self, input_x=None, labels=None, **kwargs):
|
||||
def forward(self, input_x, labels=None, **kwargs):
|
||||
y = input_x * self.a + self.b
|
||||
result = {"output": y}
|
||||
if labels is not None:
|
||||
@ -177,7 +177,7 @@ if is_torch_available():
|
||||
self.b = torch.nn.Parameter(torch.tensor(config.b).float())
|
||||
self.double_output = config.double_output
|
||||
|
||||
def forward(self, input_x=None, labels=None, **kwargs):
|
||||
def forward(self, input_x, labels=None, **kwargs):
|
||||
y = input_x * self.a + self.b
|
||||
if labels is None:
|
||||
return (y, y) if self.double_output else (y,)
|
||||
|
Loading…
Reference in New Issue
Block a user