mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00

* Result of black 23.1 * Update target to Python 3.7 * Switch flake8 to ruff * Configure isort * Configure isort * Apply isort with line limit * Put the right black version * adapt black in check copies * Fix copies
230 lines
9.1 KiB
Python
230 lines
9.1 KiB
Python
# coding=utf-8
|
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" Testing suite for the PyTorch GPTNeoX model. """
|
|
|
|
|
|
import unittest
|
|
|
|
from transformers import GPTNeoXConfig, is_torch_available
|
|
from transformers.testing_utils import require_torch, torch_device
|
|
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import GPTNeoXForCausalLM, GPTNeoXModel
|
|
|
|
|
|
class GPTNeoXModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=13,
|
|
seq_length=7,
|
|
is_training=True,
|
|
use_input_mask=True,
|
|
use_token_type_ids=True,
|
|
use_labels=True,
|
|
vocab_size=99,
|
|
hidden_size=32,
|
|
num_hidden_layers=5,
|
|
num_attention_heads=4,
|
|
intermediate_size=37,
|
|
hidden_act="gelu",
|
|
hidden_dropout_prob=0.1,
|
|
attention_probs_dropout_prob=0.1,
|
|
max_position_embeddings=512,
|
|
type_vocab_size=16,
|
|
type_sequence_label_size=2,
|
|
initializer_range=0.02,
|
|
num_labels=3,
|
|
num_choices=4,
|
|
scope=None,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.use_input_mask = use_input_mask
|
|
self.use_token_type_ids = use_token_type_ids
|
|
self.use_labels = use_labels
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.intermediate_size = intermediate_size
|
|
self.hidden_act = hidden_act
|
|
self.hidden_dropout_prob = hidden_dropout_prob
|
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.type_vocab_size = type_vocab_size
|
|
self.type_sequence_label_size = type_sequence_label_size
|
|
self.initializer_range = initializer_range
|
|
self.num_labels = num_labels
|
|
self.num_choices = num_choices
|
|
self.scope = scope
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
|
|
input_mask = None
|
|
if self.use_input_mask:
|
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
|
|
|
token_labels = None
|
|
if self.use_labels:
|
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
|
|
config = self.get_config()
|
|
|
|
return config, input_ids, input_mask, token_labels
|
|
|
|
def get_config(self):
|
|
return GPTNeoXConfig(
|
|
vocab_size=self.vocab_size,
|
|
hidden_size=self.hidden_size,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
num_attention_heads=self.num_attention_heads,
|
|
intermediate_size=self.intermediate_size,
|
|
hidden_act=self.hidden_act,
|
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
type_vocab_size=self.type_vocab_size,
|
|
is_decoder=False,
|
|
initializer_range=self.initializer_range,
|
|
)
|
|
|
|
def prepare_config_and_inputs_for_decoder(self):
|
|
config, input_ids, input_mask, token_labels = self.prepare_config_and_inputs()
|
|
|
|
config.is_decoder = True
|
|
|
|
return config, input_ids, input_mask, token_labels
|
|
|
|
def create_and_check_model(self, config, input_ids, input_mask):
|
|
model = GPTNeoXModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
_ = model(input_ids, attention_mask=input_mask)
|
|
result = model(input_ids)
|
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
|
|
def create_and_check_model_as_decoder(self, config, input_ids, input_mask):
|
|
config.add_cross_attention = True
|
|
model = GPTNeoXModel(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask)
|
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
|
|
def create_and_check_for_causal_lm(self, config, input_ids, input_mask, token_labels):
|
|
model = GPTNeoXForCausalLM(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
|
|
|
def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask):
|
|
config.is_decoder = True
|
|
model = GPTNeoXForCausalLM(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
# first forward pass
|
|
outputs = model(input_ids, attention_mask=input_mask, use_cache=True)
|
|
past_key_values = outputs.past_key_values
|
|
|
|
# create hypothetical multiple next token and extent to next_input_ids
|
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
|
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
|
|
|
# append to next input_ids and
|
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
|
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
|
|
|
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask, output_hidden_states=True)
|
|
output_from_no_past = output_from_no_past["hidden_states"][0]
|
|
output_from_past = model(
|
|
next_tokens,
|
|
attention_mask=next_attention_mask,
|
|
past_key_values=past_key_values,
|
|
output_hidden_states=True,
|
|
)["hidden_states"][0]
|
|
|
|
# select random slice
|
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
|
|
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
|
|
|
# test that outputs are equal for slice
|
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
config, input_ids, input_mask, token_labels = config_and_inputs
|
|
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class GPTNeoXModelTest(ModelTesterMixin, unittest.TestCase):
|
|
all_model_classes = (GPTNeoXModel, GPTNeoXForCausalLM) if is_torch_available() else ()
|
|
all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else ()
|
|
test_pruning = False
|
|
test_missing_keys = False
|
|
test_model_parallel = False
|
|
test_head_masking = False
|
|
|
|
def setUp(self):
|
|
self.model_tester = GPTNeoXModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=GPTNeoXConfig, hidden_size=37)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
def test_model(self):
|
|
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(config, input_ids, input_mask)
|
|
|
|
def test_model_as_decoder(self):
|
|
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs_for_decoder()
|
|
self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask)
|
|
|
|
def test_model_as_decoder_with_default_input_mask(self):
|
|
# This regression test was failing with PyTorch < 1.3
|
|
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs_for_decoder()
|
|
|
|
input_mask = None
|
|
|
|
self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask)
|
|
|
|
def test_decoder_model_past_large_inputs(self):
|
|
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_decoder_model_past_large_inputs(config, input_ids, input_mask)
|
|
|
|
def test_model_for_causal_lm(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
|
|
|
@unittest.skip(reason="Feed forward chunking is not implemented")
|
|
def test_feed_forward_chunking(self):
|
|
pass
|