
* Added pytests for pvt-v2, all passed
* Added pvt_v2 to docs/source/end/model_doc
* Ran fix-copies and fixup. All checks passed
* Added additional ReLU for linear attention mode
* pvt_v2_b2_linear converted and working
* copied models/pvt to adapt to pvt_v2
* First commit of pvt_v2
* PvT-v2 now works in AutoModel
* Reverted batch eval changes for PR
* Expanded type support for Pvt-v2 config
* Fixed config docstring. Added channels property
* Fixed model names in tests
* Fixed config backbone compat. Added additional type support for image size in config
* Fixed config backbone compat
* Allowed for batching of eval metrics
* copied models/pvt to adapt to pvt_v2
* First commit of pvt_v2
* Set key and value layers to use separate linear modules. Fixed pruning function
* Set AvgPool to 7
* Fixed issue in init
* PvT-v2 now works in AutoModel
* Successful conversion of pretrained weights for PVT-v2
* Successful conversion of pretrained weights for PVT-v2 models
* Added pytests for pvt-v2, all passed
* Ran fix-copies and fixup. All checks passed
* Added additional ReLU for linear attention mode
* pvt_v2_b2_linear converted and working
* Allowed for batching of eval metrics
* copied models/pvt to adapt to pvt_v2
* First commit of pvt_v2
* Set key and value layers to use separate linear modules. Fixed pruning function
* Set AvgPool to 7
* Fixed issue in init
* PvT-v2 now works in AutoModel
* Successful conversion of pretrained weights for PVT-v2
* Successful conversion of pretrained weights for PVT-v2 models
* Added pytests for pvt-v2, all passed
* Ran fix-copies and fixup. All checks passed
* Added additional ReLU for linear attention mode
* pvt_v2_b2_linear converted and working
* Reverted batch eval changes for PR
* Updated index.md
* Expanded type support for Pvt-v2 config
* Fixed config docstring. Added channels property
* Fixed model names in tests
* Fixed config backbone compat
* Ran fix-copies
* Fixed PvtV2Backbone tests
* Added TFRegNet to OBJECTS_TO_IGNORE in check_docstrings.py
* Fixed backbone stuff and fixed tests: all passing
* Ran make fixup
* Made modifications for code checks
* Remove ONNX config from configuration_pvt_v2.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Use explicit image size dict in test_modeling_pvt_v2.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Make image_size optional in test_modeling_pvt_v2.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Remove _ntuple use in modeling_pvt_v2.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Remove reference to fp16_enabled
* Model modules now take config as first argument even when not used
* Replaced abbreviations for "SR" and "AP" with explicit "spatialreduction" and "averagepooling"
* All LayerNorm now instantiates with config.layer_norm_eps
* Added docstring for depth-wise conv layer
* PvtV2Config now only takes Union[int, Tuple[int, int]] for image size
* Refactored PVTv2 in prep for gradient checkpointing
* Gradient checkpointing ready to test
* Removed override of _set_gradient_checkpointing
* Cleaned out old code
* Applied code fixup
* Applied code fixup
* Began debug of pvt_v2 tests
* Leave handling of num_labels to base pretrained config class
* Deactivated gradient checkpointing tests until it is fixed
* Removed PvtV2ImageProcessor which duped PvtImageProcessor
* Allowed for batching of eval metrics
* copied models/pvt to adapt to pvt_v2
* First commit of pvt_v2
* Set key and value layers to use separate linear modules. Fixed pruning function
* Set AvgPool to 7
* Fixed issue in init
* PvT-v2 now works in AutoModel
* Successful conversion of pretrained weights for PVT-v2
* Successful conversion of pretrained weights for PVT-v2 models
* Added pytests for pvt-v2, all passed
* Added pvt_v2 to docs/source/end/model_doc
* Ran fix-copies and fixup. All checks passed
* Added additional ReLU for linear attention mode
* pvt_v2_b2_linear converted and working
* copied models/pvt to adapt to pvt_v2
* First commit of pvt_v2
* PvT-v2 now works in AutoModel
* Reverted batch eval changes for PR
* Expanded type support for Pvt-v2 config
* Fixed config docstring. Added channels property
* Fixed model names in tests
* Fixed config backbone compat. Added additional type support for image size in config
* Fixed config backbone compat
* Allowed for batching of eval metrics
* copied models/pvt to adapt to pvt_v2
* First commit of pvt_v2
* Set key and value layers to use separate linear modules. Fixed pruning function
* Set AvgPool to 7
* Fixed issue in init
* PvT-v2 now works in AutoModel
* Successful conversion of pretrained weights for PVT-v2
* Successful conversion of pretrained weights for PVT-v2 models
* Added pytests for pvt-v2, all passed
* Ran fix-copies and fixup. All checks passed
* Added additional ReLU for linear attention mode
* pvt_v2_b2_linear converted and working
* Allowed for batching of eval metrics
* copied models/pvt to adapt to pvt_v2
* First commit of pvt_v2
* Set key and value layers to use separate linear modules. Fixed pruning function
* Set AvgPool to 7
* Fixed issue in init
* PvT-v2 now works in AutoModel
* Successful conversion of pretrained weights for PVT-v2
* Successful conversion of pretrained weights for PVT-v2 models
* Added pytests for pvt-v2, all passed
* Ran fix-copies and fixup. All checks passed
* Added additional ReLU for linear attention mode
* pvt_v2_b2_linear converted and working
* Reverted batch eval changes for PR
* Expanded type support for Pvt-v2 config
* Fixed config docstring. Added channels property
* Fixed model names in tests
* Fixed config backbone compat
* Ran fix-copies
* Fixed PvtV2Backbone tests
* Added TFRegNet to OBJECTS_TO_IGNORE in check_docstrings.py
* Fixed backbone stuff and fixed tests: all passing
* Ran make fixup
* Made modifications for code checks
* Remove ONNX config from configuration_pvt_v2.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Use explicit image size dict in test_modeling_pvt_v2.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Make image_size optional in test_modeling_pvt_v2.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Remove _ntuple use in modeling_pvt_v2.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Remove reference to fp16_enabled
* Model modules now take config as first argument even when not used
* Replaced abbreviations for "SR" and "AP" with explicit "spatialreduction" and "averagepooling"
* All LayerNorm now instantiates with config.layer_norm_eps
* Added docstring for depth-wise conv layer
* PvtV2Config now only takes Union[int, Tuple[int, int]] for image size
* Refactored PVTv2 in prep for gradient checkpointing
* Gradient checkpointing ready to test
* Removed override of _set_gradient_checkpointing
* Cleaned out old code
* Applied code fixup
* Applied code fixup
* Allowed for batching of eval metrics
* copied models/pvt to adapt to pvt_v2
* First commit of pvt_v2
* PvT-v2 now works in AutoModel
* Ran fix-copies and fixup. All checks passed
* copied models/pvt to adapt to pvt_v2
* First commit of pvt_v2
* PvT-v2 now works in AutoModel
* Reverted batch eval changes for PR
* Fixed config docstring. Added channels property
* Fixed config backbone compat
* Allowed for batching of eval metrics
* copied models/pvt to adapt to pvt_v2
* First commit of pvt_v2
* PvT-v2 now works in AutoModel
* Ran fix-copies and fixup. All checks passed
* Allowed for batching of eval metrics
* copied models/pvt to adapt to pvt_v2
* First commit of pvt_v2
* PvT-v2 now works in AutoModel
* Fixed config backbone compat
* Ran fix-copies
* Began debug of pvt_v2 tests
* Leave handling of num_labels to base pretrained config class
* Deactivated gradient checkpointing tests until it is fixed
* Removed PvtV2ImageProcessor which duped PvtImageProcessor
* Fixed issue from rebase
* Fixed issue from rebase
* Set tests for gradient checkpointing to skip those using reentrant since it isn't supported
* Fixed issue from rebase
* Fixed issue from rebase
* Changed model name in docs
* Removed duplicate PvtV2Backbone
* Work around type switching issue in tests
* Fix model name in config comments
* Update docs/source/en/model_doc/pvt_v2.md
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Changed name of variable from 'attn_reduce' to 'sr_type'
* Changed name of variable from 'attn_reduce' to 'sr_type'
* Changed from using 'sr_type' to 'linear_attention' for clarity
* Update src/transformers/models/pvt_v2/modeling_pvt_v2.py
Removed old code
* Changed from using 'sr_type' to 'linear_attention' for clarity
* Fixed Class names to be more descriptive
* Update src/transformers/models/pvt_v2/modeling_pvt_v2.py
Removed outdated code
* Moved paper abstract to single line in pvt_v2.md
* Added usage tips to pvt_v2.md
* Simplified module inits by passing layer_idx
* Fixed typing for hidden_act in PvtV2Config
* Removed unusued import
* Add pvt_v2 to docs/source/en/_toctree.yml
* Updated documentation in docs/source/en/model_doc/pvt_v2.md to be more comprehensive.
* Updated documentation in docs/source/en/model_doc/pvt_v2.md to be more comprehensive.
* Update src/transformers/models/pvt_v2/modeling_pvt_v2.py
Move function parameters to single line
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Update src/transformers/models/pvt_v2/modeling_pvt_v2.py
Update year of copyright to 2024
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Update src/transformers/models/pvt_v2/modeling_pvt_v2.py
Make code more explicit
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Updated sr_ratio to be more explicit spatial_reduction_ratio
* Removed excess type hints in modeling_pvt_v2.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Move params to single line in modeling_pvt_v2.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Removed needless comment in modeling_pvt_v2.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Update copyright date in pvt_v2.md
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Moved params to single line in modeling_pvt_v2.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Updated copyright date in configuration_pvt_v2.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Cleaned comments in modeling_pvt_v2.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Renamed spatial_reduction Conv2D operation
* Revert "Update src/transformers/models/pvt_v2/modeling_pvt_v2.py
"
This reverts commit c4a04416dd
.
* Updated conversion script to reflect module name change
* Deprecated reshape_last_stage option in config
* Removed unused imports
* Code formatting
* Fixed outdated decorators on test_inference_fp16
* Added "Copied from" comments in test_modeling_pvt_v2.py
* Fixed import listing
* Updated model name
* Force empty commit for PR refresh
* Fixed linting issue
* Removed # Copied from comments
* Added PVTv2 to README_fr.md
* Ran make fix-copies
* Replace all FoamoftheSea hub references with OpenGVLab
* Fixed out_indices and out_features logic in configuration_pvt_v2.py
* Made ImageNet weight conversion verification optional in convert_pvt_v2_to_pytorch.py
* Ran code fixup
* Fixed order of parent classes in PvtV2Config to fix the to_dict method override
---------
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
8.7 KiB
Pyramid Vision Transformer V2 (PVTv2)
Overview
The PVTv2 model was proposed in PVT v2: Improved Baselines with Pyramid Vision Transformer by Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, Kaitao Song, Ding Liang, Tong Lu, Ping Luo, and Ling Shao. As an improved variant of PVT, it eschews position embeddings, relying instead on positional information encoded through zero-padding and overlapping patch embeddings. This lack of reliance on position embeddings simplifies the architecture, and enables running inference at any resolution without needing to interpolate them.
The PVTv2 encoder structure has been successfully deployed to achieve state-of-the-art scores in Segformer for semantic segmentation, GLPN for monocular depth, and Panoptic Segformer for panoptic segmentation.
PVTv2 belongs to a family of models called hierarchical transformers , which make adaptations to transformer layers in order to generate multi-scale feature maps. Unlike the columnal structure of Vision Transformer (ViT) which loses fine-grained detail, multi-scale feature maps are known preserve this detail and aid performance in dense prediction tasks. In the case of PVTv2, this is achieved by generating image patch tokens using 2D convolution with overlapping kernels in each encoder layer.
The multi-scale features of hierarchical transformers allow them to be easily swapped in for traditional workhorse computer vision backbone models like ResNet in larger architectures. Both Segformer and Panoptic Segformer demonstrated that configurations using PVTv2 for a backbone consistently outperformed those with similarly sized ResNet backbones.
Another powerful feature of the PVTv2 is the complexity reduction in the self-attention layers called Spatial Reduction Attention (SRA), which uses 2D convolution layers to project hidden states to a smaller resolution before attending to them with the queries, improving the O(n^2)
complexity of self-attention to O(n^2/R)
, with R
being the spatial reduction ratio (sr_ratio
, aka kernel size and stride in the 2D convolution).
SRA was introduced in PVT, and is the default attention complexity reduction method used in PVTv2. However, PVTv2 also introduced the option of using a self-attention mechanism with linear complexity related to image size, which they called "Linear SRA". This method uses average pooling to reduce the hidden states to a fixed size that is invariant to their original resolution (although this is inherently more lossy than regular SRA). This option can be enabled by setting linear_attention
to True
in the PVTv2Config.
Abstract from the paper:
Transformer recently has presented encouraging progress in computer vision. In this work, we present new baselines by improving the original Pyramid Vision Transformer (PVT v1) by adding three designs, including (1) linear complexity attention layer, (2) overlapping patch embedding, and (3) convolutional feed-forward network. With these modifications, PVT v2 reduces the computational complexity of PVT v1 to linear and achieves significant improvements on fundamental vision tasks such as classification, detection, and segmentation. Notably, the proposed PVT v2 achieves comparable or better performances than recent works such as Swin Transformer. We hope this work will facilitate state-of-the-art Transformer researches in computer vision. Code is available at https://github.com/whai362/PVT.
This model was contributed by FoamoftheSea. The original code can be found here.
Usage tips
- PVTv2 is a hierarchical transformer model which has demonstrated powerful performance in image classification and multiple other tasks, used as a backbone for semantic segmentation in Segformer, monocular depth estimation in GLPN, and panoptic segmentation in Panoptic Segformer, consistently showing higher performance than similar ResNet configurations.
- Hierarchical transformers like PVTv2 achieve superior data and parameter efficiency on image data compared with pure transformer architectures by incorporating design elements of convolutional neural networks (CNNs) into their encoders. This creates a best-of-both-worlds architecture that infuses the useful inductive biases of CNNs like translation equivariance and locality into the network while still enjoying the benefits of dynamic data response and global relationship modeling provided by the self-attention mechanism of transformers.
- PVTv2 uses overlapping patch embeddings to create multi-scale feature maps, which are infused with location information using zero-padding and depth-wise convolutions.
- To reduce the complexity in the attention layers, PVTv2 performs a spatial reduction on the hidden states using either strided 2D convolution (SRA) or fixed-size average pooling (Linear SRA). Although inherently more lossy, Linear SRA provides impressive performance with a linear complexity with respect to image size. To use Linear SRA in the self-attention layers, set
linear_attention=True
in thePvtV2Config
. - [
PvtV2Model
] is the hierarchical transformer encoder (which is also often referred to as Mix Transformer or MiT in the literature). [PvtV2ForImageClassification
] adds a simple classifier head on top to perform Image Classification. [PvtV2Backbone
] can be used with the [AutoBackbone
] system in larger architectures like Deformable DETR. - ImageNet pretrained weights for all model sizes can be found on the hub.
The best way to get started with the PVTv2 is to load the pretrained checkpoint with the size of your choosing using AutoModelForImageClassification
:
import requests
import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
model = AutoModelForImageClassification.from_pretrained("OpenGVLab/pvt_v2_b0")
image_processor = AutoImageProcessor.from_pretrained("OpenGVLab/pvt_v2_b0")
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
processed = image_processor(image)
outputs = model(torch.tensor(processed["pixel_values"]))
To use the PVTv2 as a backbone for more complex architectures like DeformableDETR, you can use AutoBackbone (this model would need fine-tuning as you're replacing the backbone in the pretrained model):
import requests
import torch
from transformers import AutoConfig, AutoModelForObjectDetection, AutoImageProcessor
from PIL import Image
model = AutoModelForObjectDetection.from_config(
config=AutoConfig.from_pretrained(
"SenseTime/deformable-detr",
backbone_config=AutoConfig.from_pretrained("OpenGVLab/pvt_v2_b5"),
use_timm_backbone=False
),
)
image_processor = AutoImageProcessor.from_pretrained("SenseTime/deformable-detr")
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
processed = image_processor(image)
outputs = model(torch.tensor(processed["pixel_values"]))
PVTv2 performance on ImageNet-1K by model size (B0-B5):
Method | Size | Acc@1 | #Params (M) |
---|---|---|---|
PVT-V2-B0 | 224 | 70.5 | 3.7 |
PVT-V2-B1 | 224 | 78.7 | 14.0 |
PVT-V2-B2-Linear | 224 | 82.1 | 22.6 |
PVT-V2-B2 | 224 | 82.0 | 25.4 |
PVT-V2-B3 | 224 | 83.1 | 45.2 |
PVT-V2-B4 | 224 | 83.6 | 62.6 |
PVT-V2-B5 | 224 | 83.8 | 82.0 |
PvtV2Config
autodoc PvtV2Config
PvtForImageClassification
autodoc PvtV2ForImageClassification - forward
PvtModel
autodoc PvtV2Model - forward