From 61d4bb333f6d39a7fbe31d161b8bd14787ceec2e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 8 Nov 2021 10:58:22 +0100 Subject: [PATCH] XX. --- src/transformers/debug_utils.py | 4 +++- src/transformers/utils/fx.py | 4 +++- tests/test_modeling_canine.py | 2 +- tests/test_modeling_common.py | 2 +- tests/test_modeling_flax_common.py | 2 +- tests/test_modeling_fnet.py | 2 +- tests/test_modeling_tf_common.py | 2 +- 7 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/transformers/debug_utils.py b/src/transformers/debug_utils.py index 4588ca58f5f..029475f64ab 100644 --- a/src/transformers/debug_utils.py +++ b/src/transformers/debug_utils.py @@ -133,7 +133,9 @@ class DebugUnderflowOverflow: """ - def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None): + def __init__(self, model, max_frames_to_save=21, trace_batch_nums=None, abort_after_batch_num=None): + if trace_batch_nums is None: + trace_batch_nums = [] self.model = model self.trace_batch_nums = trace_batch_nums self.abort_after_batch_num = abort_after_batch_num diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 38f2caa904b..c17a9100e27 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -235,7 +235,9 @@ class HFTracer(Tracer): default_methods_to_record = {"__bool__", "size", "dim"} - def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1): + def __init__(self, batch_size=1, sequence_length=None, num_choices=-1): + if sequence_length is None: + sequence_length = [128, 128] super().__init__() if not is_torch_fx_available(): diff --git a/tests/test_modeling_canine.py b/tests/test_modeling_canine.py index 888e1d33e6f..d710b6e2a59 100644 --- a/tests/test_modeling_canine.py +++ b/tests/test_modeling_canine.py @@ -363,7 +363,7 @@ class CanineModelTest(ModelTesterMixin, unittest.TestCase): t[t != t] = 0 return t - def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + def check_equivalence(model, tuple_inputs, dict_inputs, **additional_kwargs): with torch.no_grad(): tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9dfe9275fda..b4aaf114fd8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1366,7 +1366,7 @@ class ModelTesterMixin: t[t != t] = 0 return t - def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + def check_equivalence(model, tuple_inputs, dict_inputs, **additional_kwargs): with torch.no_grad(): tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 228084a0dfa..e071670c529 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -137,7 +137,7 @@ class FlaxModelTesterMixin: t[t != t] = 0 return t - def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + def check_equivalence(model, tuple_inputs, dict_inputs, **additional_kwargs): tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() diff --git a/tests/test_modeling_fnet.py b/tests/test_modeling_fnet.py index eaa61f779f0..217db0db860 100644 --- a/tests/test_modeling_fnet.py +++ b/tests/test_modeling_fnet.py @@ -313,7 +313,7 @@ class FNetModelTest(ModelTesterMixin, unittest.TestCase): t[t != t] = 0 return t - def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + def check_equivalence(model, tuple_inputs, dict_inputs, **additional_kwargs): with torch.no_grad(): tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 64ca24eeb6f..9423b7494a8 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -875,7 +875,7 @@ class TFModelTesterMixin: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + def check_equivalence(model, tuple_inputs, dict_inputs, **additional_kwargs): tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs) dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()