mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Remove jnp.DeviceArray since it is deprecated. (#24875)
* Remove jnp.DeviceArray since it is deprecated. * Replace all instances of jnp.DeviceArray with jax.Array * Update src/transformers/models/bert/modeling_flax_bert.py --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
parent
fdd81aea12
commit
a6e6b1c622
@ -1467,8 +1467,8 @@ class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel):
|
||||
self,
|
||||
decoder_input_ids,
|
||||
max_length,
|
||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
attention_mask: Optional[jax.Array] = None,
|
||||
decoder_attention_mask: Optional[jax.Array] = None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -1960,7 +1960,7 @@ class FlaxBartForCausalLMModule(nn.Module):
|
||||
class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
|
||||
module_class = FlaxBartForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
|
@ -1677,7 +1677,7 @@ class FlaxBertForCausalLMModule(nn.Module):
|
||||
class FlaxBertForCausalLM(FlaxBertPreTrainedModel):
|
||||
module_class = FlaxBertForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
|
@ -2599,7 +2599,7 @@ class FlaxBigBirdForCausalLMModule(nn.Module):
|
||||
class FlaxBigBirdForCausalLM(FlaxBigBirdPreTrainedModel):
|
||||
module_class = FlaxBigBirdForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
|
@ -1443,8 +1443,8 @@ class FlaxBlenderbotForConditionalGeneration(FlaxBlenderbotPreTrainedModel):
|
||||
self,
|
||||
decoder_input_ids,
|
||||
max_length,
|
||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
attention_mask: Optional[jax.Array] = None,
|
||||
decoder_attention_mask: Optional[jax.Array] = None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -1441,8 +1441,8 @@ class FlaxBlenderbotSmallForConditionalGeneration(FlaxBlenderbotSmallPreTrainedM
|
||||
self,
|
||||
decoder_input_ids,
|
||||
max_length,
|
||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
attention_mask: Optional[jax.Array] = None,
|
||||
decoder_attention_mask: Optional[jax.Array] = None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -1565,7 +1565,7 @@ class FlaxElectraForCausalLMModule(nn.Module):
|
||||
class FlaxElectraForCausalLM(FlaxElectraPreTrainedModel):
|
||||
module_class = FlaxElectraForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
|
@ -722,8 +722,8 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
||||
self,
|
||||
decoder_input_ids,
|
||||
max_length,
|
||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
attention_mask: Optional[jax.Array] = None,
|
||||
decoder_attention_mask: Optional[jax.Array] = None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -742,7 +742,7 @@ class FlaxGPT2LMHeadModule(nn.Module):
|
||||
class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
|
||||
module_class = FlaxGPT2LMHeadModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
|
@ -654,7 +654,7 @@ class FlaxGPTNeoForCausalLMModule(nn.Module):
|
||||
class FlaxGPTNeoForCausalLM(FlaxGPTNeoPreTrainedModel):
|
||||
module_class = FlaxGPTNeoForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
|
@ -683,7 +683,7 @@ class FlaxGPTJForCausalLMModule(nn.Module):
|
||||
class FlaxGPTJForCausalLM(FlaxGPTJPreTrainedModel):
|
||||
module_class = FlaxGPTJForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
|
@ -2388,8 +2388,8 @@ class FlaxLongT5ForConditionalGeneration(FlaxLongT5PreTrainedModel):
|
||||
self,
|
||||
decoder_input_ids,
|
||||
max_length,
|
||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
attention_mask: Optional[jax.Array] = None,
|
||||
decoder_attention_mask: Optional[jax.Array] = None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -1436,8 +1436,8 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel):
|
||||
self,
|
||||
decoder_input_ids,
|
||||
max_length,
|
||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
attention_mask: Optional[jax.Array] = None,
|
||||
decoder_attention_mask: Optional[jax.Array] = None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -1502,8 +1502,8 @@ class FlaxMBartForConditionalGeneration(FlaxMBartPreTrainedModel):
|
||||
self,
|
||||
decoder_input_ids,
|
||||
max_length,
|
||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
attention_mask: Optional[jax.Array] = None,
|
||||
decoder_attention_mask: Optional[jax.Array] = None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -763,7 +763,7 @@ class FlaxOPTForCausalLMModule(nn.Module):
|
||||
class FlaxOPTForCausalLM(FlaxOPTPreTrainedModel):
|
||||
module_class = FlaxOPTForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
|
@ -1450,8 +1450,8 @@ class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel):
|
||||
self,
|
||||
decoder_input_ids,
|
||||
max_length,
|
||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
attention_mask: Optional[jax.Array] = None,
|
||||
decoder_attention_mask: Optional[jax.Array] = None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -1452,7 +1452,7 @@ class FlaxRobertaForCausalLMModule(nn.Module):
|
||||
class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel):
|
||||
module_class = FlaxRobertaForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
|
@ -1478,7 +1478,7 @@ class FlaxRobertaPreLayerNormForCausalLMModule(nn.Module):
|
||||
class FlaxRobertaPreLayerNormForCausalLM(FlaxRobertaPreLayerNormPreTrainedModel):
|
||||
module_class = FlaxRobertaPreLayerNormForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
|
@ -745,8 +745,8 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
||||
self,
|
||||
decoder_input_ids,
|
||||
max_length,
|
||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
attention_mask: Optional[jax.Array] = None,
|
||||
decoder_attention_mask: Optional[jax.Array] = None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -1740,8 +1740,8 @@ class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel):
|
||||
self,
|
||||
decoder_input_ids,
|
||||
max_length,
|
||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
attention_mask: Optional[jax.Array] = None,
|
||||
decoder_attention_mask: Optional[jax.Array] = None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -688,7 +688,7 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
||||
self,
|
||||
decoder_input_ids,
|
||||
max_length,
|
||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
decoder_attention_mask: Optional[jax.Array] = None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -1448,8 +1448,8 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
|
||||
self,
|
||||
decoder_input_ids,
|
||||
max_length,
|
||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
attention_mask: Optional[jax.Array] = None,
|
||||
decoder_attention_mask: Optional[jax.Array] = None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -766,7 +766,7 @@ class FlaxXGLMForCausalLMModule(nn.Module):
|
||||
class FlaxXGLMForCausalLM(FlaxXGLMPreTrainedModel):
|
||||
module_class = FlaxXGLMForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
|
@ -1469,7 +1469,7 @@ class FlaxXLMRobertaForCausalLMModule(nn.Module):
|
||||
class FlaxXLMRobertaForCausalLM(FlaxXLMRobertaPreTrainedModel):
|
||||
module_class = FlaxXLMRobertaForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
|
@ -1469,7 +1469,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
|
||||
module_class = Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
@ -2969,8 +2969,8 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(Flax{{coo
|
||||
self,
|
||||
decoder_input_ids,
|
||||
max_length,
|
||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
attention_mask: Optional[jax.Array] = None,
|
||||
decoder_attention_mask: Optional[jax.Array] = None,
|
||||
encoder_outputs=None,
|
||||
**kwargs
|
||||
):
|
||||
|
Loading…
Reference in New Issue
Block a user