Update code example (#11631)

* Update code example

* Code review
This commit is contained in:
NielsRogge 2021-05-10 07:48:43 +02:00 committed by GitHub
parent 7e406f4a65
commit f785c51692
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1069,6 +1069,7 @@ class LukeForEntityClassification(LukePreTrainedModel):
>>> logits = outputs.logits
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: person
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@ -1181,6 +1182,7 @@ class LukeForEntityPairClassification(LukePreTrainedModel):
>>> logits = outputs.logits
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: per:cities_of_residence
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@ -1309,8 +1311,12 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
>>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
>>> predicted_class_indices = logits.argmax(-1).squeeze().tolist()
>>> for span, predicted_class_idx in zip(entity_spans, predicted_class_indices):
... if predicted_class_idx != 0:
... print(text[span[0]:span[1]], model.config.id2label[predicted_class_idx])
Beyoncé PER
Los Angeles LOC
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict