mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
PPL guide minor code snippet fix (#7938)
This commit is contained in:
parent
0e24e4c136
commit
13842e413c
@ -125,18 +125,19 @@ are 512 preceding tokens available to condition on).
|
||||
lls = []
|
||||
for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
|
||||
begin_loc = max(i + stride - max_length, 0)
|
||||
end_loc = i + stride
|
||||
end_loc = min(i + stride, encodings.input_ids.size(1))
|
||||
trg_len = end_loc - i # may be different from stride on last loop
|
||||
input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device)
|
||||
target_ids = input_ids.clone()
|
||||
target_ids[:,:-stride] = -100
|
||||
target_ids[:,:-trg_len] = -100
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids, labels=target_ids)
|
||||
log_likelihood = outputs[0] * stride
|
||||
log_likelihood = outputs[0] * trg_len
|
||||
|
||||
lls.append(log_likelihood)
|
||||
|
||||
ppl = torch.exp(torch.stack(lls).sum() / i)
|
||||
|
||||
ppl = torch.exp(torch.stack(lls).sum() / end_loc)
|
||||
|
||||
Running this with the stride length equal to the max input length is
|
||||
equivalent to the suboptimal, non-sliding-window strategy we discussed above.
|
||||
|
Loading…
Reference in New Issue
Block a user