mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Use head_dim if in config for RoPE (#32495)
* use head_dim if in config for RoPE * typo * simplify with getattr
This commit is contained in:
parent
c215523528
commit
5fd7ca7bc9
@ -58,7 +58,8 @@ def _compute_default_rope_parameters(
|
||||
elif config is not None:
|
||||
base = config.rope_theta
|
||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
|
||||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
dim = int(head_dim * partial_rotary_factor)
|
||||
|
||||
attention_factor = 1.0 # Unused in this type of RoPE
|
||||
|
||||
@ -143,7 +144,8 @@ def _compute_dynamic_ntk_parameters(
|
||||
elif config is not None:
|
||||
base = config.rope_theta
|
||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
|
||||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
dim = int(head_dim * partial_rotary_factor)
|
||||
max_position_embeddings = config.max_position_embeddings
|
||||
factor = config.rope_scaling["factor"]
|
||||
|
||||
@ -185,7 +187,8 @@ def _compute_yarn_parameters(
|
||||
|
||||
base = config.rope_theta
|
||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
|
||||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
dim = int(head_dim * partial_rotary_factor)
|
||||
max_position_embeddings = config.max_position_embeddings
|
||||
factor = config.rope_scaling["factor"]
|
||||
|
||||
@ -265,7 +268,8 @@ def _compute_longrope_parameters(
|
||||
|
||||
base = config.rope_theta
|
||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
|
||||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
dim = int(head_dim * partial_rotary_factor)
|
||||
long_factor = config.rope_scaling["long_factor"]
|
||||
short_factor = config.rope_scaling["short_factor"]
|
||||
factor = config.rope_scaling.get("factor")
|
||||
@ -450,7 +454,8 @@ def _validate_longrope_parameters(config: PretrainedConfig):
|
||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
||||
|
||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
|
||||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
dim = int(head_dim * partial_rotary_factor)
|
||||
|
||||
short_factor = rope_scaling.get("short_factor")
|
||||
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
|
||||
|
Loading…
Reference in New Issue
Block a user