mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Support for Falcon2-11B (#30771)
* remove unrelated changes * remove unrelated changes on phi and stable LM * add: Test for Falcon 10B * fix: formatting * fix: loading the falcon 10B in 8 bit precision using bitsanbytes. * fix: device placement * fix: broken tests. * fix: backwards compatibility for falcon 1B architecture. * chore: updated test. * chore: test_modeling_falcon.py to use the 11B model. * chore: minor edit * chore: formating. --------- Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Co-authored-by: ArthurZucker <arthur.zucker@gmail.com>
This commit is contained in:
parent
f63d822242
commit
e52741f601
@ -42,6 +42,9 @@ class FalconConfig(PretrainedConfig):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 71):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_ln_in_parallel_attn (`int`, *optional*):
|
||||
Set to 2 if separate layer norms are to be used for the MLP and the attention output when using parallel
|
||||
attention, otherwise, 1.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the layer normalization layers.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
@ -115,6 +118,7 @@ class FalconConfig(PretrainedConfig):
|
||||
hidden_size=4544,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=71,
|
||||
num_ln_in_parallel_attn=None,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
use_cache=True,
|
||||
@ -154,6 +158,7 @@ class FalconConfig(PretrainedConfig):
|
||||
self.multi_query = multi_query # Ignored when new_decoder_architecture is True
|
||||
self.parallel_attn = parallel_attn
|
||||
self.bias = bias
|
||||
self.num_ln_in_parallel_attn = num_ln_in_parallel_attn
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
|
@ -767,15 +767,20 @@ class FalconDecoderLayer(nn.Module):
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
self.config = config
|
||||
|
||||
if config.new_decoder_architecture:
|
||||
if config.num_ln_in_parallel_attn is None and config.new_decoder_architecture:
|
||||
config.num_ln_in_parallel_attn = 2
|
||||
|
||||
if not config.parallel_attn:
|
||||
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
else:
|
||||
if config.num_ln_in_parallel_attn == 2:
|
||||
# The layer norm before self-attention
|
||||
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
# The layer norm before the MLP
|
||||
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
else:
|
||||
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
if not config.parallel_attn:
|
||||
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -796,7 +801,7 @@ class FalconDecoderLayer(nn.Module):
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
if self.config.new_decoder_architecture:
|
||||
if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2:
|
||||
attention_layernorm_out = self.ln_attn(hidden_states)
|
||||
mlp_layernorm_out = self.ln_mlp(hidden_states)
|
||||
else:
|
||||
@ -826,6 +831,13 @@ class FalconDecoderLayer(nn.Module):
|
||||
)
|
||||
mlp_layernorm_out = self.post_attention_layernorm(residual)
|
||||
|
||||
if (
|
||||
self.config.new_decoder_architecture
|
||||
and self.config.parallel_attn
|
||||
and self.config.num_ln_in_parallel_attn == 1
|
||||
):
|
||||
mlp_layernorm_out = attention_layernorm_out
|
||||
|
||||
outputs = attn_outputs[1:]
|
||||
|
||||
# MLP.
|
||||
|
@ -602,6 +602,25 @@ class FalconLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(output_str, EXPECTED_OUTPUT)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_lm_generate_falcon_11b(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-11B", padding_side="left")
|
||||
model = FalconForCausalLM.from_pretrained(
|
||||
"tiiuae/falcon-11B", device_map={"": torch_device}, load_in_8bit=True
|
||||
)
|
||||
model.eval()
|
||||
inputs = tokenizer(
|
||||
"Two roads diverged in a yellow wood,", return_tensors="pt", return_token_type_ids=False
|
||||
).to(torch_device)
|
||||
|
||||
EXPECTED_OUTPUT = "Two roads diverged in a yellow wood,\nAnd sorry I could not travel both\n"
|
||||
|
||||
output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=9)
|
||||
output_str = tokenizer.batch_decode(output_ids)[0]
|
||||
|
||||
self.assertEqual(output_str, EXPECTED_OUTPUT)
|
||||
|
||||
@slow
|
||||
def test_lm_generation_big_models(self):
|
||||
# The big models are way too big for the CI, so we use tiny random models that resemble their
|
||||
@ -647,7 +666,7 @@ class FalconLanguageGenerationTest(unittest.TestCase):
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"tiiuae/falcon-7b",
|
||||
device_map="auto",
|
||||
device_map={"": torch_device},
|
||||
load_in_4bit=True,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user