mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix CTRL test_torchscrip_xxx
CI by updating _create_and_check_torchscript
(#19786)
* Run inputs before trace * Run inputs before trace Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
31565ff0fd
commit
3a1aeea3c5
@ -658,6 +658,7 @@ class ModelTesterMixin:
|
||||
attention_mask = inputs["attention_mask"]
|
||||
decoder_input_ids = inputs["decoder_input_ids"]
|
||||
decoder_attention_mask = inputs["decoder_attention_mask"]
|
||||
model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||
traced_model = torch.jit.trace(
|
||||
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||
)
|
||||
@ -665,11 +666,13 @@ class ModelTesterMixin:
|
||||
input_ids = inputs["input_ids"]
|
||||
bbox = inputs["bbox"]
|
||||
image = inputs["image"].tensor
|
||||
model(input_ids, bbox, image)
|
||||
traced_model = torch.jit.trace(
|
||||
model, (input_ids, bbox, image), check_trace=False
|
||||
) # when traced model is checked, an error is produced due to name mangling
|
||||
else:
|
||||
main_input = inputs[main_input_name]
|
||||
model(main_input)
|
||||
traced_model = torch.jit.trace(model, main_input)
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
Loading…
Reference in New Issue
Block a user