diff --git a/README.md b/README.md index 04ce7d45eda..f1f9b89f0b8 100644 --- a/README.md +++ b/README.md @@ -89,15 +89,18 @@ BERT_MODEL_CLASSES = [BertModel, BertForPreTraining, BertForMaskedLM, BertForNex BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering] -# All the classes for an architecture can be loaded from pretrained weights for this architecture -# Note that additional weights added for fine-tuning are only initialized and need to be trained on the down-stream task +# All the classes for an architecture can be initiated from pretrained weights for this architecture +# Note that additional weights added for fine-tuning are only initialized +# and need to be trained on the down-stream task tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') for model_class in BERT_MODEL_CLASSES: # Load pretrained model/tokenizer model = model_class.from_pretrained('bert-base-uncased') # Models can return full list of hidden-states & attentions weights at each layer -model = model_class.from_pretrained(pretrained_weights, output_hidden_states=True, output_attentions=True) +model = model_class.from_pretrained(pretrained_weights, + output_hidden_states=True, + output_attentions=True) input_ids = torch.tensor([tokenizer.encode("Let's see all hidden-states and attentions on this text")]) all_hidden_states, all_attentions = model(input_ids)[-2:]