From 14f0e8e55734456b71ad8b3c7d94e9d006d7fb8d Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 19 Jun 2019 15:29:28 +0200 Subject: [PATCH] fix cuda --- examples/bertology.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/bertology.py b/examples/bertology.py index 888d95e5c64..6997b9e26de 100644 --- a/examples/bertology.py +++ b/examples/bertology.py @@ -209,8 +209,8 @@ def run_model(): attn_entropy, head_importance, _, _ = compute_heads_importance(args, model, eval_dataloader) # Print/save matrices - np.save(os.path.join(args.output_dir, 'attn_entropy.npy'), attn_entropy) - np.save(os.path.join(args.output_dir, 'head_importance.npy'), head_importance) + np.save(os.path.join(args.output_dir, 'attn_entropy.npy'), attn_entropy.detach().cpu().numpy()) + np.save(os.path.join(args.output_dir, 'head_importance.npy'), head_importance.detach().cpu().numpy()) logger.info("Attention entropies") print_2d_tensor(attn_entropy)