mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
Update quality tooling for formatting (#21480)
* Result of black 23.1 * Update target to Python 3.7 * Switch flake8 to ruff * Configure isort * Configure isort * Apply isort with line limit * Put the right black version * adapt black in check copies * Fix copies
This commit is contained in:
parent
b7bb2b59f7
commit
6f79d26442
@ -134,11 +134,10 @@ jobs:
|
||||
command: pip freeze | tee installed.txt
|
||||
- store_artifacts:
|
||||
path: ~/transformers/installed.txt
|
||||
- run: black --check --preview examples tests src utils
|
||||
- run: isort --check-only examples tests src utils
|
||||
- run: black --check examples tests src utils
|
||||
- run: ruff examples tests src utils
|
||||
- run: python utils/custom_init_isort.py --check_only
|
||||
- run: python utils/sort_auto_mappings.py --check_only
|
||||
- run: flake8 examples tests src utils
|
||||
- run: doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
|
||||
- run: python utils/check_doc_toc.py
|
||||
|
||||
|
14
Makefile
14
Makefile
@ -9,9 +9,8 @@ modified_only_fixup:
|
||||
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
|
||||
@if test -n "$(modified_py_files)"; then \
|
||||
echo "Checking/fixing $(modified_py_files)"; \
|
||||
black --preview $(modified_py_files); \
|
||||
isort $(modified_py_files); \
|
||||
flake8 $(modified_py_files); \
|
||||
black $(modified_py_files); \
|
||||
ruff $(modified_py_files) --fix; \
|
||||
else \
|
||||
echo "No library .py files were modified"; \
|
||||
fi
|
||||
@ -48,11 +47,10 @@ repo-consistency:
|
||||
# this target runs checks on all files
|
||||
|
||||
quality:
|
||||
black --check --preview $(check_dirs)
|
||||
isort --check-only $(check_dirs)
|
||||
black --check $(check_dirs)
|
||||
python utils/custom_init_isort.py --check_only
|
||||
python utils/sort_auto_mappings.py --check_only
|
||||
flake8 $(check_dirs)
|
||||
ruff $(check_dirs)
|
||||
doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
|
||||
python utils/check_doc_toc.py
|
||||
|
||||
@ -67,8 +65,8 @@ extra_style_checks:
|
||||
# this target runs checks on all files and potentially modifies some of them
|
||||
|
||||
style:
|
||||
black --preview $(check_dirs)
|
||||
isort $(check_dirs)
|
||||
black $(check_dirs)
|
||||
ruff $(check_dirs) --fix
|
||||
${MAKE} autogenerate_code
|
||||
${MAKE} extra_style_checks
|
||||
|
||||
|
@ -96,7 +96,7 @@ while True:
|
||||
queues.append(rq)
|
||||
strings
|
||||
outs = pipe(strings, batch_size=len(strings))
|
||||
for (rq, out) in zip(queues, outs):
|
||||
for rq, out in zip(queues, outs):
|
||||
await rq.put(out)
|
||||
```
|
||||
|
||||
|
@ -166,7 +166,6 @@ Unlike other data collators, this specific data collator needs to apply a differ
|
||||
|
||||
>>> @dataclass
|
||||
... class DataCollatorCTCWithPadding:
|
||||
|
||||
... processor: AutoProcessor
|
||||
... padding: Union[bool, str] = "longest"
|
||||
|
||||
|
@ -213,7 +213,6 @@ The `image_processor` expects the annotations to be in the following format: `{'
|
||||
|
||||
```py
|
||||
>>> def formatted_anns(image_id, category, area, bbox):
|
||||
|
||||
... annotations = []
|
||||
... for i in range(0, len(category)):
|
||||
... new_ann = {
|
||||
@ -399,6 +398,7 @@ First, prepare the `cppe5["test"]` set: format the annotations and save the data
|
||||
```py
|
||||
>>> import json
|
||||
|
||||
|
||||
>>> # format annotations the same as for training, no need for data augmentation
|
||||
>>> def val_formatted_anns(image_id, objects):
|
||||
... annotations = []
|
||||
|
@ -159,7 +159,6 @@ A diferencia de otros collators de datos, este tiene que aplicarle un método de
|
||||
|
||||
>>> @dataclass
|
||||
... class DataCollatorCTCWithPadding:
|
||||
|
||||
... processor: AutoProcessor
|
||||
... padding: Union[bool, str] = "longest"
|
||||
|
||||
|
@ -29,23 +29,23 @@ from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import datasets
|
||||
import nltk # Here to have a nice missing dependency error message early on
|
||||
import numpy as np
|
||||
from datasets import Dataset, load_dataset
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import evaluate
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import nltk # Here to have a nice missing dependency error message early on
|
||||
import numpy as np
|
||||
import optax
|
||||
import transformers
|
||||
from datasets import Dataset, load_dataset
|
||||
from filelock import FileLock
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoImageProcessor,
|
||||
AutoTokenizer,
|
||||
|
@ -32,20 +32,20 @@ from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import nltk
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import nltk
|
||||
import numpy as np
|
||||
import optax
|
||||
from datasets import load_dataset
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.jax_utils import pad_shard_unpad
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
|
@ -34,19 +34,19 @@ from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import Dataset, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import transformers
|
||||
from datasets import Dataset, load_dataset
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from tqdm import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
|
@ -34,19 +34,19 @@ from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
from datasets import load_dataset
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.jax_utils import pad_shard_unpad
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
|
@ -33,19 +33,19 @@ from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
from datasets import load_dataset
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.jax_utils import pad_shard_unpad
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
|
@ -31,20 +31,21 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import evaluate
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from flax import struct, traverse_util
|
||||
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from tqdm import tqdm
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
@ -55,7 +56,6 @@ from transformers import (
|
||||
is_tensorboard_available,
|
||||
)
|
||||
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -301,6 +301,7 @@ class DataTrainingArguments:
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Create a train state
|
||||
def create_train_state(
|
||||
model: FlaxAutoModelForQuestionAnswering,
|
||||
@ -387,6 +388,7 @@ def create_learning_rate_fn(
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region train data iterator
|
||||
def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
|
||||
"""Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices."""
|
||||
@ -405,6 +407,7 @@ def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region eval data iterator
|
||||
def eval_data_collator(dataset: Dataset, batch_size: int):
|
||||
"""Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop."""
|
||||
@ -934,7 +937,6 @@ def main():
|
||||
total_steps = step_per_epoch * num_epochs
|
||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||
for epoch in epochs:
|
||||
|
||||
train_start = time.time()
|
||||
train_metrics = []
|
||||
|
||||
@ -975,7 +977,6 @@ def main():
|
||||
and (cur_step % training_args.eval_steps == 0 or cur_step % step_per_epoch == 0)
|
||||
and cur_step > 0
|
||||
):
|
||||
|
||||
eval_metrics = {}
|
||||
all_start_logits = []
|
||||
all_end_logits = []
|
||||
|
@ -31,22 +31,22 @@ from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import datasets
|
||||
import nltk # Here to have a nice missing dependency error message early on
|
||||
import numpy as np
|
||||
from datasets import Dataset, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import evaluate
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import nltk # Here to have a nice missing dependency error message early on
|
||||
import numpy as np
|
||||
import optax
|
||||
import transformers
|
||||
from datasets import Dataset, load_dataset
|
||||
from filelock import FileLock
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from tqdm import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
|
@ -26,20 +26,20 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import evaluate
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from flax import struct, traverse_util
|
||||
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from tqdm import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
@ -586,7 +586,6 @@ def main():
|
||||
total_steps = steps_per_epoch * num_epochs
|
||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (0/{num_epochs})", position=0)
|
||||
for epoch in epochs:
|
||||
|
||||
train_start = time.time()
|
||||
train_metrics = []
|
||||
|
||||
@ -623,7 +622,6 @@ def main():
|
||||
train_metrics = []
|
||||
|
||||
if (cur_step % training_args.eval_steps == 0 or cur_step % steps_per_epoch == 0) and cur_step > 0:
|
||||
|
||||
# evaluate
|
||||
eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size)
|
||||
for batch in tqdm(
|
||||
|
@ -28,20 +28,20 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import ClassLabel, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import evaluate
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import transformers
|
||||
from datasets import ClassLabel, load_dataset
|
||||
from flax import struct, traverse_util
|
||||
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from tqdm import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
@ -695,7 +695,6 @@ def main():
|
||||
total_steps = step_per_epoch * num_epochs
|
||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||
for epoch in epochs:
|
||||
|
||||
train_start = time.time()
|
||||
train_metrics = []
|
||||
|
||||
@ -731,7 +730,6 @@ def main():
|
||||
train_metrics = []
|
||||
|
||||
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
|
||||
|
||||
eval_metrics = {}
|
||||
# evaluate
|
||||
for batch in tqdm(
|
||||
|
@ -29,21 +29,22 @@ from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
|
||||
# for dataset and preprocessing
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import transformers
|
||||
from flax import jax_utils
|
||||
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from tqdm import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
|
@ -22,6 +22,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from utils_multiple_choice import MultipleChoiceDataset, Split, processors
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
@ -36,7 +37,6 @@ from transformers import (
|
||||
set_seed,
|
||||
)
|
||||
from transformers.trainer_utils import is_main_process
|
||||
from utils_multiple_choice import MultipleChoiceDataset, Split, processors
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -26,8 +26,8 @@ from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import tqdm
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
@ -112,7 +112,6 @@ if is_torch_available():
|
||||
# and the others will use the cache.
|
||||
lock_path = cached_features_file + ".lock"
|
||||
with FileLock(lock_path):
|
||||
|
||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||
self.features = torch.load(cached_features_file)
|
||||
|
@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule):
|
||||
config=None,
|
||||
tokenizer=None,
|
||||
model=None,
|
||||
**config_kwargs
|
||||
**config_kwargs,
|
||||
):
|
||||
"""Initialize a model, tokenizer and config."""
|
||||
super().__init__()
|
||||
@ -346,7 +346,7 @@ def generic_train(
|
||||
extra_callbacks=[],
|
||||
checkpoint_callback=None,
|
||||
logging_callback=None,
|
||||
**extra_train_kwargs
|
||||
**extra_train_kwargs,
|
||||
):
|
||||
pl.seed_everything(args.seed)
|
||||
|
||||
|
@ -7,21 +7,19 @@ from argparse import Namespace
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
from transformers import glue_compute_metrics as compute_metrics
|
||||
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
||||
from transformers import glue_output_modes
|
||||
from transformers import glue_output_modes, glue_tasks_num_labels
|
||||
from transformers import glue_processors as processors
|
||||
from transformers import glue_tasks_num_labels
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GLUETransformer(BaseTransformer):
|
||||
|
||||
mode = "sequence-classification"
|
||||
|
||||
def __init__(self, hparams):
|
||||
|
@ -7,11 +7,10 @@ from importlib import import_module
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
from utils_ner import TokenClassificationTask
|
||||
|
||||
|
||||
|
@ -172,7 +172,6 @@ def train(args, train_dataset, model, tokenizer):
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
|
@ -30,9 +30,10 @@ from transformers import (
|
||||
DataCollatorWithPadding,
|
||||
HfArgumentParser,
|
||||
SquadDataset,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers import SquadDataTrainingArguments as DataTrainingArguments
|
||||
from transformers import Trainer, TrainingArguments
|
||||
from transformers.trainer_utils import is_main_process
|
||||
|
||||
|
||||
|
@ -4,6 +4,7 @@ import json
|
||||
from typing import List
|
||||
|
||||
from ltp import LTP
|
||||
|
||||
from transformers import BertTokenizer
|
||||
|
||||
|
||||
@ -93,7 +94,6 @@ def prepare_ref(lines: List[str], ltp_tokenizer: LTP, bert_tokenizer: BertTokeni
|
||||
|
||||
ref_ids = []
|
||||
for input_ids, chinese_word in zip(bert_res, ltp_res):
|
||||
|
||||
input_tokens = []
|
||||
for id in input_ids:
|
||||
token = bert_tokenizer._convert_id_to_token(id)
|
||||
|
@ -19,9 +19,10 @@ import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import transformers
|
||||
from seq2seq_trainer import Seq2SeqTrainer
|
||||
from seq2seq_training_args import Seq2SeqTrainingArguments
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
@ -337,7 +338,6 @@ def main():
|
||||
metrics["val_loss"] = round(metrics["val_loss"], 4)
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
|
||||
handle_metrics("val", metrics, training_args.output_dir)
|
||||
all_metrics.update(metrics)
|
||||
|
||||
|
@ -16,8 +16,8 @@ from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from rouge_cli import calculate_rouge_path
|
||||
|
||||
from utils import calculate_rouge
|
||||
|
||||
|
||||
@ -87,7 +87,6 @@ def test_single_sent_scores_dont_depend_on_newline_sep():
|
||||
|
||||
|
||||
def test_pegasus_newline():
|
||||
|
||||
pred = [
|
||||
"""" "a person who has such a video needs to immediately give it to the investigators," prosecutor says .<n> "it is a very disturbing scene," editor-in-chief of bild online tells "erin burnett: outfront" """
|
||||
]
|
||||
|
@ -17,11 +17,11 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pack_dataset import pack_data_dir
|
||||
from parameterized import parameterized
|
||||
from save_len_file import save_len_file
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.mbart.modeling_mbart import shift_tokens_right
|
||||
from transformers.testing_utils import TestCasePlus, slow
|
||||
|
@ -18,6 +18,7 @@ import json
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import FSMTForConditionalGeneration, FSMTTokenizer
|
||||
from transformers.testing_utils import get_tests_dir, require_torch, slow, torch_device
|
||||
from utils import calculate_bleu
|
||||
|
@ -21,6 +21,7 @@ from unittest.mock import patch
|
||||
from parameterized import parameterized
|
||||
from run_eval import run_generate
|
||||
from run_eval_search import run_search
|
||||
|
||||
from transformers.testing_utils import CaptureStdout, TestCasePlus, slow
|
||||
from utils import ROUGE_KEYS
|
||||
|
||||
|
@ -29,7 +29,6 @@ from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def pack_examples(tok, src_examples, tgt_examples, max_tokens=1024):
|
||||
|
||||
finished_src, finished_tgt = [], []
|
||||
|
||||
sorted_examples = list(zip(src_examples, tgt_examples))
|
||||
|
@ -20,6 +20,7 @@ import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
from run_eval import datetime_now, run_generate
|
||||
|
||||
from utils import ROUGE_KEYS
|
||||
|
||||
|
||||
|
@ -17,6 +17,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from seq2seq_trainer import arg_to_scheduler
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
|
||||
|
@ -29,10 +29,10 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from rouge_score import rouge_scorer, scoring
|
||||
from sacrebleu import corpus_bleu
|
||||
from sentence_splitter import add_newline_to_end_of_each_sentence
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
|
||||
from sentence_splitter import add_newline_to_end_of_each_sentence
|
||||
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
|
||||
from transformers.models.bart.modeling_bart import shift_tokens_right
|
||||
from transformers.utils import cached_property
|
||||
@ -132,7 +132,7 @@ class AbstractSeq2SeqDataset(Dataset):
|
||||
type_path="train",
|
||||
n_obs=None,
|
||||
prefix="",
|
||||
**dataset_kwargs
|
||||
**dataset_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
||||
|
@ -24,6 +24,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
import numpy as np
|
||||
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||
from torch import nn
|
||||
from utils_ner import Split, TokenClassificationDataset, TokenClassificationTask
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
@ -38,7 +39,6 @@ from transformers import (
|
||||
set_seed,
|
||||
)
|
||||
from transformers.trainer_utils import is_main_process
|
||||
from utils_ner import Split, TokenClassificationDataset, TokenClassificationTask
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -24,6 +24,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
|
||||
from utils_ner import Split, TFTokenClassificationDataset, TokenClassificationTask
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
@ -35,7 +36,6 @@ from transformers import (
|
||||
TFTrainingArguments,
|
||||
)
|
||||
from transformers.utils import logging as hf_logging
|
||||
from utils_ner import Split, TFTokenClassificationDataset, TokenClassificationTask
|
||||
|
||||
|
||||
hf_logging.set_verbosity_info()
|
||||
|
@ -3,7 +3,6 @@ import os
|
||||
from typing import List, TextIO, Union
|
||||
|
||||
from conllu import parse_incr
|
||||
|
||||
from utils_ner import InputExample, Split, TokenClassificationTask
|
||||
|
||||
|
||||
|
@ -23,6 +23,7 @@ from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
@ -240,7 +241,6 @@ if is_torch_available():
|
||||
# and the others will use the cache.
|
||||
lock_path = cached_features_file + ".lock"
|
||||
with FileLock(lock_path):
|
||||
|
||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||
self.features = torch.load(cached_features_file)
|
||||
|
@ -23,10 +23,10 @@ from random import randint
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
import numpy as np
|
||||
from datasets import DatasetDict, load_dataset
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
|
@ -19,6 +19,7 @@ import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
@ -33,7 +34,6 @@ from torchvision.transforms import (
|
||||
ToTensor,
|
||||
)
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from transformers import (
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
|
@ -21,8 +21,13 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
@ -35,12 +40,7 @@ from torchvision.transforms import (
|
||||
)
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from transformers import AutoConfig, AutoImageProcessor, AutoModelForImageClassification, SchedulerType, get_scheduler
|
||||
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils.versions import require_version
|
||||
|
@ -30,10 +30,10 @@ from itertools import chain
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
|
@ -33,15 +33,15 @@ from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
|
@ -30,9 +30,9 @@ from itertools import chain
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
from datasets import load_dataset
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
|
@ -33,15 +33,15 @@ from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
|
@ -30,17 +30,17 @@ from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
|
@ -25,11 +25,12 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
from datasets import load_dataset
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from trainer_qa import QuestionAnsweringTrainer
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForQuestionAnswering,
|
||||
@ -45,7 +46,6 @@ from transformers import (
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import check_min_version, send_example_telemetry
|
||||
from transformers.utils.versions import require_version
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
|
@ -25,11 +25,12 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
from datasets import load_dataset
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from trainer_qa import QuestionAnsweringTrainer
|
||||
from utils_qa import postprocess_qa_predictions_with_beam_search
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
DataCollatorWithPadding,
|
||||
EvalPrediction,
|
||||
@ -44,7 +45,6 @@ from transformers import (
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import check_min_version, send_example_telemetry
|
||||
from transformers.utils.versions import require_version
|
||||
from utils_qa import postprocess_qa_predictions_with_beam_search
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
|
@ -27,18 +27,19 @@ import random
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
from utils_qa import postprocess_qa_predictions_with_beam_search
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AdamW,
|
||||
DataCollatorWithPadding,
|
||||
@ -52,7 +53,6 @@ from transformers import (
|
||||
)
|
||||
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils.versions import require_version
|
||||
from utils_qa import postprocess_qa_predictions_with_beam_search
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
|
@ -27,18 +27,19 @@ import random
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
@ -53,7 +54,6 @@ from transformers import (
|
||||
)
|
||||
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils.versions import require_version
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
|
@ -25,11 +25,11 @@ from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import datasets
|
||||
from datasets import load_dataset
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from trainer_seq2seq_qa import QuestionAnsweringSeq2SeqTrainer
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
|
@ -21,17 +21,17 @@ import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import functional
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoImageProcessor,
|
||||
|
@ -22,21 +22,21 @@ import random
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import Repository, create_repo, hf_hub_download
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import functional
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from huggingface_hub import Repository, create_repo, hf_hub_download
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoImageProcessor,
|
||||
|
@ -24,14 +24,14 @@ from typing import Dict, List, Optional, Union
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import DatasetDict, concatenate_datasets, load_dataset
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from transformers import (
|
||||
AdamW,
|
||||
SchedulerType,
|
||||
@ -641,7 +641,6 @@ def main():
|
||||
|
||||
# update step
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||
|
||||
# compute grad norm for monitoring
|
||||
scale = (
|
||||
accelerator.scaler._scale.item()
|
||||
|
@ -26,11 +26,11 @@ from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import DatasetDict, load_dataset
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
@ -708,7 +708,6 @@ def main():
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
|
||||
# use last checkpoint if exist
|
||||
if last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
|
@ -26,10 +26,10 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
import torch
|
||||
from datasets import DatasetDict, load_dataset
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
|
@ -25,13 +25,13 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
import nltk # Here to have a nice missing dependency error message early on
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from filelock import FileLock
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
|
@ -27,20 +27,20 @@ import random
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
import nltk
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from datasets import load_dataset
|
||||
from filelock import FileLock
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
|
@ -24,8 +24,8 @@ import tempfile
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
from transformers.testing_utils import TestCasePlus, get_gpu_count, run_command, slow, torch_device
|
||||
from transformers.utils import is_apex_available
|
||||
|
||||
|
@ -24,10 +24,10 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
|
@ -22,17 +22,17 @@ import random
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSequenceClassification,
|
||||
|
@ -25,10 +25,10 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
|
@ -26,10 +26,10 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
import numpy as np
|
||||
from datasets import ClassLabel, load_dataset
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
|
@ -27,17 +27,17 @@ import random
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import ClassLabel, load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from datasets import ClassLabel, load_dataset
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
|
@ -25,10 +25,10 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
|
@ -27,18 +27,18 @@ import random
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import evaluate
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
@ -69,7 +69,6 @@ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
# Parsing input arguments
|
||||
def parse_args():
|
||||
|
||||
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
@ -751,5 +750,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
main()
|
||||
|
@ -22,6 +22,7 @@ from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from utils_hans import HansDataset, InputFeatures, hans_processors, hans_tasks_num_labels
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
@ -35,7 +36,6 @@ from transformers import (
|
||||
set_seed,
|
||||
)
|
||||
from transformers.trainer_utils import is_main_process
|
||||
from utils_hans import HansDataset, InputFeatures, hans_processors, hans_tasks_num_labels
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -20,8 +20,8 @@ from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import tqdm
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
from transformers import (
|
||||
BartTokenizer,
|
||||
BartTokenizerFast,
|
||||
@ -134,7 +134,6 @@ if is_torch_available():
|
||||
# and the others will use the cache.
|
||||
lock_path = cached_features_file + ".lock"
|
||||
with FileLock(lock_path):
|
||||
|
||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||
self.features = torch.load(cached_features_file)
|
||||
|
@ -25,14 +25,14 @@ import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pabee.modeling_pabee_albert import AlbertForSequenceClassificationWithPabee
|
||||
from pabee.modeling_pabee_bert import BertForSequenceClassificationWithPabee
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import transformers
|
||||
from pabee.modeling_pabee_albert import AlbertForSequenceClassificationWithPabee
|
||||
from pabee.modeling_pabee_bert import BertForSequenceClassificationWithPabee
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
@ -173,7 +173,6 @@ def train(args, train_dataset, model, tokenizer):
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
@ -263,7 +262,6 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix="", patience=0):
|
||||
|
||||
if args.model_type == "albert":
|
||||
model.albert.set_regression_threshold(args.regression_threshold)
|
||||
model.albert.set_patience(patience)
|
||||
@ -736,7 +734,6 @@ def main():
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
|
@ -4,6 +4,7 @@ import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import run_glue_with_pabee
|
||||
|
||||
from transformers.testing_utils import TestCasePlus
|
||||
|
||||
|
||||
|
@ -24,9 +24,9 @@ import logging
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
|
||||
from model_bertabs import BertAbsSummarizer
|
||||
from models.model_builder import AbsSummarizer # The authors' implementation
|
||||
|
||||
from transformers import BertTokenizer
|
||||
|
||||
|
||||
|
@ -24,10 +24,10 @@ import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from configuration_bertabs import BertAbsConfig
|
||||
from torch import nn
|
||||
from torch.nn.init import xavier_uniform_
|
||||
|
||||
from configuration_bertabs import BertAbsConfig
|
||||
from transformers import BertConfig, BertModel, PreTrainedModel
|
||||
|
||||
|
||||
|
@ -6,10 +6,10 @@ import sys
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from modeling_bertabs import BertAbs, build_predictor
|
||||
from torch.utils.data import DataLoader, SequentialSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from modeling_bertabs import BertAbs, build_predictor
|
||||
from transformers import BertTokenizer
|
||||
|
||||
from .utils_summarization import (
|
||||
@ -45,7 +45,6 @@ def evaluate(args):
|
||||
generated_summaries = []
|
||||
|
||||
import nltk
|
||||
|
||||
import rouge
|
||||
|
||||
nltk.download("punkt")
|
||||
|
@ -3,8 +3,8 @@ from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from datasets import ClassLabel, DatasetDict, load_dataset
|
||||
|
||||
from evaluate import load
|
||||
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
|
@ -1,7 +1,7 @@
|
||||
from arguments import TokenizerTrainingArguments
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from arguments import TokenizerTrainingArguments
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||
|
||||
|
@ -6,16 +6,16 @@ from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from arguments import TrainingArguments
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import Repository
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from arguments import TrainingArguments
|
||||
from huggingface_hub import Repository
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
|
||||
|
||||
|
||||
|
@ -5,15 +5,15 @@ import re
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
from arguments import HumanEvalArguments
|
||||
from datasets import load_dataset, load_metric
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
from arguments import HumanEvalArguments
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
from arguments import InitializationArguments
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
|
@ -6,10 +6,9 @@ from functools import partial
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
from datasets import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from datasketch import MinHash, MinHashLSH
|
||||
from dpu_utils.utils.iterators import ThreadedIterator
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
NON_ALPHA = re.compile("[^A-Za-z_0-9]")
|
||||
|
@ -9,10 +9,10 @@ import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from arguments import PreprocessingArguments
|
||||
from datasets import load_dataset
|
||||
from minhash_deduplication import deduplicate_dataset
|
||||
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
from arguments import PretokenizationArguments
|
||||
from datasets import load_dataset
|
||||
|
||||
from arguments import PretokenizationArguments
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from minhash_deduplication import deduplicate_dataset, make_duplicate_clusters
|
||||
|
||||
|
||||
|
@ -1,12 +1,12 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from arguments import EvaluationArguments
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from accelerate import Accelerator
|
||||
from arguments import EvaluationArguments
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed
|
||||
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import gym
|
||||
from mujoco_py import GlfwContext
|
||||
|
||||
from transformers import DecisionTransformerModel
|
||||
|
||||
|
||||
|
@ -229,7 +229,10 @@ class DeeBertModel(BertPreTrainedModel):
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[
|
||||
outputs = (
|
||||
sequence_output,
|
||||
pooled_output,
|
||||
) + encoder_outputs[
|
||||
1:
|
||||
] # add hidden_states and attentions if they are here
|
||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions), highway exits
|
||||
|
@ -19,7 +19,6 @@ from .modeling_highway_bert import BertPreTrainedModel, DeeBertModel, HighwayExc
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class DeeRobertaModel(DeeBertModel):
|
||||
|
||||
config_class = RobertaConfig
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
@ -36,7 +35,6 @@ class DeeRobertaModel(DeeBertModel):
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class DeeRobertaForSequenceClassification(BertPreTrainedModel):
|
||||
|
||||
config_class = RobertaConfig
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
|
@ -4,6 +4,7 @@ import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import run_glue_deebert
|
||||
|
||||
from transformers.testing_utils import TestCasePlus, get_gpu_count, require_torch_non_multi_gpu, slow
|
||||
|
||||
|
||||
@ -45,7 +46,6 @@ class DeeBertTests(TestCasePlus):
|
||||
@slow
|
||||
@require_torch_non_multi_gpu
|
||||
def test_glue_deebert_train(self):
|
||||
|
||||
train_args = """
|
||||
--model_type roberta
|
||||
--model_name_or_path roberta-base
|
||||
|
@ -21,14 +21,14 @@ import time
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
|
||||
from lm_seqs_dataset import LmSeqsDataset
|
||||
from torch import nn
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
|
||||
from lm_seqs_dataset import LmSeqsDataset
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
from utils import logger
|
||||
|
||||
|
@ -189,7 +189,6 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
|
@ -24,9 +24,9 @@ import shutil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from distiller import Distiller
|
||||
from lm_seqs_dataset import LmSeqsDataset
|
||||
|
||||
from transformers import (
|
||||
BertConfig,
|
||||
BertForMaskedLM,
|
||||
|
@ -5,13 +5,13 @@ import copy
|
||||
import logging
|
||||
import random
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
import joblib
|
||||
from transformers import AdamW, GPT2LMHeadModel, get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
@ -119,7 +119,6 @@ def recopy_gpt2(orig_model, device, max_steps):
|
||||
|
||||
|
||||
def intermittent_save(contexts, real_perps, past_perps, filename):
|
||||
|
||||
"""
|
||||
save the perplexity differences to filename
|
||||
|
||||
@ -152,7 +151,6 @@ def collect_objective_set(
|
||||
filename="dev.jbl",
|
||||
recopy_model=recopy_gpt2,
|
||||
):
|
||||
|
||||
"""
|
||||
Collect individual IGF values from pre-trained transformer model
|
||||
max_steps samples of training data to train secondary model
|
||||
@ -271,7 +269,6 @@ def generate_datasets(
|
||||
def train_secondary_learner(
|
||||
secondary_learner, train_dataset, max_epochs, batch_size, eval_freq=50, igf_model_path="secondary_learner.pt"
|
||||
):
|
||||
|
||||
"""
|
||||
Train the secondary learner (igf_model)
|
||||
|
||||
|
@ -28,11 +28,9 @@ Last, a plot is generated to compare the performance of IGF to standard fine-tun
|
||||
import argparse
|
||||
import random
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
import joblib
|
||||
from igf.igf import (
|
||||
SecondaryLearner,
|
||||
collect_objective_set,
|
||||
@ -43,6 +41,8 @@ from igf.igf import (
|
||||
set_seed,
|
||||
train_secondary_learner,
|
||||
)
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
from transformers import GPT2LMHeadModel
|
||||
|
||||
|
||||
@ -55,7 +55,6 @@ def generate_n_pairs(
|
||||
data_file="data/tokenized_stories_train_wikitext103.jbl",
|
||||
igf_data_file="igf_context_pairs.jbl",
|
||||
):
|
||||
|
||||
"""
|
||||
Collecting *n* pairs for training the secondary learner
|
||||
Args:
|
||||
|
@ -4,8 +4,6 @@ from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -16,6 +14,8 @@ from flax import jax_utils, struct, traverse_util
|
||||
from flax.serialization import from_bytes, to_bytes
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import shard
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering
|
||||
from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule
|
||||
|
||||
@ -98,7 +98,6 @@ class Args:
|
||||
|
||||
@dataclass
|
||||
class DataCollator:
|
||||
|
||||
pad_id: int
|
||||
max_length: int = 4096 # no dynamic padding on TPUs
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
from datasets import load_from_disk
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from bigbird_flax import FlaxBigBirdForNaturalQuestions
|
||||
from datasets import load_from_disk
|
||||
|
||||
from transformers import BigBirdTokenizerFast
|
||||
|
||||
|
||||
|
@ -1,10 +1,9 @@
|
||||
import os
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import jsonlines
|
||||
|
||||
|
||||
DOC_STRIDE = 2048
|
||||
MAX_LENGTH = 4096
|
||||
|
@ -1,12 +1,12 @@
|
||||
import os
|
||||
from dataclasses import replace
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
import jax
|
||||
import wandb
|
||||
from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step
|
||||
from datasets import load_dataset
|
||||
from flax import jax_utils
|
||||
|
||||
from transformers import BigBirdTokenizerFast
|
||||
|
||||
|
||||
|
@ -32,17 +32,17 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
from datasets import load_dataset
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
|
@ -20,6 +20,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from configuration_hybrid_clip import HybridCLIPConfig
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
|
||||
from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
|
||||
from transformers.modeling_flax_utils import FlaxPreTrainedModel
|
||||
from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
|
||||
@ -132,7 +133,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
||||
input_shape: Optional[Tuple] = None,
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
if input_shape is None:
|
||||
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
||||
|
@ -32,22 +32,22 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import torch
|
||||
from flax import jax_utils
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, shard, shard_prng_key
|
||||
from modeling_hybrid_clip import FlaxHybridCLIP
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.io import ImageReadMode, read_image
|
||||
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from tqdm import tqdm
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import transformers
|
||||
from flax import jax_utils
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, shard, shard_prng_key
|
||||
from modeling_hybrid_clip import FlaxHybridCLIP
|
||||
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
|
||||
|
||||
|
||||
|
@ -28,19 +28,19 @@ from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import Dataset, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import transformers
|
||||
from datasets import Dataset, load_dataset
|
||||
from flax.core.frozen_dict import freeze, unfreeze
|
||||
from flax.training.common_utils import onehot, stack_forest
|
||||
from jax.experimental.maps import mesh
|
||||
from jax.experimental.pjit import pjit
|
||||
from partitions import set_partitions
|
||||
from tqdm import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
|
@ -6,18 +6,18 @@ from dataclasses import field
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import DatasetDict, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import librosa
|
||||
import numpy as np
|
||||
import optax
|
||||
from datasets import DatasetDict, load_dataset
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import (
|
||||
FlaxWav2Vec2ForPreTraining,
|
||||
HfArgumentParser,
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user