mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fixes torch jit tracing for LayoutLMv2 model (re-open) (#18313)
* Fixes torch jit tracing for LayoutLMv2 model. Pytorch seems to reuse memory for input_shape which caused a mismatch in shapes later in the forward pass. * Fixed code quality * avoid unneeded allocation of vector for shape
This commit is contained in:
parent
1d71ad8905
commit
70e7d1d656
@ -805,6 +805,16 @@ class LayoutLMv2Model(LayoutLMv2PreTrainedModel):
|
||||
|
||||
return visual_bbox
|
||||
|
||||
def _get_input_shape(self, input_ids=None, inputs_embeds=None):
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
return input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
return inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
@add_start_docstrings_to_model_forward(LAYOUTLMV2_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -857,21 +867,14 @@ class LayoutLMv2Model(LayoutLMv2PreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
input_shape = self._get_input_shape(input_ids, inputs_embeds)
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
visual_shape = list(input_shape)
|
||||
visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]
|
||||
visual_shape = torch.Size(visual_shape)
|
||||
final_shape = list(input_shape)
|
||||
# needs a new copy of input_shape for tracing. Otherwise wrong dimensions will occur
|
||||
final_shape = list(self._get_input_shape(input_ids, inputs_embeds))
|
||||
final_shape[1] += visual_shape[1]
|
||||
final_shape = torch.Size(final_shape)
|
||||
|
||||
|
@ -260,7 +260,7 @@ class LayoutLMv2ModelTester:
|
||||
class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_torchscript = True
|
||||
test_mismatched_shapes = False
|
||||
|
||||
all_model_classes = (
|
||||
|
@ -648,6 +648,13 @@ class ModelTesterMixin:
|
||||
traced_model = torch.jit.trace(
|
||||
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||
)
|
||||
elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs
|
||||
input_ids = inputs["input_ids"]
|
||||
bbox = inputs["bbox"]
|
||||
image = inputs["image"].tensor
|
||||
traced_model = torch.jit.trace(
|
||||
model, (input_ids, bbox, image), check_trace=False
|
||||
) # when traced model is checked, an error is produced due to name mangling
|
||||
else:
|
||||
main_input = inputs[main_input_name]
|
||||
traced_model = torch.jit.trace(model, main_input)
|
||||
|
Loading…
Reference in New Issue
Block a user