mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
[GPTNeoX] Nit in config (#24349)
* add raise value error for attention size * nits to fix test_config * style
This commit is contained in:
parent
c2882403c4
commit
e5c760d636
@ -126,3 +126,7 @@ class GPTNeoXConfig(PretrainedConfig):
|
||||
self.use_cache = use_cache
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.use_parallel_residual = use_parallel_residual
|
||||
if self.hidden_size % self.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size is not divisble by the number of attention heads! Make sure to update them!"
|
||||
)
|
||||
|
@ -88,6 +88,10 @@ class GPTNeoXAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
if self.hidden_size % self.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size is not divisble by the number of attention heads! Make sure to update them"
|
||||
)
|
||||
self.head_size = self.hidden_size // self.num_attention_heads
|
||||
self.rotary_ndims = int(self.head_size * config.rotary_pct)
|
||||
max_positions = config.max_position_embeddings
|
||||
|
@ -253,7 +253,7 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = GPTNeoXModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=GPTNeoXConfig, hidden_size=37)
|
||||
self.config_tester = ConfigTester(self, config_class=GPTNeoXConfig, hidden_size=64, num_attention_heads=8)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
Loading…
Reference in New Issue
Block a user