diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index 9e2cbf485c4..6453669f689 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -468,9 +468,17 @@ def generate(model, input_ids, generation_config=None, left_padding=None, **kwar Follow the recommended practices below to ensure your custom decoding method works as expected. - Feel free to reuse the logic for validation and input preparation in the original [`~GenerationMixin.generate`]. - Pin the `transformers` version in the requirements if you use any private method/attribute in `model`. -- You can add other files in the `custom_generate` folder, and use relative imports. - Consider adding model validation, input validation, or even a separate test file to help users sanity-check your code in their environment. +Your custom `generate` method can relative import code from the `custom_generate` folder. For example, if you have a `utils.py` file, you can import it like this: + +```py +from .utils import some_function +``` + +Only relative imports from the same-level `custom_generate` folder are supported. Parent/sibling folder imports are not valid. The `custom_generate` argument also works locally with any directory that contains a `custom_generate` structure. This is the recommended workflow for developing your custom decoding method. + + #### requirements.txt You can optionally specify additional Python requirements in a `requirements.txt` file inside the `custom_generate` folder. These are checked at runtime and an exception will be thrown if they're missing, nudging users to update their environment accordingly. diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index 6a88859e0aa..7a498721a91 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -402,10 +402,11 @@ def get_cached_module_file( if not (submodule_path / module_file).exists() or not filecmp.cmp( resolved_module_file, str(submodule_path / module_file) ): + (submodule_path / module_file).parent.mkdir(parents=True, exist_ok=True) shutil.copy(resolved_module_file, submodule_path / module_file) importlib.invalidate_caches() for module_needed in modules_needed: - module_needed = f"{module_needed}.py" + module_needed = Path(module_file).parent / f"{module_needed}.py" module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed) if not (submodule_path / module_needed).exists() or not filecmp.cmp( module_needed_file, str(submodule_path / module_needed)