mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Substantially reduce memory usage in _update_causal_mask for large batches by using .expand instead of .repeat [needs tests+sanity check] (#29413)
* try to fix gemma mem use * fix: handle attention mask dim==2 case * remove logits=logits.float() * clean up + add llama * apply formatting * readability edit: swap order of items being multiplied * revert change unrelated to PR * revert black autoformat * switch to one .to * Accept style edits Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
965cf67769
commit
2a939f20ff
@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Gemma model."""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@ -971,10 +972,11 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
|
||||
# We use the current dtype to avoid any overflows
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
|
||||
|
||||
causal_mask = causal_mask.to(dtype=dtype, device=device)
|
||||
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
|
||||
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None and attention_mask.dim() == 2:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||
|
@ -17,7 +17,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch LLaMA model."""
|
||||
"""PyTorch LLaMA model."""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@ -1083,10 +1084,10 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
|
||||
# We use the current dtype to avoid any overflows
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
|
||||
|
||||
causal_mask = causal_mask.to(dtype=dtype, device=device)
|
||||
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
|
||||
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None and attention_mask.dim() == 2:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||
|
Loading…
Reference in New Issue
Block a user