mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
better error messages
This commit is contained in:
parent
43a237f15e
commit
a5997dd81a
@ -107,7 +107,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
|
||||
return logits
|
||||
|
||||
|
||||
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, is_xlnet=False, xlm_lang=None, device='cpu'):
|
||||
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0,
|
||||
is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'):
|
||||
context = torch.tensor(context, dtype=torch.long, device=device)
|
||||
context = context.unsqueeze(0).repeat(num_samples, 1)
|
||||
generated = context
|
||||
@ -125,10 +126,16 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
|
||||
target_mapping[0, 0, -1] = 1.0 # predict last token
|
||||
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
|
||||
|
||||
if is_xlm_mlm and xlm_mask_token:
|
||||
# XLM MLM models are direct models (predict same token, not next token)
|
||||
# => need one additional dummy token in the input (will be masked and guessed)
|
||||
input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1)
|
||||
inputs = {'input_ids': input_ids}
|
||||
|
||||
if xlm_lang is not None:
|
||||
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)
|
||||
|
||||
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
||||
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
|
||||
next_token_logits = outputs[0][0, -1, :] / (temperature if temperature > 0 else 1.)
|
||||
|
||||
# reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
|
||||
@ -167,10 +174,7 @@ def main():
|
||||
parser.add_argument('--stop_token', type=str, default=None,
|
||||
help="Token at which text generation is stopped")
|
||||
args = parser.parse_args()
|
||||
if args.model_type in ["ctrl"]:
|
||||
if args.temperature > 0.7 :
|
||||
print('CTRL typically works better with lower temperatures (and lower top_k).')
|
||||
|
||||
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
|
||||
@ -191,6 +195,10 @@ def main():
|
||||
args.length = MAX_LENGTH # avoid infinite loop
|
||||
|
||||
print(args)
|
||||
if args.model_type in ["ctrl"]:
|
||||
if args.temperature > 0.7 :
|
||||
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
|
||||
|
||||
while True:
|
||||
xlm_lang = None
|
||||
# XLM Language usage detailed in the issues #1414
|
||||
@ -204,6 +212,13 @@ def main():
|
||||
language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ")
|
||||
xlm_lang = tokenizer.lang2id[language]
|
||||
|
||||
# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence)
|
||||
is_xlm_mlm = args.model_type in ["xlm"] and 'mlm' in args.model_name_or_path
|
||||
if is_xlm_mlm:
|
||||
xlm_mask_token = tokenizer.mask_token_id
|
||||
else:
|
||||
xlm_mask_token = None
|
||||
|
||||
raw_text = args.prompt if args.prompt else input("Model prompt >>> ")
|
||||
if args.model_type in ["transfo-xl", "xlnet"]:
|
||||
# Models with memory likes to have a long prompt for short inputs.
|
||||
@ -218,6 +233,8 @@ def main():
|
||||
top_p=args.top_p,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
is_xlnet=bool(args.model_type == "xlnet"),
|
||||
is_xlm_mlm=is_xlm_mlm,
|
||||
xlm_mask_token=xlm_mask_token,
|
||||
xlm_lang=xlm_lang,
|
||||
device=args.device,
|
||||
)
|
||||
|
@ -130,20 +130,19 @@ class PretrainedConfig(object):
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||
except EnvironmentError as e:
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||
logger.error(
|
||||
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
|
||||
config_file))
|
||||
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
|
||||
config_file)
|
||||
else:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
msg = "Model name '{}' was not found in model name list ({}). " \
|
||||
"We assumed '{}' was a path or url to a configuration file named {} or " \
|
||||
"a directory containing such a file but couldn't find any such file at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
', '.join(cls.pretrained_config_archive_map.keys()),
|
||||
config_file))
|
||||
raise e
|
||||
config_file, CONFIG_NAME)
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
if resolved_config_file == config_file:
|
||||
logger.info("loading configuration file {}".format(config_file))
|
||||
else:
|
||||
|
@ -316,20 +316,20 @@ class PreTrainedModel(nn.Module):
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||
except EnvironmentError as e:
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
logger.error(
|
||||
"Couldn't reach server at '{}' to download pretrained weights.".format(
|
||||
archive_file))
|
||||
msg = "Couldn't reach server at '{}' to download pretrained weights.".format(
|
||||
archive_file)
|
||||
else:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
msg = "Model name '{}' was not found in model name list ({}). " \
|
||||
"We assumed '{}' was a path or url to model weight files named one of {} but " \
|
||||
"couldn't find any such file at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
', '.join(cls.pretrained_model_archive_map.keys()),
|
||||
archive_file))
|
||||
raise e
|
||||
archive_file,
|
||||
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME])
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
if resolved_archive_file == archive_file:
|
||||
logger.info("loading weights file {}".format(archive_file))
|
||||
else:
|
||||
|
@ -337,13 +337,13 @@ class PreTrainedTokenizer(object):
|
||||
vocab_files[file_id] = full_file_name
|
||||
|
||||
if all(full_file_name is None for full_file_name in vocab_files.values()):
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find tokenizer files"
|
||||
"at this path or url.".format(
|
||||
raise EnvironmentError(
|
||||
"Model name '{}' was not found in tokenizers model name list ({}). "
|
||||
"We assumed '{}' was a path or url to a directory containing vocabulary files "
|
||||
"named {} but couldn't find such vocabulary files at this path or url.".format(
|
||||
pretrained_model_name_or_path, ', '.join(s3_models),
|
||||
pretrained_model_name_or_path, ))
|
||||
return None
|
||||
pretrained_model_name_or_path,
|
||||
list(cls.vocab_files_names.values())))
|
||||
|
||||
# Get files from url, cache, or disk depending on the case
|
||||
try:
|
||||
@ -353,17 +353,18 @@ class PreTrainedTokenizer(object):
|
||||
resolved_vocab_files[file_id] = None
|
||||
else:
|
||||
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||
except EnvironmentError as e:
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in s3_models:
|
||||
logger.error("Couldn't reach server to download vocabulary.")
|
||||
msg = "Couldn't reach server at '{}' to download vocabulary files."
|
||||
else:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find files {} "
|
||||
"at this path or url.".format(
|
||||
msg = "Model name '{}' was not found in tokenizers model name list ({}). " \
|
||||
"We assumed '{}' was a path or url to a directory containing vocabulary files " \
|
||||
"named {}, but couldn't find such vocabulary files at this path or url.".format(
|
||||
pretrained_model_name_or_path, ', '.join(s3_models),
|
||||
pretrained_model_name_or_path, str(vocab_files.keys())))
|
||||
raise e
|
||||
pretrained_model_name_or_path,
|
||||
list(cls.vocab_files_names.values()))
|
||||
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
for file_id, file_path in vocab_files.items():
|
||||
if file_path == resolved_vocab_files[file_id]:
|
||||
|
Loading…
Reference in New Issue
Block a user