Initial support for symbolic tracing with torch.fx allowing dynamic axes (#13579)

* Symbolic trace dynamic axes support for BERT like models (albert, bert, distilbert, mobilebert, electra, megatron-bert)
* Sanity checks before tracing that make sure the model to trace is supported
* Adapted to PyTorch 1.9

Co-authored-by: Michael Benayoun <michael@huggingface.co>
This commit is contained in:
Michael Benayoun 2021-10-05 14:19:47 +02:00 committed by GitHub
parent 46efc58024
commit d4e4efce68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 571 additions and 17 deletions

View File

@ -280,7 +280,7 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOIN
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
TORCH_FX_REQUIRED_VERSION = version.parse("1.8")
TORCH_FX_REQUIRED_VERSION = version.parse("1.9")
TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8")
_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False

View File

@ -204,7 +204,7 @@ class MultiHeadSelfAttention(nn.Module):
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
scores.masked_fill_(mask, -float("inf")) # (bs, n_heads, q_length, k_length)
scores = scores.masked_fill(mask, -float("inf")) # (bs, n_heads, q_length, k_length)
weights = nn.Softmax(dim=-1)(scores) # (bs, n_heads, q_length, k_length)
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)

View File

@ -1,7 +1,8 @@
import copy
import functools
import inspect
from typing import Any, Dict, List, Optional, Union
import random
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import torch
from packaging import version
@ -9,9 +10,8 @@ from torch import nn
from torch.fx import Graph, GraphModule, Node, Proxy, Tracer
from torch.fx.node import Argument
from transformers.file_utils import TORCH_FX_REQUIRED_VERSION, importlib_metadata, is_torch_fx_available
from .. import (
CONFIG_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
@ -22,16 +22,106 @@ from .. import (
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
GPT2DoubleHeadsModel,
PretrainedConfig,
PreTrainedModel,
logging,
)
from ..file_utils import TORCH_FX_REQUIRED_VERSION, importlib_metadata, is_torch_fx_available
from ..models.auto import get_values
from .fx_transformations import (
_cache_attributes,
_patch_arguments_,
_restore_attributes_,
transform_to_dynamic_input_,
transformation,
)
logger = logging.get_logger(__name__)
def _generate_supported_model_classes(
model_name: Type[PretrainedConfig],
supported_tasks: Optional[Union[str, List[str]]] = None,
) -> List[Type[PreTrainedModel]]:
model_config_class = CONFIG_MAPPING[model_name]
task_mapping = {
"default": MODEL_MAPPING,
"pretraining": MODEL_FOR_PRETRAINING_MAPPING,
"next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
"masked-lm": MODEL_FOR_MASKED_LM_MAPPING,
"causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING,
"seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
"multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
}
if supported_tasks is None:
supported_tasks = task_mapping.keys()
if isinstance(supported_tasks, str):
supported_tasks = [supported_tasks]
model_classes = []
for task in supported_tasks:
model_class = task_mapping[task].get(model_config_class, None)
if model_class:
model_classes.append(model_class)
return model_classes
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"albert",
"bert",
"distilbert",
"mobilebert",
"electra",
"megatron-bert",
"gpt2",
"gptj",
"gpt_neo",
"t5",
]
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS_FOR_DYNAMIC_AXES = [
"albert",
"bert",
"distilbert",
"mobilebert",
"electra",
"megatron-bert",
]
_REGULAR_SUPPORTED_MODELS = []
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
if isinstance(item, dict):
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(**item))
else:
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_classes(item))
_SPECIAL_SUPPORTED_MODELS = [
GPT2DoubleHeadsModel,
]
_SUPPORTED_MODELS = tuple(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)
_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = []
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS_FOR_DYNAMIC_AXES:
if isinstance(item, dict):
_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES.extend(_generate_supported_model_classes(**item))
else:
_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES.extend(_generate_supported_model_classes(item))
_SPECIAL_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = []
_SUPPORTED_MODELS_FOR_DYNAMIC_AXES = tuple(
_REGULAR_SUPPORTED_MODELS_FOR_DYNAMIC_AXES + _SPECIAL_SUPPORTED_MODELS_FOR_DYNAMIC_AXES
)
class HFProxy(Proxy):
"""
Proxy that is able to provide the proper ranks, shapes and boolean values during symbolic tracing by implementing
@ -228,7 +318,7 @@ class HFTracer(Tracer):
if method_names is None:
method_names = self.default_methods_to_record
inputs = dict()
inputs = {}
for input_name in input_names:
inputs.update(self._generate_dummy_input(model, input_name))
@ -251,6 +341,22 @@ class HFTracer(Tracer):
for cache_name in self.recorded_methods.values():
setattr(model, cache_name, getattr(clone, cache_name))
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if isinstance(attr_val, torch.nn.Parameter):
for n, p in self.root.named_parameters():
if attr_val is p:
if n not in parameter_proxy_cache:
parameter_proxy_cache[n] = self.create_proxy("get_attr", n, (), {})
return parameter_proxy_cache[n]
# TODO: condition this on wether dynamic axes were requested.
if isinstance(attr_val, torch.Tensor):
for n, p in self.root.named_buffers():
if attr_val is p:
if n not in parameter_proxy_cache:
parameter_proxy_cache[n] = self.create_proxy("get_attr", n, (), {})
return parameter_proxy_cache[n]
return attr_val
def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None, method_names=None) -> Graph:
sig = inspect.signature(root.forward)
input_names = sig.parameters.keys() - concrete_args.keys()
@ -264,6 +370,19 @@ class HFTracer(Tracer):
_reset_tensor_methods(self.original_methods)
# TODO: keep this until necessary.
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
# A PR that solves this was posted: https://github.com/pytorch/pytorch/pull/59569 but it was not merged yet.
for node in graph.nodes:
if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them.
if node.target in input_names:
node.args = ()
# It is a concrete arg so it is not used and should be removed.
else:
graph.erase_node(node)
return graph
def _insert_module_as_submodule(self, mod):
@ -295,7 +414,7 @@ class HFTracer(Tracer):
if path is None:
path = self._insert_module_as_submodule(mod)
if path is None:
raise NameError("module is not installed as a submodule")
raise NameError(f"Module named {mod._get_name()} is not installed as a submodule")
self.prev_module = path
return path
@ -308,7 +427,7 @@ class HFTracer(Tracer):
return n
path = self._insert_module_as_submodule(mod)
if path is None:
raise NameError("module is not installed as a submodule")
raise NameError(f"Module {mod._get_name()} is not installed as a submodule")
self.prev_module = path
return path
@ -318,11 +437,65 @@ class HFTracer(Tracer):
return super().create_arg(a)
@transformation
def prepare_for_retracing(gm: GraphModule) -> Tuple[GraphModule, Dict[str, Any]]:
"""
Prepares a GraphModule produced by symbolic_trace for retracing by:
- Caching all the attributes specific to the way the model was initially traced
- Patching back the model to a "static input shapes" version if it was traced to accept dynamic input shapes
For instance, the need to retrace a GraphModule can happen when applying quantization.
"""
attributes = _cache_attributes(gm)
_patch_arguments_(gm, gm.dynamic2static)
return gm, attributes
def restore_after_retracing_(gm: GraphModule, attributes: Dict[str, Any]):
"""Restores a GraphModule that was retraced to its initial state in terms of static / dynamic input shapes."""
_restore_attributes_(gm, attributes)
# transform_to_dynamic_input_ will override the static2dynamic and dynamic2static dictionaries which is the desired
# behaviour as the previously restored dictionaries contain nodes from the original GraphModule as values.
transform_to_dynamic_input_(gm, is_retracing=True)
_patch_arguments_(gm, gm.static2dynamic)
return gm
def retrace_graph_with(
gm: GraphModule, tracer: Tracer = None, func: Callable[[GraphModule], GraphModule] = None
) -> GraphModule:
"""
Retraces a GraphModule by either using a tracer or a function using a tracer (for instance
torch.quantization.quantize_fx.prepare_fx). It takes care of preparing the model for retracing, retracing it and
restoring anything necessary after the retrace.
"""
if tracer is None and func is None:
raise ValueError("Either a tracer or a function using a tracer must be provided.")
elif tracer is not None and func is not None:
raise ValueError("Either provide a tracer or a function using a tracer, but not both.")
else:
gm, attributes = prepare_for_retracing(gm)
tracing_func = tracer.trace if tracer else func
traced = tracing_func(gm)
restore_after_retracing_(traced, attributes)
return traced
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
if forbidden_values is None:
forbidden_values = []
value = random.randint(low, high)
while value in forbidden_values:
value = random.randint(low, high)
return value
def symbolic_trace(
model: PreTrainedModel,
input_names: Optional[List[str]] = None,
batch_size: int = 1,
sequence_length: Union[int, List[int]] = [128, 128],
sequence_length: Union[int, List[int], Tuple[int]] = (128, 128),
num_choices: int = -1,
) -> GraphModule:
@ -360,12 +533,61 @@ def symbolic_trace(
input_names = model.dummy_inputs.keys()
sig = inspect.signature(model.forward)
# TODO: how to handle the case of the "return_dict" parameter.
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
# Preparing HFTracer batch_size and sequence_lenght values for potential dynamic axes.
use_dynamic_batch_size = batch_size <= 0
if isinstance(sequence_length, (list, tuple)):
use_dynamic_sequence_length = sequence_length[0] <= 0 or sequence_length[1] <= 0
else:
use_dynamic_sequence_length = sequence_length <= 0
if use_dynamic_batch_size or use_dynamic_sequence_length:
forbidden_values = [
model.config.num_attention_heads,
model.config.hidden_size,
model.config.hidden_size // model.config.num_attention_heads,
]
if use_dynamic_batch_size:
batch_size = _generate_random_int(forbidden_values=forbidden_values)
forbidden_values.append(batch_size)
if use_dynamic_sequence_length:
encoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values)
forbidden_values.append(encoder_sequence_length)
decoder_sequence_length = _generate_random_int(forbidden_values=forbidden_values)
sequence_length = [encoder_sequence_length, decoder_sequence_length]
if not isinstance(model, _SUPPORTED_MODELS):
supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS))
raise NotImplementedError(
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
)
if (use_dynamic_batch_size or use_dynamic_sequence_length) and not isinstance(
model, _SUPPORTED_MODELS_FOR_DYNAMIC_AXES
):
supported_model_names = ", ".join((cls.__name__ for cls in _SUPPORTED_MODELS_FOR_DYNAMIC_AXES))
raise NotImplementedError(
f"Dynamic axes are not supported for {model.__class__.__name__} yet, supported models: {supported_model_names}"
)
# Tracing.
tracer = HFTracer(batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices)
traced_graph = tracer.trace(model, concrete_args=concrete_args)
traced = torch.fx.GraphModule(model, traced_graph)
traced.config = copy.deepcopy(model.config)
traced.num_choices = num_choices
traced.dummy_inputs = {}
for name in input_names:
traced.dummy_inputs.update(tracer._generate_dummy_input(model, name))
traced.use_dynamic_batch_size = use_dynamic_batch_size
traced.use_dynamic_sequence_length = use_dynamic_sequence_length
traced.static_batch_size = batch_size
traced.static_sequence_length = sequence_length
transform_to_dynamic_input_(traced)
return traced

View File

@ -0,0 +1,321 @@
import copy
import functools
import operator
from inspect import signature
from typing import Any, Callable, Dict, Optional, Union
import torch
from torch.fx import Graph, GraphModule, Node
# Torch FX transformation convention:
# - transformations that are supposed to act on a copy of the original GraphModule are decorated with @transformation
# - transformations that are inplace have a name ending with "_"
def _cache_attributes(gm: GraphModule) -> Dict[str, Any]:
attributes_to_keep = [
"config",
"num_choices",
"dummy_inputs",
"use_dynamic_batch_size",
"use_dynamic_sequence_length",
"static_batch_size",
"static_sequence_length",
"static2dynamic",
"dynamic2static",
]
attributes = {k: getattr(gm, k, None) for k in attributes_to_keep}
return attributes
def _restore_attributes_(gm: GraphModule, attributes: Dict[str, Any]):
for name, attr in attributes.items():
setattr(gm, name, attr)
def deepcopy_graph(gm: GraphModule) -> GraphModule:
"""
Performs a deepcopy of the GraphModule while also copying the relevant attributes to know whether the model was
traced with dynamic axes, and what were the values if that is the case.
"""
# First, create a copy of the module without the graph.
graph = gm.__dict__.pop("_graph")
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(gm.__dict__)
gm.__dict__["_graph"] = graph
# Then, copy the graph.
val_map = {}
graph_clone = Graph()
output_val = graph_clone.graph_copy(graph, val_map=val_map)
graph_clone.output(output_val)
# Finally create a new GraphModule (or a subclass of GraphModule) from the module and the graph copies.
# gm.__class__ is used to take into account that gm can be an instance of a subclass of GraphModule.
clone = gm.__class__(fake_mod, graph_clone)
# Restore the dynamic axes related attributes to the clone.
attributes = _cache_attributes(gm)
attributes["dynamic2static"] = {val_map.get(k, k): v for k, v in attributes["dynamic2static"].items()}
attributes["static2dynamic"] = {v: k for k, v in attributes["dynamic2static"].items()}
_restore_attributes_(clone, attributes)
return clone
def transformation(func):
"""
Decorator that wraps a torch.fx transformation by feeding it a copy of the GraphModule to transform instead of the
original.
"""
def map_fn(arg):
if isinstance(arg, GraphModule):
return deepcopy_graph(arg)
return arg
@functools.wraps(func)
def wrapper(*args, **kwargs):
new_args = tuple(map_fn(arg) for arg in args)
new_kwargs = {k: map_fn(v) for k, v in kwargs.items()}
return func(*new_args, **new_kwargs)
wrapper._is_transformation = True
return wrapper
def compose_transformations(
*args: Callable[[GraphModule], Optional[GraphModule]], inplace: bool = False
) -> GraphModule:
"""
Allows to compose transformations together and takes of:
1. Performing the transformations on a copy of the GraphModule if inplace is set to False, transformations that
are decorated with @transformation (which means that they are not modifying the original GraphModule) are
unwrapped to make them inplace.
2. Linting and recompiling only at the end of the composition for performance purposes.
"""
args = list(args)
if not inplace:
args.insert(0, deepcopy_graph)
for i, transformation in enumerate(args[:-1]):
sig = signature(transformation)
# Unwrapping @transformation decorated transformations as performing the transformations inplace or on a copy is
# already handled by this function.
if getattr(transformation, "_is_transformation", False):
transformation = transformation.__wrapped__
# Linting and recompiling only after the last transformation applied to make composition efficient.
if "lint_and_recompile" in sig.parameters:
args[i] = functools.partial(transformation, lint_and_recompile=False)
def reduce_func(f, g):
def compose_f_and_g(gm):
output_g = g(gm)
if output_g is None:
output_g = gm
output_f = f(output_g)
if output_f is None:
output_f = gm
return output_f
return compose_f_and_g
return functools.reduce(reduce_func, reversed(args), lambda x: x)
def remove_unused_nodes_(gm: GraphModule, lint_and_recompile: bool = True):
"""Removes all the unused nodes in a GraphModule."""
graph = gm.graph
for node in graph.nodes:
if not node.users and node.op not in ["placeholder", "output"]:
graph.erase_node(node)
if lint_and_recompile:
graph.lint()
gm.recompile()
def _insert_batch_size_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node:
"""Inserts a node that retrieves the batch size dynamically from the input of the model."""
graph = gm.graph
input_names = set(gm.dummy_inputs.keys())
batch_size_node = None
for node in graph.nodes:
if node.op == "placeholder" and node.name in input_names:
with graph.inserting_after(node):
batch_size_node = graph.call_method("size", args=(node, 0))
if batch_size_node is None:
raise ValueError("Could not insert the node that computes the batch size")
if lint_and_recompile:
graph.lint()
gm.recompile()
# Useful when retracing for quantization.
if hasattr(gm, "_qconfig_map"):
gm._qconfig_map[batch_size_node.name] = None
return batch_size_node
def _insert_encoder_sequence_length_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node:
"""Inserts a node that retrieves the encoder sequence length dynamically from the input of the model."""
graph = gm.graph
input_names = set(gm.dummy_inputs.keys())
encoder_sequence_length_node = None
for node in graph.nodes:
if node.op == "placeholder" and node.name in input_names and "decoder" not in node.name:
with graph.inserting_after(node):
# There are two cases to handle:
# 1. num_choices < 0, meaning that the model is not performing a "multiple choice" task, in this case the
# input shapes is [batch_size, sequence_length] => index 1
# 2. num_choices > 0, meaning the model is performing a "multiple choice" task, in this case the input
# shape is [batch_size, num_choices, sequence_length] => index 2
encoder_sequence_length_node = graph.call_method("size", args=(node, 1 if gm.num_choices < 0 else 2))
if encoder_sequence_length_node is None:
raise ValueError("Could not insert the node that computes the encoder sequence length")
if lint_and_recompile:
graph.lint()
gm.recompile()
# Useful when retracing for quantization.
if hasattr(gm, "_qconfig_map"):
gm._qconfig_map[encoder_sequence_length_node.name] = None
return encoder_sequence_length_node
def _change_view_methods_(
gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True
):
"""
Changes arguments of view ops that refer to static batch size / sequence lengths to make them refer to the
batch_size / sequence_length nodes.
"""
graph = gm.graph
for node in graph.nodes:
if node.op == "call_method" and node.target == "view":
if isinstance(node.args[1], tuple):
node.args = (node.args[0], *node.args[1])
node.args = tuple((mapping.get(arg, arg) for arg in node.args))
if lint_and_recompile:
graph.lint()
gm.recompile()
def _patch_getitem_(
gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True
):
"""Patches getitem nodes by replacing current arguments to their corresponding values in mapping."""
# TODO: combine this with the patch_argument function which seems to do almost the same thing.
graph = gm.graph
for node in graph.nodes:
if node.op == "call_function" and node.target == operator.getitem:
indices = node.args[1]
if isinstance(indices, tuple):
new_indices = []
for idx in indices:
if isinstance(idx, slice):
new_indices.append(
slice(
mapping.get(idx.start, idx.start),
mapping.get(idx.stop, idx.stop),
mapping.get(idx.step, idx.step),
)
)
elif isinstance(idx, int):
new_indices.append(mapping.get(idx, idx))
else:
new_indices.append(idx)
node.args = (node.args[0], tuple(new_indices))
else:
node.args = (node.args[0], mapping.get(node.args[1], node.args[1]))
if lint_and_recompile:
graph.lint()
gm.recompile()
def _patch_arguments_(
gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True
):
"""
Patches node by replacing their argument to their corresponding values in mapping (supports regular types, tuples
and slices).
"""
def _patch_slice(s, mapping):
return slice(mapping.get(s.start, s.start), mapping.get(s.stop, s.stop), mapping.get(s.step, s.step))
graph = gm.graph
supported_types = (Node, str, int, float)
for node in graph.nodes:
new_args = []
for arg in node.args:
if isinstance(arg, tuple):
new_arg = []
for a in arg:
if isinstance(a, slice):
new_arg.append(_patch_slice(a, mapping))
else:
new_arg.append(mapping.get(a, a))
new_args.append(tuple(new_arg))
elif isinstance(arg, slice):
new_args.append(_patch_slice(arg, mapping))
elif isinstance(arg, supported_types):
new_args.append(mapping.get(arg, arg))
else:
new_args.append(arg)
node.args = tuple(new_args)
if lint_and_recompile:
graph.lint()
gm.recompile()
def transform_to_dynamic_input_(gm: GraphModule, is_retracing: bool = False):
"""Transformation that enables traced models to perform inference on dynamic input shapes."""
graph = gm.graph
static2dynamic = {}
# Inserting the nodes that will fetch the batch size and sequence lengths dynamically.
if gm.use_dynamic_batch_size:
batch_size_node = _insert_batch_size_node_(gm, lint_and_recompile=False)
static2dynamic[gm.static_batch_size] = batch_size_node
if gm.num_choices > 0:
with graph.inserting_after(batch_size_node):
static2dynamic[gm.static_batch_size * gm.num_choices] = graph.call_function(
operator.mul, args=(batch_size_node, gm.num_choices)
)
# Useful when retracing for quantization.
if hasattr(gm, "_qconfig_map"):
gm._qconfig_map[static2dynamic[gm.static_batch_size * gm.num_choices]] = None
if gm.use_dynamic_sequence_length:
encoder_sequence_length_node = _insert_encoder_sequence_length_node_(gm, lint_and_recompile=False)
static2dynamic[gm.static_sequence_length[0]] = encoder_sequence_length_node
# TODO: do the same for the decoder.
pass
_change_view_methods_(gm, static2dynamic, lint_and_recompile=False)
_patch_getitem_(gm, static2dynamic, lint_and_recompile=False)
remove_unused_nodes_(gm, lint_and_recompile=False)
graph.lint()
gm.recompile()
gm.static2dynamic = static2dynamic
gm.dynamic2static = {v: k for (k, v) in static2dynamic.items()}

View File

@ -232,6 +232,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True

View File

@ -445,6 +445,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
)
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model

View File

@ -92,6 +92,7 @@ class ModelTesterMixin:
all_model_classes = ()
all_generative_model_classes = ()
fx_ready_model_classes = ()
fx_dynamic_ready_model_classes = ()
test_torchscript = True
test_pruning = True
test_resize_embeddings = True
@ -607,14 +608,19 @@ class ModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True)
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
def test_torch_fx_dynamic_axes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
self._create_and_check_torch_fx_tracing(config, inputs_dict, dynamic_axes=True)
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False, dynamic_axes=False):
if not is_torch_fx_available():
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False
for model_class in self.fx_ready_model_classes:
model_classes = self.fx_ready_model_classes if not dynamic_axes else self.fx_dynamic_ready_model_classes
for model_class in model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
@ -640,12 +646,11 @@ class ModelTesterMixin:
traced_model = symbolic_trace(
model,
input_names,
batch_size=batch_size,
sequence_length=[encoder_sequence_length, decoder_sequence_length],
batch_size=batch_size if not dynamic_axes else -1,
sequence_length=[encoder_sequence_length, decoder_sequence_length] if not dynamic_axes else -1,
)
traced_output = traced_model(**filtered_inputs)
else:
input_names = ["input_ids", "attention_mask", "token_type_ids"]
input_ids = inputs["input_ids"]
@ -679,8 +684,8 @@ class ModelTesterMixin:
traced_model = symbolic_trace(
model,
input_names,
batch_size=batch_size,
sequence_length=sequence_length,
batch_size=batch_size if not dynamic_axes else -1,
sequence_length=sequence_length if not dynamic_axes else -1,
num_choices=num_choices,
)
traced_output = traced_model(**filtered_inputs)

View File

@ -210,6 +210,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
else None
)
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_pruning = True
test_torchscript = True
test_resize_embeddings = True

View File

@ -290,6 +290,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model

View File

@ -284,6 +284,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
# test_resize_embeddings = False
test_head_masking = False

View File

@ -270,6 +270,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model