mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Ensure PT model is in evaluation mode and lightweight forward pass done (#17970)
This commit is contained in:
parent
d6cec45801
commit
009171d1ba
@ -145,7 +145,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
# If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in
|
||||
# recursivelly, keeping the name of the attribute.
|
||||
if isinstance(pt_out, torch.Tensor):
|
||||
tensor_difference = np.max(np.abs(pt_out.detach().numpy() - tf_out.numpy()))
|
||||
tensor_difference = np.max(np.abs(pt_out.numpy() - tf_out.numpy()))
|
||||
differences[attr_name] = tensor_difference
|
||||
else:
|
||||
root_name = attr_name
|
||||
@ -270,9 +270,13 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
|
||||
# Load models and acquire a basic input compatible with the model.
|
||||
pt_model = pt_class.from_pretrained(self._local_dir)
|
||||
pt_model.eval()
|
||||
|
||||
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
|
||||
pt_input, tf_input = self.get_inputs(pt_model, config)
|
||||
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
|
||||
del pt_model # will no longer be used, and may have a large memory footprint
|
||||
|
||||
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
|
||||
|
Loading…
Reference in New Issue
Block a user