compatibility with sklearn and keras

This commit is contained in:
thomwolf 2019-10-17 13:17:05 +02:00 committed by Morgan Funtowicz
parent b81ab431f2
commit 7c1697562a
3 changed files with 59 additions and 32 deletions

View File

@ -83,7 +83,7 @@ class TrainCommand(BaseTransformersCLICommand):
self.logger.info('Loading model {}'.format(args.model_name))
self.model_name = args.model_name
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name)
self.pipeline = AutoTokenizer.from_pretrained(args.model_name)
if args.task == 'text_classification':
self.model = SequenceClassifModel.from_pretrained(args.model_name)
elif args.task == 'token_classification':

View File

@ -222,7 +222,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
batch_length = max(len(input_ids) for input_ids in all_input_ids)
features = []
for (ex_index, (input_ids, example)) in enumerate(zip(all_input_ids, examples)):
for (ex_index, (input_ids, example)) in enumerate(zip(all_input_ids, self.examples)):
if ex_index % 10000 == 0:
logger.info("Writing example %d", ex_index)
# The mask has 1 for real tokens and 0 for padding tokens. Only real

View File

@ -109,7 +109,32 @@ class TextClassificationPipeline(object):
self.tokenizer.save_pretrained(save_directory)
def compile(self, learning_rate=3e-5, epsilon=1e-8):
def prepare_data(self, train_samples_text, train_samples_labels,
valid_samples_text=None, valid_samples_labels=None,
validation_split=0.1, **kwargs):
dataset = SingleSentenceClassificationProcessor.create_from_examples(train_samples_text,
train_samples_labels)
num_data_samples = len(dataset)
if valid_samples_text is not None and valid_samples_labels is not None:
valid_dataset = SingleSentenceClassificationProcessor.create_from_examples(valid_samples_text,
valid_samples_labels)
num_valid_samples = len(valid_dataset)
train_dataset = dataset
num_train_samples = num_data_samples
else:
assert 0.0 <= validation_split <= 1.0, "validation_split should be between 0.0 and 1.0"
num_valid_samples = int(num_data_samples * validation_split)
num_train_samples = num_data_samples - num_valid_samples
train_dataset = dataset[num_train_samples]
valid_dataset = dataset[num_valid_samples]
logger.info('Tokenizing and processing dataset')
train_dataset = train_dataset.get_features(self.tokenizer, return_tensors=self.framework)
valid_dataset = valid_dataset.get_features(self.tokenizer, return_tensors=self.framework)
return train_dataset, valid_dataset, num_train_samples, num_valid_samples
def compile(self, learning_rate=3e-5, epsilon=1e-8, **kwargs):
if self.framework == 'tf':
logger.info('Preparing model')
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule
@ -125,39 +150,20 @@ class TextClassificationPipeline(object):
self.is_compiled = True
def prepare_data(self, train_samples_text, train_samples_labels,
valid_samples_text=None, valid_samples_labels=None,
validation_split=0.1):
dataset = SingleSentenceClassificationProcessor.create_from_examples(train_samples_text,
train_samples_labels)
num_data_samples = len(dataset)
if valid_samples_text is not None and valid_samples_labels is not None:
valid_dataset = SingleSentenceClassificationProcessor.create_from_examples(valid_samples_text,
valid_samples_labels)
num_valid_samples = len(valid_dataset)
train_dataset = dataset
num_train_samples = num_data_samples
else:
assert 0.0 < validation_split < 1.0, "validation_split should be between 0.0 and 1.0"
num_valid_samples = int(num_data_samples * validation_split)
num_train_samples = num_data_samples - num_valid_samples
train_dataset = dataset[num_train_samples]
valid_dataset = dataset[num_valid_samples]
logger.info('Tokenizing and processing dataset')
train_dataset = train_dataset.get_features(self.tokenizer, return_tensors=self.framework)
valid_dataset = valid_dataset.get_features(self.tokenizer, return_tensors=self.framework)
return train_dataset, valid_dataset, num_train_samples, num_valid_samples
def fit(self, train_samples_text, train_samples_labels,
def fit(self, train_samples_text=None, train_samples_labels=None,
valid_samples_text=None, valid_samples_labels=None,
train_batch_size=None, valid_batch_size=None,
validation_split=0.1,
**kwargs):
# Generic compatibility with sklearn and Keras
if 'y' in kwargs and train_samples_labels is None:
train_samples_labels = kwargs.pop('y')
if 'X' in kwargs and train_samples_text is None:
train_samples_text = kwargs.pop('X')
if not self.is_compiled:
self.compile()
self.compile(**kwargs)
datasets = self.prepare_data(train_samples_text, train_samples_labels,
valid_samples_text, valid_samples_labels,
@ -180,11 +186,32 @@ class TextClassificationPipeline(object):
self.is_trained = True
def __call__(self, text):
def fit_transform(self, *texts, **kwargs):
# Generic compatibility with sklearn and Keras
self.fit(*texts, **kwargs)
return self(*texts, **kwargs)
def transform(self, *texts, **kwargs):
# Generic compatibility with sklearn and Keras
return self(*texts, **kwargs)
def predict(self, *texts, **kwargs):
# Generic compatibility with sklearn and Keras
return self(*texts, **kwargs)
def __call__(self, *texts, **kwargs):
# Generic compatibility with sklearn and Keras
if 'X' in kwargs and not texts:
texts = kwargs.pop('X')
if not self.is_trained:
logger.error("Some weights of the model are not trained. Please fine-tune the model on a classification task before using it.")
inputs = self.tokenizer.encode_plus(text, add_special_tokens=True, return_tensors=self.framework)
inputs = self.tokenizer.batch_encode_plus(texts, add_special_tokens=True, return_tensors=self.framework)
if self.framework == 'tf':
# TODO trace model
predictions = self.model(**inputs)[0]