[GPTNeoX] Nit in config (#24349)

* add raise value error for attention size

* nits to fix test_config

* style
This commit is contained in:
Arthur 2023-06-20 19:19:19 +02:00 committed by GitHub
parent c2882403c4
commit e5c760d636
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 1 deletions

View File

@ -126,3 +126,7 @@ class GPTNeoXConfig(PretrainedConfig):
self.use_cache = use_cache
self.tie_word_embeddings = tie_word_embeddings
self.use_parallel_residual = use_parallel_residual
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
"The hidden size is not divisble by the number of attention heads! Make sure to update them!"
)

View File

@ -88,6 +88,10 @@ class GPTNeoXAttention(nn.Module):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
"The hidden size is not divisble by the number of attention heads! Make sure to update them"
)
self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size * config.rotary_pct)
max_positions = config.max_position_embeddings

View File

@ -253,7 +253,7 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def setUp(self):
self.model_tester = GPTNeoXModelTester(self)
self.config_tester = ConfigTester(self, config_class=GPTNeoXConfig, hidden_size=37)
self.config_tester = ConfigTester(self, config_class=GPTNeoXConfig, hidden_size=64, num_attention_heads=8)
def test_config(self):
self.config_tester.run_common_tests()