mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
* Changed max_position_embeddings default value from 2048 to 4096 * force push * Fixed formatting issues. Fixed missing argument in write_model. * Reverted to the default value 2048 in the Llama config. Added comments for the llama_version argument. * Fixed issue with default value value of max_position_embeddings in docstring * Updated help message for llama versions Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
2749e479f3
commit
de11e654c9
@ -80,7 +80,9 @@ def write_json(text, path):
|
||||
json.dump(text, f)
|
||||
|
||||
|
||||
def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True):
|
||||
def write_model(
|
||||
model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True, llama_version=1
|
||||
):
|
||||
# for backward compatibility, before you needed the repo to be called `my_repo/model_size`
|
||||
if not os.path.isfile(os.path.join(input_base_path, "params.json")):
|
||||
input_base_path = os.path.join(input_base_path, model_size)
|
||||
@ -102,7 +104,16 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
|
||||
if base > 10000.0:
|
||||
max_position_embeddings = 16384
|
||||
else:
|
||||
max_position_embeddings = 2048
|
||||
# Depending on the Llama version, the default max_position_embeddings has different values.
|
||||
if llama_version == 1:
|
||||
max_position_embeddings = 2048
|
||||
elif llama_version == 2:
|
||||
max_position_embeddings = 4096
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Version {llama_version} of llama is not supported yet. "
|
||||
"Current supported versions of llama are [1, 2]."
|
||||
)
|
||||
|
||||
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
|
||||
if tokenizer_path is not None:
|
||||
@ -301,6 +312,14 @@ def main():
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
|
||||
# Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
|
||||
parser.add_argument(
|
||||
"--llama_version",
|
||||
choices=[1, 2],
|
||||
default=1,
|
||||
type=int,
|
||||
help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
spm_path = os.path.join(args.input_dir, "tokenizer.model")
|
||||
if args.model_size != "tokenizer_only":
|
||||
@ -310,6 +329,7 @@ def main():
|
||||
model_size=args.model_size,
|
||||
safe_serialization=args.safe_serialization,
|
||||
tokenizer_path=spm_path,
|
||||
llama_version=args.llama_version,
|
||||
)
|
||||
else:
|
||||
write_tokenizer(args.output_dir, spm_path)
|
||||
|
Loading…
Reference in New Issue
Block a user