Fix GPT-NeoX-20B past handling, attention computation (#17811)

* Fix GPT-NeoX-20B past handling, swap attention computation to hopefully avoid NaN, update docs

* 20B tests
This commit is contained in:
Jason Phang 2022-06-30 05:47:40 -07:00 committed by GitHub
parent 692e61e91a
commit 205bc4152c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 18 deletions

View File

@ -38,32 +38,28 @@ class GPTNeoXConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 30522):
vocab_size (`int`, *optional*, defaults to 50432):
Vocabulary size of the GPTNeoX model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GPTNeoXModel`].
hidden_size (`int`, *optional*, defaults to 768):
hidden_size (`int`, *optional*, defaults to 6144):
Dimension of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12):
num_hidden_layers (`int`, *optional*, defaults to 44):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
num_attention_heads (`int`, *optional*, defaults to 64):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 3072):
intermediate_size (`int`, *optional*, defaults to 24576):
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention probabilities.
rotary_pct (`float`, *optional*, defaults to 0.25):
percentage of hidden dimensions to allocate to rotary embeddings
rotary_emb_base (`int`, *optional*, defaults to 10000)
base for computing rotary embeddings frequency
max_position_embeddings (`int`, *optional*, defaults to 512):
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
initializer_range (`float`, *optional*, defaults to 0.02):
initializer_range (`float`, *optional*, defaults to 1e-5):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
@ -94,8 +90,6 @@ class GPTNeoXConfig(PretrainedConfig):
num_attention_heads=64,
intermediate_size=24576,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
rotary_pct=0.25,
rotary_emb_base=10000,
max_position_embeddings=2048,
@ -115,8 +109,6 @@ class GPTNeoXConfig(PretrainedConfig):
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base
self.initializer_range = initializer_range

View File

@ -195,7 +195,20 @@ class GPTNeoXAttention(nn.Module):
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
attn_scores = torch.einsum("bik,bjk->bij", query, key) / self.norm_factor
attn_scores = torch.zeros(
batch_size * num_attention_heads,
query_length,
key_length,
dtype=query.dtype,
device=key.device,
)
attn_scores = torch.baddbmm(
attn_scores,
query,
key.transpose(1, 2),
beta=1.0,
alpha=(1.0 / self.norm_factor),
)
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
mask_value = torch.finfo(attn_scores.dtype).min
@ -637,7 +650,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
if past and past[0] is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}

View File

@ -226,6 +226,10 @@ class GPTNeoXModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
@unittest.skip(reason="Feed forward chunking is not implemented")
def test_feed_forward_chunking(self):
pass
@slow
def test_model_from_pretrained(self):
for model_name in GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
@ -247,7 +251,7 @@ class GPTNeoXModelIntegrationTest(unittest.TestCase):
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[[33.8045, 2.3958, 34.2816], [63.7805, 4.8332, 63.5882], [66.9116, 5.2198, 63.1185]]]
[[[33.5938, 2.3789, 34.0312], [63.4688, 4.8164, 63.3438], [66.8750, 5.2422, 63.0625]]]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))