transformers/src/transformers/models/__init__.py
Daniel Stancl a72f1c9f5b
Add LongT5 model (#16792)
* Initial commit

* Make some fixes

* Make PT model full forward pass

* Drop TF & Flax implementation, fix copies etc

* Add Flax model and update some corresponding stuff

* Drop some TF things

* Update config and flax local attn

* Add encoder_attention_type to config

* .

* Update docs

* Do some cleansing

* Fix some issues -> make style; add some docs

* Fix position_bias + mask addition + Update tests

* Fix repo consistency

* Fix model consistency by removing flax operation over attn_mask

* [WIP] Add PT TGlobal LongT5

* .

* [WIP] Add flax tglobal model

* [WIP] Update flax model to use the right attention type in the encoder

* Fix flax tglobal model forward pass

* Make the use of global_relative_attention_bias

* Add test suites for TGlobal model

* Fix minor bugs, clean code

* Fix pt-flax equivalence though not convinced with correctness

* Fix LocalAttn implementation to match the original impl. + update READMEs

* Few updates

* Update: [Flax] improve large model init and loading #16148

* Add ckpt conversion script accoring to #16853 + handle torch device placement

* Minor updates to conversion script.

* Typo: AutoModelForSeq2SeqLM -> FlaxAutoModelForSeq2SeqLM

* gpu support + dtype fix

* Apply some suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* * Remove (de)parallelize stuff
* Edit shape comments
* Update README.md
* make fix-copies

* Remove caching logic for local & tglobal attention

* Apply another batch of suggestions from code review

* Add missing checkpoints
* Format converting scripts
* Drop (de)parallelize links from longT5 mdx

* Fix converting script + revert config file change

* Revert "Remove caching logic for local & tglobal attention"

This reverts commit 2a619828f6ddc3e65bd9bb1725a12b77fa883a46.

* Stash caching logic in Flax model

* Make side relative bias used always

* Drop caching logic in PT model

* Return side bias as it was

* Drop all remaining model parallel logic

* Remove clamp statements

* Move test files to the proper place

* Update docs with new version of hf-doc-builder

* Fix test imports

* Make some minor improvements

* Add missing checkpoints to docs
* Make TGlobal model compatible with torch.onnx.export
* Replace some np.ndarray with jnp.ndarray

* Fix TGlobal for ONNX conversion + update docs

* fix _make_global_fixed_block_ids and masked neg  value

* update flax model

* style and quality

* fix imports

* remove load_tf_weights_in_longt5 from init and fix copies

* add slow test for TGlobal model

* typo fix

* Drop obsolete is_parallelizable and one warning

* Update __init__ files to fix repo-consistency

* fix pipeline test

* Fix some device placements

* [wip]: Update tests -- need to generate summaries to update expected_summary

* Fix quality

* Update LongT5 model card

* Update (slow) summarization tests

* make style

* rename checkpoitns

* finish

* fix flax tests

Co-authored-by: phungvanduy <pvduy23@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: patil-suraj <surajp815@gmail.com>
2022-06-13 22:36:58 +02:00

152 lines
2.5 KiB
Python

# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import (
albert,
auto,
bart,
barthez,
bartpho,
beit,
bert,
bert_generation,
bert_japanese,
bertweet,
big_bird,
bigbird_pegasus,
blenderbot,
blenderbot_small,
bloom,
bort,
byt5,
camembert,
canine,
clip,
convbert,
convnext,
cpm,
ctrl,
cvt,
data2vec,
deberta,
deberta_v2,
decision_transformer,
deit,
detr,
dialogpt,
distilbert,
dit,
dpr,
dpt,
electra,
encoder_decoder,
flaubert,
flava,
fnet,
fsmt,
funnel,
glpn,
gpt2,
gpt_neo,
gpt_neox,
gptj,
herbert,
hubert,
ibert,
imagegpt,
layoutlm,
layoutlmv2,
layoutlmv3,
layoutxlm,
led,
levit,
longformer,
longt5,
luke,
lxmert,
m2m_100,
marian,
maskformer,
mbart,
mbart50,
mctct,
megatron_bert,
megatron_gpt2,
mluke,
mmbt,
mobilebert,
mpnet,
mt5,
nystromformer,
openai,
opt,
pegasus,
perceiver,
phobert,
plbart,
poolformer,
prophetnet,
qdqbert,
rag,
realm,
reformer,
regnet,
rembert,
resnet,
retribert,
roberta,
roformer,
segformer,
sew,
sew_d,
speech_encoder_decoder,
speech_to_text,
speech_to_text_2,
splinter,
squeezebert,
swin,
t5,
tapas,
tapex,
trajectory_transformer,
transfo_xl,
trocr,
unispeech,
unispeech_sat,
van,
vilt,
vision_encoder_decoder,
vision_text_dual_encoder,
visual_bert,
vit,
vit_mae,
wav2vec2,
wav2vec2_conformer,
wav2vec2_phoneme,
wav2vec2_with_lm,
wavlm,
xglm,
xlm,
xlm_prophetnet,
xlm_roberta,
xlm_roberta_xl,
xlnet,
yolos,
yoso,
)