This commit is contained in:
Nicolas Patry 2021-11-08 10:58:22 +01:00
parent 5b03214158
commit 61d4bb333f
7 changed files with 11 additions and 7 deletions

View File

@ -133,7 +133,9 @@ class DebugUnderflowOverflow:
"""
def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None):
def __init__(self, model, max_frames_to_save=21, trace_batch_nums=None, abort_after_batch_num=None):
if trace_batch_nums is None:
trace_batch_nums = []
self.model = model
self.trace_batch_nums = trace_batch_nums
self.abort_after_batch_num = abort_after_batch_num

View File

@ -235,7 +235,9 @@ class HFTracer(Tracer):
default_methods_to_record = {"__bool__", "size", "dim"}
def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1):
def __init__(self, batch_size=1, sequence_length=None, num_choices=-1):
if sequence_length is None:
sequence_length = [128, 128]
super().__init__()
if not is_torch_fx_available():

View File

@ -363,7 +363,7 @@ class CanineModelTest(ModelTesterMixin, unittest.TestCase):
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
def check_equivalence(model, tuple_inputs, dict_inputs, **additional_kwargs):
with torch.no_grad():
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()

View File

@ -1366,7 +1366,7 @@ class ModelTesterMixin:
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
def check_equivalence(model, tuple_inputs, dict_inputs, **additional_kwargs):
with torch.no_grad():
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()

View File

@ -137,7 +137,7 @@ class FlaxModelTesterMixin:
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
def check_equivalence(model, tuple_inputs, dict_inputs, **additional_kwargs):
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()

View File

@ -313,7 +313,7 @@ class FNetModelTest(ModelTesterMixin, unittest.TestCase):
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
def check_equivalence(model, tuple_inputs, dict_inputs, **additional_kwargs):
with torch.no_grad():
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()

View File

@ -875,7 +875,7 @@ class TFModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
def check_equivalence(model, tuple_inputs, dict_inputs, **additional_kwargs):
tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()