mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
ViT and Swin symbolic tracing with torch.fx (#17182)
* Support tracing for ViT * Swin support * Fix copies * Fix type annotation issue * Removed unused import
This commit is contained in:
parent
1a688709b3
commit
8c7481f35c
@ -168,7 +168,7 @@ class DeiTSelfAttention(nn.Module):
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
@ -200,7 +200,7 @@ class DeiTSelfAttention(nn.Module):
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
|
@ -177,7 +177,7 @@ class DPTViTSelfAttention(nn.Module):
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
@ -209,7 +209,7 @@ class DPTViTSelfAttention(nn.Module):
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
|
@ -496,7 +496,7 @@ def window_reverse(windows, window_size, height, width):
|
||||
"""
|
||||
Merges windows to produce higher resolution features.
|
||||
"""
|
||||
batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
|
||||
batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size))
|
||||
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
|
||||
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
|
||||
return windows
|
||||
@ -697,7 +697,7 @@ class MaskFormerSwinSelfAttention(nn.Module):
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
@ -750,7 +750,7 @@ class MaskFormerSwinSelfAttention(nn.Module):
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
|
@ -226,7 +226,7 @@ def window_reverse(windows, window_size, height, width):
|
||||
"""
|
||||
Merges windows to produce higher resolution features.
|
||||
"""
|
||||
batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
|
||||
batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size))
|
||||
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
|
||||
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
|
||||
return windows
|
||||
@ -435,7 +435,7 @@ class SwinSelfAttention(nn.Module):
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
@ -488,7 +488,7 @@ class SwinSelfAttention(nn.Module):
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
@ -1071,7 +1071,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
|
||||
# Reshape to (batch_size, num_channels, height, width)
|
||||
sequence_output = sequence_output.transpose(1, 2)
|
||||
batch_size, num_channels, sequence_length = sequence_output.shape
|
||||
height = width = int(sequence_length**0.5)
|
||||
height = width = math.floor(sequence_length**0.5)
|
||||
sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)
|
||||
|
||||
# Reconstruct pixel values
|
||||
|
@ -213,7 +213,7 @@ class ViTSelfAttention(nn.Module):
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
@ -245,7 +245,7 @@ class ViTSelfAttention(nn.Module):
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
@ -687,7 +687,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
|
||||
# Reshape to (batch_size, num_channels, height, width)
|
||||
sequence_output = sequence_output[:, 1:]
|
||||
batch_size, sequence_length, num_channels = sequence_output.shape
|
||||
height = width = int(sequence_length**0.5)
|
||||
height = width = math.floor(sequence_length**0.5)
|
||||
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
|
||||
|
||||
# Reconstruct pixel values
|
||||
|
@ -342,7 +342,7 @@ class ViTMAESelfAttention(nn.Module):
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
@ -374,7 +374,7 @@ class ViTMAESelfAttention(nn.Module):
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
|
@ -280,7 +280,7 @@ class YolosSelfAttention(nn.Module):
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
@ -312,7 +312,7 @@ class YolosSelfAttention(nn.Module):
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
|
@ -14,12 +14,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import collections
|
||||
import functools
|
||||
import inspect
|
||||
import math
|
||||
import random
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
@ -31,6 +31,7 @@ from .. import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
@ -71,6 +72,7 @@ def _generate_supported_model_classes(
|
||||
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
}
|
||||
|
||||
@ -100,6 +102,8 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
||||
"gpt_neo",
|
||||
"t5",
|
||||
"roberta",
|
||||
"vit",
|
||||
"swin",
|
||||
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
|
||||
# "layoutlm",
|
||||
# "xlnet",
|
||||
@ -276,6 +280,31 @@ def torch_tensor_index_select(self, dim, index):
|
||||
return torch_tensor_index_select(self, dim, index)
|
||||
|
||||
|
||||
def torch_roll(input, shifts, dims=None):
|
||||
return input
|
||||
|
||||
|
||||
def torch_nn_conv2d(self, input):
|
||||
h_in, w_in = input.shape[-2:]
|
||||
shape = None
|
||||
padding = self.padding
|
||||
if padding == "valid":
|
||||
padding = (0, 0)
|
||||
if padding == "same":
|
||||
shape = list(input.shape)
|
||||
if shape is None:
|
||||
shape = list(input.shape)
|
||||
h_out = math.floor(
|
||||
(h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
|
||||
)
|
||||
w_out = math.floor(
|
||||
(w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
|
||||
)
|
||||
shape[-2:] = [h_out, w_out]
|
||||
shape[-3] = self.out_channels
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
def torch_nn_mseloss(self, input, target):
|
||||
if self.reduction == "none":
|
||||
shape = target.shape
|
||||
@ -317,9 +346,11 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
||||
torch.Tensor.mul: torch_tensor_mul_override,
|
||||
torch.matmul: torch_matmul_override,
|
||||
torch.Tensor.repeat: torch_tensor_repeat_override,
|
||||
torch.roll: torch_roll,
|
||||
# TODO: those might not be needed.
|
||||
# torch.index_select: torch_index_select,
|
||||
# torch.Tensor.index_select: torch_tensor_index_select,
|
||||
torch.nn.Conv2d: torch_nn_conv2d,
|
||||
torch.nn.MSELoss: torch_nn_mseloss,
|
||||
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
|
||||
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
|
||||
@ -368,6 +399,9 @@ class HFProxy(Proxy):
|
||||
# we peephole optimize to the method invocation
|
||||
return HFAttribute(self, k)
|
||||
|
||||
def __setitem__(self, indices, values):
|
||||
return self.tracer.create_proxy("call_method", "__setitem__", (self, indices, values), {})
|
||||
|
||||
def __contains__(self, key):
|
||||
# To handle cases such as :
|
||||
# `"some_key" in kwargs`
|
||||
@ -521,6 +555,15 @@ class HFTracer(Tracer):
|
||||
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
|
||||
else:
|
||||
raise NotImplementedError(f"{model_class} not supported yet.")
|
||||
elif "pixel_values" in input_name:
|
||||
batch_size = shape[0]
|
||||
image_size = model.config.image_size
|
||||
if not isinstance(image_size, collections.abc.Iterable):
|
||||
image_size = (image_size, image_size)
|
||||
height, width = image_size
|
||||
inputs_dict[input_name] = torch.zeros(
|
||||
batch_size, model.config.num_channels, height, width, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
elif "mask" in input_name or "ids" in input_name:
|
||||
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
|
||||
@ -663,6 +706,11 @@ class HFTracer(Tracer):
|
||||
else:
|
||||
self.graph.erase_node(node)
|
||||
|
||||
# TODO: solves GraphModule creation.
|
||||
# Without this, return type annotation "Tuple" is causing code execution failure.
|
||||
if node.op == "output":
|
||||
node.type = None
|
||||
|
||||
return self.graph
|
||||
|
||||
def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool:
|
||||
@ -761,12 +809,4 @@ def symbolic_trace(
|
||||
traced_graph = tracer.trace(model, concrete_args=concrete_args)
|
||||
traced = torch.fx.GraphModule(model, traced_graph)
|
||||
|
||||
# Copy all the original attributes to the traced GraphModule.
|
||||
regular_module_attributes = dir(nn.Module())
|
||||
for name in dir(model):
|
||||
attr = getattr(model, name)
|
||||
if name.startswith("_") or name in regular_module_attributes:
|
||||
continue
|
||||
setattr(traced, name, deepcopy(attr))
|
||||
|
||||
return traced
|
||||
|
@ -175,6 +175,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
fx_compatible = True
|
||||
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
|
@ -155,6 +155,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
fx_compatible = True
|
||||
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
|
@ -738,8 +738,7 @@ class ModelTesterMixin:
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
else:
|
||||
input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
||||
input_ids = inputs["input_ids"]
|
||||
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
start_positions = inputs.get("start_positions", None)
|
||||
@ -756,12 +755,6 @@ class ModelTesterMixin:
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
rank = len(input_ids.shape)
|
||||
if rank not in [2, 3]:
|
||||
raise NotImplementedError(
|
||||
f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}."
|
||||
)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user