Fix arg names for our models (#20166)

* Fix arg names for our models

* Clean out the other uses of "residx" in infer()

* make fixup
This commit is contained in:
Matt 2022-11-10 16:47:58 +00:00 committed by GitHub
parent 6dda14dc47
commit 68187c4642
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2248,8 +2248,7 @@ class EsmForProteinFolding(EsmPreTrainedModel):
def infer(
self,
seqs: Union[str, List[str]],
residx=None,
with_mask: Optional[torch.Tensor] = None,
position_ids=None,
):
if type(seqs) is str:
lst = [seqs]
@ -2272,17 +2271,17 @@ class EsmForProteinFolding(EsmPreTrainedModel):
]
) # B=1 x L
mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst])
residx = (
torch.arange(aatype.shape[1], device=device).expand(len(lst), -1) if residx is None else residx.to(device)
position_ids = (
torch.arange(aatype.shape[1], device=device).expand(len(lst), -1)
if position_ids is None
else position_ids.to(device)
)
if residx.ndim == 1:
residx = residx.unsqueeze(0)
if position_ids.ndim == 1:
position_ids = position_ids.unsqueeze(0)
return self.forward(
aatype,
mask,
mask_aa=with_mask is not None,
masking_pattern=with_mask,
residx=residx,
position_ids=position_ids,
)
@staticmethod