mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
fixing transfo eval script
This commit is contained in:
parent
973926431e
commit
ed47cb6cba
@ -111,7 +111,7 @@ def evaluate(eval_iter):
|
|||||||
mems = tuple()
|
mems = tuple()
|
||||||
for idx, (data, target, seq_len) in enumerate(eval_iter):
|
for idx, (data, target, seq_len) in enumerate(eval_iter):
|
||||||
ret = model(data, target, *mems)
|
ret = model(data, target, *mems)
|
||||||
loss, mems = ret[0], ret[1:]
|
loss, mems = ret
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
total_loss += seq_len * loss.item()
|
total_loss += seq_len * loss.item()
|
||||||
total_len += seq_len
|
total_len += seq_len
|
||||||
|
@ -1215,7 +1215,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
# So, have to initialize size(0) mems inside the model forward.
|
# So, have to initialize size(0) mems inside the model forward.
|
||||||
# Moreover, have to return new_mems to allow nn.DataParallel to piece
|
# Moreover, have to return new_mems to allow nn.DataParallel to piece
|
||||||
# them together.
|
# them together.
|
||||||
if not mems: mems = self.init_mems(data)
|
if not mems:
|
||||||
|
mems = self.init_mems(data)
|
||||||
|
|
||||||
hidden, new_mems = self._forward(data, mems=mems)
|
hidden, new_mems = self._forward(data, mems=mems)
|
||||||
if target is None:
|
if target is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user