Fix CTRL test_torchscrip_xxx CI by updating _create_and_check_torchscript (#19786)

* Run inputs before trace

* Run inputs before trace

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-10-21 16:23:13 +02:00 committed by GitHub
parent 31565ff0fd
commit 3a1aeea3c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -658,6 +658,7 @@ class ModelTesterMixin:
attention_mask = inputs["attention_mask"]
decoder_input_ids = inputs["decoder_input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"]
model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
traced_model = torch.jit.trace(
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
)
@ -665,11 +666,13 @@ class ModelTesterMixin:
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
image = inputs["image"].tensor
model(input_ids, bbox, image)
traced_model = torch.jit.trace(
model, (input_ids, bbox, image), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
else:
main_input = inputs[main_input_name]
model(main_input)
traced_model = torch.jit.trace(model, main_input)
except RuntimeError:
self.fail("Couldn't trace module.")