mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
compatibility with sklearn and keras
This commit is contained in:
parent
b81ab431f2
commit
7c1697562a
@ -83,7 +83,7 @@ class TrainCommand(BaseTransformersCLICommand):
|
||||
|
||||
self.logger.info('Loading model {}'.format(args.model_name))
|
||||
self.model_name = args.model_name
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
self.pipeline = AutoTokenizer.from_pretrained(args.model_name)
|
||||
if args.task == 'text_classification':
|
||||
self.model = SequenceClassifModel.from_pretrained(args.model_name)
|
||||
elif args.task == 'token_classification':
|
||||
|
@ -222,7 +222,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
|
||||
batch_length = max(len(input_ids) for input_ids in all_input_ids)
|
||||
|
||||
features = []
|
||||
for (ex_index, (input_ids, example)) in enumerate(zip(all_input_ids, examples)):
|
||||
for (ex_index, (input_ids, example)) in enumerate(zip(all_input_ids, self.examples)):
|
||||
if ex_index % 10000 == 0:
|
||||
logger.info("Writing example %d", ex_index)
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
|
@ -109,7 +109,32 @@ class TextClassificationPipeline(object):
|
||||
self.tokenizer.save_pretrained(save_directory)
|
||||
|
||||
|
||||
def compile(self, learning_rate=3e-5, epsilon=1e-8):
|
||||
def prepare_data(self, train_samples_text, train_samples_labels,
|
||||
valid_samples_text=None, valid_samples_labels=None,
|
||||
validation_split=0.1, **kwargs):
|
||||
dataset = SingleSentenceClassificationProcessor.create_from_examples(train_samples_text,
|
||||
train_samples_labels)
|
||||
num_data_samples = len(dataset)
|
||||
if valid_samples_text is not None and valid_samples_labels is not None:
|
||||
valid_dataset = SingleSentenceClassificationProcessor.create_from_examples(valid_samples_text,
|
||||
valid_samples_labels)
|
||||
num_valid_samples = len(valid_dataset)
|
||||
train_dataset = dataset
|
||||
num_train_samples = num_data_samples
|
||||
else:
|
||||
assert 0.0 <= validation_split <= 1.0, "validation_split should be between 0.0 and 1.0"
|
||||
num_valid_samples = int(num_data_samples * validation_split)
|
||||
num_train_samples = num_data_samples - num_valid_samples
|
||||
train_dataset = dataset[num_train_samples]
|
||||
valid_dataset = dataset[num_valid_samples]
|
||||
|
||||
logger.info('Tokenizing and processing dataset')
|
||||
train_dataset = train_dataset.get_features(self.tokenizer, return_tensors=self.framework)
|
||||
valid_dataset = valid_dataset.get_features(self.tokenizer, return_tensors=self.framework)
|
||||
return train_dataset, valid_dataset, num_train_samples, num_valid_samples
|
||||
|
||||
|
||||
def compile(self, learning_rate=3e-5, epsilon=1e-8, **kwargs):
|
||||
if self.framework == 'tf':
|
||||
logger.info('Preparing model')
|
||||
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule
|
||||
@ -125,39 +150,20 @@ class TextClassificationPipeline(object):
|
||||
self.is_compiled = True
|
||||
|
||||
|
||||
def prepare_data(self, train_samples_text, train_samples_labels,
|
||||
valid_samples_text=None, valid_samples_labels=None,
|
||||
validation_split=0.1):
|
||||
dataset = SingleSentenceClassificationProcessor.create_from_examples(train_samples_text,
|
||||
train_samples_labels)
|
||||
num_data_samples = len(dataset)
|
||||
if valid_samples_text is not None and valid_samples_labels is not None:
|
||||
valid_dataset = SingleSentenceClassificationProcessor.create_from_examples(valid_samples_text,
|
||||
valid_samples_labels)
|
||||
num_valid_samples = len(valid_dataset)
|
||||
train_dataset = dataset
|
||||
num_train_samples = num_data_samples
|
||||
else:
|
||||
assert 0.0 < validation_split < 1.0, "validation_split should be between 0.0 and 1.0"
|
||||
num_valid_samples = int(num_data_samples * validation_split)
|
||||
num_train_samples = num_data_samples - num_valid_samples
|
||||
train_dataset = dataset[num_train_samples]
|
||||
valid_dataset = dataset[num_valid_samples]
|
||||
|
||||
logger.info('Tokenizing and processing dataset')
|
||||
train_dataset = train_dataset.get_features(self.tokenizer, return_tensors=self.framework)
|
||||
valid_dataset = valid_dataset.get_features(self.tokenizer, return_tensors=self.framework)
|
||||
return train_dataset, valid_dataset, num_train_samples, num_valid_samples
|
||||
|
||||
|
||||
def fit(self, train_samples_text, train_samples_labels,
|
||||
def fit(self, train_samples_text=None, train_samples_labels=None,
|
||||
valid_samples_text=None, valid_samples_labels=None,
|
||||
train_batch_size=None, valid_batch_size=None,
|
||||
validation_split=0.1,
|
||||
**kwargs):
|
||||
|
||||
# Generic compatibility with sklearn and Keras
|
||||
if 'y' in kwargs and train_samples_labels is None:
|
||||
train_samples_labels = kwargs.pop('y')
|
||||
if 'X' in kwargs and train_samples_text is None:
|
||||
train_samples_text = kwargs.pop('X')
|
||||
|
||||
if not self.is_compiled:
|
||||
self.compile()
|
||||
self.compile(**kwargs)
|
||||
|
||||
datasets = self.prepare_data(train_samples_text, train_samples_labels,
|
||||
valid_samples_text, valid_samples_labels,
|
||||
@ -180,11 +186,32 @@ class TextClassificationPipeline(object):
|
||||
self.is_trained = True
|
||||
|
||||
|
||||
def __call__(self, text):
|
||||
def fit_transform(self, *texts, **kwargs):
|
||||
# Generic compatibility with sklearn and Keras
|
||||
self.fit(*texts, **kwargs)
|
||||
return self(*texts, **kwargs)
|
||||
|
||||
|
||||
def transform(self, *texts, **kwargs):
|
||||
# Generic compatibility with sklearn and Keras
|
||||
return self(*texts, **kwargs)
|
||||
|
||||
|
||||
def predict(self, *texts, **kwargs):
|
||||
# Generic compatibility with sklearn and Keras
|
||||
return self(*texts, **kwargs)
|
||||
|
||||
|
||||
def __call__(self, *texts, **kwargs):
|
||||
# Generic compatibility with sklearn and Keras
|
||||
if 'X' in kwargs and not texts:
|
||||
texts = kwargs.pop('X')
|
||||
|
||||
if not self.is_trained:
|
||||
logger.error("Some weights of the model are not trained. Please fine-tune the model on a classification task before using it.")
|
||||
|
||||
inputs = self.tokenizer.encode_plus(text, add_special_tokens=True, return_tensors=self.framework)
|
||||
inputs = self.tokenizer.batch_encode_plus(texts, add_special_tokens=True, return_tensors=self.framework)
|
||||
|
||||
if self.framework == 'tf':
|
||||
# TODO trace model
|
||||
predictions = self.model(**inputs)[0]
|
||||
|
Loading…
Reference in New Issue
Block a user