mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
parent
7e406f4a65
commit
f785c51692
@ -1069,6 +1069,7 @@ class LukeForEntityClassification(LukePreTrainedModel):
|
|||||||
>>> logits = outputs.logits
|
>>> logits = outputs.logits
|
||||||
>>> predicted_class_idx = logits.argmax(-1).item()
|
>>> predicted_class_idx = logits.argmax(-1).item()
|
||||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||||
|
Predicted class: person
|
||||||
"""
|
"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
@ -1181,6 +1182,7 @@ class LukeForEntityPairClassification(LukePreTrainedModel):
|
|||||||
>>> logits = outputs.logits
|
>>> logits = outputs.logits
|
||||||
>>> predicted_class_idx = logits.argmax(-1).item()
|
>>> predicted_class_idx = logits.argmax(-1).item()
|
||||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||||
|
Predicted class: per:cities_of_residence
|
||||||
"""
|
"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
@ -1309,8 +1311,12 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
|
|||||||
>>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
|
>>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
>>> logits = outputs.logits
|
>>> logits = outputs.logits
|
||||||
>>> predicted_class_idx = logits.argmax(-1).item()
|
>>> predicted_class_indices = logits.argmax(-1).squeeze().tolist()
|
||||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
>>> for span, predicted_class_idx in zip(entity_spans, predicted_class_indices):
|
||||||
|
... if predicted_class_idx != 0:
|
||||||
|
... print(text[span[0]:span[1]], model.config.id2label[predicted_class_idx])
|
||||||
|
Beyoncé PER
|
||||||
|
Los Angeles LOC
|
||||||
"""
|
"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user