diff --git a/tests/modeling_test.py b/tests/modeling_test.py index 60b7666723a..d3d937a06e1 100644 --- a/tests/modeling_test.py +++ b/tests/modeling_test.py @@ -16,16 +16,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six import unittest -import collections import json import random -import re import torch -import modeling as modeling +import modeling class BertModelTest(unittest.TestCase): @@ -124,9 +121,6 @@ class BertModelTest(unittest.TestCase): output_result = tester.create_model() tester.check_output(output_result) - # TODO Find PyTorch equivalent of assert_all_tensors_reachable() if necessary - # self.assert_all_tensors_reachable(sess, [init_op, ops]) - @classmethod def ids_tensor(cls, shape, vocab_size, rng=None, name=None): """Creates a random int32 tensor of the shape within the vocab size.""" @@ -141,120 +135,7 @@ class BertModelTest(unittest.TestCase): for _ in range(total_dims): values.append(rng.randint(0, vocab_size - 1)) - # TODO Solve : the returned tensors provoke index out of range errors when passed to the model - return torch.tensor(data=values, dtype=torch.int32) - - def assert_all_tensors_reachable(self, sess, outputs): - """Checks that all the tensors in the graph are reachable from outputs.""" - graph = sess.graph - - ignore_strings = [ - "^.*/dilation_rate$", - "^.*/Tensordot/concat$", - "^.*/Tensordot/concat/axis$", - "^testing/.*$", - ] - - ignore_regexes = [re.compile(x) for x in ignore_strings] - - unreachable = self.get_unreachable_ops(graph, outputs) - filtered_unreachable = [] - for x in unreachable: - do_ignore = False - for r in ignore_regexes: - m = r.match(x.name) - if m is not None: - do_ignore = True - if do_ignore: - continue - filtered_unreachable.append(x) - unreachable = filtered_unreachable - - self.assertEqual( - len(unreachable), 0, "The following ops are unreachable: %s" % - (" ".join([x.name for x in unreachable]))) - - @classmethod - def get_unreachable_ops(cls, graph, outputs): - """Finds all of the tensors in graph that are unreachable from outputs.""" - outputs = cls.flatten_recursive(outputs) - output_to_op = collections.defaultdict(list) - op_to_all = collections.defaultdict(list) - assign_out_to_in = collections.defaultdict(list) - - for op in graph.get_operations(): - for x in op.inputs: - op_to_all[op.name].append(x.name) - for y in op.outputs: - output_to_op[y.name].append(op.name) - op_to_all[op.name].append(y.name) - if str(op.type) == "Assign": - for y in op.outputs: - for x in op.inputs: - assign_out_to_in[y.name].append(x.name) - - assign_groups = collections.defaultdict(list) - for out_name in assign_out_to_in.keys(): - name_group = assign_out_to_in[out_name] - for n1 in name_group: - assign_groups[n1].append(out_name) - for n2 in name_group: - if n1 != n2: - assign_groups[n1].append(n2) - - seen_tensors = {} - stack = [x.name for x in outputs] - while stack: - name = stack.pop() - if name in seen_tensors: - continue - seen_tensors[name] = True - - if name in output_to_op: - for op_name in output_to_op[name]: - if op_name in op_to_all: - for input_name in op_to_all[op_name]: - if input_name not in stack: - stack.append(input_name) - - expanded_names = [] - if name in assign_groups: - for assign_name in assign_groups[name]: - expanded_names.append(assign_name) - - for expanded_name in expanded_names: - if expanded_name not in stack: - stack.append(expanded_name) - - unreachable_ops = [] - for op in graph.get_operations(): - is_unreachable = False - all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs] - for name in all_names: - if name not in seen_tensors: - is_unreachable = True - if is_unreachable: - unreachable_ops.append(op) - return unreachable_ops - - @classmethod - def flatten_recursive(cls, item): - """Flattens (potentially nested) a tuple/dictionary/list to a list.""" - output = [] - if isinstance(item, list): - output.extend(item) - elif isinstance(item, tuple): - output.extend(list(item)) - elif isinstance(item, dict): - for (_, v) in six.iteritems(item): - output.append(v) - else: - return [item] - - flat_output = [] - for x in output: - flat_output.extend(cls.flatten_recursive(x)) - return flat_output + return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous() if __name__ == "__main__":