Example doc for token classification of Llama and Dependent/Copied Models (#34139)

* Added Example Doc for token classification on all tokenClassificationModels copied from llama

* Refactor code to add code sample docstrings for Gemma and Gemma2 models (including modular Gemma)

* Refactor code to update model checkpoint names for Qwen2 models
This commit is contained in:
Vijay 2024-10-22 22:56:16 +05:30 committed by GitHub
parent 644d5287b2
commit 049682a5a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 90 additions and 3 deletions

View File

@ -39,6 +39,7 @@ from ...modeling_outputs import (
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
@ -48,6 +49,9 @@ from ...utils import (
from .configuration_gemma import GemmaConfig
_CHECKPOINT_FOR_DOC = "google/gemma-7b"
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
@ -1233,6 +1237,11 @@ class GemmaForTokenClassification(GemmaPreTrainedModel):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@ -49,6 +49,7 @@ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
SPIECE_UNDERLINE = ""
_CHECKPOINT_FOR_DOC = "google/gemma-7b"
logger = logging.get_logger(__name__)

View File

@ -37,6 +37,7 @@ from ...modeling_outputs import (
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal,
@ -47,6 +48,9 @@ from ...utils import (
from .configuration_gemma2 import Gemma2Config
_CHECKPOINT_FOR_DOC = "google/gemma2-7b"
class Gemma2RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
@ -1292,6 +1296,11 @@ class Gemma2ForTokenClassification(Gemma2PreTrainedModel):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@ -50,6 +50,8 @@ if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
_CHECKPOINT_FOR_DOC = "google/gemma2-7b"
logger = logging.get_logger(__name__)

View File

@ -41,6 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
@ -52,6 +53,7 @@ from .configuration_llama import LlamaConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf"
_CONFIG_FOR_DOC = "LlamaConfig"
@ -1446,6 +1448,11 @@ class LlamaForTokenClassification(LlamaPreTrainedModel):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@ -40,6 +40,7 @@ from ...modeling_outputs import (
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
@ -55,6 +56,7 @@ if is_flash_attn_2_available():
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1"
_CONFIG_FOR_DOC = "MistralConfig"
@ -1242,6 +1244,11 @@ class MistralForTokenClassification(MistralPreTrainedModel):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@ -41,6 +41,7 @@ from ...modeling_outputs import (
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
@ -65,6 +66,7 @@ if is_torch_fx_available():
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "mistralai/Mixtral-8x7B-v0.1"
_CONFIG_FOR_DOC = "MixtralConfig"
@ -1468,6 +1470,11 @@ class MixtralForTokenClassification(MixtralPreTrainedModel):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@ -39,6 +39,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
@ -50,6 +51,7 @@ from .configuration_nemotron import NemotronConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "nvidia/nemotron-3-8b-base-4k-hf"
_CONFIG_FOR_DOC = "NemotronConfig"
@ -1323,6 +1325,11 @@ class NemotronForTokenClassification(NemotronPreTrainedModel):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(NEMOTRON_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@ -39,12 +39,19 @@ from ...modeling_outputs import (
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_persimmon import PersimmonConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "adept/persimmon-8b-base"
_CONFIG_FOR_DOC = "PersimmonConfig"
@ -1120,6 +1127,11 @@ class PersimmonForTokenClassification(PersimmonPreTrainedModel):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@ -41,6 +41,7 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
@ -58,7 +59,7 @@ if is_flash_attn_2_available():
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B"
_CONFIG_FOR_DOC = "Qwen2Config"
@ -1348,6 +1349,11 @@ class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@ -41,6 +41,7 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
@ -56,7 +57,7 @@ if is_flash_attn_2_available():
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "Qwen/Qwen1.5-MoE-A2.7B"
_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-57B-A14B"
_CONFIG_FOR_DOC = "Qwen2MoeConfig"
@ -1533,6 +1534,11 @@ class Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@ -40,6 +40,7 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
@ -56,6 +57,7 @@ if is_flash_attn_2_available():
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "stabilityai/stablelm-3b-4e1t"
_CONFIG_FOR_DOC = "StableLmConfig"
@ -1396,6 +1398,11 @@ class StableLmForTokenClassification(StableLmPreTrainedModel):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@ -40,6 +40,7 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
@ -56,6 +57,7 @@ if is_flash_attn_2_available():
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "bigcode/starcoder2-7b"
_CONFIG_FOR_DOC = "Starcoder2Config"
@ -1316,6 +1318,11 @@ class Starcoder2ForTokenClassification(Starcoder2PreTrainedModel):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,