[AutoBackbone] Improve API (#20407)

* Add hidden states and attentions to backbone outputs

* Update ResNet

* Fix more tests

* Debug test

* Fix test_determinism

* Fix test_save_load

* Remove file

* Disable fx tests

* Test

* Add fx support for backbones

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge 2022-11-28 17:20:24 +01:00 committed by GitHub
parent 39a72125e7
commit 0bae286de9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 96 additions and 22 deletions

View File

@ -909,6 +909,7 @@ else:
[
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
"MODEL_FOR_AUDIO_XVECTOR_MAPPING",
"MODEL_FOR_BACKBONE_MAPPING",
"MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING",
"MODEL_FOR_CAUSAL_LM_MAPPING",
"MODEL_FOR_CTC_MAPPING",
@ -3981,6 +3982,7 @@ if TYPE_CHECKING:
from .models.auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
MODEL_FOR_AUDIO_XVECTOR_MAPPING,
MODEL_FOR_BACKBONE_MAPPING,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_CTC_MAPPING,

View File

@ -1273,6 +1273,20 @@ class BackboneOutput(ModelOutput):
Args:
feature_maps (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`):
Feature maps of the stages.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, num_channels, height, width)`,
depending on the backbone.
Hidden-states of the model at the output of each stage plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Only applicable if the backbone uses attention.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
feature_maps: Tuple[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None

View File

@ -45,6 +45,7 @@ else:
_import_structure["modeling_auto"] = [
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
"MODEL_FOR_AUDIO_XVECTOR_MAPPING",
"MODEL_FOR_BACKBONE_MAPPING",
"MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING",
"MODEL_FOR_CAUSAL_LM_MAPPING",
"MODEL_FOR_CTC_MAPPING",
@ -199,6 +200,7 @@ if TYPE_CHECKING:
from .modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
MODEL_FOR_AUDIO_XVECTOR_MAPPING,
MODEL_FOR_BACKBONE_MAPPING,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_CTC_MAPPING,

View File

@ -456,7 +456,9 @@ class ResNetBackbone(ResNetPreTrainedModel):
@add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(self, pixel_values: Optional[torch.FloatTensor] = None) -> BackboneOutput:
def forward(
self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
) -> BackboneOutput:
"""
Returns:
@ -478,6 +480,11 @@ class ResNetBackbone(ResNetPreTrainedModel):
>>> outputs = model(**inputs)
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
outputs = self.resnet(pixel_values, output_hidden_states=True, return_dict=True)
hidden_states = outputs.hidden_states
@ -487,4 +494,14 @@ class ResNetBackbone(ResNetPreTrainedModel):
if stage in self.out_features:
feature_maps += (hidden_states[idx],)
return BackboneOutput(feature_maps=feature_maps)
if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output += (outputs.hidden_states,)
return output
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=None,
)

View File

@ -380,6 +380,9 @@ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = None
MODEL_FOR_AUDIO_XVECTOR_MAPPING = None
MODEL_FOR_BACKBONE_MAPPING = None
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = None

View File

@ -34,6 +34,7 @@ from .. import PretrainedConfig, PreTrainedModel, logging
from ..models.auto import get_values
from ..models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_BACKBONE_MAPPING_NAMES,
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_CTC_MAPPING_NAMES,
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
@ -82,6 +83,7 @@ def _generate_supported_model_class_names(
"ctc": MODEL_FOR_CTC_MAPPING_NAMES,
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
"semantic-segmentation": MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
"backbone": MODEL_FOR_BACKBONE_MAPPING_NAMES,
}
if supported_tasks is None:
@ -713,6 +715,7 @@ class HFTracer(Tracer):
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
*get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
*get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES),
]:
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)

View File

@ -141,7 +141,15 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase):
attention_mask and seq_length.
"""
all_model_classes = (ResNetModel, ResNetForImageClassification) if is_torch_available() else ()
all_model_classes = (
(
ResNetModel,
ResNetForImageClassification,
ResNetBackbone,
)
if is_torch_available()
else ()
)
fx_compatible = True
test_pruning = False
@ -247,6 +255,10 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase):
check_hidden_states_output(inputs_dict, config, model_class)
@unittest.skip(reason="ResNet does not use feedforward chunking")
def test_feed_forward_chunking(self):
pass
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)

View File

@ -93,6 +93,7 @@ if is_torch_available():
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
MODEL_FOR_AUDIO_XVECTOR_MAPPING,
MODEL_FOR_BACKBONE_MAPPING,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING,
@ -255,28 +256,35 @@ class ModelTesterMixin:
def test_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_save_load(out1, out2):
# make sure we don't have nans
out_2 = out2.cpu().numpy()
out_2[np.isnan(out_2)] = 0
out_1 = out1.cpu().numpy()
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
out_2 = outputs[0].cpu().numpy()
out_2[np.isnan(out_2)] = 0
first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
model.to(torch_device)
with torch.no_grad():
after_outputs = model(**self._prepare_for_class(inputs_dict, model_class))
second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
# Make sure we don't have nans
out_1 = after_outputs[0].cpu().numpy()
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
if isinstance(first, tuple) and isinstance(second, tuple):
for tensor1, tensor2 in zip(first, second):
check_save_load(tensor1, tensor2)
else:
check_save_load(first, second)
def test_save_load_keys_to_ignore_on_save(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -453,6 +461,15 @@ class ModelTesterMixin:
def test_determinism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_determinism(first, second):
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
@ -461,12 +478,11 @@ class ModelTesterMixin:
first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
if isinstance(first, tuple) and isinstance(second, tuple):
for tensor1, tensor2 in zip(first, second):
check_determinism(tensor1, tensor2)
else:
check_determinism(first, second)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@ -502,7 +518,10 @@ class ModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
if model_class in get_values(MODEL_MAPPING):
if model_class in [
*get_values(MODEL_MAPPING),
*get_values(MODEL_FOR_BACKBONE_MAPPING),
]:
continue
model = model_class(config)
@ -521,7 +540,10 @@ class ModelTesterMixin:
config.use_cache = False
config.return_dict = True
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
if (
model_class in [*get_values(MODEL_MAPPING), *get_values(MODEL_FOR_BACKBONE_MAPPING)]
or not model_class.supports_gradient_checkpointing
):
continue
model = model_class(config)
model.to(torch_device)

View File

@ -47,7 +47,6 @@ PRIVATE_MODELS = [
# Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested
"ResNetBackbone", # Backbones have their own tests.
"CLIPSegDecoder", # Building part of bigger (tested) model.
"TableTransformerEncoder", # Building part of bigger (tested) model.
"TableTransformerDecoder", # Building part of bigger (tested) model.