mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
XX.
This commit is contained in:
parent
5b03214158
commit
61d4bb333f
@ -133,7 +133,9 @@ class DebugUnderflowOverflow:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None):
|
||||
def __init__(self, model, max_frames_to_save=21, trace_batch_nums=None, abort_after_batch_num=None):
|
||||
if trace_batch_nums is None:
|
||||
trace_batch_nums = []
|
||||
self.model = model
|
||||
self.trace_batch_nums = trace_batch_nums
|
||||
self.abort_after_batch_num = abort_after_batch_num
|
||||
|
@ -235,7 +235,9 @@ class HFTracer(Tracer):
|
||||
|
||||
default_methods_to_record = {"__bool__", "size", "dim"}
|
||||
|
||||
def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1):
|
||||
def __init__(self, batch_size=1, sequence_length=None, num_choices=-1):
|
||||
if sequence_length is None:
|
||||
sequence_length = [128, 128]
|
||||
super().__init__()
|
||||
|
||||
if not is_torch_fx_available():
|
||||
|
@ -363,7 +363,7 @@ class CanineModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
t[t != t] = 0
|
||||
return t
|
||||
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, **additional_kwargs):
|
||||
with torch.no_grad():
|
||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
|
@ -1366,7 +1366,7 @@ class ModelTesterMixin:
|
||||
t[t != t] = 0
|
||||
return t
|
||||
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, **additional_kwargs):
|
||||
with torch.no_grad():
|
||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
|
@ -137,7 +137,7 @@ class FlaxModelTesterMixin:
|
||||
t[t != t] = 0
|
||||
return t
|
||||
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, **additional_kwargs):
|
||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
|
||||
|
@ -313,7 +313,7 @@ class FNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
t[t != t] = 0
|
||||
return t
|
||||
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, **additional_kwargs):
|
||||
with torch.no_grad():
|
||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
|
@ -875,7 +875,7 @@ class TFModelTesterMixin:
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, **additional_kwargs):
|
||||
tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user