mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Support : Leverage Accelerate for object detection/segmentation models (#28312)
* made changes for object detection models * added support for segmentation models. * Made changes for segmentaion models * Changed import statements * solving conflicts * removed conflicts * Resolving commits * Removed conflicts * Fix : Pixel_mask_value set to False
This commit is contained in:
parent
aee11fe427
commit
0eb408551c
@ -30,6 +30,7 @@ from ...utils import (
|
|||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_accelerate_available,
|
||||||
is_scipy_available,
|
is_scipy_available,
|
||||||
is_timm_available,
|
is_timm_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
@ -41,6 +42,10 @@ from ...utils.backbone_utils import load_backbone
|
|||||||
from .configuration_conditional_detr import ConditionalDetrConfig
|
from .configuration_conditional_detr import ConditionalDetrConfig
|
||||||
|
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import PartialState
|
||||||
|
from accelerate.utils import reduce
|
||||||
|
|
||||||
if is_scipy_available():
|
if is_scipy_available():
|
||||||
from scipy.optimize import linear_sum_assignment
|
from scipy.optimize import linear_sum_assignment
|
||||||
|
|
||||||
@ -2507,11 +2512,12 @@ class ConditionalDetrLoss(nn.Module):
|
|||||||
# Compute the average number of target boxes across all nodes, for normalization purposes
|
# Compute the average number of target boxes across all nodes, for normalization purposes
|
||||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||||
# (Niels): comment out function below, distributed training to be added
|
|
||||||
# if is_dist_avail_and_initialized():
|
world_size = 1
|
||||||
# torch.distributed.all_reduce(num_boxes)
|
if PartialState._shared_state != {}:
|
||||||
# (Niels) in original implementation, num_boxes is divided by get_world_size()
|
num_boxes = reduce(num_boxes)
|
||||||
num_boxes = torch.clamp(num_boxes, min=1).item()
|
world_size = PartialState().num_processes
|
||||||
|
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||||
|
|
||||||
# Compute all the requested losses
|
# Compute all the requested losses
|
||||||
losses = {}
|
losses = {}
|
||||||
|
@ -43,7 +43,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
|||||||
from ...modeling_outputs import BaseModelOutput
|
from ...modeling_outputs import BaseModelOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import meshgrid
|
from ...pytorch_utils import meshgrid
|
||||||
from ...utils import is_ninja_available, logging
|
from ...utils import is_accelerate_available, is_ninja_available, logging
|
||||||
from ...utils.backbone_utils import load_backbone
|
from ...utils.backbone_utils import load_backbone
|
||||||
from .configuration_deformable_detr import DeformableDetrConfig
|
from .configuration_deformable_detr import DeformableDetrConfig
|
||||||
from .load_custom import load_cuda_kernels
|
from .load_custom import load_cuda_kernels
|
||||||
@ -65,6 +65,10 @@ else:
|
|||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from transformers.image_transforms import center_to_corners_format
|
from transformers.image_transforms import center_to_corners_format
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import PartialState
|
||||||
|
from accelerate.utils import reduce
|
||||||
|
|
||||||
|
|
||||||
class MultiScaleDeformableAttentionFunction(Function):
|
class MultiScaleDeformableAttentionFunction(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -2246,11 +2250,11 @@ class DeformableDetrLoss(nn.Module):
|
|||||||
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
||||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||||
# (Niels): comment out function below, distributed training to be added
|
world_size = 1
|
||||||
# if is_dist_avail_and_initialized():
|
if PartialState._shared_state != {}:
|
||||||
# torch.distributed.all_reduce(num_boxes)
|
num_boxes = reduce(num_boxes)
|
||||||
# (Niels) in original implementation, num_boxes is divided by get_world_size()
|
world_size = PartialState().num_processes
|
||||||
num_boxes = torch.clamp(num_boxes, min=1).item()
|
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||||
|
|
||||||
# Compute all the requested losses
|
# Compute all the requested losses
|
||||||
losses = {}
|
losses = {}
|
||||||
|
@ -30,6 +30,7 @@ from ...utils import (
|
|||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_accelerate_available,
|
||||||
is_scipy_available,
|
is_scipy_available,
|
||||||
is_timm_available,
|
is_timm_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
@ -41,6 +42,10 @@ from ...utils.backbone_utils import load_backbone
|
|||||||
from .configuration_detr import DetrConfig
|
from .configuration_detr import DetrConfig
|
||||||
|
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import PartialState
|
||||||
|
from accelerate.utils import reduce
|
||||||
|
|
||||||
if is_scipy_available():
|
if is_scipy_available():
|
||||||
from scipy.optimize import linear_sum_assignment
|
from scipy.optimize import linear_sum_assignment
|
||||||
|
|
||||||
@ -2204,11 +2209,11 @@ class DetrLoss(nn.Module):
|
|||||||
# Compute the average number of target boxes across all nodes, for normalization purposes
|
# Compute the average number of target boxes across all nodes, for normalization purposes
|
||||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||||
# (Niels): comment out function below, distributed training to be added
|
world_size = 1
|
||||||
# if is_dist_avail_and_initialized():
|
if PartialState._shared_state != {}:
|
||||||
# torch.distributed.all_reduce(num_boxes)
|
num_boxes = reduce(num_boxes)
|
||||||
# (Niels) in original implementation, num_boxes is divided by get_world_size()
|
world_size = PartialState().num_processes
|
||||||
num_boxes = torch.clamp(num_boxes, min=1).item()
|
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||||
|
|
||||||
# Compute all the requested losses
|
# Compute all the requested losses
|
||||||
losses = {}
|
losses = {}
|
||||||
|
@ -34,7 +34,7 @@ from ...file_utils import (
|
|||||||
)
|
)
|
||||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions
|
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import logging
|
from ...utils import is_accelerate_available, logging
|
||||||
from ...utils.backbone_utils import load_backbone
|
from ...utils.backbone_utils import load_backbone
|
||||||
from .configuration_mask2former import Mask2FormerConfig
|
from .configuration_mask2former import Mask2FormerConfig
|
||||||
|
|
||||||
@ -42,6 +42,10 @@ from .configuration_mask2former import Mask2FormerConfig
|
|||||||
if is_scipy_available():
|
if is_scipy_available():
|
||||||
from scipy.optimize import linear_sum_assignment
|
from scipy.optimize import linear_sum_assignment
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import PartialState
|
||||||
|
from accelerate.utils import reduce
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -788,6 +792,12 @@ class Mask2FormerLoss(nn.Module):
|
|||||||
"""
|
"""
|
||||||
num_masks = sum([len(classes) for classes in class_labels])
|
num_masks = sum([len(classes) for classes in class_labels])
|
||||||
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
|
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
|
||||||
|
world_size = 1
|
||||||
|
if PartialState._shared_state != {}:
|
||||||
|
num_masks_pt = reduce(num_masks_pt)
|
||||||
|
world_size = PartialState().num_processes
|
||||||
|
|
||||||
|
num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)
|
||||||
return num_masks_pt
|
return num_masks_pt
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ from ...utils import (
|
|||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_accelerate_available,
|
||||||
is_scipy_available,
|
is_scipy_available,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
@ -42,6 +43,10 @@ from .configuration_maskformer import MaskFormerConfig
|
|||||||
from .configuration_maskformer_swin import MaskFormerSwinConfig
|
from .configuration_maskformer_swin import MaskFormerSwinConfig
|
||||||
|
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import PartialState
|
||||||
|
from accelerate.utils import reduce
|
||||||
|
|
||||||
if is_scipy_available():
|
if is_scipy_available():
|
||||||
from scipy.optimize import linear_sum_assignment
|
from scipy.optimize import linear_sum_assignment
|
||||||
|
|
||||||
@ -1194,6 +1199,12 @@ class MaskFormerLoss(nn.Module):
|
|||||||
"""
|
"""
|
||||||
num_masks = sum([len(classes) for classes in class_labels])
|
num_masks = sum([len(classes) for classes in class_labels])
|
||||||
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
|
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
|
||||||
|
world_size = 1
|
||||||
|
if PartialState._shared_state != {}:
|
||||||
|
num_masks_pt = reduce(num_masks_pt)
|
||||||
|
world_size = PartialState().num_processes
|
||||||
|
|
||||||
|
num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)
|
||||||
return num_masks_pt
|
return num_masks_pt
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ from ...utils import (
|
|||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_accelerate_available,
|
||||||
is_scipy_available,
|
is_scipy_available,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
@ -40,6 +41,10 @@ from ...utils.backbone_utils import load_backbone
|
|||||||
from .configuration_oneformer import OneFormerConfig
|
from .configuration_oneformer import OneFormerConfig
|
||||||
|
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import PartialState
|
||||||
|
from accelerate.utils import reduce
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -723,6 +728,12 @@ class OneFormerLoss(nn.Module):
|
|||||||
"""
|
"""
|
||||||
num_masks = sum([len(classes) for classes in class_labels])
|
num_masks = sum([len(classes) for classes in class_labels])
|
||||||
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
|
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
|
||||||
|
world_size = 1
|
||||||
|
if PartialState._shared_state != {}:
|
||||||
|
num_masks_pt = reduce(num_masks_pt)
|
||||||
|
world_size = PartialState().num_processes
|
||||||
|
|
||||||
|
num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)
|
||||||
return num_masks_pt
|
return num_masks_pt
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ from ...utils import (
|
|||||||
ModelOutput,
|
ModelOutput,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_accelerate_available,
|
||||||
is_scipy_available,
|
is_scipy_available,
|
||||||
is_timm_available,
|
is_timm_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
@ -50,6 +51,10 @@ if is_timm_available():
|
|||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from transformers.image_transforms import center_to_corners_format
|
from transformers.image_transforms import center_to_corners_format
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import PartialState
|
||||||
|
from accelerate.utils import reduce
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "TableTransformerConfig"
|
_CONFIG_FOR_DOC = "TableTransformerConfig"
|
||||||
@ -1751,11 +1756,11 @@ class TableTransformerLoss(nn.Module):
|
|||||||
# Compute the average number of target boxes across all nodes, for normalization purposes
|
# Compute the average number of target boxes across all nodes, for normalization purposes
|
||||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||||
# (Niels): comment out function below, distributed training to be added
|
world_size = 1
|
||||||
# if is_dist_avail_and_initialized():
|
if PartialState._shared_state != {}:
|
||||||
# torch.distributed.all_reduce(num_boxes)
|
num_boxes = reduce(num_boxes)
|
||||||
# (Niels) in original implementation, num_boxes is divided by get_world_size()
|
world_size = PartialState().num_processes
|
||||||
num_boxes = torch.clamp(num_boxes, min=1).item()
|
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||||
|
|
||||||
# Compute all the requested losses
|
# Compute all the requested losses
|
||||||
losses = {}
|
losses = {}
|
||||||
|
@ -1297,7 +1297,7 @@ class YolosImageProcessor(BaseImageProcessor):
|
|||||||
encoded_inputs = self.pad(
|
encoded_inputs = self.pad(
|
||||||
images,
|
images,
|
||||||
annotations=annotations,
|
annotations=annotations,
|
||||||
return_pixel_mask=True,
|
return_pixel_mask=False,
|
||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
input_data_format=input_data_format,
|
input_data_format=input_data_format,
|
||||||
update_bboxes=do_convert_annotations,
|
update_bboxes=do_convert_annotations,
|
||||||
|
@ -33,6 +33,7 @@ from ...utils import (
|
|||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_accelerate_available,
|
||||||
is_scipy_available,
|
is_scipy_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
logging,
|
logging,
|
||||||
@ -48,6 +49,9 @@ if is_scipy_available():
|
|||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from transformers.image_transforms import center_to_corners_format
|
from transformers.image_transforms import center_to_corners_format
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import PartialState
|
||||||
|
from accelerate.utils import reduce
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@ -1074,11 +1078,11 @@ class YolosLoss(nn.Module):
|
|||||||
# Compute the average number of target boxes across all nodes, for normalization purposes
|
# Compute the average number of target boxes across all nodes, for normalization purposes
|
||||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||||
# (Niels): comment out function below, distributed training to be added
|
world_size = 1
|
||||||
# if is_dist_avail_and_initialized():
|
if PartialState._shared_state != {}:
|
||||||
# torch.distributed.all_reduce(num_boxes)
|
num_boxes = reduce(num_boxes)
|
||||||
# (Niels) in original implementation, num_boxes is divided by get_world_size()
|
world_size = PartialState().num_processes
|
||||||
num_boxes = torch.clamp(num_boxes, min=1).item()
|
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||||
|
|
||||||
# Compute all the requested losses
|
# Compute all the requested losses
|
||||||
losses = {}
|
losses = {}
|
||||||
|
Loading…
Reference in New Issue
Block a user