Patch model parallel test (#8920)

* Patch model parallel test

* Remove line

* Remove `ci_*` from scheduled branches
This commit is contained in:
Lysandre Debut 2020-12-03 17:15:47 -05:00 committed by GitHub
parent 0c5615af66
commit aa60b230ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 12 deletions

View File

@ -4,7 +4,7 @@ on:
push:
branches:
- master
- model-templates
- ci_*
paths:
- "src/**"
- "tests/**"

View File

@ -6,9 +6,6 @@
name: Self-hosted runner (scheduled)
on:
push:
branches:
- ci_*
repository_dispatch:
schedule:
- cron: "0 0 * * *"

View File

@ -1141,22 +1141,22 @@ class ModelTesterMixin:
for model_class in self.all_parallelizable_model_classes:
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
output = model(**inputs_dict)
model.parallelize()
def cast_to_gpu(dictionary):
def cast_to_device(dictionary, device):
output = {}
for k, v in dictionary.items():
if isinstance(v, torch.Tensor):
output[k] = v.to("cuda:0")
output[k] = v.to(device)
else:
output[k] = v
return output
parallel_output = model(**cast_to_gpu(inputs_dict))
model = model_class(config)
output = model(**cast_to_device(inputs_dict, "cpu"))
model.parallelize()
parallel_output = model(**cast_to_device(inputs_dict, "cuda:0"))
for value, parallel_value in zip(output, parallel_output):
if isinstance(value, torch.Tensor):