ESM openfold_utils type hints (#20544)

* add type annotations for esm chunk_utils

use isinstance builtin instead of 'type(x) is y'; add assertions to aid in type inferencing; use bools instead of ints in _get_minimal_slice_set for improved type clarity; refactor to avoid re-assigning to the same variable with a different type

* add type annotations for esm data_transforms

refactor to avoid re-assigning to the same variable with a different type

* add type annotations for esm feats utils

refactor to avoid re-assigning to the same variable with a different type

* add type annotations for esm loss utils

* add/fix type annotations for esm rigit_utils

refactor to avoid re-assigning to the same variable with a different type; fix Callable, Tuple type hints; match conditional structure to other methods; fix return type on Rotation.cat and Rotation.unsqueeze

* add type annotations for esm tensor_utils

overload for tree_map; use insinstance builtin instead of 'type(x) is y'; export dict_multimap, flatten_final_dims, permute_final_dims in openfold_utils

* add type annotations for esm protein utils

add FIXME for attempted string mutation; add missing None check in get_pdb_headers; fix potentially unbound variable 'chain_tag' in to_pdb; modify get_pdb_headers return type

* add type annotations for esm residue constants

hints on collection constants; remove magic trailing comma to reduce number of lines; change list -> tuple for rigid_group_atom_positions for improved hinting

* code style fixup

Co-authored-by: Matt <rocketknight1@gmail.com>
This commit is contained in:
Matthew Hoffman 2022-12-05 10:23:15 -06:00 committed by GitHub
parent 8ea6694d92
commit afe2a466bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 491 additions and 760 deletions

View File

@ -6,3 +6,4 @@ from .loss import compute_predicted_aligned_error, compute_tm
from .protein import Protein as OFProtein from .protein import Protein as OFProtein
from .protein import to_pdb from .protein import to_pdb
from .rigid_utils import Rigid, Rotation from .rigid_utils import Rigid, Rotation
from .tensor_utils import dict_multimap, flatten_final_dims, permute_final_dims

View File

@ -14,23 +14,22 @@
import logging import logging
import math import math
from functools import partial from functools import partial
from typing import Any, Callable, Dict, Optional, Sequence, Tuple from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import torch import torch
from .tensor_utils import tensor_tree_map, tree_map from .tensor_utils import tensor_tree_map, tree_map
def _fetch_dims(tree): def _fetch_dims(tree: Union[dict, list, tuple, torch.Tensor]) -> List[Tuple[int, ...]]:
shapes = [] shapes = []
tree_type = type(tree) if isinstance(tree, dict):
if tree_type is dict:
for v in tree.values(): for v in tree.values():
shapes.extend(_fetch_dims(v)) shapes.extend(_fetch_dims(v))
elif tree_type is list or tree_type is tuple: elif isinstance(tree, (list, tuple)):
for t in tree: for t in tree:
shapes.extend(_fetch_dims(t)) shapes.extend(_fetch_dims(t))
elif tree_type is torch.Tensor: elif isinstance(tree, torch.Tensor):
shapes.append(tree.shape) shapes.append(tree.shape)
else: else:
raise ValueError("Not supported") raise ValueError("Not supported")
@ -39,10 +38,7 @@ def _fetch_dims(tree):
@torch.jit.ignore @torch.jit.ignore
def _flat_idx_to_idx( def _flat_idx_to_idx(flat_idx: int, dims: Tuple[int, ...]) -> Tuple[int, ...]:
flat_idx: int,
dims: Tuple[int],
) -> Tuple[int]:
idx = [] idx = []
for d in reversed(dims): for d in reversed(dims):
idx.append(flat_idx % d) idx.append(flat_idx % d)
@ -55,10 +51,10 @@ def _flat_idx_to_idx(
def _get_minimal_slice_set( def _get_minimal_slice_set(
start: Sequence[int], start: Sequence[int],
end: Sequence[int], end: Sequence[int],
dims: int, dims: Sequence[int],
start_edges: Optional[Sequence[bool]] = None, start_edges: Optional[Sequence[bool]] = None,
end_edges: Optional[Sequence[bool]] = None, end_edges: Optional[Sequence[bool]] = None,
) -> Sequence[Tuple[int]]: ) -> List[Tuple[slice, ...]]:
""" """
Produces an ordered sequence of tensor slices that, when used in sequence on a tensor with shape dims, yields Produces an ordered sequence of tensor slices that, when used in sequence on a tensor with shape dims, yields
tensors that contain every leaf in the contiguous range [start, end]. Care is taken to yield a short sequence of tensors that contain every leaf in the contiguous range [start, end]. Care is taken to yield a short sequence of
@ -69,11 +65,11 @@ def _get_minimal_slice_set(
# start_edges and end_edges both indicate whether, starting from any given # start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the # dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree # corresponding tensor, modeled as a tree
def reduce_edge_list(l): def reduce_edge_list(l: List[bool]) -> None:
tally = 1 tally = True
for i in range(len(l)): for i in range(len(l)):
reversed_idx = -1 * (i + 1) reversed_idx = -1 * (i + 1)
l[reversed_idx] *= tally l[reversed_idx] &= tally
tally = l[reversed_idx] tally = l[reversed_idx]
if start_edges is None: if start_edges is None:
@ -90,48 +86,54 @@ def _get_minimal_slice_set(
elif len(start) == 1: elif len(start) == 1:
return [(slice(start[0], end[0] + 1),)] return [(slice(start[0], end[0] + 1),)]
slices = [] slices: List[Tuple[slice, ...]] = []
path = [] path_list: List[slice] = []
# Dimensions common to start and end can be selected directly # Dimensions common to start and end can be selected directly
for s, e in zip(start, end): for s, e in zip(start, end):
if s == e: if s == e:
path.append(slice(s, s + 1)) path_list.append(slice(s, s + 1))
else: else:
break break
path = tuple(path) path: Tuple[slice, ...] = tuple(path_list)
divergence_idx = len(path) divergence_idx = len(path)
# start == end, and we're done # start == end, and we're done
if divergence_idx == len(dims): if divergence_idx == len(dims):
return [tuple(path)] return [path]
def upper() -> Tuple[Tuple[slice, ...], ...]:
assert start_edges is not None
assert end_edges is not None
def upper():
sdi = start[divergence_idx] sdi = start[divergence_idx]
return [ return tuple(
path + (slice(sdi, sdi + 1),) + s path + (slice(sdi, sdi + 1),) + s
for s in _get_minimal_slice_set( for s in _get_minimal_slice_set(
start[divergence_idx + 1 :], start[divergence_idx + 1 :],
[d - 1 for d in dims[divergence_idx + 1 :]], [d - 1 for d in dims[divergence_idx + 1 :]],
dims[divergence_idx + 1 :], dims[divergence_idx + 1 :],
start_edges=start_edges[divergence_idx + 1 :], start_edges=start_edges[divergence_idx + 1 :],
end_edges=[1 for _ in end_edges[divergence_idx + 1 :]], end_edges=[True for _ in end_edges[divergence_idx + 1 :]],
) )
] )
def lower() -> Tuple[Tuple[slice, ...], ...]:
assert start_edges is not None
assert end_edges is not None
def lower():
edi = end[divergence_idx] edi = end[divergence_idx]
return [ return tuple(
path + (slice(edi, edi + 1),) + s path + (slice(edi, edi + 1),) + s
for s in _get_minimal_slice_set( for s in _get_minimal_slice_set(
[0 for _ in start[divergence_idx + 1 :]], [0 for _ in start[divergence_idx + 1 :]],
end[divergence_idx + 1 :], end[divergence_idx + 1 :],
dims[divergence_idx + 1 :], dims[divergence_idx + 1 :],
start_edges=[1 for _ in start_edges[divergence_idx + 1 :]], start_edges=[True for _ in start_edges[divergence_idx + 1 :]],
end_edges=end_edges[divergence_idx + 1 :], end_edges=end_edges[divergence_idx + 1 :],
) )
] )
# If both start and end are at the edges of the subtree rooted at # If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once # divergence_idx, we can just select the whole subtree at once
@ -156,16 +158,11 @@ def _get_minimal_slice_set(
slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)) slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx]),))
slices.extend(lower()) slices.extend(lower())
return [tuple(s) for s in slices] return slices
@torch.jit.ignore @torch.jit.ignore
def _chunk_slice( def _chunk_slice(t: torch.Tensor, flat_start: int, flat_end: int, no_batch_dims: int) -> torch.Tensor:
t: torch.Tensor,
flat_start: int,
flat_end: int,
no_batch_dims: int,
) -> torch.Tensor:
""" """
Equivalent to Equivalent to
@ -232,7 +229,7 @@ def chunk_layer(
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)] initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
def _prep_inputs(t): def _prep_inputs(t: torch.Tensor) -> torch.Tensor:
if not low_mem: if not low_mem:
if not sum(t.shape[:no_batch_dims]) == no_batch_dims: if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
@ -241,7 +238,7 @@ def chunk_layer(
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
return t return t
prepped_inputs = tensor_tree_map(_prep_inputs, inputs) prepped_inputs: Dict[str, Any] = tensor_tree_map(_prep_inputs, inputs)
prepped_outputs = None prepped_outputs = None
if _out is not None: if _out is not None:
prepped_outputs = tensor_tree_map(lambda t: t.view([-1] + list(t.shape[no_batch_dims:])), _out) prepped_outputs = tensor_tree_map(lambda t: t.view([-1] + list(t.shape[no_batch_dims:])), _out)
@ -252,7 +249,7 @@ def chunk_layer(
no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0) no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0)
def _select_chunk(t): def _select_chunk(t: torch.Tensor) -> torch.Tensor:
return t[i : i + chunk_size] if t.shape[0] != 1 else t return t[i : i + chunk_size] if t.shape[0] != 1 else t
i = 0 i = 0
@ -269,7 +266,7 @@ def chunk_layer(
no_batch_dims=len(orig_batch_dims), no_batch_dims=len(orig_batch_dims),
) )
chunks = tensor_tree_map(select_chunk, prepped_inputs) chunks: Dict[str, Any] = tensor_tree_map(select_chunk, prepped_inputs)
# Run the layer on the chunk # Run the layer on the chunk
output_chunk = layer(**chunks) output_chunk = layer(**chunks)
@ -279,12 +276,11 @@ def chunk_layer(
out = tensor_tree_map(lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]), output_chunk) out = tensor_tree_map(lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]), output_chunk)
# Put the chunk in its pre-allocated space # Put the chunk in its pre-allocated space
out_type = type(output_chunk) if isinstance(output_chunk, dict):
if out_type is dict:
def assign(d1, d2): def assign(d1: dict, d2: dict) -> None:
for k, v in d1.items(): for k, v in d1.items():
if type(v) is dict: if isinstance(v, dict):
assign(v, d2[k]) assign(v, d2[k])
else: else:
if _add_into_out: if _add_into_out:
@ -293,13 +289,13 @@ def chunk_layer(
v[i : i + chunk_size] = d2[k] v[i : i + chunk_size] = d2[k]
assign(out, output_chunk) assign(out, output_chunk)
elif out_type is tuple: elif isinstance(output_chunk, tuple):
for x1, x2 in zip(out, output_chunk): for x1, x2 in zip(out, output_chunk):
if _add_into_out: if _add_into_out:
x1[i : i + chunk_size] += x2 x1[i : i + chunk_size] += x2
else: else:
x1[i : i + chunk_size] = x2 x1[i : i + chunk_size] = x2
elif out_type is torch.Tensor: elif isinstance(output_chunk, torch.Tensor):
if _add_into_out: if _add_into_out:
out[i : i + chunk_size] += output_chunk out[i : i + chunk_size] += output_chunk
else: else:
@ -319,24 +315,24 @@ class ChunkSizeTuner:
self, self,
# Heuristically, runtimes for most of the modules in the network # Heuristically, runtimes for most of the modules in the network
# plateau earlier than this on all GPUs I've run the model on. # plateau earlier than this on all GPUs I've run the model on.
max_chunk_size=512, max_chunk_size: int = 512,
): ):
self.max_chunk_size = max_chunk_size self.max_chunk_size = max_chunk_size
self.cached_chunk_size = None self.cached_chunk_size: Optional[int] = None
self.cached_arg_data = None self.cached_arg_data: Optional[tuple] = None
def _determine_favorable_chunk_size(self, fn, args, min_chunk_size): def _determine_favorable_chunk_size(self, fn: Callable, args: tuple, min_chunk_size: int) -> int:
logging.info("Tuning chunk size...") logging.info("Tuning chunk size...")
if min_chunk_size >= self.max_chunk_size: if min_chunk_size >= self.max_chunk_size:
return min_chunk_size return min_chunk_size
candidates = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)] candidates: List[int] = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
candidates = [c for c in candidates if c > min_chunk_size] candidates = [c for c in candidates if c > min_chunk_size]
candidates = [min_chunk_size] + candidates candidates = [min_chunk_size] + candidates
candidates[-1] += 4 candidates[-1] += 4
def test_chunk_size(chunk_size): def test_chunk_size(chunk_size: int) -> bool:
try: try:
with torch.no_grad(): with torch.no_grad():
fn(*args, chunk_size=chunk_size) fn(*args, chunk_size=chunk_size)
@ -356,13 +352,13 @@ class ChunkSizeTuner:
return candidates[min_viable_chunk_size_index] return candidates[min_viable_chunk_size_index]
def _compare_arg_caches(self, ac1, ac2): def _compare_arg_caches(self, ac1: Iterable, ac2: Iterable) -> bool:
consistent = True consistent = True
for a1, a2 in zip(ac1, ac2): for a1, a2 in zip(ac1, ac2):
assert type(ac1) == type(ac2) assert type(ac1) == type(ac2)
if type(ac1) is list or type(ac1) is tuple: if isinstance(ac1, (list, tuple)):
consistent &= self._compare_arg_caches(a1, a2) consistent &= self._compare_arg_caches(a1, a2)
elif type(ac1) is dict: elif isinstance(ac1, dict):
a1_items = [v for _, v in sorted(a1.items(), key=lambda x: x[0])] a1_items = [v for _, v in sorted(a1.items(), key=lambda x: x[0])]
a2_items = [v for _, v in sorted(a2.items(), key=lambda x: x[0])] a2_items = [v for _, v in sorted(a2.items(), key=lambda x: x[0])]
consistent &= self._compare_arg_caches(a1_items, a2_items) consistent &= self._compare_arg_caches(a1_items, a2_items)
@ -374,11 +370,11 @@ class ChunkSizeTuner:
def tune_chunk_size( def tune_chunk_size(
self, self,
representative_fn: Callable, representative_fn: Callable,
args: Tuple[Any], args: tuple,
min_chunk_size: int, min_chunk_size: int,
) -> int: ) -> int:
consistent = True consistent = True
arg_data = tree_map(lambda a: a.shape if type(a) is torch.Tensor else a, args, object) arg_data: tuple = tree_map(lambda a: a.shape if isinstance(a, torch.Tensor) else a, args, object)
if self.cached_arg_data is not None: if self.cached_arg_data is not None:
# If args have changed shape/value, we need to re-tune # If args have changed shape/value, we need to re-tune
assert len(self.cached_arg_data) == len(arg_data) assert len(self.cached_arg_data) == len(arg_data)
@ -395,4 +391,6 @@ class ChunkSizeTuner:
) )
self.cached_arg_data = arg_data self.cached_arg_data = arg_data
assert self.cached_chunk_size is not None
return self.cached_chunk_size return self.cached_chunk_size

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict
import numpy as np import numpy as np
import torch import torch
@ -20,39 +22,39 @@ from . import residue_constants as rc
from .tensor_utils import tensor_tree_map, tree_map from .tensor_utils import tensor_tree_map, tree_map
def make_atom14_masks(protein): def make_atom14_masks(protein: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Construct denser atom positions (14 dimensions instead of 37).""" """Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = [] restype_atom14_to_atom37_list = []
restype_atom37_to_atom14 = [] restype_atom37_to_atom14_list = []
restype_atom14_mask = [] restype_atom14_mask_list = []
for rt in rc.restypes: for rt in rc.restypes:
atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]] atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
restype_atom14_to_atom37.append([(rc.atom_order[name] if name else 0) for name in atom_names]) restype_atom14_to_atom37_list.append([(rc.atom_order[name] if name else 0) for name in atom_names])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append( restype_atom37_to_atom14_list.append(
[(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in rc.atom_types] [(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in rc.atom_types]
) )
restype_atom14_mask.append([(1.0 if name else 0.0) for name in atom_names]) restype_atom14_mask_list.append([(1.0 if name else 0.0) for name in atom_names])
# Add dummy mapping for restype 'UNK' # Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14) restype_atom14_to_atom37_list.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37) restype_atom37_to_atom14_list.append([0] * 37)
restype_atom14_mask.append([0.0] * 14) restype_atom14_mask_list.append([0.0] * 14)
restype_atom14_to_atom37 = torch.tensor( restype_atom14_to_atom37 = torch.tensor(
restype_atom14_to_atom37, restype_atom14_to_atom37_list,
dtype=torch.int32, dtype=torch.int32,
device=protein["aatype"].device, device=protein["aatype"].device,
) )
restype_atom37_to_atom14 = torch.tensor( restype_atom37_to_atom14 = torch.tensor(
restype_atom37_to_atom14, restype_atom37_to_atom14_list,
dtype=torch.int32, dtype=torch.int32,
device=protein["aatype"].device, device=protein["aatype"].device,
) )
restype_atom14_mask = torch.tensor( restype_atom14_mask = torch.tensor(
restype_atom14_mask, restype_atom14_mask_list,
dtype=torch.float32, dtype=torch.float32,
device=protein["aatype"].device, device=protein["aatype"].device,
) )
@ -85,8 +87,7 @@ def make_atom14_masks(protein):
return protein return protein
def make_atom14_masks_np(batch): def make_atom14_masks_np(batch: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
batch = tree_map(lambda n: torch.tensor(n, device=batch["aatype"].device), batch, np.ndarray) batch = tree_map(lambda n: torch.tensor(n, device=batch["aatype"].device), batch, np.ndarray)
out = make_atom14_masks(batch) out = tensor_tree_map(lambda t: np.array(t), make_atom14_masks(batch))
out = tensor_tree_map(lambda t: np.array(t), out)
return out return out

View File

@ -13,14 +13,29 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, Tuple, overload
import torch import torch
import torch.nn as nn import torch.types
from torch import nn
from . import residue_constants as rc from . import residue_constants as rc
from .rigid_utils import Rigid, Rotation from .rigid_utils import Rigid, Rotation
from .tensor_utils import batched_gather from .tensor_utils import batched_gather
@overload
def pseudo_beta_fn(aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_masks: None) -> torch.Tensor:
...
@overload
def pseudo_beta_fn(
aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
...
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
is_gly = aatype == rc.restype_order["G"] is_gly = aatype == rc.restype_order["G"]
ca_idx = rc.atom_order["CA"] ca_idx = rc.atom_order["CA"]
@ -42,7 +57,7 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
return pseudo_beta return pseudo_beta
def atom14_to_atom37(atom14, batch): def atom14_to_atom37(atom14: torch.Tensor, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
atom37_data = batched_gather( atom37_data = batched_gather(
atom14, atom14,
batch["residx_atom37_to_atom14"], batch["residx_atom37_to_atom14"],
@ -55,7 +70,7 @@ def atom14_to_atom37(atom14, batch):
return atom37_data return atom37_data
def build_template_angle_feat(template_feats): def build_template_angle_feat(template_feats: Dict[str, torch.Tensor]) -> torch.Tensor:
template_aatype = template_feats["template_aatype"] template_aatype = template_feats["template_aatype"]
torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"] torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"] alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"]
@ -73,7 +88,15 @@ def build_template_angle_feat(template_feats):
return template_angle_feat return template_angle_feat
def build_template_pair_feat(batch, min_bin, max_bin, no_bins, use_unit_vector=False, eps=1e-20, inf=1e8): def build_template_pair_feat(
batch: Dict[str, torch.Tensor],
min_bin: torch.types.Number,
max_bin: torch.types.Number,
no_bins: int,
use_unit_vector: bool = False,
eps: float = 1e-20,
inf: float = 1e8,
) -> torch.Tensor:
template_mask = batch["template_pseudo_beta_mask"] template_mask = batch["template_pseudo_beta_mask"]
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
@ -86,7 +109,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, use_unit_vector=F
to_concat = [dgram, template_mask_2d[..., None]] to_concat = [dgram, template_mask_2d[..., None]]
aatype_one_hot = nn.functional.one_hot( aatype_one_hot: torch.LongTensor = nn.functional.one_hot(
batch["template_aatype"], batch["template_aatype"],
rc.restype_num + 2, rc.restype_num + 2,
) )
@ -126,8 +149,8 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, use_unit_vector=F
return act return act
def build_extra_msa_feat(batch): def build_extra_msa_feat(batch: Dict[str, torch.Tensor]) -> torch.Tensor:
msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23) msa_1hot: torch.LongTensor = nn.functional.one_hot(batch["extra_msa"], 23)
msa_feat = [ msa_feat = [
msa_1hot, msa_1hot,
batch["extra_has_deletion"].unsqueeze(-1), batch["extra_has_deletion"].unsqueeze(-1),
@ -141,7 +164,7 @@ def torsion_angles_to_frames(
alpha: torch.Tensor, alpha: torch.Tensor,
aatype: torch.Tensor, aatype: torch.Tensor,
rrgdf: torch.Tensor, rrgdf: torch.Tensor,
): ) -> Rigid:
# [*, N, 8, 4, 4] # [*, N, 8, 4, 4]
default_4x4 = rrgdf[aatype, ...] default_4x4 = rrgdf[aatype, ...]
@ -172,9 +195,7 @@ def torsion_angles_to_frames(
all_rots[..., 1, 2] = -alpha[..., 0] all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha all_rots[..., 2, 1:] = alpha
all_rots = Rigid(Rotation(rot_mats=all_rots), None) all_frames = default_r.compose(Rigid(Rotation(rot_mats=all_rots), None))
all_frames = default_r.compose(all_rots)
chi2_frame_to_frame = all_frames[..., 5] chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6] chi3_frame_to_frame = all_frames[..., 6]
@ -203,22 +224,22 @@ def torsion_angles_to_frames(
def frames_and_literature_positions_to_atom14_pos( def frames_and_literature_positions_to_atom14_pos(
r: Rigid, r: Rigid,
aatype: torch.Tensor, aatype: torch.Tensor,
default_frames, default_frames: torch.Tensor,
group_idx, group_idx: torch.Tensor,
atom_mask, atom_mask: torch.Tensor,
lit_positions, lit_positions: torch.Tensor,
): ) -> torch.Tensor:
# [*, N, 14] # [*, N, 14]
group_mask = group_idx[aatype, ...] group_mask = group_idx[aatype, ...]
# [*, N, 14, 8] # [*, N, 14, 8]
group_mask = nn.functional.one_hot( group_mask_one_hot: torch.LongTensor = nn.functional.one_hot(
group_mask, group_mask,
num_classes=default_frames.shape[-3], num_classes=default_frames.shape[-3],
) )
# [*, N, 14, 8] # [*, N, 14, 8]
t_atoms_to_global = r[..., None, :] * group_mask t_atoms_to_global = r[..., None, :] * group_mask_one_hot
# [*, N, 14] # [*, N, 14]
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1))

View File

@ -18,7 +18,7 @@ from typing import Dict, Optional, Tuple
import torch import torch
def _calculate_bin_centers(boundaries: torch.Tensor): def _calculate_bin_centers(boundaries: torch.Tensor) -> torch.Tensor:
step = boundaries[1] - boundaries[0] step = boundaries[1] - boundaries[0]
bin_centers = boundaries + step / 2 bin_centers = boundaries + step / 2
bin_centers = torch.cat([bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0) bin_centers = torch.cat([bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0)

View File

@ -17,7 +17,7 @@
import dataclasses import dataclasses
import re import re
import string import string
from typing import Any, Mapping, Optional, Sequence from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple
import numpy as np import numpy as np
@ -69,10 +69,10 @@ class Protein:
def from_proteinnet_string(proteinnet_str: str) -> Protein: def from_proteinnet_string(proteinnet_str: str) -> Protein:
tag_re = r"(\[[A-Z]+\]\n)" tag_re = r"(\[[A-Z]+\]\n)"
tags = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0] tags: List[str] = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0]
groups = zip(tags[0::2], [l.split("\n") for l in tags[1::2]]) groups: Iterator[Tuple[str, List[str]]] = zip(tags[0::2], [l.split("\n") for l in tags[1::2]])
atoms = ["N", "CA", "C"] atoms: List[str] = ["N", "CA", "C"]
aatype = None aatype = None
atom_positions = None atom_positions = None
atom_mask = None atom_mask = None
@ -81,12 +81,12 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
seq = g[1][0].strip() seq = g[1][0].strip()
for i in range(len(seq)): for i in range(len(seq)):
if seq[i] not in residue_constants.restypes: if seq[i] not in residue_constants.restypes:
seq[i] = "X" seq[i] = "X" # FIXME: strings are immutable
aatype = np.array( aatype = np.array(
[residue_constants.restype_order.get(res_symbol, residue_constants.restype_num) for res_symbol in seq] [residue_constants.restype_order.get(res_symbol, residue_constants.restype_num) for res_symbol in seq]
) )
elif "[TERTIARY]" == g[0]: elif "[TERTIARY]" == g[0]:
tertiary = [] tertiary: List[List[float]] = []
for axis in range(3): for axis in range(3):
tertiary.append(list(map(float, g[1][axis].split()))) tertiary.append(list(map(float, g[1][axis].split())))
tertiary_np = np.array(tertiary) tertiary_np = np.array(tertiary)
@ -106,6 +106,8 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
atom_mask[:, residue_constants.atom_order[atom]] = 1 atom_mask[:, residue_constants.atom_order[atom]] = 1
atom_mask *= mask[..., None] atom_mask *= mask[..., None]
assert aatype is not None
return Protein( return Protein(
atom_positions=atom_positions, atom_positions=atom_positions,
atom_mask=atom_mask, atom_mask=atom_mask,
@ -115,8 +117,8 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
) )
def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]: def get_pdb_headers(prot: Protein, chain_id: int = 0) -> List[str]:
pdb_headers = [] pdb_headers: List[str] = []
remark = prot.remark remark = prot.remark
if remark is not None: if remark is not None:
@ -124,7 +126,7 @@ def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
parents = prot.parents parents = prot.parents
parents_chain_index = prot.parents_chain_index parents_chain_index = prot.parents_chain_index
if parents_chain_index is not None: if parents is not None and parents_chain_index is not None:
parents = [p for i, p in zip(parents_chain_index, parents) if i == chain_id] parents = [p for i, p in zip(parents_chain_index, parents) if i == chain_id]
if parents is None or len(parents) == 0: if parents is None or len(parents) == 0:
@ -139,18 +141,18 @@ def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
"""Add pdb headers to an existing PDB string. Useful during multi-chain """Add pdb headers to an existing PDB string. Useful during multi-chain
recycling recycling
""" """
out_pdb_lines = [] out_pdb_lines: List[str] = []
lines = pdb_str.split("\n") lines = pdb_str.split("\n")
remark = prot.remark remark = prot.remark
if remark is not None: if remark is not None:
out_pdb_lines.append(f"REMARK {remark}") out_pdb_lines.append(f"REMARK {remark}")
parents_per_chain = None parents_per_chain: List[List[str]]
if prot.parents is not None and len(prot.parents) > 0: if prot.parents is not None and len(prot.parents) > 0:
parents_per_chain = [] parents_per_chain = []
if prot.parents_chain_index is not None: if prot.parents_chain_index is not None:
parent_dict = {} parent_dict: Dict[str, List[str]] = {}
for p, i in zip(prot.parents, prot.parents_chain_index): for p, i in zip(prot.parents, prot.parents_chain_index):
parent_dict.setdefault(str(i), []) parent_dict.setdefault(str(i), [])
parent_dict[str(i)].append(p) parent_dict[str(i)].append(p)
@ -160,11 +162,11 @@ def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
chain_parents = parent_dict.get(str(i), ["N/A"]) chain_parents = parent_dict.get(str(i), ["N/A"])
parents_per_chain.append(chain_parents) parents_per_chain.append(chain_parents)
else: else:
parents_per_chain.append(prot.parents) parents_per_chain.append(list(prot.parents))
else: else:
parents_per_chain = [["N/A"]] parents_per_chain = [["N/A"]]
def make_parent_line(p): def make_parent_line(p: Sequence[str]) -> str:
return f"PARENT {' '.join(p)}" return f"PARENT {' '.join(p)}"
out_pdb_lines.append(make_parent_line(parents_per_chain[0])) out_pdb_lines.append(make_parent_line(parents_per_chain[0]))
@ -196,12 +198,12 @@ def to_pdb(prot: Protein) -> str:
""" """
restypes = residue_constants.restypes + ["X"] restypes = residue_constants.restypes + ["X"]
def res_1to3(r): def res_1to3(r: int) -> str:
return residue_constants.restype_1to3.get(restypes[r], "UNK") return residue_constants.restype_1to3.get(restypes[r], "UNK")
atom_types = residue_constants.atom_types atom_types = residue_constants.atom_types
pdb_lines = [] pdb_lines: List[str] = []
atom_mask = prot.atom_mask atom_mask = prot.atom_mask
aatype = prot.aatype aatype = prot.aatype
@ -221,6 +223,7 @@ def to_pdb(prot: Protein) -> str:
atom_index = 1 atom_index = 1
prev_chain_index = 0 prev_chain_index = 0
chain_tags = string.ascii_uppercase chain_tags = string.ascii_uppercase
chain_tag = None
# Add all atom sites. # Add all atom sites.
for i in range(n): for i in range(n):
res_name_3 = res_1to3(aatype[i]) res_name_3 = res_1to3(aatype[i])
@ -313,15 +316,12 @@ def from_prediction(
Returns: Returns:
A protein instance. A protein instance.
""" """
if b_factors is None:
b_factors = np.zeros_like(result["final_atom_mask"])
return Protein( return Protein(
aatype=features["aatype"], aatype=features["aatype"],
atom_positions=result["final_atom_positions"], atom_positions=result["final_atom_positions"],
atom_mask=result["final_atom_mask"], atom_mask=result["final_atom_mask"],
residue_index=features["residue_index"] + 1, residue_index=features["residue_index"] + 1,
b_factors=b_factors, b_factors=b_factors if b_factors is not None else np.zeros_like(result["final_atom_mask"]),
chain_index=chain_index, chain_index=chain_index,
remark=remark, remark=remark,
parents=parents, parents=parents,

View File

@ -16,7 +16,7 @@
from __future__ import annotations from __future__ import annotations
from functools import lru_cache from functools import lru_cache
from typing import Any, Callable, Optional, Sequence, Tuple from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
import numpy as np import numpy as np
import torch import torch
@ -33,7 +33,7 @@ def rot_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
The product ab The product ab
""" """
def row_mul(i): def row_mul(i: int) -> torch.Tensor:
return torch.stack( return torch.stack(
[ [
a[..., i, 0] * b[..., 0, 0] + a[..., i, 1] * b[..., 1, 0] + a[..., i, 2] * b[..., 2, 0], a[..., i, 0] * b[..., 0, 0] + a[..., i, 1] * b[..., 1, 0] + a[..., i, 2] * b[..., 2, 0],
@ -76,7 +76,7 @@ def rot_vec_mul(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def identity_rot_mats( def identity_rot_mats(
batch_dims: Tuple[int], batch_dims: Tuple[int, ...],
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
requires_grad: bool = True, requires_grad: bool = True,
@ -91,7 +91,7 @@ def identity_rot_mats(
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def identity_trans( def identity_trans(
batch_dims: Tuple[int], batch_dims: Tuple[int, ...],
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
requires_grad: bool = True, requires_grad: bool = True,
@ -102,7 +102,7 @@ def identity_trans(
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def identity_quats( def identity_quats(
batch_dims: Tuple[int], batch_dims: Tuple[int, ...],
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
requires_grad: bool = True, requires_grad: bool = True,
@ -115,15 +115,14 @@ def identity_quats(
return quat return quat
_quat_elements = ["a", "b", "c", "d"] _quat_elements: List[str] = ["a", "b", "c", "d"]
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements] _qtr_keys: List[str] = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)} _qtr_ind_dict: Dict[str, int] = {key: ind for ind, key in enumerate(_qtr_keys)}
def _to_mat(pairs): def _to_mat(pairs: List[Tuple[str, int]]) -> np.ndarray:
mat = np.zeros((4, 4)) mat = np.zeros((4, 4))
for pair in pairs: for key, value in pairs:
key, value = pair
ind = _qtr_ind_dict[key] ind = _qtr_ind_dict[key]
mat[ind // 4][ind % 4] = value mat[ind // 4][ind % 4] = value
@ -165,14 +164,11 @@ def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
return torch.sum(quat, dim=(-3, -4)) return torch.sum(quat, dim=(-3, -4))
def rot_to_quat( def rot_to_quat(rot: torch.Tensor) -> torch.Tensor:
rot: torch.Tensor,
):
if rot.shape[-2:] != (3, 3): if rot.shape[-2:] != (3, 3):
raise ValueError("Input rotation is incorrectly shaped") raise ValueError("Input rotation is incorrectly shaped")
rot = [[rot[..., i, j] for j in range(3)] for i in range(3)] [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = [[rot[..., i, j] for j in range(3)] for i in range(3)]
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
k = [ k = [
[ [
@ -201,9 +197,7 @@ def rot_to_quat(
], ],
] ]
k = (1.0 / 3.0) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2) _, vectors = torch.linalg.eigh((1.0 / 3.0) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2))
_, vectors = torch.linalg.eigh(k)
return vectors[..., -1] return vectors[..., -1]
@ -218,7 +212,7 @@ _QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], [1, 0, 0,
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :] _QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
_CACHED_QUATS = { _CACHED_QUATS: Dict[str, np.ndarray] = {
"_QTR_MAT": _QTR_MAT, "_QTR_MAT": _QTR_MAT,
"_QUAT_MULTIPLY": _QUAT_MULTIPLY, "_QUAT_MULTIPLY": _QUAT_MULTIPLY,
"_QUAT_MULTIPLY_BY_VEC": _QUAT_MULTIPLY_BY_VEC, "_QUAT_MULTIPLY_BY_VEC": _QUAT_MULTIPLY_BY_VEC,
@ -226,29 +220,29 @@ _CACHED_QUATS = {
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def _get_quat(quat_key, dtype, device): def _get_quat(quat_key: str, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device) return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device)
def quat_multiply(quat1, quat2): def quat_multiply(quat1: torch.Tensor, quat2: torch.Tensor) -> torch.Tensor:
"""Multiply a quaternion by another quaternion.""" """Multiply a quaternion by another quaternion."""
mat = _get_quat("_QUAT_MULTIPLY", dtype=quat1.dtype, device=quat1.device) mat = _get_quat("_QUAT_MULTIPLY", dtype=quat1.dtype, device=quat1.device)
reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape) reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
return torch.sum(reshaped_mat * quat1[..., :, None, None] * quat2[..., None, :, None], dim=(-3, -2)) return torch.sum(reshaped_mat * quat1[..., :, None, None] * quat2[..., None, :, None], dim=(-3, -2))
def quat_multiply_by_vec(quat, vec): def quat_multiply_by_vec(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
"""Multiply a quaternion by a pure-vector quaternion.""" """Multiply a quaternion by a pure-vector quaternion."""
mat = _get_quat("_QUAT_MULTIPLY_BY_VEC", dtype=quat.dtype, device=quat.device) mat = _get_quat("_QUAT_MULTIPLY_BY_VEC", dtype=quat.dtype, device=quat.device)
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape) reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
return torch.sum(reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], dim=(-3, -2)) return torch.sum(reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], dim=(-3, -2))
def invert_rot_mat(rot_mat: torch.Tensor): def invert_rot_mat(rot_mat: torch.Tensor) -> torch.Tensor:
return rot_mat.transpose(-1, -2) return rot_mat.transpose(-1, -2)
def invert_quat(quat: torch.Tensor): def invert_quat(quat: torch.Tensor) -> torch.Tensor:
quat_prime = quat.clone() quat_prime = quat.clone()
quat_prime[..., 1:] *= -1 quat_prime[..., 1:] *= -1
inv = quat_prime / torch.sum(quat**2, dim=-1, keepdim=True) inv = quat_prime / torch.sum(quat**2, dim=-1, keepdim=True)
@ -361,10 +355,7 @@ class Rotation:
else: else:
raise ValueError("Both rotations are None") raise ValueError("Both rotations are None")
def __mul__( def __mul__(self, right: torch.Tensor) -> Rotation:
self,
right: torch.Tensor,
) -> Rotation:
""" """
Pointwise left multiplication of the rotation with a tensor. Can be used to e.g. mask the Rotation. Pointwise left multiplication of the rotation with a tensor. Can be used to e.g. mask the Rotation.
@ -386,10 +377,7 @@ class Rotation:
else: else:
raise ValueError("Both rotations are None") raise ValueError("Both rotations are None")
def __rmul__( def __rmul__(self, left: torch.Tensor) -> Rotation:
self,
left: torch.Tensor,
) -> Rotation:
""" """
Reverse pointwise multiplication of the rotation with a tensor. Reverse pointwise multiplication of the rotation with a tensor.
@ -413,13 +401,12 @@ class Rotation:
Returns: Returns:
The virtual shape of the rotation object The virtual shape of the rotation object
""" """
s = None if self._rot_mats is not None:
if self._quats is not None: return self._rot_mats.shape[:-2]
s = self._quats.shape[:-1] elif self._quats is not None:
return self._quats.shape[:-1]
else: else:
s = self._rot_mats.shape[:-2] raise ValueError("Both rotations are None")
return s
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
@ -473,14 +460,12 @@ class Rotation:
Returns: Returns:
The rotation as a rotation matrix tensor The rotation as a rotation matrix tensor
""" """
rot_mats = self._rot_mats if self._rot_mats is not None:
if rot_mats is None: return self._rot_mats
if self._quats is None: elif self._quats is not None:
raise ValueError("Both rotations are None") return quat_to_rot(self._quats)
else: else:
rot_mats = quat_to_rot(self._quats) raise ValueError("Both rotations are None")
return rot_mats
def get_quats(self) -> torch.Tensor: def get_quats(self) -> torch.Tensor:
""" """
@ -491,14 +476,12 @@ class Rotation:
Returns: Returns:
The rotation as a quaternion tensor. The rotation as a quaternion tensor.
""" """
quats = self._quats if self._rot_mats is not None:
if quats is None: return rot_to_quat(self._rot_mats)
if self._rot_mats is None: elif self._quats is not None:
raise ValueError("Both rotations are None") return self._quats
else: else:
quats = rot_to_quat(self._rot_mats) raise ValueError("Both rotations are None")
return quats
def get_cur_rot(self) -> torch.Tensor: def get_cur_rot(self) -> torch.Tensor:
""" """
@ -618,10 +601,7 @@ class Rotation:
# "Tensor" stuff # "Tensor" stuff
def unsqueeze( def unsqueeze(self, dim: int) -> Rotation:
self,
dim: int,
) -> Rigid:
""" """
Analogous to torch.unsqueeze. The dimension is relative to the shape of the Rotation object. Analogous to torch.unsqueeze. The dimension is relative to the shape of the Rotation object.
@ -643,10 +623,7 @@ class Rotation:
raise ValueError("Both rotations are None") raise ValueError("Both rotations are None")
@staticmethod @staticmethod
def cat( def cat(rs: Sequence[Rotation], dim: int) -> Rotation:
rs: Sequence[Rotation],
dim: int,
) -> Rigid:
""" """
Concatenates rotations along one of the batch dimensions. Analogous to torch.cat(). Concatenates rotations along one of the batch dimensions. Analogous to torch.cat().
@ -661,12 +638,14 @@ class Rotation:
Returns: Returns:
A concatenated Rotation object in rotation matrix format A concatenated Rotation object in rotation matrix format
""" """
rot_mats = [r.get_rot_mats() for r in rs] rot_mats = torch.cat(
rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2) [r.get_rot_mats() for r in rs],
dim=dim if dim >= 0 else dim - 2,
)
return Rotation(rot_mats=rot_mats, quats=None) return Rotation(rot_mats=rot_mats, quats=None)
def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rotation: def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rotation:
""" """
Apply a Tensor -> Tensor function to underlying rotation tensors, mapping over the rotation dimension(s). Can Apply a Tensor -> Tensor function to underlying rotation tensors, mapping over the rotation dimension(s). Can
be used e.g. to sum out a one-hot batch dimension. be used e.g. to sum out a one-hot batch dimension.
@ -754,11 +733,7 @@ class Rigid:
dimensions of its component parts. dimensions of its component parts.
""" """
def __init__( def __init__(self, rots: Optional[Rotation], trans: Optional[torch.Tensor]):
self,
rots: Optional[Rotation],
trans: Optional[torch.Tensor],
):
""" """
Args: Args:
rots: A [*, 3, 3] rotation tensor rots: A [*, 3, 3] rotation tensor
@ -795,6 +770,9 @@ class Rigid:
requires_grad, requires_grad,
) )
assert rots is not None
assert trans is not None
if (rots.shape != trans.shape[:-1]) or (rots.device != trans.device): if (rots.shape != trans.shape[:-1]) or (rots.device != trans.device):
raise ValueError("Rots and trans incompatible") raise ValueError("Rots and trans incompatible")
@ -806,7 +784,7 @@ class Rigid:
@staticmethod @staticmethod
def identity( def identity(
shape: Tuple[int], shape: Tuple[int, ...],
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
requires_grad: bool = True, requires_grad: bool = True,
@ -832,10 +810,7 @@ class Rigid:
identity_trans(shape, dtype, device, requires_grad), identity_trans(shape, dtype, device, requires_grad),
) )
def __getitem__( def __getitem__(self, index: Any) -> Rigid:
self,
index: Any,
) -> Rigid:
""" """
Indexes the affine transformation with PyTorch-style indices. The index is applied to the shared dimensions of Indexes the affine transformation with PyTorch-style indices. The index is applied to the shared dimensions of
both the rotation and the translation. both the rotation and the translation.
@ -860,10 +835,7 @@ class Rigid:
self._trans[index + (slice(None),)], self._trans[index + (slice(None),)],
) )
def __mul__( def __mul__(self, right: torch.Tensor) -> Rigid:
self,
right: torch.Tensor,
) -> Rigid:
""" """
Pointwise left multiplication of the transformation with a tensor. Can be used to e.g. mask the Rigid. Pointwise left multiplication of the transformation with a tensor. Can be used to e.g. mask the Rigid.
@ -881,10 +853,7 @@ class Rigid:
return Rigid(new_rots, new_trans) return Rigid(new_rots, new_trans)
def __rmul__( def __rmul__(self, left: torch.Tensor) -> Rigid:
self,
left: torch.Tensor,
) -> Rigid:
""" """
Reverse pointwise multiplication of the transformation with a tensor. Reverse pointwise multiplication of the transformation with a tensor.
@ -904,8 +873,7 @@ class Rigid:
Returns: Returns:
The shape of the transformation The shape of the transformation
""" """
s = self._trans.shape[:-1] return self._trans.shape[:-1]
return s
@property @property
def device(self) -> torch.device: def device(self) -> torch.device:
@ -935,10 +903,7 @@ class Rigid:
""" """
return self._trans return self._trans
def compose_q_update_vec( def compose_q_update_vec(self, q_update_vec: torch.Tensor) -> Rigid:
self,
q_update_vec: torch.Tensor,
) -> Rigid:
""" """
Composes the transformation with a quaternion update vector of shape [*, 6], where the final 6 columns Composes the transformation with a quaternion update vector of shape [*, 6], where the final 6 columns
represent the x, y, and z values of a quaternion of form (1, x, y, z) followed by a 3D translation. represent the x, y, and z values of a quaternion of form (1, x, y, z) followed by a 3D translation.
@ -956,10 +921,7 @@ class Rigid:
return Rigid(new_rots, new_translation) return Rigid(new_rots, new_translation)
def compose( def compose(self, r: Rigid) -> Rigid:
self,
r: Rigid,
) -> Rigid:
""" """
Composes the current rigid object with another. Composes the current rigid object with another.
@ -973,10 +935,7 @@ class Rigid:
new_trans = self._rots.apply(r._trans) + self._trans new_trans = self._rots.apply(r._trans) + self._trans
return Rigid(new_rot, new_trans) return Rigid(new_rot, new_trans)
def apply( def apply(self, pts: torch.Tensor) -> torch.Tensor:
self,
pts: torch.Tensor,
) -> torch.Tensor:
""" """
Applies the transformation to a coordinate tensor. Applies the transformation to a coordinate tensor.
@ -1012,7 +971,7 @@ class Rigid:
return Rigid(rot_inv, -1 * trn_inv) return Rigid(rot_inv, -1 * trn_inv)
def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid: def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rigid:
""" """
Apply a Tensor -> Tensor function to underlying translation and rotation tensors, mapping over the Apply a Tensor -> Tensor function to underlying translation and rotation tensors, mapping over the
translation/rotation dimensions respectively. translation/rotation dimensions respectively.
@ -1074,10 +1033,7 @@ class Rigid:
return tensor return tensor
@staticmethod @staticmethod
def from_tensor_7( def from_tensor_7(t: torch.Tensor, normalize_quats: bool = False) -> Rigid:
t: torch.Tensor,
normalize_quats: bool = False,
) -> Rigid:
if t.shape[-1] != 7: if t.shape[-1] != 7:
raise ValueError("Incorrectly shaped input tensor") raise ValueError("Incorrectly shaped input tensor")
@ -1102,18 +1058,18 @@ class Rigid:
Returns: Returns:
A transformation object of shape [*] A transformation object of shape [*]
""" """
p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1) p_neg_x_axis_unbound = torch.unbind(p_neg_x_axis, dim=-1)
origin = torch.unbind(origin, dim=-1) origin_unbound = torch.unbind(origin, dim=-1)
p_xy_plane = torch.unbind(p_xy_plane, dim=-1) p_xy_plane_unbound = torch.unbind(p_xy_plane, dim=-1)
e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)] e0 = [c1 - c2 for c1, c2 in zip(origin_unbound, p_neg_x_axis_unbound)]
e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)] e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane_unbound, origin_unbound)]
denom = torch.sqrt(sum((c * c for c in e0)) + eps) denom = torch.sqrt(sum(c * c for c in e0) + eps * torch.ones_like(e0[0]))
e0 = [c / denom for c in e0] e0 = [c / denom for c in e0]
dot = sum((c1 * c2 for c1, c2 in zip(e0, e1))) dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)] e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
denom = torch.sqrt(sum((c * c for c in e1)) + eps) denom = torch.sqrt(sum((c * c for c in e1)) + eps * torch.ones_like(e1[0]))
e1 = [c / denom for c in e1] e1 = [c / denom for c in e1]
e2 = [ e2 = [
e0[1] * e1[2] - e0[2] * e1[1], e0[1] * e1[2] - e0[2] * e1[1],
@ -1126,12 +1082,9 @@ class Rigid:
rot_obj = Rotation(rot_mats=rots, quats=None) rot_obj = Rotation(rot_mats=rots, quats=None)
return Rigid(rot_obj, torch.stack(origin, dim=-1)) return Rigid(rot_obj, torch.stack(origin_unbound, dim=-1))
def unsqueeze( def unsqueeze(self, dim: int) -> Rigid:
self,
dim: int,
) -> Rigid:
""" """
Analogous to torch.unsqueeze. The dimension is relative to the shared dimensions of the rotation/translation. Analogous to torch.unsqueeze. The dimension is relative to the shared dimensions of the rotation/translation.
@ -1148,10 +1101,7 @@ class Rigid:
return Rigid(rots, trans) return Rigid(rots, trans)
@staticmethod @staticmethod
def cat( def cat(ts: Sequence[Rigid], dim: int) -> Rigid:
ts: Sequence[Rigid],
dim: int,
) -> Rigid:
""" """
Concatenates transformations along a new dimension. Concatenates transformations along a new dimension.
@ -1168,7 +1118,7 @@ class Rigid:
return Rigid(rots, trans) return Rigid(rots, trans)
def apply_rot_fn(self, fn: Callable[Rotation, Rotation]) -> Rigid: def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Rigid:
""" """
Applies a Rotation -> Rotation function to the stored rotation object. Applies a Rotation -> Rotation function to the stored rotation object.
@ -1179,7 +1129,7 @@ class Rigid:
""" """
return Rigid(fn(self._rots), self._trans) return Rigid(fn(self._rots), self._trans)
def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid: def apply_trans_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rigid:
""" """
Applies a Tensor -> Tensor function to the stored translation. Applies a Tensor -> Tensor function to the stored translation.
@ -1213,7 +1163,9 @@ class Rigid:
return self.apply_rot_fn(lambda r: r.detach()) return self.apply_rot_fn(lambda r: r.detach())
@staticmethod @staticmethod
def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): def make_transform_from_reference(
n_xyz: torch.Tensor, ca_xyz: torch.Tensor, c_xyz: torch.Tensor, eps: float = 1e-20
) -> Rigid:
""" """
Returns a transformation object from reference coordinates. Returns a transformation object from reference coordinates.

View File

@ -14,13 +14,14 @@
# limitations under the License. # limitations under the License.
from functools import partial from functools import partial
from typing import List from typing import Any, Callable, Dict, List, Type, TypeVar, Union, overload
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.types
def add(m1, m2, inplace): def add(m1: torch.Tensor, m2: torch.Tensor, inplace: bool) -> torch.Tensor:
# The first operation in a checkpoint can't be in-place, but it's # The first operation in a checkpoint can't be in-place, but it's
# nice to have in-place addition during inference. Thus... # nice to have in-place addition during inference. Thus...
if not inplace: if not inplace:
@ -31,33 +32,35 @@ def add(m1, m2, inplace):
return m1 return m1
def permute_final_dims(tensor: torch.Tensor, inds: List[int]): def permute_final_dims(tensor: torch.Tensor, inds: List[int]) -> torch.Tensor:
zero_index = -1 * len(inds) zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index]))) first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds]) return tensor.permute(first_inds + [zero_index + i for i in inds])
def flatten_final_dims(t: torch.Tensor, no_dims: int): def flatten_final_dims(t: torch.Tensor, no_dims: int) -> torch.Tensor:
return t.reshape(t.shape[:-no_dims] + (-1,)) return t.reshape(t.shape[:-no_dims] + (-1,))
def masked_mean(mask, value, dim, eps=1e-4): def masked_mean(mask: torch.Tensor, value: torch.Tensor, dim: int, eps: float = 1e-4) -> torch.Tensor:
mask = mask.expand(*value.shape) mask = mask.expand(*value.shape)
return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): def pts_to_distogram(
pts: torch.Tensor, min_bin: torch.types.Number = 2.3125, max_bin: torch.types.Number = 21.6875, no_bins: int = 64
) -> torch.Tensor:
boundaries = torch.linspace(min_bin, max_bin, no_bins - 1, device=pts.device) boundaries = torch.linspace(min_bin, max_bin, no_bins - 1, device=pts.device)
dists = torch.sqrt(torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1)) dists = torch.sqrt(torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1))
return torch.bucketize(dists, boundaries) return torch.bucketize(dists, boundaries)
def dict_multimap(fn, dicts): def dict_multimap(fn: Callable[[list], Any], dicts: List[dict]) -> dict:
first = dicts[0] first = dicts[0]
new_dict = {} new_dict = {}
for k, v in first.items(): for k, v in first.items():
all_v = [d[k] for d in dicts] all_v = [d[k] for d in dicts]
if type(v) is dict: if isinstance(v, dict):
new_dict[k] = dict_multimap(fn, all_v) new_dict[k] = dict_multimap(fn, all_v)
else: else:
new_dict[k] = fn(all_v) new_dict[k] = fn(all_v)
@ -65,21 +68,21 @@ def dict_multimap(fn, dicts):
return new_dict return new_dict
def one_hot(x, v_bins): def one_hot(x: torch.Tensor, v_bins: torch.Tensor) -> torch.Tensor:
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
diffs = x[..., None] - reshaped_bins diffs = x[..., None] - reshaped_bins
am = torch.argmin(torch.abs(diffs), dim=-1) am = torch.argmin(torch.abs(diffs), dim=-1)
return nn.functional.one_hot(am, num_classes=len(v_bins)).float() return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
def batched_gather(data, inds, dim=0, no_batch_dims=0): def batched_gather(data: torch.Tensor, inds: torch.Tensor, dim: int = 0, no_batch_dims: int = 0) -> torch.Tensor:
ranges = [] ranges: List[Union[slice, torch.Tensor]] = []
for i, s in enumerate(data.shape[:no_batch_dims]): for i, s in enumerate(data.shape[:no_batch_dims]):
r = torch.arange(s) r = torch.arange(s)
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
ranges.append(r) ranges.append(r)
remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)] remaining_dims: List[Union[slice, torch.Tensor]] = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
ranges.extend(remaining_dims) ranges.extend(remaining_dims)
# Matt note: Editing this to get around the behaviour of using a list as an array index changing # Matt note: Editing this to get around the behaviour of using a list as an array index changing
@ -87,11 +90,16 @@ def batched_gather(data, inds, dim=0, no_batch_dims=0):
return data[tuple(ranges)] return data[tuple(ranges)]
T = TypeVar("T")
# With tree_map, a poor man's JAX tree_map # With tree_map, a poor man's JAX tree_map
def dict_map(fn, dic, leaf_type): def dict_map(
new_dict = {} fn: Callable[[T], Any], dic: Dict[Any, Union[dict, list, tuple, T]], leaf_type: Type[T]
) -> Dict[Any, Union[dict, list, tuple, Any]]:
new_dict: Dict[Any, Union[dict, list, tuple, Any]] = {}
for k, v in dic.items(): for k, v in dic.items():
if type(v) is dict: if isinstance(v, dict):
new_dict[k] = dict_map(fn, v, leaf_type) new_dict[k] = dict_map(fn, v, leaf_type)
else: else:
new_dict[k] = tree_map(fn, v, leaf_type) new_dict[k] = tree_map(fn, v, leaf_type)
@ -99,13 +107,33 @@ def dict_map(fn, dic, leaf_type):
return new_dict return new_dict
@overload
def tree_map(fn: Callable[[T], Any], tree: T, leaf_type: Type[T]) -> Any:
...
@overload
def tree_map(fn: Callable[[T], Any], tree: dict, leaf_type: Type[T]) -> dict:
...
@overload
def tree_map(fn: Callable[[T], Any], tree: list, leaf_type: Type[T]) -> list:
...
@overload
def tree_map(fn: Callable[[T], Any], tree: tuple, leaf_type: Type[T]) -> tuple:
...
def tree_map(fn, tree, leaf_type): def tree_map(fn, tree, leaf_type):
if isinstance(tree, dict): if isinstance(tree, dict):
return dict_map(fn, tree, leaf_type) return dict_map(fn, tree, leaf_type)
elif isinstance(tree, list): elif isinstance(tree, list):
return [tree_map(fn, x, leaf_type) for x in tree] return [tree_map(fn, x, leaf_type) for x in tree]
elif isinstance(tree, tuple): elif isinstance(tree, tuple):
return tuple([tree_map(fn, x, leaf_type) for x in tree]) return tuple(tree_map(fn, x, leaf_type) for x in tree)
elif isinstance(tree, leaf_type): elif isinstance(tree, leaf_type):
return fn(tree) return fn(tree)
else: else: