mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[ESM] fix accelerate
tests for esmfold (#20387)
* fix `accelerate` tests for esmfold * cleaner solution
This commit is contained in:
parent
d2357a0133
commit
2e17db8a86
@ -638,7 +638,7 @@ class EsmPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = EsmConfig
|
||||
base_model_prefix = "esm"
|
||||
_no_split_modules = ["EsmLayer"]
|
||||
_no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock"]
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
||||
def _init_weights(self, module):
|
||||
|
@ -1956,9 +1956,9 @@ class EsmFoldingTrunk(nn.Module):
|
||||
for recycle_idx in range(no_recycles):
|
||||
with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]):
|
||||
# === Recycling ===
|
||||
recycle_s = self.recycle_s_norm(recycle_s.detach())
|
||||
recycle_z = self.recycle_z_norm(recycle_z.detach())
|
||||
recycle_z += self.recycle_disto(recycle_bins.detach())
|
||||
recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device)
|
||||
recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device)
|
||||
recycle_z += self.recycle_disto(recycle_bins.detach()).to(device)
|
||||
|
||||
s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
|
||||
|
||||
@ -2207,6 +2207,9 @@ class EsmForProteinFolding(EsmPreTrainedModel):
|
||||
return EsmForProteinFoldingOutput(**structure)
|
||||
|
||||
def af2_idx_to_esm_idx(self, aa, mask):
|
||||
# avoid indexing on different devices
|
||||
if self.af2_to_esm.device != aa.device:
|
||||
self.af2_to_esm = self.af2_to_esm.to(aa.device)
|
||||
aa = (aa + 1).masked_fill(mask != 1, 0)
|
||||
return self.af2_to_esm[aa]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user