mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
# Add whole word mask support for lm fine-tune (#7925)
* ADD: add whole word mask proxy for both eng and chinese * MOD: adjust format * MOD: reformat code * MOD: update import * MOD: fix bug * MOD: add import * MOD: fix bug * MOD: decouple code and update readme * MOD: reformat code * Update examples/language-modeling/README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/language-modeling/README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/language-modeling/run_language_modeling.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/language-modeling/run_language_modeling.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/language-modeling/run_language_modeling.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/language-modeling/run_language_modeling.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * change wwm to whole_word_mask * reformat code * reformat * format * Code quality * ADD: update chinese ref readme * MOD: small changes * MOD: small changes2 * update readme Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
This commit is contained in:
parent
64b4d25cf3
commit
a16e568f22
@ -45,6 +45,8 @@ slightly slower (over-fitting takes more epochs).
|
||||
|
||||
We use the `--mlm` flag so that the script may change its loss function.
|
||||
|
||||
If using whole-word masking, use both the`--mlm` and `--wwm` flags.
|
||||
|
||||
```bash
|
||||
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
|
||||
export TEST_FILE=/path/to/dataset/wiki.test.raw
|
||||
@ -57,7 +59,55 @@ python run_language_modeling.py \
|
||||
--train_data_file=$TRAIN_FILE \
|
||||
--do_eval \
|
||||
--eval_data_file=$TEST_FILE \
|
||||
--mlm
|
||||
--mlm \
|
||||
--wwm
|
||||
```
|
||||
|
||||
For Chinese models, it's same with English model with only --mlm`. If using whole-word masking, we need to generate a reference files, case it's char level.
|
||||
|
||||
**Q :** Why ref file ?
|
||||
|
||||
**A :** Suppose we have a Chinese sentence like : `我喜欢你` The original Chinese-BERT will tokenize it as `['我','喜','欢','你']` in char level.
|
||||
Actually, `喜欢` is a whole word. For whole word mask proxy, We need res like `['我','喜','##欢','你']`.
|
||||
So we need a ref file to tell model which pos of BERT original token should be added `##`.
|
||||
|
||||
**Q :** Why LTP ?
|
||||
|
||||
**A :** Cause the best known Chinese WWM BERT is [Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm) by HIT. It works well on so many Chines Task like CLUE (Chinese GLUE).
|
||||
They use LTP, so if we want to fine-tune their model, we need LTP.
|
||||
|
||||
```bash
|
||||
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
|
||||
export LTP_RESOURCE=/path/to/ltp/tokenizer
|
||||
export BERT_RESOURCE=/path/to/bert/tokenizer
|
||||
export SAVE_PATH=/path/to/data/ref.txt
|
||||
|
||||
python chinese_ref.py \
|
||||
--file_name=$TRAIN_FILE \
|
||||
--ltp=$LTP_RESOURCE
|
||||
--bert=$BERT_RESOURCE \
|
||||
--save_path=$SAVE_PATH
|
||||
```
|
||||
Now Chinese Ref is only supported by `LineByLineWithRefDataset` Class, so we need add `line_by_line` flag:
|
||||
|
||||
|
||||
```bash
|
||||
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
|
||||
export TEST_FILE=/path/to/dataset/wiki.test.raw
|
||||
export REF_FILE=/path/to/ref.txt
|
||||
|
||||
python run_language_modeling.py \
|
||||
--output_dir=output \
|
||||
--model_type=roberta \
|
||||
--model_name_or_path=roberta-base \
|
||||
--do_train \
|
||||
--train_data_file=$TRAIN_FILE \
|
||||
--chinese_ref_file=$REF_FILE \
|
||||
--do_eval \
|
||||
--eval_data_file=$TEST_FILE \
|
||||
--mlm \
|
||||
--line_by_line \
|
||||
--wwm
|
||||
```
|
||||
|
||||
### XLNet and permutation language modeling
|
||||
|
147
examples/language-modeling/chinese_ref.py
Normal file
147
examples/language-modeling/chinese_ref.py
Normal file
@ -0,0 +1,147 @@
|
||||
import argparse
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from ltp import LTP
|
||||
from transformers.tokenization_bert import BertTokenizer
|
||||
|
||||
|
||||
def _is_chinese_char(cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if (
|
||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
||||
): #
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_chinese(word: str):
|
||||
# word like '180' or '身高' or '神'
|
||||
for char in word:
|
||||
char = ord(char)
|
||||
if not _is_chinese_char(char):
|
||||
return 0
|
||||
return 1
|
||||
|
||||
|
||||
def get_chinese_word(tokens: List[str]):
|
||||
word_set = set()
|
||||
|
||||
for token in tokens:
|
||||
chinese_word = len(token) > 1 and is_chinese(token)
|
||||
if chinese_word:
|
||||
word_set.add(token)
|
||||
word_list = list(word_set)
|
||||
return word_list
|
||||
|
||||
|
||||
def add_sub_symbol(bert_tokens: List[str], chinese_word_set: set()):
|
||||
if not chinese_word_set:
|
||||
return bert_tokens
|
||||
max_word_len = max([len(w) for w in chinese_word_set])
|
||||
|
||||
bert_word = bert_tokens
|
||||
start, end = 0, len(bert_word)
|
||||
while start < end:
|
||||
single_word = True
|
||||
if is_chinese(bert_word[start]):
|
||||
l = min(end - start, max_word_len)
|
||||
for i in range(l, 1, -1):
|
||||
whole_word = "".join(bert_word[start : start + i])
|
||||
if whole_word in chinese_word_set:
|
||||
for j in range(start + 1, start + i):
|
||||
bert_word[j] = "##" + bert_word[j]
|
||||
start = start + i
|
||||
single_word = False
|
||||
break
|
||||
if single_word:
|
||||
start += 1
|
||||
return bert_word
|
||||
|
||||
|
||||
def prepare_ref(lines: List[str], ltp_tokenizer: LTP, bert_tokenizer: BertTokenizer):
|
||||
ltp_res = []
|
||||
|
||||
for i in range(0, len(lines), 100):
|
||||
res = ltp_tokenizer.seg(lines[i : i + 100])[0]
|
||||
res = [get_chinese_word(r) for r in res]
|
||||
ltp_res.extend(res)
|
||||
assert len(ltp_res) == len(lines)
|
||||
|
||||
bert_res = []
|
||||
for i in range(0, len(lines), 100):
|
||||
res = bert_tokenizer(lines[i : i + 100], add_special_tokens=True, truncation=True, max_length=512)
|
||||
bert_res.extend(res["input_ids"])
|
||||
assert len(bert_res) == len(lines)
|
||||
|
||||
ref_ids = []
|
||||
for input_ids, chinese_word in zip(bert_res, ltp_res):
|
||||
|
||||
input_tokens = []
|
||||
for id in input_ids:
|
||||
token = bert_tokenizer._convert_id_to_token(id)
|
||||
input_tokens.append(token)
|
||||
input_tokens = add_sub_symbol(input_tokens, chinese_word)
|
||||
ref_id = []
|
||||
# We only save pos of chinese subwords start with ##, which mean is part of a whole word.
|
||||
for i, token in enumerate(input_tokens):
|
||||
if token[:2] == "##":
|
||||
clean_token = token[2:]
|
||||
# save chinese tokens' pos
|
||||
if len(clean_token) == 1 and _is_chinese_char(ord(clean_token)):
|
||||
ref_id.append(i)
|
||||
ref_ids.append(ref_id)
|
||||
|
||||
assert len(ref_ids) == len(bert_res)
|
||||
|
||||
return ref_ids
|
||||
|
||||
|
||||
def main(args):
|
||||
# For Chinese (Ro)Bert, the best result is from : RoBERTa-wwm-ext (https://github.com/ymcui/Chinese-BERT-wwm)
|
||||
# If we want to fine-tune these model, we have to use same tokenizer : LTP (https://github.com/HIT-SCIR/ltp)
|
||||
with open(args.file_name, "r", encoding="utf-8") as f:
|
||||
data = f.readlines()
|
||||
|
||||
ltp_tokenizer = LTP(args.ltp) # faster in GPU device
|
||||
bert_tokenizer = BertTokenizer.from_pretrained(args.bert)
|
||||
|
||||
ref_ids = prepare_ref(data, ltp_tokenizer, bert_tokenizer)
|
||||
|
||||
with open(args.save_path, "w", encoding="utf-8") as f:
|
||||
data = [json.dumps(ref) + "\n" for ref in ref_ids]
|
||||
f.writelines(data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="prepare_chinese_ref")
|
||||
parser.add_argument(
|
||||
"--file_name",
|
||||
type=str,
|
||||
default="./resources/chinese-demo.txt",
|
||||
help="file need process, same as training data in lm",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ltp", type=str, default="./resources/ltp", help="resources for LTP tokenizer, usually a path"
|
||||
)
|
||||
parser.add_argument("--bert", type=str, default="./resources/robert", help="resources for Bert tokenizer")
|
||||
parser.add_argument("--save_path", type=str, default="./resources/ref.txt", help="path to save res")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
@ -37,8 +37,10 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
DataCollatorForLanguageModeling,
|
||||
DataCollatorForPermutationLanguageModeling,
|
||||
DataCollatorForWholeWordMask,
|
||||
HfArgumentParser,
|
||||
LineByLineTextDataset,
|
||||
LineByLineWithRefDataset,
|
||||
PreTrainedTokenizer,
|
||||
TextDataset,
|
||||
Trainer,
|
||||
@ -101,6 +103,10 @@ class DataTrainingArguments:
|
||||
default=None,
|
||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
||||
)
|
||||
chinese_ref_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input ref data file for whole word mask in Chinees."},
|
||||
)
|
||||
line_by_line: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
|
||||
@ -109,6 +115,7 @@ class DataTrainingArguments:
|
||||
mlm: bool = field(
|
||||
default=False, metadata={"help": "Train with masked-language modeling loss instead of language modeling."}
|
||||
)
|
||||
whole_word_mask: bool = field(default=False, metadata={"help": "Whether ot not to use whole word mask."})
|
||||
mlm_probability: float = field(
|
||||
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
|
||||
)
|
||||
@ -143,6 +150,16 @@ def get_dataset(
|
||||
):
|
||||
def _dataset(file_path):
|
||||
if args.line_by_line:
|
||||
if args.chinese_ref_file is not None:
|
||||
if not args.whole_word_mask or not args.mlm:
|
||||
raise ValueError("You need to set world whole masking and mlm to True for Chinese Whole Word Mask")
|
||||
return LineByLineWithRefDataset(
|
||||
tokenizer=tokenizer,
|
||||
file_path=file_path,
|
||||
block_size=args.block_size,
|
||||
ref_path=args.chinese_ref_file,
|
||||
)
|
||||
|
||||
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
|
||||
else:
|
||||
return TextDataset(
|
||||
@ -174,7 +191,6 @@ def main():
|
||||
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
|
||||
"or remove the --do_eval argument."
|
||||
)
|
||||
|
||||
if (
|
||||
os.path.exists(training_args.output_dir)
|
||||
and os.listdir(training_args.output_dir)
|
||||
@ -270,9 +286,14 @@ def main():
|
||||
max_span_length=data_args.max_span_length,
|
||||
)
|
||||
else:
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
|
||||
)
|
||||
if data_args.mlm and data_args.whole_word_mask:
|
||||
data_collator = DataCollatorForWholeWordMask(
|
||||
tokenizer=tokenizer, mlm_probability=data_args.mlm_probability
|
||||
)
|
||||
else:
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
|
||||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
|
@ -284,6 +284,7 @@ if is_torch_available():
|
||||
DataCollatorForNextSentencePrediction,
|
||||
DataCollatorForPermutationLanguageModeling,
|
||||
DataCollatorForSOP,
|
||||
DataCollatorForWholeWordMask,
|
||||
DataCollatorWithPadding,
|
||||
default_data_collator,
|
||||
)
|
||||
@ -291,6 +292,7 @@ if is_torch_available():
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
LineByLineTextDataset,
|
||||
LineByLineWithRefDataset,
|
||||
LineByLineWithSOPTextDataset,
|
||||
SquadDataset,
|
||||
SquadDataTrainingArguments,
|
||||
|
@ -1,3 +1,4 @@
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
|
||||
|
||||
@ -195,6 +196,124 @@ class DataCollatorForLanguageModeling:
|
||||
return inputs, labels
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
||||
"""
|
||||
Data collator used for language modeling.
|
||||
- collates batches of tensors, honoring their tokenizer's pad_token
|
||||
- preprocesses batches for masked language modeling
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if isinstance(examples[0], (dict, BatchEncoding)):
|
||||
input_ids = [e["input_ids"] for e in examples]
|
||||
else:
|
||||
input_ids = examples
|
||||
examples = [{"input_ids": e} for e in examples]
|
||||
|
||||
batch_input = self._tensorize_batch(input_ids)
|
||||
|
||||
mask_labels = []
|
||||
for e in examples:
|
||||
ref_tokens = []
|
||||
for id in e["input_ids"].tolist():
|
||||
token = self.tokenizer._convert_id_to_token(id)
|
||||
ref_tokens.append(token)
|
||||
|
||||
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
|
||||
if "chinese_ref" in e:
|
||||
ref_pos = e["chinese_ref"].tolist()
|
||||
len_seq = e["input_ids"].size(0)
|
||||
for i in range(len_seq):
|
||||
if i in ref_pos:
|
||||
ref_tokens[i] = "##" + ref_tokens[i]
|
||||
mask_labels.append(self._whole_word_mask(ref_tokens))
|
||||
batch_mask = self._tensorize_batch(mask_labels)
|
||||
inputs, labels = self.mask_tokens(batch_input, batch_mask)
|
||||
return {"input_ids": inputs, "labels": labels}
|
||||
|
||||
def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
|
||||
"""
|
||||
Get 0/1 labels for masked tokens with whole word mask proxy
|
||||
"""
|
||||
|
||||
cand_indexes = []
|
||||
for (i, token) in enumerate(input_tokens):
|
||||
if token == "[CLS]" or token == "[SEP]":
|
||||
continue
|
||||
|
||||
if len(cand_indexes) >= 1 and token.startswith("##"):
|
||||
cand_indexes[-1].append(i)
|
||||
else:
|
||||
cand_indexes.append([i])
|
||||
|
||||
random.shuffle(cand_indexes)
|
||||
num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
|
||||
masked_lms = []
|
||||
covered_indexes = set()
|
||||
for index_set in cand_indexes:
|
||||
if len(masked_lms) >= num_to_predict:
|
||||
break
|
||||
# If adding a whole-word mask would exceed the maximum number of
|
||||
# predictions, then just skip this candidate.
|
||||
if len(masked_lms) + len(index_set) > num_to_predict:
|
||||
continue
|
||||
is_any_index_covered = False
|
||||
for index in index_set:
|
||||
if index in covered_indexes:
|
||||
is_any_index_covered = True
|
||||
break
|
||||
if is_any_index_covered:
|
||||
continue
|
||||
for index in index_set:
|
||||
covered_indexes.add(index)
|
||||
masked_lms.append(index)
|
||||
|
||||
assert len(covered_indexes) == len(masked_lms)
|
||||
mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
|
||||
return mask_labels
|
||||
|
||||
def mask_tokens(self, inputs: torch.Tensor, mask_labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
||||
Set 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
|
||||
"""
|
||||
|
||||
if self.tokenizer.mask_token is None:
|
||||
raise ValueError(
|
||||
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
|
||||
)
|
||||
labels = inputs.clone()
|
||||
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
||||
|
||||
probability_matrix = mask_labels
|
||||
|
||||
special_tokens_mask = [
|
||||
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
||||
]
|
||||
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
||||
if self.tokenizer._pad_token is not None:
|
||||
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
||||
probability_matrix.masked_fill_(padding_mask, value=0.0)
|
||||
|
||||
masked_indices = probability_matrix.bool()
|
||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||
|
||||
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
||||
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
||||
|
||||
# 10% of the time, we replace masked input tokens with random word
|
||||
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
||||
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
||||
inputs[indices_random] = random_words[indices_random]
|
||||
|
||||
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
||||
return inputs, labels
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSOP(DataCollatorForLanguageModeling):
|
||||
"""
|
||||
|
@ -5,6 +5,7 @@
|
||||
from .glue import GlueDataset, GlueDataTrainingArguments
|
||||
from .language_modeling import (
|
||||
LineByLineTextDataset,
|
||||
LineByLineWithRefDataset,
|
||||
LineByLineWithSOPTextDataset,
|
||||
TextDataset,
|
||||
TextDatasetForNextSentencePrediction,
|
||||
|
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
@ -106,12 +107,48 @@ class LineByLineTextDataset(Dataset):
|
||||
|
||||
batch_encoding = tokenizer(lines, add_special_tokens=True, truncation=True, max_length=block_size)
|
||||
self.examples = batch_encoding["input_ids"]
|
||||
self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, i) -> torch.Tensor:
|
||||
return torch.tensor(self.examples[i], dtype=torch.long)
|
||||
def __getitem__(self, i) -> Dict[str, torch.tensor]:
|
||||
return self.examples[i]
|
||||
|
||||
|
||||
class LineByLineWithRefDataset(Dataset):
|
||||
"""
|
||||
This will be superseded by a framework-agnostic approach
|
||||
soon.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, ref_path: str):
|
||||
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
|
||||
assert os.path.isfile(ref_path), f"Ref file path {file_path} not found"
|
||||
# Here, we do not cache the features, operating under the assumption
|
||||
# that we will soon use fast multithreaded tokenizers from the
|
||||
# `tokenizers` repo everywhere =)
|
||||
logger.info("Creating features from dataset file at %s", file_path)
|
||||
logger.info("Use ref segment results at %s", ref_path)
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
data = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
|
||||
batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size)
|
||||
self.examples = batch_encoding["input_ids"]
|
||||
self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
|
||||
|
||||
# Get ref inf from file
|
||||
with open(ref_path, encoding="utf-8") as f:
|
||||
ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
|
||||
assert len(data) == len(ref)
|
||||
n = len(self.examples)
|
||||
for i in range(n):
|
||||
self.examples[i]["chinese_ref"] = torch.tensor(ref[i], dtype=torch.long)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.tensor]:
|
||||
return self.examples[i]
|
||||
|
||||
|
||||
class LineByLineWithSOPTextDataset(Dataset):
|
||||
|
@ -45,6 +45,11 @@ class DataCollatorForSOP:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class DataCollatorForWholeWordMask:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class DataCollatorWithPadding:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
@ -69,6 +74,11 @@ class LineByLineTextDataset:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class LineByLineWithRefDataset:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class LineByLineWithSOPTextDataset:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
Loading…
Reference in New Issue
Block a user