mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 10:41:07 +06:00
Fix arg names for our models (#20166)
* Fix arg names for our models * Clean out the other uses of "residx" in infer() * make fixup
This commit is contained in:
parent
6dda14dc47
commit
68187c4642
@ -2248,8 +2248,7 @@ class EsmForProteinFolding(EsmPreTrainedModel):
|
|||||||
def infer(
|
def infer(
|
||||||
self,
|
self,
|
||||||
seqs: Union[str, List[str]],
|
seqs: Union[str, List[str]],
|
||||||
residx=None,
|
position_ids=None,
|
||||||
with_mask: Optional[torch.Tensor] = None,
|
|
||||||
):
|
):
|
||||||
if type(seqs) is str:
|
if type(seqs) is str:
|
||||||
lst = [seqs]
|
lst = [seqs]
|
||||||
@ -2272,17 +2271,17 @@ class EsmForProteinFolding(EsmPreTrainedModel):
|
|||||||
]
|
]
|
||||||
) # B=1 x L
|
) # B=1 x L
|
||||||
mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst])
|
mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst])
|
||||||
residx = (
|
position_ids = (
|
||||||
torch.arange(aatype.shape[1], device=device).expand(len(lst), -1) if residx is None else residx.to(device)
|
torch.arange(aatype.shape[1], device=device).expand(len(lst), -1)
|
||||||
|
if position_ids is None
|
||||||
|
else position_ids.to(device)
|
||||||
)
|
)
|
||||||
if residx.ndim == 1:
|
if position_ids.ndim == 1:
|
||||||
residx = residx.unsqueeze(0)
|
position_ids = position_ids.unsqueeze(0)
|
||||||
return self.forward(
|
return self.forward(
|
||||||
aatype,
|
aatype,
|
||||||
mask,
|
mask,
|
||||||
mask_aa=with_mask is not None,
|
position_ids=position_ids,
|
||||||
masking_pattern=with_mask,
|
|
||||||
residx=residx,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
Loading…
Reference in New Issue
Block a user