mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Use Python 3.9 syntax in tests (#37343)
Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
parent
0fc683d1cd
commit
1e6b546ea6
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 HuggingFace Inc.
|
# Copyright 2023 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 HuggingFace Inc.
|
# Copyright 2024 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 HuggingFace Inc.
|
# Copyright 2023 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 HuggingFace Inc.
|
# Copyright 2024 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 HuggingFace Inc.
|
# Copyright 2023 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 HuggingFace Inc.
|
# Copyright 2024 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 HuggingFace Inc.
|
# Copyright 2024 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 HuggingFace Inc.
|
# Copyright 2024 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 HuggingFace Inc.
|
# Copyright 2023 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 HuggingFace Inc.
|
# Copyright 2023 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 HuggingFace Inc.
|
# Copyright 2024 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -14,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Union
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@ -35,7 +34,7 @@ if is_vision_available():
|
|||||||
AUTHORIZED_TYPES = ["string", "boolean", "integer", "number", "audio", "image", "any"]
|
AUTHORIZED_TYPES = ["string", "boolean", "integer", "number", "audio", "image", "any"]
|
||||||
|
|
||||||
|
|
||||||
def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
|
def create_inputs(tool_inputs: dict[str, dict[Union[str, type], str]]):
|
||||||
inputs = {}
|
inputs = {}
|
||||||
|
|
||||||
for input_name, input_desc in tool_inputs.items():
|
for input_name, input_desc in tool_inputs.items():
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 HuggingFace Inc.
|
# Copyright 2024 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 The HuggingFace Team Inc.
|
# Copyright 2023 The HuggingFace Team Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import io
|
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -436,9 +435,9 @@ class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# use self.get_config_dict(stage) to use these to ensure the original is not modified
|
# use self.get_config_dict(stage) to use these to ensure the original is not modified
|
||||||
with io.open(self.ds_config_file[ZERO2], "r", encoding="utf-8") as f:
|
with open(self.ds_config_file[ZERO2], encoding="utf-8") as f:
|
||||||
config_zero2 = json.load(f)
|
config_zero2 = json.load(f)
|
||||||
with io.open(self.ds_config_file[ZERO3], "r", encoding="utf-8") as f:
|
with open(self.ds_config_file[ZERO3], encoding="utf-8") as f:
|
||||||
config_zero3 = json.load(f)
|
config_zero3 = json.load(f)
|
||||||
# The following setting slows things down, so don't enable it by default unless needed by a test.
|
# The following setting slows things down, so don't enable it by default unless needed by a test.
|
||||||
# It's in the file as a demo for users since we want everything to work out of the box even if slower.
|
# It's in the file as a demo for users since we want everything to work out of the box even if slower.
|
||||||
|
@ -17,7 +17,6 @@ import os
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
@ -186,7 +185,7 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
def test_run_seq2seq_bnb(self):
|
def test_run_seq2seq_bnb(self):
|
||||||
from transformers.training_args import OptimizerNames
|
from transformers.training_args import OptimizerNames
|
||||||
|
|
||||||
def train_and_return_metrics(optim: str) -> Tuple[int, float]:
|
def train_and_return_metrics(optim: str) -> tuple[int, float]:
|
||||||
extra_args = "--skip_memory_metrics 0"
|
extra_args = "--skip_memory_metrics 0"
|
||||||
|
|
||||||
output_dir = self.run_trainer(
|
output_dir = self.run_trainer(
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team Inc.
|
# Copyright 2020 The HuggingFace Team Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team Inc.
|
# Copyright 2020 The HuggingFace Team Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 The HuggingFace Team Inc.
|
# Copyright 2022 The HuggingFace Team Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team Inc.
|
# Copyright 2020 The HuggingFace Team Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -14,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from typing import List, Union
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
@ -86,7 +85,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
self.assertFalse(torch.isinf(scores_before_min_length).any())
|
self.assertFalse(torch.isinf(scores_before_min_length).any())
|
||||||
|
|
||||||
@parameterized.expand([(0,), ([0, 18],)])
|
@parameterized.expand([(0,), ([0, 18],)])
|
||||||
def test_new_min_length_dist_processor(self, eos_token_id: Union[int, List[int]]):
|
def test_new_min_length_dist_processor(self, eos_token_id: Union[int, list[int]]):
|
||||||
vocab_size = 20
|
vocab_size = 20
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team Inc.
|
# Copyright 2020 The HuggingFace Team Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 The HuggingFace Team Inc.
|
# Copyright 2023 The HuggingFace Team Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team Inc.
|
# Copyright 2020 The HuggingFace Team Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2019 Hugging Face inc.
|
# Copyright 2019 Hugging Face inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 HuggingFace Inc.
|
# Copyright 2024 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 HuggingFace Inc.
|
# Copyright 2024 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 HuggingFace Inc.
|
# Copyright 2022 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 the HuggingFace Inc. team.
|
# Copyright 2021 the HuggingFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 the HuggingFace Inc. team.
|
# Copyright 2021 the HuggingFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 the HuggingFace Inc. team.
|
# Copyright 2021 the HuggingFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -107,7 +106,7 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
|||||||
json.dump(config_dict, fp)
|
json.dump(config_dict, fp)
|
||||||
|
|
||||||
# drop `processor_class` in tokenizer config
|
# drop `processor_class` in tokenizer config
|
||||||
with open(os.path.join(tmpdirname, TOKENIZER_CONFIG_FILE), "r") as f:
|
with open(os.path.join(tmpdirname, TOKENIZER_CONFIG_FILE)) as f:
|
||||||
config_dict = json.load(f)
|
config_dict = json.load(f)
|
||||||
config_dict.pop("processor_class")
|
config_dict.pop("processor_class")
|
||||||
|
|
||||||
@ -130,7 +129,7 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
|||||||
|
|
||||||
if os.path.isfile(os.path.join(tmpdirname, PROCESSOR_NAME)):
|
if os.path.isfile(os.path.join(tmpdirname, PROCESSOR_NAME)):
|
||||||
# drop `processor_class` in processor
|
# drop `processor_class` in processor
|
||||||
with open(os.path.join(tmpdirname, PROCESSOR_NAME), "r") as f:
|
with open(os.path.join(tmpdirname, PROCESSOR_NAME)) as f:
|
||||||
config_dict = json.load(f)
|
config_dict = json.load(f)
|
||||||
config_dict.pop("processor_class")
|
config_dict.pop("processor_class")
|
||||||
|
|
||||||
@ -138,7 +137,7 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
|||||||
f.write(json.dumps(config_dict))
|
f.write(json.dumps(config_dict))
|
||||||
|
|
||||||
# drop `processor_class` in tokenizer
|
# drop `processor_class` in tokenizer
|
||||||
with open(os.path.join(tmpdirname, TOKENIZER_CONFIG_FILE), "r") as f:
|
with open(os.path.join(tmpdirname, TOKENIZER_CONFIG_FILE)) as f:
|
||||||
config_dict = json.load(f)
|
config_dict = json.load(f)
|
||||||
config_dict.pop("processor_class")
|
config_dict.pop("processor_class")
|
||||||
|
|
||||||
@ -161,7 +160,7 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
|||||||
|
|
||||||
if os.path.isfile(os.path.join(tmpdirname, PROCESSOR_NAME)):
|
if os.path.isfile(os.path.join(tmpdirname, PROCESSOR_NAME)):
|
||||||
# drop `processor_class` in processor
|
# drop `processor_class` in processor
|
||||||
with open(os.path.join(tmpdirname, PROCESSOR_NAME), "r") as f:
|
with open(os.path.join(tmpdirname, PROCESSOR_NAME)) as f:
|
||||||
config_dict = json.load(f)
|
config_dict = json.load(f)
|
||||||
config_dict.pop("processor_class")
|
config_dict.pop("processor_class")
|
||||||
|
|
||||||
@ -169,7 +168,7 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
|||||||
f.write(json.dumps(config_dict))
|
f.write(json.dumps(config_dict))
|
||||||
|
|
||||||
# drop `processor_class` in feature extractor
|
# drop `processor_class` in feature extractor
|
||||||
with open(os.path.join(tmpdirname, FEATURE_EXTRACTOR_NAME), "r") as f:
|
with open(os.path.join(tmpdirname, FEATURE_EXTRACTOR_NAME)) as f:
|
||||||
config_dict = json.load(f)
|
config_dict = json.load(f)
|
||||||
config_dict.pop("processor_class")
|
config_dict.pop("processor_class")
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021, The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2021, The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -179,7 +179,7 @@ class FlaxBartModelTester:
|
|||||||
|
|
||||||
outputs = model.decode(decoder_input_ids, encoder_outputs)
|
outputs = model.decode(decoder_input_ids, encoder_outputs)
|
||||||
|
|
||||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
diff = np.max(np.abs(outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))
|
||||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||||
|
|
||||||
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict):
|
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict):
|
||||||
@ -225,7 +225,7 @@ class FlaxBartModelTester:
|
|||||||
|
|
||||||
outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask)
|
outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask)
|
||||||
|
|
||||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
diff = np.max(np.abs(outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))
|
||||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 Ecole Polytechnique and HuggingFace Inc. team.
|
# Copyright 2020 Ecole Polytechnique and HuggingFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 HuggingFace Inc. team.
|
# Copyright 2021 HuggingFace Inc. team.
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 HuggingFace Inc.
|
# Copyright 2021 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2018 Salesforce and HuggingFace Inc. team.
|
# Copyright 2018 Salesforce and HuggingFace Inc. team.
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -892,7 +891,7 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
target_end_logits = torch.tensor(
|
target_end_logits = torch.tensor(
|
||||||
[[-12.1736, -8.8487, -14.8877, -11.6713, -15.1165, -12.2396, -7.6828, -15.4153, -12.2528, -14.3671, -12.3596, -7.4272, -14.9615, -13.6356, -11.7939, -9.9767, -14.8112, -8.9567, -15.8798, -11.5291, -9.4249, -14.7544, -7.9387, -16.2789, -8.9702, -15.3111, -11.5585, -7.9992, -4.1127, 10.3209, -8.3926, -10.2005], [-11.1375, -15.4027, -12.6861, -16.9884, -13.7093, -10.3560, -15.7228, -12.9290, -15.8519, -13.7953, -10.2460, -15.7198, -14.2078, -12.8477, -11.4861, -16.1017, -11.8900, -16.4488, -13.2959, -10.3980, -15.4874, -10.3539, -16.8263, -10.9973, -17.0344, -9.2751, -10.1196, -13.8907, -12.1025, -13.0628, -12.8530, -13.8173]], # noqa: E321
|
[[-12.1736, -8.8487, -14.8877, -11.6713, -15.1165, -12.2396, -7.6828, -15.4153, -12.2528, -14.3671, -12.3596, -7.4272, -14.9615, -13.6356, -11.7939, -9.9767, -14.8112, -8.9567, -15.8798, -11.5291, -9.4249, -14.7544, -7.9387, -16.2789, -8.9702, -15.3111, -11.5585, -7.9992, -4.1127, 10.3209, -8.3926, -10.2005], [-11.1375, -15.4027, -12.6861, -16.9884, -13.7093, -10.3560, -15.7228, -12.9290, -15.8519, -13.7953, -10.2460, -15.7198, -14.2078, -12.8477, -11.4861, -16.1017, -11.8900, -16.4488, -13.2959, -10.3980, -15.4874, -10.3539, -16.8263, -10.9973, -17.0344, -9.2751, -10.1196, -13.8907, -12.1025, -13.0628, -12.8530, -13.8173]],
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021, The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2021, The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -179,7 +179,7 @@ class FlaxBlenderbotModelTester:
|
|||||||
|
|
||||||
outputs = model.decode(decoder_input_ids, encoder_outputs)
|
outputs = model.decode(decoder_input_ids, encoder_outputs)
|
||||||
|
|
||||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
diff = np.max(np.abs(outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))
|
||||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||||
|
|
||||||
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict):
|
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict):
|
||||||
@ -225,7 +225,7 @@ class FlaxBlenderbotModelTester:
|
|||||||
|
|
||||||
outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask)
|
outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask)
|
||||||
|
|
||||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
diff = np.max(np.abs(outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))
|
||||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021, The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2021, The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -178,7 +178,7 @@ class FlaxBlenderbotSmallModelTester:
|
|||||||
|
|
||||||
outputs = model.decode(decoder_input_ids, encoder_outputs)
|
outputs = model.decode(decoder_input_ids, encoder_outputs)
|
||||||
|
|
||||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
diff = np.max(np.abs(outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))
|
||||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||||
|
|
||||||
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict):
|
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict):
|
||||||
@ -224,7 +224,7 @@ class FlaxBlenderbotSmallModelTester:
|
|||||||
|
|
||||||
outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask)
|
outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask)
|
||||||
|
|
||||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
diff = np.max(np.abs(outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))
|
||||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 HuggingFace Inc.
|
# Copyright 2022 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -128,7 +128,7 @@ class FlaxBloomModelTester:
|
|||||||
|
|
||||||
outputs = model(input_ids)
|
outputs = model(input_ids)
|
||||||
|
|
||||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
diff = np.max(np.abs(outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))
|
||||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||||
|
|
||||||
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict):
|
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict):
|
||||||
@ -163,7 +163,7 @@ class FlaxBloomModelTester:
|
|||||||
|
|
||||||
outputs = model(input_ids, attention_mask=attention_mask)
|
outputs = model(input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
diff = np.max(np.abs(outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))
|
||||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
|
# Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -15,7 +14,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -36,14 +35,14 @@ class BridgeTowerImageProcessingTester:
|
|||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
do_resize: bool = True,
|
do_resize: bool = True,
|
||||||
size: Dict[str, int] = None,
|
size: dict[str, int] = None,
|
||||||
size_divisor: int = 32,
|
size_divisor: int = 32,
|
||||||
do_rescale: bool = True,
|
do_rescale: bool = True,
|
||||||
rescale_factor: Union[int, float] = 1 / 255,
|
rescale_factor: Union[int, float] = 1 / 255,
|
||||||
do_normalize: bool = True,
|
do_normalize: bool = True,
|
||||||
do_center_crop: bool = True,
|
do_center_crop: bool = True,
|
||||||
image_mean: Optional[Union[float, List[float]]] = [0.48145466, 0.4578275, 0.40821073],
|
image_mean: Optional[Union[float, list[float]]] = [0.48145466, 0.4578275, 0.40821073],
|
||||||
image_std: Optional[Union[float, List[float]]] = [0.26862954, 0.26130258, 0.27577711],
|
image_std: Optional[Union[float, list[float]]] = [0.26862954, 0.26130258, 0.27577711],
|
||||||
do_pad: bool = True,
|
do_pad: bool = True,
|
||||||
batch_size=7,
|
batch_size=7,
|
||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
|
# Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 Google T5 Authors and HuggingFace Inc. team.
|
# Copyright 2020 Google T5 Authors and HuggingFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -20,7 +19,6 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from transformers import AddedToken, BatchEncoding, ByT5Tokenizer
|
from transformers import AddedToken, BatchEncoding, ByT5Tokenizer
|
||||||
from transformers.utils import cached_property, is_tf_available, is_torch_available
|
from transformers.utils import cached_property, is_tf_available, is_torch_available
|
||||||
@ -57,7 +55,7 @@ class ByT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
pretrained_name = pretrained_name or cls.tmpdirname
|
pretrained_name = pretrained_name or cls.tmpdirname
|
||||||
return cls.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
return cls.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||||
|
|
||||||
def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20, min_length=5) -> Tuple[str, list]:
|
def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20, min_length=5) -> tuple[str, list]:
|
||||||
# XXX The default common tokenizer tests assume that every ID is decodable on its own.
|
# XXX The default common tokenizer tests assume that every ID is decodable on its own.
|
||||||
# This assumption is invalid for ByT5 because single bytes might not be
|
# This assumption is invalid for ByT5 because single bytes might not be
|
||||||
# valid utf-8 (byte 128 for instance).
|
# valid utf-8 (byte 128 for instance).
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2018 HuggingFace Inc. team.
|
# Copyright 2018 HuggingFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -15,7 +14,6 @@
|
|||||||
"""Testing suite for the PyTorch CANINE model."""
|
"""Testing suite for the PyTorch CANINE model."""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
from transformers import CanineConfig, is_torch_available
|
from transformers import CanineConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
@ -383,7 +381,7 @@ class CanineModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||||
|
|
||||||
def recursive_check(tuple_object, dict_object):
|
def recursive_check(tuple_object, dict_object):
|
||||||
if isinstance(tuple_object, (List, Tuple)):
|
if isinstance(tuple_object, (list, tuple)):
|
||||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||||
elif tuple_object is None:
|
elif tuple_object is None:
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 Google AI and HuggingFace Inc. team.
|
# Copyright 2021 Google AI and HuggingFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 HuggingFace Inc.
|
# Copyright 2024 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 HuggingFace Inc.
|
# Copyright 2021 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 HuggingFace Inc.
|
# Copyright 2023 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 HuggingFace Inc.
|
# Copyright 2021 HuggingFace Inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -18,7 +17,7 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
@ -229,8 +228,8 @@ class CLIPModelTesterMixin(ModelTesterMixin):
|
|||||||
def test_eager_matches_sdpa_inference(
|
def test_eager_matches_sdpa_inference(
|
||||||
self,
|
self,
|
||||||
torch_dtype: str,
|
torch_dtype: str,
|
||||||
use_attention_mask_options: Tuple[Optional[str], ...] = (None, "left", "right"),
|
use_attention_mask_options: tuple[Optional[str], ...] = (None, "left", "right"),
|
||||||
logit_keys: Tuple[str, ...] = ("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"),
|
logit_keys: tuple[str, ...] = ("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"),
|
||||||
):
|
):
|
||||||
if not self.all_model_classes[0]._supports_sdpa:
|
if not self.all_model_classes[0]._supports_sdpa:
|
||||||
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
|
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user