mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Patch model parallel test (#8920)
* Patch model parallel test * Remove line * Remove `ci_*` from scheduled branches
This commit is contained in:
parent
0c5615af66
commit
aa60b230ec
2
.github/workflows/self-push.yml
vendored
2
.github/workflows/self-push.yml
vendored
@ -4,7 +4,7 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- model-templates
|
||||
- ci_*
|
||||
paths:
|
||||
- "src/**"
|
||||
- "tests/**"
|
||||
|
3
.github/workflows/self-scheduled.yml
vendored
3
.github/workflows/self-scheduled.yml
vendored
@ -6,9 +6,6 @@
|
||||
name: Self-hosted runner (scheduled)
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- ci_*
|
||||
repository_dispatch:
|
||||
schedule:
|
||||
- cron: "0 0 * * *"
|
||||
|
@ -1141,22 +1141,22 @@ class ModelTesterMixin:
|
||||
for model_class in self.all_parallelizable_model_classes:
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
model = model_class(config)
|
||||
output = model(**inputs_dict)
|
||||
|
||||
model.parallelize()
|
||||
|
||||
def cast_to_gpu(dictionary):
|
||||
def cast_to_device(dictionary, device):
|
||||
output = {}
|
||||
for k, v in dictionary.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
output[k] = v.to("cuda:0")
|
||||
output[k] = v.to(device)
|
||||
else:
|
||||
output[k] = v
|
||||
|
||||
return output
|
||||
|
||||
parallel_output = model(**cast_to_gpu(inputs_dict))
|
||||
model = model_class(config)
|
||||
output = model(**cast_to_device(inputs_dict, "cpu"))
|
||||
|
||||
model.parallelize()
|
||||
|
||||
parallel_output = model(**cast_to_device(inputs_dict, "cuda:0"))
|
||||
|
||||
for value, parallel_value in zip(output, parallel_output):
|
||||
if isinstance(value, torch.Tensor):
|
||||
|
Loading…
Reference in New Issue
Block a user