Adding model_parallel = False

This commit is contained in:
retarfi 2024-05-12 03:36:57 +09:00
parent e0c3cee170
commit ba1d99976a
2 changed files with 4 additions and 0 deletions

View File

@ -2189,6 +2189,8 @@ class MT5ForTokenClassification(MT5PreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
self.model_parallel = False
@add_start_docstrings_to_model_forward(MT5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
# Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.forward with T5->MT5

View File

@ -2136,6 +2136,8 @@ class T5ForTokenClassification(T5PreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
self.model_parallel = False
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(