Update quality tooling for formatting (#21480)

* Result of black 23.1

* Update target to Python 3.7

* Switch flake8 to ruff

* Configure isort

* Configure isort

* Apply isort with line limit

* Put the right black version

* adapt black in check copies

* Fix copies
This commit is contained in:
Sylvain Gugger 2023-02-06 18:10:56 -05:00 committed by GitHub
parent b7bb2b59f7
commit 6f79d26442
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1211 changed files with 1532 additions and 2687 deletions

View File

@ -134,11 +134,10 @@ jobs:
command: pip freeze | tee installed.txt
- store_artifacts:
path: ~/transformers/installed.txt
- run: black --check --preview examples tests src utils
- run: isort --check-only examples tests src utils
- run: black --check examples tests src utils
- run: ruff examples tests src utils
- run: python utils/custom_init_isort.py --check_only
- run: python utils/sort_auto_mappings.py --check_only
- run: flake8 examples tests src utils
- run: doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
- run: python utils/check_doc_toc.py

View File

@ -9,9 +9,8 @@ modified_only_fixup:
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
@if test -n "$(modified_py_files)"; then \
echo "Checking/fixing $(modified_py_files)"; \
black --preview $(modified_py_files); \
isort $(modified_py_files); \
flake8 $(modified_py_files); \
black $(modified_py_files); \
ruff $(modified_py_files) --fix; \
else \
echo "No library .py files were modified"; \
fi
@ -48,11 +47,10 @@ repo-consistency:
# this target runs checks on all files
quality:
black --check --preview $(check_dirs)
isort --check-only $(check_dirs)
black --check $(check_dirs)
python utils/custom_init_isort.py --check_only
python utils/sort_auto_mappings.py --check_only
flake8 $(check_dirs)
ruff $(check_dirs)
doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
python utils/check_doc_toc.py
@ -67,8 +65,8 @@ extra_style_checks:
# this target runs checks on all files and potentially modifies some of them
style:
black --preview $(check_dirs)
isort $(check_dirs)
black $(check_dirs)
ruff $(check_dirs) --fix
${MAKE} autogenerate_code
${MAKE} extra_style_checks

View File

@ -96,7 +96,7 @@ while True:
queues.append(rq)
strings
outs = pipe(strings, batch_size=len(strings))
for (rq, out) in zip(queues, outs):
for rq, out in zip(queues, outs):
await rq.put(out)
```

View File

@ -166,7 +166,6 @@ Unlike other data collators, this specific data collator needs to apply a differ
>>> @dataclass
... class DataCollatorCTCWithPadding:
... processor: AutoProcessor
... padding: Union[bool, str] = "longest"

View File

@ -213,7 +213,6 @@ The `image_processor` expects the annotations to be in the following format: `{'
```py
>>> def formatted_anns(image_id, category, area, bbox):
... annotations = []
... for i in range(0, len(category)):
... new_ann = {
@ -399,6 +398,7 @@ First, prepare the `cppe5["test"]` set: format the annotations and save the data
```py
>>> import json
>>> # format annotations the same as for training, no need for data augmentation
>>> def val_formatted_anns(image_id, objects):
... annotations = []

View File

@ -159,7 +159,6 @@ A diferencia de otros collators de datos, este tiene que aplicarle un método de
>>> @dataclass
... class DataCollatorCTCWithPadding:
... processor: AutoProcessor
... padding: Union[bool, str] = "longest"

View File

@ -29,23 +29,23 @@ from pathlib import Path
from typing import Callable, Optional
import datasets
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
from datasets import Dataset, load_dataset
from PIL import Image
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
import optax
import transformers
from datasets import Dataset, load_dataset
from filelock import FileLock
from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository, create_repo
from PIL import Image
from tqdm import tqdm
import transformers
from transformers import (
AutoImageProcessor,
AutoTokenizer,

View File

@ -32,20 +32,20 @@ from itertools import chain
from pathlib import Path
from typing import Dict, List, Optional
import nltk
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import flax
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
from datasets import load_dataset
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,

View File

@ -34,19 +34,19 @@ from pathlib import Path
from typing import Callable, Optional
import datasets
import numpy as np
from datasets import Dataset, load_dataset
from tqdm import tqdm
import jax
import jax.numpy as jnp
import numpy as np
import optax
import transformers
from datasets import Dataset, load_dataset
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
import transformers
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,

View File

@ -34,19 +34,19 @@ from itertools import chain
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
from datasets import load_dataset
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,

View File

@ -33,19 +33,19 @@ from itertools import chain
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
from datasets import load_dataset
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,

View File

@ -31,20 +31,21 @@ from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple
import datasets
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import numpy as np
import optax
import transformers
from datasets import load_dataset
from flax import struct, traverse_util
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
from utils_qa import postprocess_qa_predictions
import transformers
from transformers import (
AutoConfig,
AutoTokenizer,
@ -55,7 +56,6 @@ from transformers import (
is_tensorboard_available,
)
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from utils_qa import postprocess_qa_predictions
logger = logging.getLogger(__name__)
@ -301,6 +301,7 @@ class DataTrainingArguments:
# endregion
# region Create a train state
def create_train_state(
model: FlaxAutoModelForQuestionAnswering,
@ -387,6 +388,7 @@ def create_learning_rate_fn(
# endregion
# region train data iterator
def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
"""Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices."""
@ -405,6 +407,7 @@ def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
# endregion
# region eval data iterator
def eval_data_collator(dataset: Dataset, batch_size: int):
"""Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop."""
@ -934,7 +937,6 @@ def main():
total_steps = step_per_epoch * num_epochs
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
train_metrics = []
@ -975,7 +977,6 @@ def main():
and (cur_step % training_args.eval_steps == 0 or cur_step % step_per_epoch == 0)
and cur_step > 0
):
eval_metrics = {}
all_start_logits = []
all_end_logits = []

View File

@ -31,22 +31,22 @@ from pathlib import Path
from typing import Callable, Optional
import datasets
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
from datasets import Dataset, load_dataset
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
import optax
import transformers
from datasets import Dataset, load_dataset
from filelock import FileLock
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
import transformers
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,

View File

@ -26,20 +26,20 @@ from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple
import datasets
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import numpy as np
import optax
import transformers
from datasets import load_dataset
from flax import struct, traverse_util
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
import transformers
from transformers import (
AutoConfig,
AutoTokenizer,
@ -586,7 +586,6 @@ def main():
total_steps = steps_per_epoch * num_epochs
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (0/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
train_metrics = []
@ -623,7 +622,6 @@ def main():
train_metrics = []
if (cur_step % training_args.eval_steps == 0 or cur_step % steps_per_epoch == 0) and cur_step > 0:
# evaluate
eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size)
for batch in tqdm(

View File

@ -28,20 +28,20 @@ from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple
import datasets
import numpy as np
from datasets import ClassLabel, load_dataset
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import numpy as np
import optax
import transformers
from datasets import ClassLabel, load_dataset
from flax import struct, traverse_util
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
import transformers
from transformers import (
AutoConfig,
AutoTokenizer,
@ -695,7 +695,6 @@ def main():
total_steps = step_per_epoch * num_epochs
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
train_metrics = []
@ -731,7 +730,6 @@ def main():
train_metrics = []
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
eval_metrics = {}
# evaluate
for batch in tqdm(

View File

@ -29,21 +29,22 @@ from enum import Enum
from pathlib import Path
from typing import Callable, Optional
import jax
import jax.numpy as jnp
import optax
# for dataset and preprocessing
import torch
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import jax
import jax.numpy as jnp
import optax
import transformers
from flax import jax_utils
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
import transformers
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,

View File

@ -22,6 +22,7 @@ from dataclasses import dataclass, field
from typing import Dict, Optional
import numpy as np
from utils_multiple_choice import MultipleChoiceDataset, Split, processors
import transformers
from transformers import (
@ -36,7 +37,6 @@ from transformers import (
set_seed,
)
from transformers.trainer_utils import is_main_process
from utils_multiple_choice import MultipleChoiceDataset, Split, processors
logger = logging.getLogger(__name__)

View File

@ -26,8 +26,8 @@ from enum import Enum
from typing import List, Optional
import tqdm
from filelock import FileLock
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
@ -112,7 +112,6 @@ if is_torch_available():
# and the others will use the cache.
lock_path = cached_features_file + ".lock"
with FileLock(lock_path):
if os.path.exists(cached_features_file) and not overwrite_cache:
logger.info(f"Loading features from cached file {cached_features_file}")
self.features = torch.load(cached_features_file)

View File

@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule):
config=None,
tokenizer=None,
model=None,
**config_kwargs
**config_kwargs,
):
"""Initialize a model, tokenizer and config."""
super().__init__()
@ -346,7 +346,7 @@ def generic_train(
extra_callbacks=[],
checkpoint_callback=None,
logging_callback=None,
**extra_train_kwargs
**extra_train_kwargs,
):
pl.seed_everything(args.seed)

View File

@ -7,21 +7,19 @@ from argparse import Namespace
import numpy as np
import torch
from lightning_base import BaseTransformer, add_generic_args, generic_train
from torch.utils.data import DataLoader, TensorDataset
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import glue_compute_metrics as compute_metrics
from transformers import glue_convert_examples_to_features as convert_examples_to_features
from transformers import glue_output_modes
from transformers import glue_output_modes, glue_tasks_num_labels
from transformers import glue_processors as processors
from transformers import glue_tasks_num_labels
logger = logging.getLogger(__name__)
class GLUETransformer(BaseTransformer):
mode = "sequence-classification"
def __init__(self, hparams):

View File

@ -7,11 +7,10 @@ from importlib import import_module
import numpy as np
import torch
from lightning_base import BaseTransformer, add_generic_args, generic_train
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, TensorDataset
from lightning_base import BaseTransformer, add_generic_args, generic_train
from utils_ner import TokenClassificationTask

View File

@ -172,7 +172,6 @@ def train(args, train_dataset, model, tokenizer):
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1

View File

@ -30,9 +30,10 @@ from transformers import (
DataCollatorWithPadding,
HfArgumentParser,
SquadDataset,
Trainer,
TrainingArguments,
)
from transformers import SquadDataTrainingArguments as DataTrainingArguments
from transformers import Trainer, TrainingArguments
from transformers.trainer_utils import is_main_process

View File

@ -4,6 +4,7 @@ import json
from typing import List
from ltp import LTP
from transformers import BertTokenizer
@ -93,7 +94,6 @@ def prepare_ref(lines: List[str], ltp_tokenizer: LTP, bert_tokenizer: BertTokeni
ref_ids = []
for input_ids, chinese_word in zip(bert_res, ltp_res):
input_tokens = []
for id in input_ids:
token = bert_tokenizer._convert_id_to_token(id)

View File

@ -19,9 +19,10 @@ import sys
from dataclasses import dataclass, field
from typing import Optional
import transformers
from seq2seq_trainer import Seq2SeqTrainer
from seq2seq_training_args import Seq2SeqTrainingArguments
import transformers
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
@ -337,7 +338,6 @@ def main():
metrics["val_loss"] = round(metrics["val_loss"], 4)
if trainer.is_world_process_zero():
handle_metrics("val", metrics, training_args.output_dir)
all_metrics.update(metrics)

View File

@ -16,8 +16,8 @@ from collections import defaultdict
from pathlib import Path
import pandas as pd
from rouge_cli import calculate_rouge_path
from utils import calculate_rouge
@ -87,7 +87,6 @@ def test_single_sent_scores_dont_depend_on_newline_sep():
def test_pegasus_newline():
pred = [
"""" "a person who has such a video needs to immediately give it to the investigators," prosecutor says .<n> "it is a very disturbing scene," editor-in-chief of bild online tells "erin burnett: outfront" """
]

View File

@ -17,11 +17,11 @@ from pathlib import Path
import numpy as np
import pytest
from torch.utils.data import DataLoader
from pack_dataset import pack_data_dir
from parameterized import parameterized
from save_len_file import save_len_file
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from transformers.models.mbart.modeling_mbart import shift_tokens_right
from transformers.testing_utils import TestCasePlus, slow

View File

@ -18,6 +18,7 @@ import json
import unittest
from parameterized import parameterized
from transformers import FSMTForConditionalGeneration, FSMTTokenizer
from transformers.testing_utils import get_tests_dir, require_torch, slow, torch_device
from utils import calculate_bleu

View File

@ -21,6 +21,7 @@ from unittest.mock import patch
from parameterized import parameterized
from run_eval import run_generate
from run_eval_search import run_search
from transformers.testing_utils import CaptureStdout, TestCasePlus, slow
from utils import ROUGE_KEYS

View File

@ -29,7 +29,6 @@ from transformers import AutoTokenizer
def pack_examples(tok, src_examples, tgt_examples, max_tokens=1024):
finished_src, finished_tgt = [], []
sorted_examples = list(zip(src_examples, tgt_examples))

View File

@ -20,6 +20,7 @@ import sys
from collections import OrderedDict
from run_eval import datetime_now, run_generate
from utils import ROUGE_KEYS

View File

@ -17,6 +17,7 @@ from dataclasses import dataclass, field
from typing import Optional
from seq2seq_trainer import arg_to_scheduler
from transformers import TrainingArguments

View File

@ -29,10 +29,10 @@ import torch
import torch.distributed as dist
from rouge_score import rouge_scorer, scoring
from sacrebleu import corpus_bleu
from sentence_splitter import add_newline_to_end_of_each_sentence
from torch import nn
from torch.utils.data import Dataset, Sampler
from sentence_splitter import add_newline_to_end_of_each_sentence
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
from transformers.models.bart.modeling_bart import shift_tokens_right
from transformers.utils import cached_property
@ -132,7 +132,7 @@ class AbstractSeq2SeqDataset(Dataset):
type_path="train",
n_obs=None,
prefix="",
**dataset_kwargs
**dataset_kwargs,
):
super().__init__()
self.src_file = Path(data_dir).joinpath(type_path + ".source")

View File

@ -24,6 +24,7 @@ from typing import Dict, List, Optional, Tuple
import numpy as np
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
from torch import nn
from utils_ner import Split, TokenClassificationDataset, TokenClassificationTask
import transformers
from transformers import (
@ -38,7 +39,6 @@ from transformers import (
set_seed,
)
from transformers.trainer_utils import is_main_process
from utils_ner import Split, TokenClassificationDataset, TokenClassificationTask
logger = logging.getLogger(__name__)

View File

@ -24,6 +24,7 @@ from typing import Dict, List, Optional, Tuple
import numpy as np
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from utils_ner import Split, TFTokenClassificationDataset, TokenClassificationTask
from transformers import (
AutoConfig,
@ -35,7 +36,6 @@ from transformers import (
TFTrainingArguments,
)
from transformers.utils import logging as hf_logging
from utils_ner import Split, TFTokenClassificationDataset, TokenClassificationTask
hf_logging.set_verbosity_info()

View File

@ -3,7 +3,6 @@ import os
from typing import List, TextIO, Union
from conllu import parse_incr
from utils_ner import InputExample, Split, TokenClassificationTask

View File

@ -23,6 +23,7 @@ from enum import Enum
from typing import List, Optional, Union
from filelock import FileLock
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
@ -240,7 +241,6 @@ if is_torch_available():
# and the others will use the cache.
lock_path = cached_features_file + ".lock"
with FileLock(lock_path):
if os.path.exists(cached_features_file) and not overwrite_cache:
logger.info(f"Loading features from cached file {cached_features_file}")
self.features = torch.load(cached_features_file)

View File

@ -23,10 +23,10 @@ from random import randint
from typing import Optional
import datasets
import evaluate
import numpy as np
from datasets import DatasetDict, load_dataset
import evaluate
import transformers
from transformers import (
AutoConfig,

View File

@ -19,6 +19,7 @@ import sys
from dataclasses import dataclass, field
from typing import Optional
import evaluate
import numpy as np
import torch
from datasets import load_dataset
@ -33,7 +34,6 @@ from torchvision.transforms import (
ToTensor,
)
import evaluate
import transformers
from transformers import (
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,

View File

@ -21,8 +21,13 @@ import os
from pathlib import Path
import datasets
import evaluate
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from huggingface_hub import Repository, create_repo
from torch.utils.data import DataLoader
from torchvision.transforms import (
CenterCrop,
@ -35,12 +40,7 @@ from torchvision.transforms import (
)
from tqdm.auto import tqdm
import evaluate
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository, create_repo
from transformers import AutoConfig, AutoImageProcessor, AutoModelForImageClassification, SchedulerType, get_scheduler
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version

View File

@ -30,10 +30,10 @@ from itertools import chain
from typing import Optional
import datasets
import evaluate
import torch
from datasets import load_dataset
import evaluate
import transformers
from transformers import (
CONFIG_MAPPING,

View File

@ -33,15 +33,15 @@ from pathlib import Path
import datasets
import torch
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from huggingface_hub import Repository, create_repo
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository, create_repo
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,

View File

@ -30,9 +30,9 @@ from itertools import chain
from typing import Optional
import datasets
import evaluate
from datasets import load_dataset
import evaluate
import transformers
from transformers import (
CONFIG_MAPPING,

View File

@ -33,15 +33,15 @@ from pathlib import Path
import datasets
import torch
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from huggingface_hub import Repository, create_repo
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository, create_repo
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,

View File

@ -30,17 +30,17 @@ from pathlib import Path
from typing import Optional, Union
import datasets
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import evaluate
import transformers
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from huggingface_hub import Repository, create_repo
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,

View File

@ -25,11 +25,12 @@ from dataclasses import dataclass, field
from typing import Optional
import datasets
from datasets import load_dataset
import evaluate
import transformers
from datasets import load_dataset
from trainer_qa import QuestionAnsweringTrainer
from utils_qa import postprocess_qa_predictions
import transformers
from transformers import (
AutoConfig,
AutoModelForQuestionAnswering,
@ -45,7 +46,6 @@ from transformers import (
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from utils_qa import postprocess_qa_predictions
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.

View File

@ -25,11 +25,12 @@ from dataclasses import dataclass, field
from typing import Optional
import datasets
from datasets import load_dataset
import evaluate
import transformers
from datasets import load_dataset
from trainer_qa import QuestionAnsweringTrainer
from utils_qa import postprocess_qa_predictions_with_beam_search
import transformers
from transformers import (
DataCollatorWithPadding,
EvalPrediction,
@ -44,7 +45,6 @@ from transformers import (
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from utils_qa import postprocess_qa_predictions_with_beam_search
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.

View File

@ -27,18 +27,19 @@ import random
from pathlib import Path
import datasets
import evaluate
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import evaluate
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from huggingface_hub import Repository, create_repo
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from utils_qa import postprocess_qa_predictions_with_beam_search
import transformers
from transformers import (
AdamW,
DataCollatorWithPadding,
@ -52,7 +53,6 @@ from transformers import (
)
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
from utils_qa import postprocess_qa_predictions_with_beam_search
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.

View File

@ -27,18 +27,19 @@ import random
from pathlib import Path
import datasets
import evaluate
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import evaluate
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from huggingface_hub import Repository, create_repo
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from utils_qa import postprocess_qa_predictions
import transformers
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
@ -53,7 +54,6 @@ from transformers import (
)
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
from utils_qa import postprocess_qa_predictions
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.

View File

@ -25,11 +25,11 @@ from dataclasses import dataclass, field
from typing import List, Optional, Tuple
import datasets
from datasets import load_dataset
import evaluate
import transformers
from datasets import load_dataset
from trainer_seq2seq_qa import QuestionAnsweringSeq2SeqTrainer
import transformers
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,

View File

@ -21,17 +21,17 @@ import sys
from dataclasses import dataclass, field
from typing import Optional
import evaluate
import numpy as np
import torch
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from PIL import Image
from torch import nn
from torchvision import transforms
from torchvision.transforms import functional
import evaluate
import transformers
from huggingface_hub import hf_hub_download
from transformers import (
AutoConfig,
AutoImageProcessor,

View File

@ -22,21 +22,21 @@ import random
from pathlib import Path
import datasets
import evaluate
import numpy as np
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from huggingface_hub import Repository, create_repo, hf_hub_download
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms import functional
from tqdm.auto import tqdm
import evaluate
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository, create_repo, hf_hub_download
from transformers import (
AutoConfig,
AutoImageProcessor,

View File

@ -24,14 +24,14 @@ from typing import Dict, List, Optional, Union
import datasets
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import DatasetDict, concatenate_datasets, load_dataset
from huggingface_hub import Repository, create_repo
from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from huggingface_hub import Repository, create_repo
from transformers import (
AdamW,
SchedulerType,
@ -641,7 +641,6 @@ def main():
# update step
if (step + 1) % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
# compute grad norm for monitoring
scale = (
accelerator.scaler._scale.item()

View File

@ -26,11 +26,11 @@ from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
import datasets
import evaluate
import numpy as np
import torch
from datasets import DatasetDict, load_dataset
import evaluate
import transformers
from transformers import (
AutoConfig,
@ -708,7 +708,6 @@ def main():
# Training
if training_args.do_train:
# use last checkpoint if exist
if last_checkpoint is not None:
checkpoint = last_checkpoint

View File

@ -26,10 +26,10 @@ from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import datasets
import evaluate
import torch
from datasets import DatasetDict, load_dataset
import evaluate
import transformers
from transformers import (
AutoConfig,

View File

@ -25,13 +25,13 @@ from dataclasses import dataclass, field
from typing import Optional
import datasets
import evaluate
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
from datasets import load_dataset
import evaluate
import transformers
from filelock import FileLock
import transformers
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,

View File

@ -27,20 +27,20 @@ import random
from pathlib import Path
import datasets
import evaluate
import nltk
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import evaluate
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from filelock import FileLock
from huggingface_hub import Repository, create_repo
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,

View File

@ -24,8 +24,8 @@ import tempfile
from unittest import mock
import torch
from accelerate.utils import write_basic_config
from transformers.testing_utils import TestCasePlus, get_gpu_count, run_command, slow, torch_device
from transformers.utils import is_apex_available

View File

@ -24,10 +24,10 @@ from dataclasses import dataclass, field
from typing import Optional
import datasets
import evaluate
import numpy as np
from datasets import load_dataset
import evaluate
import transformers
from transformers import (
AutoConfig,

View File

@ -22,17 +22,17 @@ import random
from pathlib import Path
import datasets
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import evaluate
import transformers
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from huggingface_hub import Repository, create_repo
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,

View File

@ -25,10 +25,10 @@ from dataclasses import dataclass, field
from typing import Optional
import datasets
import evaluate
import numpy as np
from datasets import load_dataset
import evaluate
import transformers
from transformers import (
AutoConfig,

View File

@ -26,10 +26,10 @@ from dataclasses import dataclass, field
from typing import Optional
import datasets
import evaluate
import numpy as np
from datasets import ClassLabel, load_dataset
import evaluate
import transformers
from transformers import (
AutoConfig,

View File

@ -27,17 +27,17 @@ import random
from pathlib import Path
import datasets
import torch
from datasets import ClassLabel, load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import evaluate
import transformers
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import ClassLabel, load_dataset
from huggingface_hub import Repository, create_repo
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,

View File

@ -25,10 +25,10 @@ from dataclasses import dataclass, field
from typing import Optional
import datasets
import evaluate
import numpy as np
from datasets import load_dataset
import evaluate
import transformers
from transformers import (
AutoConfig,

View File

@ -27,18 +27,18 @@ import random
from pathlib import Path
import datasets
import evaluate
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import evaluate
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from huggingface_hub import Repository, create_repo
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
@ -69,7 +69,6 @@ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
# Parsing input arguments
def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
parser.add_argument(
"--dataset_name",
@ -751,5 +750,4 @@ def main():
if __name__ == "__main__":
main()

View File

@ -22,6 +22,7 @@ from typing import Dict, List, Optional
import numpy as np
import torch
from utils_hans import HansDataset, InputFeatures, hans_processors, hans_tasks_num_labels
import transformers
from transformers import (
@ -35,7 +36,6 @@ from transformers import (
set_seed,
)
from transformers.trainer_utils import is_main_process
from utils_hans import HansDataset, InputFeatures, hans_processors, hans_tasks_num_labels
logger = logging.getLogger(__name__)

View File

@ -20,8 +20,8 @@ from dataclasses import dataclass
from typing import List, Optional, Union
import tqdm
from filelock import FileLock
from transformers import (
BartTokenizer,
BartTokenizerFast,
@ -134,7 +134,6 @@ if is_torch_available():
# and the others will use the cache.
lock_path = cached_features_file + ".lock"
with FileLock(lock_path):
if os.path.exists(cached_features_file) and not overwrite_cache:
logger.info(f"Loading features from cached file {cached_features_file}")
self.features = torch.load(cached_features_file)

View File

@ -25,14 +25,14 @@ import random
import numpy as np
import torch
from pabee.modeling_pabee_albert import AlbertForSequenceClassificationWithPabee
from pabee.modeling_pabee_bert import BertForSequenceClassificationWithPabee
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
import transformers
from pabee.modeling_pabee_albert import AlbertForSequenceClassificationWithPabee
from pabee.modeling_pabee_bert import BertForSequenceClassificationWithPabee
from transformers import (
WEIGHTS_NAME,
AdamW,
@ -173,7 +173,6 @@ def train(args, train_dataset, model, tokenizer):
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
@ -263,7 +262,6 @@ def train(args, train_dataset, model, tokenizer):
def evaluate(args, model, tokenizer, prefix="", patience=0):
if args.model_type == "albert":
model.albert.set_regression_threshold(args.regression_threshold)
model.albert.set_patience(patience)
@ -736,7 +734,6 @@ def main():
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""

View File

@ -4,6 +4,7 @@ import sys
from unittest.mock import patch
import run_glue_with_pabee
from transformers.testing_utils import TestCasePlus

View File

@ -24,9 +24,9 @@ import logging
from collections import namedtuple
import torch
from model_bertabs import BertAbsSummarizer
from models.model_builder import AbsSummarizer # The authors' implementation
from transformers import BertTokenizer

View File

@ -24,10 +24,10 @@ import math
import numpy as np
import torch
from configuration_bertabs import BertAbsConfig
from torch import nn
from torch.nn.init import xavier_uniform_
from configuration_bertabs import BertAbsConfig
from transformers import BertConfig, BertModel, PreTrainedModel

View File

@ -6,10 +6,10 @@ import sys
from collections import namedtuple
import torch
from modeling_bertabs import BertAbs, build_predictor
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm
from modeling_bertabs import BertAbs, build_predictor
from transformers import BertTokenizer
from .utils_summarization import (
@ -45,7 +45,6 @@ def evaluate(args):
generated_summaries = []
import nltk
import rouge
nltk.download("punkt")

View File

@ -3,8 +3,8 @@ from copy import deepcopy
import numpy as np
from datasets import ClassLabel, DatasetDict, load_dataset
from evaluate import load
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,

View File

@ -1,7 +1,7 @@
from arguments import TokenizerTrainingArguments
from datasets import load_dataset
from tqdm import tqdm
from arguments import TokenizerTrainingArguments
from transformers import AutoTokenizer, HfArgumentParser
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode

View File

@ -6,16 +6,16 @@ from pathlib import Path
import datasets
import torch
from accelerate import Accelerator, DistributedType
from arguments import TrainingArguments
from datasets import load_dataset
from huggingface_hub import Repository
from torch.optim import AdamW
from torch.utils.data import IterableDataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
import transformers
from accelerate import Accelerator, DistributedType
from arguments import TrainingArguments
from huggingface_hub import Repository
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed

View File

@ -5,15 +5,15 @@ import re
from collections import defaultdict
import torch
from accelerate import Accelerator
from accelerate.utils import set_seed
from arguments import HumanEvalArguments
from datasets import load_dataset, load_metric
from torch.utils.data import IterableDataset
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
import transformers
from accelerate import Accelerator
from accelerate.utils import set_seed
from arguments import HumanEvalArguments
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, StoppingCriteria, StoppingCriteriaList

View File

@ -1,4 +1,5 @@
from arguments import InitializationArguments
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

View File

@ -6,10 +6,9 @@ from functools import partial
from typing import Dict, List, Optional, Set, Tuple, Type
from datasets import Dataset
from tqdm import tqdm
from datasketch import MinHash, MinHashLSH
from dpu_utils.utils.iterators import ThreadedIterator
from tqdm import tqdm
NON_ALPHA = re.compile("[^A-Za-z_0-9]")

View File

@ -9,10 +9,10 @@ import time
from pathlib import Path
import numpy as np
from datasets import load_dataset
from arguments import PreprocessingArguments
from datasets import load_dataset
from minhash_deduplication import deduplicate_dataset
from transformers import AutoTokenizer, HfArgumentParser

View File

@ -1,9 +1,9 @@
import multiprocessing
import time
from arguments import PretokenizationArguments
from datasets import load_dataset
from arguments import PretokenizationArguments
from transformers import AutoTokenizer, HfArgumentParser

View File

@ -1,7 +1,6 @@
from unittest import TestCase
from datasets import Dataset
from minhash_deduplication import deduplicate_dataset, make_duplicate_clusters

View File

@ -1,12 +1,12 @@
import logging
import torch
from accelerate import Accelerator
from arguments import EvaluationArguments
from datasets import load_dataset
from torch.utils.data import IterableDataset
from torch.utils.data.dataloader import DataLoader
from accelerate import Accelerator
from arguments import EvaluationArguments
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed

View File

@ -1,8 +1,8 @@
import gym
import numpy as np
import torch
import gym
from mujoco_py import GlfwContext
from transformers import DecisionTransformerModel

View File

@ -229,7 +229,10 @@ class DeeBertModel(BertPreTrainedModel):
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
outputs = (sequence_output, pooled_output,) + encoder_outputs[
outputs = (
sequence_output,
pooled_output,
) + encoder_outputs[
1:
] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions), highway exits

View File

@ -19,7 +19,6 @@ from .modeling_highway_bert import BertPreTrainedModel, DeeBertModel, HighwayExc
ROBERTA_START_DOCSTRING,
)
class DeeRobertaModel(DeeBertModel):
config_class = RobertaConfig
base_model_prefix = "roberta"
@ -36,7 +35,6 @@ class DeeRobertaModel(DeeBertModel):
ROBERTA_START_DOCSTRING,
)
class DeeRobertaForSequenceClassification(BertPreTrainedModel):
config_class = RobertaConfig
base_model_prefix = "roberta"

View File

@ -4,6 +4,7 @@ import sys
from unittest.mock import patch
import run_glue_deebert
from transformers.testing_utils import TestCasePlus, get_gpu_count, require_torch_non_multi_gpu, slow
@ -45,7 +46,6 @@ class DeeBertTests(TestCasePlus):
@slow
@require_torch_non_multi_gpu
def test_glue_deebert_train(self):
train_args = """
--model_type roberta
--model_name_or_path roberta-base

View File

@ -21,14 +21,14 @@ import time
import psutil
import torch
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
from lm_seqs_dataset import LmSeqsDataset
from torch import nn
from torch.optim import AdamW
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
from lm_seqs_dataset import LmSeqsDataset
from transformers import get_linear_schedule_with_warmup
from utils import logger

View File

@ -189,7 +189,6 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1

View File

@ -24,9 +24,9 @@ import shutil
import numpy as np
import torch
from distiller import Distiller
from lm_seqs_dataset import LmSeqsDataset
from transformers import (
BertConfig,
BertForMaskedLM,

View File

@ -5,13 +5,13 @@ import copy
import logging
import random
import joblib
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import joblib
from transformers import AdamW, GPT2LMHeadModel, get_linear_schedule_with_warmup
@ -119,7 +119,6 @@ def recopy_gpt2(orig_model, device, max_steps):
def intermittent_save(contexts, real_perps, past_perps, filename):
"""
save the perplexity differences to filename
@ -152,7 +151,6 @@ def collect_objective_set(
filename="dev.jbl",
recopy_model=recopy_gpt2,
):
"""
Collect individual IGF values from pre-trained transformer model
max_steps samples of training data to train secondary model
@ -271,7 +269,6 @@ def generate_datasets(
def train_secondary_learner(
secondary_learner, train_dataset, max_epochs, batch_size, eval_freq=50, igf_model_path="secondary_learner.pt"
):
"""
Train the secondary learner (igf_model)

View File

@ -28,11 +28,9 @@ Last, a plot is generated to compare the performance of IGF to standard fine-tun
import argparse
import random
import joblib
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler
import joblib
from igf.igf import (
SecondaryLearner,
collect_objective_set,
@ -43,6 +41,8 @@ from igf.igf import (
set_seed,
train_secondary_learner,
)
from torch.utils.data import DataLoader, RandomSampler
from transformers import GPT2LMHeadModel
@ -55,7 +55,6 @@ def generate_n_pairs(
data_file="data/tokenized_stories_train_wikitext103.jbl",
igf_data_file="igf_context_pairs.jbl",
):
"""
Collecting *n* pairs for training the secondary learner
Args:

View File

@ -4,8 +4,6 @@ from dataclasses import dataclass
from functools import partial
from typing import Callable
from tqdm.auto import tqdm
import flax.linen as nn
import jax
import jax.numpy as jnp
@ -16,6 +14,8 @@ from flax import jax_utils, struct, traverse_util
from flax.serialization import from_bytes, to_bytes
from flax.training import train_state
from flax.training.common_utils import shard
from tqdm.auto import tqdm
from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering
from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule
@ -98,7 +98,6 @@ class Args:
@dataclass
class DataCollator:
pad_id: int
max_length: int = 4096 # no dynamic padding on TPUs

View File

@ -1,8 +1,8 @@
from datasets import load_from_disk
import jax
import jax.numpy as jnp
from bigbird_flax import FlaxBigBirdForNaturalQuestions
from datasets import load_from_disk
from transformers import BigBirdTokenizerFast

View File

@ -1,10 +1,9 @@
import os
import jsonlines
import numpy as np
from tqdm import tqdm
import jsonlines
DOC_STRIDE = 2048
MAX_LENGTH = 4096

View File

@ -1,12 +1,12 @@
import os
from dataclasses import replace
from datasets import load_dataset
import jax
import wandb
from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step
from datasets import load_dataset
from flax import jax_utils
from transformers import BigBirdTokenizerFast

View File

@ -32,17 +32,17 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple
import datasets
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
from datasets import load_dataset
from flax import jax_utils, traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from tqdm import tqdm
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,

View File

@ -20,6 +20,7 @@ import jax
import jax.numpy as jnp
from configuration_hybrid_clip import HybridCLIPConfig
from flax.core.frozen_dict import FrozenDict
from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
from transformers.modeling_flax_utils import FlaxPreTrainedModel
from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
@ -132,7 +133,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
**kwargs
**kwargs,
):
if input_shape is None:
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))

View File

@ -32,22 +32,22 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional
import jax
import jax.numpy as jnp
import optax
import torch
from flax import jax_utils
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, shard, shard_prng_key
from modeling_hybrid_clip import FlaxHybridCLIP
from torchvision.datasets import VisionDataset
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
import jax
import jax.numpy as jnp
import optax
import transformers
from flax import jax_utils
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, shard, shard_prng_key
from modeling_hybrid_clip import FlaxHybridCLIP
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed

View File

@ -28,19 +28,19 @@ from pathlib import Path
from typing import Callable, Optional
import datasets
import numpy as np
from datasets import Dataset, load_dataset
from tqdm import tqdm
import jax
import jax.numpy as jnp
import numpy as np
import optax
import transformers
from datasets import Dataset, load_dataset
from flax.core.frozen_dict import freeze, unfreeze
from flax.training.common_utils import onehot, stack_forest
from jax.experimental.maps import mesh
from jax.experimental.pjit import pjit
from partitions import set_partitions
from tqdm import tqdm
import transformers
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,

View File

@ -6,18 +6,18 @@ from dataclasses import field
from pathlib import Path
from typing import Dict, List, Optional, Union
import numpy as np
from datasets import DatasetDict, load_dataset
from tqdm import tqdm
import flax
import jax
import jax.numpy as jnp
import librosa
import numpy as np
import optax
from datasets import DatasetDict, load_dataset
from flax import jax_utils, traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from tqdm import tqdm
from transformers import (
FlaxWav2Vec2ForPreTraining,
HfArgumentParser,

Some files were not shown because too many files have changed in this diff Show More