mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix assert in src/transformers/data/datasets/language_modeling.py (#14077)
* replace assertion with ValueError * fix code style Co-authored-by: skpig <1900012999@pku.edu.cn>
This commit is contained in:
parent
0106826a65
commit
31560f6397
@ -57,7 +57,8 @@ class TextDataset(Dataset):
|
|||||||
),
|
),
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
|
if os.path.isfile(file_path) is False:
|
||||||
|
raise ValueError(f"Input file path {file_path} not found")
|
||||||
|
|
||||||
block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
|
block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
|
||||||
|
|
||||||
@ -123,7 +124,8 @@ class LineByLineTextDataset(Dataset):
|
|||||||
),
|
),
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
|
if os.path.isfile(file_path) is False:
|
||||||
|
raise ValueError(f"Input file path {file_path} not found")
|
||||||
# Here, we do not cache the features, operating under the assumption
|
# Here, we do not cache the features, operating under the assumption
|
||||||
# that we will soon use fast multithreaded tokenizers from the
|
# that we will soon use fast multithreaded tokenizers from the
|
||||||
# `tokenizers` repo everywhere =)
|
# `tokenizers` repo everywhere =)
|
||||||
@ -155,8 +157,10 @@ class LineByLineWithRefDataset(Dataset):
|
|||||||
),
|
),
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
|
if os.path.isfile(file_path) is False:
|
||||||
assert os.path.isfile(ref_path), f"Ref file path {file_path} not found"
|
raise ValueError(f"Input file path {file_path} not found")
|
||||||
|
if os.path.isfile(ref_path) is False:
|
||||||
|
raise ValueError(f"Ref file path {file_path} not found")
|
||||||
# Here, we do not cache the features, operating under the assumption
|
# Here, we do not cache the features, operating under the assumption
|
||||||
# that we will soon use fast multithreaded tokenizers from the
|
# that we will soon use fast multithreaded tokenizers from the
|
||||||
# `tokenizers` repo everywhere =)
|
# `tokenizers` repo everywhere =)
|
||||||
@ -168,7 +172,11 @@ class LineByLineWithRefDataset(Dataset):
|
|||||||
# Get ref inf from file
|
# Get ref inf from file
|
||||||
with open(ref_path, encoding="utf-8") as f:
|
with open(ref_path, encoding="utf-8") as f:
|
||||||
ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
|
ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
|
||||||
assert len(data) == len(ref)
|
if len(data) != len(ref):
|
||||||
|
raise ValueError(
|
||||||
|
f"Length of Input file should be equal to Ref file. But the length of {file_path} is {len(data)} "
|
||||||
|
f"while length of {ref_path} is {len(ref)}"
|
||||||
|
)
|
||||||
|
|
||||||
batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size)
|
batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size)
|
||||||
self.examples = batch_encoding["input_ids"]
|
self.examples = batch_encoding["input_ids"]
|
||||||
@ -197,14 +205,16 @@ class LineByLineWithSOPTextDataset(Dataset):
|
|||||||
),
|
),
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
assert os.path.isdir(file_dir)
|
if os.path.isdir(file_dir) is False:
|
||||||
|
raise ValueError(f"{file_dir} is not a directory")
|
||||||
logger.info(f"Creating features from dataset file folder at {file_dir}")
|
logger.info(f"Creating features from dataset file folder at {file_dir}")
|
||||||
self.examples = []
|
self.examples = []
|
||||||
# TODO: randomness could apply a random seed, ex. rng = random.Random(random_seed)
|
# TODO: randomness could apply a random seed, ex. rng = random.Random(random_seed)
|
||||||
# file path looks like ./dataset/wiki_1, ./dataset/wiki_2
|
# file path looks like ./dataset/wiki_1, ./dataset/wiki_2
|
||||||
for file_name in os.listdir(file_dir):
|
for file_name in os.listdir(file_dir):
|
||||||
file_path = os.path.join(file_dir, file_name)
|
file_path = os.path.join(file_dir, file_name)
|
||||||
assert os.path.isfile(file_path)
|
if os.path.isfile(file_path) is False:
|
||||||
|
raise ValueError(f"{file_path} is not a file")
|
||||||
article_open = False
|
article_open = False
|
||||||
with open(file_path, encoding="utf-8") as f:
|
with open(file_path, encoding="utf-8") as f:
|
||||||
original_lines = f.readlines()
|
original_lines = f.readlines()
|
||||||
@ -297,7 +307,8 @@ class LineByLineWithSOPTextDataset(Dataset):
|
|||||||
if total_length <= max_num_tokens:
|
if total_length <= max_num_tokens:
|
||||||
break
|
break
|
||||||
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
||||||
assert len(trunc_tokens) >= 1
|
if not (len(trunc_tokens) >= 1):
|
||||||
|
raise ValueError("Sequence length to be truncated must be no less than one")
|
||||||
# We want to sometimes truncate from the front and sometimes from the
|
# We want to sometimes truncate from the front and sometimes from the
|
||||||
# back to add more randomness and avoid biases.
|
# back to add more randomness and avoid biases.
|
||||||
if random.random() < 0.5:
|
if random.random() < 0.5:
|
||||||
@ -306,8 +317,10 @@ class LineByLineWithSOPTextDataset(Dataset):
|
|||||||
trunc_tokens.pop()
|
trunc_tokens.pop()
|
||||||
|
|
||||||
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)
|
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)
|
||||||
assert len(tokens_a) >= 1
|
if not (len(tokens_a) >= 1):
|
||||||
assert len(tokens_b) >= 1
|
raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
|
||||||
|
if not (len(tokens_b) >= 1):
|
||||||
|
raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
|
||||||
|
|
||||||
# add special tokens
|
# add special tokens
|
||||||
input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
|
input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
|
||||||
@ -352,7 +365,8 @@ class TextDatasetForNextSentencePrediction(Dataset):
|
|||||||
),
|
),
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
|
if not os.path.isfile(file_path):
|
||||||
|
raise ValueError(f"Input file path {file_path} not found")
|
||||||
|
|
||||||
self.short_seq_probability = short_seq_probability
|
self.short_seq_probability = short_seq_probability
|
||||||
self.nsp_probability = nsp_probability
|
self.nsp_probability = nsp_probability
|
||||||
@ -488,8 +502,10 @@ class TextDatasetForNextSentencePrediction(Dataset):
|
|||||||
for j in range(a_end, len(current_chunk)):
|
for j in range(a_end, len(current_chunk)):
|
||||||
tokens_b.extend(current_chunk[j])
|
tokens_b.extend(current_chunk[j])
|
||||||
|
|
||||||
assert len(tokens_a) >= 1
|
if not (len(tokens_a) >= 1):
|
||||||
assert len(tokens_b) >= 1
|
raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
|
||||||
|
if not (len(tokens_b) >= 1):
|
||||||
|
raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
|
||||||
|
|
||||||
# add special tokens
|
# add special tokens
|
||||||
input_ids = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
|
input_ids = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
|
||||||
|
Loading…
Reference in New Issue
Block a user