diff --git a/docs/source/en/model_doc/eomt.md b/docs/source/en/model_doc/eomt.md
index 34842de2101..86816a475fb 100644
--- a/docs/source/en/model_doc/eomt.md
+++ b/docs/source/en/model_doc/eomt.md
@@ -74,20 +74,16 @@ inputs = processor(
return_tensors="pt",
)
-# Remove Patch Offsets from inputs — only used later for post-processing.
-patch_offsets = inputs.pop("patch_offsets")
-
with torch.inference_mode():
outputs = model(**inputs)
# Prepare the original image size in the format (height, width)
-original_image_sizes = [(image.height, image.width)]
+target_sizes = [(image.height, image.width)]
# Post-process the model outputs to get final segmentation prediction
preds = processor.post_process_semantic_segmentation(
outputs,
- patch_offsets=patch_offsets,
- original_image_sizes=original_image_sizes,
+ target_sizes=target_sizes,
)
# Visualize the segmentation mask
@@ -130,12 +126,12 @@ with torch.inference_mode():
outputs = model(**inputs)
# Prepare the original image size in the format (height, width)
-original_image_sizes = [(image.height, image.width)]
+target_sizes = [(image.height, image.width)]
# Post-process the model outputs to get final segmentation prediction
preds = processor.post_process_instance_segmentation(
outputs,
- original_image_sizes=original_image_sizes,
+ target_sizes=target_sizes,
)
# Visualize the segmentation mask
@@ -173,12 +169,12 @@ with torch.inference_mode():
outputs = model(**inputs)
# Prepare the original image size in the format (height, width)
-original_image_sizes = [(image.height, image.width)]
+target_sizes = [(image.height, image.width)]
# Post-process the model outputs to get final segmentation prediction
preds = processor.post_process_panoptic_segmentation(
outputs,
- original_image_sizes=original_image_sizes,
+ target_sizes=target_sizes,
)
# Visualize the panoptic segmentation mask
diff --git a/docs/source/en/model_doc/vitpose.md b/docs/source/en/model_doc/vitpose.md
index 7a417cc2138..f9ed7265934 100644
--- a/docs/source/en/model_doc/vitpose.md
+++ b/docs/source/en/model_doc/vitpose.md
@@ -10,52 +10,39 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# ViTPose
-
-
-

+
+
+

+
-## Overview
+# ViTPose
-The ViTPose model was proposed in [ViTPose: Simple Vision Transformer Baselines for Human Pose Estimation](https://huggingface.co/papers/2204.12484) by Yufei Xu, Jing Zhang, Qiming Zhang, Dacheng Tao. ViTPose employs a standard, non-hierarchical [Vision Transformer](vit) as backbone for the task of keypoint estimation. A simple decoder head is added on top to predict the heatmaps from a given image. Despite its simplicity, the model gets state-of-the-art results on the challenging MS COCO Keypoint Detection benchmark. The model was further improved in [ViTPose++: Vision Transformer for Generic Body Pose Estimation](https://huggingface.co/papers/2212.04246) where the authors employ
-a mixture-of-experts (MoE) module in the ViT backbone along with pre-training on more data, which further enhances the performance.
+[ViTPose](https://huggingface.co/papers/2204.12484) is a vision transformer-based model for keypoint (pose) estimation. It uses a simple, non-hierarchical [ViT](./vit) backbone and a lightweight decoder head. This architecture simplifies model design, takes advantage of transformer scalability, and can be adapted to different training strategies.
-The abstract from the paper is the following:
-
-*Although no specific domain knowledge is considered in the design, plain vision transformers have shown excellent performance in visual recognition tasks. However, little effort has been made to reveal the potential of such simple structures for pose estimation tasks. In this paper, we show the surprisingly good capabilities of plain vision transformers for pose estimation from various aspects, namely simplicity in model structure, scalability in model size, flexibility in training paradigm, and transferability of knowledge between models, through a simple baseline model called ViTPose. Specifically, ViTPose employs plain and non-hierarchical vision transformers as backbones to extract features for a given person instance and a lightweight decoder for pose estimation. It can be scaled up from 100M to 1B parameters by taking the advantages of the scalable model capacity and high parallelism of transformers, setting a new Pareto front between throughput and performance. Besides, ViTPose is very flexible regarding the attention type, input resolution, pre-training and finetuning strategy, as well as dealing with multiple pose tasks. We also empirically demonstrate that the knowledge of large ViTPose models can be easily transferred to small ones via a simple knowledge token. Experimental results show that our basic ViTPose model outperforms representative methods on the challenging MS COCO Keypoint Detection benchmark, while the largest model sets a new state-of-the-art.*
+[ViTPose++](https://huggingface.co/papers/2212.04246) improves on ViTPose by incorporating a mixture-of-experts (MoE) module in the backbone and using more diverse pretraining data.

-
ViTPose architecture. Taken from the original paper.
+You can find all ViTPose and ViTPose++ checkpoints under the [ViTPose collection](https://huggingface.co/collections/usyd-community/vitpose-677fcfd0a0b2b5c8f79c4335).
-This model was contributed by [nielsr](https://huggingface.co/nielsr) and [sangbumchoi](https://github.com/SangbumChoi).
-The original code can be found [here](https://github.com/ViTAE-Transformer/ViTPose).
-
-## Usage Tips
-
-ViTPose is a so-called top-down keypoint detection model. This means that one first uses an object detector, like [RT-DETR](rt_detr.md), to detect people (or other instances) in an image. Next, ViTPose takes the cropped images as input and predicts the keypoints for each of them.
+The example below demonstrates pose estimation with the [`VitPoseForPoseEstimation`] class.
```py
import torch
import requests
import numpy as np
-
+import supervision as sv
from PIL import Image
-
from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation
device = "cuda" if torch.cuda.is_available() else "cpu"
-url = "http://images.cocodataset.org/val2017/000000000139.jpg"
+url = "https://www.fcbarcelona.com/fcbarcelona/photo/2021/01/31/3c55a19f-dfc1-4451-885e-afd14e890a11/mini_2021-01-31-BARCELONA-ATHLETIC-BILBAOI-30.JPG"
image = Image.open(requests.get(url, stream=True).raw)
-# ------------------------------------------------------------------------
-# Stage 1. Detect humans on the image
-# ------------------------------------------------------------------------
-
-# You can choose any detector of your choice
+# Detect humans in the image
person_image_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
person_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365", device_map=device)
@@ -67,7 +54,7 @@ with torch.no_grad():
results = person_image_processor.post_process_object_detection(
outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3
)
-result = results[0] # take first image results
+result = results[0]
# Human label refers 0 index in COCO dataset
person_boxes = result["boxes"][result["labels"] == 0]
@@ -77,10 +64,7 @@ person_boxes = person_boxes.cpu().numpy()
person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
-# ------------------------------------------------------------------------
-# Stage 2. Detect keypoints for each person found
-# ------------------------------------------------------------------------
-
+# Detect keypoints for each person found
image_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-base-simple")
model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple", device_map=device)
@@ -90,54 +74,7 @@ with torch.no_grad():
outputs = model(**inputs)
pose_results = image_processor.post_process_pose_estimation(outputs, boxes=[person_boxes])
-image_pose_result = pose_results[0] # results for first image
-```
-
-### ViTPose++ models
-
-The best [checkpoints](https://huggingface.co/collections/usyd-community/vitpose-677fcfd0a0b2b5c8f79c4335) are those of the [ViTPose++ paper](https://huggingface.co/papers/2212.04246). ViTPose++ models employ a so-called [Mixture-of-Experts (MoE)](https://huggingface.co/blog/moe) architecture for the ViT backbone, resulting in better performance.
-
-The ViTPose+ checkpoints use 6 experts, hence 6 different dataset indices can be passed.
-An overview of the various dataset indices is provided below:
-
-- 0: [COCO validation 2017](https://cocodataset.org/#overview) dataset, using an object detector that gets 56 AP on the "person" class
-- 1: [AiC](https://github.com/fabbrimatteo/AiC-Dataset) dataset
-- 2: [MPII](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/software-and-datasets/mpii-human-pose-dataset) dataset
-- 3: [AP-10K](https://github.com/AlexTheBad/AP-10K) dataset
-- 4: [APT-36K](https://github.com/pandorgan/APT-36K) dataset
-- 5: [COCO-WholeBody](https://github.com/jin-s13/COCO-WholeBody) dataset
-
-Pass the `dataset_index` argument in the forward of the model to indicate which experts to use for each example in the batch. Example usage is shown below:
-
-```python
-image_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-plus-base")
-model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-plus-base", device=device)
-
-inputs = image_processor(image, boxes=[person_boxes], return_tensors="pt").to(device)
-
-dataset_index = torch.tensor([0], device=device) # must be a tensor of shape (batch_size,)
-
-with torch.no_grad():
- outputs = model(**inputs, dataset_index=dataset_index)
-```
-
-The ViTPose+ checkpoints use 6 experts, hence 6 different dataset indices can be passed.
-An overview of the various dataset indices is provided below:
-
-- 0: [COCO validation 2017](https://cocodataset.org/#overview) dataset, using an object detector that gets 56 AP on the "person" class
-- 1: [AiC](https://github.com/fabbrimatteo/AiC-Dataset) dataset
-- 2: [MPII](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/software-and-datasets/mpii-human-pose-dataset) dataset
-- 3: [AP-10K](https://github.com/AlexTheBad/AP-10K) dataset
-- 4: [APT-36K](https://github.com/pandorgan/APT-36K) dataset
-- 5: [COCO-WholeBody](https://github.com/jin-s13/COCO-WholeBody) dataset
-
-
-### Visualization
-
-To visualize the various keypoints, one can either leverage the `supervision` [library](https://github.com/roboflow/supervision (requires `pip install supervision`):
-
-```python
-import supervision as sv
+image_pose_result = pose_results[0]
xy = torch.stack([pose_result['keypoints'] for pose_result in image_pose_result]).cpu().numpy()
scores = torch.stack([pose_result['scores'] for pose_result in image_pose_result]).cpu().numpy()
@@ -162,119 +99,192 @@ annotated_frame = vertex_annotator.annotate(
scene=annotated_frame,
key_points=key_points
)
+annotated_frame
```
-Alternatively, one can also visualize the keypoints using [OpenCV](https://opencv.org/) (requires `pip install opencv-python`):
+
+

+
-```python
-import math
-import cv2
+Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
-def draw_points(image, keypoints, scores, pose_keypoint_color, keypoint_score_threshold, radius, show_keypoint_weight):
- if pose_keypoint_color is not None:
- assert len(pose_keypoint_color) == len(keypoints)
- for kid, (kpt, kpt_score) in enumerate(zip(keypoints, scores)):
- x_coord, y_coord = int(kpt[0]), int(kpt[1])
- if kpt_score > keypoint_score_threshold:
- color = tuple(int(c) for c in pose_keypoint_color[kid])
- if show_keypoint_weight:
- cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)
- transparency = max(0, min(1, kpt_score))
- cv2.addWeighted(image, transparency, image, 1 - transparency, 0, dst=image)
- else:
- cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)
+The example below uses [torchao](../quantization/torchao) to only quantize the weights to int4.
-def draw_links(image, keypoints, scores, keypoint_edges, link_colors, keypoint_score_threshold, thickness, show_keypoint_weight, stick_width = 2):
- height, width, _ = image.shape
- if keypoint_edges is not None and link_colors is not None:
- assert len(link_colors) == len(keypoint_edges)
- for sk_id, sk in enumerate(keypoint_edges):
- x1, y1, score1 = (int(keypoints[sk[0], 0]), int(keypoints[sk[0], 1]), scores[sk[0]])
- x2, y2, score2 = (int(keypoints[sk[1], 0]), int(keypoints[sk[1], 1]), scores[sk[1]])
- if (
- x1 > 0
- and x1 < width
- and y1 > 0
- and y1 < height
- and x2 > 0
- and x2 < width
- and y2 > 0
- and y2 < height
- and score1 > keypoint_score_threshold
- and score2 > keypoint_score_threshold
- ):
- color = tuple(int(c) for c in link_colors[sk_id])
+```py
+# pip install torchao
+import torch
+import requests
+import numpy as np
+from PIL import Image
+from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation, TorchAoConfig
+
+url = "https://www.fcbarcelona.com/fcbarcelona/photo/2021/01/31/3c55a19f-dfc1-4451-885e-afd14e890a11/mini_2021-01-31-BARCELONA-ATHLETIC-BILBAOI-30.JPG"
+image = Image.open(requests.get(url, stream=True).raw)
+
+person_image_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
+person_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365", device_map=device)
+
+inputs = person_image_processor(images=image, return_tensors="pt").to(device)
+
+with torch.no_grad():
+ outputs = person_model(**inputs)
+
+results = person_image_processor.post_process_object_detection(
+ outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3
+)
+result = results[0]
+
+person_boxes = result["boxes"][result["labels"] == 0]
+person_boxes = person_boxes.cpu().numpy()
+
+person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
+person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
+
+quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
+
+image_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-plus-huge")
+model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-plus-huge", device_map=device, quantization_config=quantization_config)
+
+inputs = image_processor(image, boxes=[person_boxes], return_tensors="pt").to(device)
+
+with torch.no_grad():
+ outputs = model(**inputs)
+
+pose_results = image_processor.post_process_pose_estimation(outputs, boxes=[person_boxes])
+image_pose_result = pose_results[0]
+```
+
+## Notes
+
+- Use [`AutoProcessor`] to automatically prepare bounding box and image inputs.
+- ViTPose is a top-down pose estimator. It uses a object detector to detect individuals first before keypoint prediction.
+- ViTPose++ has 6 different MoE expert heads (COCO validation `0`, AiC `1`, MPII `2`, AP-10K `3`, APT-36K `4`, COCO-WholeBody `5`) which supports 6 different datasets. Pass a specific value corresponding to the dataset to the `dataset_index` to indicate which expert to use.
+
+ ```py
+ from transformers import AutoProcessor, VitPoseForPoseEstimation
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ image_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-plus-base")
+ model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-plus-base", device=device)
+
+ inputs = image_processor(image, boxes=[person_boxes], return_tensors="pt").to(device)
+ dataset_index = torch.tensor([0], device=device) # must be a tensor of shape (batch_size,)
+
+ with torch.no_grad():
+ outputs = model(**inputs, dataset_index=dataset_index)
+ ```
+
+- [OpenCV](https://opencv.org/) is an alternative option for visualizing the estimated pose.
+
+ ```py
+ # pip install opencv-python
+ import math
+ import cv2
+
+ def draw_points(image, keypoints, scores, pose_keypoint_color, keypoint_score_threshold, radius, show_keypoint_weight):
+ if pose_keypoint_color is not None:
+ assert len(pose_keypoint_color) == len(keypoints)
+ for kid, (kpt, kpt_score) in enumerate(zip(keypoints, scores)):
+ x_coord, y_coord = int(kpt[0]), int(kpt[1])
+ if kpt_score > keypoint_score_threshold:
+ color = tuple(int(c) for c in pose_keypoint_color[kid])
if show_keypoint_weight:
- X = (x1, x2)
- Y = (y1, y2)
- mean_x = np.mean(X)
- mean_y = np.mean(Y)
- length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5
- angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1]))
- polygon = cv2.ellipse2Poly(
- (int(mean_x), int(mean_y)), (int(length / 2), int(stick_width)), int(angle), 0, 360, 1
- )
- cv2.fillConvexPoly(image, polygon, color)
- transparency = max(0, min(1, 0.5 * (keypoints[sk[0], 2] + keypoints[sk[1], 2])))
+ cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)
+ transparency = max(0, min(1, kpt_score))
cv2.addWeighted(image, transparency, image, 1 - transparency, 0, dst=image)
else:
- cv2.line(image, (x1, y1), (x2, y2), color, thickness=thickness)
+ cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)
+ def draw_links(image, keypoints, scores, keypoint_edges, link_colors, keypoint_score_threshold, thickness, show_keypoint_weight, stick_width = 2):
+ height, width, _ = image.shape
+ if keypoint_edges is not None and link_colors is not None:
+ assert len(link_colors) == len(keypoint_edges)
+ for sk_id, sk in enumerate(keypoint_edges):
+ x1, y1, score1 = (int(keypoints[sk[0], 0]), int(keypoints[sk[0], 1]), scores[sk[0]])
+ x2, y2, score2 = (int(keypoints[sk[1], 0]), int(keypoints[sk[1], 1]), scores[sk[1]])
+ if (
+ x1 > 0
+ and x1 < width
+ and y1 > 0
+ and y1 < height
+ and x2 > 0
+ and x2 < width
+ and y2 > 0
+ and y2 < height
+ and score1 > keypoint_score_threshold
+ and score2 > keypoint_score_threshold
+ ):
+ color = tuple(int(c) for c in link_colors[sk_id])
+ if show_keypoint_weight:
+ X = (x1, x2)
+ Y = (y1, y2)
+ mean_x = np.mean(X)
+ mean_y = np.mean(Y)
+ length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5
+ angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1]))
+ polygon = cv2.ellipse2Poly(
+ (int(mean_x), int(mean_y)), (int(length / 2), int(stick_width)), int(angle), 0, 360, 1
+ )
+ cv2.fillConvexPoly(image, polygon, color)
+ transparency = max(0, min(1, 0.5 * (keypoints[sk[0], 2] + keypoints[sk[1], 2])))
+ cv2.addWeighted(image, transparency, image, 1 - transparency, 0, dst=image)
+ else:
+ cv2.line(image, (x1, y1), (x2, y2), color, thickness=thickness)
-# Note: keypoint_edges and color palette are dataset-specific
-keypoint_edges = model.config.edges
+ # Note: keypoint_edges and color palette are dataset-specific
+ keypoint_edges = model.config.edges
-palette = np.array(
- [
- [255, 128, 0],
- [255, 153, 51],
- [255, 178, 102],
- [230, 230, 0],
- [255, 153, 255],
- [153, 204, 255],
- [255, 102, 255],
- [255, 51, 255],
- [102, 178, 255],
- [51, 153, 255],
- [255, 153, 153],
- [255, 102, 102],
- [255, 51, 51],
- [153, 255, 153],
- [102, 255, 102],
- [51, 255, 51],
- [0, 255, 0],
- [0, 0, 255],
- [255, 0, 0],
- [255, 255, 255],
- ]
-)
+ palette = np.array(
+ [
+ [255, 128, 0],
+ [255, 153, 51],
+ [255, 178, 102],
+ [230, 230, 0],
+ [255, 153, 255],
+ [153, 204, 255],
+ [255, 102, 255],
+ [255, 51, 255],
+ [102, 178, 255],
+ [51, 153, 255],
+ [255, 153, 153],
+ [255, 102, 102],
+ [255, 51, 51],
+ [153, 255, 153],
+ [102, 255, 102],
+ [51, 255, 51],
+ [0, 255, 0],
+ [0, 0, 255],
+ [255, 0, 0],
+ [255, 255, 255],
+ ]
+ )
-link_colors = palette[[0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]]
-keypoint_colors = palette[[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]]
+ link_colors = palette[[0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]]
+ keypoint_colors = palette[[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]]
-numpy_image = np.array(image)
+ numpy_image = np.array(image)
-for pose_result in image_pose_result:
- scores = np.array(pose_result["scores"])
- keypoints = np.array(pose_result["keypoints"])
+ for pose_result in image_pose_result:
+ scores = np.array(pose_result["scores"])
+ keypoints = np.array(pose_result["keypoints"])
- # draw each point on image
- draw_points(numpy_image, keypoints, scores, keypoint_colors, keypoint_score_threshold=0.3, radius=4, show_keypoint_weight=False)
+ # draw each point on image
+ draw_points(numpy_image, keypoints, scores, keypoint_colors, keypoint_score_threshold=0.3, radius=4, show_keypoint_weight=False)
- # draw links
- draw_links(numpy_image, keypoints, scores, keypoint_edges, link_colors, keypoint_score_threshold=0.3, thickness=1, show_keypoint_weight=False)
+ # draw links
+ draw_links(numpy_image, keypoints, scores, keypoint_edges, link_colors, keypoint_score_threshold=0.3, thickness=1, show_keypoint_weight=False)
-pose_image = Image.fromarray(numpy_image)
-pose_image
-```
-

+ pose_image = Image.fromarray(numpy_image)
+ pose_image
+ ```
## Resources
-A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViTPose. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
+Refer to resources below to learn more about using ViTPose.
-- A demo of ViTPose on images and video can be found [here](https://huggingface.co/spaces/hysts/ViTPose-transformers).
-- A notebook illustrating inference and visualization can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/ViTPose/Inference_with_ViTPose_for_human_pose_estimation.ipynb).
+- This [notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/ViTPose/Inference_with_ViTPose_for_body_pose_estimation.ipynb) demonstrates inference and visualization.
+- This [Space](https://huggingface.co/spaces/hysts/ViTPose-transformers) demonstrates ViTPose on images and video.
## VitPoseImageProcessor
diff --git a/src/transformers/commands/chat.py b/src/transformers/commands/chat.py
index 91979590046..8f6f49f26bc 100644
--- a/src/transformers/commands/chat.py
+++ b/src/transformers/commands/chat.py
@@ -333,6 +333,11 @@ class ChatCommand(BaseTransformersCLICommand):
)
args.host, args.port = args.model_name_or_path_or_address.rsplit(":", 1)
+
+ if args.model_name_or_path is None:
+ raise ValueError(
+ "When connecting to a server, please specify a model name with the --model_name_or_path flag."
+ )
else:
self.spawn_backend = True
args.model_name_or_path = args.model_name_or_path_or_address
diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py
index 9b886b27210..d8f61603692 100644
--- a/src/transformers/commands/serving.py
+++ b/src/transformers/commands/serving.py
@@ -347,7 +347,7 @@ class ServeCommand(BaseTransformersCLICommand):
if not req.stream:
return {"error": "Only streaming mode is supported."}
- update_model = req.model != self.loaded_model
+ update_model = self.canonicalized_model_name(req.model) != self.loaded_model
if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
@@ -402,7 +402,7 @@ class ServeCommand(BaseTransformersCLICommand):
if self.last_messages is None:
req_continues_last_messages = False
# The new request has fewer rounds of conversation: this is a new request
- elif len(self.last_messages) > len(req.messages):
+ elif len(self.last_messages) >= len(req.messages):
req_continues_last_messages = False
# Otherwise, check that the last messages are a subset of the new request
else:
@@ -417,7 +417,7 @@ class ServeCommand(BaseTransformersCLICommand):
def generate(self, app):
@app.post("/v1/chat/completions")
def _serve(req: "ChatCompletionInput"):
- update_model = req.model != self.loaded_model
+ update_model = self.canonicalized_model_name(req.model) != self.loaded_model
if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
@@ -585,6 +585,11 @@ class ServeCommand(BaseTransformersCLICommand):
return quantization_config
+ def canonicalized_model_name(self, model_id: str) -> str:
+ if "@" in model_id:
+ return model_id
+ return f"{model_id}@main"
+
def load_model_and_tokenizer(
self, model_id_and_revision: str, args: ServeArguments
) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]:
@@ -621,9 +626,9 @@ class ServeCommand(BaseTransformersCLICommand):
if getattr(model, "hf_device_map", None) is None:
model = model.to(args.device)
- self.loaded_model = model_id_and_revision
+ self.loaded_model = f"{model_id}@{revision}"
- print("Loaded model", model_id_and_revision)
+ logger.warning(f"Loaded model {self.loaded_model}")
return model, tokenizer
diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py
index de6ae44bb5a..13a1c83a719 100644
--- a/src/transformers/generation/utils.py
+++ b/src/transformers/generation/utils.py
@@ -3773,16 +3773,28 @@ class GenerationMixin(ContinuousMixin):
Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False
"""
# a. Can the open beams improve the top completed scores?
- # early_stopping == False -> apply heuristic = always get the best score from
- # `cur_len - decoder_prompt_len`. See the discussion below for more details.
- # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
+ # early_stopping == False -> apply heuristic = always get the best score from `cur_len - decoder_prompt_len`.
# early_stopping == "never" -> compute the best score from `max_length` or `cur_len`, depending on the
# sign of `length_penalty`. Positive `length_penalty` favors longer sequences, thus we use
# `max_length` there.
+ # !!
+ # Be sure to check the docstring for `early_stopping` and `length_penalty`. The default parameterization
+ # does NOT correspond to a canonical beam search implementation, and tends to favor shorter output sequences
+ # compared to it (the heuristic active by default underestimates the maximum achievable score, and thus cut
+ # generation short). Also, be mindful that length penalty > 0.0 actually favors longer sequences, despite
+ # its name. These modifications were empirically found in the past (prior to 2022) to produce better quality
+ # generations, and changing them is BC breaking.
+ # For a canonical beam search implementation, set `early_stopping="never"` and `length_penalty=0.0`.
+ # See the discussion below for more details.
+ # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
+ # !!
if early_stopping == "never" and length_penalty > 0.0:
best_hypothetical_length = max_length - decoder_prompt_len
else:
best_hypothetical_length = cur_len - decoder_prompt_len
+
+ # best-case scenario: the next tokens have logprobs=0 (probability=1), and the score stays the same before
+ # applying length penalty
best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty)
worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9)
improvement_possible = torch.any(best_possible_running_score > worst_finished_score)
diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py
index c81b990b7de..48636496f99 100644
--- a/src/transformers/models/blip_2/modeling_blip_2.py
+++ b/src/transformers/models/blip_2/modeling_blip_2.py
@@ -415,6 +415,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
_no_split_modules = [
"Blip2Attention",
"Blip2QFormerMultiHeadAttention",
+ "Blip2EncoderLayer",
"Blip2TextEmbeddings",
"T5Block",
"OPTDecoderLayer",
@@ -1262,6 +1263,7 @@ class Blip2Model(Blip2PreTrainedModel):
config_class = Blip2Config
main_input_name = "pixel_values"
_keep_in_fp32_modules = ["query_tokens", "qformer"]
+ _supports_flash_attn_2 = False # because self.qformer does not support FA2
def __init__(self, config: Blip2Config):
super().__init__(config)
@@ -1646,6 +1648,7 @@ class Blip2Model(Blip2PreTrainedModel):
class Blip2TextModelWithProjection(Blip2PreTrainedModel):
supports_gradient_checkpointing = False
_keep_in_fp32_modules = ["query_tokens", "qformer"]
+ _supports_flash_attn_2 = False # because self.qformer does not support FA2
def __init__(self, config: Blip2Config):
super().__init__(config)
@@ -1738,6 +1741,7 @@ class Blip2TextModelWithProjection(Blip2PreTrainedModel):
class Blip2VisionModelWithProjection(Blip2PreTrainedModel):
main_input_name = "pixel_values"
_keep_in_fp32_modules = ["query_tokens", "qformer"]
+ _supports_flash_attn_2 = False # because self.qformer does not support FA2
def __init__(self, config: Blip2Config):
super().__init__(config)
@@ -1857,6 +1861,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
_keep_in_fp32_modules = ["query_tokens", "qformer"]
+ _supports_flash_attn_2 = False # because self.qformer does not support FA2
def __init__(self, config: Blip2Config):
super().__init__(config)
@@ -2086,9 +2091,13 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
else:
special_image_mask = input_ids == self.config.image_token_id
- special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
+ special_image_mask = (
+ special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
+ )
+ language_model_inputs = language_model_inputs.to(inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(
+ special_image_mask, language_model_inputs
+ )
else:
logger.warning_once(
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
@@ -2234,9 +2243,15 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
else:
special_image_mask = input_ids == self.config.image_token_id
- special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
+ special_image_mask = (
+ special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
+ )
+ language_model_inputs = language_model_inputs.to(inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(
+ special_image_mask, language_model_inputs
+ )
+
+ attention_mask = attention_mask.to(language_attention_mask.device)
else:
logger.warning_once(
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
@@ -2259,6 +2274,8 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
if not self.language_model.config.is_encoder_decoder:
+ if input_ids is not None:
+ input_ids = input_ids.to(language_model_inputs.device)
inputs["input_ids"] = input_ids
outputs = self.language_model.generate(**inputs, **generate_kwargs)
@@ -2275,6 +2292,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
main_input_name = "pixel_values"
_keep_in_fp32_modules = ["query_tokens", "qformer"]
+ _supports_flash_attn_2 = False # because self.qformer does not support FA2
def __init__(self, config: Blip2Config):
super().__init__(config)
diff --git a/src/transformers/models/dab_detr/modeling_dab_detr.py b/src/transformers/models/dab_detr/modeling_dab_detr.py
index 0f88a06fe64..119a7a0b162 100644
--- a/src/transformers/models/dab_detr/modeling_dab_detr.py
+++ b/src/transformers/models/dab_detr/modeling_dab_detr.py
@@ -829,6 +829,9 @@ class DabDetrPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
@@ -841,6 +844,8 @@ class DabDetrPreTrainedModel(PreTrainedModel):
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
bias_value = -math.log((1 - prior_prob) / prior_prob)
module.class_embed.bias.data.fill_(bias_value)
+ elif isinstance(module, nn.PReLU):
+ module.reset_parameters()
# Modified from transformers.models.detr.modeling_detr.DetrEncoder with Detr->DabDetr,DETR->ConditionalDETR
diff --git a/src/transformers/models/dac/modeling_dac.py b/src/transformers/models/dac/modeling_dac.py
index 191e7af89e3..398d258bef0 100644
--- a/src/transformers/models/dac/modeling_dac.py
+++ b/src/transformers/models/dac/modeling_dac.py
@@ -480,6 +480,12 @@ class DacPreTrainedModel(PreTrainedAudioTokenizerBase):
if isinstance(module, nn.Conv1d):
nn.init.trunc_normal_(module.weight, std=0.02)
nn.init.constant_(module.bias, 0)
+ elif isinstance(module, Snake1d):
+ module.alpha.data.fill_(1.0)
+ elif isinstance(module, nn.ConvTranspose1d):
+ module.reset_parameters()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=0.02)
def apply_weight_norm(self):
weight_norm = nn.utils.weight_norm
diff --git a/src/transformers/models/encodec/modeling_encodec.py b/src/transformers/models/encodec/modeling_encodec.py
index a74315ab4cc..6e610ba2953 100644
--- a/src/transformers/models/encodec/modeling_encodec.py
+++ b/src/transformers/models/encodec/modeling_encodec.py
@@ -235,7 +235,7 @@ class EncodecLSTM(nn.Module):
LSTM without worrying about the hidden state, nor the layout of the data. Expects input as convolutional layout.
"""
- def __init__(self, config, dimension):
+ def __init__(self, config: EncodecConfig, dimension: int):
super().__init__()
self.lstm = nn.LSTM(dimension, dimension, config.num_lstm_layers)
@@ -452,11 +452,7 @@ class EncodecPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ if isinstance(module, nn.GroupNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
@@ -464,10 +460,8 @@ class EncodecPreTrainedModel(PreTrainedModel):
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.ConvTranspose1d):
+ module.reset_parameters()
elif isinstance(module, nn.LSTM):
for name, param in module.named_parameters():
if "weight" in name:
@@ -659,7 +653,7 @@ class EncodecModel(EncodecPreTrainedModel):
def decode(
self,
- audio_codes: torch.Tensor,
+ audio_codes: torch.LongTensor,
audio_scales: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
@@ -708,10 +702,10 @@ class EncodecModel(EncodecPreTrainedModel):
@auto_docstring
def forward(
self,
- input_values: torch.Tensor,
- padding_mask: Optional[torch.Tensor] = None,
+ input_values: torch.FloatTensor,
+ padding_mask: Optional[torch.BoolTensor] = None,
bandwidth: Optional[float] = None,
- audio_codes: Optional[torch.Tensor] = None,
+ audio_codes: Optional[torch.LongTensor] = None,
audio_scales: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple[torch.Tensor, torch.Tensor], EncodecOutput]:
diff --git a/src/transformers/models/eomt/image_processing_eomt.py b/src/transformers/models/eomt/image_processing_eomt.py
index 73fe46034cd..e63a1be95fe 100644
--- a/src/transformers/models/eomt/image_processing_eomt.py
+++ b/src/transformers/models/eomt/image_processing_eomt.py
@@ -97,7 +97,7 @@ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, in
Computes the output image size given the input image size and the desired output size.
Args:
- image_size (`Tuple[int, int]`):
+ image_size (`tuple[int, int]`):
The input image size.
size (`int`):
The desired output size.
@@ -531,13 +531,13 @@ class EomtImageProcessor(BaseImageProcessor):
Image or batch of images to preprocess.
segmentation_maps (`ImageInput`, *optional*):
The corresponding semantic segmentation maps with the pixel-wise annotations.
- instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
+ instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*):
A mapping between object instance ids and class ids.
do_split_image (`bool`, *optional*, defaults to `self.do_split_image`):
Whether to split the input images into overlapping patches for semantic segmentation.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the input images.
- size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
Target size as a dictionary with `"shortest_edge"` and `"longest_edge"` keys.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use when resizing.
@@ -550,9 +550,9 @@ class EomtImageProcessor(BaseImageProcessor):
do_pad (`bool`, *optional*, defaults to `False`):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
- image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
Mean for normalization. Single value or list for each channel.
- image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
Standard deviation for normalization. Single value or list for each channel.
ignore_index (`int`, *optional*):
Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
@@ -640,7 +640,7 @@ class EomtImageProcessor(BaseImageProcessor):
)
if do_split_image and patch_offsets:
- encoded_inputs["patch_offsets"] = patch_offsets
+ encoded_inputs["patch_offsets"] = [torch.tensor(offsets) for offsets in patch_offsets]
return encoded_inputs
@@ -663,8 +663,8 @@ class EomtImageProcessor(BaseImageProcessor):
each mask.
Args:
- pixel_values_list (`List[ImageInput]`):
- List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
+ pixel_values_list (`list[ImageInput]`):
+ list of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
width)`.
segmentation_maps (`ImageInput`, *optional*):
@@ -678,7 +678,7 @@ class EomtImageProcessor(BaseImageProcessor):
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
- instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
+ instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*):
A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an
instance segmentation map where each pixel represents an instance id. Can be provided as a single
dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map
@@ -740,7 +740,7 @@ class EomtImageProcessor(BaseImageProcessor):
self,
segmentation_logits: torch.Tensor,
patch_offsets: list[tuple[int, int, int]],
- original_image_sizes: list[tuple[int, int]],
+ target_sizes: list[tuple[int, int]],
size: dict[str, int],
) -> list[torch.Tensor]:
"""
@@ -750,28 +750,28 @@ class EomtImageProcessor(BaseImageProcessor):
segmentation_logits (`torch.Tensor`):
A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits
for each image patch.
- patch_offsets (`List[Tuple[int, int, int]]`):
+ patch_offsets (`list[tuple[int, int, int]]`):
A list of tuples where each tuple contains:
- `image_index` (int): Index of the original image this patch belongs to.
- `start` (int): Start pixel index of the patch along the long dimension (height or width).
- `end` (int): End pixel index of the patch along the long dimension.
- original_image_sizes (`List[Tuple[int, int]]`):
- List of original (height, width) dimensions for each image before preprocessing.
- size (`Dict[str, int]`):
+ target_sizes (`list[tuple[int, int]]`):
+ list of original (height, width) dimensions for each image before preprocessing.
+ size (`dict[str, int]`):
A size dict which was used to resize.
"""
num_classes = segmentation_logits.shape[1]
aggregated_logits = []
patch_counts = []
- for image_size in original_image_sizes:
+ for image_size in target_sizes:
height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"])
aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
# Stitch patches back into full-sized logit maps
for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets):
- if original_image_sizes[image_idx][0] > original_image_sizes[image_idx][1]:
+ if target_sizes[image_idx][0] > target_sizes[image_idx][1]:
aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx]
patch_counts[image_idx][:, patch_start:patch_end, :] += 1
else:
@@ -784,7 +784,7 @@ class EomtImageProcessor(BaseImageProcessor):
averaged_logits = logit_sum / count.clamp(min=1)
resized_logits = F.interpolate(
averaged_logits[None, ...],
- size=original_image_sizes[idx],
+ size=target_sizes[idx],
mode="bilinear",
align_corners=False,
)[0]
@@ -796,14 +796,14 @@ class EomtImageProcessor(BaseImageProcessor):
def unpad_image(
self,
segmentation_logits: torch.Tensor,
- original_image_sizes: list[tuple[int, int]],
+ target_sizes: list[tuple[int, int]],
size: dict[str, int],
) -> list[torch.Tensor]:
"""Restores panoptic segmentation logits to their original image resolutions."""
resized_logits = []
- for idx, original_size in enumerate(original_image_sizes):
+ for idx, original_size in enumerate(target_sizes):
target_height, target_width = get_size_with_aspect_ratio(
original_size, size["shortest_edge"], size["longest_edge"]
)
@@ -817,8 +817,7 @@ class EomtImageProcessor(BaseImageProcessor):
def post_process_semantic_segmentation(
self,
outputs,
- patch_offsets: list[tuple[int, int, int]],
- original_image_sizes: list[tuple[int, int]],
+ target_sizes: list[tuple[int, int]],
size: Optional[dict[str, int]] = None,
) -> np.ndarray:
"""Post-processes model outputs into final semantic segmentation prediction."""
@@ -827,6 +826,7 @@ class EomtImageProcessor(BaseImageProcessor):
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
+ patch_offsets = outputs.patch_offsets
output_size = get_target_size(size)
masks_queries_logits = F.interpolate(
@@ -841,15 +841,15 @@ class EomtImageProcessor(BaseImageProcessor):
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
- output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, original_image_sizes, size)
+ output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
- preds = torch.stack(output_logits).argmax(dim=1)
+ preds = [logit.argmax(dim=0) for logit in output_logits]
return preds
def post_process_panoptic_segmentation(
self,
outputs,
- original_image_sizes: list[tuple[int, int]],
+ target_sizes: list[tuple[int, int]],
threshold: float = 0.8,
mask_threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8,
@@ -873,7 +873,7 @@ class EomtImageProcessor(BaseImageProcessor):
mode="bilinear",
)
- mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
+ mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size)
pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1)
results: list = []
@@ -885,7 +885,7 @@ class EomtImageProcessor(BaseImageProcessor):
# No mask found
if mask_probs.shape[0] <= 0:
- height, width = original_image_sizes[i] if original_image_sizes is not None else mask_probs.shape[1:]
+ height, width = target_sizes[i] if target_sizes is not None else mask_probs.shape[1:]
segmentation = torch.zeros((height, width)) - 1
results.append({"segmentation": segmentation, "segments_info": []})
continue
@@ -897,16 +897,17 @@ class EomtImageProcessor(BaseImageProcessor):
stuff_classes=stuff_classes,
mask_threshold=mask_threshold,
overlap_mask_area_threshold=overlap_mask_area_threshold,
- target_size=original_image_sizes[i] if original_image_sizes is not None else None,
+ target_size=target_sizes[i] if target_sizes is not None else None,
)
results.append({"segmentation": segmentation, "segments_info": segments})
return results
+ @filter_out_non_signature_kwargs()
def post_process_instance_segmentation(
self,
outputs,
- original_image_sizes: list[tuple[int, int]],
+ target_sizes: list[tuple[int, int]],
threshold: float = 0.5,
size: Optional[dict[str, int]] = None,
):
@@ -924,7 +925,7 @@ class EomtImageProcessor(BaseImageProcessor):
mode="bilinear",
)
- mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
+ mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size)
device = masks_queries_logits.device
batch_size = class_queries_logits.shape[0]
@@ -946,7 +947,7 @@ class EomtImageProcessor(BaseImageProcessor):
)
pred_scores = scores * mask_scores
- segmentation = torch.zeros(original_image_sizes[i], device=device) - 1
+ segmentation = torch.zeros(target_sizes[i], device=device) - 1
instance_maps, segments = [], []
current_segment_id = 0
diff --git a/src/transformers/models/eomt/image_processing_eomt_fast.py b/src/transformers/models/eomt/image_processing_eomt_fast.py
index 04b53c418db..343c6ae2cf1 100644
--- a/src/transformers/models/eomt/image_processing_eomt_fast.py
+++ b/src/transformers/models/eomt/image_processing_eomt_fast.py
@@ -41,6 +41,7 @@ from ...processing_utils import Unpack
from ...utils import (
TensorType,
auto_docstring,
+ filter_out_non_signature_kwargs,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
@@ -268,7 +269,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
r"""
segmentation_maps (`ImageInput`, *optional*):
The segmentation maps to preprocess for corresponding images.
- instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
+ instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*):
A mapping between object instance ids and class ids.
"""
# args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same
@@ -340,7 +341,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
outputs["class_labels"] = class_labels
if patch_offsets:
- outputs["patch_offsets"] = patch_offsets
+ outputs["patch_offsets"] = [torch.tensor(offsets) for offsets in patch_offsets]
return outputs
@@ -348,7 +349,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
self,
segmentation_logits: torch.Tensor,
patch_offsets: list[tuple[int, int, int]],
- original_image_sizes: list[tuple[int, int]],
+ target_sizes: list[tuple[int, int]],
size: dict[str, int],
) -> list[torch.Tensor]:
"""
@@ -358,28 +359,28 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
segmentation_logits (`torch.Tensor`):
A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits
for each image patch.
- patch_offsets (`List[Tuple[int, int, int]]`):
+ patch_offsets (`list[tuple[int, int, int]]`):
A list of tuples where each tuple contains:
- `image_index` (int): Index of the original image this patch belongs to.
- `start` (int): Start pixel index of the patch along the long dimension (height or width).
- `end` (int): End pixel index of the patch along the long dimension.
- original_image_sizes (`List[Tuple[int, int]]`):
- List of original (height, width) dimensions for each image before preprocessing.
- size (`Dict[str, int]`):
+ target_sizes (`list[tuple[int, int]]`):
+ list of original (height, width) dimensions for each image before preprocessing.
+ size (`dict[str, int]`):
A size dict which was used to resize.
"""
num_classes = segmentation_logits.shape[1]
aggregated_logits = []
patch_counts = []
- for image_size in original_image_sizes:
+ for image_size in target_sizes:
height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"])
aggregated_logits.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
patch_counts.append(torch.zeros((num_classes, height, width), device=segmentation_logits.device))
# Stitch patches back into full-sized logit maps
for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets):
- if original_image_sizes[image_idx][0] > original_image_sizes[image_idx][1]:
+ if target_sizes[image_idx][0] > target_sizes[image_idx][1]:
aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx]
patch_counts[image_idx][:, patch_start:patch_end, :] += 1
else:
@@ -392,7 +393,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
averaged_logits = logit_sum / count.clamp(min=1)
resized_logits = torch.nn.functional.interpolate(
averaged_logits[None, ...],
- size=original_image_sizes[idx],
+ size=target_sizes[idx],
mode="bilinear",
align_corners=False,
)[0]
@@ -404,14 +405,14 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
def unpad_image(
self,
segmentation_logits: torch.Tensor,
- original_image_sizes: list[tuple[int, int]],
+ target_sizes: list[tuple[int, int]],
size: dict[str, int],
) -> list[torch.Tensor]:
"""Restores panoptic segmentation logits to their original image resolutions."""
resized_logits = []
- for idx, original_size in enumerate(original_image_sizes):
+ for idx, original_size in enumerate(target_sizes):
target_height, target_width = get_size_with_aspect_ratio(
original_size, size["shortest_edge"], size["longest_edge"]
)
@@ -425,8 +426,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
def post_process_semantic_segmentation(
self,
outputs,
- patch_offsets: list[tuple[int, int, int]],
- original_image_sizes: list[tuple[int, int]],
+ target_sizes: list[tuple[int, int]],
size: Optional[dict[str, int]] = None,
) -> np.ndarray:
"""Post-processes model outputs into final semantic segmentation prediction."""
@@ -435,6 +435,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
+ patch_offsets = outputs.patch_offsets
output_size = get_target_size(size)
masks_queries_logits = torch.nn.functional.interpolate(
@@ -449,15 +450,15 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
- output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, original_image_sizes, size)
+ output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
- preds = torch.stack(output_logits).argmax(dim=1)
+ preds = [logit.argmax(dim=0) for logit in output_logits]
return preds
def post_process_panoptic_segmentation(
self,
outputs,
- original_image_sizes: list[tuple[int, int]],
+ target_sizes: list[tuple[int, int]],
threshold: float = 0.8,
mask_threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8,
@@ -481,7 +482,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
mode="bilinear",
)
- mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
+ mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size)
pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(dim=-1).max(-1)
results: list = []
@@ -493,7 +494,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
# No mask found
if mask_probs.shape[0] <= 0:
- height, width = original_image_sizes[i] if original_image_sizes is not None else mask_probs.shape[1:]
+ height, width = target_sizes[i] if target_sizes is not None else mask_probs.shape[1:]
segmentation = torch.zeros((height, width)) - 1
results.append({"segmentation": segmentation, "segments_info": []})
continue
@@ -505,16 +506,17 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
stuff_classes=stuff_classes,
mask_threshold=mask_threshold,
overlap_mask_area_threshold=overlap_mask_area_threshold,
- target_size=original_image_sizes[i] if original_image_sizes is not None else None,
+ target_size=target_sizes[i] if target_sizes is not None else None,
)
results.append({"segmentation": segmentation, "segments_info": segments})
return results
+ @filter_out_non_signature_kwargs()
def post_process_instance_segmentation(
self,
outputs,
- original_image_sizes: list[tuple[int, int]],
+ target_sizes: list[tuple[int, int]],
threshold: float = 0.8,
size: Optional[dict[str, int]] = None,
):
@@ -532,7 +534,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
mode="bilinear",
)
- mask_probs_batch = self.unpad_image(masks_queries_logits, original_image_sizes, size)
+ mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size)
device = masks_queries_logits.device
batch_size = class_queries_logits.shape[0]
@@ -554,7 +556,7 @@ class EomtImageProcessorFast(BaseImageProcessorFast):
)
pred_scores = scores * mask_scores
- segmentation = torch.zeros(original_image_sizes[i], device=device) - 1
+ segmentation = torch.zeros(target_sizes[i], device=device) - 1
instance_maps, segments = [], []
current_segment_id = 0
diff --git a/src/transformers/models/eomt/modeling_eomt.py b/src/transformers/models/eomt/modeling_eomt.py
index bbdd11e1f58..bc865988ca6 100644
--- a/src/transformers/models/eomt/modeling_eomt.py
+++ b/src/transformers/models/eomt/modeling_eomt.py
@@ -74,6 +74,8 @@ class EomtForUniversalSegmentationOutput(ModelOutput):
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Self and Cross Attentions weights from transformer decoder.
+ patch_offsets (`list[torch.Tensor]`, *optional*):
+ list of tuples indicating the image index and start and end positions of patches for semantic segementation.
"""
loss: Optional[torch.FloatTensor] = None
@@ -82,6 +84,7 @@ class EomtForUniversalSegmentationOutput(ModelOutput):
last_hidden_state: Optional[torch.FloatTensor] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
+ patch_offsets: Optional[list[torch.Tensor]] = None
# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
@@ -996,7 +999,7 @@ class EomtPreTrainedModel(PreTrainedModel):
base_model_prefix = "eomt"
main_input_name = "pixel_values"
supports_gradient_checkpointing = False
- _no_split_modules = ["EomtMLP"]
+ _no_split_modules = ["EomtLayer"]
_supports_sdpa = True
_supports_flash_attn_2 = True
@@ -1097,13 +1100,16 @@ class EomtForUniversalSegmentation(EomtPreTrainedModel):
class_labels: Optional[list[Tensor]] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
+ patch_offsets: Optional[list[Tensor]] = None,
) -> EomtForUniversalSegmentationOutput:
r"""
- mask_labels (`List[torch.Tensor]`, *optional*):
- List of mask labels of shape `(num_labels, height, width)` to be fed to a model
- class_labels (`List[torch.LongTensor]`, *optional*):
+ mask_labels (`list[torch.Tensor]`, *optional*):
+ list of mask labels of shape `(num_labels, height, width)` to be fed to a model
+ class_labels (`list[torch.LongTensor]`, *optional*):
list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
+ patch_offsets (`list[torch.Tensor]`, *optional*):
+ list of tuples indicating the image index and start and end positions of patches for semantic segementation.
"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1126,7 +1132,7 @@ class EomtForUniversalSegmentation(EomtPreTrainedModel):
all_hidden_states += (hidden_states,)
if idx == self.num_hidden_layers - self.config.num_blocks:
- query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1)
+ query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device)
hidden_states = torch.cat((query, hidden_states), dim=1)
if idx >= self.num_hidden_layers - self.config.num_blocks and (
@@ -1206,6 +1212,7 @@ class EomtForUniversalSegmentation(EomtPreTrainedModel):
last_hidden_state=sequence_output,
hidden_states=all_hidden_states,
attentions=all_attentions,
+ patch_offsets=patch_offsets,
)
def get_input_embeddings(self):
diff --git a/src/transformers/models/eomt/modular_eomt.py b/src/transformers/models/eomt/modular_eomt.py
index fc82836e4be..44ecb69eca6 100644
--- a/src/transformers/models/eomt/modular_eomt.py
+++ b/src/transformers/models/eomt/modular_eomt.py
@@ -226,6 +226,8 @@ class EomtForUniversalSegmentationOutput(ModelOutput):
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Self and Cross Attentions weights from transformer decoder.
+ patch_offsets (`list[torch.Tensor]`, *optional*):
+ list of tuples indicating the image index and start and end positions of patches for semantic segementation.
"""
loss: Optional[torch.FloatTensor] = None
@@ -234,6 +236,7 @@ class EomtForUniversalSegmentationOutput(ModelOutput):
last_hidden_state: Optional[torch.FloatTensor] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
+ patch_offsets: Optional[list[torch.Tensor]] = None
class EomtLoss(Mask2FormerLoss):
@@ -368,7 +371,7 @@ class EomtPreTrainedModel(PreTrainedModel):
base_model_prefix = "eomt"
main_input_name = "pixel_values"
supports_gradient_checkpointing = False
- _no_split_modules = ["EomtMLP"]
+ _no_split_modules = ["EomtLayer"]
_supports_sdpa = True
_supports_flash_attn_2 = True
@@ -473,13 +476,16 @@ class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Modul
class_labels: Optional[list[Tensor]] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
+ patch_offsets: Optional[list[Tensor]] = None,
):
r"""
- mask_labels (`List[torch.Tensor]`, *optional*):
- List of mask labels of shape `(num_labels, height, width)` to be fed to a model
- class_labels (`List[torch.LongTensor]`, *optional*):
+ mask_labels (`list[torch.Tensor]`, *optional*):
+ list of mask labels of shape `(num_labels, height, width)` to be fed to a model
+ class_labels (`list[torch.LongTensor]`, *optional*):
list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
+ patch_offsets (`list[torch.Tensor]`, *optional*):
+ list of tuples indicating the image index and start and end positions of patches for semantic segementation.
"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -502,7 +508,7 @@ class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Modul
all_hidden_states += (hidden_states,)
if idx == self.num_hidden_layers - self.config.num_blocks:
- query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1)
+ query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device)
hidden_states = torch.cat((query, hidden_states), dim=1)
if idx >= self.num_hidden_layers - self.config.num_blocks and (
@@ -582,6 +588,7 @@ class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation, nn.Modul
last_hidden_state=sequence_output,
hidden_states=all_hidden_states,
attentions=all_attentions,
+ patch_offsets=patch_offsets,
)
diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py
index 426e557d9d3..942053be3e7 100644
--- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py
+++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py
@@ -445,9 +445,16 @@ class FalconMambaPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights."""
+ std = self.config.initializer_range
if isinstance(module, FalconMambaMixer):
+ # S4D real initialization. These are not discretized!
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
+ A = torch.arange(1, module.ssm_state_size + 1, dtype=torch.float32)[None, :]
+ A = A.expand(module.intermediate_size, -1).contiguous()
+ module.A_log.copy_(torch.log(A))
module.A_log._no_weight_decay = True
module.D._no_weight_decay = True
+ module.D.data.fill_(1.0)
dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
if self.config.time_step_init_scheme == "constant":
@@ -462,33 +469,39 @@ class FalconMambaPreTrainedModel(PreTrainedModel):
).clamp(min=self.config.time_step_floor)
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
- with torch.no_grad():
- module.dt_proj.bias.copy_(inv_dt)
+ module.dt_proj.bias.copy_(inv_dt)
module.dt_proj.bias._no_reinit = True
+ nn.init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5))
+ if module.conv1d.bias is not None:
+ if not getattr(module.conv1d.bias, "_no_reinit", False):
+ nn.init.zeros_(module.conv1d.bias)
+ nn.init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5))
+
+ if self.config.rescale_prenorm_residual:
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
+ #
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
+ # We need to reinit p since this code could be called multiple times
+ # Having just p *= scale would repeatedly scale it down
+ p = module.out_proj.weight
+ p /= math.sqrt(self.config.num_hidden_layers)
+
if isinstance(module, nn.Linear):
+ if not getattr(module.weight, "_no_reinit", False):
+ nn.init.normal_(module.weight, std=std)
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
+ elif isinstance(module, FalconMambaRMSNorm):
+ module.weight.data.fill_(1.0)
elif isinstance(module, nn.Embedding):
- nn.init.normal_(module.weight, std=self.config.initializer_range)
-
- if self.config.rescale_prenorm_residual:
- # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
- # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
- # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
- # > -- GPT-2 :: https://openai.com/blog/better-language-models/
- #
- # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
- for name, p in module.named_parameters():
- if name in ["out_proj.weight"]:
- # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
- # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
- # We need to reinit p since this code could be called multiple times
- # Having just p *= scale would repeatedly scale it down
- nn.init.kaiming_uniform_(p, a=math.sqrt(5))
- with torch.no_grad():
- p /= math.sqrt(self.config.num_hidden_layers)
+ nn.init.normal_(module.weight, std=std)
@dataclass
diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py
index 31ccb4becd1..743f74a1215 100644
--- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py
+++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py
@@ -1414,16 +1414,18 @@ class GroundingDinoPreTrainedModel(PreTrainedModel):
module.out_vision_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(module.out_text_proj.weight)
module.out_text_proj.bias.data.fill_(0)
- elif isinstance(module, (GroundingDinoEncoderLayer, GroundingDinoDecoderLayer)):
- for p in module.parameters():
- if p.dim() > 1:
- nn.init.normal_(p, mean=0.0, std=std)
+ elif isinstance(module, GroundingDinoFusionLayer):
+ module.vision_param.data.fill_(1e-4)
+ module.text_param.data.fill_(1e-4)
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
diff --git a/src/transformers/models/llava_onevision/configuration_llava_onevision.py b/src/transformers/models/llava_onevision/configuration_llava_onevision.py
index 6e618b1ce59..f6f40c1bd83 100644
--- a/src/transformers/models/llava_onevision/configuration_llava_onevision.py
+++ b/src/transformers/models/llava_onevision/configuration_llava_onevision.py
@@ -176,7 +176,7 @@ class LlavaOnevisionConfig(PretrainedConfig):
patch_size=14,
image_size=384,
num_hidden_layers=26,
- num_attention_heads=14,
+ num_attention_heads=16,
vision_use_head=False,
)
diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py
index f2347833db6..7da4ef57878 100644
--- a/src/transformers/models/mamba/modeling_mamba.py
+++ b/src/transformers/models/mamba/modeling_mamba.py
@@ -382,9 +382,16 @@ class MambaPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights."""
+ std = self.config.initializer_range
if isinstance(module, MambaMixer):
+ # S4D real initialization. These are not discretized!
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
+ A = torch.arange(1, module.ssm_state_size + 1, dtype=torch.float32)[None, :]
+ A = A.expand(module.intermediate_size, -1).contiguous()
+ module.A_log.copy_(torch.log(A))
module.A_log._no_weight_decay = True
module.D._no_weight_decay = True
+ module.D.data.fill_(1.0)
dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
if self.config.time_step_init_scheme == "constant":
@@ -399,33 +406,39 @@ class MambaPreTrainedModel(PreTrainedModel):
).clamp(min=self.config.time_step_floor)
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
- with torch.no_grad():
- module.dt_proj.bias.copy_(inv_dt)
+ module.dt_proj.bias.copy_(inv_dt)
module.dt_proj.bias._no_reinit = True
+ nn.init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5))
+ if module.conv1d.bias is not None:
+ if not getattr(module.conv1d.bias, "_no_reinit", False):
+ nn.init.zeros_(module.conv1d.bias)
+ nn.init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5))
+
+ if self.config.rescale_prenorm_residual:
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
+ #
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
+ # We need to reinit p since this code could be called multiple times
+ # Having just p *= scale would repeatedly scale it down
+ p = module.out_proj.weight
+ p /= math.sqrt(self.config.num_hidden_layers)
+
if isinstance(module, nn.Linear):
+ if not getattr(module.weight, "_no_reinit", False):
+ nn.init.normal_(module.weight, std=std)
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
+ elif isinstance(module, MambaRMSNorm):
+ module.weight.data.fill_(1.0)
elif isinstance(module, nn.Embedding):
- nn.init.normal_(module.weight, std=self.config.initializer_range)
-
- if self.config.rescale_prenorm_residual:
- # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
- # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
- # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
- # > -- GPT-2 :: https://openai.com/blog/better-language-models/
- #
- # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
- for name, p in module.named_parameters():
- if name in ["out_proj.weight"]:
- # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
- # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
- # We need to reinit p since this code could be called multiple times
- # Having just p *= scale would repeatedly scale it down
- nn.init.kaiming_uniform_(p, a=math.sqrt(5))
- with torch.no_grad():
- p /= math.sqrt(self.config.num_hidden_layers)
+ nn.init.normal_(module.weight, std=std)
@dataclass
diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py
index 1f663462d5e..e601b4d8a69 100644
--- a/src/transformers/models/mamba2/modeling_mamba2.py
+++ b/src/transformers/models/mamba2/modeling_mamba2.py
@@ -721,9 +721,15 @@ class Mamba2PreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights."""
+ std = self.config.initializer_range
if isinstance(module, Mamba2Mixer):
+ # S4D real initialization. These are not discretized!
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
+ A = torch.arange(1, self.config.num_heads + 1)
+ module.A_log.copy_(torch.log(A))
module.A_log._no_weight_decay = True
module.D._no_weight_decay = True
+ module.D.data.fill_(1.0)
dt = torch.exp(
torch.rand(self.config.num_heads)
@@ -733,33 +739,39 @@ class Mamba2PreTrainedModel(PreTrainedModel):
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
- with torch.no_grad():
- module.dt_bias.copy_(inv_dt)
+ module.dt_bias.copy_(inv_dt)
module.dt_bias._no_reinit = True
+ nn.init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5))
+ if module.conv1d.bias is not None:
+ if not getattr(module.conv1d.bias, "_no_reinit", False):
+ nn.init.zeros_(module.conv1d.bias)
+ nn.init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5))
+
+ if self.config.rescale_prenorm_residual:
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
+ #
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
+ # We need to reinit p since this code could be called multiple times
+ # Having just p *= scale would repeatedly scale it down
+ p = module.out_proj.weight
+ p /= math.sqrt(self.config.num_hidden_layers)
+
if isinstance(module, nn.Linear):
+ if not getattr(module.weight, "_no_reinit", False):
+ nn.init.normal_(module.weight, std=std)
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
+ elif isinstance(module, (Mamba2RMSNorm, MambaRMSNormGated)):
+ module.weight.data.fill_(1.0)
elif isinstance(module, nn.Embedding):
- nn.init.normal_(module.weight, std=self.config.initializer_range)
-
- if self.config.rescale_prenorm_residual:
- # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
- # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
- # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
- # > -- GPT-2 :: https://openai.com/blog/better-language-models/
- #
- # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
- for name, p in module.named_parameters():
- if name in ["out_proj.weight"]:
- # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
- # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
- # We need to reinit p since this code could be called multiple times
- # Having just p *= scale would repeatedly scale it down
- nn.init.kaiming_uniform_(p, a=math.sqrt(5))
- with torch.no_grad():
- p /= math.sqrt(self.config.num_hidden_layers)
+ nn.init.normal_(module.weight, std=std)
@dataclass
diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py
index 8a0cb7dbf80..139256c7c71 100644
--- a/src/transformers/models/musicgen/modeling_musicgen.py
+++ b/src/transformers/models/musicgen/modeling_musicgen.py
@@ -440,10 +440,13 @@ class MusicgenPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_factor
- if isinstance(module, (nn.Linear, nn.Conv1d)):
+ if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
index e8aa032784f..55e28ca58f7 100644
--- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
+++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
@@ -406,10 +406,13 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_factor
- if isinstance(module, (nn.Linear, nn.Conv1d)):
+ if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
@@ -1286,7 +1289,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
The text encoder model that encodes text into hidden states for conditioning.
audio_encoder (`PreTrainedModel`, *optional*):
The audio encoder model that encodes audio into hidden states for conditioning.
- decoder (`MusicgenForCausalLM`, *optional*):
+ decoder (`MusicgenMelodyForCausalLM`, *optional*):
The decoder model that generates audio tokens based on conditioning signals.
"""
if config is None and None in (text_encoder, audio_encoder, decoder):
diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py
index 0dfbf833324..9bac40553d9 100644
--- a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py
+++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py
@@ -1006,10 +1006,15 @@ class OmDetTurboPreTrainedModel(PreTrainedModel):
nn.init.xavier_uniform_(module.query_position_head.layers[1].weight)
for layer in module.channel_projection_layers:
nn.init.xavier_uniform_(layer[0].weight)
+ elif isinstance(module, OmDetTurboLanguageBackbone):
+ nn.init.normal_(module.text_projection, std=self.config.text_projection_in_dim**-0.5)
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
if module.bias is not None:
module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, OmDetTurboDecoder):
diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
index 45fcbe80495..f90f7ff9cf9 100644
--- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
+++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
@@ -283,6 +283,9 @@ class Qwen2AudioPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
diff --git a/src/transformers/models/seggpt/modeling_seggpt.py b/src/transformers/models/seggpt/modeling_seggpt.py
index 80a51fb5565..364483359ee 100644
--- a/src/transformers/models/seggpt/modeling_seggpt.py
+++ b/src/transformers/models/seggpt/modeling_seggpt.py
@@ -604,7 +604,7 @@ class SegGptPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["SegGptEmbeddings", "SegGptLayer"]
- def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ def _init_weights(self, module: nn.Module) -> None:
"""Initialize the weights"""
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv2d)):
@@ -615,7 +615,7 @@ class SegGptPreTrainedModel(PreTrainedModel):
)
if module.bias is not None:
module.bias.data.zero_()
- elif isinstance(module, nn.LayerNorm):
+ elif isinstance(module, (nn.LayerNorm, SegGptLayerNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, SegGptAttention):
diff --git a/src/transformers/models/smolvlm/processing_smolvlm.py b/src/transformers/models/smolvlm/processing_smolvlm.py
index ada719a70e0..72f63c37ffd 100644
--- a/src/transformers/models/smolvlm/processing_smolvlm.py
+++ b/src/transformers/models/smolvlm/processing_smolvlm.py
@@ -434,6 +434,10 @@ class SmolVLMProcessor(ProcessorMixin):
if chat_template is None and has_video:
# re-assign to the correct default template for BC, if user is not requesting their own template
chat_template = DEFAULT_CHAT_TEMPLATE
+
+ kwargs.setdefault("num_frames", self.video_processor.num_frames)
+ kwargs.setdefault("fps", self.video_processor.fps)
+
return super().apply_chat_template(conversation, chat_template, **kwargs)
diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py
index 33e50de7aa8..ce92e7b66bb 100644
--- a/src/transformers/models/superglue/modeling_superglue.py
+++ b/src/transformers/models/superglue/modeling_superglue.py
@@ -551,17 +551,18 @@ class SuperGluePreTrainedModel(PreTrainedModel):
def _init_weights(self, module: nn.Module) -> None:
"""Initialize the weights"""
- if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv1d)):
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
- elif isinstance(module, nn.LayerNorm):
+ elif isinstance(module, nn.BatchNorm1d):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
- elif isinstance(module, SuperGlueMultiLayerPerceptron):
- nn.init.constant_(module.linear.bias, 0.0)
+
+ if hasattr(module, "bin_score"):
+ module.bin_score.data.fill_(1.0)
@auto_docstring(
diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py
index 2a97cde3ccf..9dd9d9ce008 100644
--- a/src/transformers/processing_utils.py
+++ b/src/transformers/processing_utils.py
@@ -1097,9 +1097,13 @@ class ProcessorMixin(PushToHubMixin):
processor_config=processor_dict, valid_kwargs=accepted_args_and_kwargs
)
- # remove args that are in processor_dict to avoid duplicate arguments
- args_to_remove = [i for i, arg in enumerate(accepted_args_and_kwargs) if arg in processor_dict]
- args = [arg for i, arg in enumerate(args) if i not in args_to_remove]
+ # update args that are already in processor_dict to avoid duplicate arguments
+ args_to_update = {
+ i: valid_kwargs.pop(arg)
+ for i, arg in enumerate(accepted_args_and_kwargs)
+ if (arg in valid_kwargs and i < len(args))
+ }
+ args = [arg if i not in args_to_update else args_to_update[i] for i, arg in enumerate(args)]
# instantiate processor with used (and valid) kwargs only
processor = cls(*args, **valid_kwargs)
diff --git a/tests/commands/test_chat.py b/tests/commands/test_chat.py
index 6ba3413fafa..e07df4a3938 100644
--- a/tests/commands/test_chat.py
+++ b/tests/commands/test_chat.py
@@ -29,12 +29,34 @@ class ChatCLITest(unittest.TestCase):
self.assertIn("chat interface", cs.out.lower())
@patch.object(ChatCommand, "run")
- def test_cli_dispatch(self, run_mock):
+ def test_cli_dispatch_model(self, run_mock):
+ """
+ Running transformers chat with just a model should work & spawn a serve underneath
+ """
args = ["transformers", "chat", "hf-internal-testing/tiny-random-gpt2"]
with patch("sys.argv", args):
cli.main()
run_mock.assert_called_once()
+ def test_cli_dispatch_url(self):
+ """
+ Running transformers chat with just a URL should not work as a model should additionally be specified
+ """
+ args = ["transformers", "chat", "localhost:8000"]
+ with self.assertRaises(ValueError):
+ with patch("sys.argv", args):
+ cli.main()
+
+ @patch.object(ChatCommand, "run")
+ def test_cli_dispatch_url_and_model(self, run_mock):
+ """
+ Running transformers chat with a URL and a model should work
+ """
+ args = ["transformers", "chat", "localhost:8000", "--model_name_or_path=hf-internal-testing/tiny-random-gpt2"]
+ with patch("sys.argv", args):
+ cli.main()
+ run_mock.assert_called_once()
+
def test_parsed_args(self):
with (
patch.object(ChatCommand, "__init__", return_value=None) as init_mock,
diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py
index af95bbb2c32..4cac2f38136 100644
--- a/tests/models/blip_2/test_modeling_blip_2.py
+++ b/tests/models/blip_2/test_modeling_blip_2.py
@@ -1786,7 +1786,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
# Test output
- self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
+ expected_ids = [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118]
+ self.assertEqual(predictions[0].tolist(), [50265] * 32 + expected_ids) # 50265 is the img token id
self.assertEqual("a woman sitting on the beach with a dog", generated_text)
# image and context
@@ -1797,10 +1798,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
# Test output
- self.assertEqual(
- predictions[0].tolist(),
- [2, 45641, 35, 61, 343, 16, 42, 116, 31652, 35, 24, 18, 45, 10, 343, 6, 24, 18, 10, 4105, 50118],
- )
+ expected_ids = [2, 45641, 35, 61, 343, 16, 42, 116, 31652, 35, 24, 18, 45, 10, 343, 6, 24, 18, 10, 4105, 50118]
+ self.assertEqual(predictions[0].tolist(), [50265] * 32 + expected_ids) # 50265 is the img token id
self.assertEqual(generated_text, "Question: which city is this? Answer: it's not a city, it's a beach")
@require_torch_multi_accelerator
@@ -1826,8 +1825,17 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
# Test output
- self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
- self.assertEqual("woman playing with dog on the beach", generated_text)
+ expected_ids_and_text = Expectations(
+ {
+ ("cuda", None): ([0, 2335, 1556, 28, 1782, 30, 8, 2608, 1], "woman playing with dog on the beach"),
+ ("rocm", (9, 5)): (
+ [0, 3, 9, 2335, 19, 1556, 28, 160, 1782, 30, 8, 2608, 1],
+ "a woman is playing with her dog on the beach",
+ ),
+ }
+ ).get_expectation()
+ self.assertEqual(predictions[0].tolist(), expected_ids_and_text[0])
+ self.assertEqual(generated_text, expected_ids_and_text[1])
# image and context
prompt = "Question: which city is this? Answer:"
@@ -1837,11 +1845,17 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
# Test output
- self.assertEqual(
- predictions[0].tolist(),
- [0, 3, 7, 152, 67, 839, 1],
- )
- self.assertEqual(generated_text, "san diego")
+ expected_ids_and_text = Expectations(
+ {
+ ("cuda", None): ([0, 3, 7, 152, 67, 839, 1], "san diego"),
+ ("rocm", (9, 5)): (
+ [0, 3, 7, 152, 2515, 11389, 3523, 1],
+ "san francisco", # TODO: check if this is ok
+ ),
+ }
+ ).get_expectation()
+ self.assertEqual(predictions[0].tolist(), expected_ids_and_text[0])
+ self.assertEqual(generated_text, expected_ids_and_text[1])
def test_expansion_in_processing(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
diff --git a/tests/models/dpt/test_modeling_dpt.py b/tests/models/dpt/test_modeling_dpt.py
index 248b40121a5..eb968ad9f68 100644
--- a/tests/models/dpt/test_modeling_dpt.py
+++ b/tests/models/dpt/test_modeling_dpt.py
@@ -18,7 +18,7 @@ import unittest
from transformers import DPTConfig
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
-from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
@@ -342,11 +342,15 @@ class DPTModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 384, 384))
self.assertEqual(predicted_depth.shape, expected_shape)
- expected_slice = torch.tensor(
- [[6.3199, 6.3629, 6.4148], [6.3850, 6.3615, 6.4166], [6.3519, 6.3176, 6.3575]]
- ).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [[6.3199, 6.3629, 6.4148], [6.3850, 6.3615, 6.4166], [6.3519, 6.3176, 6.3575]],
+ ("cuda", 8): [[6.3215, 6.3635, 6.4155], [6.3863, 6.3622, 6.4174], [6.3530, 6.3184, 6.3583]],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
- torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
+ torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)
def test_inference_semantic_segmentation(self):
image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large-ade")
diff --git a/tests/models/dpt/test_modeling_dpt_auto_backbone.py b/tests/models/dpt/test_modeling_dpt_auto_backbone.py
index 5ef6c11c375..1505be27cf7 100644
--- a/tests/models/dpt/test_modeling_dpt_auto_backbone.py
+++ b/tests/models/dpt/test_modeling_dpt_auto_backbone.py
@@ -17,7 +17,7 @@ import unittest
from transformers import Dinov2Config, DPTConfig
from transformers.file_utils import is_torch_available, is_vision_available
-from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device
from transformers.utils.import_utils import get_torch_major_and_minor_version
from ...test_configuration_common import ConfigTester
@@ -267,11 +267,15 @@ class DPTModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 576, 736))
self.assertEqual(predicted_depth.shape, expected_shape)
- expected_slice = torch.tensor(
- [[6.0336, 7.1502, 7.4130], [6.8977, 7.2383, 7.2268], [7.9180, 8.0525, 8.0134]]
- ).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [[6.0336, 7.1502, 7.4130], [6.8977, 7.2383, 7.2268], [7.9180, 8.0525, 8.0134]],
+ ("cuda", 8): [[6.0350, 7.1518, 7.4144], [6.8992, 7.2396, 7.2280], [7.9194, 8.0538, 8.0145]],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
- torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
+ torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)
def test_inference_depth_estimation_beit(self):
image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-beit-base-384")
@@ -289,11 +293,23 @@ class DPTModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 384, 384))
self.assertEqual(predicted_depth.shape, expected_shape)
- expected_slice = torch.tensor(
- [[2669.7061, 2663.7144, 2674.9399], [2633.9326, 2650.9092, 2665.4270], [2621.8271, 2632.0129, 2637.2290]]
- ).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [
+ [2669.7061, 2663.7144, 2674.9399],
+ [2633.9326, 2650.9092, 2665.4270],
+ [2621.8271, 2632.0129, 2637.2290],
+ ],
+ ("cuda", 8): [
+ [2669.4292, 2663.4121, 2674.6233],
+ [2633.7400, 2650.7026, 2665.2085],
+ [2621.6572, 2631.8452, 2637.0525],
+ ],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
- torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
+ torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)
def test_inference_depth_estimation_swinv2(self):
image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
@@ -311,8 +327,20 @@ class DPTModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 256, 256))
self.assertEqual(predicted_depth.shape, expected_shape)
- expected_slice = torch.tensor(
- [[1032.7719, 1025.1886, 1030.2661], [1023.7619, 1021.0075, 1024.9121], [1022.5667, 1018.8522, 1021.4145]]
- ).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [
+ [1032.7719, 1025.1886, 1030.2661],
+ [1023.7619, 1021.0075, 1024.9121],
+ [1022.5667, 1018.8522, 1021.4145],
+ ],
+ ("cuda", 8): [
+ [1032.7170, 1025.0629, 1030.1941],
+ [1023.7309, 1020.9786, 1024.8594],
+ [1022.5233, 1018.8235, 1021.3312],
+ ],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
- torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
+ torch.testing.assert_close(outputs.predicted_depth[0, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)
diff --git a/tests/models/dpt/test_modeling_dpt_hybrid.py b/tests/models/dpt/test_modeling_dpt_hybrid.py
index fbdd88278ea..79cad886db4 100644
--- a/tests/models/dpt/test_modeling_dpt_hybrid.py
+++ b/tests/models/dpt/test_modeling_dpt_hybrid.py
@@ -194,6 +194,9 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_config(self):
self.config_tester.run_common_tests()
+ def test_batching_equivalence(self, atol=2e-5, rtol=2e-5):
+ super().test_batching_equivalence(atol=atol, rtol=rtol)
+
@unittest.skip(reason="DPT does not use inputs_embeds")
def test_inputs_embeds(self):
pass
diff --git a/tests/models/encodec/test_modeling_encodec.py b/tests/models/encodec/test_modeling_encodec.py
index 21e9ac10405..a429561b715 100644
--- a/tests/models/encodec/test_modeling_encodec.py
+++ b/tests/models/encodec/test_modeling_encodec.py
@@ -310,12 +310,13 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
def test_feed_forward_chunking(self):
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
+ # original_config.norm_type = "time_group_norm"
for model_class in self.all_model_classes:
torch.manual_seed(0)
config = copy.deepcopy(original_config)
config.chunk_length_s = None
config.overlap = None
- config.sampling_rate = 10
+ config.sampling_rate = 20
model = model_class(config)
model.to(torch_device)
@@ -326,9 +327,9 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
hidden_states_no_chunk = model(**inputs)[1]
torch.manual_seed(0)
- config.chunk_length_s = 1
+ config.chunk_length_s = 2
config.overlap = 0
- config.sampling_rate = 10
+ config.sampling_rate = 20
model = model_class(config)
model.to(torch_device)
diff --git a/tests/models/eomt/test_image_processing_eomt.py b/tests/models/eomt/test_image_processing_eomt.py
index 6d449453de6..594a1d9fe86 100644
--- a/tests/models/eomt/test_image_processing_eomt.py
+++ b/tests/models/eomt/test_image_processing_eomt.py
@@ -84,10 +84,11 @@ class EomtImageProcessingTester:
"num_labels": self.num_labels,
}
- def prepare_fake_eomt_outputs(self, batch_size):
+ def prepare_fake_eomt_outputs(self, batch_size, patch_offsets=None):
return EomtForUniversalSegmentationOutput(
masks_queries_logits=torch.randn((batch_size, self.num_queries, self.height, self.width)),
class_queries_logits=torch.randn((batch_size, self.num_queries, self.num_classes + 1)),
+ patch_offsets=patch_offsets,
)
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
@@ -263,13 +264,13 @@ class EomtImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(images=image, do_split_image=True, return_tensors="pt")
- patch_offsets = inputs.pop("patch_offsets")
+ patch_offsets = inputs["patch_offsets"]
- original_sizes = [image.size[::-1]]
+ target_sizes = [image.size[::-1]]
# For semantic segmentation, the BS of output is 2 coz, two patches are created for the image.
- outputs = self.image_processor_tester.prepare_fake_eomt_outputs(inputs["pixel_values"].shape[0])
- segmentation = processor.post_process_semantic_segmentation(outputs, patch_offsets, original_sizes)
+ outputs = self.image_processor_tester.prepare_fake_eomt_outputs(inputs["pixel_values"].shape[0], patch_offsets)
+ segmentation = processor.post_process_semantic_segmentation(outputs, target_sizes)
self.assertEqual(segmentation[0].shape, (image.height, image.width))
diff --git a/tests/models/eomt/test_modeling_eomt.py b/tests/models/eomt/test_modeling_eomt.py
index c5260302506..c4b026cc18e 100644
--- a/tests/models/eomt/test_modeling_eomt.py
+++ b/tests/models/eomt/test_modeling_eomt.py
@@ -17,12 +17,13 @@ import unittest
import requests
-from transformers import AutoImageProcessor, EomtConfig, EomtForUniversalSegmentation
+from transformers import AutoImageProcessor, EomtConfig, EomtForUniversalSegmentation, pipeline
from transformers.testing_utils import require_torch, require_torch_accelerator, require_torch_fp16, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
+from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
@@ -100,8 +101,9 @@ class EomtForUniversalSegmentationTester:
@require_torch
-class EomtForUniversalSegmentationTest(ModelTesterMixin, unittest.TestCase):
+class EomtForUniversalSegmentationTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (EomtForUniversalSegmentation,) if is_torch_available() else ()
+ pipeline_model_mapping = {"image-segmentation": EomtForUniversalSegmentation} if is_torch_available() else {}
is_encoder_decoder = False
test_pruning = False
test_head_masking = False
@@ -340,7 +342,6 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(images=image, return_tensors="pt").to(model.device)
- patch_offsets = inputs.pop("patch_offsets", None)
with torch.inference_mode():
outputs = model(**inputs)
@@ -348,11 +349,9 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
self.assertTrue(outputs.class_queries_logits.shape == (2, 100, 151))
self.assertTrue(outputs.masks_queries_logits.shape == (2, 100, 128, 128))
- preds = processor.post_process_semantic_segmentation(
- outputs, original_image_sizes=[(image.size[1], image.size[0])], patch_offsets=patch_offsets
- )
+ preds = processor.post_process_semantic_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])[0]
- self.assertTrue(preds.shape[1:] == (image.size[1], image.size[0]))
+ self.assertTrue(preds.shape == (image.size[1], image.size[0]))
# fmt: off
EXPECTED_SLICE = torch.tensor([
@@ -369,7 +368,7 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
], device=model.device)
# fmt: on
- output_slice = preds[0, :10, :10]
+ output_slice = preds[:10, :10]
torch.testing.assert_close(output_slice, EXPECTED_SLICE, rtol=1e-2, atol=1e-2)
@slow
@@ -387,9 +386,7 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 134))
self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160))
- preds = processor.post_process_panoptic_segmentation(
- outputs, original_image_sizes=[(image.size[1], image.size[0])]
- )[0]
+ preds = processor.post_process_panoptic_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])[0]
segmentation, segments_info = preds["segmentation"], preds["segments_info"]
# fmt: off
@@ -438,9 +435,7 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
self.assertTrue(outputs.class_queries_logits.shape == (1, 200, 81))
self.assertTrue(outputs.masks_queries_logits.shape == (1, 200, 160, 160))
- preds = processor.post_process_instance_segmentation(
- outputs, original_image_sizes=[(image.size[1], image.size[0])]
- )[0]
+ preds = processor.post_process_instance_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])[0]
segmentation, segments_info = preds["segmentation"], preds["segments_info"]
# fmt: off
@@ -473,3 +468,15 @@ class EomtForUniversalSegmentationIntegrationTest(unittest.TestCase):
self.assertEqual(actual["id"], expected["id"])
self.assertEqual(actual["label_id"], expected["label_id"])
self.assertAlmostEqual(actual["score"], expected["score"], delta=1e-3)
+
+ @slow
+ def test_segmentation_pipeline(self):
+ image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
+
+ pipe = pipeline(model=self.model_id, subtask="panoptic", device=torch_device)
+ output = pipe(image)
+
+ EXPECTED_OUTPUT_LABELS = ["cat", "cat", "couch", "remote", "remote"]
+
+ output_labels = [segment["label"] for segment in output]
+ self.assertEqual(output_labels, EXPECTED_OUTPUT_LABELS)
diff --git a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py
index e59787fb8c6..cada419ea03 100644
--- a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py
+++ b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py
@@ -33,7 +33,7 @@ from transformers.testing_utils import (
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
-from ...test_modeling_common import ModelTesterMixin, ids_tensor
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -359,9 +359,11 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
def test_initialization(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ config.rescale_prenorm_residual = True
+ configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
- model = model_class(config=config)
+ model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if "dt_proj.bias" in name:
dt = torch.exp(
@@ -380,6 +382,19 @@ class FalconMambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
if param.requires_grad:
# check if it's a ones like
torch.testing.assert_close(param.data, torch.ones_like(param.data), rtol=1e-5, atol=1e-5)
+ else:
+ if param.requires_grad:
+ if (
+ "mixer.conv1d.weight" in name
+ or "mixer.dt_proj.weight" in name
+ or "mixer.out_proj.weight" in name
+ ):
+ continue
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
@slow
# Ignore copy
diff --git a/tests/models/glm4v/test_modeling_glm4v.py b/tests/models/glm4v/test_modeling_glm4v.py
index 48d4a9b858e..a9901ded239 100644
--- a/tests/models/glm4v/test_modeling_glm4v.py
+++ b/tests/models/glm4v/test_modeling_glm4v.py
@@ -69,16 +69,15 @@ class Glm4vVisionText2TextModelTester:
is_training=True,
text_config={
"vocab_size": 99,
- "hidden_size": 32,
- "intermediate_size": 37,
- "num_hidden_layers": 4,
- "num_attention_heads": 4,
- "num_key_value_heads": 2,
+ "hidden_size": 16,
+ "intermediate_size": 22,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 1,
"output_channels": 64,
"hidden_act": "silu",
"max_position_embeddings": 512,
"rope_scaling": {"type": "default", "mrope_section": [2, 1, 1]},
- "max_window_layers": 3,
"rope_theta": 10000,
"tie_word_embeddings": True,
"bos_token_id": 0,
@@ -87,11 +86,10 @@ class Glm4vVisionText2TextModelTester:
},
vision_config={
"depth": 2,
- "embed_dim": 32,
"hidden_act": "silu",
- "hidden_size": 32,
- "mlp_ratio": 4,
- "num_heads": 4,
+ "hidden_size": 48,
+ "out_hidden_size": 16,
+ "intermediate_size": 22,
"patch_size": 14,
"spatial_merge_size": 1,
"temporal_patch_size": 2,
@@ -239,10 +237,6 @@ class Glm4vModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
def test_multi_gpu_data_parallel_forward(self):
pass
- @unittest.skip(reason="We cannot configure to output a smaller model.")
- def test_model_is_small(self):
- pass
-
@unittest.skip("Error with compilation")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
diff --git a/tests/models/grounding_dino/test_modeling_grounding_dino.py b/tests/models/grounding_dino/test_modeling_grounding_dino.py
index 2afe3f0ef38..953255797b5 100644
--- a/tests/models/grounding_dino/test_modeling_grounding_dino.py
+++ b/tests/models/grounding_dino/test_modeling_grounding_dino.py
@@ -586,6 +586,8 @@ class GroundingDinoModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
or "value_proj" in name
or "output_proj" in name
or "reference_points" in name
+ or "vision_proj" in name
+ or "text_proj" in name
):
continue
self.assertIn(
@@ -679,25 +681,48 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.d_model))
self.assertEqual(outputs.logits.shape, expected_shape_logits)
- expected_boxes = torch.tensor(
- [[0.7674, 0.4136, 0.4572], [0.2566, 0.5463, 0.4760], [0.2585, 0.5442, 0.4641]]
- ).to(torch_device)
- expected_logits = torch.tensor(
- [[-4.8913, -0.1900, -0.2161], [-4.9653, -0.3719, -0.3950], [-5.9599, -3.3765, -3.3104]]
- ).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [[0.7674, 0.4136, 0.4572], [0.2566, 0.5463, 0.4760], [0.2585, 0.5442, 0.4641]],
+ ("cuda", 8): [[0.7674, 0.4135, 0.4571], [0.2566, 0.5463, 0.4760], [0.2585, 0.5442, 0.4640]],
+ }
+ )
+ expected_boxes = torch.tensor(expectations.get_expectation()).to(torch_device)
+
+ expectations = Expectations(
+ {
+ (None, None): [[-4.8913, -0.1900, -0.2161], [-4.9653, -0.3719, -0.3950], [-5.9599, -3.3765, -3.3104]],
+ ("cuda", 8): [[-4.8927, -0.1910, -0.2169], [-4.9657, -0.3748, -0.3980], [-5.9579, -3.3812, -3.3153]],
+ }
+ )
+ expected_logits = torch.tensor(expectations.get_expectation()).to(torch_device)
torch.testing.assert_close(outputs.logits[0, :3, :3], expected_logits, rtol=1e-3, atol=1e-3)
expected_shape_boxes = torch.Size((1, model.config.num_queries, 4))
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
- torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_boxes, rtol=1e-4, atol=1e-4)
+ torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_boxes, rtol=2e-4, atol=2e-4)
# verify postprocessing
results = processor.image_processor.post_process_object_detection(
outputs, threshold=0.35, target_sizes=[(image.height, image.width)]
)[0]
- expected_scores = torch.tensor([0.4526, 0.4082]).to(torch_device)
- expected_slice_boxes = torch.tensor([344.8143, 23.1796, 637.4004, 373.8295]).to(torch_device)
+
+ expectations = Expectations(
+ {
+ (None, None): [[0.4526, 0.4082]],
+ ("cuda", 8): [0.4524, 0.4074],
+ }
+ )
+ expected_scores = torch.tensor(expectations.get_expectation()).to(torch_device)
+
+ expectations = Expectations(
+ {
+ (None, None): [344.8143, 23.1796, 637.4004, 373.8295],
+ ("cuda", 8): [344.8210, 23.1831, 637.3943, 373.8227],
+ }
+ )
+ expected_slice_boxes = torch.tensor(expectations.get_expectation()).to(torch_device)
self.assertEqual(len(results["scores"]), 2)
torch.testing.assert_close(results["scores"], expected_scores, rtol=1e-3, atol=1e-3)
diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py
index 840493648ff..b570d1a130b 100644
--- a/tests/models/mamba/test_modeling_mamba.py
+++ b/tests/models/mamba/test_modeling_mamba.py
@@ -24,7 +24,7 @@ from transformers.testing_utils import require_torch, require_torch_multi_gpu, s
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
-from ...test_modeling_common import ModelTesterMixin, ids_tensor
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -326,9 +326,11 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_initialization(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ config.rescale_prenorm_residual = True
+ configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
- model = model_class(config=config)
+ model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if "dt_proj.bias" in name:
dt = torch.exp(
@@ -347,6 +349,19 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
if param.requires_grad:
# check if it's a ones like
torch.testing.assert_close(param.data, torch.ones_like(param.data), rtol=1e-5, atol=1e-5)
+ else:
+ if param.requires_grad:
+ if (
+ "mixer.conv1d.weight" in name
+ or "mixer.dt_proj.weight" in name
+ or "mixer.out_proj.weight" in name
+ ):
+ continue
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
@slow
def test_model_from_pretrained(self):
diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py
index dfa8bca69ef..c9cec231e64 100644
--- a/tests/models/mamba2/test_modeling_mamba2.py
+++ b/tests/models/mamba2/test_modeling_mamba2.py
@@ -13,6 +13,7 @@
# limitations under the License.
+import math
import unittest
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
@@ -28,7 +29,7 @@ from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
-from ...test_modeling_common import ModelTesterMixin, ids_tensor
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -276,14 +277,37 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def test_initialization(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ config.rescale_prenorm_residual = True
+ configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
- model = model_class(config=config)
+ model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
- if "D" in name:
+ if "dt_proj.bias" in name:
+ dt = torch.exp(
+ torch.tensor([0, 1]) * (math.log(config.time_step_max) - math.log(config.time_step_min))
+ + math.log(config.time_step_min)
+ ).clamp(min=config.time_step_floor)
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
+ if param.requires_grad:
+ self.assertTrue(param.data.max().item() <= inv_dt[1])
+ self.assertTrue(param.data.min().item() >= inv_dt[0])
+ elif "A_log" in name:
+ A = torch.arange(1, config.num_heads + 1)
+ torch.testing.assert_close(param.data, torch.log(A), rtol=1e-5, atol=1e-5)
+ elif "D" in name:
if param.requires_grad:
# check if it's a ones like
torch.testing.assert_close(param.data, torch.ones_like(param.data), rtol=1e-5, atol=1e-5)
+ else:
+ if param.requires_grad:
+ if "mixer.conv1d.weight" in name or "mixer.dt_bias" in name or "mixer.out_proj.weight" in name:
+ continue
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
@unittest.skip(reason="Mamba 2 weights are not tied")
def test_tied_weights_keys(self):
diff --git a/tests/models/mask2former/test_modeling_mask2former.py b/tests/models/mask2former/test_modeling_mask2former.py
index 5762a1f6ffc..cf6521424bb 100644
--- a/tests/models/mask2former/test_modeling_mask2former.py
+++ b/tests/models/mask2former/test_modeling_mask2former.py
@@ -21,6 +21,7 @@ from tests.test_modeling_common import floats_tensor
from transformers import AutoModelForImageClassification, Mask2FormerConfig, is_torch_available, is_vision_available
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
from transformers.testing_utils import (
+ Expectations,
require_timm,
require_torch,
require_torch_accelerator,
@@ -403,7 +404,7 @@ class Mask2FormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
)
-TOLERANCE = 1e-4
+TOLERANCE = 2e-4
# We will verify our results on an image of cute cats
@@ -438,31 +439,52 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs)
expected_slice_hidden_state = torch.tensor(
- [[-0.2790, -1.0717, -1.1668], [-0.5128, -0.3128, -0.4987], [-0.5832, 0.1971, -0.0197]]
+ [
+ [-0.2790, -1.0717, -1.1668],
+ [-0.5128, -0.3128, -0.4987],
+ [-0.5832, 0.1971, -0.0197],
+ ]
).to(torch_device)
- self.assertTrue(
- torch.allclose(
- outputs.encoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE
- )
+ torch.testing.assert_close(
+ outputs.encoder_last_hidden_state[0, 0, :3, :3],
+ expected_slice_hidden_state,
+ atol=TOLERANCE,
+ rtol=TOLERANCE,
)
- expected_slice_hidden_state = torch.tensor(
- [[0.8973, 1.1847, 1.1776], [1.1934, 1.5040, 1.5128], [1.1153, 1.4486, 1.4951]]
- ).to(torch_device)
- self.assertTrue(
- torch.allclose(
- outputs.pixel_decoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE
- )
+ expectations = Expectations(
+ {
+ (None, None): [
+ [0.8973, 1.1847, 1.1776],
+ [1.1934, 1.5040, 1.5128],
+ [1.1153, 1.4486, 1.4951],
+ ],
+ ("cuda", 8): [
+ [0.8974, 1.1848, 1.1777],
+ [1.1933, 1.5041, 1.5128],
+ [1.1154, 1.4487, 1.4950],
+ ],
+ }
)
+ expected_slice_hidden_state = torch.tensor(expectations.get_expectation()).to(torch_device)
+ torch.testing.assert_close(outputs.pixel_decoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE,rtol=TOLERANCE) # fmt: skip
- expected_slice_hidden_state = torch.tensor(
- [[2.1152, 1.7000, -0.8603], [1.5808, 1.8004, -0.9353], [1.6043, 1.7495, -0.5999]]
- ).to(torch_device)
- self.assertTrue(
- torch.allclose(
- outputs.transformer_decoder_last_hidden_state[0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE
- )
+ expectations = Expectations(
+ {
+ (None, None): [
+ [2.1152, 1.7000, -0.8603],
+ [1.5808, 1.8004, -0.9353],
+ [1.6043, 1.7495, -0.5999],
+ ],
+ ("cuda", 8): [
+ [2.1153, 1.7004, -0.8604],
+ [1.5807, 1.8007, -0.9354],
+ [1.6040, 1.7498, -0.6001],
+ ],
+ }
)
+ expected_slice_hidden_state = torch.tensor(expectations.get_expectation()).to(torch_device)
+ torch.testing.assert_close(outputs.transformer_decoder_last_hidden_state[0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE, rtol=TOLERANCE) # fmt: skip
def test_inference_universal_segmentation_head(self):
model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
@@ -482,23 +504,40 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase):
self.assertEqual(
masks_queries_logits.shape, (1, model.config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4)
)
- expected_slice = [
- [-8.7839, -9.0056, -8.8121],
- [-7.4104, -7.0313, -6.5401],
- [-6.6105, -6.3427, -6.4675],
- ]
- expected_slice = torch.tensor(expected_slice).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [
+ [-8.7839, -9.0056, -8.8121],
+ [-7.4104, -7.0313, -6.5401],
+ [-6.6105, -6.3427, -6.4675],
+ ],
+ ("cuda", 8): [
+ [-8.7809, -9.0041, -8.8087],
+ [-7.4075, -7.0307, -6.5385],
+ [-6.6088, -6.3417, -6.4627],
+ ],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
torch.testing.assert_close(masks_queries_logits[0, 0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
# class_queries_logits
class_queries_logits = outputs.class_queries_logits
self.assertEqual(class_queries_logits.shape, (1, model.config.num_queries, model.config.num_labels + 1))
- expected_slice = torch.tensor(
- [
- [1.8324, -8.0835, -4.1922],
- [0.8450, -9.0050, -3.6053],
- [0.3045, -7.7293, -3.0275],
- ]
- ).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [
+ [1.8324, -8.0835, -4.1922],
+ [0.8450, -9.0050, -3.6053],
+ [0.3045, -7.7293, -3.0275],
+ ],
+ ("cuda", 8): [
+ [1.8326, -8.0834, -4.1916],
+ [0.8446, -9.0048, -3.6048],
+ [0.3042, -7.7296, -3.0277],
+ ],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
torch.testing.assert_close(
outputs.class_queries_logits[0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE
)
diff --git a/tests/models/maskformer/test_modeling_maskformer.py b/tests/models/maskformer/test_modeling_maskformer.py
index 2f30d4dc3c6..8644439f4a8 100644
--- a/tests/models/maskformer/test_modeling_maskformer.py
+++ b/tests/models/maskformer/test_modeling_maskformer.py
@@ -21,6 +21,7 @@ import numpy as np
from tests.test_modeling_common import floats_tensor
from transformers import DetrConfig, MaskFormerConfig, SwinConfig, is_torch_available, is_vision_available
from transformers.testing_utils import (
+ Expectations,
require_timm,
require_torch,
require_torch_accelerator,
@@ -478,7 +479,7 @@ class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
self.assertEqual(model.model.pixel_level_module.encoder.out_indices, [1, 2, 3])
-TOLERANCE = 1e-4
+TOLERANCE = 2e-4
# We will verify our results on an image of cute cats
@@ -513,31 +514,43 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs)
expected_slice_hidden_state = torch.tensor(
- [[-0.0482, 0.9228, 0.4951], [-0.2547, 0.8017, 0.8527], [-0.0069, 0.3385, -0.0089]]
+ [
+ [-0.0482, 0.9228, 0.4951],
+ [-0.2547, 0.8017, 0.8527],
+ [-0.0069, 0.3385, -0.0089],
+ ]
).to(torch_device)
- self.assertTrue(
- torch.allclose(
- outputs.encoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE
- )
- )
+ torch.allclose(outputs.encoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE, rtol=TOLERANCE) # fmt: skip
- expected_slice_hidden_state = torch.tensor(
- [[-0.8422, -0.8434, -0.9718], [-1.0144, -0.5565, -0.4195], [-1.0038, -0.4484, -0.1961]]
- ).to(torch_device)
- self.assertTrue(
- torch.allclose(
- outputs.pixel_decoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE
- )
+ expectations = Expectations(
+ {
+ (None, None): [[-0.8422, -0.8434, -0.9718], [-1.0144, -0.5565, -0.4195], [-1.0038, -0.4484, -0.1961]],
+ ("cuda", 8): [
+ [-0.8422, -0.8435, -0.9717],
+ [-1.0145, -0.5564, -0.4195],
+ [-1.0040, -0.4486, -0.1962],
+ ],
+ }
)
+ expected_slice_hidden_state = torch.tensor(expectations.get_expectation()).to(torch_device)
+ torch.allclose(outputs.pixel_decoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE,rtol=TOLERANCE) # fmt: skip
- expected_slice_hidden_state = torch.tensor(
- [[0.2852, -0.0159, 0.9735], [0.6254, 0.1858, 0.8529], [-0.0680, -0.4116, 1.8413]]
- ).to(torch_device)
- self.assertTrue(
- torch.allclose(
- outputs.transformer_decoder_last_hidden_state[0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE
- )
+ expectations = Expectations(
+ {
+ (None, None): [
+ [0.2852, -0.0159, 0.9735],
+ [0.6254, 0.1858, 0.8529],
+ [-0.0680, -0.4116, 1.8413],
+ ],
+ ("cuda", 8): [
+ [0.2853, -0.0162, 0.9736],
+ [0.6256, 0.1856, 0.8530],
+ [-0.0679, -0.4118, 1.8416],
+ ],
+ }
)
+ expected_slice_hidden_state = torch.tensor(expectations.get_expectation()).to(torch_device)
+ torch.allclose(outputs.transformer_decoder_last_hidden_state[0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE, rtol=TOLERANCE) # fmt: skip
def test_inference_instance_segmentation_head(self):
model = (
@@ -562,25 +575,42 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
masks_queries_logits.shape,
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
)
- expected_slice = [
- [-1.3737124, -1.7724937, -1.9364233],
- [-1.5977281, -1.9867939, -2.1523695],
- [-1.5795398, -1.9269832, -2.093942],
- ]
- expected_slice = torch.tensor(expected_slice).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [
+ [-1.3737124, -1.7724937, -1.9364233],
+ [-1.5977281, -1.9867939, -2.1523695],
+ [-1.5795398, -1.9269832, -2.093942],
+ ],
+ ("cuda", 8): [
+ [-1.3737, -1.7727, -1.9367],
+ [-1.5979, -1.9871, -2.1527],
+ [-1.5797, -1.9271, -2.0941],
+ ],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
torch.testing.assert_close(masks_queries_logits[0, 0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
# class_queries_logits
class_queries_logits = outputs.class_queries_logits
self.assertEqual(
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
)
- expected_slice = torch.tensor(
- [
- [1.6512e00, -5.2572e00, -3.3519e00],
- [3.6169e-02, -5.9025e00, -2.9313e00],
- [1.0766e-04, -7.7630e00, -5.1263e00],
- ]
- ).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [
+ [1.6512e00, -5.2572e00, -3.3519e00],
+ [3.6169e-02, -5.9025e00, -2.9313e00],
+ [1.0766e-04, -7.7630e00, -5.1263e00],
+ ],
+ ("cuda", 8): [
+ [1.6507e00, -5.2568e00, -3.3520e00],
+ [3.5767e-02, -5.9023e00, -2.9313e00],
+ [-6.2712e-04, -7.7627e00, -5.1268e00],
+ ],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
torch.testing.assert_close(
outputs.class_queries_logits[0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE
)
@@ -608,17 +638,34 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
masks_queries_logits.shape,
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
)
- expected_slice = [[-0.9046, -2.6366, -4.6062], [-3.4179, -5.7890, -8.8057], [-4.9179, -7.6560, -10.7711]]
- expected_slice = torch.tensor(expected_slice).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [[-0.9046, -2.6366, -4.6062], [-3.4179, -5.7890, -8.8057], [-4.9179, -7.6560, -10.7711]],
+ ("cuda", 8): [[-0.9000, -2.6283, -4.5964], [-3.4123, -5.7789, -8.7919], [-4.9132, -7.6444, -10.7557]],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
torch.testing.assert_close(masks_queries_logits[0, 0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
# class_queries_logits
class_queries_logits = outputs.class_queries_logits
self.assertEqual(
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
)
- expected_slice = torch.tensor(
- [[4.7188, -3.2585, -2.8857], [6.6871, -2.9181, -1.2487], [7.2449, -2.2764, -2.1874]]
- ).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [
+ [4.7188, -3.2585, -2.8857],
+ [6.6871, -2.9181, -1.2487],
+ [7.2449, -2.2764, -2.1874],
+ ],
+ ("cuda", 8): [
+ [4.7177, -3.2586, -2.8853],
+ [6.6845, -2.9186, -1.2491],
+ [7.2443, -2.2760, -2.1858],
+ ],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
torch.testing.assert_close(
outputs.class_queries_logits[0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE
)
diff --git a/tests/models/mobilenet_v1/test_modeling_mobilenet_v1.py b/tests/models/mobilenet_v1/test_modeling_mobilenet_v1.py
index 7c05d4b41c9..688542c727a 100644
--- a/tests/models/mobilenet_v1/test_modeling_mobilenet_v1.py
+++ b/tests/models/mobilenet_v1/test_modeling_mobilenet_v1.py
@@ -16,7 +16,7 @@
import unittest
from transformers import MobileNetV1Config
-from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -246,6 +246,12 @@ class MobileNetV1ModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1001))
self.assertEqual(outputs.logits.shape, expected_shape)
- expected_slice = torch.tensor([-4.1739, -1.1233, 3.1205]).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [-4.1739, -1.1233, 3.1205],
+ ("cuda", 8): [-4.1725, -1.1238, 3.1191],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
- torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4)
+ torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4)
diff --git a/tests/models/mobilenet_v2/test_modeling_mobilenet_v2.py b/tests/models/mobilenet_v2/test_modeling_mobilenet_v2.py
index 13c7698af5f..5f3807cda82 100644
--- a/tests/models/mobilenet_v2/test_modeling_mobilenet_v2.py
+++ b/tests/models/mobilenet_v2/test_modeling_mobilenet_v2.py
@@ -16,7 +16,7 @@
import unittest
from transformers import MobileNetV2Config
-from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -301,9 +301,15 @@ class MobileNetV2ModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1001))
self.assertEqual(outputs.logits.shape, expected_shape)
- expected_slice = torch.tensor([0.2445, -1.1993, 0.1905]).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [0.2445, -1.1993, 0.1905],
+ ("cuda", 8): [0.2445, -1.1970, 0.1868],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
- torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4)
+ torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4)
@slow
def test_inference_semantic_segmentation(self):
@@ -324,13 +330,20 @@ class MobileNetV2ModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 21, 65, 65))
self.assertEqual(logits.shape, expected_shape)
- expected_slice = torch.tensor(
- [
- [[17.5790, 17.7581, 18.3355], [18.3257, 18.4230, 18.8973], [18.6169, 18.8650, 19.2187]],
- [[-2.1595, -2.0977, -2.3741], [-2.4226, -2.3028, -2.6835], [-2.7819, -2.5991, -2.7706]],
- [[4.2058, 4.8317, 4.7638], [4.4136, 5.0361, 4.9383], [4.5028, 4.9644, 4.8734]],
- ],
- device=torch_device,
+ expectations = Expectations(
+ {
+ (None, None): [
+ [[17.5790, 17.7581, 18.3355], [18.3257, 18.4230, 18.8973], [18.6169, 18.8650, 19.2187]],
+ [[-2.1595, -2.0977, -2.3741], [-2.4226, -2.3028, -2.6835], [-2.7819, -2.5991, -2.7706]],
+ [[4.2058, 4.8317, 4.7638], [4.4136, 5.0361, 4.9383], [4.5028, 4.9644, 4.8734]],
+ ],
+ ("cuda", 8): [
+ [[17.5809, 17.7571, 18.3341], [18.3240, 18.4216, 18.8974], [18.6174, 18.8662, 19.2177]],
+ [[-2.1562, -2.0942, -2.3703], [-2.4199, -2.2999, -2.6818], [-2.7800, -2.5944, -2.7678]],
+ [[4.2092, 4.8356, 4.7694], [4.4181, 5.0401, 4.9409], [4.5089, 4.9700, 4.8802]],
+ ],
+ }
)
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
- torch.testing.assert_close(logits[0, :3, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
+ torch.testing.assert_close(logits[0, :3, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)
diff --git a/tests/models/mobilevit/test_modeling_mobilevit.py b/tests/models/mobilevit/test_modeling_mobilevit.py
index f6cc09edddd..43fb0d638eb 100644
--- a/tests/models/mobilevit/test_modeling_mobilevit.py
+++ b/tests/models/mobilevit/test_modeling_mobilevit.py
@@ -16,7 +16,7 @@
import unittest
from transformers import MobileViTConfig
-from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -304,9 +304,15 @@ class MobileViTModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
- expected_slice = torch.tensor([-1.9364, -1.2327, -0.4653]).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [-1.9364, -1.2327, -0.4653],
+ ("cuda", 8): [-1.9401, -1.2384, -0.4702],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
- torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4)
+ torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4)
@slow
def test_inference_semantic_segmentation(self):
@@ -327,16 +333,23 @@ class MobileViTModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 21, 32, 32))
self.assertEqual(logits.shape, expected_shape)
- expected_slice = torch.tensor(
- [
- [[6.9713, 6.9786, 7.2422], [7.2893, 7.2825, 7.4446], [7.6580, 7.8797, 7.9420]],
- [[-10.6869, -10.3250, -10.3471], [-10.4228, -9.9868, -9.7132], [-11.0405, -11.0221, -10.7318]],
- [[-3.3089, -2.8539, -2.6740], [-3.2706, -2.5621, -2.5108], [-3.2534, -2.6615, -2.6651]],
- ],
- device=torch_device,
+ expectations = Expectations(
+ {
+ (None, None): [
+ [[6.9713, 6.9786, 7.2422], [7.2893, 7.2825, 7.4446], [7.6580, 7.8797, 7.9420]],
+ [[-10.6869, -10.3250, -10.3471], [-10.4228, -9.9868, -9.7132], [-11.0405, -11.0221, -10.7318]],
+ [[-3.3089, -2.8539, -2.6740], [-3.2706, -2.5621, -2.5108], [-3.2534, -2.6615, -2.6651]],
+ ],
+ ("cuda", 8): [
+ [[6.9661, 6.9753, 7.2386], [7.2864, 7.2785, 7.4429], [7.6577, 7.8770, 7.9387]],
+ [[-10.7046, -10.3411, -10.3641], [-10.4402, -10.0004, -9.7269], [-11.0579, -11.0358, -10.7459]],
+ [[-3.3022, -2.8465, -2.6661], [-3.2654, -2.5542, -2.5055], [-3.2477, -2.6544, -2.6562]],
+ ],
+ }
)
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
- torch.testing.assert_close(logits[0, :3, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
+ torch.testing.assert_close(logits[0, :3, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)
@slow
def test_post_processing_semantic_segmentation(self):
diff --git a/tests/models/mobilevitv2/test_modeling_mobilevitv2.py b/tests/models/mobilevitv2/test_modeling_mobilevitv2.py
index 7a0433f123b..daca2394be2 100644
--- a/tests/models/mobilevitv2/test_modeling_mobilevitv2.py
+++ b/tests/models/mobilevitv2/test_modeling_mobilevitv2.py
@@ -16,7 +16,14 @@
import unittest
from transformers import MobileViTV2Config
-from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
+from transformers.testing_utils import (
+ Expectations,
+ require_torch,
+ require_torch_multi_gpu,
+ require_vision,
+ slow,
+ torch_device,
+)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -317,9 +324,15 @@ class MobileViTV2ModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
- expected_slice = torch.tensor([-1.6336e00, -7.3204e-02, -5.1883e-01]).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [-1.6336e00, -7.3204e-02, -5.1883e-01],
+ ("cuda", 8): [-1.6341, -0.0665, -0.5158],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
- torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4)
+ torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4)
@slow
def test_inference_semantic_segmentation(self):
@@ -340,16 +353,23 @@ class MobileViTV2ModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 21, 32, 32))
self.assertEqual(logits.shape, expected_shape)
- expected_slice = torch.tensor(
- [
- [[7.0863, 7.1525, 6.8201], [6.6931, 6.8770, 6.8933], [6.2978, 7.0366, 6.9636]],
- [[-3.7134, -3.6712, -3.6675], [-3.5825, -3.3549, -3.4777], [-3.3435, -3.3979, -3.2857]],
- [[-2.9329, -2.8003, -2.7369], [-3.0564, -2.4780, -2.0207], [-2.6889, -1.9298, -1.7640]],
- ],
- device=torch_device,
+ expectations = Expectations(
+ {
+ (None, None): [
+ [[7.0863, 7.1525, 6.8201], [6.6931, 6.8770, 6.8933], [6.2978, 7.0366, 6.9636]],
+ [[-3.7134, -3.6712, -3.6675], [-3.5825, -3.3549, -3.4777], [-3.3435, -3.3979, -3.2857]],
+ [[-2.9329, -2.8003, -2.7369], [-3.0564, -2.4780, -2.0207], [-2.6889, -1.9298, -1.7640]],
+ ],
+ ("cuda", 8): [
+ [[7.0866, 7.1509, 6.8188], [6.6935, 6.8757, 6.8927], [6.2988, 7.0365, 6.9631]],
+ [[-3.7113, -3.6686, -3.6643], [-3.5801, -3.3516, -3.4739], [-3.3432, -3.3966, -3.2832]],
+ [[-2.9359, -2.8037, -2.7387], [-3.0595, -2.4798, -2.0222], [-2.6901, -1.9306, -1.7659]],
+ ],
+ }
)
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
- torch.testing.assert_close(logits[0, :3, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
+ torch.testing.assert_close(logits[0, :3, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)
@slow
def test_post_processing_semantic_segmentation(self):
diff --git a/tests/models/omdet_turbo/test_modeling_omdet_turbo.py b/tests/models/omdet_turbo/test_modeling_omdet_turbo.py
index 11568f66f4d..9d76ad392cc 100644
--- a/tests/models/omdet_turbo/test_modeling_omdet_turbo.py
+++ b/tests/models/omdet_turbo/test_modeling_omdet_turbo.py
@@ -629,6 +629,7 @@ class OmDetTurboModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
or "decoder.channel_projection_layers" in name
or "query_position_head" in name
or "decoder.encoder_vision_features" in name
+ or "language_backbone.text_projection" in name
):
continue
self.assertIn(
diff --git a/tests/models/oneformer/test_modeling_oneformer.py b/tests/models/oneformer/test_modeling_oneformer.py
index 58a93a8c4fa..670756a9bfa 100644
--- a/tests/models/oneformer/test_modeling_oneformer.py
+++ b/tests/models/oneformer/test_modeling_oneformer.py
@@ -21,6 +21,7 @@ import numpy as np
from tests.test_modeling_common import floats_tensor
from transformers import AutoModelForImageClassification, OneFormerConfig, is_torch_available, is_vision_available
from transformers.testing_utils import (
+ Expectations,
is_flaky,
require_timm,
require_torch,
@@ -528,7 +529,7 @@ class OneFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
self.assertEqual(model.model.pixel_level_module.encoder.out_indices, [1, 2, 3])
-TOLERANCE = 1e-4
+TOLERANCE = 2e-4
# We will verify our results on an image of cute cats
@@ -574,12 +575,15 @@ class OneFormerModelIntegrationTest(unittest.TestCase):
slice_hidden_state = outputs.pixel_decoder_hidden_states[0][0, 0, :3, :3]
torch.testing.assert_close(slice_hidden_state, expected_slice_hidden_state, atol=TOLERANCE, rtol=TOLERANCE)
- # fmt: off
- expected_slice_hidden_state = [[3.0668, -1.1833, -5.1103], [3.344, -3.362, -5.1101], [2.6017, -4.3613, -4.1444]]
- expected_slice_hidden_state = torch.tensor(expected_slice_hidden_state).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [[3.0668, -1.1833, -5.1103], [3.344, -3.362, -5.1101], [2.6017, -4.3613, -4.1444]],
+ ("cuda", 8): [[3.0590, -1.1903, -5.1119], [3.3919, -3.3547, -5.1469], [2.6041, -4.3592, -4.1406]],
+ }
+ )
+ expected_slice_hidden_state = torch.tensor(expectations.get_expectation()).to(torch_device)
slice_hidden_state = outputs.transformer_decoder_class_predictions[0, :3, :3]
torch.testing.assert_close(slice_hidden_state, expected_slice_hidden_state, atol=TOLERANCE, rtol=TOLERANCE)
- # fmt: on
def test_inference_universal_segmentation_head(self):
model = OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
@@ -599,8 +603,13 @@ class OneFormerModelIntegrationTest(unittest.TestCase):
masks_queries_logits.shape,
(1, model.config.num_queries, inputs_shape[-2] // 4, (inputs_shape[-1] + 2) // 4),
)
- expected_slice = [[3.1848, 4.2141, 4.1993], [2.9000, 3.5721, 3.6603], [2.5358, 3.0883, 3.6168]]
- expected_slice = torch.tensor(expected_slice).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [[3.1848, 4.2141, 4.1993], [2.9000, 3.5721, 3.6603], [2.5358, 3.0883, 3.6168]],
+ ("cuda", 8): [[3.1687, 4.1893, 4.1742], [2.8768, 3.5380, 3.6257], [2.5121, 3.0552, 3.5822]],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
torch.testing.assert_close(masks_queries_logits[0, 0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
# class_queries_logits
@@ -609,8 +618,13 @@ class OneFormerModelIntegrationTest(unittest.TestCase):
class_queries_logits.shape,
(1, model.config.num_queries, model.config.num_labels + 1),
)
- expected_slice = [[3.0668, -1.1833, -5.1103], [3.3440, -3.3620, -5.1101], [2.6017, -4.3613, -4.1444]]
- expected_slice = torch.tensor(expected_slice).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [[3.0668, -1.1833, -5.1103], [3.3440, -3.3620, -5.1101], [2.6017, -4.3613, -4.1444]],
+ ("cuda", 8): [[3.0590, -1.1903, -5.1119], [3.3919, -3.3547, -5.1469], [2.6041, -4.3592, -4.1406]],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
torch.testing.assert_close(class_queries_logits[0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
@require_torch_accelerator
diff --git a/tests/models/poolformer/test_modeling_poolformer.py b/tests/models/poolformer/test_modeling_poolformer.py
index 0fee2b295f0..56300abbe8c 100644
--- a/tests/models/poolformer/test_modeling_poolformer.py
+++ b/tests/models/poolformer/test_modeling_poolformer.py
@@ -17,7 +17,7 @@ import unittest
from transformers import is_torch_available, is_vision_available
from transformers.models.auto import get_values
-from transformers.testing_utils import require_torch, slow, torch_device
+from transformers.testing_utils import Expectations, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@@ -144,6 +144,9 @@ class PoolFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
+ def test_batching_equivalence(self, atol=2e-4, rtol=2e-4):
+ super().test_batching_equivalence(atol=atol, rtol=rtol)
+
@unittest.skip(reason="PoolFormer does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@@ -235,5 +238,11 @@ class PoolFormerModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
- expected_slice = torch.tensor([-0.6113, 0.1685, -0.0492]).to(torch_device)
- torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4)
+ expectations = Expectations(
+ {
+ (None, None): [-0.6113, 0.1685, -0.0492],
+ ("cuda", 8): [-0.6112, 0.1690, -0.0481],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
+ torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4)
diff --git a/tests/models/pvt/test_modeling_pvt.py b/tests/models/pvt/test_modeling_pvt.py
index d52348555ad..eeaabcbd608 100644
--- a/tests/models/pvt/test_modeling_pvt.py
+++ b/tests/models/pvt/test_modeling_pvt.py
@@ -17,6 +17,7 @@ import unittest
from transformers import is_torch_available, is_vision_available
from transformers.testing_utils import (
+ Expectations,
require_accelerate,
require_torch,
require_torch_accelerator,
@@ -153,6 +154,9 @@ class PvtModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
self.model_tester = PvtModelTester(self)
self.config_tester = PvtConfigTester(self, config_class=PvtConfig)
+ def test_batching_equivalence(self, atol=1e-4, rtol=1e-4):
+ super().test_batching_equivalence(atol=atol, rtol=rtol)
+
def test_config(self):
self.config_tester.run_common_tests()
@@ -257,9 +261,15 @@ class PvtModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, model.config.num_labels))
self.assertEqual(outputs.logits.shape, expected_shape)
- expected_slice = torch.tensor([-1.4192, -1.9158, -0.9702]).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [-1.4192, -1.9158, -0.9702],
+ ("cuda", 8): [-1.4194, -1.9161, -0.9705],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
- torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4)
+ torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4)
@slow
def test_inference_model(self):
@@ -278,11 +288,15 @@ class PvtModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 50, 512))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
- expected_slice = torch.tensor(
- [[-0.3086, 1.0402, 1.1816], [-0.2880, 0.5781, 0.6124], [0.1480, 0.6129, -0.0590]]
- ).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [[-0.3086, 1.0402, 1.1816], [-0.2880, 0.5781, 0.6124], [0.1480, 0.6129, -0.0590]],
+ ("cuda", 8): [[-0.3084, 1.0402, 1.1816], [-0.2883, 0.5781, 0.6123], [0.1487, 0.6119, -0.0584]],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
- torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
+ torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)
@slow
@require_accelerate
diff --git a/tests/models/pvt_v2/test_modeling_pvt_v2.py b/tests/models/pvt_v2/test_modeling_pvt_v2.py
index d1a765b19d4..0aca4e6652b 100644
--- a/tests/models/pvt_v2/test_modeling_pvt_v2.py
+++ b/tests/models/pvt_v2/test_modeling_pvt_v2.py
@@ -167,6 +167,9 @@ class PvtV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
+ def test_batching_equivalence(self, atol=5e-4, rtol=5e-4):
+ super().test_batching_equivalence(atol=atol, rtol=rtol)
+
@unittest.skip(reason="Pvt-V2 does not use inputs_embeds")
def test_inputs_embeds(self):
pass
diff --git a/tests/models/regnet/test_modeling_regnet.py b/tests/models/regnet/test_modeling_regnet.py
index 9f88bc8c9c1..8fc8e452da9 100644
--- a/tests/models/regnet/test_modeling_regnet.py
+++ b/tests/models/regnet/test_modeling_regnet.py
@@ -17,7 +17,7 @@ import unittest
from transformers import RegNetConfig
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
-from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@@ -146,6 +146,9 @@ class RegNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_config(self):
self.config_tester.run_common_tests()
+ def test_batching_equivalence(self, atol=3e-5, rtol=3e-5):
+ super().test_batching_equivalence(atol=atol, rtol=rtol)
+
@unittest.skip(reason="RegNet does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@@ -248,6 +251,11 @@ class RegNetModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
- expected_slice = torch.tensor([-0.4180, -1.5051, -3.4836]).to(torch_device)
-
- torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4)
+ expectations = Expectations(
+ {
+ (None, None): [-0.4180, -1.5051, -3.4836],
+ ("cuda", 8): [-0.4168, -1.5056, -3.4836],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
+ torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4)
diff --git a/tests/models/resnet/test_modeling_resnet.py b/tests/models/resnet/test_modeling_resnet.py
index e63d617c0e8..3778bd40054 100644
--- a/tests/models/resnet/test_modeling_resnet.py
+++ b/tests/models/resnet/test_modeling_resnet.py
@@ -16,7 +16,7 @@
import unittest
from transformers import ResNetConfig
-from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_backbone_common import BackboneTesterMixin
@@ -301,9 +301,14 @@ class ResNetModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
- expected_slice = torch.tensor([-11.1069, -9.7877, -8.3777]).to(torch_device)
-
- torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4)
+ expectations = Expectations(
+ {
+ (None, None): [-11.1069, -9.7877, -8.3777],
+ ("cuda", 8): [-11.1112, -9.7916, -8.3788],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
+ torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=2e-4, atol=2e-4)
@require_torch
diff --git a/tests/models/rt_detr/test_modeling_rt_detr.py b/tests/models/rt_detr/test_modeling_rt_detr.py
index fa2938160d7..fad90934265 100644
--- a/tests/models/rt_detr/test_modeling_rt_detr.py
+++ b/tests/models/rt_detr/test_modeling_rt_detr.py
@@ -29,6 +29,7 @@ from transformers import (
is_vision_available,
)
from transformers.testing_utils import (
+ Expectations,
require_torch,
require_torch_accelerator,
require_vision,
@@ -732,45 +733,69 @@ class RTDetrModelIntegrationTest(unittest.TestCase):
expected_shape_logits = torch.Size((1, 300, model.config.num_labels))
self.assertEqual(outputs.logits.shape, expected_shape_logits)
- expected_logits = torch.tensor(
- [
- [-4.64763879776001, -5.001153945922852, -4.978509902954102],
- [-4.159348487854004, -4.703853607177734, -5.946484565734863],
- [-4.437461853027344, -4.65836238861084, -6.235235691070557],
- ]
- ).to(torch_device)
- expected_boxes = torch.tensor(
- [
- [0.1688060760498047, 0.19992263615131378, 0.21225441992282867],
- [0.768376350402832, 0.41226309537887573, 0.4636859893798828],
- [0.25953856110572815, 0.5483334064483643, 0.4777486026287079],
- ]
- ).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [
+ [-4.64763879776001, -5.001153945922852, -4.978509902954102],
+ [-4.159348487854004, -4.703853607177734, -5.946484565734863],
+ [-4.437461853027344, -4.65836238861084, -6.235235691070557],
+ ],
+ ("cuda", 8): [[-4.6471, -5.0008, -4.9786], [-4.1599, -4.7041, -5.9458], [-4.4374, -4.6582, -6.2340]],
+ }
+ )
+ expected_logits = torch.tensor(expectations.get_expectation()).to(torch_device)
- torch.testing.assert_close(outputs.logits[0, :3, :3], expected_logits, rtol=1e-4, atol=1e-4)
+ expectations = Expectations(
+ {
+ (None, None): [
+ [0.1688060760498047, 0.19992263615131378, 0.21225441992282867],
+ [0.768376350402832, 0.41226309537887573, 0.4636859893798828],
+ [0.25953856110572815, 0.5483334064483643, 0.4777486026287079],
+ ],
+ ("cuda", 8): [[0.1688, 0.1999, 0.2123], [0.7684, 0.4123, 0.4637], [0.2596, 0.5483, 0.4777]],
+ }
+ )
+ expected_boxes = torch.tensor(expectations.get_expectation()).to(torch_device)
+
+ torch.testing.assert_close(outputs.logits[0, :3, :3], expected_logits, rtol=2e-4, atol=2e-4)
expected_shape_boxes = torch.Size((1, 300, 4))
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
- torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_boxes, rtol=1e-4, atol=1e-4)
+ torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_boxes, rtol=2e-4, atol=2e-4)
# verify postprocessing
results = image_processor.post_process_object_detection(
outputs, threshold=0.0, target_sizes=[image.size[::-1]]
)[0]
- expected_scores = torch.tensor(
- [0.9703017473220825, 0.9599503874778748, 0.9575679302215576, 0.9506784677505493], device=torch_device
- )
- expected_labels = [57, 15, 15, 65]
- expected_slice_boxes = torch.tensor(
- [
- [0.13774872, 0.37821293, 640.13074, 476.21088],
- [343.38132, 24.276838, 640.1404, 371.49573],
- [13.225126, 54.179348, 318.98422, 472.2207],
- [40.114475, 73.44104, 175.9573, 118.48469],
- ],
- device=torch_device,
- )
- torch.testing.assert_close(results["scores"][:4], expected_scores, rtol=1e-4, atol=1e-4)
+ expectations = Expectations(
+ {
+ (None, None): [0.9703017473220825, 0.9599503874778748, 0.9575679302215576, 0.9506784677505493],
+ ("cuda", 8): [0.9704, 0.9599, 0.9576, 0.9507],
+ }
+ )
+ expected_scores = torch.tensor(expectations.get_expectation()).to(torch_device)
+
+ expected_labels = [57, 15, 15, 65]
+
+ expectations = Expectations(
+ {
+ (None, None): [
+ [0.13774872, 0.37821293, 640.13074, 476.21088],
+ [343.38132, 24.276838, 640.1404, 371.49573],
+ [13.225126, 54.179348, 318.98422, 472.2207],
+ [40.114475, 73.44104, 175.9573, 118.48469],
+ ],
+ ("cuda", 8): [
+ [1.4183e-01, 3.8063e-01, 6.4013e02, 4.7621e02],
+ [3.4338e02, 2.4275e01, 6.4014e02, 3.7150e02],
+ [1.3236e01, 5.4179e01, 3.1899e02, 4.7222e02],
+ [4.0114e01, 7.3441e01, 1.7596e02, 1.1848e02],
+ ],
+ }
+ )
+ expected_slice_boxes = torch.tensor(expectations.get_expectation()).to(torch_device)
+
+ torch.testing.assert_close(results["scores"][:4], expected_scores, rtol=2e-4, atol=2e-4)
self.assertSequenceEqual(results["labels"][:4].tolist(), expected_labels)
- torch.testing.assert_close(results["boxes"][:4], expected_slice_boxes, rtol=1e-4, atol=1e-4)
+ torch.testing.assert_close(results["boxes"][:4], expected_slice_boxes, rtol=2e-4, atol=2e-4)
diff --git a/tests/models/rt_detr_v2/test_modeling_rt_detr_v2.py b/tests/models/rt_detr_v2/test_modeling_rt_detr_v2.py
index a78f11ea46c..79202d3cf71 100644
--- a/tests/models/rt_detr_v2/test_modeling_rt_detr_v2.py
+++ b/tests/models/rt_detr_v2/test_modeling_rt_detr_v2.py
@@ -28,6 +28,7 @@ from transformers import (
is_vision_available,
)
from transformers.testing_utils import (
+ Expectations,
require_torch,
require_torch_accelerator,
require_vision,
@@ -736,42 +737,60 @@ class RTDetrV2ModelIntegrationTest(unittest.TestCase):
expected_shape_logits = torch.Size((1, 300, model.config.num_labels))
self.assertEqual(outputs.logits.shape, expected_shape_logits)
- expected_logits = torch.tensor(
- [
- [-3.7047, -5.1914, -6.1787],
- [-4.0108, -9.3449, -5.2047],
- [-4.1287, -4.7461, -5.8633],
- ]
- ).to(torch_device)
- expected_boxes = torch.tensor(
- [
- [0.2582, 0.5497, 0.4764],
- [0.1684, 0.1985, 0.2120],
- [0.7665, 0.4146, 0.4669],
- ]
- ).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [[-3.7047, -5.1914, -6.1787], [-4.0108, -9.3449, -5.2047], [-4.1287, -4.7461, -5.8633]],
+ ("cuda", 8): [[-3.7039, -5.1923, -6.1787], [-4.0106, -9.3452, -5.2045], [-4.1285, -4.7468, -5.8641]],
+ }
+ )
+ expected_logits = torch.tensor(expectations.get_expectation()).to(torch_device)
- torch.testing.assert_close(outputs.logits[0, :3, :3], expected_logits, atol=1e-4, rtol=1e-4)
+ expectations = Expectations(
+ {
+ (None, None): [[0.2582, 0.5497, 0.4764], [0.1684, 0.1985, 0.2120], [0.7665, 0.4146, 0.4669]],
+ }
+ )
+ expected_boxes = torch.tensor(expectations.get_expectation()).to(torch_device)
+
+ torch.testing.assert_close(outputs.logits[0, :3, :3], expected_logits, atol=2e-4, rtol=2e-4)
expected_shape_boxes = torch.Size((1, 300, 4))
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
- torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_boxes, atol=1e-4, rtol=1e-4)
+ torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_boxes, atol=2e-4, rtol=2e-4)
# verify postprocessing
results = image_processor.post_process_object_detection(
outputs, threshold=0.0, target_sizes=[image.size[::-1]]
)[0]
- expected_scores = torch.tensor([0.9652, 0.9599, 0.9462, 0.8613], device=torch_device)
- expected_labels = [15, 15, 65, 57]
- expected_slice_boxes = torch.tensor(
- [
- [3.4114e02, 2.5111e01, 6.3998e02, 3.7289e02],
- [1.2780e01, 5.6346e01, 3.1767e02, 4.7134e02],
- [3.9959e01, 7.3117e01, 1.7565e02, 1.1744e02],
- [-1.0521e-01, 2.9717e00, 6.3989e02, 4.7362e02],
- ],
- device=torch_device,
+
+ expectations = Expectations(
+ {
+ (None, None): [0.9652, 0.9599, 0.9462, 0.8613],
+ ("cuda", 8): [0.9652, 0.9599, 0.9461, 0.8613],
+ }
)
- self.assertTrue(torch.allclose(results["scores"][:4], expected_scores, atol=1e-3, rtol=1e-4))
+ expected_scores = torch.tensor(expectations.get_expectation()).to(torch_device)
+
+ expected_labels = [15, 15, 65, 57]
+
+ expectations = Expectations(
+ {
+ (None, None): [
+ [3.4114e02, 2.5111e01, 6.3998e02, 3.7289e02],
+ [1.2780e01, 5.6346e01, 3.1767e02, 4.7134e02],
+ [3.9959e01, 7.3117e01, 1.7565e02, 1.1744e02],
+ [-1.0521e-01, 2.9717e00, 6.3989e02, 4.7362e02],
+ ],
+ ("cuda", 8): [
+ [3.4115e02, 2.5109e01, 6.3997e02, 3.7290e02],
+ [1.2785e01, 5.6350e01, 3.1767e02, 4.7134e02],
+ [3.9959e01, 7.3117e01, 1.7565e02, 1.1744e02],
+ [-1.0471e-01, 2.9680e00, 6.3989e02, 4.7362e02],
+ ],
+ }
+ )
+ expected_slice_boxes = torch.tensor(expectations.get_expectation()).to(torch_device)
+
+ torch.testing.assert_close(results["scores"][:4], expected_scores, atol=1e-3, rtol=2e-4)
self.assertSequenceEqual(results["labels"][:4].tolist(), expected_labels)
- torch.testing.assert_close(results["boxes"][:4], expected_slice_boxes, atol=1e-3, rtol=1e-4)
+ torch.testing.assert_close(results["boxes"][:4], expected_slice_boxes, atol=1e-3, rtol=2e-4)
diff --git a/tests/models/segformer/test_modeling_segformer.py b/tests/models/segformer/test_modeling_segformer.py
index cd75545c62a..fcd6594217c 100644
--- a/tests/models/segformer/test_modeling_segformer.py
+++ b/tests/models/segformer/test_modeling_segformer.py
@@ -16,7 +16,7 @@
import unittest
from transformers import SegformerConfig, is_torch_available, is_vision_available
-from transformers.testing_utils import require_torch, slow, torch_device
+from transformers.testing_utils import Expectations, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@@ -200,6 +200,9 @@ class SegformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs)
+ def test_batching_equivalence(self, atol=2e-4, rtol=2e-4):
+ super().test_batching_equivalence(atol=atol, rtol=rtol)
+
@unittest.skip(reason="SegFormer does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@@ -367,14 +370,22 @@ class SegformerModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, model.config.num_labels, 128, 128))
self.assertEqual(outputs.logits.shape, expected_shape)
- expected_slice = torch.tensor(
- [
- [[-4.6310, -5.5232, -6.2356], [-5.1921, -6.1444, -6.5996], [-5.4424, -6.2790, -6.7574]],
- [[-12.1391, -13.3122, -13.9554], [-12.8732, -13.9352, -14.3563], [-12.9438, -13.8226, -14.2513]],
- [[-12.5134, -13.4686, -14.4915], [-12.8669, -14.4343, -14.7758], [-13.2523, -14.5819, -15.0694]],
- ]
- ).to(torch_device)
- torch.testing.assert_close(outputs.logits[0, :3, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
+ expectations = Expectations(
+ {
+ (None, None): [
+ [[-4.6310, -5.5232, -6.2356], [-5.1921, -6.1444, -6.5996], [-5.4424, -6.2790, -6.7574]],
+ [[-12.1391, -13.3122, -13.9554], [-12.8732, -13.9352, -14.3563], [-12.9438, -13.8226, -14.2513]],
+ [[-12.5134, -13.4686, -14.4915], [-12.8669, -14.4343, -14.7758], [-13.2523, -14.5819, -15.0694]],
+ ],
+ ("cuda", 8): [
+ [[-4.6310, -5.5232, -6.2361], [-5.1918, -6.1445, -6.5996], [-5.4427, -6.2792, -6.7580]],
+ [[-12.1397, -13.3124, -13.9551], [-12.8736, -13.9347, -14.3569], [-12.9440, -13.8222, -14.2514]],
+ [[-12.5135, -13.4682, -14.4913], [-12.8670, -14.4339, -14.7766], [-13.2519, -14.5800, -15.0685]],
+ ],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
+ torch.testing.assert_close(outputs.logits[0, :3, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)
@slow
def test_inference_image_segmentation_city(self):
@@ -396,13 +407,24 @@ class SegformerModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, model.config.num_labels, 128, 128))
self.assertEqual(outputs.logits.shape, expected_shape)
- expected_slice = torch.tensor(
- [
- [[-13.5748, -13.9111, -12.6500], [-14.3500, -15.3683, -14.2328], [-14.7532, -16.0424, -15.6087]],
- [[-17.1651, -15.8725, -12.9653], [-17.2580, -17.3718, -14.8223], [-16.6058, -16.8783, -16.7452]],
- [[-3.6456, -3.0209, -1.4203], [-3.0797, -3.1959, -2.0000], [-1.8757, -1.9217, -1.6997]],
- ]
- ).to(torch_device)
+ expected_slice = torch.tensor([]).to(torch_device)
+
+ expectations = Expectations(
+ {
+ (None, None): [
+ [[-13.5748, -13.9111, -12.6500], [-14.3500, -15.3683, -14.2328], [-14.7532, -16.0424, -15.6087]],
+ [[-17.1651, -15.8725, -12.9653], [-17.2580, -17.3718, -14.8223], [-16.6058, -16.8783, -16.7452]],
+ [[-3.6456, -3.0209, -1.4203], [-3.0797, -3.1959, -2.0000], [-1.8757, -1.9217, -1.6997]],
+ ],
+ ("cuda", 8): [
+ [[-13.5728, -13.9089, -12.6492], [-14.3478, -15.3656, -14.2309], [-14.7512, -16.0394, -15.6065]],
+ [[-17.1642, -15.8704, -12.9641], [-17.2572, -17.3701, -14.8214], [-16.6043, -16.8761, -16.7425]],
+ [[-3.6444, -3.0189, -1.4195], [-3.0787, -3.1953, -1.9993], [-1.8755, -1.9219, -1.7002]],
+ ],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
+
torch.testing.assert_close(outputs.logits[0, :3, :3, :3], expected_slice, rtol=1e-1, atol=1e-1)
@slow
diff --git a/tests/models/seggpt/test_modeling_seggpt.py b/tests/models/seggpt/test_modeling_seggpt.py
index 1176613fa20..4083276e185 100644
--- a/tests/models/seggpt/test_modeling_seggpt.py
+++ b/tests/models/seggpt/test_modeling_seggpt.py
@@ -21,6 +21,7 @@ from datasets import load_dataset
from transformers import SegGptConfig
from transformers.testing_utils import (
+ Expectations,
require_torch,
require_vision,
slow,
@@ -379,15 +380,23 @@ class SegGptModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 3, 896, 448))
self.assertEqual(outputs.pred_masks.shape, expected_shape)
- expected_slice = torch.tensor(
- [
- [[-2.1208, -2.1190, -2.1198], [-2.1237, -2.1228, -2.1227], [-2.1232, -2.1226, -2.1228]],
- [[-2.0405, -2.0396, -2.0403], [-2.0434, -2.0434, -2.0433], [-2.0428, -2.0432, -2.0434]],
- [[-1.8102, -1.8088, -1.8099], [-1.8131, -1.8126, -1.8129], [-1.8130, -1.8128, -1.8131]],
- ]
- ).to(torch_device)
+ expectations = Expectations(
+ {
+ (None, None): [
+ [[-2.1208, -2.1190, -2.1198], [-2.1237, -2.1228, -2.1227], [-2.1232, -2.1226, -2.1228]],
+ [[-2.0405, -2.0396, -2.0403], [-2.0434, -2.0434, -2.0433], [-2.0428, -2.0432, -2.0434]],
+ [[-1.8102, -1.8088, -1.8099], [-1.8131, -1.8126, -1.8129], [-1.8130, -1.8128, -1.8131]],
+ ],
+ ("cuda", 8): [
+ [[-2.1208, -2.1189, -2.1198], [-2.1236, -2.1229, -2.1230], [-2.1233, -2.1227, -2.1228]],
+ [[-2.0408, -2.0398, -2.0405], [-2.0435, -2.0437, -2.0438], [-2.0431, -2.0435, -2.0436]],
+ [[-1.8101, -1.8086, -1.8098], [-1.8129, -1.8126, -1.8130], [-1.8128, -1.8128, -1.8130]],
+ ],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device)
- torch.testing.assert_close(outputs.pred_masks[0, :, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
+ torch.testing.assert_close(outputs.pred_masks[0, :, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)
result = image_processor.post_process_semantic_segmentation(outputs, [input_image.size[::-1]])[0]
diff --git a/tests/models/smolvlm/test_modeling_smolvlm.py b/tests/models/smolvlm/test_modeling_smolvlm.py
index 280399eb6b8..135043e9860 100644
--- a/tests/models/smolvlm/test_modeling_smolvlm.py
+++ b/tests/models/smolvlm/test_modeling_smolvlm.py
@@ -536,23 +536,24 @@ class SmolVLMForConditionalGenerationIntegrationTest(unittest.TestCase):
).content
)
)
- self.image2 = Image.open(
- BytesIO(requests.get("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg").content)
- )
- self.image3 = Image.open(
- BytesIO(
- requests.get(
- "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"
- ).content
- )
- )
+
+ self.video_messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "video",
+ "path": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/assisted-generation/gif_1_1080p.mov",
+ },
+ {"type": "text", "text": "Describe this video in detail"},
+ ],
+ },
+ ]
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@slow
- # TODO (Orr?) this is a dummy test to check if the model generates things that make sense.
- # Needs to be expanded to a tiny video
def test_integration_test(self):
model = SmolVLMForConditionalGeneration.from_pretrained(
"HuggingFaceTB/SmolVLM2-256M-Video-Instruct",
@@ -571,3 +572,26 @@ class SmolVLMForConditionalGenerationIntegrationTest(unittest.TestCase):
expected_generated_text = "\n\n\n\nIn this image, we see a view of the Statue of Liberty and the"
self.assertEqual(generated_texts[0], expected_generated_text)
+
+ @slow
+ def test_integration_test_video(self):
+ model = SmolVLMForConditionalGeneration.from_pretrained(
+ "HuggingFaceTB/SmolVLM2-256M-Video-Instruct",
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ )
+
+ # Create inputs
+ inputs = self.processor.apply_chat_template(
+ self.video_messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ ).to(device=torch_device, dtype=torch.bfloat16)
+
+ generated_ids = model.generate(**inputs, max_new_tokens=20)
+ generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
+
+ expected_generated_text = 'User: You are provided the following series of nine frames from a 0:00:09 [H:MM:SS] video.\n\nFrame from 00:00:\nFrame from 00:01:\nFrame from 00:02:\nFrame from 00:03:\nFrame from 00:04:\nFrame from 00:05:\nFrame from 00:06:\nFrame from 00:08:\nFrame from 00:09:\n\nDescribe this video in detail\nAssistant: The video depicts a large language model architecture, specifically a language model with a "quick brown" feature' # fmt: skip
+ self.assertEqual(generated_texts[0], expected_generated_text)
diff --git a/tests/models/swin2sr/test_modeling_swin2sr.py b/tests/models/swin2sr/test_modeling_swin2sr.py
index 125d5418e8e..a1767a0ab24 100644
--- a/tests/models/swin2sr/test_modeling_swin2sr.py
+++ b/tests/models/swin2sr/test_modeling_swin2sr.py
@@ -16,7 +16,7 @@
import unittest
from transformers import Swin2SRConfig
-from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.testing_utils import Expectations, require_torch, require_vision, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -360,7 +360,12 @@ class Swin2SRModelIntegrationTest(unittest.TestCase):
# verify the logits
expected_shape = torch.Size([1, 3, 976, 1296])
self.assertEqual(outputs.reconstruction.shape, expected_shape)
- expected_slice = torch.tensor(
- [[0.5454, 0.5542, 0.5640], [0.5518, 0.5562, 0.5649], [0.5391, 0.5425, 0.5620]], dtype=model.dtype
- ).to(torch_device)
- torch.testing.assert_close(outputs.reconstruction[0, 0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
+
+ expectations = Expectations(
+ {
+ (None, None): [[0.5454, 0.5542, 0.5640], [0.5518, 0.5562, 0.5649], [0.5391, 0.5425, 0.5620]],
+ ("cuda", 8): [[0.5454, 0.5547, 0.5640], [0.5522, 0.5562, 0.5649], [0.5391, 0.5425, 0.5620]],
+ }
+ )
+ expected_slice = torch.tensor(expectations.get_expectation()).to(torch_device, dtype=model.dtype)
+ torch.testing.assert_close(outputs.reconstruction[0, 0, :3, :3], expected_slice, rtol=2e-4, atol=2e-4)
diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py
index 2b5eb30dcf4..67b59fef4ff 100644
--- a/tests/models/switch_transformers/test_modeling_switch_transformers.py
+++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py
@@ -19,6 +19,7 @@ import unittest
from transformers import SwitchTransformersConfig, is_torch_available
from transformers.testing_utils import (
+ Expectations,
require_tokenizers,
require_torch,
require_torch_accelerator,
@@ -1035,18 +1036,28 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
decoder_input_ids = torch.ones((32, 64), dtype=torch.long).to(torch_device)
# fmt: off
- EXPECTED_MEAN_LOGITS = torch.Tensor(
- [
- -0.204102, -0.193359, 0.523438, -0.296875, 0.108887,
- 0.0211182, 0.605469, -0.100586, -0.0551758, 0.296875,
- 0.0090332, 0.174805, 0.139648, -0.170898, -0.0981445,
- 0.0245361, 0.0373535, 0.050293, -0.212891, 0.129883,
- 0.390625, -0.203125, -0.122559, -0.180664, 0.0437012,
- -0.349609, -0.0250244, -0.104004, -0.15918, -0.133789
- ]
- ).to(torch.bfloat16)
+ expectations = Expectations(
+ {
+ (None, None): [
+ -0.204102, -0.193359, 0.523438, -0.296875, 0.108887,
+ 0.0211182, 0.605469, -0.100586, -0.0551758, 0.296875,
+ 0.0090332, 0.174805, 0.139648, -0.170898, -0.0981445,
+ 0.0245361, 0.0373535, 0.050293, -0.212891, 0.129883,
+ 0.390625, -0.203125, -0.122559, -0.180664, 0.0437012,
+ -0.349609, -0.0250244, -0.104004, -0.15918, -0.133789
+ ],
+ ("cuda", 8): [
+ -0.2051, -0.1914, 0.5352, -0.2988, 0.1108, 0.0200, 0.6094, -0.1025,
+ -0.0549, 0.2988, -0.0018, 0.1758, 0.1348, -0.1689, -0.1035, 0.0266,
+ 0.0383, 0.0493, -0.2119, 0.1328, 0.3906, -0.2041, -0.1240, -0.1836,
+ 0.0454, -0.3477, -0.0256, -0.1050, -0.1572, -0.1338
+ ],
+ }
+ )
+ EXPECTED_MEAN_LOGITS = torch.tensor(expectations.get_expectation()).to(torch_device, dtype=torch.bfloat16)
# fmt: on
- hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state.cpu()
+
+ hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state
hf_logits = hf_logits[0, 0, :30]
torch.testing.assert_close(hf_logits, EXPECTED_MEAN_LOGITS, rtol=6e-3, atol=9e-3)
diff --git a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py
index f7f374ed574..3f103309a04 100644
--- a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py
+++ b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py
@@ -153,10 +153,18 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
def test_retain_grad_hidden_states_attentions(self):
pass
+ @unittest.skip(reason="TimmWrapper initialization is managed on the timm side")
+ def test_can_init_all_missing_weights(self):
+ pass
+
@unittest.skip(reason="TimmWrapper initialization is managed on the timm side")
def test_initialization(self):
pass
+ @unittest.skip(reason="TimmWrapper initialization is managed on the timm side")
+ def test_mismatched_shapes_have_properly_initialized_weights(self):
+ pass
+
@unittest.skip(reason="Need to use a timm model and there is no tiny model available.")
def test_model_is_small(self):
pass
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index 0587c73bd9b..da48081d6bf 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -855,7 +855,7 @@ class ModelTesterMixin:
# For now, skip everything older than 2025 and "important models" (too much models to patch otherwise)
# Use `supports_cache_class` as a proxy to judge "important" models in order to prioritize them
# TODO: relax this as we patch more and more models
- if addition_year < 2025 and not model_class._supports_cache_class:
+ if addition_year < 2024 and not model_class._supports_cache_class:
self.skipTest(reason=f"{model_class} is not a priorited model for now.")
# Monkey patch the method to add a seed (we do it on PreTrainedModel._initialize_weights, which wraps
@@ -895,6 +895,11 @@ class ModelTesterMixin:
model_from_config.state_dict().items(), model_from_pretrained.state_dict().items()
):
self.assertEqual(k1, k2, "The keys from each model should be the same")
+
+ # In case using torch.nn.utils.parametrizations on a module, we should skip the resulting keys
+ if re.search(r"\.parametrizations\..*?\.original[01]", k1):
+ continue
+
# Since we added the seed, they should be exactly the same (i.e. using allclose maybe be wrong due
# to very low std in init function)
if not (v1 == v2).all():
diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py
index ebede32f3e4..855bcaaf27a 100644
--- a/tests/test_processing_common.py
+++ b/tests/test_processing_common.py
@@ -351,6 +351,18 @@ class ProcessorTesterMixin:
return_tensors="pt",
)
+ def test_args_overlap_kwargs(self):
+ if "image_processor" not in self.processor_class.attributes:
+ self.skipTest(f"image_processor attribute not present in {self.processor_class}")
+ processor_first = self.get_processor()
+ image_processor = processor_first.image_processor
+ image_processor.is_override = True
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ processor_first.save_pretrained(tmpdirname)
+ processor_second = self.processor_class.from_pretrained(tmpdirname, image_processor=image_processor)
+ self.assertTrue(processor_second.image_processor.is_override)
+
def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")