mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
finishing model test
This commit is contained in:
parent
d69b0b0e90
commit
87da161c2a
@ -16,16 +16,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
import unittest
|
||||
import collections
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
import modeling as modeling
|
||||
import modeling
|
||||
|
||||
|
||||
class BertModelTest(unittest.TestCase):
|
||||
@ -124,9 +121,6 @@ class BertModelTest(unittest.TestCase):
|
||||
output_result = tester.create_model()
|
||||
tester.check_output(output_result)
|
||||
|
||||
# TODO Find PyTorch equivalent of assert_all_tensors_reachable() if necessary
|
||||
# self.assert_all_tensors_reachable(sess, [init_op, ops])
|
||||
|
||||
@classmethod
|
||||
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||
@ -141,120 +135,7 @@ class BertModelTest(unittest.TestCase):
|
||||
for _ in range(total_dims):
|
||||
values.append(rng.randint(0, vocab_size - 1))
|
||||
|
||||
# TODO Solve : the returned tensors provoke index out of range errors when passed to the model
|
||||
return torch.tensor(data=values, dtype=torch.int32)
|
||||
|
||||
def assert_all_tensors_reachable(self, sess, outputs):
|
||||
"""Checks that all the tensors in the graph are reachable from outputs."""
|
||||
graph = sess.graph
|
||||
|
||||
ignore_strings = [
|
||||
"^.*/dilation_rate$",
|
||||
"^.*/Tensordot/concat$",
|
||||
"^.*/Tensordot/concat/axis$",
|
||||
"^testing/.*$",
|
||||
]
|
||||
|
||||
ignore_regexes = [re.compile(x) for x in ignore_strings]
|
||||
|
||||
unreachable = self.get_unreachable_ops(graph, outputs)
|
||||
filtered_unreachable = []
|
||||
for x in unreachable:
|
||||
do_ignore = False
|
||||
for r in ignore_regexes:
|
||||
m = r.match(x.name)
|
||||
if m is not None:
|
||||
do_ignore = True
|
||||
if do_ignore:
|
||||
continue
|
||||
filtered_unreachable.append(x)
|
||||
unreachable = filtered_unreachable
|
||||
|
||||
self.assertEqual(
|
||||
len(unreachable), 0, "The following ops are unreachable: %s" %
|
||||
(" ".join([x.name for x in unreachable])))
|
||||
|
||||
@classmethod
|
||||
def get_unreachable_ops(cls, graph, outputs):
|
||||
"""Finds all of the tensors in graph that are unreachable from outputs."""
|
||||
outputs = cls.flatten_recursive(outputs)
|
||||
output_to_op = collections.defaultdict(list)
|
||||
op_to_all = collections.defaultdict(list)
|
||||
assign_out_to_in = collections.defaultdict(list)
|
||||
|
||||
for op in graph.get_operations():
|
||||
for x in op.inputs:
|
||||
op_to_all[op.name].append(x.name)
|
||||
for y in op.outputs:
|
||||
output_to_op[y.name].append(op.name)
|
||||
op_to_all[op.name].append(y.name)
|
||||
if str(op.type) == "Assign":
|
||||
for y in op.outputs:
|
||||
for x in op.inputs:
|
||||
assign_out_to_in[y.name].append(x.name)
|
||||
|
||||
assign_groups = collections.defaultdict(list)
|
||||
for out_name in assign_out_to_in.keys():
|
||||
name_group = assign_out_to_in[out_name]
|
||||
for n1 in name_group:
|
||||
assign_groups[n1].append(out_name)
|
||||
for n2 in name_group:
|
||||
if n1 != n2:
|
||||
assign_groups[n1].append(n2)
|
||||
|
||||
seen_tensors = {}
|
||||
stack = [x.name for x in outputs]
|
||||
while stack:
|
||||
name = stack.pop()
|
||||
if name in seen_tensors:
|
||||
continue
|
||||
seen_tensors[name] = True
|
||||
|
||||
if name in output_to_op:
|
||||
for op_name in output_to_op[name]:
|
||||
if op_name in op_to_all:
|
||||
for input_name in op_to_all[op_name]:
|
||||
if input_name not in stack:
|
||||
stack.append(input_name)
|
||||
|
||||
expanded_names = []
|
||||
if name in assign_groups:
|
||||
for assign_name in assign_groups[name]:
|
||||
expanded_names.append(assign_name)
|
||||
|
||||
for expanded_name in expanded_names:
|
||||
if expanded_name not in stack:
|
||||
stack.append(expanded_name)
|
||||
|
||||
unreachable_ops = []
|
||||
for op in graph.get_operations():
|
||||
is_unreachable = False
|
||||
all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs]
|
||||
for name in all_names:
|
||||
if name not in seen_tensors:
|
||||
is_unreachable = True
|
||||
if is_unreachable:
|
||||
unreachable_ops.append(op)
|
||||
return unreachable_ops
|
||||
|
||||
@classmethod
|
||||
def flatten_recursive(cls, item):
|
||||
"""Flattens (potentially nested) a tuple/dictionary/list to a list."""
|
||||
output = []
|
||||
if isinstance(item, list):
|
||||
output.extend(item)
|
||||
elif isinstance(item, tuple):
|
||||
output.extend(list(item))
|
||||
elif isinstance(item, dict):
|
||||
for (_, v) in six.iteritems(item):
|
||||
output.append(v)
|
||||
else:
|
||||
return [item]
|
||||
|
||||
flat_output = []
|
||||
for x in output:
|
||||
flat_output.extend(cls.flatten_recursive(x))
|
||||
return flat_output
|
||||
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user