mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-19 20:48:22 +06:00
Fix custom generate from local directory (#38916)
Some checks failed
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
New model PR merged notification / Notify new model (push) Has been cancelled
Self-hosted runner (push-caller) / Check if setup was changed (push) Has been cancelled
Self-hosted runner (push-caller) / build-docker-containers (push) Has been cancelled
Self-hosted runner (push-caller) / Trigger Push CI (push) Has been cancelled
Some checks failed
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
New model PR merged notification / Notify new model (push) Has been cancelled
Self-hosted runner (push-caller) / Check if setup was changed (push) Has been cancelled
Self-hosted runner (push-caller) / build-docker-containers (push) Has been cancelled
Self-hosted runner (push-caller) / Trigger Push CI (push) Has been cancelled
Fix custom generate from local directory: 1. Create parent dirs before copying files (custom_generate dir) 2. Correctly copy relative imports to the submodule file. 3. Update docs.
This commit is contained in:
parent
3d34b92116
commit
166e823f77
@ -468,9 +468,17 @@ def generate(model, input_ids, generation_config=None, left_padding=None, **kwar
|
||||
Follow the recommended practices below to ensure your custom decoding method works as expected.
|
||||
- Feel free to reuse the logic for validation and input preparation in the original [`~GenerationMixin.generate`].
|
||||
- Pin the `transformers` version in the requirements if you use any private method/attribute in `model`.
|
||||
- You can add other files in the `custom_generate` folder, and use relative imports.
|
||||
- Consider adding model validation, input validation, or even a separate test file to help users sanity-check your code in their environment.
|
||||
|
||||
Your custom `generate` method can relative import code from the `custom_generate` folder. For example, if you have a `utils.py` file, you can import it like this:
|
||||
|
||||
```py
|
||||
from .utils import some_function
|
||||
```
|
||||
|
||||
Only relative imports from the same-level `custom_generate` folder are supported. Parent/sibling folder imports are not valid. The `custom_generate` argument also works locally with any directory that contains a `custom_generate` structure. This is the recommended workflow for developing your custom decoding method.
|
||||
|
||||
|
||||
#### requirements.txt
|
||||
|
||||
You can optionally specify additional Python requirements in a `requirements.txt` file inside the `custom_generate` folder. These are checked at runtime and an exception will be thrown if they're missing, nudging users to update their environment accordingly.
|
||||
|
@ -402,10 +402,11 @@ def get_cached_module_file(
|
||||
if not (submodule_path / module_file).exists() or not filecmp.cmp(
|
||||
resolved_module_file, str(submodule_path / module_file)
|
||||
):
|
||||
(submodule_path / module_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(resolved_module_file, submodule_path / module_file)
|
||||
importlib.invalidate_caches()
|
||||
for module_needed in modules_needed:
|
||||
module_needed = f"{module_needed}.py"
|
||||
module_needed = Path(module_file).parent / f"{module_needed}.py"
|
||||
module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed)
|
||||
if not (submodule_path / module_needed).exists() or not filecmp.cmp(
|
||||
module_needed_file, str(submodule_path / module_needed)
|
||||
|
Loading…
Reference in New Issue
Block a user