Fix arg names for our models (#20166)

* Fix arg names for our models

* Clean out the other uses of "residx" in infer()

* make fixup
This commit is contained in:
Matt 2022-11-10 16:47:58 +00:00 committed by GitHub
parent 6dda14dc47
commit 68187c4642
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2248,8 +2248,7 @@ class EsmForProteinFolding(EsmPreTrainedModel):
def infer( def infer(
self, self,
seqs: Union[str, List[str]], seqs: Union[str, List[str]],
residx=None, position_ids=None,
with_mask: Optional[torch.Tensor] = None,
): ):
if type(seqs) is str: if type(seqs) is str:
lst = [seqs] lst = [seqs]
@ -2272,17 +2271,17 @@ class EsmForProteinFolding(EsmPreTrainedModel):
] ]
) # B=1 x L ) # B=1 x L
mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst]) mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst])
residx = ( position_ids = (
torch.arange(aatype.shape[1], device=device).expand(len(lst), -1) if residx is None else residx.to(device) torch.arange(aatype.shape[1], device=device).expand(len(lst), -1)
if position_ids is None
else position_ids.to(device)
) )
if residx.ndim == 1: if position_ids.ndim == 1:
residx = residx.unsqueeze(0) position_ids = position_ids.unsqueeze(0)
return self.forward( return self.forward(
aatype, aatype,
mask, mask,
mask_aa=with_mask is not None, position_ids=position_ids,
masking_pattern=with_mask,
residx=residx,
) )
@staticmethod @staticmethod