mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Fixes torch jit tracing for LayoutLMv2 model (re-open) (#18313)
* Fixes torch jit tracing for LayoutLMv2 model. Pytorch seems to reuse memory for input_shape which caused a mismatch in shapes later in the forward pass. * Fixed code quality * avoid unneeded allocation of vector for shape
This commit is contained in:
parent
1d71ad8905
commit
70e7d1d656
@ -805,6 +805,16 @@ class LayoutLMv2Model(LayoutLMv2PreTrainedModel):
|
|||||||
|
|
||||||
return visual_bbox
|
return visual_bbox
|
||||||
|
|
||||||
|
def _get_input_shape(self, input_ids=None, inputs_embeds=None):
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
return input_ids.size()
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
return inputs_embeds.size()[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(LAYOUTLMV2_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
@add_start_docstrings_to_model_forward(LAYOUTLMV2_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
@ -857,21 +867,14 @@ class LayoutLMv2Model(LayoutLMv2PreTrainedModel):
|
|||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
input_shape = self._get_input_shape(input_ids, inputs_embeds)
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
||||||
elif input_ids is not None:
|
|
||||||
input_shape = input_ids.size()
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
visual_shape = list(input_shape)
|
visual_shape = list(input_shape)
|
||||||
visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]
|
visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]
|
||||||
visual_shape = torch.Size(visual_shape)
|
visual_shape = torch.Size(visual_shape)
|
||||||
final_shape = list(input_shape)
|
# needs a new copy of input_shape for tracing. Otherwise wrong dimensions will occur
|
||||||
|
final_shape = list(self._get_input_shape(input_ids, inputs_embeds))
|
||||||
final_shape[1] += visual_shape[1]
|
final_shape[1] += visual_shape[1]
|
||||||
final_shape = torch.Size(final_shape)
|
final_shape = torch.Size(final_shape)
|
||||||
|
|
||||||
|
@ -260,7 +260,7 @@ class LayoutLMv2ModelTester:
|
|||||||
class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
|
class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
test_torchscript = True
|
||||||
test_mismatched_shapes = False
|
test_mismatched_shapes = False
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
|
@ -648,6 +648,13 @@ class ModelTesterMixin:
|
|||||||
traced_model = torch.jit.trace(
|
traced_model = torch.jit.trace(
|
||||||
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||||
)
|
)
|
||||||
|
elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
bbox = inputs["bbox"]
|
||||||
|
image = inputs["image"].tensor
|
||||||
|
traced_model = torch.jit.trace(
|
||||||
|
model, (input_ids, bbox, image), check_trace=False
|
||||||
|
) # when traced model is checked, an error is produced due to name mangling
|
||||||
else:
|
else:
|
||||||
main_input = inputs[main_input_name]
|
main_input = inputs[main_input_name]
|
||||||
traced_model = torch.jit.trace(model, main_input)
|
traced_model = torch.jit.trace(model, main_input)
|
||||||
|
Loading…
Reference in New Issue
Block a user