mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
pass langs parameter to certain XLM models (#2734)
* pass langs parameter to certain XLM models Adding an argument that specifies the language the SQuAD dataset is in so language-sensitive XLMs (e.g. `xlm-mlm-tlm-xnli15-1024`) don't default to language `0`. Allows resolution of issue #1799 . * fixing from `make style` * fixing style (again)
This commit is contained in:
parent
9e5b549b4d
commit
d1ab1fab1b
@ -219,6 +219,11 @@ def train(args, train_dataset, model, tokenizer):
|
||||
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
|
||||
if args.version_2_with_negative:
|
||||
inputs.update({"is_impossible": batch[7]})
|
||||
if hasattr(model, "config") and hasattr(model.config, "lang2id"):
|
||||
inputs.update(
|
||||
{"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
|
||||
)
|
||||
|
||||
outputs = model(**inputs)
|
||||
# model outputs are always tuple in transformers (see doc)
|
||||
loss = outputs[0]
|
||||
@ -330,6 +335,11 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
# XLNet and XLM use more arguments for their predictions
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
|
||||
# for lang_id-sensitive xlm models
|
||||
if hasattr(model, "config") and hasattr(model.config, "lang2id"):
|
||||
inputs.update(
|
||||
{"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
|
||||
)
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
@ -635,6 +645,12 @@ def main():
|
||||
help="If true, all of the warnings related to data processing will be printed. "
|
||||
"A number of warnings are expected for a normal SQuAD evaluation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lang_id",
|
||||
default=0,
|
||||
type=int,
|
||||
help="language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)",
|
||||
)
|
||||
|
||||
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
|
||||
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
|
||||
|
Loading…
Reference in New Issue
Block a user