ViT and Swin symbolic tracing with torch.fx (#17182)

* Support tracing for ViT

* Swin support

* Fix copies

* Fix type annotation issue

* Removed unused import
This commit is contained in:
Michael Benayoun 2022-05-12 10:42:27 +02:00 committed by GitHub
parent 1a688709b3
commit 8c7481f35c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 70 additions and 35 deletions

View File

@ -168,7 +168,7 @@ class DeiTSelfAttention(nn.Module):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
@ -200,7 +200,7 @@ class DeiTSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

View File

@ -177,7 +177,7 @@ class DPTViTSelfAttention(nn.Module):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
@ -209,7 +209,7 @@ class DPTViTSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

View File

@ -496,7 +496,7 @@ def window_reverse(windows, window_size, height, width):
"""
Merges windows to produce higher resolution features.
"""
batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size))
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
return windows
@ -697,7 +697,7 @@ class MaskFormerSwinSelfAttention(nn.Module):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
@ -750,7 +750,7 @@ class MaskFormerSwinSelfAttention(nn.Module):
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

View File

@ -226,7 +226,7 @@ def window_reverse(windows, window_size, height, width):
"""
Merges windows to produce higher resolution features.
"""
batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size))
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
return windows
@ -435,7 +435,7 @@ class SwinSelfAttention(nn.Module):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
@ -488,7 +488,7 @@ class SwinSelfAttention(nn.Module):
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
@ -1071,7 +1071,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
# Reshape to (batch_size, num_channels, height, width)
sequence_output = sequence_output.transpose(1, 2)
batch_size, num_channels, sequence_length = sequence_output.shape
height = width = int(sequence_length**0.5)
height = width = math.floor(sequence_length**0.5)
sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)
# Reconstruct pixel values

View File

@ -213,7 +213,7 @@ class ViTSelfAttention(nn.Module):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
@ -245,7 +245,7 @@ class ViTSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
@ -687,7 +687,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
# Reshape to (batch_size, num_channels, height, width)
sequence_output = sequence_output[:, 1:]
batch_size, sequence_length, num_channels = sequence_output.shape
height = width = int(sequence_length**0.5)
height = width = math.floor(sequence_length**0.5)
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
# Reconstruct pixel values

View File

@ -342,7 +342,7 @@ class ViTMAESelfAttention(nn.Module):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
@ -374,7 +374,7 @@ class ViTMAESelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

View File

@ -280,7 +280,7 @@ class YolosSelfAttention(nn.Module):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
@ -312,7 +312,7 @@ class YolosSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

View File

@ -14,12 +14,12 @@
# limitations under the License.
import builtins
import collections
import functools
import inspect
import math
import random
import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
import torch
@ -31,6 +31,7 @@ from .. import (
CONFIG_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
@ -71,6 +72,7 @@ def _generate_supported_model_classes(
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
}
@ -100,6 +102,8 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"gpt_neo",
"t5",
"roberta",
"vit",
"swin",
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
# "layoutlm",
# "xlnet",
@ -276,6 +280,31 @@ def torch_tensor_index_select(self, dim, index):
return torch_tensor_index_select(self, dim, index)
def torch_roll(input, shifts, dims=None):
return input
def torch_nn_conv2d(self, input):
h_in, w_in = input.shape[-2:]
shape = None
padding = self.padding
if padding == "valid":
padding = (0, 0)
if padding == "same":
shape = list(input.shape)
if shape is None:
shape = list(input.shape)
h_out = math.floor(
(h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
)
w_out = math.floor(
(w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
)
shape[-2:] = [h_out, w_out]
shape[-3] = self.out_channels
return torch.empty(shape, device="meta")
def torch_nn_mseloss(self, input, target):
if self.reduction == "none":
shape = target.shape
@ -317,9 +346,11 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.Tensor.mul: torch_tensor_mul_override,
torch.matmul: torch_matmul_override,
torch.Tensor.repeat: torch_tensor_repeat_override,
torch.roll: torch_roll,
# TODO: those might not be needed.
# torch.index_select: torch_index_select,
# torch.Tensor.index_select: torch_tensor_index_select,
torch.nn.Conv2d: torch_nn_conv2d,
torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
@ -368,6 +399,9 @@ class HFProxy(Proxy):
# we peephole optimize to the method invocation
return HFAttribute(self, k)
def __setitem__(self, indices, values):
return self.tracer.create_proxy("call_method", "__setitem__", (self, indices, values), {})
def __contains__(self, key):
# To handle cases such as :
# `"some_key" in kwargs`
@ -521,6 +555,15 @@ class HFTracer(Tracer):
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
else:
raise NotImplementedError(f"{model_class} not supported yet.")
elif "pixel_values" in input_name:
batch_size = shape[0]
image_size = model.config.image_size
if not isinstance(image_size, collections.abc.Iterable):
image_size = (image_size, image_size)
height, width = image_size
inputs_dict[input_name] = torch.zeros(
batch_size, model.config.num_channels, height, width, dtype=torch.float32, device=device
)
elif "mask" in input_name or "ids" in input_name:
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
@ -663,6 +706,11 @@ class HFTracer(Tracer):
else:
self.graph.erase_node(node)
# TODO: solves GraphModule creation.
# Without this, return type annotation "Tuple" is causing code execution failure.
if node.op == "output":
node.type = None
return self.graph
def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool:
@ -761,12 +809,4 @@ def symbolic_trace(
traced_graph = tracer.trace(model, concrete_args=concrete_args)
traced = torch.fx.GraphModule(model, traced_graph)
# Copy all the original attributes to the traced GraphModule.
regular_module_attributes = dir(nn.Module())
for name in dir(model):
attr = getattr(model, name)
if name.startswith("_") or name in regular_module_attributes:
continue
setattr(traced, name, deepcopy(attr))
return traced

View File

@ -175,6 +175,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
fx_compatible = True
test_pruning = False
test_resize_embeddings = False

View File

@ -155,6 +155,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
fx_compatible = True
test_pruning = False
test_resize_embeddings = False

View File

@ -738,8 +738,7 @@ class ModelTesterMixin:
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
else:
input_names = ["input_ids", "attention_mask", "token_type_ids"]
input_ids = inputs["input_ids"]
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
@ -756,12 +755,6 @@ class ModelTesterMixin:
model_output = model(**filtered_inputs)
rank = len(input_ids.shape)
if rank not in [2, 3]:
raise NotImplementedError(
f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}."
)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)