mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-18 12:08:22 +06:00
finishing model test
This commit is contained in:
parent
d69b0b0e90
commit
87da161c2a
@ -16,16 +16,13 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import six
|
|
||||||
import unittest
|
import unittest
|
||||||
import collections
|
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import re
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import modeling as modeling
|
import modeling
|
||||||
|
|
||||||
|
|
||||||
class BertModelTest(unittest.TestCase):
|
class BertModelTest(unittest.TestCase):
|
||||||
@ -124,9 +121,6 @@ class BertModelTest(unittest.TestCase):
|
|||||||
output_result = tester.create_model()
|
output_result = tester.create_model()
|
||||||
tester.check_output(output_result)
|
tester.check_output(output_result)
|
||||||
|
|
||||||
# TODO Find PyTorch equivalent of assert_all_tensors_reachable() if necessary
|
|
||||||
# self.assert_all_tensors_reachable(sess, [init_op, ops])
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
||||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||||
@ -141,120 +135,7 @@ class BertModelTest(unittest.TestCase):
|
|||||||
for _ in range(total_dims):
|
for _ in range(total_dims):
|
||||||
values.append(rng.randint(0, vocab_size - 1))
|
values.append(rng.randint(0, vocab_size - 1))
|
||||||
|
|
||||||
# TODO Solve : the returned tensors provoke index out of range errors when passed to the model
|
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
|
||||||
return torch.tensor(data=values, dtype=torch.int32)
|
|
||||||
|
|
||||||
def assert_all_tensors_reachable(self, sess, outputs):
|
|
||||||
"""Checks that all the tensors in the graph are reachable from outputs."""
|
|
||||||
graph = sess.graph
|
|
||||||
|
|
||||||
ignore_strings = [
|
|
||||||
"^.*/dilation_rate$",
|
|
||||||
"^.*/Tensordot/concat$",
|
|
||||||
"^.*/Tensordot/concat/axis$",
|
|
||||||
"^testing/.*$",
|
|
||||||
]
|
|
||||||
|
|
||||||
ignore_regexes = [re.compile(x) for x in ignore_strings]
|
|
||||||
|
|
||||||
unreachable = self.get_unreachable_ops(graph, outputs)
|
|
||||||
filtered_unreachable = []
|
|
||||||
for x in unreachable:
|
|
||||||
do_ignore = False
|
|
||||||
for r in ignore_regexes:
|
|
||||||
m = r.match(x.name)
|
|
||||||
if m is not None:
|
|
||||||
do_ignore = True
|
|
||||||
if do_ignore:
|
|
||||||
continue
|
|
||||||
filtered_unreachable.append(x)
|
|
||||||
unreachable = filtered_unreachable
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
len(unreachable), 0, "The following ops are unreachable: %s" %
|
|
||||||
(" ".join([x.name for x in unreachable])))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_unreachable_ops(cls, graph, outputs):
|
|
||||||
"""Finds all of the tensors in graph that are unreachable from outputs."""
|
|
||||||
outputs = cls.flatten_recursive(outputs)
|
|
||||||
output_to_op = collections.defaultdict(list)
|
|
||||||
op_to_all = collections.defaultdict(list)
|
|
||||||
assign_out_to_in = collections.defaultdict(list)
|
|
||||||
|
|
||||||
for op in graph.get_operations():
|
|
||||||
for x in op.inputs:
|
|
||||||
op_to_all[op.name].append(x.name)
|
|
||||||
for y in op.outputs:
|
|
||||||
output_to_op[y.name].append(op.name)
|
|
||||||
op_to_all[op.name].append(y.name)
|
|
||||||
if str(op.type) == "Assign":
|
|
||||||
for y in op.outputs:
|
|
||||||
for x in op.inputs:
|
|
||||||
assign_out_to_in[y.name].append(x.name)
|
|
||||||
|
|
||||||
assign_groups = collections.defaultdict(list)
|
|
||||||
for out_name in assign_out_to_in.keys():
|
|
||||||
name_group = assign_out_to_in[out_name]
|
|
||||||
for n1 in name_group:
|
|
||||||
assign_groups[n1].append(out_name)
|
|
||||||
for n2 in name_group:
|
|
||||||
if n1 != n2:
|
|
||||||
assign_groups[n1].append(n2)
|
|
||||||
|
|
||||||
seen_tensors = {}
|
|
||||||
stack = [x.name for x in outputs]
|
|
||||||
while stack:
|
|
||||||
name = stack.pop()
|
|
||||||
if name in seen_tensors:
|
|
||||||
continue
|
|
||||||
seen_tensors[name] = True
|
|
||||||
|
|
||||||
if name in output_to_op:
|
|
||||||
for op_name in output_to_op[name]:
|
|
||||||
if op_name in op_to_all:
|
|
||||||
for input_name in op_to_all[op_name]:
|
|
||||||
if input_name not in stack:
|
|
||||||
stack.append(input_name)
|
|
||||||
|
|
||||||
expanded_names = []
|
|
||||||
if name in assign_groups:
|
|
||||||
for assign_name in assign_groups[name]:
|
|
||||||
expanded_names.append(assign_name)
|
|
||||||
|
|
||||||
for expanded_name in expanded_names:
|
|
||||||
if expanded_name not in stack:
|
|
||||||
stack.append(expanded_name)
|
|
||||||
|
|
||||||
unreachable_ops = []
|
|
||||||
for op in graph.get_operations():
|
|
||||||
is_unreachable = False
|
|
||||||
all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs]
|
|
||||||
for name in all_names:
|
|
||||||
if name not in seen_tensors:
|
|
||||||
is_unreachable = True
|
|
||||||
if is_unreachable:
|
|
||||||
unreachable_ops.append(op)
|
|
||||||
return unreachable_ops
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def flatten_recursive(cls, item):
|
|
||||||
"""Flattens (potentially nested) a tuple/dictionary/list to a list."""
|
|
||||||
output = []
|
|
||||||
if isinstance(item, list):
|
|
||||||
output.extend(item)
|
|
||||||
elif isinstance(item, tuple):
|
|
||||||
output.extend(list(item))
|
|
||||||
elif isinstance(item, dict):
|
|
||||||
for (_, v) in six.iteritems(item):
|
|
||||||
output.append(v)
|
|
||||||
else:
|
|
||||||
return [item]
|
|
||||||
|
|
||||||
flat_output = []
|
|
||||||
for x in output:
|
|
||||||
flat_output.extend(cls.flatten_recursive(x))
|
|
||||||
return flat_output
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user