Add is_model_supported for fx (#28521)

* modify check_if_model_is_supported to return bool

* add is_model_supported and have check_if_model_is_supported use that

* Update src/transformers/utils/fx.py

Fantastic

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
inisis 2024-01-17 01:52:44 +08:00 committed by GitHub
parent 02f8738ef8
commit 7142bdfa90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1199,8 +1199,12 @@ def get_concrete_args(model: nn.Module, input_names: List[str]):
return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
def is_model_supported(model: PreTrainedModel):
return model.__class__.__name__ in _SUPPORTED_MODELS
def check_if_model_is_supported(model: PreTrainedModel):
if model.__class__.__name__ not in _SUPPORTED_MODELS:
if not is_model_supported(model):
supported_model_names = ", ".join(_SUPPORTED_MODELS)
raise NotImplementedError(
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"