Fix assert in src/transformers/data/datasets/language_modeling.py (#14077)

* replace assertion with ValueError

* fix code style

Co-authored-by: skpig <1900012999@pku.edu.cn>
This commit is contained in:
Baizhou Huang 2021-10-20 19:54:39 +08:00 committed by GitHub
parent 0106826a65
commit 31560f6397
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -57,7 +57,8 @@ class TextDataset(Dataset):
),
FutureWarning,
)
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
if os.path.isfile(file_path) is False:
raise ValueError(f"Input file path {file_path} not found")
block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
@ -123,7 +124,8 @@ class LineByLineTextDataset(Dataset):
),
FutureWarning,
)
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
if os.path.isfile(file_path) is False:
raise ValueError(f"Input file path {file_path} not found")
# Here, we do not cache the features, operating under the assumption
# that we will soon use fast multithreaded tokenizers from the
# `tokenizers` repo everywhere =)
@ -155,8 +157,10 @@ class LineByLineWithRefDataset(Dataset):
),
FutureWarning,
)
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
assert os.path.isfile(ref_path), f"Ref file path {file_path} not found"
if os.path.isfile(file_path) is False:
raise ValueError(f"Input file path {file_path} not found")
if os.path.isfile(ref_path) is False:
raise ValueError(f"Ref file path {file_path} not found")
# Here, we do not cache the features, operating under the assumption
# that we will soon use fast multithreaded tokenizers from the
# `tokenizers` repo everywhere =)
@ -168,7 +172,11 @@ class LineByLineWithRefDataset(Dataset):
# Get ref inf from file
with open(ref_path, encoding="utf-8") as f:
ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
assert len(data) == len(ref)
if len(data) != len(ref):
raise ValueError(
f"Length of Input file should be equal to Ref file. But the length of {file_path} is {len(data)} "
f"while length of {ref_path} is {len(ref)}"
)
batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size)
self.examples = batch_encoding["input_ids"]
@ -197,14 +205,16 @@ class LineByLineWithSOPTextDataset(Dataset):
),
FutureWarning,
)
assert os.path.isdir(file_dir)
if os.path.isdir(file_dir) is False:
raise ValueError(f"{file_dir} is not a directory")
logger.info(f"Creating features from dataset file folder at {file_dir}")
self.examples = []
# TODO: randomness could apply a random seed, ex. rng = random.Random(random_seed)
# file path looks like ./dataset/wiki_1, ./dataset/wiki_2
for file_name in os.listdir(file_dir):
file_path = os.path.join(file_dir, file_name)
assert os.path.isfile(file_path)
if os.path.isfile(file_path) is False:
raise ValueError(f"{file_path} is not a file")
article_open = False
with open(file_path, encoding="utf-8") as f:
original_lines = f.readlines()
@ -297,7 +307,8 @@ class LineByLineWithSOPTextDataset(Dataset):
if total_length <= max_num_tokens:
break
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
assert len(trunc_tokens) >= 1
if not (len(trunc_tokens) >= 1):
raise ValueError("Sequence length to be truncated must be no less than one")
# We want to sometimes truncate from the front and sometimes from the
# back to add more randomness and avoid biases.
if random.random() < 0.5:
@ -306,8 +317,10 @@ class LineByLineWithSOPTextDataset(Dataset):
trunc_tokens.pop()
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
if not (len(tokens_a) >= 1):
raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
if not (len(tokens_b) >= 1):
raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
# add special tokens
input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
@ -352,7 +365,8 @@ class TextDatasetForNextSentencePrediction(Dataset):
),
FutureWarning,
)
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
if not os.path.isfile(file_path):
raise ValueError(f"Input file path {file_path} not found")
self.short_seq_probability = short_seq_probability
self.nsp_probability = nsp_probability
@ -488,8 +502,10 @@ class TextDatasetForNextSentencePrediction(Dataset):
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
if not (len(tokens_a) >= 1):
raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
if not (len(tokens_b) >= 1):
raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
# add special tokens
input_ids = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)