mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Update quality tooling for formatting (#21480)
* Result of black 23.1 * Update target to Python 3.7 * Switch flake8 to ruff * Configure isort * Configure isort * Apply isort with line limit * Put the right black version * adapt black in check copies * Fix copies
This commit is contained in:
parent
b7bb2b59f7
commit
6f79d26442
@ -134,11 +134,10 @@ jobs:
|
|||||||
command: pip freeze | tee installed.txt
|
command: pip freeze | tee installed.txt
|
||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: ~/transformers/installed.txt
|
path: ~/transformers/installed.txt
|
||||||
- run: black --check --preview examples tests src utils
|
- run: black --check examples tests src utils
|
||||||
- run: isort --check-only examples tests src utils
|
- run: ruff examples tests src utils
|
||||||
- run: python utils/custom_init_isort.py --check_only
|
- run: python utils/custom_init_isort.py --check_only
|
||||||
- run: python utils/sort_auto_mappings.py --check_only
|
- run: python utils/sort_auto_mappings.py --check_only
|
||||||
- run: flake8 examples tests src utils
|
|
||||||
- run: doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
|
- run: doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
|
||||||
- run: python utils/check_doc_toc.py
|
- run: python utils/check_doc_toc.py
|
||||||
|
|
||||||
|
14
Makefile
14
Makefile
@ -9,9 +9,8 @@ modified_only_fixup:
|
|||||||
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
|
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
|
||||||
@if test -n "$(modified_py_files)"; then \
|
@if test -n "$(modified_py_files)"; then \
|
||||||
echo "Checking/fixing $(modified_py_files)"; \
|
echo "Checking/fixing $(modified_py_files)"; \
|
||||||
black --preview $(modified_py_files); \
|
black $(modified_py_files); \
|
||||||
isort $(modified_py_files); \
|
ruff $(modified_py_files) --fix; \
|
||||||
flake8 $(modified_py_files); \
|
|
||||||
else \
|
else \
|
||||||
echo "No library .py files were modified"; \
|
echo "No library .py files were modified"; \
|
||||||
fi
|
fi
|
||||||
@ -48,11 +47,10 @@ repo-consistency:
|
|||||||
# this target runs checks on all files
|
# this target runs checks on all files
|
||||||
|
|
||||||
quality:
|
quality:
|
||||||
black --check --preview $(check_dirs)
|
black --check $(check_dirs)
|
||||||
isort --check-only $(check_dirs)
|
|
||||||
python utils/custom_init_isort.py --check_only
|
python utils/custom_init_isort.py --check_only
|
||||||
python utils/sort_auto_mappings.py --check_only
|
python utils/sort_auto_mappings.py --check_only
|
||||||
flake8 $(check_dirs)
|
ruff $(check_dirs)
|
||||||
doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
|
doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
|
||||||
python utils/check_doc_toc.py
|
python utils/check_doc_toc.py
|
||||||
|
|
||||||
@ -67,8 +65,8 @@ extra_style_checks:
|
|||||||
# this target runs checks on all files and potentially modifies some of them
|
# this target runs checks on all files and potentially modifies some of them
|
||||||
|
|
||||||
style:
|
style:
|
||||||
black --preview $(check_dirs)
|
black $(check_dirs)
|
||||||
isort $(check_dirs)
|
ruff $(check_dirs) --fix
|
||||||
${MAKE} autogenerate_code
|
${MAKE} autogenerate_code
|
||||||
${MAKE} extra_style_checks
|
${MAKE} extra_style_checks
|
||||||
|
|
||||||
|
@ -96,7 +96,7 @@ while True:
|
|||||||
queues.append(rq)
|
queues.append(rq)
|
||||||
strings
|
strings
|
||||||
outs = pipe(strings, batch_size=len(strings))
|
outs = pipe(strings, batch_size=len(strings))
|
||||||
for (rq, out) in zip(queues, outs):
|
for rq, out in zip(queues, outs):
|
||||||
await rq.put(out)
|
await rq.put(out)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -166,7 +166,6 @@ Unlike other data collators, this specific data collator needs to apply a differ
|
|||||||
|
|
||||||
>>> @dataclass
|
>>> @dataclass
|
||||||
... class DataCollatorCTCWithPadding:
|
... class DataCollatorCTCWithPadding:
|
||||||
|
|
||||||
... processor: AutoProcessor
|
... processor: AutoProcessor
|
||||||
... padding: Union[bool, str] = "longest"
|
... padding: Union[bool, str] = "longest"
|
||||||
|
|
||||||
|
@ -213,7 +213,6 @@ The `image_processor` expects the annotations to be in the following format: `{'
|
|||||||
|
|
||||||
```py
|
```py
|
||||||
>>> def formatted_anns(image_id, category, area, bbox):
|
>>> def formatted_anns(image_id, category, area, bbox):
|
||||||
|
|
||||||
... annotations = []
|
... annotations = []
|
||||||
... for i in range(0, len(category)):
|
... for i in range(0, len(category)):
|
||||||
... new_ann = {
|
... new_ann = {
|
||||||
@ -399,6 +398,7 @@ First, prepare the `cppe5["test"]` set: format the annotations and save the data
|
|||||||
```py
|
```py
|
||||||
>>> import json
|
>>> import json
|
||||||
|
|
||||||
|
|
||||||
>>> # format annotations the same as for training, no need for data augmentation
|
>>> # format annotations the same as for training, no need for data augmentation
|
||||||
>>> def val_formatted_anns(image_id, objects):
|
>>> def val_formatted_anns(image_id, objects):
|
||||||
... annotations = []
|
... annotations = []
|
||||||
|
@ -159,7 +159,6 @@ A diferencia de otros collators de datos, este tiene que aplicarle un método de
|
|||||||
|
|
||||||
>>> @dataclass
|
>>> @dataclass
|
||||||
... class DataCollatorCTCWithPadding:
|
... class DataCollatorCTCWithPadding:
|
||||||
|
|
||||||
... processor: AutoProcessor
|
... processor: AutoProcessor
|
||||||
... padding: Union[bool, str] = "longest"
|
... padding: Union[bool, str] = "longest"
|
||||||
|
|
||||||
|
@ -29,23 +29,23 @@ from pathlib import Path
|
|||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import nltk # Here to have a nice missing dependency error message early on
|
|
||||||
import numpy as np
|
|
||||||
from datasets import Dataset, load_dataset
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import nltk # Here to have a nice missing dependency error message early on
|
||||||
|
import numpy as np
|
||||||
import optax
|
import optax
|
||||||
import transformers
|
from datasets import Dataset, load_dataset
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
from flax import jax_utils, traverse_util
|
from flax import jax_utils, traverse_util
|
||||||
from flax.jax_utils import unreplicate
|
from flax.jax_utils import unreplicate
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
from huggingface_hub import Repository, create_repo
|
from huggingface_hub import Repository, create_repo
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoImageProcessor,
|
AutoImageProcessor,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
@ -32,20 +32,20 @@ from itertools import chain
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import nltk
|
|
||||||
import numpy as np
|
|
||||||
from datasets import load_dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import flax
|
import flax
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import nltk
|
||||||
|
import numpy as np
|
||||||
import optax
|
import optax
|
||||||
|
from datasets import load_dataset
|
||||||
from flax import jax_utils, traverse_util
|
from flax import jax_utils, traverse_util
|
||||||
from flax.jax_utils import pad_shard_unpad
|
from flax.jax_utils import pad_shard_unpad
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard
|
from flax.training.common_utils import get_metrics, onehot, shard
|
||||||
from huggingface_hub import Repository, create_repo
|
from huggingface_hub import Repository, create_repo
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
|
@ -34,19 +34,19 @@ from pathlib import Path
|
|||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
|
||||||
from datasets import Dataset, load_dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
import optax
|
import optax
|
||||||
import transformers
|
from datasets import Dataset, load_dataset
|
||||||
from flax import jax_utils, traverse_util
|
from flax import jax_utils, traverse_util
|
||||||
from flax.jax_utils import pad_shard_unpad, unreplicate
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
from huggingface_hub import Repository, create_repo
|
from huggingface_hub import Repository, create_repo
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
|
@ -34,19 +34,19 @@ from itertools import chain
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from datasets import load_dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import flax
|
import flax
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
import optax
|
import optax
|
||||||
|
from datasets import load_dataset
|
||||||
from flax import jax_utils, traverse_util
|
from flax import jax_utils, traverse_util
|
||||||
from flax.jax_utils import pad_shard_unpad
|
from flax.jax_utils import pad_shard_unpad
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard
|
from flax.training.common_utils import get_metrics, onehot, shard
|
||||||
from huggingface_hub import Repository, create_repo
|
from huggingface_hub import Repository, create_repo
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
|
@ -33,19 +33,19 @@ from itertools import chain
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from datasets import load_dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import flax
|
import flax
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
import optax
|
import optax
|
||||||
|
from datasets import load_dataset
|
||||||
from flax import jax_utils, traverse_util
|
from flax import jax_utils, traverse_util
|
||||||
from flax.jax_utils import pad_shard_unpad
|
from flax.jax_utils import pad_shard_unpad
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard
|
from flax.training.common_utils import get_metrics, onehot, shard
|
||||||
from huggingface_hub import Repository, create_repo
|
from huggingface_hub import Repository, create_repo
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
|
@ -31,20 +31,21 @@ from pathlib import Path
|
|||||||
from typing import Any, Callable, Dict, Optional, Tuple
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
|
||||||
from datasets import load_dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
import optax
|
import optax
|
||||||
import transformers
|
from datasets import load_dataset
|
||||||
from flax import struct, traverse_util
|
from flax import struct, traverse_util
|
||||||
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
|
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard
|
from flax.training.common_utils import get_metrics, onehot, shard
|
||||||
from huggingface_hub import Repository, create_repo
|
from huggingface_hub import Repository, create_repo
|
||||||
|
from tqdm import tqdm
|
||||||
|
from utils_qa import postprocess_qa_predictions
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@ -55,7 +56,6 @@ from transformers import (
|
|||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
)
|
)
|
||||||
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
||||||
from utils_qa import postprocess_qa_predictions
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -301,6 +301,7 @@ class DataTrainingArguments:
|
|||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
# region Create a train state
|
# region Create a train state
|
||||||
def create_train_state(
|
def create_train_state(
|
||||||
model: FlaxAutoModelForQuestionAnswering,
|
model: FlaxAutoModelForQuestionAnswering,
|
||||||
@ -387,6 +388,7 @@ def create_learning_rate_fn(
|
|||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
# region train data iterator
|
# region train data iterator
|
||||||
def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
|
def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
|
||||||
"""Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices."""
|
"""Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices."""
|
||||||
@ -405,6 +407,7 @@ def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
|
|||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
# region eval data iterator
|
# region eval data iterator
|
||||||
def eval_data_collator(dataset: Dataset, batch_size: int):
|
def eval_data_collator(dataset: Dataset, batch_size: int):
|
||||||
"""Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop."""
|
"""Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop."""
|
||||||
@ -934,7 +937,6 @@ def main():
|
|||||||
total_steps = step_per_epoch * num_epochs
|
total_steps = step_per_epoch * num_epochs
|
||||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||||
for epoch in epochs:
|
for epoch in epochs:
|
||||||
|
|
||||||
train_start = time.time()
|
train_start = time.time()
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
|
|
||||||
@ -975,7 +977,6 @@ def main():
|
|||||||
and (cur_step % training_args.eval_steps == 0 or cur_step % step_per_epoch == 0)
|
and (cur_step % training_args.eval_steps == 0 or cur_step % step_per_epoch == 0)
|
||||||
and cur_step > 0
|
and cur_step > 0
|
||||||
):
|
):
|
||||||
|
|
||||||
eval_metrics = {}
|
eval_metrics = {}
|
||||||
all_start_logits = []
|
all_start_logits = []
|
||||||
all_end_logits = []
|
all_end_logits = []
|
||||||
|
@ -31,22 +31,22 @@ from pathlib import Path
|
|||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import nltk # Here to have a nice missing dependency error message early on
|
|
||||||
import numpy as np
|
|
||||||
from datasets import Dataset, load_dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import nltk # Here to have a nice missing dependency error message early on
|
||||||
|
import numpy as np
|
||||||
import optax
|
import optax
|
||||||
import transformers
|
from datasets import Dataset, load_dataset
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
from flax import jax_utils, traverse_util
|
from flax import jax_utils, traverse_util
|
||||||
from flax.jax_utils import pad_shard_unpad, unreplicate
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
from huggingface_hub import Repository, create_repo
|
from huggingface_hub import Repository, create_repo
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
|
@ -26,20 +26,20 @@ from pathlib import Path
|
|||||||
from typing import Any, Callable, Dict, Optional, Tuple
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
|
||||||
from datasets import load_dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
import optax
|
import optax
|
||||||
import transformers
|
from datasets import load_dataset
|
||||||
from flax import struct, traverse_util
|
from flax import struct, traverse_util
|
||||||
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
|
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard
|
from flax.training.common_utils import get_metrics, onehot, shard
|
||||||
from huggingface_hub import Repository, create_repo
|
from huggingface_hub import Repository, create_repo
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@ -586,7 +586,6 @@ def main():
|
|||||||
total_steps = steps_per_epoch * num_epochs
|
total_steps = steps_per_epoch * num_epochs
|
||||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (0/{num_epochs})", position=0)
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (0/{num_epochs})", position=0)
|
||||||
for epoch in epochs:
|
for epoch in epochs:
|
||||||
|
|
||||||
train_start = time.time()
|
train_start = time.time()
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
|
|
||||||
@ -623,7 +622,6 @@ def main():
|
|||||||
train_metrics = []
|
train_metrics = []
|
||||||
|
|
||||||
if (cur_step % training_args.eval_steps == 0 or cur_step % steps_per_epoch == 0) and cur_step > 0:
|
if (cur_step % training_args.eval_steps == 0 or cur_step % steps_per_epoch == 0) and cur_step > 0:
|
||||||
|
|
||||||
# evaluate
|
# evaluate
|
||||||
eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size)
|
eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size)
|
||||||
for batch in tqdm(
|
for batch in tqdm(
|
||||||
|
@ -28,20 +28,20 @@ from pathlib import Path
|
|||||||
from typing import Any, Callable, Dict, Optional, Tuple
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
|
||||||
from datasets import ClassLabel, load_dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
import optax
|
import optax
|
||||||
import transformers
|
from datasets import ClassLabel, load_dataset
|
||||||
from flax import struct, traverse_util
|
from flax import struct, traverse_util
|
||||||
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
|
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard
|
from flax.training.common_utils import get_metrics, onehot, shard
|
||||||
from huggingface_hub import Repository, create_repo
|
from huggingface_hub import Repository, create_repo
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@ -695,7 +695,6 @@ def main():
|
|||||||
total_steps = step_per_epoch * num_epochs
|
total_steps = step_per_epoch * num_epochs
|
||||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||||
for epoch in epochs:
|
for epoch in epochs:
|
||||||
|
|
||||||
train_start = time.time()
|
train_start = time.time()
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
|
|
||||||
@ -731,7 +730,6 @@ def main():
|
|||||||
train_metrics = []
|
train_metrics = []
|
||||||
|
|
||||||
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
|
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
|
||||||
|
|
||||||
eval_metrics = {}
|
eval_metrics = {}
|
||||||
# evaluate
|
# evaluate
|
||||||
for batch in tqdm(
|
for batch in tqdm(
|
||||||
|
@ -29,21 +29,22 @@ from enum import Enum
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import optax
|
||||||
|
|
||||||
# for dataset and preprocessing
|
# for dataset and preprocessing
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import optax
|
|
||||||
import transformers
|
|
||||||
from flax import jax_utils
|
from flax import jax_utils
|
||||||
from flax.jax_utils import pad_shard_unpad, unreplicate
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
from huggingface_hub import Repository, create_repo
|
from huggingface_hub import Repository, create_repo
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
|
@ -22,6 +22,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from utils_multiple_choice import MultipleChoiceDataset, Split, processors
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@ -36,7 +37,6 @@ from transformers import (
|
|||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import is_main_process
|
from transformers.trainer_utils import is_main_process
|
||||||
from utils_multiple_choice import MultipleChoiceDataset, Split, processors
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -26,8 +26,8 @@ from enum import Enum
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
|
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
@ -112,7 +112,6 @@ if is_torch_available():
|
|||||||
# and the others will use the cache.
|
# and the others will use the cache.
|
||||||
lock_path = cached_features_file + ".lock"
|
lock_path = cached_features_file + ".lock"
|
||||||
with FileLock(lock_path):
|
with FileLock(lock_path):
|
||||||
|
|
||||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||||
self.features = torch.load(cached_features_file)
|
self.features = torch.load(cached_features_file)
|
||||||
|
@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
config=None,
|
config=None,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
model=None,
|
model=None,
|
||||||
**config_kwargs
|
**config_kwargs,
|
||||||
):
|
):
|
||||||
"""Initialize a model, tokenizer and config."""
|
"""Initialize a model, tokenizer and config."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -346,7 +346,7 @@ def generic_train(
|
|||||||
extra_callbacks=[],
|
extra_callbacks=[],
|
||||||
checkpoint_callback=None,
|
checkpoint_callback=None,
|
||||||
logging_callback=None,
|
logging_callback=None,
|
||||||
**extra_train_kwargs
|
**extra_train_kwargs,
|
||||||
):
|
):
|
||||||
pl.seed_everything(args.seed)
|
pl.seed_everything(args.seed)
|
||||||
|
|
||||||
|
@ -7,21 +7,19 @@ from argparse import Namespace
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||||
from torch.utils.data import DataLoader, TensorDataset
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
|
||||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
|
||||||
from transformers import glue_compute_metrics as compute_metrics
|
from transformers import glue_compute_metrics as compute_metrics
|
||||||
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
||||||
from transformers import glue_output_modes
|
from transformers import glue_output_modes, glue_tasks_num_labels
|
||||||
from transformers import glue_processors as processors
|
from transformers import glue_processors as processors
|
||||||
from transformers import glue_tasks_num_labels
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GLUETransformer(BaseTransformer):
|
class GLUETransformer(BaseTransformer):
|
||||||
|
|
||||||
mode = "sequence-classification"
|
mode = "sequence-classification"
|
||||||
|
|
||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
|
@ -7,11 +7,10 @@ from importlib import import_module
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||||
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
|
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from torch.utils.data import DataLoader, TensorDataset
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
|
||||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
|
||||||
from utils_ner import TokenClassificationTask
|
from utils_ner import TokenClassificationTask
|
||||||
|
|
||||||
|
|
||||||
|
@ -172,7 +172,6 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
for _ in train_iterator:
|
for _ in train_iterator:
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||||
for step, batch in enumerate(epoch_iterator):
|
for step, batch in enumerate(epoch_iterator):
|
||||||
|
|
||||||
# Skip past any already trained steps if resuming training
|
# Skip past any already trained steps if resuming training
|
||||||
if steps_trained_in_current_epoch > 0:
|
if steps_trained_in_current_epoch > 0:
|
||||||
steps_trained_in_current_epoch -= 1
|
steps_trained_in_current_epoch -= 1
|
||||||
|
@ -30,9 +30,10 @@ from transformers import (
|
|||||||
DataCollatorWithPadding,
|
DataCollatorWithPadding,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
SquadDataset,
|
SquadDataset,
|
||||||
|
Trainer,
|
||||||
|
TrainingArguments,
|
||||||
)
|
)
|
||||||
from transformers import SquadDataTrainingArguments as DataTrainingArguments
|
from transformers import SquadDataTrainingArguments as DataTrainingArguments
|
||||||
from transformers import Trainer, TrainingArguments
|
|
||||||
from transformers.trainer_utils import is_main_process
|
from transformers.trainer_utils import is_main_process
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import json
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from ltp import LTP
|
from ltp import LTP
|
||||||
|
|
||||||
from transformers import BertTokenizer
|
from transformers import BertTokenizer
|
||||||
|
|
||||||
|
|
||||||
@ -93,7 +94,6 @@ def prepare_ref(lines: List[str], ltp_tokenizer: LTP, bert_tokenizer: BertTokeni
|
|||||||
|
|
||||||
ref_ids = []
|
ref_ids = []
|
||||||
for input_ids, chinese_word in zip(bert_res, ltp_res):
|
for input_ids, chinese_word in zip(bert_res, ltp_res):
|
||||||
|
|
||||||
input_tokens = []
|
input_tokens = []
|
||||||
for id in input_ids:
|
for id in input_ids:
|
||||||
token = bert_tokenizer._convert_id_to_token(id)
|
token = bert_tokenizer._convert_id_to_token(id)
|
||||||
|
@ -19,9 +19,10 @@ import sys
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import transformers
|
|
||||||
from seq2seq_trainer import Seq2SeqTrainer
|
from seq2seq_trainer import Seq2SeqTrainer
|
||||||
from seq2seq_training_args import Seq2SeqTrainingArguments
|
from seq2seq_training_args import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
@ -337,7 +338,6 @@ def main():
|
|||||||
metrics["val_loss"] = round(metrics["val_loss"], 4)
|
metrics["val_loss"] = round(metrics["val_loss"], 4)
|
||||||
|
|
||||||
if trainer.is_world_process_zero():
|
if trainer.is_world_process_zero():
|
||||||
|
|
||||||
handle_metrics("val", metrics, training_args.output_dir)
|
handle_metrics("val", metrics, training_args.output_dir)
|
||||||
all_metrics.update(metrics)
|
all_metrics.update(metrics)
|
||||||
|
|
||||||
|
@ -16,8 +16,8 @@ from collections import defaultdict
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from rouge_cli import calculate_rouge_path
|
from rouge_cli import calculate_rouge_path
|
||||||
|
|
||||||
from utils import calculate_rouge
|
from utils import calculate_rouge
|
||||||
|
|
||||||
|
|
||||||
@ -87,7 +87,6 @@ def test_single_sent_scores_dont_depend_on_newline_sep():
|
|||||||
|
|
||||||
|
|
||||||
def test_pegasus_newline():
|
def test_pegasus_newline():
|
||||||
|
|
||||||
pred = [
|
pred = [
|
||||||
"""" "a person who has such a video needs to immediately give it to the investigators," prosecutor says .<n> "it is a very disturbing scene," editor-in-chief of bild online tells "erin burnett: outfront" """
|
"""" "a person who has such a video needs to immediately give it to the investigators," prosecutor says .<n> "it is a very disturbing scene," editor-in-chief of bild online tells "erin burnett: outfront" """
|
||||||
]
|
]
|
||||||
|
@ -17,11 +17,11 @@ from pathlib import Path
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from pack_dataset import pack_data_dir
|
from pack_dataset import pack_data_dir
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from save_len_file import save_len_file
|
from save_len_file import save_len_file
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from transformers.models.mbart.modeling_mbart import shift_tokens_right
|
from transformers.models.mbart.modeling_mbart import shift_tokens_right
|
||||||
from transformers.testing_utils import TestCasePlus, slow
|
from transformers.testing_utils import TestCasePlus, slow
|
||||||
|
@ -18,6 +18,7 @@ import json
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import FSMTForConditionalGeneration, FSMTTokenizer
|
from transformers import FSMTForConditionalGeneration, FSMTTokenizer
|
||||||
from transformers.testing_utils import get_tests_dir, require_torch, slow, torch_device
|
from transformers.testing_utils import get_tests_dir, require_torch, slow, torch_device
|
||||||
from utils import calculate_bleu
|
from utils import calculate_bleu
|
||||||
|
@ -21,6 +21,7 @@ from unittest.mock import patch
|
|||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from run_eval import run_generate
|
from run_eval import run_generate
|
||||||
from run_eval_search import run_search
|
from run_eval_search import run_search
|
||||||
|
|
||||||
from transformers.testing_utils import CaptureStdout, TestCasePlus, slow
|
from transformers.testing_utils import CaptureStdout, TestCasePlus, slow
|
||||||
from utils import ROUGE_KEYS
|
from utils import ROUGE_KEYS
|
||||||
|
|
||||||
|
@ -29,7 +29,6 @@ from transformers import AutoTokenizer
|
|||||||
|
|
||||||
|
|
||||||
def pack_examples(tok, src_examples, tgt_examples, max_tokens=1024):
|
def pack_examples(tok, src_examples, tgt_examples, max_tokens=1024):
|
||||||
|
|
||||||
finished_src, finished_tgt = [], []
|
finished_src, finished_tgt = [], []
|
||||||
|
|
||||||
sorted_examples = list(zip(src_examples, tgt_examples))
|
sorted_examples = list(zip(src_examples, tgt_examples))
|
||||||
|
@ -20,6 +20,7 @@ import sys
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from run_eval import datetime_now, run_generate
|
from run_eval import datetime_now, run_generate
|
||||||
|
|
||||||
from utils import ROUGE_KEYS
|
from utils import ROUGE_KEYS
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from seq2seq_trainer import arg_to_scheduler
|
from seq2seq_trainer import arg_to_scheduler
|
||||||
|
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,10 +29,10 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from rouge_score import rouge_scorer, scoring
|
from rouge_score import rouge_scorer, scoring
|
||||||
from sacrebleu import corpus_bleu
|
from sacrebleu import corpus_bleu
|
||||||
|
from sentence_splitter import add_newline_to_end_of_each_sentence
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import Dataset, Sampler
|
from torch.utils.data import Dataset, Sampler
|
||||||
|
|
||||||
from sentence_splitter import add_newline_to_end_of_each_sentence
|
|
||||||
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
|
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
|
||||||
from transformers.models.bart.modeling_bart import shift_tokens_right
|
from transformers.models.bart.modeling_bart import shift_tokens_right
|
||||||
from transformers.utils import cached_property
|
from transformers.utils import cached_property
|
||||||
@ -132,7 +132,7 @@ class AbstractSeq2SeqDataset(Dataset):
|
|||||||
type_path="train",
|
type_path="train",
|
||||||
n_obs=None,
|
n_obs=None,
|
||||||
prefix="",
|
prefix="",
|
||||||
**dataset_kwargs
|
**dataset_kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
||||||
|
@ -24,6 +24,7 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
|
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from utils_ner import Split, TokenClassificationDataset, TokenClassificationTask
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@ -38,7 +39,6 @@ from transformers import (
|
|||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import is_main_process
|
from transformers.trainer_utils import is_main_process
|
||||||
from utils_ner import Split, TokenClassificationDataset, TokenClassificationTask
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -24,6 +24,7 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
|
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
|
||||||
|
from utils_ner import Split, TFTokenClassificationDataset, TokenClassificationTask
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
@ -35,7 +36,6 @@ from transformers import (
|
|||||||
TFTrainingArguments,
|
TFTrainingArguments,
|
||||||
)
|
)
|
||||||
from transformers.utils import logging as hf_logging
|
from transformers.utils import logging as hf_logging
|
||||||
from utils_ner import Split, TFTokenClassificationDataset, TokenClassificationTask
|
|
||||||
|
|
||||||
|
|
||||||
hf_logging.set_verbosity_info()
|
hf_logging.set_verbosity_info()
|
||||||
|
@ -3,7 +3,6 @@ import os
|
|||||||
from typing import List, TextIO, Union
|
from typing import List, TextIO, Union
|
||||||
|
|
||||||
from conllu import parse_incr
|
from conllu import parse_incr
|
||||||
|
|
||||||
from utils_ner import InputExample, Split, TokenClassificationTask
|
from utils_ner import InputExample, Split, TokenClassificationTask
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ from enum import Enum
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
|
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
@ -240,7 +241,6 @@ if is_torch_available():
|
|||||||
# and the others will use the cache.
|
# and the others will use the cache.
|
||||||
lock_path = cached_features_file + ".lock"
|
lock_path = cached_features_file + ".lock"
|
||||||
with FileLock(lock_path):
|
with FileLock(lock_path):
|
||||||
|
|
||||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||||
self.features = torch.load(cached_features_file)
|
self.features = torch.load(cached_features_file)
|
||||||
|
@ -23,10 +23,10 @@ from random import randint
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import evaluate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import DatasetDict, load_dataset
|
from datasets import DatasetDict, load_dataset
|
||||||
|
|
||||||
import evaluate
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
|
@ -19,6 +19,7 @@ import sys
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
@ -33,7 +34,6 @@ from torchvision.transforms import (
|
|||||||
ToTensor,
|
ToTensor,
|
||||||
)
|
)
|
||||||
|
|
||||||
import evaluate
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||||
|
@ -21,8 +21,13 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import evaluate
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
from accelerate.utils import set_seed
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from huggingface_hub import Repository, create_repo
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision.transforms import (
|
from torchvision.transforms import (
|
||||||
CenterCrop,
|
CenterCrop,
|
||||||
@ -35,12 +40,7 @@ from torchvision.transforms import (
|
|||||||
)
|
)
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import evaluate
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator
|
|
||||||
from accelerate.logging import get_logger
|
|
||||||
from accelerate.utils import set_seed
|
|
||||||
from huggingface_hub import Repository, create_repo
|
|
||||||
from transformers import AutoConfig, AutoImageProcessor, AutoModelForImageClassification, SchedulerType, get_scheduler
|
from transformers import AutoConfig, AutoImageProcessor, AutoModelForImageClassification, SchedulerType, get_scheduler
|
||||||
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
@ -30,10 +30,10 @@ from itertools import chain
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import evaluate
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
import evaluate
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
|
@ -33,15 +33,15 @@ from pathlib import Path
|
|||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import Accelerator, DistributedType
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
from accelerate.utils import set_seed
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from huggingface_hub import Repository, create_repo
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator, DistributedType
|
|
||||||
from accelerate.logging import get_logger
|
|
||||||
from accelerate.utils import set_seed
|
|
||||||
from huggingface_hub import Repository, create_repo
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
|
@ -30,9 +30,9 @@ from itertools import chain
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import evaluate
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
import evaluate
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
|
@ -33,15 +33,15 @@ from pathlib import Path
|
|||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import Accelerator, DistributedType
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
from accelerate.utils import set_seed
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from huggingface_hub import Repository, create_repo
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator, DistributedType
|
|
||||||
from accelerate.logging import get_logger
|
|
||||||
from accelerate.utils import set_seed
|
|
||||||
from huggingface_hub import Repository, create_repo
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
|
@ -30,17 +30,17 @@ from pathlib import Path
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
|
||||||
from datasets import load_dataset
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
import transformers
|
import torch
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
from datasets import load_dataset
|
||||||
from huggingface_hub import Repository, create_repo
|
from huggingface_hub import Repository, create_repo
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
|
@ -25,11 +25,12 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
from datasets import load_dataset
|
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
import transformers
|
from datasets import load_dataset
|
||||||
from trainer_qa import QuestionAnsweringTrainer
|
from trainer_qa import QuestionAnsweringTrainer
|
||||||
|
from utils_qa import postprocess_qa_predictions
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
@ -45,7 +46,6 @@ from transformers import (
|
|||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils import check_min_version, send_example_telemetry
|
from transformers.utils import check_min_version, send_example_telemetry
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from utils_qa import postprocess_qa_predictions
|
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
|
@ -25,11 +25,12 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
from datasets import load_dataset
|
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
import transformers
|
from datasets import load_dataset
|
||||||
from trainer_qa import QuestionAnsweringTrainer
|
from trainer_qa import QuestionAnsweringTrainer
|
||||||
|
from utils_qa import postprocess_qa_predictions_with_beam_search
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
DataCollatorWithPadding,
|
DataCollatorWithPadding,
|
||||||
EvalPrediction,
|
EvalPrediction,
|
||||||
@ -44,7 +45,6 @@ from transformers import (
|
|||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils import check_min_version, send_example_telemetry
|
from transformers.utils import check_min_version, send_example_telemetry
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from utils_qa import postprocess_qa_predictions_with_beam_search
|
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
|
@ -27,18 +27,19 @@ import random
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import evaluate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
import evaluate
|
|
||||||
import transformers
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
from datasets import load_dataset
|
||||||
from huggingface_hub import Repository, create_repo
|
from huggingface_hub import Repository, create_repo
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from utils_qa import postprocess_qa_predictions_with_beam_search
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AdamW,
|
AdamW,
|
||||||
DataCollatorWithPadding,
|
DataCollatorWithPadding,
|
||||||
@ -52,7 +53,6 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from utils_qa import postprocess_qa_predictions_with_beam_search
|
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
|
@ -27,18 +27,19 @@ import random
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import evaluate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
import evaluate
|
|
||||||
import transformers
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
from datasets import load_dataset
|
||||||
from huggingface_hub import Repository, create_repo
|
from huggingface_hub import Repository, create_repo
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from utils_qa import postprocess_qa_predictions
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
@ -53,7 +54,6 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from utils_qa import postprocess_qa_predictions
|
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
|
@ -25,11 +25,11 @@ from dataclasses import dataclass, field
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
from datasets import load_dataset
|
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
import transformers
|
from datasets import load_dataset
|
||||||
from trainer_seq2seq_qa import QuestionAnsweringSeq2SeqTrainer
|
from trainer_seq2seq_qa import QuestionAnsweringSeq2SeqTrainer
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
|
@ -21,17 +21,17 @@ import sys
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import evaluate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.transforms import functional
|
from torchvision.transforms import functional
|
||||||
|
|
||||||
import evaluate
|
|
||||||
import transformers
|
import transformers
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoImageProcessor,
|
AutoImageProcessor,
|
||||||
|
@ -22,21 +22,21 @@ import random
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import evaluate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
from accelerate.utils import set_seed
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from huggingface_hub import Repository, create_repo, hf_hub_download
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.transforms import functional
|
from torchvision.transforms import functional
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import evaluate
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator
|
|
||||||
from accelerate.logging import get_logger
|
|
||||||
from accelerate.utils import set_seed
|
|
||||||
from huggingface_hub import Repository, create_repo, hf_hub_download
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoImageProcessor,
|
AutoImageProcessor,
|
||||||
|
@ -24,14 +24,14 @@ from typing import Dict, List, Optional, Union
|
|||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from accelerate.logging import get_logger
|
||||||
from datasets import DatasetDict, concatenate_datasets, load_dataset
|
from datasets import DatasetDict, concatenate_datasets, load_dataset
|
||||||
|
from huggingface_hub import Repository, create_repo
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data.dataloader import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator
|
|
||||||
from accelerate.logging import get_logger
|
|
||||||
from huggingface_hub import Repository, create_repo
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AdamW,
|
AdamW,
|
||||||
SchedulerType,
|
SchedulerType,
|
||||||
@ -641,7 +641,6 @@ def main():
|
|||||||
|
|
||||||
# update step
|
# update step
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
if (step + 1) % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||||
|
|
||||||
# compute grad norm for monitoring
|
# compute grad norm for monitoring
|
||||||
scale = (
|
scale = (
|
||||||
accelerator.scaler._scale.item()
|
accelerator.scaler._scale.item()
|
||||||
|
@ -26,11 +26,11 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import evaluate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import DatasetDict, load_dataset
|
from datasets import DatasetDict, load_dataset
|
||||||
|
|
||||||
import evaluate
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
@ -708,7 +708,6 @@ def main():
|
|||||||
|
|
||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
|
|
||||||
# use last checkpoint if exist
|
# use last checkpoint if exist
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
checkpoint = last_checkpoint
|
checkpoint = last_checkpoint
|
||||||
|
@ -26,10 +26,10 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import evaluate
|
||||||
import torch
|
import torch
|
||||||
from datasets import DatasetDict, load_dataset
|
from datasets import DatasetDict, load_dataset
|
||||||
|
|
||||||
import evaluate
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
|
@ -25,13 +25,13 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import evaluate
|
||||||
import nltk # Here to have a nice missing dependency error message early on
|
import nltk # Here to have a nice missing dependency error message early on
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
import evaluate
|
|
||||||
import transformers
|
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
|
@ -27,20 +27,20 @@ import random
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import evaluate
|
||||||
import nltk
|
import nltk
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
import evaluate
|
|
||||||
import transformers
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
from datasets import load_dataset
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
from huggingface_hub import Repository, create_repo
|
from huggingface_hub import Repository, create_repo
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
|
@ -24,8 +24,8 @@ import tempfile
|
|||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from accelerate.utils import write_basic_config
|
from accelerate.utils import write_basic_config
|
||||||
|
|
||||||
from transformers.testing_utils import TestCasePlus, get_gpu_count, run_command, slow, torch_device
|
from transformers.testing_utils import TestCasePlus, get_gpu_count, run_command, slow, torch_device
|
||||||
from transformers.utils import is_apex_available
|
from transformers.utils import is_apex_available
|
||||||
|
|
||||||
|
@ -24,10 +24,10 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import evaluate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
import evaluate
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
|
@ -22,17 +22,17 @@ import random
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
|
||||||
from datasets import load_dataset
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
import transformers
|
import torch
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
from datasets import load_dataset
|
||||||
from huggingface_hub import Repository, create_repo
|
from huggingface_hub import Repository, create_repo
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
|
@ -25,10 +25,10 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import evaluate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
import evaluate
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
|
@ -26,10 +26,10 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import evaluate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import ClassLabel, load_dataset
|
from datasets import ClassLabel, load_dataset
|
||||||
|
|
||||||
import evaluate
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
|
@ -27,17 +27,17 @@ import random
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
|
||||||
from datasets import ClassLabel, load_dataset
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
import transformers
|
import torch
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
from datasets import ClassLabel, load_dataset
|
||||||
from huggingface_hub import Repository, create_repo
|
from huggingface_hub import Repository, create_repo
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
|
@ -25,10 +25,10 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import evaluate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
import evaluate
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
|
@ -27,18 +27,18 @@ import random
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import evaluate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
import evaluate
|
|
||||||
import transformers
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
|
from datasets import load_dataset
|
||||||
from huggingface_hub import Repository, create_repo
|
from huggingface_hub import Repository, create_repo
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
@ -69,7 +69,6 @@ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
|||||||
|
|
||||||
# Parsing input arguments
|
# Parsing input arguments
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
|
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset_name",
|
"--dataset_name",
|
||||||
@ -751,5 +750,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
@ -22,6 +22,7 @@ from typing import Dict, List, Optional
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from utils_hans import HansDataset, InputFeatures, hans_processors, hans_tasks_num_labels
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@ -35,7 +36,6 @@ from transformers import (
|
|||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import is_main_process
|
from transformers.trainer_utils import is_main_process
|
||||||
from utils_hans import HansDataset, InputFeatures, hans_processors, hans_tasks_num_labels
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -20,8 +20,8 @@ from dataclasses import dataclass
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
BartTokenizer,
|
BartTokenizer,
|
||||||
BartTokenizerFast,
|
BartTokenizerFast,
|
||||||
@ -134,7 +134,6 @@ if is_torch_available():
|
|||||||
# and the others will use the cache.
|
# and the others will use the cache.
|
||||||
lock_path = cached_features_file + ".lock"
|
lock_path = cached_features_file + ".lock"
|
||||||
with FileLock(lock_path):
|
with FileLock(lock_path):
|
||||||
|
|
||||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||||
self.features = torch.load(cached_features_file)
|
self.features = torch.load(cached_features_file)
|
||||||
|
@ -25,14 +25,14 @@ import random
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from pabee.modeling_pabee_albert import AlbertForSequenceClassificationWithPabee
|
||||||
|
from pabee.modeling_pabee_bert import BertForSequenceClassificationWithPabee
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from pabee.modeling_pabee_albert import AlbertForSequenceClassificationWithPabee
|
|
||||||
from pabee.modeling_pabee_bert import BertForSequenceClassificationWithPabee
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
AdamW,
|
AdamW,
|
||||||
@ -173,7 +173,6 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
for _ in train_iterator:
|
for _ in train_iterator:
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||||
for step, batch in enumerate(epoch_iterator):
|
for step, batch in enumerate(epoch_iterator):
|
||||||
|
|
||||||
# Skip past any already trained steps if resuming training
|
# Skip past any already trained steps if resuming training
|
||||||
if steps_trained_in_current_epoch > 0:
|
if steps_trained_in_current_epoch > 0:
|
||||||
steps_trained_in_current_epoch -= 1
|
steps_trained_in_current_epoch -= 1
|
||||||
@ -263,7 +262,6 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
|
|
||||||
def evaluate(args, model, tokenizer, prefix="", patience=0):
|
def evaluate(args, model, tokenizer, prefix="", patience=0):
|
||||||
|
|
||||||
if args.model_type == "albert":
|
if args.model_type == "albert":
|
||||||
model.albert.set_regression_threshold(args.regression_threshold)
|
model.albert.set_regression_threshold(args.regression_threshold)
|
||||||
model.albert.set_patience(patience)
|
model.albert.set_patience(patience)
|
||||||
@ -736,7 +734,6 @@ def main():
|
|||||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||||
|
|
||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
|
|
||||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import sys
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import run_glue_with_pabee
|
import run_glue_with_pabee
|
||||||
|
|
||||||
from transformers.testing_utils import TestCasePlus
|
from transformers.testing_utils import TestCasePlus
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,9 +24,9 @@ import logging
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from model_bertabs import BertAbsSummarizer
|
from model_bertabs import BertAbsSummarizer
|
||||||
from models.model_builder import AbsSummarizer # The authors' implementation
|
from models.model_builder import AbsSummarizer # The authors' implementation
|
||||||
|
|
||||||
from transformers import BertTokenizer
|
from transformers import BertTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,10 +24,10 @@ import math
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from configuration_bertabs import BertAbsConfig
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.init import xavier_uniform_
|
from torch.nn.init import xavier_uniform_
|
||||||
|
|
||||||
from configuration_bertabs import BertAbsConfig
|
|
||||||
from transformers import BertConfig, BertModel, PreTrainedModel
|
from transformers import BertConfig, BertModel, PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,10 +6,10 @@ import sys
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from modeling_bertabs import BertAbs, build_predictor
|
||||||
from torch.utils.data import DataLoader, SequentialSampler
|
from torch.utils.data import DataLoader, SequentialSampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from modeling_bertabs import BertAbs, build_predictor
|
|
||||||
from transformers import BertTokenizer
|
from transformers import BertTokenizer
|
||||||
|
|
||||||
from .utils_summarization import (
|
from .utils_summarization import (
|
||||||
@ -45,7 +45,6 @@ def evaluate(args):
|
|||||||
generated_summaries = []
|
generated_summaries = []
|
||||||
|
|
||||||
import nltk
|
import nltk
|
||||||
|
|
||||||
import rouge
|
import rouge
|
||||||
|
|
||||||
nltk.download("punkt")
|
nltk.download("punkt")
|
||||||
|
@ -3,8 +3,8 @@ from copy import deepcopy
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import ClassLabel, DatasetDict, load_dataset
|
from datasets import ClassLabel, DatasetDict, load_dataset
|
||||||
|
|
||||||
from evaluate import load
|
from evaluate import load
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
|
from arguments import TokenizerTrainingArguments
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from arguments import TokenizerTrainingArguments
|
|
||||||
from transformers import AutoTokenizer, HfArgumentParser
|
from transformers import AutoTokenizer, HfArgumentParser
|
||||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||||
|
|
||||||
|
@ -6,16 +6,16 @@ from pathlib import Path
|
|||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import Accelerator, DistributedType
|
||||||
|
from arguments import TrainingArguments
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from huggingface_hub import Repository
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import IterableDataset
|
from torch.utils.data import IterableDataset
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data.dataloader import DataLoader
|
||||||
from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
|
from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator, DistributedType
|
|
||||||
from arguments import TrainingArguments
|
|
||||||
from huggingface_hub import Repository
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
|
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,15 +5,15 @@ import re
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
from arguments import HumanEvalArguments
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset, load_metric
|
||||||
from torch.utils.data import IterableDataset
|
from torch.utils.data import IterableDataset
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data.dataloader import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator
|
|
||||||
from accelerate.utils import set_seed
|
|
||||||
from arguments import HumanEvalArguments
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, StoppingCriteria, StoppingCriteriaList
|
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, StoppingCriteria, StoppingCriteriaList
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from arguments import InitializationArguments
|
from arguments import InitializationArguments
|
||||||
|
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,10 +6,9 @@ from functools import partial
|
|||||||
from typing import Dict, List, Optional, Set, Tuple, Type
|
from typing import Dict, List, Optional, Set, Tuple, Type
|
||||||
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from datasketch import MinHash, MinHashLSH
|
from datasketch import MinHash, MinHashLSH
|
||||||
from dpu_utils.utils.iterators import ThreadedIterator
|
from dpu_utils.utils.iterators import ThreadedIterator
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
NON_ALPHA = re.compile("[^A-Za-z_0-9]")
|
NON_ALPHA = re.compile("[^A-Za-z_0-9]")
|
||||||
|
@ -9,10 +9,10 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset
|
|
||||||
|
|
||||||
from arguments import PreprocessingArguments
|
from arguments import PreprocessingArguments
|
||||||
|
from datasets import load_dataset
|
||||||
from minhash_deduplication import deduplicate_dataset
|
from minhash_deduplication import deduplicate_dataset
|
||||||
|
|
||||||
from transformers import AutoTokenizer, HfArgumentParser
|
from transformers import AutoTokenizer, HfArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from arguments import PretokenizationArguments
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
from arguments import PretokenizationArguments
|
|
||||||
from transformers import AutoTokenizer, HfArgumentParser
|
from transformers import AutoTokenizer, HfArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
|
||||||
from minhash_deduplication import deduplicate_dataset, make_duplicate_clusters
|
from minhash_deduplication import deduplicate_dataset, make_duplicate_clusters
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from arguments import EvaluationArguments
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from torch.utils.data import IterableDataset
|
from torch.utils.data import IterableDataset
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data.dataloader import DataLoader
|
||||||
|
|
||||||
from accelerate import Accelerator
|
|
||||||
from arguments import EvaluationArguments
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed
|
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import gym
|
|
||||||
from mujoco_py import GlfwContext
|
from mujoco_py import GlfwContext
|
||||||
|
|
||||||
from transformers import DecisionTransformerModel
|
from transformers import DecisionTransformerModel
|
||||||
|
|
||||||
|
|
||||||
|
@ -229,7 +229,10 @@ class DeeBertModel(BertPreTrainedModel):
|
|||||||
sequence_output = encoder_outputs[0]
|
sequence_output = encoder_outputs[0]
|
||||||
pooled_output = self.pooler(sequence_output)
|
pooled_output = self.pooler(sequence_output)
|
||||||
|
|
||||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[
|
outputs = (
|
||||||
|
sequence_output,
|
||||||
|
pooled_output,
|
||||||
|
) + encoder_outputs[
|
||||||
1:
|
1:
|
||||||
] # add hidden_states and attentions if they are here
|
] # add hidden_states and attentions if they are here
|
||||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions), highway exits
|
return outputs # sequence_output, pooled_output, (hidden_states), (attentions), highway exits
|
||||||
|
@ -19,7 +19,6 @@ from .modeling_highway_bert import BertPreTrainedModel, DeeBertModel, HighwayExc
|
|||||||
ROBERTA_START_DOCSTRING,
|
ROBERTA_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class DeeRobertaModel(DeeBertModel):
|
class DeeRobertaModel(DeeBertModel):
|
||||||
|
|
||||||
config_class = RobertaConfig
|
config_class = RobertaConfig
|
||||||
base_model_prefix = "roberta"
|
base_model_prefix = "roberta"
|
||||||
|
|
||||||
@ -36,7 +35,6 @@ class DeeRobertaModel(DeeBertModel):
|
|||||||
ROBERTA_START_DOCSTRING,
|
ROBERTA_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class DeeRobertaForSequenceClassification(BertPreTrainedModel):
|
class DeeRobertaForSequenceClassification(BertPreTrainedModel):
|
||||||
|
|
||||||
config_class = RobertaConfig
|
config_class = RobertaConfig
|
||||||
base_model_prefix = "roberta"
|
base_model_prefix = "roberta"
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import sys
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import run_glue_deebert
|
import run_glue_deebert
|
||||||
|
|
||||||
from transformers.testing_utils import TestCasePlus, get_gpu_count, require_torch_non_multi_gpu, slow
|
from transformers.testing_utils import TestCasePlus, get_gpu_count, require_torch_non_multi_gpu, slow
|
||||||
|
|
||||||
|
|
||||||
@ -45,7 +46,6 @@ class DeeBertTests(TestCasePlus):
|
|||||||
@slow
|
@slow
|
||||||
@require_torch_non_multi_gpu
|
@require_torch_non_multi_gpu
|
||||||
def test_glue_deebert_train(self):
|
def test_glue_deebert_train(self):
|
||||||
|
|
||||||
train_args = """
|
train_args = """
|
||||||
--model_type roberta
|
--model_type roberta
|
||||||
--model_name_or_path roberta-base
|
--model_name_or_path roberta-base
|
||||||
|
@ -21,14 +21,14 @@ import time
|
|||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
|
||||||
|
from lm_seqs_dataset import LmSeqsDataset
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
|
|
||||||
from lm_seqs_dataset import LmSeqsDataset
|
|
||||||
from transformers import get_linear_schedule_with_warmup
|
from transformers import get_linear_schedule_with_warmup
|
||||||
from utils import logger
|
from utils import logger
|
||||||
|
|
||||||
|
@ -189,7 +189,6 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
for _ in train_iterator:
|
for _ in train_iterator:
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||||
for step, batch in enumerate(epoch_iterator):
|
for step, batch in enumerate(epoch_iterator):
|
||||||
|
|
||||||
# Skip past any already trained steps if resuming training
|
# Skip past any already trained steps if resuming training
|
||||||
if steps_trained_in_current_epoch > 0:
|
if steps_trained_in_current_epoch > 0:
|
||||||
steps_trained_in_current_epoch -= 1
|
steps_trained_in_current_epoch -= 1
|
||||||
|
@ -24,9 +24,9 @@ import shutil
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from distiller import Distiller
|
from distiller import Distiller
|
||||||
from lm_seqs_dataset import LmSeqsDataset
|
from lm_seqs_dataset import LmSeqsDataset
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
BertConfig,
|
BertConfig,
|
||||||
BertForMaskedLM,
|
BertForMaskedLM,
|
||||||
|
@ -5,13 +5,13 @@ import copy
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import joblib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import joblib
|
|
||||||
from transformers import AdamW, GPT2LMHeadModel, get_linear_schedule_with_warmup
|
from transformers import AdamW, GPT2LMHeadModel, get_linear_schedule_with_warmup
|
||||||
|
|
||||||
|
|
||||||
@ -119,7 +119,6 @@ def recopy_gpt2(orig_model, device, max_steps):
|
|||||||
|
|
||||||
|
|
||||||
def intermittent_save(contexts, real_perps, past_perps, filename):
|
def intermittent_save(contexts, real_perps, past_perps, filename):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
save the perplexity differences to filename
|
save the perplexity differences to filename
|
||||||
|
|
||||||
@ -152,7 +151,6 @@ def collect_objective_set(
|
|||||||
filename="dev.jbl",
|
filename="dev.jbl",
|
||||||
recopy_model=recopy_gpt2,
|
recopy_model=recopy_gpt2,
|
||||||
):
|
):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Collect individual IGF values from pre-trained transformer model
|
Collect individual IGF values from pre-trained transformer model
|
||||||
max_steps samples of training data to train secondary model
|
max_steps samples of training data to train secondary model
|
||||||
@ -271,7 +269,6 @@ def generate_datasets(
|
|||||||
def train_secondary_learner(
|
def train_secondary_learner(
|
||||||
secondary_learner, train_dataset, max_epochs, batch_size, eval_freq=50, igf_model_path="secondary_learner.pt"
|
secondary_learner, train_dataset, max_epochs, batch_size, eval_freq=50, igf_model_path="secondary_learner.pt"
|
||||||
):
|
):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Train the secondary learner (igf_model)
|
Train the secondary learner (igf_model)
|
||||||
|
|
||||||
|
@ -28,11 +28,9 @@ Last, a plot is generated to compare the performance of IGF to standard fine-tun
|
|||||||
import argparse
|
import argparse
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import joblib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
|
||||||
|
|
||||||
import joblib
|
|
||||||
from igf.igf import (
|
from igf.igf import (
|
||||||
SecondaryLearner,
|
SecondaryLearner,
|
||||||
collect_objective_set,
|
collect_objective_set,
|
||||||
@ -43,6 +41,8 @@ from igf.igf import (
|
|||||||
set_seed,
|
set_seed,
|
||||||
train_secondary_learner,
|
train_secondary_learner,
|
||||||
)
|
)
|
||||||
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
|
|
||||||
from transformers import GPT2LMHeadModel
|
from transformers import GPT2LMHeadModel
|
||||||
|
|
||||||
|
|
||||||
@ -55,7 +55,6 @@ def generate_n_pairs(
|
|||||||
data_file="data/tokenized_stories_train_wikitext103.jbl",
|
data_file="data/tokenized_stories_train_wikitext103.jbl",
|
||||||
igf_data_file="igf_context_pairs.jbl",
|
igf_data_file="igf_context_pairs.jbl",
|
||||||
):
|
):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Collecting *n* pairs for training the secondary learner
|
Collecting *n* pairs for training the secondary learner
|
||||||
Args:
|
Args:
|
||||||
|
@ -4,8 +4,6 @@ from dataclasses import dataclass
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@ -16,6 +14,8 @@ from flax import jax_utils, struct, traverse_util
|
|||||||
from flax.serialization import from_bytes, to_bytes
|
from flax.serialization import from_bytes, to_bytes
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import shard
|
from flax.training.common_utils import shard
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering
|
from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering
|
||||||
from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule
|
from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule
|
||||||
|
|
||||||
@ -98,7 +98,6 @@ class Args:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollator:
|
class DataCollator:
|
||||||
|
|
||||||
pad_id: int
|
pad_id: int
|
||||||
max_length: int = 4096 # no dynamic padding on TPUs
|
max_length: int = 4096 # no dynamic padding on TPUs
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from datasets import load_from_disk
|
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from bigbird_flax import FlaxBigBirdForNaturalQuestions
|
from bigbird_flax import FlaxBigBirdForNaturalQuestions
|
||||||
|
from datasets import load_from_disk
|
||||||
|
|
||||||
from transformers import BigBirdTokenizerFast
|
from transformers import BigBirdTokenizerFast
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import jsonlines
|
|
||||||
|
|
||||||
|
|
||||||
DOC_STRIDE = 2048
|
DOC_STRIDE = 2048
|
||||||
MAX_LENGTH = 4096
|
MAX_LENGTH = 4096
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
|
|
||||||
from datasets import load_dataset
|
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import wandb
|
import wandb
|
||||||
from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step
|
from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step
|
||||||
|
from datasets import load_dataset
|
||||||
from flax import jax_utils
|
from flax import jax_utils
|
||||||
|
|
||||||
from transformers import BigBirdTokenizerFast
|
from transformers import BigBirdTokenizerFast
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,17 +32,17 @@ from pathlib import Path
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
|
||||||
from datasets import load_dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import flax
|
import flax
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
import optax
|
import optax
|
||||||
|
from datasets import load_dataset
|
||||||
from flax import jax_utils, traverse_util
|
from flax import jax_utils, traverse_util
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard
|
from flax.training.common_utils import get_metrics, onehot, shard
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
|
@ -20,6 +20,7 @@ import jax
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from configuration_hybrid_clip import HybridCLIPConfig
|
from configuration_hybrid_clip import HybridCLIPConfig
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict
|
||||||
|
|
||||||
from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
|
from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
|
||||||
from transformers.modeling_flax_utils import FlaxPreTrainedModel
|
from transformers.modeling_flax_utils import FlaxPreTrainedModel
|
||||||
from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
|
from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
|
||||||
@ -132,7 +133,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|||||||
input_shape: Optional[Tuple] = None,
|
input_shape: Optional[Tuple] = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
if input_shape is None:
|
if input_shape is None:
|
||||||
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
||||||
|
@ -32,22 +32,22 @@ from dataclasses import dataclass, field
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import optax
|
||||||
import torch
|
import torch
|
||||||
|
from flax import jax_utils
|
||||||
|
from flax.jax_utils import unreplicate
|
||||||
|
from flax.training import train_state
|
||||||
|
from flax.training.common_utils import get_metrics, shard, shard_prng_key
|
||||||
|
from modeling_hybrid_clip import FlaxHybridCLIP
|
||||||
from torchvision.datasets import VisionDataset
|
from torchvision.datasets import VisionDataset
|
||||||
from torchvision.io import ImageReadMode, read_image
|
from torchvision.io import ImageReadMode, read_image
|
||||||
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
|
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import optax
|
|
||||||
import transformers
|
import transformers
|
||||||
from flax import jax_utils
|
|
||||||
from flax.jax_utils import unreplicate
|
|
||||||
from flax.training import train_state
|
|
||||||
from flax.training.common_utils import get_metrics, shard, shard_prng_key
|
|
||||||
from modeling_hybrid_clip import FlaxHybridCLIP
|
|
||||||
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
|
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,19 +28,19 @@ from pathlib import Path
|
|||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
|
||||||
from datasets import Dataset, load_dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
import optax
|
import optax
|
||||||
import transformers
|
from datasets import Dataset, load_dataset
|
||||||
from flax.core.frozen_dict import freeze, unfreeze
|
from flax.core.frozen_dict import freeze, unfreeze
|
||||||
from flax.training.common_utils import onehot, stack_forest
|
from flax.training.common_utils import onehot, stack_forest
|
||||||
from jax.experimental.maps import mesh
|
from jax.experimental.maps import mesh
|
||||||
from jax.experimental.pjit import pjit
|
from jax.experimental.pjit import pjit
|
||||||
from partitions import set_partitions
|
from partitions import set_partitions
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
|
@ -6,18 +6,18 @@ from dataclasses import field
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from datasets import DatasetDict, load_dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import flax
|
import flax
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import librosa
|
import librosa
|
||||||
|
import numpy as np
|
||||||
import optax
|
import optax
|
||||||
|
from datasets import DatasetDict, load_dataset
|
||||||
from flax import jax_utils, traverse_util
|
from flax import jax_utils, traverse_util
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard
|
from flax.training.common_utils import get_metrics, onehot, shard
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
FlaxWav2Vec2ForPreTraining,
|
FlaxWav2Vec2ForPreTraining,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user