Substantially reduce memory usage in _update_causal_mask for large batches by using .expand instead of .repeat [needs tests+sanity check] (#29413)

* try to fix gemma mem use

* fix: handle attention mask dim==2 case

* remove logits=logits.float()

* clean up + add llama

* apply formatting

* readability edit: swap order of items being multiplied

* revert change unrelated to PR

* revert black autoformat

* switch to one .to

* Accept style edits

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Glen Taggart 2024-03-06 15:56:25 -08:00 committed by GitHub
parent 965cf67769
commit 2a939f20ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 6 deletions

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Gemma model."""
import math
import warnings
from typing import List, Optional, Tuple, Union
@ -971,10 +972,11 @@ class GemmaModel(GemmaPreTrainedModel):
# We use the current dtype to avoid any overflows
min_dtype = torch.finfo(dtype).min
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
causal_mask = causal_mask.to(dtype=dtype, device=device)
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
if attention_mask is not None and attention_mask.dim() == 2:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)

View File

@ -17,7 +17,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch LLaMA model."""
"""PyTorch LLaMA model."""
import math
import warnings
from typing import List, Optional, Tuple, Union
@ -1083,10 +1084,10 @@ class LlamaModel(LlamaPreTrainedModel):
# We use the current dtype to avoid any overflows
min_dtype = torch.finfo(dtype).min
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
causal_mask = causal_mask.to(dtype=dtype, device=device)
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
if attention_mask is not None and attention_mask.dim() == 2:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)