mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Initial support for symbolic tracing with torch.fx allowing dynamic axes (#13579)
* Symbolic trace dynamic axes support for BERT like models (albert, bert, distilbert, mobilebert, electra, megatron-bert) * Sanity checks before tracing that make sure the model to trace is supported * Adapted to PyTorch 1.9 Co-authored-by: Michael Benayoun <michael@huggingface.co>
This commit is contained in:
parent
46efc58024
commit
d4e4efce68
@ -280,7 +280,7 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOIN
|
||||
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
|
||||
|
||||
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
|
||||
TORCH_FX_REQUIRED_VERSION = version.parse("1.8")
|
||||
TORCH_FX_REQUIRED_VERSION = version.parse("1.9")
|
||||
TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8")
|
||||
|
||||
_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False
|
||||
|
@ -204,7 +204,7 @@ class MultiHeadSelfAttention(nn.Module):
|
||||
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
|
||||
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
|
||||
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
|
||||
scores.masked_fill_(mask, -float("inf")) # (bs, n_heads, q_length, k_length)
|
||||
scores = scores.masked_fill(mask, -float("inf")) # (bs, n_heads, q_length, k_length)
|
||||
|
||||
weights = nn.Softmax(dim=-1)(scores) # (bs, n_heads, q_length, k_length)
|
||||
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
|
||||
|
@ -1,7 +1,8 @@
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import random
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
@ -9,9 +10,8 @@ from torch import nn
|
||||
from torch.fx import Graph, GraphModule, Node, Proxy, Tracer
|
||||
from torch.fx.node import Argument
|
||||
|
||||
from transformers.file_utils import TORCH_FX_REQUIRED_VERSION, importlib_metadata, is_torch_fx_available
|
||||
|
||||
from .. import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
@ -22,16 +22,106 @@ from .. import (
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
GPT2DoubleHeadsModel,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
logging,
|
||||
)
|
||||
from ..file_utils import TORCH_FX_REQUIRED_VERSION, importlib_metadata, is_torch_fx_available
|
||||
from ..models.auto import get_values
|
||||
from .fx_transformations import (
|
||||
_cache_attributes,
|
||||
_patch_arguments_,
|
||||
_restore_attributes_,
|
||||
transform_to_dynamic_input_,
|
||||
transformation,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _generate_supported_model_classes(
|
||||
model_name: Type[PretrainedConfig],
|
||||
supported_tasks: Optional[Union[str, List[str]]] = None,
|
||||
) -> List[Type[PreTrainedModel]]:
|
||||
model_config_class = CONFIG_MAPPING[model_name]
|
||||
task_mapping = {
|
||||
"default": MODEL_MAPPING,
|
||||
"pretraining": MODEL_FOR_PRETRAINING_MAPPING,
|
||||
"next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
"masked-lm": MODEL_FOR_MASKED_LM_MAPPING,
|
||||
"causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
"seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
"multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
}
|
||||
|
||||
if supported_tasks is None:
|
||||
supported_tasks = task_mapping.keys()
|
||||
if isinstance(supported_tasks, str):
|
||||
supported_tasks = [supported_tasks]
|
||||
|
||||
model_classes = []
|
||||
for task in supported_tasks:
|
||||
model_class = task_mapping[task].get(model_config_class, None)
|
||||
if model_class:
|
||||
model_classes.append(model_class)
|
||||
|
||||
return model_classes
|
||||
|
||||
|
||||
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
||||
"albert",
|
||||
"bert",
|
||||
"distilbert",
|
||||
"mobilebert",
|
||||
"electra",
|
||||
"megatron-bert",
|
||||
"gpt2",
|
||||
"gptj",
|
||||
"gpt_neo",
|
||||
"t5",
|
||||
]
|
||||
|
||||
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS_FOR_DYNAMIC_AXES = [
|
||||
"albert",
|
||||
"bert",
|
||||
"distilbert",
|
||||
"mobilebert",
|
||||
"electra",
|
||||
"megatron-bert",
|
||||
]
|
||||
|
||||
_REGULAR_SUPPORTED_MODELS = []
|
||||
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
|
||||
if isinstance(item, dict):
|
||||
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(**item))
|
||||
else:
|
||||
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(item))
|
||||
|
||||
_SPECIAL_SUPPORTED_MODELS = [
|
||||
GPT2DoubleHeadsModel,
|
||||
]
|
||||
_SUPPORTED_MODELS = tuple(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)
|
||||
|
||||
_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = []
|
||||
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS_FOR_DYNAMIC_AXES:
|
||||
if isinstance(item, dict):
|
||||
_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES.extend(_generate_supported_model_classes(**item))
|
||||
else:
|
||||
_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES.extend(_generate_supported_model_classes(item))
|
||||
|
||||
_SPECIAL_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = []
|
||||
_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = tuple(
|
||||
_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES + _SPECIAL_SUPPORTED_MODELS_FOR_DYNAMIC_AXES
|
||||
)
|
||||
|
||||
|
||||
class HFProxy(Proxy):
|
||||
"""
|
||||
Proxy that is able to provide the proper ranks, shapes and boolean values during symbolic tracing by implementing
|
||||
@ -228,7 +318,7 @@ class HFTracer(Tracer):
|
||||
if method_names is None:
|
||||
method_names = self.default_methods_to_record
|
||||
|
||||
inputs = dict()
|
||||
inputs = {}
|
||||
for input_name in input_names:
|
||||
inputs.update(self._generate_dummy_input(model, input_name))
|
||||
|
||||
@ -251,6 +341,22 @@ class HFTracer(Tracer):
|
||||
for cache_name in self.recorded_methods.values():
|
||||
setattr(model, cache_name, getattr(clone, cache_name))
|
||||
|
||||
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
|
||||
if isinstance(attr_val, torch.nn.Parameter):
|
||||
for n, p in self.root.named_parameters():
|
||||
if attr_val is p:
|
||||
if n not in parameter_proxy_cache:
|
||||
parameter_proxy_cache[n] = self.create_proxy("get_attr", n, (), {})
|
||||
return parameter_proxy_cache[n]
|
||||
# TODO: condition this on wether dynamic axes were requested.
|
||||
if isinstance(attr_val, torch.Tensor):
|
||||
for n, p in self.root.named_buffers():
|
||||
if attr_val is p:
|
||||
if n not in parameter_proxy_cache:
|
||||
parameter_proxy_cache[n] = self.create_proxy("get_attr", n, (), {})
|
||||
return parameter_proxy_cache[n]
|
||||
return attr_val
|
||||
|
||||
def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None, method_names=None) -> Graph:
|
||||
sig = inspect.signature(root.forward)
|
||||
input_names = sig.parameters.keys() - concrete_args.keys()
|
||||
@ -264,6 +370,19 @@ class HFTracer(Tracer):
|
||||
|
||||
_reset_tensor_methods(self.original_methods)
|
||||
|
||||
# TODO: keep this until necessary.
|
||||
# This is necessary because concrete args are added as input to the traced module since
|
||||
# https://github.com/pytorch/pytorch/pull/55888.
|
||||
# A PR that solves this was posted: https://github.com/pytorch/pytorch/pull/59569 but it was not merged yet.
|
||||
for node in graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
# Removing default values for inputs as the forward pass will fail with them.
|
||||
if node.target in input_names:
|
||||
node.args = ()
|
||||
# It is a concrete arg so it is not used and should be removed.
|
||||
else:
|
||||
graph.erase_node(node)
|
||||
|
||||
return graph
|
||||
|
||||
def _insert_module_as_submodule(self, mod):
|
||||
@ -295,7 +414,7 @@ class HFTracer(Tracer):
|
||||
if path is None:
|
||||
path = self._insert_module_as_submodule(mod)
|
||||
if path is None:
|
||||
raise NameError("module is not installed as a submodule")
|
||||
raise NameError(f"Module named {mod._get_name()} is not installed as a submodule")
|
||||
self.prev_module = path
|
||||
return path
|
||||
|
||||
@ -308,7 +427,7 @@ class HFTracer(Tracer):
|
||||
return n
|
||||
path = self._insert_module_as_submodule(mod)
|
||||
if path is None:
|
||||
raise NameError("module is not installed as a submodule")
|
||||
raise NameError(f"Module {mod._get_name()} is not installed as a submodule")
|
||||
self.prev_module = path
|
||||
return path
|
||||
|
||||
@ -318,11 +437,65 @@ class HFTracer(Tracer):
|
||||
return super().create_arg(a)
|
||||
|
||||
|
||||
@transformation
|
||||
def prepare_for_retracing(gm: GraphModule) -> Tuple[GraphModule, Dict[str, Any]]:
|
||||
"""
|
||||
Prepares a GraphModule produced by symbolic_trace for retracing by:
|
||||
|
||||
- Caching all the attributes specific to the way the model was initially traced
|
||||
- Patching back the model to a "static input shapes" version if it was traced to accept dynamic input shapes
|
||||
For instance, the need to retrace a GraphModule can happen when applying quantization.
|
||||
"""
|
||||
attributes = _cache_attributes(gm)
|
||||
_patch_arguments_(gm, gm.dynamic2static)
|
||||
|
||||
return gm, attributes
|
||||
|
||||
|
||||
def restore_after_retracing_(gm: GraphModule, attributes: Dict[str, Any]):
|
||||
"""Restores a GraphModule that was retraced to its initial state in terms of static / dynamic input shapes."""
|
||||
_restore_attributes_(gm, attributes)
|
||||
# transform_to_dynamic_input_ will override the static2dynamic and dynamic2static dictionaries which is the desired
|
||||
# behaviour as the previously restored dictionaries contain nodes from the original GraphModule as values.
|
||||
transform_to_dynamic_input_(gm, is_retracing=True)
|
||||
_patch_arguments_(gm, gm.static2dynamic)
|
||||
return gm
|
||||
|
||||
|
||||
def retrace_graph_with(
|
||||
gm: GraphModule, tracer: Tracer = None, func: Callable[[GraphModule], GraphModule] = None
|
||||
) -> GraphModule:
|
||||
"""
|
||||
Retraces a GraphModule by either using a tracer or a function using a tracer (for instance
|
||||
torch.quantization.quantize_fx.prepare_fx). It takes care of preparing the model for retracing, retracing it and
|
||||
restoring anything necessary after the retrace.
|
||||
"""
|
||||
if tracer is None and func is None:
|
||||
raise ValueError("Either a tracer or a function using a tracer must be provided.")
|
||||
elif tracer is not None and func is not None:
|
||||
raise ValueError("Either provide a tracer or a function using a tracer, but not both.")
|
||||
else:
|
||||
gm, attributes = prepare_for_retracing(gm)
|
||||
tracing_func = tracer.trace if tracer else func
|
||||
traced = tracing_func(gm)
|
||||
restore_after_retracing_(traced, attributes)
|
||||
return traced
|
||||
|
||||
|
||||
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
|
||||
if forbidden_values is None:
|
||||
forbidden_values = []
|
||||
value = random.randint(low, high)
|
||||
while value in forbidden_values:
|
||||
value = random.randint(low, high)
|
||||
return value
|
||||
|
||||
|
||||
def symbolic_trace(
|
||||
model: PreTrainedModel,
|
||||
input_names: Optional[List[str]] = None,
|
||||
batch_size: int = 1,
|
||||
sequence_length: Union[int, List[int]] = [128, 128],
|
||||
sequence_length: Union[int, List[int], Tuple[int]] = (128, 128),
|
||||
num_choices: int = -1,
|
||||
) -> GraphModule:
|
||||
|
||||
@ -360,12 +533,61 @@ def symbolic_trace(
|
||||
input_names = model.dummy_inputs.keys()
|
||||
|
||||
sig = inspect.signature(model.forward)
|
||||
# TODO: how to handle the case of the "return_dict" parameter.
|
||||
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
|
||||
|
||||
# Preparing HFTracer batch_size and sequence_lenght values for potential dynamic axes.
|
||||
use_dynamic_batch_size = batch_size <= 0
|
||||
if isinstance(sequence_length, (list, tuple)):
|
||||
use_dynamic_sequence_length = sequence_length[0] <= 0 or sequence_length[1] <= 0
|
||||
else:
|
||||
use_dynamic_sequence_length = sequence_length <= 0
|
||||
|
||||
if use_dynamic_batch_size or use_dynamic_sequence_length:
|
||||
forbidden_values = [
|
||||
model.config.num_attention_heads,
|
||||
model.config.hidden_size,
|
||||
model.config.hidden_size // model.config.num_attention_heads,
|
||||
]
|
||||
if use_dynamic_batch_size:
|
||||
batch_size = _generate_random_int(forbidden_values=forbidden_values)
|
||||
forbidden_values.append(batch_size)
|
||||
if use_dynamic_sequence_length:
|
||||
encoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values)
|
||||
forbidden_values.append(encoder_sequence_length)
|
||||
decoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values)
|
||||
sequence_length = [encoder_sequence_length, decoder_sequence_length]
|
||||
|
||||
if not isinstance(model, _SUPPORTED_MODELS):
|
||||
supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS))
|
||||
raise NotImplementedError(
|
||||
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
|
||||
)
|
||||
if (use_dynamic_batch_size or use_dynamic_sequence_length) and not isinstance(
|
||||
model, _SUPPORTED_MODELS_FOR_DYNAMIC_AXES
|
||||
):
|
||||
supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS_FOR_DYNAMIC_AXES))
|
||||
raise NotImplementedError(
|
||||
f"Dynamic axes are not supported for {model.__class__.__name__} yet, supported models: {supported_model_names}"
|
||||
)
|
||||
|
||||
# Tracing.
|
||||
tracer = HFTracer(batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices)
|
||||
|
||||
traced_graph = tracer.trace(model, concrete_args=concrete_args)
|
||||
traced = torch.fx.GraphModule(model, traced_graph)
|
||||
|
||||
traced.config = copy.deepcopy(model.config)
|
||||
traced.num_choices = num_choices
|
||||
traced.dummy_inputs = {}
|
||||
|
||||
for name in input_names:
|
||||
traced.dummy_inputs.update(tracer._generate_dummy_input(model, name))
|
||||
|
||||
traced.use_dynamic_batch_size = use_dynamic_batch_size
|
||||
traced.use_dynamic_sequence_length = use_dynamic_sequence_length
|
||||
traced.static_batch_size = batch_size
|
||||
traced.static_sequence_length = sequence_length
|
||||
|
||||
transform_to_dynamic_input_(traced)
|
||||
|
||||
return traced
|
||||
|
321
src/transformers/utils/fx_transformations.py
Normal file
321
src/transformers/utils/fx_transformations.py
Normal file
@ -0,0 +1,321 @@
|
||||
import copy
|
||||
import functools
|
||||
import operator
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
|
||||
|
||||
# Torch FX transformation convention:
|
||||
# - transformations that are supposed to act on a copy of the original GraphModule are decorated with @transformation
|
||||
# - transformations that are inplace have a name ending with "_"
|
||||
|
||||
|
||||
def _cache_attributes(gm: GraphModule) -> Dict[str, Any]:
|
||||
attributes_to_keep = [
|
||||
"config",
|
||||
"num_choices",
|
||||
"dummy_inputs",
|
||||
"use_dynamic_batch_size",
|
||||
"use_dynamic_sequence_length",
|
||||
"static_batch_size",
|
||||
"static_sequence_length",
|
||||
"static2dynamic",
|
||||
"dynamic2static",
|
||||
]
|
||||
attributes = {k: getattr(gm, k, None) for k in attributes_to_keep}
|
||||
return attributes
|
||||
|
||||
|
||||
def _restore_attributes_(gm: GraphModule, attributes: Dict[str, Any]):
|
||||
for name, attr in attributes.items():
|
||||
setattr(gm, name, attr)
|
||||
|
||||
|
||||
def deepcopy_graph(gm: GraphModule) -> GraphModule:
|
||||
"""
|
||||
Performs a deepcopy of the GraphModule while also copying the relevant attributes to know whether the model was
|
||||
traced with dynamic axes, and what were the values if that is the case.
|
||||
"""
|
||||
|
||||
# First, create a copy of the module without the graph.
|
||||
graph = gm.__dict__.pop("_graph")
|
||||
fake_mod = torch.nn.Module()
|
||||
fake_mod.__dict__ = copy.deepcopy(gm.__dict__)
|
||||
gm.__dict__["_graph"] = graph
|
||||
|
||||
# Then, copy the graph.
|
||||
val_map = {}
|
||||
graph_clone = Graph()
|
||||
output_val = graph_clone.graph_copy(graph, val_map=val_map)
|
||||
graph_clone.output(output_val)
|
||||
|
||||
# Finally create a new GraphModule (or a subclass of GraphModule) from the module and the graph copies.
|
||||
# gm.__class__ is used to take into account that gm can be an instance of a subclass of GraphModule.
|
||||
clone = gm.__class__(fake_mod, graph_clone)
|
||||
|
||||
# Restore the dynamic axes related attributes to the clone.
|
||||
attributes = _cache_attributes(gm)
|
||||
attributes["dynamic2static"] = {val_map.get(k, k): v for k, v in attributes["dynamic2static"].items()}
|
||||
attributes["static2dynamic"] = {v: k for k, v in attributes["dynamic2static"].items()}
|
||||
_restore_attributes_(clone, attributes)
|
||||
|
||||
return clone
|
||||
|
||||
|
||||
def transformation(func):
|
||||
"""
|
||||
Decorator that wraps a torch.fx transformation by feeding it a copy of the GraphModule to transform instead of the
|
||||
original.
|
||||
"""
|
||||
|
||||
def map_fn(arg):
|
||||
if isinstance(arg, GraphModule):
|
||||
return deepcopy_graph(arg)
|
||||
return arg
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
new_args = tuple(map_fn(arg) for arg in args)
|
||||
new_kwargs = {k: map_fn(v) for k, v in kwargs.items()}
|
||||
return func(*new_args, **new_kwargs)
|
||||
|
||||
wrapper._is_transformation = True
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def compose_transformations(
|
||||
*args: Callable[[GraphModule], Optional[GraphModule]], inplace: bool = False
|
||||
) -> GraphModule:
|
||||
"""
|
||||
Allows to compose transformations together and takes of:
|
||||
|
||||
1. Performing the transformations on a copy of the GraphModule if inplace is set to False, transformations that
|
||||
are decorated with @transformation (which means that they are not modifying the original GraphModule) are
|
||||
unwrapped to make them inplace.
|
||||
2. Linting and recompiling only at the end of the composition for performance purposes.
|
||||
"""
|
||||
args = list(args)
|
||||
if not inplace:
|
||||
args.insert(0, deepcopy_graph)
|
||||
|
||||
for i, transformation in enumerate(args[:-1]):
|
||||
sig = signature(transformation)
|
||||
|
||||
# Unwrapping @transformation decorated transformations as performing the transformations inplace or on a copy is
|
||||
# already handled by this function.
|
||||
if getattr(transformation, "_is_transformation", False):
|
||||
transformation = transformation.__wrapped__
|
||||
|
||||
# Linting and recompiling only after the last transformation applied to make composition efficient.
|
||||
if "lint_and_recompile" in sig.parameters:
|
||||
args[i] = functools.partial(transformation, lint_and_recompile=False)
|
||||
|
||||
def reduce_func(f, g):
|
||||
def compose_f_and_g(gm):
|
||||
output_g = g(gm)
|
||||
if output_g is None:
|
||||
output_g = gm
|
||||
output_f = f(output_g)
|
||||
if output_f is None:
|
||||
output_f = gm
|
||||
return output_f
|
||||
|
||||
return compose_f_and_g
|
||||
|
||||
return functools.reduce(reduce_func, reversed(args), lambda x: x)
|
||||
|
||||
|
||||
def remove_unused_nodes_(gm: GraphModule, lint_and_recompile: bool = True):
|
||||
"""Removes all the unused nodes in a GraphModule."""
|
||||
graph = gm.graph
|
||||
for node in graph.nodes:
|
||||
if not node.users and node.op not in ["placeholder", "output"]:
|
||||
graph.erase_node(node)
|
||||
|
||||
if lint_and_recompile:
|
||||
graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
|
||||
def _insert_batch_size_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node:
|
||||
"""Inserts a node that retrieves the batch size dynamically from the input of the model."""
|
||||
graph = gm.graph
|
||||
input_names = set(gm.dummy_inputs.keys())
|
||||
batch_size_node = None
|
||||
for node in graph.nodes:
|
||||
if node.op == "placeholder" and node.name in input_names:
|
||||
with graph.inserting_after(node):
|
||||
batch_size_node = graph.call_method("size", args=(node, 0))
|
||||
|
||||
if batch_size_node is None:
|
||||
raise ValueError("Could not insert the node that computes the batch size")
|
||||
|
||||
if lint_and_recompile:
|
||||
graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
# Useful when retracing for quantization.
|
||||
if hasattr(gm, "_qconfig_map"):
|
||||
gm._qconfig_map[batch_size_node.name] = None
|
||||
|
||||
return batch_size_node
|
||||
|
||||
|
||||
def _insert_encoder_sequence_length_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node:
|
||||
"""Inserts a node that retrieves the encoder sequence length dynamically from the input of the model."""
|
||||
graph = gm.graph
|
||||
input_names = set(gm.dummy_inputs.keys())
|
||||
encoder_sequence_length_node = None
|
||||
for node in graph.nodes:
|
||||
if node.op == "placeholder" and node.name in input_names and "decoder" not in node.name:
|
||||
with graph.inserting_after(node):
|
||||
# There are two cases to handle:
|
||||
# 1. num_choices < 0, meaning that the model is not performing a "multiple choice" task, in this case the
|
||||
# input shapes is [batch_size, sequence_length] => index 1
|
||||
# 2. num_choices > 0, meaning the model is performing a "multiple choice" task, in this case the input
|
||||
# shape is [batch_size, num_choices, sequence_length] => index 2
|
||||
encoder_sequence_length_node = graph.call_method("size", args=(node, 1 if gm.num_choices < 0 else 2))
|
||||
|
||||
if encoder_sequence_length_node is None:
|
||||
raise ValueError("Could not insert the node that computes the encoder sequence length")
|
||||
|
||||
if lint_and_recompile:
|
||||
graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
# Useful when retracing for quantization.
|
||||
if hasattr(gm, "_qconfig_map"):
|
||||
gm._qconfig_map[encoder_sequence_length_node.name] = None
|
||||
|
||||
return encoder_sequence_length_node
|
||||
|
||||
|
||||
def _change_view_methods_(
|
||||
gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True
|
||||
):
|
||||
"""
|
||||
Changes arguments of view ops that refer to static batch size / sequence lengths to make them refer to the
|
||||
batch_size / sequence_length nodes.
|
||||
"""
|
||||
graph = gm.graph
|
||||
for node in graph.nodes:
|
||||
if node.op == "call_method" and node.target == "view":
|
||||
if isinstance(node.args[1], tuple):
|
||||
node.args = (node.args[0], *node.args[1])
|
||||
node.args = tuple((mapping.get(arg, arg) for arg in node.args))
|
||||
|
||||
if lint_and_recompile:
|
||||
graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
|
||||
def _patch_getitem_(
|
||||
gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True
|
||||
):
|
||||
"""Patches getitem nodes by replacing current arguments to their corresponding values in mapping."""
|
||||
# TODO: combine this with the patch_argument function which seems to do almost the same thing.
|
||||
graph = gm.graph
|
||||
for node in graph.nodes:
|
||||
if node.op == "call_function" and node.target == operator.getitem:
|
||||
indices = node.args[1]
|
||||
if isinstance(indices, tuple):
|
||||
new_indices = []
|
||||
for idx in indices:
|
||||
if isinstance(idx, slice):
|
||||
new_indices.append(
|
||||
slice(
|
||||
mapping.get(idx.start, idx.start),
|
||||
mapping.get(idx.stop, idx.stop),
|
||||
mapping.get(idx.step, idx.step),
|
||||
)
|
||||
)
|
||||
elif isinstance(idx, int):
|
||||
new_indices.append(mapping.get(idx, idx))
|
||||
else:
|
||||
new_indices.append(idx)
|
||||
|
||||
node.args = (node.args[0], tuple(new_indices))
|
||||
else:
|
||||
node.args = (node.args[0], mapping.get(node.args[1], node.args[1]))
|
||||
|
||||
if lint_and_recompile:
|
||||
graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
|
||||
def _patch_arguments_(
|
||||
gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True
|
||||
):
|
||||
"""
|
||||
Patches node by replacing their argument to their corresponding values in mapping (supports regular types, tuples
|
||||
and slices).
|
||||
"""
|
||||
|
||||
def _patch_slice(s, mapping):
|
||||
return slice(mapping.get(s.start, s.start), mapping.get(s.stop, s.stop), mapping.get(s.step, s.step))
|
||||
|
||||
graph = gm.graph
|
||||
supported_types = (Node, str, int, float)
|
||||
for node in graph.nodes:
|
||||
new_args = []
|
||||
for arg in node.args:
|
||||
if isinstance(arg, tuple):
|
||||
new_arg = []
|
||||
for a in arg:
|
||||
if isinstance(a, slice):
|
||||
new_arg.append(_patch_slice(a, mapping))
|
||||
else:
|
||||
new_arg.append(mapping.get(a, a))
|
||||
new_args.append(tuple(new_arg))
|
||||
elif isinstance(arg, slice):
|
||||
new_args.append(_patch_slice(arg, mapping))
|
||||
elif isinstance(arg, supported_types):
|
||||
new_args.append(mapping.get(arg, arg))
|
||||
else:
|
||||
new_args.append(arg)
|
||||
node.args = tuple(new_args)
|
||||
|
||||
if lint_and_recompile:
|
||||
graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
|
||||
def transform_to_dynamic_input_(gm: GraphModule, is_retracing: bool = False):
|
||||
"""Transformation that enables traced models to perform inference on dynamic input shapes."""
|
||||
graph = gm.graph
|
||||
static2dynamic = {}
|
||||
|
||||
# Inserting the nodes that will fetch the batch size and sequence lengths dynamically.
|
||||
if gm.use_dynamic_batch_size:
|
||||
batch_size_node = _insert_batch_size_node_(gm, lint_and_recompile=False)
|
||||
static2dynamic[gm.static_batch_size] = batch_size_node
|
||||
if gm.num_choices > 0:
|
||||
with graph.inserting_after(batch_size_node):
|
||||
static2dynamic[gm.static_batch_size * gm.num_choices] = graph.call_function(
|
||||
operator.mul, args=(batch_size_node, gm.num_choices)
|
||||
)
|
||||
# Useful when retracing for quantization.
|
||||
if hasattr(gm, "_qconfig_map"):
|
||||
gm._qconfig_map[static2dynamic[gm.static_batch_size * gm.num_choices]] = None
|
||||
|
||||
if gm.use_dynamic_sequence_length:
|
||||
encoder_sequence_length_node = _insert_encoder_sequence_length_node_(gm, lint_and_recompile=False)
|
||||
static2dynamic[gm.static_sequence_length[0]] = encoder_sequence_length_node
|
||||
|
||||
# TODO: do the same for the decoder.
|
||||
pass
|
||||
|
||||
_change_view_methods_(gm, static2dynamic, lint_and_recompile=False)
|
||||
_patch_getitem_(gm, static2dynamic, lint_and_recompile=False)
|
||||
|
||||
remove_unused_nodes_(gm, lint_and_recompile=False)
|
||||
|
||||
graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
gm.static2dynamic = static2dynamic
|
||||
gm.dynamic2static = {v: k for (k, v) in static2dynamic.items()}
|
@ -232,6 +232,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
|
@ -445,6 +445,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
)
|
||||
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
|
@ -92,6 +92,7 @@ class ModelTesterMixin:
|
||||
all_model_classes = ()
|
||||
all_generative_model_classes = ()
|
||||
fx_ready_model_classes = ()
|
||||
fx_dynamic_ready_model_classes = ()
|
||||
test_torchscript = True
|
||||
test_pruning = True
|
||||
test_resize_embeddings = True
|
||||
@ -607,14 +608,19 @@ class ModelTesterMixin:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True)
|
||||
|
||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||
def test_torch_fx_dynamic_axes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self._create_and_check_torch_fx_tracing(config, inputs_dict, dynamic_axes=True)
|
||||
|
||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False, dynamic_axes=False):
|
||||
if not is_torch_fx_available():
|
||||
return
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.return_dict = False
|
||||
|
||||
for model_class in self.fx_ready_model_classes:
|
||||
model_classes = self.fx_ready_model_classes if not dynamic_axes else self.fx_dynamic_ready_model_classes
|
||||
for model_class in model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@ -640,12 +646,11 @@ class ModelTesterMixin:
|
||||
traced_model = symbolic_trace(
|
||||
model,
|
||||
input_names,
|
||||
batch_size=batch_size,
|
||||
sequence_length=[encoder_sequence_length, decoder_sequence_length],
|
||||
batch_size=batch_size if not dynamic_axes else -1,
|
||||
sequence_length=[encoder_sequence_length, decoder_sequence_length] if not dynamic_axes else -1,
|
||||
)
|
||||
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
||||
else:
|
||||
input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
||||
input_ids = inputs["input_ids"]
|
||||
@ -679,8 +684,8 @@ class ModelTesterMixin:
|
||||
traced_model = symbolic_trace(
|
||||
model,
|
||||
input_names,
|
||||
batch_size=batch_size,
|
||||
sequence_length=sequence_length,
|
||||
batch_size=batch_size if not dynamic_axes else -1,
|
||||
sequence_length=sequence_length if not dynamic_axes else -1,
|
||||
num_choices=num_choices,
|
||||
)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
|
@ -210,6 +210,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else None
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
test_pruning = True
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = True
|
||||
|
@ -290,6 +290,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
|
@ -284,6 +284,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
|
||||
# test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
|
@ -270,6 +270,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
test_sequence_classification_problem_types = True
|
||||
|
||||
# special case for ForPreTraining model
|
||||
|
Loading…
Reference in New Issue
Block a user