mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[FlaxSpeechEncoderDecoderModel] More Rigorous PT-Flax Equivalence Tests (#16589)
This commit is contained in:
parent
c65633156b
commit
8d57c424e0
@ -413,28 +413,22 @@ class FlaxEncoderDecoderMixin:
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
pt_logits = pt_outputs.logits
|
||||
pt_outputs = pt_outputs.to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict)
|
||||
fx_logits = fx_outputs.logits
|
||||
fx_outputs = fx_outputs.to_tuple()
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
|
||||
fx_logits_loaded = fx_outputs_loaded.logits
|
||||
fx_outputs_loaded = fx_outputs_loaded.to_tuple()
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@ -445,12 +439,11 @@ class FlaxEncoderDecoderMixin:
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||
pt_logits_loaded = pt_outputs_loaded.logits
|
||||
pt_outputs_loaded = pt_outputs_loaded.to_tuple()
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user