[style] consistent nn. and nn.functional: part 3 tests (#12155)

* consistent nn. and nn.functional: p3 templates

* restore
This commit is contained in:
Stas Bekman 2021-06-14 12:18:22 -07:00 committed by GitHub
parent d9c0d08f9a
commit 372ab9cd6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 93 additions and 81 deletions

View File

@ -24,7 +24,7 @@ from .test_modeling_common import ids_tensor
if is_torch_available():
import torch
import torch.nn.functional as F
from torch import nn
from transformers.generation_logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
@ -80,13 +80,13 @@ class LogitsProcessorTest(unittest.TestCase):
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
# compute softmax
probs = F.softmax(scores, dim=-1)
probs = nn.functional.softmax(scores, dim=-1)
temp_dist_warper_sharper = TemperatureLogitsWarper(temperature=0.5)
temp_dist_warper_smoother = TemperatureLogitsWarper(temperature=1.3)
warped_prob_sharp = F.softmax(temp_dist_warper_sharper(input_ids, scores.clone()), dim=-1)
warped_prob_smooth = F.softmax(temp_dist_warper_smoother(input_ids, scores.clone()), dim=-1)
warped_prob_sharp = nn.functional.softmax(temp_dist_warper_sharper(input_ids, scores.clone()), dim=-1)
warped_prob_smooth = nn.functional.softmax(temp_dist_warper_smoother(input_ids, scores.clone()), dim=-1)
# uniform distribution stays uniform
self.assertTrue(torch.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3))

View File

@ -30,6 +30,7 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init, floats_te
if is_torch_available():
import torch
from torch import nn
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPVisionConfig, CLIPVisionModel
from transformers.models.clip.modeling_clip import CLIP_PRETRAINED_MODEL_ARCHIVE_LIST
@ -140,9 +141,9 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Module))
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -44,6 +44,7 @@ from transformers.testing_utils import (
if is_torch_available():
import numpy as np
import torch
from torch import nn
from transformers import (
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
@ -1150,10 +1151,10 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Embedding, AdaptiveEmbedding))
model.set_input_embeddings(torch.nn.Embedding(10, 10))
self.assertIsInstance(model.get_input_embeddings(), (nn.Embedding, AdaptiveEmbedding))
model.set_input_embeddings(nn.Embedding(10, 10))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_correct_missing_keys(self):
if not self.test_missing_keys:
@ -1337,7 +1338,7 @@ class ModelTesterMixin:
model.eval()
# Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model)
model = nn.DataParallel(model)
with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class))

View File

@ -27,6 +27,7 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
from torch import nn
from transformers import (
MODEL_MAPPING,
@ -176,9 +177,9 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Module))
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -30,6 +30,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
import torch
from torch import nn
from transformers import FSMTConfig, FSMTForConditionalGeneration, FSMTModel, FSMTTokenizer
from transformers.models.fsmt.modeling_fsmt import (
@ -160,10 +161,10 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Embedding))
model.set_input_embeddings(torch.nn.Embedding(10, 10))
self.assertIsInstance(model.get_input_embeddings(), (nn.Embedding))
model.set_input_embeddings(nn.Embedding(10, 10))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, torch.nn.modules.sparse.Embedding))
self.assertTrue(x is None or isinstance(x, nn.modules.sparse.Embedding))
def test_initialization_more(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()

View File

@ -26,7 +26,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention
if is_torch_available():
import torch
import torch.nn as nn
from torch import nn
from transformers import (
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
@ -304,9 +304,9 @@ class IBertModelTest(ModelTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), QuantEmbedding)
model.set_input_embeddings(torch.nn.Embedding(10, 10))
model.set_input_embeddings(nn.Embedding(10, 10))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
self.assertTrue(x is None or isinstance(x, nn.Linear))
# Override
def test_feed_forward_chunking(self):
@ -350,7 +350,7 @@ class IBertModelIntegrationTest(unittest.TestCase):
weight_bit = 8
embedding = QuantEmbedding(2, 4, quant_mode=True, weight_bit=weight_bit)
embedding_weight = torch.tensor([[-1.0, -2.0, -3.0, -4.0], [5.0, 6.0, 7.0, 8.0]])
embedding.weight = torch.nn.Parameter(embedding_weight)
embedding.weight = nn.Parameter(embedding_weight)
expected_scaling_factor = embedding_weight.abs().max() / (2 ** (weight_bit - 1) - 1)
x, x_scaling_factor = embedding(torch.tensor(0))
@ -447,8 +447,8 @@ class IBertModelIntegrationTest(unittest.TestCase):
linear_q = QuantLinear(2, 4, quant_mode=True, per_channel=per_channel, weight_bit=weight_bit)
linear_dq = QuantLinear(2, 4, quant_mode=False, per_channel=per_channel, weight_bit=weight_bit)
linear_weight = torch.tensor([[-1.0, 2.0, 3.0, -4.0], [5.0, -6.0, -7.0, 8.0]]).T
linear_q.weight = torch.nn.Parameter(linear_weight)
linear_dq.weight = torch.nn.Parameter(linear_weight)
linear_q.weight = nn.Parameter(linear_weight)
linear_dq.weight = nn.Parameter(linear_weight)
q, q_scaling_factor = linear_q(x, x_scaling_factor)
q_int = q / q_scaling_factor
@ -477,7 +477,7 @@ class IBertModelIntegrationTest(unittest.TestCase):
def test_int_gelu(self):
gelu_q = IntGELU(quant_mode=True)
gelu_dq = torch.nn.GELU()
gelu_dq = nn.GELU()
x_int = torch.range(-10000, 10000, 1)
x_scaling_factor = torch.tensor(0.001)
@ -523,7 +523,7 @@ class IBertModelIntegrationTest(unittest.TestCase):
def test_int_softmax(self):
output_bit = 8
softmax_q = IntSoftmax(output_bit, quant_mode=True)
softmax_dq = torch.nn.Softmax()
softmax_dq = nn.Softmax()
# x_int = torch.range(-10000, 10000, 1)
def _test(array):
@ -590,12 +590,12 @@ class IBertModelIntegrationTest(unittest.TestCase):
x = x_int * x_scaling_factor
ln_q = IntLayerNorm(x.shape[1:], 1e-5, quant_mode=True, output_bit=output_bit)
ln_dq = torch.nn.LayerNorm(x.shape[1:], 1e-5)
ln_dq = nn.LayerNorm(x.shape[1:], 1e-5)
ln_q.weight = torch.nn.Parameter(torch.ones(x.shape[1:]))
ln_q.bias = torch.nn.Parameter(torch.ones(x.shape[1:]))
ln_dq.weight = torch.nn.Parameter(torch.ones(x.shape[1:]))
ln_dq.bias = torch.nn.Parameter(torch.ones(x.shape[1:]))
ln_q.weight = nn.Parameter(torch.ones(x.shape[1:]))
ln_q.bias = nn.Parameter(torch.ones(x.shape[1:]))
ln_dq.weight = nn.Parameter(torch.ones(x.shape[1:]))
ln_dq.bias = nn.Parameter(torch.ones(x.shape[1:]))
q, q_scaling_factor = ln_q(x, x_scaling_factor)
q_int = q / q_scaling_factor
@ -627,13 +627,13 @@ class IBertModelIntegrationTest(unittest.TestCase):
],
}
ln_dq.weight = torch.nn.Parameter(torch.ones(x.shape[1:]))
ln_dq.bias = torch.nn.Parameter(torch.ones(x.shape[1:]))
ln_dq.weight = nn.Parameter(torch.ones(x.shape[1:]))
ln_dq.bias = nn.Parameter(torch.ones(x.shape[1:]))
dq, dq_scaling_factor = ln_dq(x, x_scaling_factor)
for label, ln_fdqs in ln_fdqs_dict.items():
for ln_fdq in ln_fdqs:
ln_fdq.weight = torch.nn.Parameter(torch.ones(x.shape[1:]))
ln_fdq.bias = torch.nn.Parameter(torch.ones(x.shape[1:]))
ln_fdq.weight = nn.Parameter(torch.ones(x.shape[1:]))
ln_fdq.bias = nn.Parameter(torch.ones(x.shape[1:]))
q, q_scaling_factor = ln_fdq(x, x_scaling_factor)
if label:
self.assertTrue(torch.allclose(q, dq, atol=1e-4))

View File

@ -32,6 +32,7 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, r
if is_torch_available():
import torch
from torch import nn
from transformers import (
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
@ -241,7 +242,7 @@ class ReformerModelTester:
# set all position encodings to zero so that postions don't matter
with torch.no_grad():
embedding = model.embeddings.position_embeddings.embedding
embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device))
embedding.weight = nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device))
embedding.weight.requires_grad = False
half_seq_len = self.seq_length // 2

View File

@ -27,6 +27,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
import torch
from torch import nn
from transformers import TransfoXLConfig, TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
from transformers.models.transfo_xl.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST
@ -362,11 +363,11 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
if hasattr(module, "emb_projs"):
for i in range(len(module.emb_projs)):
if module.emb_projs[i] is not None:
torch.nn.init.constant_(module.emb_projs[i], 0.0003)
nn.init.constant_(module.emb_projs[i], 0.0003)
if hasattr(module, "out_projs"):
for i in range(len(module.out_projs)):
if module.out_projs[i] is not None:
torch.nn.init.constant_(module.out_projs[i], 0.0003)
nn.init.constant_(module.out_projs[i], 0.0003)
for param in ["r_emb", "r_w_bias", "r_r_bias", "r_bias"]:
if hasattr(module, param) and getattr(module, param) is not None:

View File

@ -27,6 +27,7 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
from torch import nn
from transformers import ViTConfig, ViTForImageClassification, ViTModel
from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
@ -169,9 +170,9 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Module))
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -24,6 +24,7 @@ from transformers.testing_utils import require_torch
if is_torch_available():
import torch
from torch import nn
from transformers import (
Adafactor,
@ -70,7 +71,7 @@ class OptimizationTest(unittest.TestCase):
def test_adam_w(self):
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
target = torch.tensor([0.4, 0.2, -0.5])
criterion = torch.nn.MSELoss()
criterion = nn.MSELoss()
# No warmup, constant schedule, no gradient clipping
optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
for _ in range(100):
@ -84,7 +85,7 @@ class OptimizationTest(unittest.TestCase):
def test_adafactor(self):
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
target = torch.tensor([0.4, 0.2, -0.5])
criterion = torch.nn.MSELoss()
criterion = nn.MSELoss()
# No warmup, constant schedule, no gradient clipping
optimizer = Adafactor(
params=[w],
@ -109,7 +110,7 @@ class OptimizationTest(unittest.TestCase):
@require_torch
class ScheduleInitTest(unittest.TestCase):
m = torch.nn.Linear(50, 50) if is_torch_available() else None
m = nn.Linear(50, 50) if is_torch_available() else None
optimizer = AdamW(m.parameters(), lr=10.0) if is_torch_available() else None
num_steps = 10

View File

@ -32,6 +32,7 @@ from .test_pipelines_common import MonoInputPipelineCommonMixin
if is_torch_available():
import torch
from torch import nn
from transformers.models.gpt2 import GPT2Config, GPT2LMHeadModel
@ -59,8 +60,8 @@ class SimpleConversationPipelineTests(unittest.TestCase):
bias[76] = 1
weight = torch.zeros((V, D), requires_grad=True)
model.lm_head.bias = torch.nn.Parameter(bias)
model.lm_head.weight = torch.nn.Parameter(weight)
model.lm_head.bias = nn.Parameter(bias)
model.lm_head.weight = nn.Parameter(weight)
# # Created with:
# import tempfile

View File

@ -23,6 +23,7 @@ from .test_pipelines_common import MonoInputPipelineCommonMixin
if is_torch_available():
import torch
from torch import nn
from transformers.models.bart import BartConfig, BartForConditionalGeneration
@ -55,7 +56,7 @@ class SimpleSummarizationPipelineTests(unittest.TestCase):
bias = torch.zeros(V)
bias[76] = 10
model.lm_head.bias = torch.nn.Parameter(bias)
model.lm_head.bias = nn.Parameter(bias)
# # Generated with:
# import tempfile

View File

@ -53,6 +53,7 @@ from transformers.utils.hp_naming import TrialShortNamer
if is_torch_available():
import torch
from torch import nn
from torch.utils.data import IterableDataset
from transformers import (
@ -154,11 +155,11 @@ if is_torch_available():
for i in range(len(self.dataset)):
yield self.dataset[i]
class RegressionModel(torch.nn.Module):
class RegressionModel(nn.Module):
def __init__(self, a=0, b=0, double_output=False):
super().__init__()
self.a = torch.nn.Parameter(torch.tensor(a).float())
self.b = torch.nn.Parameter(torch.tensor(b).float())
self.a = nn.Parameter(torch.tensor(a).float())
self.b = nn.Parameter(torch.tensor(b).float())
self.double_output = double_output
self.config = None
@ -166,21 +167,21 @@ if is_torch_available():
y = input_x * self.a + self.b
if labels is None:
return (y, y) if self.double_output else (y,)
loss = torch.nn.functional.mse_loss(y, labels)
loss = nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y)
class RegressionDictModel(torch.nn.Module):
class RegressionDictModel(nn.Module):
def __init__(self, a=0, b=0):
super().__init__()
self.a = torch.nn.Parameter(torch.tensor(a).float())
self.b = torch.nn.Parameter(torch.tensor(b).float())
self.a = nn.Parameter(torch.tensor(a).float())
self.b = nn.Parameter(torch.tensor(b).float())
self.config = None
def forward(self, input_x, labels=None, **kwargs):
y = input_x * self.a + self.b
result = {"output": y}
if labels is not None:
result["loss"] = torch.nn.functional.mse_loss(y, labels)
result["loss"] = nn.functional.mse_loss(y, labels)
return result
class RegressionPreTrainedModel(PreTrainedModel):
@ -189,15 +190,15 @@ if is_torch_available():
def __init__(self, config):
super().__init__(config)
self.a = torch.nn.Parameter(torch.tensor(config.a).float())
self.b = torch.nn.Parameter(torch.tensor(config.b).float())
self.a = nn.Parameter(torch.tensor(config.a).float())
self.b = nn.Parameter(torch.tensor(config.b).float())
self.double_output = config.double_output
def forward(self, input_x, labels=None, **kwargs):
y = input_x * self.a + self.b
if labels is None:
return (y, y) if self.double_output else (y,)
loss = torch.nn.functional.mse_loss(y, labels)
loss = nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y)
class RegressionRandomPreTrainedModel(PreTrainedModel):
@ -206,8 +207,8 @@ if is_torch_available():
def __init__(self, config):
super().__init__(config)
self.a = torch.nn.Parameter(torch.tensor(config.a).float())
self.b = torch.nn.Parameter(torch.tensor(config.b).float())
self.a = nn.Parameter(torch.tensor(config.a).float())
self.b = nn.Parameter(torch.tensor(config.b).float())
def forward(self, input_x, labels=None, **kwargs):
y = input_x * self.a + self.b
@ -219,21 +220,21 @@ if is_torch_available():
if labels is None:
return (y,)
loss = torch.nn.functional.mse_loss(y, labels)
loss = nn.functional.mse_loss(y, labels)
return (loss, y)
class TstLayer(torch.nn.Module):
class TstLayer(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear1 = torch.nn.Linear(hidden_size, hidden_size)
self.ln1 = torch.nn.LayerNorm(hidden_size)
self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
self.ln2 = torch.nn.LayerNorm(hidden_size)
self.bias = torch.nn.Parameter(torch.zeros(hidden_size))
self.linear1 = nn.Linear(hidden_size, hidden_size)
self.ln1 = nn.LayerNorm(hidden_size)
self.linear2 = nn.Linear(hidden_size, hidden_size)
self.ln2 = nn.LayerNorm(hidden_size)
self.bias = nn.Parameter(torch.zeros(hidden_size))
def forward(self, x):
h = self.ln1(torch.nn.functional.relu(self.linear1(x)))
h = torch.nn.functional.relu(self.linear2(x))
h = self.ln1(nn.functional.relu(self.linear1(x)))
h = nn.functional.relu(self.linear2(x))
return self.ln2(x + h + self.bias)
def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, pretrained=True, **kwargs):
@ -1065,7 +1066,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
assert_flos_extraction(trainer, trainer.model)
# with enforced DataParallel
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))
assert_flos_extraction(trainer, nn.DataParallel(trainer.model))
trainer.train()
self.assertTrue(isinstance(trainer.state.total_flos, float))
@ -1186,7 +1187,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
def test_no_wd_param_group(self):
model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)]))
model = nn.Sequential(TstLayer(128), nn.ModuleList([TstLayer(128), TstLayer(128)]))
trainer = Trainer(model=model)
trainer.create_optimizer_and_scheduler(10)
# fmt: off

View File

@ -24,6 +24,7 @@ from transformers.testing_utils import require_torch
if is_torch_available():
import torch
from torch import nn
from torch.utils.data import IterableDataset
from transformers.modeling_outputs import SequenceClassifierOutput
@ -40,18 +41,18 @@ if is_torch_available():
get_parameter_names,
)
class TstLayer(torch.nn.Module):
class TstLayer(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear1 = torch.nn.Linear(hidden_size, hidden_size)
self.ln1 = torch.nn.LayerNorm(hidden_size)
self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
self.ln2 = torch.nn.LayerNorm(hidden_size)
self.bias = torch.nn.Parameter(torch.zeros(hidden_size))
self.linear1 = nn.Linear(hidden_size, hidden_size)
self.ln1 = nn.LayerNorm(hidden_size)
self.linear2 = nn.Linear(hidden_size, hidden_size)
self.ln2 = nn.LayerNorm(hidden_size)
self.bias = nn.Parameter(torch.zeros(hidden_size))
def forward(self, x):
h = self.ln1(torch.nn.functional.relu(self.linear1(x)))
h = torch.nn.functional.relu(self.linear2(x))
h = self.ln1(nn.functional.relu(self.linear1(x)))
h = nn.functional.relu(self.linear2(x))
return self.ln2(x + h + self.bias)
class RandomIterableDataset(IterableDataset):
@ -151,10 +152,10 @@ class TrainerUtilsTest(unittest.TestCase):
num_labels = 12
random_logits = torch.randn(4, 5, num_labels)
random_labels = torch.randint(0, num_labels, (4, 5))
loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1))
loss = nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1))
model_output = SequenceClassifierOutput(logits=random_logits)
label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels)
log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1)
log_probs = -nn.functional.log_softmax(random_logits, dim=-1)
expected_loss = (1 - epsilon) * loss + epsilon * log_probs.mean()
self.assertTrue(torch.allclose(label_smoothed_loss, expected_loss))
@ -163,10 +164,10 @@ class TrainerUtilsTest(unittest.TestCase):
random_labels[2, 1] = -100
random_labels[2, 3] = -100
loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1))
loss = nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1))
model_output = SequenceClassifierOutput(logits=random_logits)
label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels)
log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1)
log_probs = -nn.functional.log_softmax(random_logits, dim=-1)
# Mask the log probs with the -100 labels
log_probs[0, 1] = 0.0
log_probs[2, 1] = 0.0
@ -230,10 +231,10 @@ class TrainerUtilsTest(unittest.TestCase):
self.assertEqual(list(sorted(indices_process_0 + indices_process_1)), list(range(100)))
def test_get_parameter_names(self):
model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)]))
model = nn.Sequential(TstLayer(128), nn.ModuleList([TstLayer(128), TstLayer(128)]))
# fmt: off
self.assertEqual(
get_parameter_names(model, [torch.nn.LayerNorm]),
get_parameter_names(model, [nn.LayerNorm]),
['0.linear1.weight', '0.linear1.bias', '0.linear2.weight', '0.linear2.bias', '0.bias', '1.0.linear1.weight', '1.0.linear1.bias', '1.0.linear2.weight', '1.0.linear2.bias', '1.0.bias', '1.1.linear1.weight', '1.1.linear1.bias', '1.1.linear2.weight', '1.1.linear2.bias', '1.1.bias']
)
# fmt: on