[Wav2Vec2ProcessorWithLM] improve multi processing (#15247)

* [Wav2Vec2ProcessorWithLM] improve multi processing

* close pool
This commit is contained in:
Patrick von Platen 2022-01-21 18:30:10 +01:00 committed by GitHub
parent 4cff3fae11
commit 80af1048cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -18,7 +18,7 @@ Speech processor class for Wav2Vec2
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import Pool from multiprocessing import get_context
from typing import TYPE_CHECKING, Iterable, List, Optional, Union from typing import TYPE_CHECKING, Iterable, List, Optional, Union
import numpy as np import numpy as np
@ -300,7 +300,7 @@ class Wav2Vec2ProcessorWithLM:
# create multiprocessing pool and list numpy arrays # create multiprocessing pool and list numpy arrays
logits_list = [array for array in logits] logits_list = [array for array in logits]
pool = Pool(num_processes) pool = get_context("fork").Pool(num_processes)
# pyctcdecode # pyctcdecode
decoded_beams = self.decoder.decode_beams_batch( decoded_beams = self.decoder.decode_beams_batch(
@ -313,6 +313,9 @@ class Wav2Vec2ProcessorWithLM:
hotword_weight=hotword_weight, hotword_weight=hotword_weight,
) )
# clone multi-processing pool
pool.close()
# extract text # extract text
batch_texts = [d[0][0] for d in decoded_beams] batch_texts = [d[0][0] for d in decoded_beams]