Ensure PT model is in evaluation mode and lightweight forward pass done (#17970)

This commit is contained in:
amyeroberts 2022-07-01 19:33:47 +01:00 committed by GitHub
parent d6cec45801
commit 009171d1ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -145,7 +145,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
# If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in
# recursivelly, keeping the name of the attribute.
if isinstance(pt_out, torch.Tensor):
tensor_difference = np.max(np.abs(pt_out.detach().numpy() - tf_out.numpy()))
tensor_difference = np.max(np.abs(pt_out.numpy() - tf_out.numpy()))
differences[attr_name] = tensor_difference
else:
root_name = attr_name
@ -270,9 +270,13 @@ class PTtoTFCommand(BaseTransformersCLICommand):
# Load models and acquire a basic input compatible with the model.
pt_model = pt_class.from_pretrained(self._local_dir)
pt_model.eval()
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
pt_input, tf_input = self.get_inputs(pt_model, config)
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
with torch.no_grad():
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
del pt_model # will no longer be used, and may have a large memory footprint
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)