Fix torch version comparisons (#18460)

Comparisons like
version.parse(torch.__version__) > version.parse("1.6")
are True for torch==1.6.0+cu101 or torch==1.6.0+cpu

version.parse(version.parse(torch.__version__).base_version) are preferred (and available in pytorch_utils.py
This commit is contained in:
LSinev 2022-08-03 20:37:18 +03:00 committed by GitHub
parent be41eaf55f
commit 02b176c4ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 164 additions and 87 deletions

View File

@ -30,7 +30,7 @@ from transformers import (
if is_apex_available(): if is_apex_available():
from apex import amp from apex import amp
if version.parse(torch.__version__) >= version.parse("1.6"): if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"):
_is_native_amp_available = True _is_native_amp_available = True
from torch.cuda.amp import autocast from torch.cuda.amp import autocast

View File

@ -33,7 +33,7 @@ if is_apex_available():
from apex import amp from apex import amp
if version.parse(torch.__version__) >= version.parse("1.6"): if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"):
_is_native_amp_available = True _is_native_amp_available = True
from torch.cuda.amp import autocast from torch.cuda.amp import autocast

View File

@ -26,7 +26,7 @@ from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
if is_apex_available(): if is_apex_available():
from apex import amp from apex import amp
if version.parse(torch.__version__) >= version.parse("1.6"): if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"):
_is_native_amp_available = True _is_native_amp_available = True
from torch.cuda.amp import autocast from torch.cuda.amp import autocast

View File

@ -44,7 +44,7 @@ class GELUActivation(nn.Module):
def __init__(self, use_gelu_python: bool = False): def __init__(self, use_gelu_python: bool = False):
super().__init__() super().__init__()
if version.parse(torch.__version__) < version.parse("1.4") or use_gelu_python: if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.4") or use_gelu_python:
self.act = self._gelu_python self.act = self._gelu_python
else: else:
self.act = nn.functional.gelu self.act = nn.functional.gelu
@ -110,7 +110,7 @@ class SiLUActivation(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
if version.parse(torch.__version__) < version.parse("1.7"): if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.7"):
self.act = self._silu_python self.act = self._silu_python
else: else:
self.act = nn.functional.silu self.act = nn.functional.silu
@ -130,7 +130,7 @@ class MishActivation(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
if version.parse(torch.__version__) < version.parse("1.9"): if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.9"):
self.act = self._mish_python self.act = self._mish_python
else: else:
self.act = nn.functional.mish self.act = nn.functional.mish

View File

@ -273,6 +273,8 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format
import torch import torch
from torch.onnx import export from torch.onnx import export
from .pytorch_utils import is_torch_less_than_1_11
print(f"Using framework PyTorch: {torch.__version__}") print(f"Using framework PyTorch: {torch.__version__}")
with torch.no_grad(): with torch.no_grad():
@ -281,7 +283,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility # so we check the torch version for backwards compatibility
if parse(torch.__version__) <= parse("1.10.99"): if is_torch_less_than_1_11:
export( export(
nlp.model, nlp.model,
model_args, model_args,

View File

@ -20,7 +20,6 @@ from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -35,7 +34,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
@ -212,7 +216,7 @@ class AlbertEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),

View File

@ -24,7 +24,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -41,7 +40,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
@ -195,7 +199,7 @@ class BertEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),

View File

@ -23,7 +23,6 @@ from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -38,7 +37,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward from ...pytorch_utils import apply_chunking_to_forward, is_torch_greater_than_1_6
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
@ -260,7 +259,7 @@ class BigBirdEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),

View File

@ -22,7 +22,6 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -36,7 +35,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel, SequenceSummary from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_convbert import ConvBertConfig from .configuration_convbert import ConvBertConfig
@ -194,7 +198,7 @@ class ConvBertEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),

View File

@ -19,7 +19,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -35,7 +34,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
@ -83,7 +87,7 @@ class Data2VecTextForTextEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),

View File

@ -21,12 +21,16 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...pytorch_utils import (
Conv1D,
find_pruneable_heads_and_indices,
is_torch_greater_or_equal_than_1_6,
prune_conv1d_layer,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
@ -36,7 +40,7 @@ from ...utils import (
) )
if version.parse(torch.__version__) >= version.parse("1.6"): if is_torch_greater_or_equal_than_1_6:
is_amp_available = True is_amp_available = True
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
else: else:

View File

@ -23,7 +23,6 @@ from typing import Dict, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -40,7 +39,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
@ -102,7 +106,7 @@ class Embeddings(nn.Module):
self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12) self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
) )

View File

@ -21,7 +21,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -37,7 +36,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel, SequenceSummary from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
@ -165,7 +169,7 @@ class ElectraEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),

View File

@ -19,10 +19,10 @@ import random
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import torch import torch
from packaging import version
from torch import nn from torch import nn
from ...modeling_outputs import BaseModelOutput from ...modeling_outputs import BaseModelOutput
from ...pytorch_utils import is_torch_greater_than_1_6
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ..xlm.modeling_xlm import ( from ..xlm.modeling_xlm import (
XLMForMultipleChoice, XLMForMultipleChoice,
@ -139,7 +139,7 @@ class FlaubertModel(XLMModel):
super().__init__(config) super().__init__(config)
self.layerdrop = getattr(config, "layerdrop", 0.0) self.layerdrop = getattr(config, "layerdrop", 0.0)
self.pre_norm = getattr(config, "pre_norm", False) self.pre_norm = getattr(config, "pre_norm", False)
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
) )

View File

@ -22,7 +22,6 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from transformers.utils.doc import add_code_sample_docstrings from transformers.utils.doc import add_code_sample_docstrings
@ -30,6 +29,7 @@ from transformers.utils.doc import add_code_sample_docstrings
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import is_torch_greater_than_1_6
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
@ -392,7 +392,7 @@ class FlavaTextEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),

View File

@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -44,7 +43,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward from ...pytorch_utils import apply_chunking_to_forward, is_torch_greater_than_1_6
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
@ -118,7 +117,7 @@ class FNetEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),

View File

@ -22,12 +22,18 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...pytorch_utils import (
Conv1D,
find_pruneable_heads_and_indices,
is_torch_greater_or_equal_than_1_6,
prune_conv1d_layer,
)
if version.parse(torch.__version__) >= version.parse("1.6"):
if is_torch_greater_or_equal_than_1_6:
is_amp_available = True is_amp_available = True
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
else: else:
@ -41,7 +47,6 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel, SequenceSummary from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,

View File

@ -21,12 +21,18 @@ from typing import Any, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...pytorch_utils import (
Conv1D,
find_pruneable_heads_and_indices,
is_torch_greater_or_equal_than_1_6,
prune_conv1d_layer,
)
if version.parse(torch.__version__) >= version.parse("1.6"):
if is_torch_greater_or_equal_than_1_6:
is_amp_available = True is_amp_available = True
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
else: else:
@ -39,7 +45,6 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_imagegpt import ImageGPTConfig from .configuration_imagegpt import ImageGPTConfig

View File

@ -21,7 +21,6 @@ from typing import Optional
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
@ -34,6 +33,7 @@ from ...modeling_utils import (
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_linear_layer, prune_linear_layer,
) )
from ...pytorch_utils import is_torch_greater_than_1_6
from ...utils import logging from ...utils import logging
from .configuration_mctct import MCTCTConfig from .configuration_mctct import MCTCTConfig
@ -153,7 +153,7 @@ class MCTCTEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),

View File

@ -23,7 +23,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -39,7 +38,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
@ -183,7 +187,7 @@ class NezhaEmbeddings(nn.Module):
# any TensorFlow checkpoint file # any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros((1, config.max_position_embeddings), dtype=torch.long), torch.zeros((1, config.max_position_embeddings), dtype=torch.long),

View File

@ -20,7 +20,6 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -34,7 +33,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_nystromformer import NystromformerConfig from .configuration_nystromformer import NystromformerConfig
@ -68,7 +72,7 @@ class NystromformerEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),

View File

@ -23,7 +23,6 @@ from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -40,7 +39,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, is_torch_greater_than_1_6, prune_linear_layer
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
@ -167,7 +166,7 @@ class QDQBertEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),

View File

@ -20,7 +20,6 @@ from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
from packaging import version
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
@ -32,7 +31,12 @@ from ...modeling_outputs import (
ModelOutput, ModelOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_realm import RealmConfig from .configuration_realm import RealmConfig
@ -181,7 +185,7 @@ class RealmEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),

View File

@ -20,7 +20,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -36,7 +35,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
@ -83,7 +87,7 @@ class RobertaEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),

View File

@ -21,7 +21,6 @@ from typing import List, Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
@ -35,14 +34,19 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
find_pruneable_heads_and_indices,
is_torch_greater_or_equal_than_1_10,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_vilt import ViltConfig from .configuration_vilt import ViltConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
if version.parse(torch.__version__) < version.parse("1.10.0"): if not is_torch_greater_or_equal_than_1_10:
logger.warning( logger.warning(
f"You are using torch=={torch.__version__}, but torch>=1.10.0 is required to use " f"You are using torch=={torch.__version__}, but torch>=1.10.0 is required to use "
"ViltModel. Please upgrade torch." "ViltModel. Please upgrade torch."
@ -251,7 +255,7 @@ class TextEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),

View File

@ -19,7 +19,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -35,7 +34,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
@ -76,7 +80,7 @@ class XLMRobertaXLEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),

View File

@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -35,7 +34,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_yoso import YosoConfig from .configuration_yoso import YosoConfig
@ -257,7 +261,7 @@ class YosoEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),

View File

@ -34,6 +34,7 @@ from .config import OnnxConfig
if is_torch_available(): if is_torch_available():
from ..modeling_utils import PreTrainedModel from ..modeling_utils import PreTrainedModel
from ..pytorch_utils import is_torch_less_than_1_11
if is_tf_available(): if is_tf_available():
from ..modeling_tf_utils import TFPreTrainedModel from ..modeling_tf_utils import TFPreTrainedModel
@ -155,7 +156,7 @@ def export_pytorch(
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility # so we check the torch version for backwards compatibility
if parse(torch.__version__) < parse("1.10"): if is_torch_less_than_1_11:
# export can work with named args but the dict containing named args # export can work with named args but the dict containing named args
# has to be the last element of the args tuple. # has to be the last element of the args tuple.
try: try:

View File

@ -967,7 +967,9 @@ class Pipeline(_ScikitCompat):
def get_inference_context(self): def get_inference_context(self):
inference_context = ( inference_context = (
torch.inference_mode if version.parse(torch.__version__) >= version.parse("1.9.0") else torch.no_grad torch.inference_mode
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.9.0")
else torch.no_grad
) )
return inference_context return inference_context

View File

@ -25,8 +25,12 @@ ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
is_torch_less_than_1_8 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.8.0") parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11") is_torch_greater_or_equal_than_1_6 = parsed_torch_version_base >= version.parse("1.6.0")
is_torch_greater_than_1_6 = parsed_torch_version_base > version.parse("1.6.0")
is_torch_less_than_1_8 = parsed_torch_version_base < version.parse("1.8.0")
is_torch_greater_or_equal_than_1_10 = parsed_torch_version_base >= version.parse("1.10")
is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11")
def torch_int_div(tensor1, tensor2): def torch_int_div(tensor1, tensor2):

View File

@ -71,7 +71,12 @@ from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
from .optimization import Adafactor, get_scheduler from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS from .pytorch_utils import (
ALL_LAYERNORM_LAYERS,
is_torch_greater_or_equal_than_1_6,
is_torch_greater_or_equal_than_1_10,
is_torch_less_than_1_11,
)
from .tokenization_utils_base import PreTrainedTokenizerBase from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import ( from .trainer_callback import (
CallbackHandler, CallbackHandler,
@ -165,11 +170,11 @@ if is_in_notebook():
if is_apex_available(): if is_apex_available():
from apex import amp from apex import amp
if version.parse(torch.__version__) >= version.parse("1.6"): if is_torch_greater_or_equal_than_1_6:
_is_torch_generator_available = True _is_torch_generator_available = True
_is_native_cuda_amp_available = True _is_native_cuda_amp_available = True
if version.parse(torch.__version__) >= version.parse("1.10"): if is_torch_greater_or_equal_than_1_10:
_is_native_cpu_amp_available = True _is_native_cpu_amp_available = True
if is_datasets_available(): if is_datasets_available():
@ -405,7 +410,7 @@ class Trainer:
# Would have to update setup.py with torch>=1.12.0 # Would have to update setup.py with torch>=1.12.0
# which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0 # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
# below is the current alternative. # below is the current alternative.
if version.parse(torch.__version__) < version.parse("1.12.0"): if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
raise ValueError("FSDP requires PyTorch >= 1.12.0") raise ValueError("FSDP requires PyTorch >= 1.12.0")
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
@ -1676,7 +1681,7 @@ class Trainer:
is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance( is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
train_dataloader.sampler, RandomSampler train_dataloader.sampler, RandomSampler
) )
if version.parse(torch.__version__) < version.parse("1.11") or not is_random_sampler: if is_torch_less_than_1_11 or not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler. # We just need to begin an iteration to create the randomization of the sampler.
# That was before PyTorch 1.11 however... # That was before PyTorch 1.11 however...
for _ in train_dataloader: for _ in train_dataloader:
@ -2430,7 +2435,7 @@ class Trainer:
arguments, depending on the situation. arguments, depending on the situation.
""" """
if self.use_cuda_amp or self.use_cpu_amp: if self.use_cuda_amp or self.use_cpu_amp:
if version.parse(torch.__version__) >= version.parse("1.10"): if is_torch_greater_or_equal_than_1_10:
ctx_manager = ( ctx_manager = (
torch.cpu.amp.autocast(dtype=self.amp_dtype) torch.cpu.amp.autocast(dtype=self.amp_dtype)
if self.use_cpu_amp if self.use_cpu_amp

View File

@ -835,7 +835,7 @@ def _get_learning_rate(self):
last_lr = ( last_lr = (
# backward compatibility for pytorch schedulers # backward compatibility for pytorch schedulers
self.lr_scheduler.get_last_lr()[0] self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4") if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0] else self.lr_scheduler.get_lr()[0]
) )
return last_lr return last_lr

View File

@ -300,7 +300,7 @@ def is_torch_bf16_gpu_available():
# 4. torch.autocast exists # 4. torch.autocast exists
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's # XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
# really only correct for the 0th gpu (or currently set default device if different from 0) # really only correct for the 0th gpu (or currently set default device if different from 0)
if version.parse(torch.__version__) < version.parse("1.10"): if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.10"):
return False return False
if torch.cuda.is_available() and torch.version.cuda is not None: if torch.cuda.is_available() and torch.version.cuda is not None:
@ -322,7 +322,7 @@ def is_torch_bf16_cpu_available():
import torch import torch
if version.parse(torch.__version__) < version.parse("1.10"): if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.10"):
return False return False
try: try:
@ -357,7 +357,7 @@ def is_torch_tf32_available():
return False return False
if int(torch.version.cuda.split(".")[0]) < 11: if int(torch.version.cuda.split(".")[0]) < 11:
return False return False
if version.parse(torch.__version__) < version.parse("1.7"): if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.7"):
return False return False
return True return True

View File

@ -22,7 +22,6 @@ import os
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
@ -48,6 +47,7 @@ from ...pytorch_utils import (
apply_chunking_to_forward, apply_chunking_to_forward,
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_linear_layer, prune_linear_layer,
is_torch_greater_than_1_6,
) )
from ...utils import logging from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
@ -157,7 +157,7 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),