mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 01:32:23 +06:00
Make canine model exportable by removing unncessary complicated logic (#37124)
This commit is contained in:
parent
60b75d99b6
commit
c0bd8048a5
@ -1056,7 +1056,7 @@ class CanineModel(CaninePreTrainedModel):
|
|||||||
|
|
||||||
return molecule_attention_mask
|
return molecule_attention_mask
|
||||||
|
|
||||||
def _repeat_molecules(self, molecules: torch.Tensor, char_seq_length: torch.Tensor) -> torch.Tensor:
|
def _repeat_molecules(self, molecules: torch.Tensor, char_seq_length: int) -> torch.Tensor:
|
||||||
"""Repeats molecules to make them the same length as the char sequence."""
|
"""Repeats molecules to make them the same length as the char sequence."""
|
||||||
|
|
||||||
rate = self.config.downsampling_rate
|
rate = self.config.downsampling_rate
|
||||||
@ -1070,7 +1070,7 @@ class CanineModel(CaninePreTrainedModel):
|
|||||||
# n elements (n < `downsampling_rate`), i.e. the remainder of floor
|
# n elements (n < `downsampling_rate`), i.e. the remainder of floor
|
||||||
# division. We do this by repeating the last molecule a few extra times.
|
# division. We do this by repeating the last molecule a few extra times.
|
||||||
last_molecule = molecules[:, -1:, :]
|
last_molecule = molecules[:, -1:, :]
|
||||||
remainder_length = torch.fmod(torch.tensor(char_seq_length), torch.tensor(rate)).item()
|
remainder_length = char_seq_length % rate
|
||||||
remainder_repeated = torch.repeat_interleave(
|
remainder_repeated = torch.repeat_interleave(
|
||||||
last_molecule,
|
last_molecule,
|
||||||
# +1 molecule to compensate for truncation.
|
# +1 molecule to compensate for truncation.
|
||||||
|
Loading…
Reference in New Issue
Block a user