Experimental symbolic tracing feature with torch.fx for BERT, ELECTRA and T5 (#11475)

Symbolic tracing feature for BERT, ELECTRA and T5

Co-authored-by: Michael Benayoun <michael@huggingface.co>
Co-authored-by: Stas Bekman <stas@stason.org>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Michael Benayoun 2021-05-14 20:57:30 +02:00 committed by GitHub
parent 94a2348706
commit 86d5fb0b36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 371 additions and 4 deletions

View File

@ -265,6 +265,15 @@ def is_torch_cuda_available():
return False
_torch_fx_available = False
if _torch_available:
_torch_fx_available = version.parse(_torch_version) >= version.parse("1.8")
def is_torch_fx_available():
return _torch_fx_available
def is_tf_available():
return _tf_available
@ -1597,11 +1606,21 @@ def tf_required(func):
return wrapper
def is_torch_fx_proxy(x):
if is_torch_fx_available():
import torch.fx
return isinstance(x, torch.fx.Proxy)
return False
def is_tensor(x):
"""
Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor`, obj:`jaxlib.xla_extension.DeviceArray` or
:obj:`np.ndarray`.
"""
if is_torch_fx_proxy(x):
return True
if is_torch_available():
import torch

View File

@ -0,0 +1,253 @@
import dis
import inspect
from typing import List, Optional, Union
import torch
from torch.fx import GraphModule, Node, Proxy, Tracer
from . import PreTrainedModel
class HFProxy(Proxy):
"""
Proxy that is able to provide the proper ranks, shapes and boolean values during symbolic tracing by implementing
the dim, size and __bool__ methods. It can be easily extended by either adding new methods or extending the
existing ones.
"""
def __init__(self, node: Node, tracer: Optional[Tracer] = None):
super().__init__(node, tracer=tracer)
if hasattr(self, "tracer") and self.tracer is not None:
self.device = self.tracer.root.device
self.dtype = next(self.tracer.root.parameters()).dtype
def dim(self):
return len(self.tracer.encoder_shape)
def _shape(self, calling_frame):
module = calling_frame.f_locals.get("self", None)
is_decoder = hasattr(module, "is_decoder") and module.is_decoder
return list(self.tracer.decoder_shape) if is_decoder else list(self.tracer.encoder_shape)
def size(self, dim=None):
frame = inspect.currentframe()
calling_frame = frame.f_back
# self.size can be called through the shape property, in which case we need to get the outer
# frame, containing the meaningful information.
if calling_frame.f_code.co_name == "shape":
calling_frame = calling_frame.f_back
instructions = list(reversed(list(dis.get_instructions(calling_frame.f_code))[: calling_frame.f_lasti]))
code_context = inspect.getframeinfo(calling_frame).code_context[0].strip()
shape = self._shape(calling_frame)
if calling_frame.f_code.co_name == "transpose_for_scores":
# Provides the proper "x.size()" for:
# new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
shape = shape + [-1]
elif "context_layer" in calling_frame.f_locals:
# Provides the proper "context_layer.size()" for:
# new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
shape = shape + [-1, -1]
elif calling_frame.f_locals.get("do_cross_attention", False):
# Provides the proper shape for:
# query_length = present_key_value_state[0].shape[2]
# (modeling_t5.py)
shape = list(self.tracer.encoder_shape)
shape = shape[:1] + [-1] + shape[1:2]
elif "key_length" in code_context or "encoder_seq_length" in code_context:
shape = list(self.tracer.encoder_shape)
elif "lm_logits.size(-1)" in code_context:
shape = [self.tracer.root.config.vocab_size]
elif "start_positions" in code_context or "end_positions" in code_context:
# For question answering tasks.
shape = [1]
elif "num_choices" in code_context:
if self.tracer.num_choices <= 0:
raise ValueError("num_choices must be given to the CustomTracer for MultipleChoice tasks.")
shape = shape[:1] + [self.tracer.num_choices] + shape[1:]
else:
# Default case:
# - If self.size is called for an unpacking, retrieves the corresponding unpacking
# instruction, and returns the shape padded as much as necessary to match the expected
# number of items.
# - If self.size is called outside of an unpacking context, simply return the shape.
is_unpack = False
for inst in instructions:
if inst.opname == "UNPACK_SEQUENCE":
is_unpack = True
break
if is_unpack and inst.argval >= 3:
shape += [self.tracer.root.config.hidden_size]
dummy_values = [1] * (inst.argval - 3)
shape += dummy_values
if dim is not None:
return shape[dim]
return tuple(shape)
@property
def shape(self):
return self.size()
def __bool__(self) -> bool:
frame = inspect.currentframe()
calling_frame = frame.f_back
code_context = inspect.getframeinfo(calling_frame).code_context[0].strip()
if calling_frame.f_code.co_name == "apply_chunking_to_forward":
# Returning True to every assertion in "apply_chuncking_to_forward"
return True
elif "assert" in code_context:
# Returning True to any assertion.
return True
elif calling_frame.f_code.co_name == "get_extended_attention_mask":
# Corresponding to:
# if causal_mask.shape[1] < attention_mask.shape[1]:
return calling_frame.f_back.f_locals["past_key_values"][0] is not None
raise NotImplementedError("__bool__ was called for CustomProxy, but this case is not covered yet.")
def __setitem__(self, key, value):
pass
def __contains__(self, key):
return False
class HFTracer(Tracer):
"""
Tracer that is able to symbolically trace models from the library (currently BERT, ELECTRA and T5). To do that, it
uses the HFProxy instead of the regular PyTorch torch.fx.Proxy.
"""
def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1):
super().__init__()
encoder_sequence_length = sequence_length[0] if isinstance(sequence_length, (list, tuple)) else sequence_length
decoder_sequence_length = sequence_length[1] if isinstance(sequence_length, (list, tuple)) else -1
self.encoder_shape = [batch_size, encoder_sequence_length]
self.decoder_shape = (
[batch_size, decoder_sequence_length] if decoder_sequence_length > 0 else list(self.encoder_shape)
)
self.num_choices = num_choices
if self.num_choices > 0:
self.encoder_shape[0] *= self.num_choices
self.prev_module = None
def proxy(self, node: Node):
return HFProxy(node, self)
def _insert_module_as_submodule(self, mod):
"""
Helper method which tries to insert a module that was not declared as submodule.
"""
# First, retrieve the parent module.
if self.prev_module is None:
return None
parent_path = self.prev_module.rsplit(".", 1)[0]
parent_mod = None
for path, module in self.root.named_modules():
if path == parent_path:
parent_mod = module
break
if parent_mod is None:
return None
# If retrieving the parent module was possible, set the module not declared as a submodule
# as a parent module attribute.
path = None
for var_name, var_val in inspect.currentframe().f_back.f_locals.items():
if mod is var_val:
setattr(parent_mod, var_name, mod)
path = f"{parent_path}.{var_name}"
break
return path
def path_of_module(self, mod: torch.nn.Module) -> str:
"""
Helper method to find the qualified name of ``mod`` in the Module hierarchy of ``root``. For example, if
``root`` has a submodule named ``foo``, which has a submodule named ``bar``, passing ``bar`` into this function
will return the string "foo.bar".
Args:
mod (str): The ``Module`` to retrieve the qualified name for.
"""
# Prefer the O(1) algorithm
if hasattr(self, "submodule_paths") and self.submodule_paths:
path = self.submodule_paths.get(mod)
if path is None:
path = self._insert_module_as_submodule(mod)
if path is None:
raise NameError("module is not installed as a submodule")
self.prev_module = path
return path
# O(N^2) fallback in the case that we didn't store the submodule
# paths.
else:
for n, p in self.root.named_modules():
if mod is p:
self.prev_module = n
return n
path = self._insert_module_as_submodule(mod)
if path is None:
raise NameError("module is not installed as a submodule")
self.prev_module = path
return path
def symbolic_trace(
model: PreTrainedModel,
input_names: Optional[List[str]] = None,
batch_size: int = 1,
sequence_length: Union[int, List[int]] = [128, 128],
num_choices: int = -1,
) -> GraphModule:
"""
Performs symbolic tracing on the model.
Args:
model (:obj:`PretrainedModel`):
The model to trace.
input_names (:obj:`List[str]`, `optional`):
The names of the inputs of the traced model. If unset, model.dummy_inputs().keys() are used instead.
batch_size (:obj:`int`, `optional`, defaults to 1):
The batch size of the traced model inputs.
sequence_length (:obj:`int` or :obj:`List[int]]`):
The sequence length of the traced model inputs. For sequence-to-sequence models with different sequence
lengths between the encoder and the decoder inputs, this must be :obj:`[encoder_sequence_length,
decoder_sequence_length]`.
num_choices (:obj:`int`, `optional`, defaults to -1):
The number of possible choices for a multiple choice task.
Returns:
:obj:`torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.
Example::
from transformers.modeling_fx_utils import symbolic_trace
traced_model = symbolic_trace(
model,
input_names=["input_ids", "attention_mask", "token_type_ids"],
batch_size=1,
sequence_length=128,
)
"""
if input_names is None:
input_names = model.dummy_inputs.keys()
sig = inspect.signature(model.forward)
# TODO: how to handle the case of the "return_dict" parameter.
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
tracer = HFTracer(batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices)
traced_graph = tracer.trace(model, concrete_args=concrete_args)
traced = torch.fx.GraphModule(model, traced_graph)
return traced

View File

@ -32,6 +32,7 @@ from ...file_utils import (
DUMMY_MASK,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_fx_proxy,
replace_return_docstrings,
)
from ...modeling_outputs import (
@ -776,9 +777,14 @@ class T5PreTrainedModel(PreTrainedModel):
), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
# shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
if is_torch_fx_proxy(input_ids):
# Item assignment is not supported natively for proxies.
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
else:
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
# replace possible -100 values in labels by `pad_token_id`

View File

@ -439,6 +439,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
else ()
)
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model

View File

@ -25,7 +25,7 @@ from typing import List, Tuple
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import is_torch_available, logging
from transformers.file_utils import WEIGHTS_NAME
from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
ENDPOINT_STAGING,
@ -64,6 +64,9 @@ if is_torch_available():
T5ForConditionalGeneration,
)
if is_torch_fx_available():
from transformers.modeling_fx_utils import symbolic_trace
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
@ -82,6 +85,7 @@ class ModelTesterMixin:
model_tester = None
all_model_classes = ()
all_generative_model_classes = ()
fx_ready_model_classes = ()
test_torchscript = True
test_pruning = True
test_resize_embeddings = True
@ -565,6 +569,88 @@ class ModelTesterMixin:
self.assertTrue(models_equal)
def test_torch_fx(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
self._create_and_check_torch_fx_tracing(config, inputs_dict)
def test_torch_fx_output_loss(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True)
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
if not is_torch_fx_available():
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False
for model_class in self.fx_ready_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
try:
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
input_ids = inputs["input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"]
labels = inputs.get("labels", None)
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
if labels is not None:
input_names.append("labels")
prepared_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
model_output = model(**prepared_inputs)
batch_size = input_ids.shape[0]
encoder_sequence_length = input_ids.shape[1]
decoder_sequence_length = decoder_attention_mask.shape[1]
traced_model = symbolic_trace(
model,
input_names,
batch_size=batch_size,
sequence_length=[encoder_sequence_length, decoder_sequence_length],
)
traced_output = traced_model(**prepared_inputs)
else:
input_ids = inputs["input_ids"]
labels = inputs.get("labels", None)
input_names = ["input_ids", "attention_mask", "token_type_ids"]
if labels is not None:
input_names.append("labels")
prepared_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
model_output = model(**prepared_inputs)
batch_size = input_ids.shape[0]
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
sequence_length = input_ids.shape[2]
num_choices = input_ids.shape[1]
else:
sequence_length = input_ids.shape[1]
num_choices = -1
traced_model = symbolic_trace(
model,
input_names,
batch_size=batch_size,
sequence_length=sequence_length,
num_choices=num_choices,
)
traced_output = traced_model(**prepared_inputs)
except RuntimeError:
self.fail("Couldn't trace module.")
num_outputs = len(model_output)
outputs_are_close = all(torch.allclose(model_output[i], traced_output[i]) for i in range(num_outputs))
self.assertTrue(outputs_are_close)
def test_headmasking(self):
if not self.test_head_masking:
return

View File

@ -287,6 +287,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
fx_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model

View File

@ -488,6 +488,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
test_pruning = False
test_torchscript = True