mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Fix arg names for our models (#20166)
* Fix arg names for our models * Clean out the other uses of "residx" in infer() * make fixup
This commit is contained in:
parent
6dda14dc47
commit
68187c4642
@ -2248,8 +2248,7 @@ class EsmForProteinFolding(EsmPreTrainedModel):
|
||||
def infer(
|
||||
self,
|
||||
seqs: Union[str, List[str]],
|
||||
residx=None,
|
||||
with_mask: Optional[torch.Tensor] = None,
|
||||
position_ids=None,
|
||||
):
|
||||
if type(seqs) is str:
|
||||
lst = [seqs]
|
||||
@ -2272,17 +2271,17 @@ class EsmForProteinFolding(EsmPreTrainedModel):
|
||||
]
|
||||
) # B=1 x L
|
||||
mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst])
|
||||
residx = (
|
||||
torch.arange(aatype.shape[1], device=device).expand(len(lst), -1) if residx is None else residx.to(device)
|
||||
position_ids = (
|
||||
torch.arange(aatype.shape[1], device=device).expand(len(lst), -1)
|
||||
if position_ids is None
|
||||
else position_ids.to(device)
|
||||
)
|
||||
if residx.ndim == 1:
|
||||
residx = residx.unsqueeze(0)
|
||||
if position_ids.ndim == 1:
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
return self.forward(
|
||||
aatype,
|
||||
mask,
|
||||
mask_aa=with_mask is not None,
|
||||
masking_pattern=with_mask,
|
||||
residx=residx,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
Loading…
Reference in New Issue
Block a user