Fixes torch jit tracing for LayoutLMv2 model (re-open) (#18313)

* Fixes torch jit tracing for LayoutLMv2 model.
Pytorch seems to reuse memory for input_shape which caused a mismatch in shapes later in the forward pass.

* Fixed code quality

* avoid unneeded allocation of vector for shape
This commit is contained in:
Mikkel Denker 2022-07-27 12:38:40 +02:00 committed by GitHub
parent 1d71ad8905
commit 70e7d1d656
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 11 deletions

View File

@ -805,6 +805,16 @@ class LayoutLMv2Model(LayoutLMv2PreTrainedModel):
return visual_bbox
def _get_input_shape(self, input_ids=None, inputs_embeds=None):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
return input_ids.size()
elif inputs_embeds is not None:
return inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
@add_start_docstrings_to_model_forward(LAYOUTLMV2_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
@ -857,21 +867,14 @@ class LayoutLMv2Model(LayoutLMv2PreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
input_shape = self._get_input_shape(input_ids, inputs_embeds)
device = input_ids.device if input_ids is not None else inputs_embeds.device
visual_shape = list(input_shape)
visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]
visual_shape = torch.Size(visual_shape)
final_shape = list(input_shape)
# needs a new copy of input_shape for tracing. Otherwise wrong dimensions will occur
final_shape = list(self._get_input_shape(input_ids, inputs_embeds))
final_shape[1] += visual_shape[1]
final_shape = torch.Size(final_shape)

View File

@ -260,7 +260,7 @@ class LayoutLMv2ModelTester:
class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False
test_torchscript = False
test_torchscript = True
test_mismatched_shapes = False
all_model_classes = (

View File

@ -648,6 +648,13 @@ class ModelTesterMixin:
traced_model = torch.jit.trace(
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
)
elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
image = inputs["image"].tensor
traced_model = torch.jit.trace(
model, (input_ids, bbox, image), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
else:
main_input = inputs[main_input_name]
traced_model = torch.jit.trace(model, main_input)