mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
7e406f4a65
commit
f785c51692
@ -1069,6 +1069,7 @@ class LukeForEntityClassification(LukePreTrainedModel):
|
||||
>>> logits = outputs.logits
|
||||
>>> predicted_class_idx = logits.argmax(-1).item()
|
||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||
Predicted class: person
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@ -1181,6 +1182,7 @@ class LukeForEntityPairClassification(LukePreTrainedModel):
|
||||
>>> logits = outputs.logits
|
||||
>>> predicted_class_idx = logits.argmax(-1).item()
|
||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||
Predicted class: per:cities_of_residence
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@ -1309,8 +1311,12 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
|
||||
>>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits = outputs.logits
|
||||
>>> predicted_class_idx = logits.argmax(-1).item()
|
||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||
>>> predicted_class_indices = logits.argmax(-1).squeeze().tolist()
|
||||
>>> for span, predicted_class_idx in zip(entity_spans, predicted_class_indices):
|
||||
... if predicted_class_idx != 0:
|
||||
... print(text[span[0]:span[1]], model.config.id2label[predicted_class_idx])
|
||||
Beyoncé PER
|
||||
Los Angeles LOC
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user