mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
[Wav2Vec2ProcessorWithLM] improve multi processing (#15247)
* [Wav2Vec2ProcessorWithLM] improve multi processing * close pool
This commit is contained in:
parent
4cff3fae11
commit
80af1048cf
@ -18,7 +18,7 @@ Speech processor class for Wav2Vec2
|
|||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from multiprocessing import Pool
|
from multiprocessing import get_context
|
||||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
|
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -300,7 +300,7 @@ class Wav2Vec2ProcessorWithLM:
|
|||||||
|
|
||||||
# create multiprocessing pool and list numpy arrays
|
# create multiprocessing pool and list numpy arrays
|
||||||
logits_list = [array for array in logits]
|
logits_list = [array for array in logits]
|
||||||
pool = Pool(num_processes)
|
pool = get_context("fork").Pool(num_processes)
|
||||||
|
|
||||||
# pyctcdecode
|
# pyctcdecode
|
||||||
decoded_beams = self.decoder.decode_beams_batch(
|
decoded_beams = self.decoder.decode_beams_batch(
|
||||||
@ -313,6 +313,9 @@ class Wav2Vec2ProcessorWithLM:
|
|||||||
hotword_weight=hotword_weight,
|
hotword_weight=hotword_weight,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# clone multi-processing pool
|
||||||
|
pool.close()
|
||||||
|
|
||||||
# extract text
|
# extract text
|
||||||
batch_texts = [d[0][0] for d in decoded_beams]
|
batch_texts = [d[0][0] for d in decoded_beams]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user