mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Experimental symbolic tracing feature with torch.fx for BERT, ELECTRA and T5 (#11475)
Symbolic tracing feature for BERT, ELECTRA and T5 Co-authored-by: Michael Benayoun <michael@huggingface.co> Co-authored-by: Stas Bekman <stas@stason.org> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
94a2348706
commit
86d5fb0b36
@ -265,6 +265,15 @@ def is_torch_cuda_available():
|
||||
return False
|
||||
|
||||
|
||||
_torch_fx_available = False
|
||||
if _torch_available:
|
||||
_torch_fx_available = version.parse(_torch_version) >= version.parse("1.8")
|
||||
|
||||
|
||||
def is_torch_fx_available():
|
||||
return _torch_fx_available
|
||||
|
||||
|
||||
def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
@ -1597,11 +1606,21 @@ def tf_required(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_torch_fx_proxy(x):
|
||||
if is_torch_fx_available():
|
||||
import torch.fx
|
||||
|
||||
return isinstance(x, torch.fx.Proxy)
|
||||
return False
|
||||
|
||||
|
||||
def is_tensor(x):
|
||||
"""
|
||||
Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor`, obj:`jaxlib.xla_extension.DeviceArray` or
|
||||
:obj:`np.ndarray`.
|
||||
"""
|
||||
if is_torch_fx_proxy(x):
|
||||
return True
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
253
src/transformers/modeling_fx_utils.py
Normal file
253
src/transformers/modeling_fx_utils.py
Normal file
@ -0,0 +1,253 @@
|
||||
import dis
|
||||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node, Proxy, Tracer
|
||||
|
||||
from . import PreTrainedModel
|
||||
|
||||
|
||||
class HFProxy(Proxy):
|
||||
"""
|
||||
Proxy that is able to provide the proper ranks, shapes and boolean values during symbolic tracing by implementing
|
||||
the dim, size and __bool__ methods. It can be easily extended by either adding new methods or extending the
|
||||
existing ones.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, tracer: Optional[Tracer] = None):
|
||||
super().__init__(node, tracer=tracer)
|
||||
if hasattr(self, "tracer") and self.tracer is not None:
|
||||
self.device = self.tracer.root.device
|
||||
self.dtype = next(self.tracer.root.parameters()).dtype
|
||||
|
||||
def dim(self):
|
||||
return len(self.tracer.encoder_shape)
|
||||
|
||||
def _shape(self, calling_frame):
|
||||
module = calling_frame.f_locals.get("self", None)
|
||||
is_decoder = hasattr(module, "is_decoder") and module.is_decoder
|
||||
return list(self.tracer.decoder_shape) if is_decoder else list(self.tracer.encoder_shape)
|
||||
|
||||
def size(self, dim=None):
|
||||
frame = inspect.currentframe()
|
||||
calling_frame = frame.f_back
|
||||
|
||||
# self.size can be called through the shape property, in which case we need to get the outer
|
||||
# frame, containing the meaningful information.
|
||||
if calling_frame.f_code.co_name == "shape":
|
||||
calling_frame = calling_frame.f_back
|
||||
|
||||
instructions = list(reversed(list(dis.get_instructions(calling_frame.f_code))[: calling_frame.f_lasti]))
|
||||
code_context = inspect.getframeinfo(calling_frame).code_context[0].strip()
|
||||
|
||||
shape = self._shape(calling_frame)
|
||||
|
||||
if calling_frame.f_code.co_name == "transpose_for_scores":
|
||||
# Provides the proper "x.size()" for:
|
||||
# new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
shape = shape + [-1]
|
||||
elif "context_layer" in calling_frame.f_locals:
|
||||
# Provides the proper "context_layer.size()" for:
|
||||
# new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
shape = shape + [-1, -1]
|
||||
elif calling_frame.f_locals.get("do_cross_attention", False):
|
||||
# Provides the proper shape for:
|
||||
# query_length = present_key_value_state[0].shape[2]
|
||||
# (modeling_t5.py)
|
||||
shape = list(self.tracer.encoder_shape)
|
||||
shape = shape[:1] + [-1] + shape[1:2]
|
||||
elif "key_length" in code_context or "encoder_seq_length" in code_context:
|
||||
shape = list(self.tracer.encoder_shape)
|
||||
elif "lm_logits.size(-1)" in code_context:
|
||||
shape = [self.tracer.root.config.vocab_size]
|
||||
elif "start_positions" in code_context or "end_positions" in code_context:
|
||||
# For question answering tasks.
|
||||
shape = [1]
|
||||
elif "num_choices" in code_context:
|
||||
if self.tracer.num_choices <= 0:
|
||||
raise ValueError("num_choices must be given to the CustomTracer for MultipleChoice tasks.")
|
||||
shape = shape[:1] + [self.tracer.num_choices] + shape[1:]
|
||||
else:
|
||||
# Default case:
|
||||
# - If self.size is called for an unpacking, retrieves the corresponding unpacking
|
||||
# instruction, and returns the shape padded as much as necessary to match the expected
|
||||
# number of items.
|
||||
# - If self.size is called outside of an unpacking context, simply return the shape.
|
||||
is_unpack = False
|
||||
|
||||
for inst in instructions:
|
||||
if inst.opname == "UNPACK_SEQUENCE":
|
||||
is_unpack = True
|
||||
break
|
||||
|
||||
if is_unpack and inst.argval >= 3:
|
||||
shape += [self.tracer.root.config.hidden_size]
|
||||
dummy_values = [1] * (inst.argval - 3)
|
||||
shape += dummy_values
|
||||
|
||||
if dim is not None:
|
||||
return shape[dim]
|
||||
|
||||
return tuple(shape)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.size()
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
frame = inspect.currentframe()
|
||||
calling_frame = frame.f_back
|
||||
code_context = inspect.getframeinfo(calling_frame).code_context[0].strip()
|
||||
if calling_frame.f_code.co_name == "apply_chunking_to_forward":
|
||||
# Returning True to every assertion in "apply_chuncking_to_forward"
|
||||
return True
|
||||
elif "assert" in code_context:
|
||||
# Returning True to any assertion.
|
||||
return True
|
||||
elif calling_frame.f_code.co_name == "get_extended_attention_mask":
|
||||
# Corresponding to:
|
||||
# if causal_mask.shape[1] < attention_mask.shape[1]:
|
||||
return calling_frame.f_back.f_locals["past_key_values"][0] is not None
|
||||
raise NotImplementedError("__bool__ was called for CustomProxy, but this case is not covered yet.")
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
pass
|
||||
|
||||
def __contains__(self, key):
|
||||
return False
|
||||
|
||||
|
||||
class HFTracer(Tracer):
|
||||
"""
|
||||
Tracer that is able to symbolically trace models from the library (currently BERT, ELECTRA and T5). To do that, it
|
||||
uses the HFProxy instead of the regular PyTorch torch.fx.Proxy.
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1):
|
||||
super().__init__()
|
||||
encoder_sequence_length = sequence_length[0] if isinstance(sequence_length, (list, tuple)) else sequence_length
|
||||
decoder_sequence_length = sequence_length[1] if isinstance(sequence_length, (list, tuple)) else -1
|
||||
self.encoder_shape = [batch_size, encoder_sequence_length]
|
||||
self.decoder_shape = (
|
||||
[batch_size, decoder_sequence_length] if decoder_sequence_length > 0 else list(self.encoder_shape)
|
||||
)
|
||||
self.num_choices = num_choices
|
||||
if self.num_choices > 0:
|
||||
self.encoder_shape[0] *= self.num_choices
|
||||
|
||||
self.prev_module = None
|
||||
|
||||
def proxy(self, node: Node):
|
||||
return HFProxy(node, self)
|
||||
|
||||
def _insert_module_as_submodule(self, mod):
|
||||
"""
|
||||
Helper method which tries to insert a module that was not declared as submodule.
|
||||
"""
|
||||
# First, retrieve the parent module.
|
||||
if self.prev_module is None:
|
||||
return None
|
||||
parent_path = self.prev_module.rsplit(".", 1)[0]
|
||||
parent_mod = None
|
||||
for path, module in self.root.named_modules():
|
||||
if path == parent_path:
|
||||
parent_mod = module
|
||||
break
|
||||
if parent_mod is None:
|
||||
return None
|
||||
|
||||
# If retrieving the parent module was possible, set the module not declared as a submodule
|
||||
# as a parent module attribute.
|
||||
path = None
|
||||
for var_name, var_val in inspect.currentframe().f_back.f_locals.items():
|
||||
if mod is var_val:
|
||||
setattr(parent_mod, var_name, mod)
|
||||
path = f"{parent_path}.{var_name}"
|
||||
break
|
||||
|
||||
return path
|
||||
|
||||
def path_of_module(self, mod: torch.nn.Module) -> str:
|
||||
"""
|
||||
Helper method to find the qualified name of ``mod`` in the Module hierarchy of ``root``. For example, if
|
||||
``root`` has a submodule named ``foo``, which has a submodule named ``bar``, passing ``bar`` into this function
|
||||
will return the string "foo.bar".
|
||||
|
||||
Args:
|
||||
mod (str): The ``Module`` to retrieve the qualified name for.
|
||||
"""
|
||||
# Prefer the O(1) algorithm
|
||||
if hasattr(self, "submodule_paths") and self.submodule_paths:
|
||||
path = self.submodule_paths.get(mod)
|
||||
if path is None:
|
||||
path = self._insert_module_as_submodule(mod)
|
||||
if path is None:
|
||||
raise NameError("module is not installed as a submodule")
|
||||
self.prev_module = path
|
||||
return path
|
||||
|
||||
# O(N^2) fallback in the case that we didn't store the submodule
|
||||
# paths.
|
||||
else:
|
||||
for n, p in self.root.named_modules():
|
||||
if mod is p:
|
||||
self.prev_module = n
|
||||
return n
|
||||
path = self._insert_module_as_submodule(mod)
|
||||
if path is None:
|
||||
raise NameError("module is not installed as a submodule")
|
||||
self.prev_module = path
|
||||
return path
|
||||
|
||||
|
||||
def symbolic_trace(
|
||||
model: PreTrainedModel,
|
||||
input_names: Optional[List[str]] = None,
|
||||
batch_size: int = 1,
|
||||
sequence_length: Union[int, List[int]] = [128, 128],
|
||||
num_choices: int = -1,
|
||||
) -> GraphModule:
|
||||
|
||||
"""
|
||||
Performs symbolic tracing on the model.
|
||||
|
||||
Args:
|
||||
model (:obj:`PretrainedModel`):
|
||||
The model to trace.
|
||||
input_names (:obj:`List[str]`, `optional`):
|
||||
The names of the inputs of the traced model. If unset, model.dummy_inputs().keys() are used instead.
|
||||
batch_size (:obj:`int`, `optional`, defaults to 1):
|
||||
The batch size of the traced model inputs.
|
||||
sequence_length (:obj:`int` or :obj:`List[int]]`):
|
||||
The sequence length of the traced model inputs. For sequence-to-sequence models with different sequence
|
||||
lengths between the encoder and the decoder inputs, this must be :obj:`[encoder_sequence_length,
|
||||
decoder_sequence_length]`.
|
||||
num_choices (:obj:`int`, `optional`, defaults to -1):
|
||||
The number of possible choices for a multiple choice task.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.
|
||||
|
||||
Example::
|
||||
|
||||
from transformers.modeling_fx_utils import symbolic_trace
|
||||
traced_model = symbolic_trace(
|
||||
model,
|
||||
input_names=["input_ids", "attention_mask", "token_type_ids"],
|
||||
batch_size=1,
|
||||
sequence_length=128,
|
||||
)
|
||||
"""
|
||||
if input_names is None:
|
||||
input_names = model.dummy_inputs.keys()
|
||||
|
||||
sig = inspect.signature(model.forward)
|
||||
# TODO: how to handle the case of the "return_dict" parameter.
|
||||
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
|
||||
|
||||
tracer = HFTracer(batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices)
|
||||
traced_graph = tracer.trace(model, concrete_args=concrete_args)
|
||||
traced = torch.fx.GraphModule(model, traced_graph)
|
||||
|
||||
return traced
|
@ -32,6 +32,7 @@ from ...file_utils import (
|
||||
DUMMY_MASK,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torch_fx_proxy,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import (
|
||||
@ -776,9 +777,14 @@ class T5PreTrainedModel(PreTrainedModel):
|
||||
), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
|
||||
|
||||
# shift inputs to the right
|
||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||
shifted_input_ids[..., 0] = decoder_start_token_id
|
||||
if is_torch_fx_proxy(input_ids):
|
||||
# Item assignment is not supported natively for proxies.
|
||||
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
|
||||
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
||||
else:
|
||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||
shifted_input_ids[..., 0] = decoder_start_token_id
|
||||
|
||||
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
|
@ -439,6 +439,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
|
||||
fx_ready_model_classes = all_model_classes
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
|
@ -25,7 +25,7 @@ from typing import List, Tuple
|
||||
from huggingface_hub import HfApi
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import is_torch_available, logging
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import (
|
||||
ENDPOINT_STAGING,
|
||||
@ -64,6 +64,9 @@ if is_torch_available():
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
|
||||
if is_torch_fx_available():
|
||||
from transformers.modeling_fx_utils import symbolic_trace
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
@ -82,6 +85,7 @@ class ModelTesterMixin:
|
||||
model_tester = None
|
||||
all_model_classes = ()
|
||||
all_generative_model_classes = ()
|
||||
fx_ready_model_classes = ()
|
||||
test_torchscript = True
|
||||
test_pruning = True
|
||||
test_resize_embeddings = True
|
||||
@ -565,6 +569,88 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_torch_fx(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self._create_and_check_torch_fx_tracing(config, inputs_dict)
|
||||
|
||||
def test_torch_fx_output_loss(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True)
|
||||
|
||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||
if not is_torch_fx_available():
|
||||
return
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.return_dict = False
|
||||
|
||||
for model_class in self.fx_ready_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
||||
|
||||
try:
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
input_ids = inputs["input_ids"]
|
||||
decoder_attention_mask = inputs["decoder_attention_mask"]
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
prepared_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
|
||||
model_output = model(**prepared_inputs)
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
encoder_sequence_length = input_ids.shape[1]
|
||||
decoder_sequence_length = decoder_attention_mask.shape[1]
|
||||
|
||||
traced_model = symbolic_trace(
|
||||
model,
|
||||
input_names,
|
||||
batch_size=batch_size,
|
||||
sequence_length=[encoder_sequence_length, decoder_sequence_length],
|
||||
)
|
||||
|
||||
traced_output = traced_model(**prepared_inputs)
|
||||
|
||||
else:
|
||||
input_ids = inputs["input_ids"]
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
prepared_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
|
||||
model_output = model(**prepared_inputs)
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
sequence_length = input_ids.shape[2]
|
||||
num_choices = input_ids.shape[1]
|
||||
else:
|
||||
sequence_length = input_ids.shape[1]
|
||||
num_choices = -1
|
||||
|
||||
traced_model = symbolic_trace(
|
||||
model,
|
||||
input_names,
|
||||
batch_size=batch_size,
|
||||
sequence_length=sequence_length,
|
||||
num_choices=num_choices,
|
||||
)
|
||||
traced_output = traced_model(**prepared_inputs)
|
||||
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
num_outputs = len(model_output)
|
||||
outputs_are_close = all(torch.allclose(model_output[i], traced_output[i]) for i in range(num_outputs))
|
||||
self.assertTrue(outputs_are_close)
|
||||
|
||||
def test_headmasking(self):
|
||||
if not self.test_head_masking:
|
||||
return
|
||||
|
@ -287,6 +287,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
|
@ -488,6 +488,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
||||
fx_ready_model_classes = all_model_classes
|
||||
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_torchscript = True
|
||||
|
Loading…
Reference in New Issue
Block a user