mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
ESM openfold_utils type hints (#20544)
* add type annotations for esm chunk_utils use isinstance builtin instead of 'type(x) is y'; add assertions to aid in type inferencing; use bools instead of ints in _get_minimal_slice_set for improved type clarity; refactor to avoid re-assigning to the same variable with a different type * add type annotations for esm data_transforms refactor to avoid re-assigning to the same variable with a different type * add type annotations for esm feats utils refactor to avoid re-assigning to the same variable with a different type * add type annotations for esm loss utils * add/fix type annotations for esm rigit_utils refactor to avoid re-assigning to the same variable with a different type; fix Callable, Tuple type hints; match conditional structure to other methods; fix return type on Rotation.cat and Rotation.unsqueeze * add type annotations for esm tensor_utils overload for tree_map; use insinstance builtin instead of 'type(x) is y'; export dict_multimap, flatten_final_dims, permute_final_dims in openfold_utils * add type annotations for esm protein utils add FIXME for attempted string mutation; add missing None check in get_pdb_headers; fix potentially unbound variable 'chain_tag' in to_pdb; modify get_pdb_headers return type * add type annotations for esm residue constants hints on collection constants; remove magic trailing comma to reduce number of lines; change list -> tuple for rigid_group_atom_positions for improved hinting * code style fixup Co-authored-by: Matt <rocketknight1@gmail.com>
This commit is contained in:
parent
8ea6694d92
commit
afe2a466bb
@ -6,3 +6,4 @@ from .loss import compute_predicted_aligned_error, compute_tm
|
||||
from .protein import Protein as OFProtein
|
||||
from .protein import to_pdb
|
||||
from .rigid_utils import Rigid, Rotation
|
||||
from .tensor_utils import dict_multimap, flatten_final_dims, permute_final_dims
|
||||
|
@ -14,23 +14,22 @@
|
||||
import logging
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .tensor_utils import tensor_tree_map, tree_map
|
||||
|
||||
|
||||
def _fetch_dims(tree):
|
||||
def _fetch_dims(tree: Union[dict, list, tuple, torch.Tensor]) -> List[Tuple[int, ...]]:
|
||||
shapes = []
|
||||
tree_type = type(tree)
|
||||
if tree_type is dict:
|
||||
if isinstance(tree, dict):
|
||||
for v in tree.values():
|
||||
shapes.extend(_fetch_dims(v))
|
||||
elif tree_type is list or tree_type is tuple:
|
||||
elif isinstance(tree, (list, tuple)):
|
||||
for t in tree:
|
||||
shapes.extend(_fetch_dims(t))
|
||||
elif tree_type is torch.Tensor:
|
||||
elif isinstance(tree, torch.Tensor):
|
||||
shapes.append(tree.shape)
|
||||
else:
|
||||
raise ValueError("Not supported")
|
||||
@ -39,10 +38,7 @@ def _fetch_dims(tree):
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _flat_idx_to_idx(
|
||||
flat_idx: int,
|
||||
dims: Tuple[int],
|
||||
) -> Tuple[int]:
|
||||
def _flat_idx_to_idx(flat_idx: int, dims: Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
idx = []
|
||||
for d in reversed(dims):
|
||||
idx.append(flat_idx % d)
|
||||
@ -55,10 +51,10 @@ def _flat_idx_to_idx(
|
||||
def _get_minimal_slice_set(
|
||||
start: Sequence[int],
|
||||
end: Sequence[int],
|
||||
dims: int,
|
||||
dims: Sequence[int],
|
||||
start_edges: Optional[Sequence[bool]] = None,
|
||||
end_edges: Optional[Sequence[bool]] = None,
|
||||
) -> Sequence[Tuple[int]]:
|
||||
) -> List[Tuple[slice, ...]]:
|
||||
"""
|
||||
Produces an ordered sequence of tensor slices that, when used in sequence on a tensor with shape dims, yields
|
||||
tensors that contain every leaf in the contiguous range [start, end]. Care is taken to yield a short sequence of
|
||||
@ -69,11 +65,11 @@ def _get_minimal_slice_set(
|
||||
# start_edges and end_edges both indicate whether, starting from any given
|
||||
# dimension, the start/end index is at the top/bottom edge of the
|
||||
# corresponding tensor, modeled as a tree
|
||||
def reduce_edge_list(l):
|
||||
tally = 1
|
||||
def reduce_edge_list(l: List[bool]) -> None:
|
||||
tally = True
|
||||
for i in range(len(l)):
|
||||
reversed_idx = -1 * (i + 1)
|
||||
l[reversed_idx] *= tally
|
||||
l[reversed_idx] &= tally
|
||||
tally = l[reversed_idx]
|
||||
|
||||
if start_edges is None:
|
||||
@ -90,48 +86,54 @@ def _get_minimal_slice_set(
|
||||
elif len(start) == 1:
|
||||
return [(slice(start[0], end[0] + 1),)]
|
||||
|
||||
slices = []
|
||||
path = []
|
||||
slices: List[Tuple[slice, ...]] = []
|
||||
path_list: List[slice] = []
|
||||
|
||||
# Dimensions common to start and end can be selected directly
|
||||
for s, e in zip(start, end):
|
||||
if s == e:
|
||||
path.append(slice(s, s + 1))
|
||||
path_list.append(slice(s, s + 1))
|
||||
else:
|
||||
break
|
||||
|
||||
path = tuple(path)
|
||||
path: Tuple[slice, ...] = tuple(path_list)
|
||||
divergence_idx = len(path)
|
||||
|
||||
# start == end, and we're done
|
||||
if divergence_idx == len(dims):
|
||||
return [tuple(path)]
|
||||
return [path]
|
||||
|
||||
def upper() -> Tuple[Tuple[slice, ...], ...]:
|
||||
assert start_edges is not None
|
||||
assert end_edges is not None
|
||||
|
||||
def upper():
|
||||
sdi = start[divergence_idx]
|
||||
return [
|
||||
return tuple(
|
||||
path + (slice(sdi, sdi + 1),) + s
|
||||
for s in _get_minimal_slice_set(
|
||||
start[divergence_idx + 1 :],
|
||||
[d - 1 for d in dims[divergence_idx + 1 :]],
|
||||
dims[divergence_idx + 1 :],
|
||||
start_edges=start_edges[divergence_idx + 1 :],
|
||||
end_edges=[1 for _ in end_edges[divergence_idx + 1 :]],
|
||||
end_edges=[True for _ in end_edges[divergence_idx + 1 :]],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def lower() -> Tuple[Tuple[slice, ...], ...]:
|
||||
assert start_edges is not None
|
||||
assert end_edges is not None
|
||||
|
||||
def lower():
|
||||
edi = end[divergence_idx]
|
||||
return [
|
||||
return tuple(
|
||||
path + (slice(edi, edi + 1),) + s
|
||||
for s in _get_minimal_slice_set(
|
||||
[0 for _ in start[divergence_idx + 1 :]],
|
||||
end[divergence_idx + 1 :],
|
||||
dims[divergence_idx + 1 :],
|
||||
start_edges=[1 for _ in start_edges[divergence_idx + 1 :]],
|
||||
start_edges=[True for _ in start_edges[divergence_idx + 1 :]],
|
||||
end_edges=end_edges[divergence_idx + 1 :],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# If both start and end are at the edges of the subtree rooted at
|
||||
# divergence_idx, we can just select the whole subtree at once
|
||||
@ -156,16 +158,11 @@ def _get_minimal_slice_set(
|
||||
slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx]),))
|
||||
slices.extend(lower())
|
||||
|
||||
return [tuple(s) for s in slices]
|
||||
return slices
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk_slice(
|
||||
t: torch.Tensor,
|
||||
flat_start: int,
|
||||
flat_end: int,
|
||||
no_batch_dims: int,
|
||||
) -> torch.Tensor:
|
||||
def _chunk_slice(t: torch.Tensor, flat_start: int, flat_end: int, no_batch_dims: int) -> torch.Tensor:
|
||||
"""
|
||||
Equivalent to
|
||||
|
||||
@ -232,7 +229,7 @@ def chunk_layer(
|
||||
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
|
||||
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
|
||||
|
||||
def _prep_inputs(t):
|
||||
def _prep_inputs(t: torch.Tensor) -> torch.Tensor:
|
||||
if not low_mem:
|
||||
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
|
||||
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
|
||||
@ -241,7 +238,7 @@ def chunk_layer(
|
||||
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
|
||||
return t
|
||||
|
||||
prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
|
||||
prepped_inputs: Dict[str, Any] = tensor_tree_map(_prep_inputs, inputs)
|
||||
prepped_outputs = None
|
||||
if _out is not None:
|
||||
prepped_outputs = tensor_tree_map(lambda t: t.view([-1] + list(t.shape[no_batch_dims:])), _out)
|
||||
@ -252,7 +249,7 @@ def chunk_layer(
|
||||
|
||||
no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0)
|
||||
|
||||
def _select_chunk(t):
|
||||
def _select_chunk(t: torch.Tensor) -> torch.Tensor:
|
||||
return t[i : i + chunk_size] if t.shape[0] != 1 else t
|
||||
|
||||
i = 0
|
||||
@ -269,7 +266,7 @@ def chunk_layer(
|
||||
no_batch_dims=len(orig_batch_dims),
|
||||
)
|
||||
|
||||
chunks = tensor_tree_map(select_chunk, prepped_inputs)
|
||||
chunks: Dict[str, Any] = tensor_tree_map(select_chunk, prepped_inputs)
|
||||
|
||||
# Run the layer on the chunk
|
||||
output_chunk = layer(**chunks)
|
||||
@ -279,12 +276,11 @@ def chunk_layer(
|
||||
out = tensor_tree_map(lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]), output_chunk)
|
||||
|
||||
# Put the chunk in its pre-allocated space
|
||||
out_type = type(output_chunk)
|
||||
if out_type is dict:
|
||||
if isinstance(output_chunk, dict):
|
||||
|
||||
def assign(d1, d2):
|
||||
def assign(d1: dict, d2: dict) -> None:
|
||||
for k, v in d1.items():
|
||||
if type(v) is dict:
|
||||
if isinstance(v, dict):
|
||||
assign(v, d2[k])
|
||||
else:
|
||||
if _add_into_out:
|
||||
@ -293,13 +289,13 @@ def chunk_layer(
|
||||
v[i : i + chunk_size] = d2[k]
|
||||
|
||||
assign(out, output_chunk)
|
||||
elif out_type is tuple:
|
||||
elif isinstance(output_chunk, tuple):
|
||||
for x1, x2 in zip(out, output_chunk):
|
||||
if _add_into_out:
|
||||
x1[i : i + chunk_size] += x2
|
||||
else:
|
||||
x1[i : i + chunk_size] = x2
|
||||
elif out_type is torch.Tensor:
|
||||
elif isinstance(output_chunk, torch.Tensor):
|
||||
if _add_into_out:
|
||||
out[i : i + chunk_size] += output_chunk
|
||||
else:
|
||||
@ -319,24 +315,24 @@ class ChunkSizeTuner:
|
||||
self,
|
||||
# Heuristically, runtimes for most of the modules in the network
|
||||
# plateau earlier than this on all GPUs I've run the model on.
|
||||
max_chunk_size=512,
|
||||
max_chunk_size: int = 512,
|
||||
):
|
||||
self.max_chunk_size = max_chunk_size
|
||||
self.cached_chunk_size = None
|
||||
self.cached_arg_data = None
|
||||
self.cached_chunk_size: Optional[int] = None
|
||||
self.cached_arg_data: Optional[tuple] = None
|
||||
|
||||
def _determine_favorable_chunk_size(self, fn, args, min_chunk_size):
|
||||
def _determine_favorable_chunk_size(self, fn: Callable, args: tuple, min_chunk_size: int) -> int:
|
||||
logging.info("Tuning chunk size...")
|
||||
|
||||
if min_chunk_size >= self.max_chunk_size:
|
||||
return min_chunk_size
|
||||
|
||||
candidates = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
|
||||
candidates: List[int] = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
|
||||
candidates = [c for c in candidates if c > min_chunk_size]
|
||||
candidates = [min_chunk_size] + candidates
|
||||
candidates[-1] += 4
|
||||
|
||||
def test_chunk_size(chunk_size):
|
||||
def test_chunk_size(chunk_size: int) -> bool:
|
||||
try:
|
||||
with torch.no_grad():
|
||||
fn(*args, chunk_size=chunk_size)
|
||||
@ -356,13 +352,13 @@ class ChunkSizeTuner:
|
||||
|
||||
return candidates[min_viable_chunk_size_index]
|
||||
|
||||
def _compare_arg_caches(self, ac1, ac2):
|
||||
def _compare_arg_caches(self, ac1: Iterable, ac2: Iterable) -> bool:
|
||||
consistent = True
|
||||
for a1, a2 in zip(ac1, ac2):
|
||||
assert type(ac1) == type(ac2)
|
||||
if type(ac1) is list or type(ac1) is tuple:
|
||||
if isinstance(ac1, (list, tuple)):
|
||||
consistent &= self._compare_arg_caches(a1, a2)
|
||||
elif type(ac1) is dict:
|
||||
elif isinstance(ac1, dict):
|
||||
a1_items = [v for _, v in sorted(a1.items(), key=lambda x: x[0])]
|
||||
a2_items = [v for _, v in sorted(a2.items(), key=lambda x: x[0])]
|
||||
consistent &= self._compare_arg_caches(a1_items, a2_items)
|
||||
@ -374,11 +370,11 @@ class ChunkSizeTuner:
|
||||
def tune_chunk_size(
|
||||
self,
|
||||
representative_fn: Callable,
|
||||
args: Tuple[Any],
|
||||
args: tuple,
|
||||
min_chunk_size: int,
|
||||
) -> int:
|
||||
consistent = True
|
||||
arg_data = tree_map(lambda a: a.shape if type(a) is torch.Tensor else a, args, object)
|
||||
arg_data: tuple = tree_map(lambda a: a.shape if isinstance(a, torch.Tensor) else a, args, object)
|
||||
if self.cached_arg_data is not None:
|
||||
# If args have changed shape/value, we need to re-tune
|
||||
assert len(self.cached_arg_data) == len(arg_data)
|
||||
@ -395,4 +391,6 @@ class ChunkSizeTuner:
|
||||
)
|
||||
self.cached_arg_data = arg_data
|
||||
|
||||
assert self.cached_chunk_size is not None
|
||||
|
||||
return self.cached_chunk_size
|
||||
|
@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -20,39 +22,39 @@ from . import residue_constants as rc
|
||||
from .tensor_utils import tensor_tree_map, tree_map
|
||||
|
||||
|
||||
def make_atom14_masks(protein):
|
||||
def make_atom14_masks(protein: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Construct denser atom positions (14 dimensions instead of 37)."""
|
||||
restype_atom14_to_atom37 = []
|
||||
restype_atom37_to_atom14 = []
|
||||
restype_atom14_mask = []
|
||||
restype_atom14_to_atom37_list = []
|
||||
restype_atom37_to_atom14_list = []
|
||||
restype_atom14_mask_list = []
|
||||
|
||||
for rt in rc.restypes:
|
||||
atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
|
||||
restype_atom14_to_atom37.append([(rc.atom_order[name] if name else 0) for name in atom_names])
|
||||
restype_atom14_to_atom37_list.append([(rc.atom_order[name] if name else 0) for name in atom_names])
|
||||
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
|
||||
restype_atom37_to_atom14.append(
|
||||
restype_atom37_to_atom14_list.append(
|
||||
[(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in rc.atom_types]
|
||||
)
|
||||
|
||||
restype_atom14_mask.append([(1.0 if name else 0.0) for name in atom_names])
|
||||
restype_atom14_mask_list.append([(1.0 if name else 0.0) for name in atom_names])
|
||||
|
||||
# Add dummy mapping for restype 'UNK'
|
||||
restype_atom14_to_atom37.append([0] * 14)
|
||||
restype_atom37_to_atom14.append([0] * 37)
|
||||
restype_atom14_mask.append([0.0] * 14)
|
||||
restype_atom14_to_atom37_list.append([0] * 14)
|
||||
restype_atom37_to_atom14_list.append([0] * 37)
|
||||
restype_atom14_mask_list.append([0.0] * 14)
|
||||
|
||||
restype_atom14_to_atom37 = torch.tensor(
|
||||
restype_atom14_to_atom37,
|
||||
restype_atom14_to_atom37_list,
|
||||
dtype=torch.int32,
|
||||
device=protein["aatype"].device,
|
||||
)
|
||||
restype_atom37_to_atom14 = torch.tensor(
|
||||
restype_atom37_to_atom14,
|
||||
restype_atom37_to_atom14_list,
|
||||
dtype=torch.int32,
|
||||
device=protein["aatype"].device,
|
||||
)
|
||||
restype_atom14_mask = torch.tensor(
|
||||
restype_atom14_mask,
|
||||
restype_atom14_mask_list,
|
||||
dtype=torch.float32,
|
||||
device=protein["aatype"].device,
|
||||
)
|
||||
@ -85,8 +87,7 @@ def make_atom14_masks(protein):
|
||||
return protein
|
||||
|
||||
|
||||
def make_atom14_masks_np(batch):
|
||||
def make_atom14_masks_np(batch: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
|
||||
batch = tree_map(lambda n: torch.tensor(n, device=batch["aatype"].device), batch, np.ndarray)
|
||||
out = make_atom14_masks(batch)
|
||||
out = tensor_tree_map(lambda t: np.array(t), out)
|
||||
out = tensor_tree_map(lambda t: np.array(t), make_atom14_masks(batch))
|
||||
return out
|
||||
|
@ -13,14 +13,29 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Tuple, overload
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.types
|
||||
from torch import nn
|
||||
|
||||
from . import residue_constants as rc
|
||||
from .rigid_utils import Rigid, Rotation
|
||||
from .tensor_utils import batched_gather
|
||||
|
||||
|
||||
@overload
|
||||
def pseudo_beta_fn(aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_masks: None) -> torch.Tensor:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def pseudo_beta_fn(
|
||||
aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_masks: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
...
|
||||
|
||||
|
||||
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
|
||||
is_gly = aatype == rc.restype_order["G"]
|
||||
ca_idx = rc.atom_order["CA"]
|
||||
@ -42,7 +57,7 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
|
||||
return pseudo_beta
|
||||
|
||||
|
||||
def atom14_to_atom37(atom14, batch):
|
||||
def atom14_to_atom37(atom14: torch.Tensor, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
atom37_data = batched_gather(
|
||||
atom14,
|
||||
batch["residx_atom37_to_atom14"],
|
||||
@ -55,7 +70,7 @@ def atom14_to_atom37(atom14, batch):
|
||||
return atom37_data
|
||||
|
||||
|
||||
def build_template_angle_feat(template_feats):
|
||||
def build_template_angle_feat(template_feats: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
template_aatype = template_feats["template_aatype"]
|
||||
torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
|
||||
alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"]
|
||||
@ -73,7 +88,15 @@ def build_template_angle_feat(template_feats):
|
||||
return template_angle_feat
|
||||
|
||||
|
||||
def build_template_pair_feat(batch, min_bin, max_bin, no_bins, use_unit_vector=False, eps=1e-20, inf=1e8):
|
||||
def build_template_pair_feat(
|
||||
batch: Dict[str, torch.Tensor],
|
||||
min_bin: torch.types.Number,
|
||||
max_bin: torch.types.Number,
|
||||
no_bins: int,
|
||||
use_unit_vector: bool = False,
|
||||
eps: float = 1e-20,
|
||||
inf: float = 1e8,
|
||||
) -> torch.Tensor:
|
||||
template_mask = batch["template_pseudo_beta_mask"]
|
||||
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
|
||||
|
||||
@ -86,7 +109,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, use_unit_vector=F
|
||||
|
||||
to_concat = [dgram, template_mask_2d[..., None]]
|
||||
|
||||
aatype_one_hot = nn.functional.one_hot(
|
||||
aatype_one_hot: torch.LongTensor = nn.functional.one_hot(
|
||||
batch["template_aatype"],
|
||||
rc.restype_num + 2,
|
||||
)
|
||||
@ -126,8 +149,8 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, use_unit_vector=F
|
||||
return act
|
||||
|
||||
|
||||
def build_extra_msa_feat(batch):
|
||||
msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23)
|
||||
def build_extra_msa_feat(batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
msa_1hot: torch.LongTensor = nn.functional.one_hot(batch["extra_msa"], 23)
|
||||
msa_feat = [
|
||||
msa_1hot,
|
||||
batch["extra_has_deletion"].unsqueeze(-1),
|
||||
@ -141,7 +164,7 @@ def torsion_angles_to_frames(
|
||||
alpha: torch.Tensor,
|
||||
aatype: torch.Tensor,
|
||||
rrgdf: torch.Tensor,
|
||||
):
|
||||
) -> Rigid:
|
||||
# [*, N, 8, 4, 4]
|
||||
default_4x4 = rrgdf[aatype, ...]
|
||||
|
||||
@ -172,9 +195,7 @@ def torsion_angles_to_frames(
|
||||
all_rots[..., 1, 2] = -alpha[..., 0]
|
||||
all_rots[..., 2, 1:] = alpha
|
||||
|
||||
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
|
||||
|
||||
all_frames = default_r.compose(all_rots)
|
||||
all_frames = default_r.compose(Rigid(Rotation(rot_mats=all_rots), None))
|
||||
|
||||
chi2_frame_to_frame = all_frames[..., 5]
|
||||
chi3_frame_to_frame = all_frames[..., 6]
|
||||
@ -203,22 +224,22 @@ def torsion_angles_to_frames(
|
||||
def frames_and_literature_positions_to_atom14_pos(
|
||||
r: Rigid,
|
||||
aatype: torch.Tensor,
|
||||
default_frames,
|
||||
group_idx,
|
||||
atom_mask,
|
||||
lit_positions,
|
||||
):
|
||||
default_frames: torch.Tensor,
|
||||
group_idx: torch.Tensor,
|
||||
atom_mask: torch.Tensor,
|
||||
lit_positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# [*, N, 14]
|
||||
group_mask = group_idx[aatype, ...]
|
||||
|
||||
# [*, N, 14, 8]
|
||||
group_mask = nn.functional.one_hot(
|
||||
group_mask_one_hot: torch.LongTensor = nn.functional.one_hot(
|
||||
group_mask,
|
||||
num_classes=default_frames.shape[-3],
|
||||
)
|
||||
|
||||
# [*, N, 14, 8]
|
||||
t_atoms_to_global = r[..., None, :] * group_mask
|
||||
t_atoms_to_global = r[..., None, :] * group_mask_one_hot
|
||||
|
||||
# [*, N, 14]
|
||||
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
|
||||
|
@ -18,7 +18,7 @@ from typing import Dict, Optional, Tuple
|
||||
import torch
|
||||
|
||||
|
||||
def _calculate_bin_centers(boundaries: torch.Tensor):
|
||||
def _calculate_bin_centers(boundaries: torch.Tensor) -> torch.Tensor:
|
||||
step = boundaries[1] - boundaries[0]
|
||||
bin_centers = boundaries + step / 2
|
||||
bin_centers = torch.cat([bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0)
|
||||
|
@ -17,7 +17,7 @@
|
||||
import dataclasses
|
||||
import re
|
||||
import string
|
||||
from typing import Any, Mapping, Optional, Sequence
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -69,10 +69,10 @@ class Protein:
|
||||
|
||||
def from_proteinnet_string(proteinnet_str: str) -> Protein:
|
||||
tag_re = r"(\[[A-Z]+\]\n)"
|
||||
tags = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0]
|
||||
groups = zip(tags[0::2], [l.split("\n") for l in tags[1::2]])
|
||||
tags: List[str] = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0]
|
||||
groups: Iterator[Tuple[str, List[str]]] = zip(tags[0::2], [l.split("\n") for l in tags[1::2]])
|
||||
|
||||
atoms = ["N", "CA", "C"]
|
||||
atoms: List[str] = ["N", "CA", "C"]
|
||||
aatype = None
|
||||
atom_positions = None
|
||||
atom_mask = None
|
||||
@ -81,12 +81,12 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
|
||||
seq = g[1][0].strip()
|
||||
for i in range(len(seq)):
|
||||
if seq[i] not in residue_constants.restypes:
|
||||
seq[i] = "X"
|
||||
seq[i] = "X" # FIXME: strings are immutable
|
||||
aatype = np.array(
|
||||
[residue_constants.restype_order.get(res_symbol, residue_constants.restype_num) for res_symbol in seq]
|
||||
)
|
||||
elif "[TERTIARY]" == g[0]:
|
||||
tertiary = []
|
||||
tertiary: List[List[float]] = []
|
||||
for axis in range(3):
|
||||
tertiary.append(list(map(float, g[1][axis].split())))
|
||||
tertiary_np = np.array(tertiary)
|
||||
@ -106,6 +106,8 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
|
||||
atom_mask[:, residue_constants.atom_order[atom]] = 1
|
||||
atom_mask *= mask[..., None]
|
||||
|
||||
assert aatype is not None
|
||||
|
||||
return Protein(
|
||||
atom_positions=atom_positions,
|
||||
atom_mask=atom_mask,
|
||||
@ -115,8 +117,8 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
|
||||
)
|
||||
|
||||
|
||||
def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
|
||||
pdb_headers = []
|
||||
def get_pdb_headers(prot: Protein, chain_id: int = 0) -> List[str]:
|
||||
pdb_headers: List[str] = []
|
||||
|
||||
remark = prot.remark
|
||||
if remark is not None:
|
||||
@ -124,7 +126,7 @@ def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
|
||||
|
||||
parents = prot.parents
|
||||
parents_chain_index = prot.parents_chain_index
|
||||
if parents_chain_index is not None:
|
||||
if parents is not None and parents_chain_index is not None:
|
||||
parents = [p for i, p in zip(parents_chain_index, parents) if i == chain_id]
|
||||
|
||||
if parents is None or len(parents) == 0:
|
||||
@ -139,18 +141,18 @@ def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
|
||||
"""Add pdb headers to an existing PDB string. Useful during multi-chain
|
||||
recycling
|
||||
"""
|
||||
out_pdb_lines = []
|
||||
out_pdb_lines: List[str] = []
|
||||
lines = pdb_str.split("\n")
|
||||
|
||||
remark = prot.remark
|
||||
if remark is not None:
|
||||
out_pdb_lines.append(f"REMARK {remark}")
|
||||
|
||||
parents_per_chain = None
|
||||
parents_per_chain: List[List[str]]
|
||||
if prot.parents is not None and len(prot.parents) > 0:
|
||||
parents_per_chain = []
|
||||
if prot.parents_chain_index is not None:
|
||||
parent_dict = {}
|
||||
parent_dict: Dict[str, List[str]] = {}
|
||||
for p, i in zip(prot.parents, prot.parents_chain_index):
|
||||
parent_dict.setdefault(str(i), [])
|
||||
parent_dict[str(i)].append(p)
|
||||
@ -160,11 +162,11 @@ def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
|
||||
chain_parents = parent_dict.get(str(i), ["N/A"])
|
||||
parents_per_chain.append(chain_parents)
|
||||
else:
|
||||
parents_per_chain.append(prot.parents)
|
||||
parents_per_chain.append(list(prot.parents))
|
||||
else:
|
||||
parents_per_chain = [["N/A"]]
|
||||
|
||||
def make_parent_line(p):
|
||||
def make_parent_line(p: Sequence[str]) -> str:
|
||||
return f"PARENT {' '.join(p)}"
|
||||
|
||||
out_pdb_lines.append(make_parent_line(parents_per_chain[0]))
|
||||
@ -196,12 +198,12 @@ def to_pdb(prot: Protein) -> str:
|
||||
"""
|
||||
restypes = residue_constants.restypes + ["X"]
|
||||
|
||||
def res_1to3(r):
|
||||
def res_1to3(r: int) -> str:
|
||||
return residue_constants.restype_1to3.get(restypes[r], "UNK")
|
||||
|
||||
atom_types = residue_constants.atom_types
|
||||
|
||||
pdb_lines = []
|
||||
pdb_lines: List[str] = []
|
||||
|
||||
atom_mask = prot.atom_mask
|
||||
aatype = prot.aatype
|
||||
@ -221,6 +223,7 @@ def to_pdb(prot: Protein) -> str:
|
||||
atom_index = 1
|
||||
prev_chain_index = 0
|
||||
chain_tags = string.ascii_uppercase
|
||||
chain_tag = None
|
||||
# Add all atom sites.
|
||||
for i in range(n):
|
||||
res_name_3 = res_1to3(aatype[i])
|
||||
@ -313,15 +316,12 @@ def from_prediction(
|
||||
Returns:
|
||||
A protein instance.
|
||||
"""
|
||||
if b_factors is None:
|
||||
b_factors = np.zeros_like(result["final_atom_mask"])
|
||||
|
||||
return Protein(
|
||||
aatype=features["aatype"],
|
||||
atom_positions=result["final_atom_positions"],
|
||||
atom_mask=result["final_atom_mask"],
|
||||
residue_index=features["residue_index"] + 1,
|
||||
b_factors=b_factors,
|
||||
b_factors=b_factors if b_factors is not None else np.zeros_like(result["final_atom_mask"]),
|
||||
chain_index=chain_index,
|
||||
remark=remark,
|
||||
parents=parents,
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -16,7 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Optional, Sequence, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -33,7 +33,7 @@ def rot_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
The product ab
|
||||
"""
|
||||
|
||||
def row_mul(i):
|
||||
def row_mul(i: int) -> torch.Tensor:
|
||||
return torch.stack(
|
||||
[
|
||||
a[..., i, 0] * b[..., 0, 0] + a[..., i, 1] * b[..., 1, 0] + a[..., i, 2] * b[..., 2, 0],
|
||||
@ -76,7 +76,7 @@ def rot_vec_mul(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def identity_rot_mats(
|
||||
batch_dims: Tuple[int],
|
||||
batch_dims: Tuple[int, ...],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
requires_grad: bool = True,
|
||||
@ -91,7 +91,7 @@ def identity_rot_mats(
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def identity_trans(
|
||||
batch_dims: Tuple[int],
|
||||
batch_dims: Tuple[int, ...],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
requires_grad: bool = True,
|
||||
@ -102,7 +102,7 @@ def identity_trans(
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def identity_quats(
|
||||
batch_dims: Tuple[int],
|
||||
batch_dims: Tuple[int, ...],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
requires_grad: bool = True,
|
||||
@ -115,15 +115,14 @@ def identity_quats(
|
||||
return quat
|
||||
|
||||
|
||||
_quat_elements = ["a", "b", "c", "d"]
|
||||
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
|
||||
_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
|
||||
_quat_elements: List[str] = ["a", "b", "c", "d"]
|
||||
_qtr_keys: List[str] = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
|
||||
_qtr_ind_dict: Dict[str, int] = {key: ind for ind, key in enumerate(_qtr_keys)}
|
||||
|
||||
|
||||
def _to_mat(pairs):
|
||||
def _to_mat(pairs: List[Tuple[str, int]]) -> np.ndarray:
|
||||
mat = np.zeros((4, 4))
|
||||
for pair in pairs:
|
||||
key, value = pair
|
||||
for key, value in pairs:
|
||||
ind = _qtr_ind_dict[key]
|
||||
mat[ind // 4][ind % 4] = value
|
||||
|
||||
@ -165,14 +164,11 @@ def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
|
||||
return torch.sum(quat, dim=(-3, -4))
|
||||
|
||||
|
||||
def rot_to_quat(
|
||||
rot: torch.Tensor,
|
||||
):
|
||||
def rot_to_quat(rot: torch.Tensor) -> torch.Tensor:
|
||||
if rot.shape[-2:] != (3, 3):
|
||||
raise ValueError("Input rotation is incorrectly shaped")
|
||||
|
||||
rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
|
||||
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
|
||||
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = [[rot[..., i, j] for j in range(3)] for i in range(3)]
|
||||
|
||||
k = [
|
||||
[
|
||||
@ -201,9 +197,7 @@ def rot_to_quat(
|
||||
],
|
||||
]
|
||||
|
||||
k = (1.0 / 3.0) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)
|
||||
|
||||
_, vectors = torch.linalg.eigh(k)
|
||||
_, vectors = torch.linalg.eigh((1.0 / 3.0) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2))
|
||||
return vectors[..., -1]
|
||||
|
||||
|
||||
@ -218,7 +212,7 @@ _QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], [1, 0, 0,
|
||||
|
||||
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
|
||||
|
||||
_CACHED_QUATS = {
|
||||
_CACHED_QUATS: Dict[str, np.ndarray] = {
|
||||
"_QTR_MAT": _QTR_MAT,
|
||||
"_QUAT_MULTIPLY": _QUAT_MULTIPLY,
|
||||
"_QUAT_MULTIPLY_BY_VEC": _QUAT_MULTIPLY_BY_VEC,
|
||||
@ -226,29 +220,29 @@ _CACHED_QUATS = {
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _get_quat(quat_key, dtype, device):
|
||||
def _get_quat(quat_key: str, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
||||
return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device)
|
||||
|
||||
|
||||
def quat_multiply(quat1, quat2):
|
||||
def quat_multiply(quat1: torch.Tensor, quat2: torch.Tensor) -> torch.Tensor:
|
||||
"""Multiply a quaternion by another quaternion."""
|
||||
mat = _get_quat("_QUAT_MULTIPLY", dtype=quat1.dtype, device=quat1.device)
|
||||
reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
|
||||
return torch.sum(reshaped_mat * quat1[..., :, None, None] * quat2[..., None, :, None], dim=(-3, -2))
|
||||
|
||||
|
||||
def quat_multiply_by_vec(quat, vec):
|
||||
def quat_multiply_by_vec(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
|
||||
"""Multiply a quaternion by a pure-vector quaternion."""
|
||||
mat = _get_quat("_QUAT_MULTIPLY_BY_VEC", dtype=quat.dtype, device=quat.device)
|
||||
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
|
||||
return torch.sum(reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], dim=(-3, -2))
|
||||
|
||||
|
||||
def invert_rot_mat(rot_mat: torch.Tensor):
|
||||
def invert_rot_mat(rot_mat: torch.Tensor) -> torch.Tensor:
|
||||
return rot_mat.transpose(-1, -2)
|
||||
|
||||
|
||||
def invert_quat(quat: torch.Tensor):
|
||||
def invert_quat(quat: torch.Tensor) -> torch.Tensor:
|
||||
quat_prime = quat.clone()
|
||||
quat_prime[..., 1:] *= -1
|
||||
inv = quat_prime / torch.sum(quat**2, dim=-1, keepdim=True)
|
||||
@ -361,10 +355,7 @@ class Rotation:
|
||||
else:
|
||||
raise ValueError("Both rotations are None")
|
||||
|
||||
def __mul__(
|
||||
self,
|
||||
right: torch.Tensor,
|
||||
) -> Rotation:
|
||||
def __mul__(self, right: torch.Tensor) -> Rotation:
|
||||
"""
|
||||
Pointwise left multiplication of the rotation with a tensor. Can be used to e.g. mask the Rotation.
|
||||
|
||||
@ -386,10 +377,7 @@ class Rotation:
|
||||
else:
|
||||
raise ValueError("Both rotations are None")
|
||||
|
||||
def __rmul__(
|
||||
self,
|
||||
left: torch.Tensor,
|
||||
) -> Rotation:
|
||||
def __rmul__(self, left: torch.Tensor) -> Rotation:
|
||||
"""
|
||||
Reverse pointwise multiplication of the rotation with a tensor.
|
||||
|
||||
@ -413,13 +401,12 @@ class Rotation:
|
||||
Returns:
|
||||
The virtual shape of the rotation object
|
||||
"""
|
||||
s = None
|
||||
if self._quats is not None:
|
||||
s = self._quats.shape[:-1]
|
||||
if self._rot_mats is not None:
|
||||
return self._rot_mats.shape[:-2]
|
||||
elif self._quats is not None:
|
||||
return self._quats.shape[:-1]
|
||||
else:
|
||||
s = self._rot_mats.shape[:-2]
|
||||
|
||||
return s
|
||||
raise ValueError("Both rotations are None")
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
@ -473,14 +460,12 @@ class Rotation:
|
||||
Returns:
|
||||
The rotation as a rotation matrix tensor
|
||||
"""
|
||||
rot_mats = self._rot_mats
|
||||
if rot_mats is None:
|
||||
if self._quats is None:
|
||||
raise ValueError("Both rotations are None")
|
||||
else:
|
||||
rot_mats = quat_to_rot(self._quats)
|
||||
|
||||
return rot_mats
|
||||
if self._rot_mats is not None:
|
||||
return self._rot_mats
|
||||
elif self._quats is not None:
|
||||
return quat_to_rot(self._quats)
|
||||
else:
|
||||
raise ValueError("Both rotations are None")
|
||||
|
||||
def get_quats(self) -> torch.Tensor:
|
||||
"""
|
||||
@ -491,14 +476,12 @@ class Rotation:
|
||||
Returns:
|
||||
The rotation as a quaternion tensor.
|
||||
"""
|
||||
quats = self._quats
|
||||
if quats is None:
|
||||
if self._rot_mats is None:
|
||||
raise ValueError("Both rotations are None")
|
||||
else:
|
||||
quats = rot_to_quat(self._rot_mats)
|
||||
|
||||
return quats
|
||||
if self._rot_mats is not None:
|
||||
return rot_to_quat(self._rot_mats)
|
||||
elif self._quats is not None:
|
||||
return self._quats
|
||||
else:
|
||||
raise ValueError("Both rotations are None")
|
||||
|
||||
def get_cur_rot(self) -> torch.Tensor:
|
||||
"""
|
||||
@ -618,10 +601,7 @@ class Rotation:
|
||||
|
||||
# "Tensor" stuff
|
||||
|
||||
def unsqueeze(
|
||||
self,
|
||||
dim: int,
|
||||
) -> Rigid:
|
||||
def unsqueeze(self, dim: int) -> Rotation:
|
||||
"""
|
||||
Analogous to torch.unsqueeze. The dimension is relative to the shape of the Rotation object.
|
||||
|
||||
@ -643,10 +623,7 @@ class Rotation:
|
||||
raise ValueError("Both rotations are None")
|
||||
|
||||
@staticmethod
|
||||
def cat(
|
||||
rs: Sequence[Rotation],
|
||||
dim: int,
|
||||
) -> Rigid:
|
||||
def cat(rs: Sequence[Rotation], dim: int) -> Rotation:
|
||||
"""
|
||||
Concatenates rotations along one of the batch dimensions. Analogous to torch.cat().
|
||||
|
||||
@ -661,12 +638,14 @@ class Rotation:
|
||||
Returns:
|
||||
A concatenated Rotation object in rotation matrix format
|
||||
"""
|
||||
rot_mats = [r.get_rot_mats() for r in rs]
|
||||
rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
|
||||
rot_mats = torch.cat(
|
||||
[r.get_rot_mats() for r in rs],
|
||||
dim=dim if dim >= 0 else dim - 2,
|
||||
)
|
||||
|
||||
return Rotation(rot_mats=rot_mats, quats=None)
|
||||
|
||||
def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rotation:
|
||||
def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rotation:
|
||||
"""
|
||||
Apply a Tensor -> Tensor function to underlying rotation tensors, mapping over the rotation dimension(s). Can
|
||||
be used e.g. to sum out a one-hot batch dimension.
|
||||
@ -754,11 +733,7 @@ class Rigid:
|
||||
dimensions of its component parts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rots: Optional[Rotation],
|
||||
trans: Optional[torch.Tensor],
|
||||
):
|
||||
def __init__(self, rots: Optional[Rotation], trans: Optional[torch.Tensor]):
|
||||
"""
|
||||
Args:
|
||||
rots: A [*, 3, 3] rotation tensor
|
||||
@ -795,6 +770,9 @@ class Rigid:
|
||||
requires_grad,
|
||||
)
|
||||
|
||||
assert rots is not None
|
||||
assert trans is not None
|
||||
|
||||
if (rots.shape != trans.shape[:-1]) or (rots.device != trans.device):
|
||||
raise ValueError("Rots and trans incompatible")
|
||||
|
||||
@ -806,7 +784,7 @@ class Rigid:
|
||||
|
||||
@staticmethod
|
||||
def identity(
|
||||
shape: Tuple[int],
|
||||
shape: Tuple[int, ...],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
requires_grad: bool = True,
|
||||
@ -832,10 +810,7 @@ class Rigid:
|
||||
identity_trans(shape, dtype, device, requires_grad),
|
||||
)
|
||||
|
||||
def __getitem__(
|
||||
self,
|
||||
index: Any,
|
||||
) -> Rigid:
|
||||
def __getitem__(self, index: Any) -> Rigid:
|
||||
"""
|
||||
Indexes the affine transformation with PyTorch-style indices. The index is applied to the shared dimensions of
|
||||
both the rotation and the translation.
|
||||
@ -860,10 +835,7 @@ class Rigid:
|
||||
self._trans[index + (slice(None),)],
|
||||
)
|
||||
|
||||
def __mul__(
|
||||
self,
|
||||
right: torch.Tensor,
|
||||
) -> Rigid:
|
||||
def __mul__(self, right: torch.Tensor) -> Rigid:
|
||||
"""
|
||||
Pointwise left multiplication of the transformation with a tensor. Can be used to e.g. mask the Rigid.
|
||||
|
||||
@ -881,10 +853,7 @@ class Rigid:
|
||||
|
||||
return Rigid(new_rots, new_trans)
|
||||
|
||||
def __rmul__(
|
||||
self,
|
||||
left: torch.Tensor,
|
||||
) -> Rigid:
|
||||
def __rmul__(self, left: torch.Tensor) -> Rigid:
|
||||
"""
|
||||
Reverse pointwise multiplication of the transformation with a tensor.
|
||||
|
||||
@ -904,8 +873,7 @@ class Rigid:
|
||||
Returns:
|
||||
The shape of the transformation
|
||||
"""
|
||||
s = self._trans.shape[:-1]
|
||||
return s
|
||||
return self._trans.shape[:-1]
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
@ -935,10 +903,7 @@ class Rigid:
|
||||
"""
|
||||
return self._trans
|
||||
|
||||
def compose_q_update_vec(
|
||||
self,
|
||||
q_update_vec: torch.Tensor,
|
||||
) -> Rigid:
|
||||
def compose_q_update_vec(self, q_update_vec: torch.Tensor) -> Rigid:
|
||||
"""
|
||||
Composes the transformation with a quaternion update vector of shape [*, 6], where the final 6 columns
|
||||
represent the x, y, and z values of a quaternion of form (1, x, y, z) followed by a 3D translation.
|
||||
@ -956,10 +921,7 @@ class Rigid:
|
||||
|
||||
return Rigid(new_rots, new_translation)
|
||||
|
||||
def compose(
|
||||
self,
|
||||
r: Rigid,
|
||||
) -> Rigid:
|
||||
def compose(self, r: Rigid) -> Rigid:
|
||||
"""
|
||||
Composes the current rigid object with another.
|
||||
|
||||
@ -973,10 +935,7 @@ class Rigid:
|
||||
new_trans = self._rots.apply(r._trans) + self._trans
|
||||
return Rigid(new_rot, new_trans)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
pts: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def apply(self, pts: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies the transformation to a coordinate tensor.
|
||||
|
||||
@ -1012,7 +971,7 @@ class Rigid:
|
||||
|
||||
return Rigid(rot_inv, -1 * trn_inv)
|
||||
|
||||
def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid:
|
||||
def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rigid:
|
||||
"""
|
||||
Apply a Tensor -> Tensor function to underlying translation and rotation tensors, mapping over the
|
||||
translation/rotation dimensions respectively.
|
||||
@ -1074,10 +1033,7 @@ class Rigid:
|
||||
return tensor
|
||||
|
||||
@staticmethod
|
||||
def from_tensor_7(
|
||||
t: torch.Tensor,
|
||||
normalize_quats: bool = False,
|
||||
) -> Rigid:
|
||||
def from_tensor_7(t: torch.Tensor, normalize_quats: bool = False) -> Rigid:
|
||||
if t.shape[-1] != 7:
|
||||
raise ValueError("Incorrectly shaped input tensor")
|
||||
|
||||
@ -1102,18 +1058,18 @@ class Rigid:
|
||||
Returns:
|
||||
A transformation object of shape [*]
|
||||
"""
|
||||
p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
|
||||
origin = torch.unbind(origin, dim=-1)
|
||||
p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
|
||||
p_neg_x_axis_unbound = torch.unbind(p_neg_x_axis, dim=-1)
|
||||
origin_unbound = torch.unbind(origin, dim=-1)
|
||||
p_xy_plane_unbound = torch.unbind(p_xy_plane, dim=-1)
|
||||
|
||||
e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
|
||||
e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
|
||||
e0 = [c1 - c2 for c1, c2 in zip(origin_unbound, p_neg_x_axis_unbound)]
|
||||
e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane_unbound, origin_unbound)]
|
||||
|
||||
denom = torch.sqrt(sum((c * c for c in e0)) + eps)
|
||||
denom = torch.sqrt(sum(c * c for c in e0) + eps * torch.ones_like(e0[0]))
|
||||
e0 = [c / denom for c in e0]
|
||||
dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
|
||||
e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
|
||||
denom = torch.sqrt(sum((c * c for c in e1)) + eps)
|
||||
denom = torch.sqrt(sum((c * c for c in e1)) + eps * torch.ones_like(e1[0]))
|
||||
e1 = [c / denom for c in e1]
|
||||
e2 = [
|
||||
e0[1] * e1[2] - e0[2] * e1[1],
|
||||
@ -1126,12 +1082,9 @@ class Rigid:
|
||||
|
||||
rot_obj = Rotation(rot_mats=rots, quats=None)
|
||||
|
||||
return Rigid(rot_obj, torch.stack(origin, dim=-1))
|
||||
return Rigid(rot_obj, torch.stack(origin_unbound, dim=-1))
|
||||
|
||||
def unsqueeze(
|
||||
self,
|
||||
dim: int,
|
||||
) -> Rigid:
|
||||
def unsqueeze(self, dim: int) -> Rigid:
|
||||
"""
|
||||
Analogous to torch.unsqueeze. The dimension is relative to the shared dimensions of the rotation/translation.
|
||||
|
||||
@ -1148,10 +1101,7 @@ class Rigid:
|
||||
return Rigid(rots, trans)
|
||||
|
||||
@staticmethod
|
||||
def cat(
|
||||
ts: Sequence[Rigid],
|
||||
dim: int,
|
||||
) -> Rigid:
|
||||
def cat(ts: Sequence[Rigid], dim: int) -> Rigid:
|
||||
"""
|
||||
Concatenates transformations along a new dimension.
|
||||
|
||||
@ -1168,7 +1118,7 @@ class Rigid:
|
||||
|
||||
return Rigid(rots, trans)
|
||||
|
||||
def apply_rot_fn(self, fn: Callable[Rotation, Rotation]) -> Rigid:
|
||||
def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Rigid:
|
||||
"""
|
||||
Applies a Rotation -> Rotation function to the stored rotation object.
|
||||
|
||||
@ -1179,7 +1129,7 @@ class Rigid:
|
||||
"""
|
||||
return Rigid(fn(self._rots), self._trans)
|
||||
|
||||
def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid:
|
||||
def apply_trans_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rigid:
|
||||
"""
|
||||
Applies a Tensor -> Tensor function to the stored translation.
|
||||
|
||||
@ -1213,7 +1163,9 @@ class Rigid:
|
||||
return self.apply_rot_fn(lambda r: r.detach())
|
||||
|
||||
@staticmethod
|
||||
def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
|
||||
def make_transform_from_reference(
|
||||
n_xyz: torch.Tensor, ca_xyz: torch.Tensor, c_xyz: torch.Tensor, eps: float = 1e-20
|
||||
) -> Rigid:
|
||||
"""
|
||||
Returns a transformation object from reference coordinates.
|
||||
|
||||
|
@ -14,13 +14,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import List
|
||||
from typing import Any, Callable, Dict, List, Type, TypeVar, Union, overload
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.types
|
||||
|
||||
|
||||
def add(m1, m2, inplace):
|
||||
def add(m1: torch.Tensor, m2: torch.Tensor, inplace: bool) -> torch.Tensor:
|
||||
# The first operation in a checkpoint can't be in-place, but it's
|
||||
# nice to have in-place addition during inference. Thus...
|
||||
if not inplace:
|
||||
@ -31,33 +32,35 @@ def add(m1, m2, inplace):
|
||||
return m1
|
||||
|
||||
|
||||
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
|
||||
def permute_final_dims(tensor: torch.Tensor, inds: List[int]) -> torch.Tensor:
|
||||
zero_index = -1 * len(inds)
|
||||
first_inds = list(range(len(tensor.shape[:zero_index])))
|
||||
return tensor.permute(first_inds + [zero_index + i for i in inds])
|
||||
|
||||
|
||||
def flatten_final_dims(t: torch.Tensor, no_dims: int):
|
||||
def flatten_final_dims(t: torch.Tensor, no_dims: int) -> torch.Tensor:
|
||||
return t.reshape(t.shape[:-no_dims] + (-1,))
|
||||
|
||||
|
||||
def masked_mean(mask, value, dim, eps=1e-4):
|
||||
def masked_mean(mask: torch.Tensor, value: torch.Tensor, dim: int, eps: float = 1e-4) -> torch.Tensor:
|
||||
mask = mask.expand(*value.shape)
|
||||
return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
|
||||
|
||||
|
||||
def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
|
||||
def pts_to_distogram(
|
||||
pts: torch.Tensor, min_bin: torch.types.Number = 2.3125, max_bin: torch.types.Number = 21.6875, no_bins: int = 64
|
||||
) -> torch.Tensor:
|
||||
boundaries = torch.linspace(min_bin, max_bin, no_bins - 1, device=pts.device)
|
||||
dists = torch.sqrt(torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1))
|
||||
return torch.bucketize(dists, boundaries)
|
||||
|
||||
|
||||
def dict_multimap(fn, dicts):
|
||||
def dict_multimap(fn: Callable[[list], Any], dicts: List[dict]) -> dict:
|
||||
first = dicts[0]
|
||||
new_dict = {}
|
||||
for k, v in first.items():
|
||||
all_v = [d[k] for d in dicts]
|
||||
if type(v) is dict:
|
||||
if isinstance(v, dict):
|
||||
new_dict[k] = dict_multimap(fn, all_v)
|
||||
else:
|
||||
new_dict[k] = fn(all_v)
|
||||
@ -65,21 +68,21 @@ def dict_multimap(fn, dicts):
|
||||
return new_dict
|
||||
|
||||
|
||||
def one_hot(x, v_bins):
|
||||
def one_hot(x: torch.Tensor, v_bins: torch.Tensor) -> torch.Tensor:
|
||||
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
|
||||
diffs = x[..., None] - reshaped_bins
|
||||
am = torch.argmin(torch.abs(diffs), dim=-1)
|
||||
return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
|
||||
|
||||
|
||||
def batched_gather(data, inds, dim=0, no_batch_dims=0):
|
||||
ranges = []
|
||||
def batched_gather(data: torch.Tensor, inds: torch.Tensor, dim: int = 0, no_batch_dims: int = 0) -> torch.Tensor:
|
||||
ranges: List[Union[slice, torch.Tensor]] = []
|
||||
for i, s in enumerate(data.shape[:no_batch_dims]):
|
||||
r = torch.arange(s)
|
||||
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
|
||||
ranges.append(r)
|
||||
|
||||
remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
|
||||
remaining_dims: List[Union[slice, torch.Tensor]] = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
|
||||
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
|
||||
ranges.extend(remaining_dims)
|
||||
# Matt note: Editing this to get around the behaviour of using a list as an array index changing
|
||||
@ -87,11 +90,16 @@ def batched_gather(data, inds, dim=0, no_batch_dims=0):
|
||||
return data[tuple(ranges)]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# With tree_map, a poor man's JAX tree_map
|
||||
def dict_map(fn, dic, leaf_type):
|
||||
new_dict = {}
|
||||
def dict_map(
|
||||
fn: Callable[[T], Any], dic: Dict[Any, Union[dict, list, tuple, T]], leaf_type: Type[T]
|
||||
) -> Dict[Any, Union[dict, list, tuple, Any]]:
|
||||
new_dict: Dict[Any, Union[dict, list, tuple, Any]] = {}
|
||||
for k, v in dic.items():
|
||||
if type(v) is dict:
|
||||
if isinstance(v, dict):
|
||||
new_dict[k] = dict_map(fn, v, leaf_type)
|
||||
else:
|
||||
new_dict[k] = tree_map(fn, v, leaf_type)
|
||||
@ -99,13 +107,33 @@ def dict_map(fn, dic, leaf_type):
|
||||
return new_dict
|
||||
|
||||
|
||||
@overload
|
||||
def tree_map(fn: Callable[[T], Any], tree: T, leaf_type: Type[T]) -> Any:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def tree_map(fn: Callable[[T], Any], tree: dict, leaf_type: Type[T]) -> dict:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def tree_map(fn: Callable[[T], Any], tree: list, leaf_type: Type[T]) -> list:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def tree_map(fn: Callable[[T], Any], tree: tuple, leaf_type: Type[T]) -> tuple:
|
||||
...
|
||||
|
||||
|
||||
def tree_map(fn, tree, leaf_type):
|
||||
if isinstance(tree, dict):
|
||||
return dict_map(fn, tree, leaf_type)
|
||||
elif isinstance(tree, list):
|
||||
return [tree_map(fn, x, leaf_type) for x in tree]
|
||||
elif isinstance(tree, tuple):
|
||||
return tuple([tree_map(fn, x, leaf_type) for x in tree])
|
||||
return tuple(tree_map(fn, x, leaf_type) for x in tree)
|
||||
elif isinstance(tree, leaf_type):
|
||||
return fn(tree)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user