mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Fix assert in src/transformers/data/datasets/language_modeling.py (#14077)
* replace assertion with ValueError * fix code style Co-authored-by: skpig <1900012999@pku.edu.cn>
This commit is contained in:
parent
0106826a65
commit
31560f6397
@ -57,7 +57,8 @@ class TextDataset(Dataset):
|
||||
),
|
||||
FutureWarning,
|
||||
)
|
||||
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
|
||||
if os.path.isfile(file_path) is False:
|
||||
raise ValueError(f"Input file path {file_path} not found")
|
||||
|
||||
block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
|
||||
|
||||
@ -123,7 +124,8 @@ class LineByLineTextDataset(Dataset):
|
||||
),
|
||||
FutureWarning,
|
||||
)
|
||||
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
|
||||
if os.path.isfile(file_path) is False:
|
||||
raise ValueError(f"Input file path {file_path} not found")
|
||||
# Here, we do not cache the features, operating under the assumption
|
||||
# that we will soon use fast multithreaded tokenizers from the
|
||||
# `tokenizers` repo everywhere =)
|
||||
@ -155,8 +157,10 @@ class LineByLineWithRefDataset(Dataset):
|
||||
),
|
||||
FutureWarning,
|
||||
)
|
||||
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
|
||||
assert os.path.isfile(ref_path), f"Ref file path {file_path} not found"
|
||||
if os.path.isfile(file_path) is False:
|
||||
raise ValueError(f"Input file path {file_path} not found")
|
||||
if os.path.isfile(ref_path) is False:
|
||||
raise ValueError(f"Ref file path {file_path} not found")
|
||||
# Here, we do not cache the features, operating under the assumption
|
||||
# that we will soon use fast multithreaded tokenizers from the
|
||||
# `tokenizers` repo everywhere =)
|
||||
@ -168,7 +172,11 @@ class LineByLineWithRefDataset(Dataset):
|
||||
# Get ref inf from file
|
||||
with open(ref_path, encoding="utf-8") as f:
|
||||
ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
|
||||
assert len(data) == len(ref)
|
||||
if len(data) != len(ref):
|
||||
raise ValueError(
|
||||
f"Length of Input file should be equal to Ref file. But the length of {file_path} is {len(data)} "
|
||||
f"while length of {ref_path} is {len(ref)}"
|
||||
)
|
||||
|
||||
batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size)
|
||||
self.examples = batch_encoding["input_ids"]
|
||||
@ -197,14 +205,16 @@ class LineByLineWithSOPTextDataset(Dataset):
|
||||
),
|
||||
FutureWarning,
|
||||
)
|
||||
assert os.path.isdir(file_dir)
|
||||
if os.path.isdir(file_dir) is False:
|
||||
raise ValueError(f"{file_dir} is not a directory")
|
||||
logger.info(f"Creating features from dataset file folder at {file_dir}")
|
||||
self.examples = []
|
||||
# TODO: randomness could apply a random seed, ex. rng = random.Random(random_seed)
|
||||
# file path looks like ./dataset/wiki_1, ./dataset/wiki_2
|
||||
for file_name in os.listdir(file_dir):
|
||||
file_path = os.path.join(file_dir, file_name)
|
||||
assert os.path.isfile(file_path)
|
||||
if os.path.isfile(file_path) is False:
|
||||
raise ValueError(f"{file_path} is not a file")
|
||||
article_open = False
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
original_lines = f.readlines()
|
||||
@ -297,7 +307,8 @@ class LineByLineWithSOPTextDataset(Dataset):
|
||||
if total_length <= max_num_tokens:
|
||||
break
|
||||
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
||||
assert len(trunc_tokens) >= 1
|
||||
if not (len(trunc_tokens) >= 1):
|
||||
raise ValueError("Sequence length to be truncated must be no less than one")
|
||||
# We want to sometimes truncate from the front and sometimes from the
|
||||
# back to add more randomness and avoid biases.
|
||||
if random.random() < 0.5:
|
||||
@ -306,8 +317,10 @@ class LineByLineWithSOPTextDataset(Dataset):
|
||||
trunc_tokens.pop()
|
||||
|
||||
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)
|
||||
assert len(tokens_a) >= 1
|
||||
assert len(tokens_b) >= 1
|
||||
if not (len(tokens_a) >= 1):
|
||||
raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
|
||||
if not (len(tokens_b) >= 1):
|
||||
raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
|
||||
|
||||
# add special tokens
|
||||
input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
|
||||
@ -352,7 +365,8 @@ class TextDatasetForNextSentencePrediction(Dataset):
|
||||
),
|
||||
FutureWarning,
|
||||
)
|
||||
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
|
||||
if not os.path.isfile(file_path):
|
||||
raise ValueError(f"Input file path {file_path} not found")
|
||||
|
||||
self.short_seq_probability = short_seq_probability
|
||||
self.nsp_probability = nsp_probability
|
||||
@ -488,8 +502,10 @@ class TextDatasetForNextSentencePrediction(Dataset):
|
||||
for j in range(a_end, len(current_chunk)):
|
||||
tokens_b.extend(current_chunk[j])
|
||||
|
||||
assert len(tokens_a) >= 1
|
||||
assert len(tokens_b) >= 1
|
||||
if not (len(tokens_a) >= 1):
|
||||
raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
|
||||
if not (len(tokens_b) >= 1):
|
||||
raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
|
||||
|
||||
# add special tokens
|
||||
input_ids = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
|
||||
|
Loading…
Reference in New Issue
Block a user