Fix DeBERTa + Conversational pipeline slow tests (#10743)

* Fix DeBERTa-v2 variable assignment

* Fix conversational pipeline test
This commit is contained in:
Lysandre Debut 2021-03-16 11:18:20 -04:00 committed by GitHub
parent d3d388b934
commit 1449222217
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View File

@ -573,7 +573,7 @@ class DisentangledSelfAttention(torch.nn.Module):
self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
self.share_att_key = getattr(config, "share_att_key", False)
self.pos_att_type = config.pos_att_type
self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
self.relative_attention = getattr(config, "relative_attention", False)
if self.relative_attention:

View File

@ -340,6 +340,6 @@ class ConversationalPipeline(Pipeline):
# If the tokenizer cannot handle conversations, we default to only the old version
input_ids = [self._legacy_parse_and_tokenize(conversation) for conversation in conversations]
inputs = self.tokenizer.pad(
{"input_ids": input_ids}, padding="longest", return_attention_mask=True, return_tensors="pt"
{"input_ids": input_ids}, padding="longest", return_attention_mask=True, return_tensors=self.framework
)
return inputs