Remove jnp.DeviceArray since it is deprecated. (#24875)

* Remove jnp.DeviceArray since it is deprecated.

* Replace all instances of jnp.DeviceArray with jax.Array

* Update src/transformers/models/bert/modeling_flax_bert.py

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
mariecwhite 2023-08-05 03:36:57 +10:00 committed by GitHub
parent fdd81aea12
commit a6e6b1c622
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 38 additions and 38 deletions

View File

@ -1467,8 +1467,8 @@ class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
@ -1960,7 +1960,7 @@ class FlaxBartForCausalLMModule(nn.Module):
class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
module_class = FlaxBartForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

View File

@ -1677,7 +1677,7 @@ class FlaxBertForCausalLMModule(nn.Module):
class FlaxBertForCausalLM(FlaxBertPreTrainedModel):
module_class = FlaxBertForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

View File

@ -2599,7 +2599,7 @@ class FlaxBigBirdForCausalLMModule(nn.Module):
class FlaxBigBirdForCausalLM(FlaxBigBirdPreTrainedModel):
module_class = FlaxBigBirdForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

View File

@ -1443,8 +1443,8 @@ class FlaxBlenderbotForConditionalGeneration(FlaxBlenderbotPreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):

View File

@ -1441,8 +1441,8 @@ class FlaxBlenderbotSmallForConditionalGeneration(FlaxBlenderbotSmallPreTrainedM
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):

View File

@ -1565,7 +1565,7 @@ class FlaxElectraForCausalLMModule(nn.Module):
class FlaxElectraForCausalLM(FlaxElectraPreTrainedModel):
module_class = FlaxElectraForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

View File

@ -722,8 +722,8 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):

View File

@ -742,7 +742,7 @@ class FlaxGPT2LMHeadModule(nn.Module):
class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
module_class = FlaxGPT2LMHeadModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

View File

@ -654,7 +654,7 @@ class FlaxGPTNeoForCausalLMModule(nn.Module):
class FlaxGPTNeoForCausalLM(FlaxGPTNeoPreTrainedModel):
module_class = FlaxGPTNeoForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

View File

@ -683,7 +683,7 @@ class FlaxGPTJForCausalLMModule(nn.Module):
class FlaxGPTJForCausalLM(FlaxGPTJPreTrainedModel):
module_class = FlaxGPTJForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

View File

@ -2388,8 +2388,8 @@ class FlaxLongT5ForConditionalGeneration(FlaxLongT5PreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):

View File

@ -1436,8 +1436,8 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):

View File

@ -1502,8 +1502,8 @@ class FlaxMBartForConditionalGeneration(FlaxMBartPreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):

View File

@ -763,7 +763,7 @@ class FlaxOPTForCausalLMModule(nn.Module):
class FlaxOPTForCausalLM(FlaxOPTPreTrainedModel):
module_class = FlaxOPTForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

View File

@ -1450,8 +1450,8 @@ class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):

View File

@ -1452,7 +1452,7 @@ class FlaxRobertaForCausalLMModule(nn.Module):
class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel):
module_class = FlaxRobertaForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

View File

@ -1478,7 +1478,7 @@ class FlaxRobertaPreLayerNormForCausalLMModule(nn.Module):
class FlaxRobertaPreLayerNormForCausalLM(FlaxRobertaPreLayerNormPreTrainedModel):
module_class = FlaxRobertaPreLayerNormForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

View File

@ -745,8 +745,8 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):

View File

@ -1740,8 +1740,8 @@ class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):

View File

@ -688,7 +688,7 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
self,
decoder_input_ids,
max_length,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):

View File

@ -1448,8 +1448,8 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):

View File

@ -766,7 +766,7 @@ class FlaxXGLMForCausalLMModule(nn.Module):
class FlaxXGLMForCausalLM(FlaxXGLMPreTrainedModel):
module_class = FlaxXGLMForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

View File

@ -1469,7 +1469,7 @@ class FlaxXLMRobertaForCausalLMModule(nn.Module):
class FlaxXLMRobertaForCausalLM(FlaxXLMRobertaPreTrainedModel):
module_class = FlaxXLMRobertaForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

View File

@ -1469,7 +1469,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
module_class = Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
@ -2969,8 +2969,8 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(Flax{{coo
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs
):