mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
fix cuda
This commit is contained in:
parent
34d706a0e1
commit
14f0e8e557
@ -209,8 +209,8 @@ def run_model():
|
||||
attn_entropy, head_importance, _, _ = compute_heads_importance(args, model, eval_dataloader)
|
||||
|
||||
# Print/save matrices
|
||||
np.save(os.path.join(args.output_dir, 'attn_entropy.npy'), attn_entropy)
|
||||
np.save(os.path.join(args.output_dir, 'head_importance.npy'), head_importance)
|
||||
np.save(os.path.join(args.output_dir, 'attn_entropy.npy'), attn_entropy.detach().cpu().numpy())
|
||||
np.save(os.path.join(args.output_dir, 'head_importance.npy'), head_importance.detach().cpu().numpy())
|
||||
|
||||
logger.info("Attention entropies")
|
||||
print_2d_tensor(attn_entropy)
|
||||
|
Loading…
Reference in New Issue
Block a user