transformers/src/transformers/modeling_utils.py
Yih-Dar 563a8d58db
Delete state_dict to release memory as early as possible (#18832)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2022-09-01 10:55:30 +02:00

3042 lines
141 KiB
Python

# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import json
import os
import re
import shutil
import tempfile
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from packaging import version
from torch import Tensor, device, nn
from torch.nn import CrossEntropyLoss
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from transformers.utils.import_utils import is_sagemaker_mp_enabled
from .activations import get_activation
from .configuration_utils import PretrainedConfig
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from .dynamic_module_utils import custom_object_save
from .generation_utils import GenerationMixin
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
)
from .utils import (
DUMMY_INPUTS,
FLAX_WEIGHTS_NAME,
TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
ContextManagers,
ModelOutput,
PushToHubMixin,
cached_file,
copy_func,
has_file,
is_accelerate_available,
is_bitsandbytes_available,
is_offline_mode,
logging,
replace_return_docstrings,
)
from .utils.versions import require_version_core
if is_accelerate_available():
from accelerate import __version__ as accelerate_version
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
from accelerate.utils import (
load_offloaded_weights,
offload_weight,
save_offload_index,
set_module_tensor_to_device,
)
if version.parse(accelerate_version) > version.parse("0.11.0"):
from accelerate.utils import get_balanced_memory
else:
get_balanced_memory = None
if is_bitsandbytes_available():
from .utils.bitsandbytes import get_key_to_not_convert, replace_8bit_linear, set_module_8bit_tensor_to_device
logger = logging.get_logger(__name__)
_init_weights = True
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
from smdistributed.modelparallel import __version__ as SMP_VERSION
IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
else:
IS_SAGEMAKER_MP_POST_1_10 = False
@contextmanager
def no_init_weights(_enable=True):
"""
Context manager to globally disable weight initialization to speed up loading large models.
TODO(Patrick): Delete safety argument `_enable=True` at next major version. .
"""
global _init_weights
old_init_weights = _init_weights
if _enable:
_init_weights = False
try:
yield
finally:
_init_weights = old_init_weights
try:
from torch.nn import Identity
except ImportError:
# Older PyTorch compatibility
class Identity(nn.Module):
r"""A placeholder identity operator that is argument-insensitive."""
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, input):
return input
def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
try:
return next(parameter.parameters()).device
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].device
def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
"""
Returns the first parameter dtype (can be non-floating) or asserts if none were found.
"""
try:
return next(parameter.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
"""
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
"""
last_dtype = None
for t in parameter.parameters():
last_dtype = t.dtype
if t.is_floating_point():
return t.dtype
if last_dtype is not None:
# if no floating dtype was found return whatever the first dtype is
return last_dtype
else:
# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
last_tuple = None
for tuple in gen:
last_tuple = tuple
if tuple[1].is_floating_point():
return tuple[1].dtype
# fallback to the last dtype
return last_tuple[1].dtype
def get_state_dict_float_dtype(state_dict):
"""
Returns the first found floating dtype in `state_dict` or asserts if none were found.
"""
for t in state_dict.values():
if t.is_floating_point():
return t.dtype
raise ValueError("couldn't find any floating point dtypes in state_dict")
def get_state_dict_dtype(state_dict):
"""
Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the first dtype.
"""
for t in state_dict.values():
if t.is_floating_point():
return t.dtype
# if no floating dtype was found return whatever the first dtype is
else:
return next(state_dict.values()).dtype
def dtype_byte_size(dtype):
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`.
Example:
```py
>>> dtype_byte_size(torch.float32)
4
```
"""
if dtype == torch.bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8
def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB"):
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
[6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
<Tip warning={true}>
If one of the model's weight is bigger that `max_sahrd_size`, it will end up in its own sub-checkpoint which will
have a size greater than `max_shard_size`.
</Tip>
Args:
state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save.
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
(like `"5MB"`).
"""
max_shard_size = convert_file_size_to_int(max_shard_size)
sharded_state_dicts = []
current_block = {}
current_block_size = 0
total_size = 0
for key, weight in state_dict.items():
weight_size = weight.numel() * dtype_byte_size(weight.dtype)
# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:
sharded_state_dicts.append(current_block)
current_block = {}
current_block_size = 0
current_block[key] = weight
current_block_size += weight_size
total_size += weight_size
# Add the last block
sharded_state_dicts.append(current_block)
# If we only have one shard, we return it
if len(sharded_state_dicts) == 1:
return {WEIGHTS_NAME: sharded_state_dicts[0]}, None
# Otherwise, let's build the index
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
shards[shard_file] = shard
for key in shard.keys():
weight_map[key] = shard_file
# Add the metadata
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index
def load_sharded_checkpoint(model, folder, strict=True):
"""
This is the same as
[`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
but for a sharded checkpoint.
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
loaded in the model.
Args:
model (`torch.nn.Module`): The model in which to load the checkpoint.
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
strict (`bool`, *optional`, defaults to `True`):
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
Returns:
`NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields
- `missing_keys` is a list of str containing the missing keys
- `unexpected_keys` is a list of str containing the unexpected keys
"""
# Load the index
index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
if not os.path.isfile(index_file):
raise ValueError(f"Can't find a checkpoint index ({WEIGHTS_INDEX_NAME}) in {folder}.")
with open(index_file, "r", encoding="utf-8") as f:
index = json.load(f)
shard_files = list(set(index["weight_map"].values()))
# If strict=True, error before loading any of the state dicts.
loaded_keys = index["weight_map"].keys()
model_keys = model.state_dict().keys()
missing_keys = [key for key in model_keys if key not in loaded_keys]
unexpected_keys = [key for key in loaded_keys if key not in model_keys]
if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
if len(missing_keys) > 0:
str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
error_message += f"\nMissing key(s): {str_missing_keys}."
if len(unexpected_keys) > 0:
str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)
for shard_file in shard_files:
state_dict = torch.load(os.path.join(folder, shard_file))
model.load_state_dict(state_dict, strict=False)
# Make sure memory is fred before we load the next state dict.
del state_dict
gc.collect()
# Return the same thing as PyTorch load_state_dict function.
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
"""
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
"""
try:
return torch.load(checkpoint_file, map_location="cpu")
except Exception as e:
try:
with open(checkpoint_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please install "
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
"you cloned."
)
else:
raise ValueError(
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
"model. Make sure you have saved the model properly."
) from e
except (UnicodeDecodeError, ValueError):
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
f"at '{checkpoint_file}'. "
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
)
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if is_deepspeed_zero3_enabled():
import deepspeed
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
else:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model_to_load, state_dict, prefix=start_prefix)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict
return error_msgs
def find_submodule_and_param_name(model, long_key, start_prefix):
"""
A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed
from the start of the key
"""
if len(start_prefix) > 0 and long_key.startswith(start_prefix):
long_key = ".".join(long_key.split(".")[1:])
split_key = long_key.split(".")
submodule = model
while len(split_key) > 1:
if hasattr(submodule, split_key[0]):
submodule = getattr(submodule, split_key[0])
del split_key[0]
else:
submodule = None
break
if submodule == model:
submodule = None
return submodule, split_key[0]
def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
"""
Moves `loaded_state_dict_keys` in model to meta device which frees up the memory taken by those params.
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
`bert.pooler.dense.weight`
"""
# meta device was added in pt=1.9
require_version_core("torch>=1.9")
# dematerialize param storage for keys that are going to be replaced by state_dict, by
# putting those on the meta device
for k in loaded_state_dict_keys:
submodule, param_name = find_submodule_and_param_name(model, k, start_prefix)
if submodule is not None:
# selectively switch to the meta device only those params/buffers that will
# be next replaced from state_dict. This a complex way to do p.to_("meta")
# since we have no in-place to_ for tensors.
new_val = getattr(submodule, param_name)
if isinstance(new_val, torch.nn.Parameter):
# isinstance returns False for Params on meta device, so switch after the check
new_val = torch.nn.Parameter(new_val.to("meta"))
else:
new_val = new_val.to("meta")
setattr(submodule, param_name, new_val)
def _load_state_dict_into_meta_model(
model,
state_dict,
loaded_state_dict_keys, # left for now but could be removed, see below
start_prefix,
expected_keys,
device_map=None,
offload_folder=None,
offload_index=None,
state_dict_folder=None,
state_dict_index=None,
dtype=None,
load_in_8bit=False,
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
params back to the normal device, but only for `loaded_state_dict_keys`.
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
`bert.pooler.dense.weight`
"""
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
# - deepspeed zero 3 support
# - need to copy metadata if any - see _load_state_dict_into_model
# - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
# they won't get loaded.
error_msgs = []
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
for param_name, param in state_dict.items():
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
continue
if param_name.startswith(start_prefix):
param_name = param_name[len(start_prefix) :]
module_name = param_name
# We convert floating dtypes to the `dtype` passed.We want to keep the buffers/params
# in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param):
param = param.to(dtype)
if device_map is None:
param_device = "cpu"
else:
# find next higher level module that is defined in device_map:
# bert.lm_head.weight -> bert.lm_head -> bert -> ''
while len(module_name) > 0 and module_name not in device_map:
module_name = ".".join(module_name.split(".")[:-1])
if module_name == "" and "" not in device_map:
# TODO: group all errors and raise at the end.
raise ValueError(f"{param_name} doesn't have any device set.")
param_device = device_map[module_name]
if param_device == "disk":
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
elif not load_in_8bit:
set_module_tensor_to_device(model, param_name, param_device, value=param)
else:
set_module_8bit_tensor_to_device(model, param_name, param_device, value=param)
return error_msgs, offload_index, state_dict_index
class ModuleUtilsMixin:
"""
A few utilities for `torch.nn.Modules`, to be used as a mixin.
"""
@staticmethod
def _hook_rss_memory_pre_forward(module, *args, **kwargs):
try:
import psutil
except ImportError:
raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
process = psutil.Process(os.getpid())
mem = process.memory_info()
module.mem_rss_pre_forward = mem.rss
return None
@staticmethod
def _hook_rss_memory_post_forward(module, *args, **kwargs):
try:
import psutil
except ImportError:
raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
process = psutil.Process(os.getpid())
mem = process.memory_info()
module.mem_rss_post_forward = mem.rss
mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
return None
def add_memory_hooks(self):
"""
Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero
with `model.reset_memory_hooks_state()`.
"""
for module in self.modules():
module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
module.register_forward_hook(self._hook_rss_memory_post_forward)
self.reset_memory_hooks_state()
def reset_memory_hooks_state(self):
"""
Reset the `mem_rss_diff` attribute of each module (see [`~modeling_utils.ModuleUtilsMixin.add_memory_hooks`]).
"""
for module in self.modules():
module.mem_rss_diff = 0
module.mem_rss_post_forward = 0
module.mem_rss_pre_forward = 0
@property
def device(self) -> device:
"""
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
"""
return get_parameter_device(self)
@property
def dtype(self) -> torch.dtype:
"""
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
return get_parameter_dtype(self)
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
"""
Invert an attention mask (e.g., switches 0. and 1.).
Args:
encoder_attention_mask (`torch.Tensor`): An attention mask.
Returns:
`torch.Tensor`: The inverted attention mask.
"""
if encoder_attention_mask.dim() == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if encoder_attention_mask.dim() == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
# /transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
# encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(self.dtype).min
return encoder_extended_attention_mask
@staticmethod
def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None):
if device is not None:
warnings.warn(
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
else:
device = attention_mask.device
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat(
[
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
causal_mask,
],
axis=-1,
)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
return extended_attention_mask
def get_extended_attention_mask(
self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None, dtype: torch.float = None
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
if dtype is None:
dtype = self.dtype
if not (attention_mask.dim() == 2 and self.config.is_decoder):
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
if device is not None:
warnings.warn(
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder:
extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
input_shape, attention_mask, device
)
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
return extended_attention_mask
def get_head_mask(
self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
) -> Tensor:
"""
Prepare the head mask if needed.
Args:
head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
num_hidden_layers (`int`):
The number of hidden layers in the model.
is_attention_chunked: (`bool`, *optional*, defaults to `False`):
Whether or not the attentions scores are computed by chunks or not.
Returns:
`torch.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
`[None]` for each layer.
"""
if head_mask is not None:
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
if is_attention_chunked is True:
head_mask = head_mask.unsqueeze(-1)
else:
head_mask = [None] * num_hidden_layers
return head_mask
def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
"""-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility
return head_mask
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
"""
Get number of (optionally, trainable or non-embeddings) parameters in the module.
Args:
only_trainable (`bool`, *optional*, defaults to `False`):
Whether or not to return only the number of trainable parameters
exclude_embeddings (`bool`, *optional*, defaults to `False`):
Whether or not to return only the number of non-embeddings parameters
Returns:
`int`: The number of parameters.
"""
if exclude_embeddings:
embedding_param_names = [
f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
]
non_embedding_parameters = [
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
]
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
else:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
"""
Helper function to estimate the total number of tokens from the model inputs.
Args:
inputs (`dict`): The model inputs.
Returns:
`int`: The total number of tokens.
"""
if not hasattr(self, "warnings_issued"):
self.warnings_issued = {}
if self.main_input_name in input_dict:
return input_dict[self.main_input_name].numel()
elif "estimate_tokens" not in self.warnings_issued:
logger.warning(
"Could not estimate the number of tokens of the input, floating-point operations will not be computed"
)
self.warnings_issued["estimate_tokens"] = True
return 0
def floating_point_ops(
self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True
) -> int:
"""
Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a
batch with this transformer model. Default approximation neglects the quadratic dependency on the number of
tokens (valid if `12 * d_model << sequence_length`) as laid out in [this
paper](https://arxiv.org/pdf/2001.08361.pdf) section 2.1. Should be overridden for transformers with parameter
re-use e.g. Albert or Universal Transformers, or if doing long-range modeling with very high sequence lengths.
Args:
batch_size (`int`):
The batch size for the forward pass.
sequence_length (`int`):
The number of tokens in each line of the batch.
exclude_embeddings (`bool`, *optional*, defaults to `True`):
Whether or not to count embedding and softmax operations.
Returns:
`int`: The number of floating-point operations.
"""
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
r"""
Base class for all models.
[`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
downloading and saving models as well as a few methods common to all models to:
- resize the input embeddings,
- prune heads in the self-attention heads.
Class attributes (overridden by derived classes):
- **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
for this model architecture.
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
taking as arguments:
- **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint.
- **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model.
- **path** (`str`) -- A path to the TensorFlow checkpoint.
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
classes of the same architecture adding modules on top of the base model.
- **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
models, `pixel_values` for vision models and `input_values` for speech models).
"""
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
_auto_class = None
_no_split_modules = None
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
_keys_to_ignore_on_load_missing = None
# a list of `re` patterns of `state_dict` keys that should be removed from the list of
# unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
# warnings.
_keys_to_ignore_on_load_unexpected = None
# a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
# trained, but which are either deterministic or tied variables)
_keys_to_ignore_on_save = None
is_parallelizable = False
supports_gradient_checkpointing = False
@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
`Dict[str, torch.Tensor]`: Dummy inputs to do a forward pass in the network.
"""
return {"input_ids": torch.tensor(DUMMY_INPUTS)}
@property
def framework(self) -> str:
"""
:str: Identifies that this is a PyTorch model.
"""
return "pt"
def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
super().__init__()
if not isinstance(config, PretrainedConfig):
raise ValueError(
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
"`PretrainedConfig`. To create a model from a pretrained model use "
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
# Save config and origin of the pretrained weights if given in model
self.config = config
self.name_or_path = config.name_or_path
self.warnings_issued = {}
def post_init(self):
"""
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
modules properly initialized (such as weight initialization).
"""
self.init_weights()
self._backward_compatibility_gradient_checkpointing()
def _backward_compatibility_gradient_checkpointing(self):
if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
self.gradient_checkpointing_enable()
# Remove the attribute now that is has been consumed, so it's no saved in the config.
delattr(self.config, "gradient_checkpointing")
@classmethod
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
Args:
torch_dtype (`torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype.
"""
torch_dtype = kwargs.pop("torch_dtype", None)
# override default dtype if needed
dtype_orig = None
if torch_dtype is not None:
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
if is_deepspeed_zero3_enabled():
import deepspeed
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
model = cls(config, **kwargs)
else:
model = cls(config, **kwargs)
# restore default dtype if it was modified
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
return model
@classmethod
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
"""
Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
under specific dtype.
Args:
dtype (`torch.dtype`):
a floating dtype to set to.
Returns:
`torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
modified. If it wasn't, returns `None`.
Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
`torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
"""
if not dtype.is_floating_point:
raise ValueError(
f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
)
logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
dtype_orig = torch.get_default_dtype()
torch.set_default_dtype(dtype)
return dtype_orig
@property
def base_model(self) -> nn.Module:
"""
`torch.nn.Module`: The main body of the model.
"""
return getattr(self, self.base_model_prefix, self)
def get_input_embeddings(self) -> nn.Module:
"""
Returns the model's input embeddings.
Returns:
`nn.Module`: A torch module mapping vocabulary to hidden states.
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
return base_model.get_input_embeddings()
else:
raise NotImplementedError
def set_input_embeddings(self, value: nn.Module):
"""
Set model's input embeddings.
Args:
value (`nn.Module`): A module mapping vocabulary to hidden states.
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
base_model.set_input_embeddings(value)
else:
raise NotImplementedError
def get_output_embeddings(self) -> nn.Module:
"""
Returns the model's output embeddings.
Returns:
`nn.Module`: A torch module mapping hidden states to vocabulary.
"""
return None # Overwrite for models with output embeddings
def _init_weights(self, module):
"""
Initialize the weights. This method should be overridden by derived class.
"""
raise NotImplementedError(f"Make sure `_init_weights` is implemented for {self.__class__}")
def tie_weights(self):
"""
Tie the weights between the input embeddings and the output embeddings.
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
weights instead.
"""
if getattr(self.config, "tie_word_embeddings", True):
output_embeddings = self.get_output_embeddings()
if output_embeddings is not None:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
if hasattr(self, self.base_model_prefix):
self = getattr(self, self.base_model_prefix)
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
for module in self.modules():
if hasattr(module, "_tie_weights"):
module._tie_weights()
@staticmethod
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str):
uninitialized_encoder_weights: List[str] = []
if decoder.__class__ != encoder.__class__:
logger.info(
f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder"
" weights are correctly initialized."
)
def tie_encoder_to_decoder_recursively(
decoder_pointer: nn.Module,
encoder_pointer: nn.Module,
module_name: str,
uninitialized_encoder_weights: List[str],
depth=0,
):
assert isinstance(decoder_pointer, nn.Module) and isinstance(
encoder_pointer, nn.Module
), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module"
if hasattr(decoder_pointer, "weight"):
assert hasattr(encoder_pointer, "weight")
encoder_pointer.weight = decoder_pointer.weight
if hasattr(decoder_pointer, "bias"):
assert hasattr(encoder_pointer, "bias")
encoder_pointer.bias = decoder_pointer.bias
return
encoder_modules = encoder_pointer._modules
decoder_modules = decoder_pointer._modules
if len(decoder_modules) > 0:
assert (
len(encoder_modules) > 0
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
encoder_layer_pos = 0
for name, module in decoder_modules.items():
if name.isdigit():
encoder_name = str(int(name) + encoder_layer_pos)
decoder_name = name
if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
encoder_modules
) != len(decoder_modules):
# this can happen if the name corresponds to the position in a list module list of layers
# in this case the decoder has added a cross-attention that the encoder does not have
# thus skip this step and subtract one layer pos from encoder
encoder_layer_pos -= 1
continue
elif name not in encoder_modules:
continue
elif depth > 500:
raise ValueError(
"Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is"
" a circular dependency between two or more `nn.Modules` of your model."
)
else:
decoder_name = encoder_name = name
tie_encoder_to_decoder_recursively(
decoder_modules[decoder_name],
encoder_modules[encoder_name],
module_name + "/" + name,
uninitialized_encoder_weights,
depth=depth + 1,
)
all_encoder_weights.remove(module_name + "/" + encoder_name)
uninitialized_encoder_weights += list(all_encoder_weights)
# tie weights recursively
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights)
if len(uninitialized_encoder_weights) > 0:
logger.warning(
f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
)
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
"""Tie or clone module weights depending of whether we are using TorchScript or not"""
if self.config.torchscript:
output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
else:
output_embeddings.weight = input_embeddings.weight
if getattr(output_embeddings, "bias", None) is not None:
output_embeddings.bias.data = nn.functional.pad(
output_embeddings.bias.data,
(
0,
output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
),
"constant",
0,
)
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
output_embeddings.out_features = input_embeddings.num_embeddings
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
"""
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
Arguments:
new_num_tokens (`int`, *optional*):
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
Return:
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
"""
model_embeds = self._resize_token_embeddings(new_num_tokens)
if new_num_tokens is None:
return model_embeds
# Update base model and current model config
self.config.vocab_size = new_num_tokens
self.vocab_size = new_num_tokens
# Tie weights again if needed
self.tie_weights()
return model_embeds
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.set_input_embeddings(new_embeddings)
# if word embeddings are not tied, make sure that lm head is resized as well
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
old_lm_head = self.get_output_embeddings()
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
self.set_output_embeddings(new_lm_head)
return self.get_input_embeddings()
def _get_resized_embeddings(
self, old_embeddings: nn.Embedding, new_num_tokens: Optional[int] = None
) -> nn.Embedding:
"""
Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
initialized vectors at the end. Reducing the size will remove vectors from the end
Args:
old_embeddings (`torch.nn.Embedding`):
Old embeddings to be resized.
new_num_tokens (`int`, *optional*):
New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
`torch.nn.Embedding` module of the model without doing anything.
Return:
`torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if
`new_num_tokens` is `None`
"""
if new_num_tokens is None:
return old_embeddings
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
else:
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
if old_num_tokens == new_num_tokens:
return old_embeddings
if not isinstance(old_embeddings, nn.Embedding):
raise TypeError(
f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You"
" should either use a different resize function or make sure that `old_embeddings` are an instance of"
f" {nn.Embedding}."
)
# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(old_embeddings.weight.device, dtype=old_embeddings.weight.dtype)
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
# Copy token embeddings from the previous weights
# numbers of tokens to copy
n = min(old_num_tokens, new_num_tokens)
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=0):
if torch.distributed.get_rank() == 0:
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
else:
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
return new_embeddings
def _get_resized_lm_head(
self, old_lm_head: nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False
) -> nn.Linear:
"""
Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end
Args:
old_lm_head (`torch.nn.Linear`):
Old lm head liner layer to be resized.
new_num_tokens (`int`, *optional*):
New number of tokens in the linear matrix.
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
`torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults
to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim,
vocab_size` else `vocab_size, lm_head_dim`.
Return:
`torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is
`None`
"""
if new_num_tokens is None:
return old_lm_head
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
old_num_tokens, old_lm_head_dim = (
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
)
else:
old_num_tokens, old_lm_head_dim = (
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
)
if old_num_tokens == new_num_tokens:
return old_lm_head
if not isinstance(old_lm_head, nn.Linear):
raise TypeError(
f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}. You"
" should either use a different resize function or make sure that `old_lm_head` are an instance of"
f" {nn.Linear}."
)
# Build new lm head
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
has_new_lm_head_bias = old_lm_head.bias is not None
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias)
new_lm_head = new_lm_head.to(old_lm_head.weight.device, dtype=old_lm_head.weight.dtype)
# initialize new lm head (in particular added tokens)
self._init_weights(new_lm_head)
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
# XXX: put the long block of code in a wrapper
if is_deepspeed_zero3_enabled():
import deepspeed
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
if torch.distributed.get_rank() == 0:
# Copy old lm head weights to new lm head
if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[
:num_tokens_to_copy, :
]
else:
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[
:, :num_tokens_to_copy
]
# Copy bias weights to new lm head
if has_new_lm_head_bias:
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
else:
# Copy old lm head weights to new lm head
if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
else:
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
# Copy bias weights to new lm head
if has_new_lm_head_bias:
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
return new_lm_head
def resize_position_embeddings(self, new_num_position_embeddings: int):
raise NotImplementedError(
f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
)
def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
raise NotImplementedError(
f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
)
def init_weights(self):
"""
If needed prunes and maybe initializes weights.
"""
# Prune heads if needed
if self.config.pruned_heads:
self.prune_heads(self.config.pruned_heads)
if _init_weights:
# Initialize weights
self.apply(self._init_weights)
# Tie weights should be skipped when not initializing all weights
# since from_pretrained(...) calls tie weights anyways
self.tie_weights()
def prune_heads(self, heads_to_prune: Dict[int, List[int]]):
"""
Prunes heads of the base model.
Arguments:
heads_to_prune (`Dict[int, List[int]]`):
Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads
to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on
layer 1 and heads 2 and 3 on layer 2.
"""
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
for layer, heads in heads_to_prune.items():
union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
self.base_model._prune_heads(heads_to_prune)
def gradient_checkpointing_enable(self):
"""
Activates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True))
def gradient_checkpointing_disable(self):
"""
Deactivates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
if self.supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))
@property
def is_gradient_checkpointing(self) -> bool:
"""
Whether gradient checkpointing is activated for this model or not.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
state_dict: Optional[dict] = None,
save_function: Callable = torch.save,
push_to_hub: bool = False,
max_shard_size: Union[int, str] = "10GB",
**kwargs,
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~PreTrainedModel.from_pretrained`]` class method.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions.
state_dict (nested dictionary of `torch.Tensor`):
The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only
save parts of the model or if special precautions need to be taken when recovering the state dictionary
of a model (like when using model parallelism).
save_function (`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
<Tip warning={true}>
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
which will be bigger than `max_shard_size`.
</Tip>
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
if "save_config" in kwargs:
warnings.warn(
"`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
)
is_main_process = kwargs.pop("save_config")
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
os.makedirs(save_directory, exist_ok=True)
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
# Only save the model itself if we are using distributed training
model_to_save = unwrap_model(self)
# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
# we currently don't use this setting automatically, but may start to use with v5
dtype = get_parameter_dtype(model_to_save)
model_to_save.config.torch_dtype = str(dtype).split(".")[1]
# Attach architecture to the config
model_to_save.config.architectures = [model_to_save.__class__.__name__]
# If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:
custom_object_save(self, save_directory, config=self.config)
# Save the config
if is_main_process:
model_to_save.config.save_pretrained(save_directory)
# Save the model
if state_dict is None:
state_dict = model_to_save.state_dict()
# Translate state_dict from smp to hf if saving with smp >= 1.10
if IS_SAGEMAKER_MP_POST_1_10:
for smp_to_hf, _ in smp.state.module_manager.translate_functions:
state_dict = smp_to_hf(state_dict)
# Handle the case where some state_dict keys shouldn't be saved
if self._keys_to_ignore_on_save is not None:
for ignore_key in self._keys_to_ignore_on_save:
if ignore_key in state_dict.keys():
del state_dict[ignore_key]
# Shard the model if it is too big.
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
# Clean the folder from a previous save
for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions.
if (
filename.startswith(WEIGHTS_NAME[:-4])
and os.path.isfile(full_filename)
and filename not in shards.keys()
and is_main_process
):
os.remove(full_filename)
# Save the model
for shard_file, shard in shards.items():
save_function(shard, os.path.join(save_directory, shard_file))
if index is None:
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
else:
save_index_file = os.path.join(save_directory, WEIGHTS_INDEX_NAME)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
)
def get_memory_footprint(self, return_buffers=True):
r"""
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
Arguments:
return_buffers (`bool`, *optional*, defaults to `True`):
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
"""
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
if return_buffers:
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
mem = mem + mem_bufs
return mem
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
r"""
Instantiate a pretrained pytorch model from a pre-trained model configuration.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
the model, you should first set it back in training mode with `model.train()`.
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
weights are discarded.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
this case, `from_tf` should be set to `True` and a configuration object should be provided as
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
`./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
`True`.
- `None` if you are both providing the configuration and state dictionary (resp. with keyword
arguments `config` and `state_dict`).
model_args (sequence of positional arguments, *optional*):
All remaining positional arguments will be passed to the underlying model's `__init__` method.
config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
Can be either:
- an instance of a class derived from [`PretrainedConfig`],
- a string or path valid as input to [`~PretrainedConfig.from_pretrained`].
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
model).
- The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
save directory.
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
configuration JSON file named *config.json* is found in the directory.
state_dict (`Dict[str, torch.Tensor]`, *optional*):
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own
weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and
[`~PreTrainedModel.from_pretrained`] is not a simpler option.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
from_tf (`bool`, *optional*, defaults to `False`):
Load the model weights from a TensorFlow checkpoint save file (see docstring of
`pretrained_model_name_or_path` argument).
from_flax (`bool`, *optional*, defaults to `False`):
Load the model weights from a Flax checkpoint save file (see docstring of
`pretrained_model_name_or_path` argument).
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
checkpoint with 3 labels).
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
mirror (`str`, *optional*):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.
_fast_init(`bool`, *optional*, defaults to `True`):
Whether or not to disable fast initialization.
<Tip warning={true}>
One should only disable *_fast_init* to ensure backwards compatibility with `transformers.__version__ <
4.6.0` for seeded model initialization. This argument will be removed at the next major version. See
[pull request 11471](https://github.com/huggingface/transformers/pull/11471) for more information.
</Tip>
> Parameters for big model inference
low_cpu_mem_usage(`bool`, *optional*):
Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
This is an experimental feature and a subject to change at any moment.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/big_modeling#designing-a-device-map).
max_memory (`Dict`, *optional*):
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
GPU and the available CPU RAM if unset.
offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_state_dict (`bool`, *optional*):
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
`True` when there is some disk offload.
load_in_8bit (`bool`, *optional*, defaults to `False`):
If `True`, will convert the loaded model into mixed-8bit quantized model. To use this feature please
install `bitsandbytes` compiled with your CUDA version by running `pip install -i
https://test.pypi.org/simple/ bitsandbytes-cudaXXX` where XXX is your CUDA version (e.g. 11.6 = 116).
Make also sure that you have enough GPU RAM to store half of the model size since the 8bit modules are
not compiled and adapted for CPUs.
int8_threshold (`float`, *optional*, defaults to 6):
Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as
described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper. Any hidden
states value that is above this threshold will be considered an outlier and the operation on those
values will be done in fp16. Values are usually normally distributed, that is, most values are in the
range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently
distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8
quantization works well for values of magnitude ~5, but beyond that, there is a significant performance
penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models
(small models, fine-tuning).
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
automatically loaded:
- If a configuration is provided with `config`, `**kwargs` will be directly passed to the
underlying model's `__init__` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
corresponds to a configuration attribute will be used to override said attribute with the
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
will be passed to the underlying model's `__init__` function.
<Tip>
Passing `use_auth_token=True`` is required when you want to use a private model.
</Tip>
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
use this method in a firewalled environment.
</Tip>
Examples:
```python
>>> from transformers import BertConfig, BertModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
>>> model = BertModel.from_pretrained("./test/saved_model/")
>>> # Update configuration during loading.
>>> model = BertModel.from_pretrained("bert-base-uncased", output_attentions=True)
>>> assert model.config.output_attentions == True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
>>> config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json")
>>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)
>>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
>>> model = BertModel.from_pretrained("bert-base-uncased", from_flax=True)
```
* `low_cpu_mem_usage` algorithm:
This is an experimental function that loads the model using ~1x model size CPU memory
Here is how it works:
1. save which state_dict keys we have
2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory
3. after the model has been instantiated switch to the meta device all params/buffers that
are going to be replaced from the loaded state_dict
4. load state_dict 2nd time
5. replace the params/buffers from the state_dict
Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors
"""
config = kwargs.pop("config", None)
state_dict = kwargs.pop("state_dict", None)
cache_dir = kwargs.pop("cache_dir", None)
from_tf = kwargs.pop("from_tf", False)
from_flax = kwargs.pop("from_flax", False)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
_ = kwargs.pop("mirror", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True)
torch_dtype = kwargs.pop("torch_dtype", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None)
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
load_in_8bit = kwargs.pop("load_in_8bit", False)
int8_threshold = kwargs.pop("int8_threshold", 6.0)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
if trust_remote_code is True:
logger.warning(
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
" ignored."
)
if device_map is not None:
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
elif not low_cpu_mem_usage:
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
if low_cpu_mem_usage:
# low_cpu_mem_usage requires PyTorch >= 1.9 to have the meta device.
require_version_core("torch>=1.9")
if is_deepspeed_zero3_enabled():
raise ValueError(
"DeepSpeed Zero-3 is not compatible with `low_cpu_mem_usage=True` or with passing a `device_map`."
)
elif not is_accelerate_available():
raise ImportError(
"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`"
)
if load_in_8bit:
if not (is_accelerate_available() and is_bitsandbytes_available()):
raise ImportError(
"Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of"
" bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes` or"
" pip install bitsandbytes` "
)
if torch_dtype == "auto" or torch_dtype != torch.float16:
# We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
torch_dtype = torch.float16
logger.info("Loading the model in mixed int8 - forcing the weights to be casted in float16")
if device_map is None:
raise ValueError(
"A device map needs to be passed to run convert models into mixed-int8 format. Please run"
"`.from_pretrained` with `device_map='auto'`"
)
if from_tf or from_flax:
raise ValueError(
"Converting into mixed 8-bit weights from tf/flax weights is currently not supported, please make"
" sure the weights are in PyTorch format."
)
from_pt = not (from_tf | from_flax)
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path
config, model_kwargs = cls.config_class.from_pretrained(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
)
else:
model_kwargs = kwargs
if commit_hash is None:
commit_hash = getattr(config, "_commit_hash", None)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
is_sharded = False
sharded_metadata = None
# Load model
loading_info = None
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local:
if from_tf and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
):
# Load from a TF 1.0 checkpoint in priority if from_tf
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
elif from_tf and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
):
# Load from a TF 2.0 checkpoint in priority if from_tf
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
elif from_flax and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
):
# Load from a Flax checkpoint in priority if from_flax
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error.
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)):
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but "
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those "
"weights."
)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but "
"there is a file for Flax weights. Use `from_flax=True` to load this model from those "
"weights."
)
else:
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or "
f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
)
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
archive_file = pretrained_model_name_or_path
is_local = True
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")):
if not from_tf:
raise ValueError(
f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set "
"from_tf to True to load from this checkpoint."
)
archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
is_local = True
else:
# set correct filename
if from_tf:
filename = TF2_WEIGHTS_NAME
elif from_flax:
filename = FLAX_WEIGHTS_NAME
else:
filename = WEIGHTS_NAME
try:
# Load from URL or cache if already cached
cached_file_kwargs = dict(
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file(
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
has_file_kwargs = {
"revision": revision,
"proxies": proxies,
"use_auth_token": use_auth_token,
}
if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {WEIGHTS_NAME} but there is a file for TensorFlow weights. Use `from_tf=True` to"
" load this model from those weights."
)
elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {WEIGHTS_NAME} but there is a file for Flax weights. Use `from_flax=True` to load"
" this model from those weights."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
# to the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
f" {FLAX_WEIGHTS_NAME}."
)
if is_local:
logger.info(f"loading weights file {archive_file}")
resolved_archive_file = archive_file
else:
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
else:
resolved_archive_file = None
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
if is_sharded:
# rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path,
resolved_archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
)
# load pt weights early so that we know which dtype to init the model under
if from_pt:
if not is_sharded and state_dict is None:
# Time to load the checkpoint
state_dict = load_state_dict(resolved_archive_file)
# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
# weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
# we also may have config.torch_dtype available, but we won't rely on it till v5
dtype_orig = None
if torch_dtype is not None:
if isinstance(torch_dtype, str):
if torch_dtype == "auto":
if is_sharded and "dtype" in sharded_metadata:
torch_dtype = sharded_metadata["dtype"]
elif not is_sharded:
torch_dtype = get_state_dict_dtype(state_dict)
else:
one_state_dict = load_state_dict(resolved_archive_file[0])
torch_dtype = get_state_dict_dtype(one_state_dict)
del one_state_dict # free CPU memory
else:
raise ValueError(
f"`torch_dtype` can be either a `torch.dtype` or `auto`, but received {torch_dtype}"
)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
loaded_state_dict_keys = [k for k in state_dict.keys()]
if low_cpu_mem_usage:
state_dict = None
config.name_or_path = pretrained_model_name_or_path
# Instantiate model.
init_contexts = [no_init_weights(_enable=_fast_init)]
if is_deepspeed_zero3_enabled():
import deepspeed
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts
elif load_in_8bit or low_cpu_mem_usage:
init_contexts.append(init_empty_weights())
with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)
if load_in_8bit:
logger.info("Detected 8-bit loading: activating 8-bit loading for this model")
# We never convert lm_head or any last modules for numerical stability reasons
modules_to_not_convert = get_key_to_not_convert(model)
model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert)
if isinstance(device_map, str):
if model._no_split_modules is None:
raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.")
no_split_modules = model._no_split_modules
if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
raise ValueError(
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
"'sequential'."
)
elif device_map in ["balanced", "balanced_low_0"] and get_balanced_memory is None:
raise ValueError(f"`device_map={device_map}` requires a source install of Accelerate.")
if device_map != "sequential" and get_balanced_memory is not None:
max_memory = get_balanced_memory(
model,
max_memory=max_memory,
no_split_module_classes=no_split_modules,
dtype=torch_dtype,
low_zero=(device_map == "balanced_low_0"),
)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(
model,
no_split_module_classes=no_split_modules,
dtype=torch_dtype if not load_in_8bit else torch.int8,
max_memory=max_memory,
)
if load_in_8bit:
# The LM head can stay on disk / CPU
device_map_without_lm_head = {
key: device_map[key] for key in device_map.keys() if key != modules_to_not_convert
}
if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
raise ValueError("8-bit operations on `bitsandbytes` are not supported under CPU!")
del device_map_without_lm_head
if from_tf:
if resolved_archive_file.endswith(".index"):
# Load from a TensorFlow 1.X checkpoint - provided by original authors
model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
else:
# Load from our TensorFlow 2.0 checkpoints
try:
from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model
model, loading_info = load_tf2_checkpoint_in_pytorch_model(
model, resolved_archive_file, allow_missing_keys=True, output_loading_info=True
)
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed."
" Please see https://pytorch.org/ and https://www.tensorflow.org/install/ for installation"
" instructions."
)
raise
elif from_flax:
try:
from .modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model
model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file)
except ImportError:
logger.error(
"Loading a Flax model in PyTorch, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for"
" installation instructions."
)
raise
elif from_pt:
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
load_in_8bit=load_in_8bit,
)
# make sure token embedding weights are still tied if needed
model.tie_weights()
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
# Dispatch model with hooks on all devices if necessary
if device_map is not None:
dispatch_model(model, device_map=device_map, offload_dir=offload_folder)
if output_loading_info:
if loading_info is None:
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}
return model, loading_info
return model
@classmethod
def _load_pretrained_model(
cls,
model,
state_dict,
loaded_keys,
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=False,
sharded_metadata=None,
_fast_init=True,
low_cpu_mem_usage=False,
device_map=None,
offload_folder=None,
offload_state_dict=None,
dtype=None,
load_in_8bit=False,
):
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
raise ValueError(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
" for them."
)
os.makedirs(offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True
# Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
prefix = model.base_model_prefix
def _fix_key(key):
if "beta" in key:
return key.replace("beta", "bias")
if "gamma" in key:
return key.replace("gamma", "weight")
return key
original_loaded_keys = loaded_keys
loaded_keys = [_fix_key(key) for key in loaded_keys]
if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
else:
has_prefix_module = False
expects_prefix_module = False
# key re-naming operations are never done on the keys
# that are loaded, but always on the keys of the newly initialized model
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
add_prefix_to_model = has_prefix_module and not expects_prefix_module
if remove_prefix_from_model:
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(prefix)]
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys]
elif add_prefix_to_model:
expected_keys = [".".join([prefix, s]) for s in expected_keys]
missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls._keys_to_ignore_on_load_missing is not None:
for pat in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
# retrieve weights on meta device and put them back on CPU.
# This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step
if low_cpu_mem_usage:
for key in missing_keys:
if key.startswith(prefix):
key = ".".join(key.split(".")[1:])
param = model_state_dict[key]
if param.device == torch.device("meta"):
if not load_in_8bit:
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
else:
set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init:
uninitialized_modules = model.retrieve_modules_from_names(
missing_keys, add_prefix=add_prefix_to_model, remove_prefix=remove_prefix_from_model
)
for module in uninitialized_modules:
model._init_weights(module)
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
model_to_load = model
if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module:
start_prefix = cls.base_model_prefix + "."
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
model_to_load = getattr(model, cls.base_model_prefix)
if any(key in expected_keys_not_prefixed for key in loaded_keys):
raise ValueError(
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
"properly saved?"
)
if device_map is not None:
device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()}
def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
if remove_prefix_from_model:
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key = f"{prefix}.{checkpoint_key}"
elif add_prefix_to_model:
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = ".".join(checkpoint_key.split(".")[1:])
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
if state_dict is not None:
# Whole checkpoint
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
)
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True
# This should always be a list but, just to be sure.
if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]
error_msgs = []
mismatched_keys = []
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
if offload_state_dict:
state_dict_folder = tempfile.mkdtemp()
state_dict_index = {}
else:
state_dict_folder = None
state_dict_index = None
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys += _find_mismatched_keys(
state_dict,
model_state_dict,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
)
if low_cpu_mem_usage:
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
load_in_8bit=load_in_8bit,
)
error_msgs += new_error_msgs
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
# force memory release
del state_dict
gc.collect()
if offload_index is not None and len(offload_index) > 0:
if model != model_to_load:
# We need to add the prefix of the base model
prefix = cls.base_model_prefix
for weight_name in offload_index:
shutil.move(
os.path.join(offload_folder, f"{weight_name}.dat"),
os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"),
)
offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()}
save_offload_index(offload_index, offload_folder)
if offload_state_dict:
# Load back temporarily offloaded state dict
load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder)
shutil.rmtree(state_dict_folder)
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
if "size mismatch" in error_msg:
error_msg += (
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
elif len(mismatched_keys) == 0:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
" training."
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
module_keys = set([".".join(key.split(".")[:-1]) for key in names])
# torch.nn.ParameterList is a special case where two parameter keywords
# are appended to the module name, *e.g.* bert.special_embeddings.0
module_keys = module_keys.union(set([".".join(key.split(".")[:-2]) for key in names if key[-1].isdigit()]))
retrieved_modules = []
# retrieve all modules that has at least one missing weight name
for name, module in self.named_modules():
if remove_prefix:
name = ".".join(name.split(".")[1:]) if name.startswith(self.base_model_prefix) else name
elif add_prefix:
name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix
if name in module_keys:
retrieved_modules.append(module)
return retrieved_modules
@staticmethod
def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file, start_prefix=""):
"""
This is an experimental function that loads the model using ~1.x model size CPU memory
Before you call it do:
1. save which state_dict keys are available
2. drop state_dict before model is created, since the latter takes 1x model size memory
Here then we continue:
3. switch to the meta device all params/buffers that are going to be replaced from the loaded state_dict
4. load state_dict 2nd time
5. replace the params/buffers from the state_dict
Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.
"""
_move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
state_dict = load_state_dict(resolved_archive_file)
error_msgs = _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix)
return error_msgs
@classmethod
def register_for_auto_class(cls, auto_class="AutoModel"):
"""
Register this class with a given auto class. This should only be used for custom models as the ones in the
library are already mapped with an auto class.
<Tip warning={true}>
This API is experimental and may have some slight breaking changes in the next releases.
</Tip>
Args:
auto_class (`str` or `type`, *optional*, defaults to `"AutoModel"`):
The auto class to register this new model with.
"""
if not isinstance(auto_class, str):
auto_class = auto_class.__name__
import transformers.models.auto as auto_module
if not hasattr(auto_module, auto_class):
raise ValueError(f"{auto_class} is not a valid auto class.")
cls._auto_class = auto_class
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
PreTrainedModel.push_to_hub.__doc__ = PreTrainedModel.push_to_hub.__doc__.format(
object="model", object_class="AutoModel", object_files="model file"
)
class PoolerStartLogits(nn.Module):
"""
Compute SQuAD start logits from sequence hidden states.
Args:
config ([`PretrainedConfig`]):
The config used by the model, will be used to grab the `hidden_size` of the model.
"""
def __init__(self, config: PretrainedConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, 1)
def forward(
self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
) -> torch.FloatTensor:
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
The final hidden states of the model.
p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
should be masked.
Returns:
`torch.FloatTensor`: The start logits for SQuAD.
"""
x = self.dense(hidden_states).squeeze(-1)
if p_mask is not None:
if get_parameter_dtype(self) == torch.float16:
x = x * (1 - p_mask) - 65500 * p_mask
else:
x = x * (1 - p_mask) - 1e30 * p_mask
return x
class PoolerEndLogits(nn.Module):
"""
Compute SQuAD end logits from sequence hidden states.
Args:
config ([`PretrainedConfig`]):
The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
to use.
"""
def __init__(self, config: PretrainedConfig):
super().__init__()
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
self.activation = nn.Tanh()
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dense_1 = nn.Linear(config.hidden_size, 1)
def forward(
self,
hidden_states: torch.FloatTensor,
start_states: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
p_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
The final hidden states of the model.
start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
The hidden states of the first tokens for the labeled span.
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
The position of the first token for the labeled span.
p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
should be masked.
<Tip>
One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
`start_states`.
</Tip>
Returns:
`torch.FloatTensor`: The end logits for SQuAD.
"""
assert (
start_states is not None or start_positions is not None
), "One of start_states, start_positions should be not None"
if start_positions is not None:
slen, hsz = hidden_states.shape[-2:]
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
x = self.activation(x)
x = self.LayerNorm(x)
x = self.dense_1(x).squeeze(-1)
if p_mask is not None:
if get_parameter_dtype(self) == torch.float16:
x = x * (1 - p_mask) - 65500 * p_mask
else:
x = x * (1 - p_mask) - 1e30 * p_mask
return x
class PoolerAnswerClass(nn.Module):
"""
Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
Args:
config ([`PretrainedConfig`]):
The config used by the model, will be used to grab the `hidden_size` of the model.
"""
def __init__(self, config):
super().__init__()
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
self.activation = nn.Tanh()
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
def forward(
self,
hidden_states: torch.FloatTensor,
start_states: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
cls_index: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
The final hidden states of the model.
start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
The hidden states of the first tokens for the labeled span.
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
The position of the first token for the labeled span.
cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
<Tip>
One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
`start_states`.
</Tip>
Returns:
`torch.FloatTensor`: The SQuAD 2.0 answer class.
"""
# No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
hsz = hidden_states.shape[-1]
assert (
start_states is not None or start_positions is not None
), "One of start_states, start_positions should be not None"
if start_positions is not None:
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
if cls_index is not None:
cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
else:
cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
x = self.activation(x)
x = self.dense_1(x).squeeze(-1)
return x
@dataclass
class SquadHeadOutput(ModelOutput):
"""
Base class for outputs of question answering models using a [`~modeling_utils.SQuADHead`].
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
Classification loss as the sum of start token, end token (and is_impossible if provided) classification
losses.
start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
Indices for the top config.start_n_top start token possibilities (beam-search).
end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
(beam-search).
end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
Log probabilities for the `is_impossible` label of the answers.
"""
loss: Optional[torch.FloatTensor] = None
start_top_log_probs: Optional[torch.FloatTensor] = None
start_top_index: Optional[torch.LongTensor] = None
end_top_log_probs: Optional[torch.FloatTensor] = None
end_top_index: Optional[torch.LongTensor] = None
cls_logits: Optional[torch.FloatTensor] = None
class SQuADHead(nn.Module):
r"""
A SQuAD head inspired by XLNet.
Args:
config ([`PretrainedConfig`]):
The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
to use.
"""
def __init__(self, config):
super().__init__()
self.start_n_top = config.start_n_top
self.end_n_top = config.end_n_top
self.start_logits = PoolerStartLogits(config)
self.end_logits = PoolerEndLogits(config)
self.answer_class = PoolerAnswerClass(config)
@replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig)
def forward(
self,
hidden_states: torch.FloatTensor,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
cls_index: Optional[torch.LongTensor] = None,
is_impossible: Optional[torch.LongTensor] = None,
p_mask: Optional[torch.FloatTensor] = None,
return_dict: bool = False,
) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]:
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
Final hidden states of the model on the sequence tokens.
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Positions of the first token for the labeled span.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Positions of the last token for the labeled span.
cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Whether the question has a possible answer in the paragraph or not.
p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
should be masked.
return_dict (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Returns:
"""
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, let's remove the dimension added by batch splitting
for x in (start_positions, end_positions, cls_index, is_impossible):
if x is not None and x.dim() > 1:
x.squeeze_(-1)
# during training, compute the end logits based on the ground truth of the start position
end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
loss_fct = CrossEntropyLoss()
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if cls_index is not None and is_impossible is not None:
# Predict answerability from the representation of CLS and START
cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
loss_fct_cls = nn.BCEWithLogitsLoss()
cls_loss = loss_fct_cls(cls_logits, is_impossible)
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
total_loss += cls_loss * 0.5
return SquadHeadOutput(loss=total_loss) if return_dict else (total_loss,)
else:
# during inference, compute the end logits based on beam search
bsz, slen, hsz = hidden_states.size()
start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen)
start_top_log_probs, start_top_index = torch.topk(
start_log_probs, self.start_n_top, dim=-1
) # shape (bsz, start_n_top)
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
start_states
) # shape (bsz, slen, start_n_top, hsz)
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
end_top_log_probs, end_top_index = torch.topk(
end_log_probs, self.end_n_top, dim=1
) # shape (bsz, end_n_top, start_n_top)
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
if not return_dict:
return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
else:
return SquadHeadOutput(
start_top_log_probs=start_top_log_probs,
start_top_index=start_top_index,
end_top_log_probs=end_top_log_probs,
end_top_index=end_top_index,
cls_logits=cls_logits,
)
class SequenceSummary(nn.Module):
r"""
Compute a single vector summary of a sequence hidden states.
Args:
config ([`PretrainedConfig`]):
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
config class of your model for the default values it uses):
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
- `"last"` -- Take the last token hidden state (like XLNet)
- `"first"` -- Take the first token hidden state (like Bert)
- `"mean"` -- Take the mean of all tokens hidden states
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
- `"attn"` -- Not implemented now, use multi-head attention
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
(otherwise to `config.hidden_size`).
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
another string or `None` will add no activation.
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
"""
def __init__(self, config: PretrainedConfig):
super().__init__()
self.summary_type = getattr(config, "summary_type", "last")
if self.summary_type == "attn":
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise NotImplementedError
self.summary = Identity()
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
num_classes = config.num_labels
else:
num_classes = config.hidden_size
self.summary = nn.Linear(config.hidden_size, num_classes)
activation_string = getattr(config, "summary_activation", None)
self.activation: Callable = get_activation(activation_string) if activation_string else Identity()
self.first_dropout = Identity()
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
self.first_dropout = nn.Dropout(config.summary_first_dropout)
self.last_dropout = Identity()
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
self.last_dropout = nn.Dropout(config.summary_last_dropout)
def forward(
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
) -> torch.FloatTensor:
"""
Compute a single vector summary of a sequence hidden states.
Args:
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
The hidden states of the last layer.
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
Returns:
`torch.FloatTensor`: The summary of the sequence hidden states.
"""
if self.summary_type == "last":
output = hidden_states[:, -1]
elif self.summary_type == "first":
output = hidden_states[:, 0]
elif self.summary_type == "mean":
output = hidden_states.mean(dim=1)
elif self.summary_type == "cls_index":
if cls_index is None:
cls_index = torch.full_like(
hidden_states[..., :1, :],
hidden_states.shape[-2] - 1,
dtype=torch.long,
)
else:
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
elif self.summary_type == "attn":
raise NotImplementedError
output = self.first_dropout(output)
output = self.summary(output)
output = self.activation(output)
output = self.last_dropout(output)
return output
def unwrap_model(model: nn.Module) -> nn.Module:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
Args:
model (`torch.nn.Module`): The model to unwrap.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model