mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00

* add: the contrastive search for generaton_utils * add: testing scripts for contrastive search under examples/text-generation * update the quality of codes * revise the docstring; make the generation_contrastive_search.py scripts; * revise the examples/pytorch/text-generation/run_generation_contrastive_search.py to the auto-APIs format * revise the necessary documents * fix: revise the docstring of generation_contrastive_search.py * Fix the code indentation * fix: revise the nits and examples in contrastive_search docstring. * fix the copyright * delete generation_contrastive_search.py * revise the logic in contrastive_search * update the intergration test and the docstring * run the tests over * add the slow decorate to the contrastive_search intergrate test * add more test * do the style, quality, consistency checks
139 lines
5.0 KiB
Python
Executable File
139 lines
5.0 KiB
Python
Executable File
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
# Copyright 2022 University of Cambridge, Tencent AI Lab, DeepMind and The University of Hong Kong Authors and The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" The examples of running contrastive search on the auto-APIs;
|
|
|
|
Running this example:
|
|
python run_generation_contrastive_search.py --model_name_or_path=gpt2-large --penalty_alpha=0.6 --k=4 --length=256
|
|
"""
|
|
|
|
|
|
import argparse
|
|
import logging
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
logging.basicConfig(
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
datefmt="%m/%d/%Y %H:%M:%S",
|
|
level=logging.INFO,
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def set_seed(args):
|
|
np.random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
if args.n_gpu > 0:
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--model_name_or_path",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
)
|
|
parser.add_argument("--prompt", type=str, default="")
|
|
parser.add_argument("--length", type=int, default=20)
|
|
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
|
|
parser.add_argument(
|
|
"--temperature",
|
|
type=float,
|
|
default=1.0,
|
|
help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
|
|
)
|
|
parser.add_argument(
|
|
"--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
|
|
)
|
|
parser.add_argument("--k", type=int, default=0)
|
|
parser.add_argument("--penalty_alpha", type=float, default=0.0)
|
|
parser.add_argument("--p", type=float, default=0.9)
|
|
|
|
parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
|
|
parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
|
|
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
|
|
|
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
|
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
|
parser.add_argument(
|
|
"--fp16",
|
|
action="store_true",
|
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
|
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
|
|
|
|
logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}")
|
|
|
|
set_seed(args)
|
|
|
|
# Initialize the model and tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
|
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
|
|
|
|
# tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
|
|
# model = OPTForCausalLM.from_pretrained(args.model_name_or_path)
|
|
model.to(args.device)
|
|
|
|
if args.fp16:
|
|
model.half()
|
|
|
|
logger.info(args)
|
|
prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")
|
|
|
|
inputs = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False)
|
|
inputs = {key: value.to(args.device) for key, value in inputs.items()}
|
|
|
|
output_sequences = model.generate(
|
|
**inputs,
|
|
max_length=args.length + len(inputs["input_ids"][0]),
|
|
penalty_alpha=args.penalty_alpha,
|
|
top_k=args.k,
|
|
)
|
|
|
|
generated_sequences = []
|
|
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
|
|
print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
|
|
generated_sequence = generated_sequence.tolist()
|
|
|
|
# Decode text
|
|
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, add_special_tokens=False)
|
|
|
|
# Remove all text after the stop token
|
|
text = text[: text.find(args.stop_token) if args.stop_token else None]
|
|
|
|
# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
|
|
total_sequence = (
|
|
prompt_text + text[len(tokenizer.decode(inputs["input_ids"][0], clean_up_tokenization_spaces=True)) :]
|
|
)
|
|
|
|
generated_sequences.append(total_sequence)
|
|
print(total_sequence)
|
|
|
|
return generated_sequences
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|