diff --git a/transformers/commands/train.py b/transformers/commands/train.py index fc89d48594e..878ad21037c 100644 --- a/transformers/commands/train.py +++ b/transformers/commands/train.py @@ -83,7 +83,7 @@ class TrainCommand(BaseTransformersCLICommand): self.logger.info('Loading model {}'.format(args.model_name)) self.model_name = args.model_name - self.tokenizer = AutoTokenizer.from_pretrained(args.model_name) + self.pipeline = AutoTokenizer.from_pretrained(args.model_name) if args.task == 'text_classification': self.model = SequenceClassifModel.from_pretrained(args.model_name) elif args.task == 'token_classification': diff --git a/transformers/data/processors/utils.py b/transformers/data/processors/utils.py index 75bed86042d..61b139c02bb 100644 --- a/transformers/data/processors/utils.py +++ b/transformers/data/processors/utils.py @@ -222,7 +222,7 @@ class SingleSentenceClassificationProcessor(DataProcessor): batch_length = max(len(input_ids) for input_ids in all_input_ids) features = [] - for (ex_index, (input_ids, example)) in enumerate(zip(all_input_ids, examples)): + for (ex_index, (input_ids, example)) in enumerate(zip(all_input_ids, self.examples)): if ex_index % 10000 == 0: logger.info("Writing example %d", ex_index) # The mask has 1 for real tokens and 0 for padding tokens. Only real diff --git a/transformers/pipeline.py b/transformers/pipeline.py index f2c55def92e..dc7bcaeac30 100644 --- a/transformers/pipeline.py +++ b/transformers/pipeline.py @@ -109,7 +109,32 @@ class TextClassificationPipeline(object): self.tokenizer.save_pretrained(save_directory) - def compile(self, learning_rate=3e-5, epsilon=1e-8): + def prepare_data(self, train_samples_text, train_samples_labels, + valid_samples_text=None, valid_samples_labels=None, + validation_split=0.1, **kwargs): + dataset = SingleSentenceClassificationProcessor.create_from_examples(train_samples_text, + train_samples_labels) + num_data_samples = len(dataset) + if valid_samples_text is not None and valid_samples_labels is not None: + valid_dataset = SingleSentenceClassificationProcessor.create_from_examples(valid_samples_text, + valid_samples_labels) + num_valid_samples = len(valid_dataset) + train_dataset = dataset + num_train_samples = num_data_samples + else: + assert 0.0 <= validation_split <= 1.0, "validation_split should be between 0.0 and 1.0" + num_valid_samples = int(num_data_samples * validation_split) + num_train_samples = num_data_samples - num_valid_samples + train_dataset = dataset[num_train_samples] + valid_dataset = dataset[num_valid_samples] + + logger.info('Tokenizing and processing dataset') + train_dataset = train_dataset.get_features(self.tokenizer, return_tensors=self.framework) + valid_dataset = valid_dataset.get_features(self.tokenizer, return_tensors=self.framework) + return train_dataset, valid_dataset, num_train_samples, num_valid_samples + + + def compile(self, learning_rate=3e-5, epsilon=1e-8, **kwargs): if self.framework == 'tf': logger.info('Preparing model') # Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule @@ -125,39 +150,20 @@ class TextClassificationPipeline(object): self.is_compiled = True - def prepare_data(self, train_samples_text, train_samples_labels, - valid_samples_text=None, valid_samples_labels=None, - validation_split=0.1): - dataset = SingleSentenceClassificationProcessor.create_from_examples(train_samples_text, - train_samples_labels) - num_data_samples = len(dataset) - if valid_samples_text is not None and valid_samples_labels is not None: - valid_dataset = SingleSentenceClassificationProcessor.create_from_examples(valid_samples_text, - valid_samples_labels) - num_valid_samples = len(valid_dataset) - train_dataset = dataset - num_train_samples = num_data_samples - else: - assert 0.0 < validation_split < 1.0, "validation_split should be between 0.0 and 1.0" - num_valid_samples = int(num_data_samples * validation_split) - num_train_samples = num_data_samples - num_valid_samples - train_dataset = dataset[num_train_samples] - valid_dataset = dataset[num_valid_samples] - - logger.info('Tokenizing and processing dataset') - train_dataset = train_dataset.get_features(self.tokenizer, return_tensors=self.framework) - valid_dataset = valid_dataset.get_features(self.tokenizer, return_tensors=self.framework) - return train_dataset, valid_dataset, num_train_samples, num_valid_samples - - - def fit(self, train_samples_text, train_samples_labels, + def fit(self, train_samples_text=None, train_samples_labels=None, valid_samples_text=None, valid_samples_labels=None, train_batch_size=None, valid_batch_size=None, validation_split=0.1, **kwargs): + # Generic compatibility with sklearn and Keras + if 'y' in kwargs and train_samples_labels is None: + train_samples_labels = kwargs.pop('y') + if 'X' in kwargs and train_samples_text is None: + train_samples_text = kwargs.pop('X') + if not self.is_compiled: - self.compile() + self.compile(**kwargs) datasets = self.prepare_data(train_samples_text, train_samples_labels, valid_samples_text, valid_samples_labels, @@ -180,11 +186,32 @@ class TextClassificationPipeline(object): self.is_trained = True - def __call__(self, text): + def fit_transform(self, *texts, **kwargs): + # Generic compatibility with sklearn and Keras + self.fit(*texts, **kwargs) + return self(*texts, **kwargs) + + + def transform(self, *texts, **kwargs): + # Generic compatibility with sklearn and Keras + return self(*texts, **kwargs) + + + def predict(self, *texts, **kwargs): + # Generic compatibility with sklearn and Keras + return self(*texts, **kwargs) + + + def __call__(self, *texts, **kwargs): + # Generic compatibility with sklearn and Keras + if 'X' in kwargs and not texts: + texts = kwargs.pop('X') + if not self.is_trained: logger.error("Some weights of the model are not trained. Please fine-tune the model on a classification task before using it.") - inputs = self.tokenizer.encode_plus(text, add_special_tokens=True, return_tensors=self.framework) + inputs = self.tokenizer.batch_encode_plus(texts, add_special_tokens=True, return_tensors=self.framework) + if self.framework == 'tf': # TODO trace model predictions = self.model(**inputs)[0]