This commit is contained in:
Francesco Saverio Zuppichini 2022-03-09 15:51:56 +01:00 committed by GitHub
parent 38bce1d4cf
commit 1e8f37992f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 8 deletions

View File

@ -2313,16 +2313,16 @@ class MaskFormerModel(MaskFormerPreTrainedModel):
)
queries = transformer_module_output.last_hidden_state
encoder_hidden_states = None
pixel_decoder_hidden_states = None
transformer_decoder_hidden_states = None
hidden_states = None
if output_hidden_states:
encoder_hidden_states = pixel_level_module_output.encoder_hidden_states
pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states
transformer_decoder_hidden_states = transformer_module_output.hidden_states
hidden_states = encoder_hidden_states + pixel_decoder_hidden_states + transformer_decoder_hidden_states
else:
encoder_hidden_states = None
pixel_decoder_hidden_states = None
transformer_decoder_hidden_states = None
hidden_states = None
output = MaskFormerModelOutput(
encoder_last_hidden_state=image_features,
@ -2463,7 +2463,6 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
>>> # you can pass them to feature_extractor for postprocessing
>>> output = feature_extractor.post_process_segmentation(outputs)
>>> output = feature_extractor.post_process_semantic_segmentation(outputs)
>>> output = feature_extractor.post_process_panoptic_segmentation(outputs)
```
"""
@ -2477,7 +2476,7 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
outputs: MaskFormerModelOutput = self.model(
pixel_values,
pixel_mask,
output_hidden_states=output_hidden_states,
output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss,
return_dict=True,
output_attentions=output_attentions,
)

View File

@ -139,7 +139,7 @@ class MaskFormerModelTester:
def comm_check_on_output(result):
# let's still check that all the required stuff is there
self.parent.assertTrue(result.transformer_decoder_hidden_states is not None)
self.parent.assertTrue(result.transformer_decoder_last_hidden_state is not None)
self.parent.assertTrue(result.pixel_decoder_last_hidden_state is not None)
self.parent.assertTrue(result.encoder_last_hidden_state is not None)
# okay, now we need to check the logits shape