[Benchmark] add tpu and torchscipt for benchmark (#4850)

* add tpu and torchscipt for benchmark

* fix name in tests

* "fix email"

* make style

* better log message for tpu

* add more print and info for tpu

* allow possibility to print tpu metrics

* correct cpu usage

* fix test for non-install

* remove bugus file

* include psutil in testing

* run a couple of times before tracing in torchscript

* do not allow tpu memory tracing for now

* make style

* add torchscript to env

* better name for torch tpu

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Patrick von Platen 2020-06-09 23:12:43 +02:00 committed by GitHub
parent f0340b3031
commit 2cfb947f59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 317 additions and 84 deletions

View File

@ -84,7 +84,7 @@ extras["torch"] = ["torch"]
extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
extras["all"] = extras["serving"] + ["tensorflow", "torch"]
extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator"]
extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator", "psutil"]
extras["docs"] = ["recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme"]
extras["quality"] = [
"black",

View File

@ -78,6 +78,7 @@ from .file_utils import (
cached_path,
is_tf_available,
is_torch_available,
is_torch_tpu_available,
)
from .hf_argparser import HfArgumentParser

View File

@ -19,12 +19,17 @@
import logging
import os
import timeit
from transformers import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, PretrainedConfig, is_torch_available
from transformers import (
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
PretrainedConfig,
is_torch_available,
is_torch_tpu_available,
)
from .benchmark_utils import Benchmark, Memory, start_memory_tracing, stop_memory_tracing
from .benchmark_utils import Benchmark, Memory, measure_peak_memory_cpu, start_memory_tracing, stop_memory_tracing
if is_torch_available():
@ -48,6 +53,10 @@ class PyTorchBenchmark(Benchmark):
def train(self, model_name, batch_size, sequence_length, trace_memory=False):
try:
config = self.config_dict[model_name]
if self.args.torchscript:
config.torchscript = True
model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
model.to(self.args.device)
model.train()
@ -58,15 +67,20 @@ class PyTorchBenchmark(Benchmark):
vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device
)
if self.args.torchscript:
raise NotImplementedError("Training for torchscript is currently not implemented")
else:
train_model = model
def compute_loss_and_backprob_encoder():
loss = model(input_ids, labels=input_ids)[0]
loss = train_model(input_ids, labels=input_ids)[0]
loss.backward()
model.zero_grad()
train_model.zero_grad()
def compute_loss_and_backprob_encoder_decoder():
loss = model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0]
loss = train_model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0]
loss.backward()
model.zero_grad()
train_model.zero_grad()
_train = (
compute_loss_and_backprob_encoder_decoder
@ -79,6 +93,7 @@ class PyTorchBenchmark(Benchmark):
trace = start_memory_tracing("transformers")
if self.args.n_gpu > 0:
# gpu
# clear gpu cache
torch.cuda.empty_cache()
if hasattr(torch.cuda, "max_memory_reserved"):
@ -89,8 +104,17 @@ class PyTorchBenchmark(Benchmark):
)
torch.cuda.reset_max_memory_cached()
# calculate loss and do backpropagation
_train()
# calculate loss and do backpropagation
_train()
elif not self.args.no_tpu and is_torch_tpu_available():
# tpu
raise NotImplementedError(
"Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `args.no_memory=True`"
)
else:
# cpu
memory_bytes = measure_peak_memory_cpu(_train)
memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
if self.args.trace_memory_line_by_line:
summary = stop_memory_tracing(trace)
@ -107,40 +131,47 @@ class PyTorchBenchmark(Benchmark):
)
memory = Memory(torch.cuda.max_memory_cached())
memory = Memory(torch.cuda.max_memory_reserved())
else:
# cpu
try:
import psutil
except (ImportError):
logger.warning(
"Psutil not installed, we won't log CPU memory usage. "
"Install psutil (pip install psutil) to use CPU memory tracing."
)
memory = "N/A"
else:
process = psutil.Process(os.getpid())
memory = Memory(process.memory_info().rss)
return memory, summary
else:
if (not self.args.no_tpu and is_torch_tpu_available()) or self.args.torchscript:
# run additional 10 times to stabilize compilation for tpu and torchscript
logger.info("Do inference on TPU or torchscript. Running model 5 times to stabilize compilation")
timeit.repeat(
_train, repeat=1, number=5,
)
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
runtimes = timeit.repeat(_train, repeat=self.args.repeat, number=10,)
if not self.args.no_tpu and is_torch_tpu_available() and self.args.tpu_print_metrics:
import torch_xla.debug.metrics as met
self.print_fn(met.metrics_report())
return min(runtimes) / 10.0
except RuntimeError as e:
self.print_fn("Doesn't fit on GPU. {}".format(e))
return "N/A"
if trace_memory:
return "N/A", None
else:
return "N/A"
def inference(self, model_name, batch_size, sequence_length, trace_memory=False):
try:
config = self.config_dict[model_name]
model = None
if self.args.torchscript:
config.torchscript = True
if self.args.with_lm_head:
model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
else:
model = MODEL_MAPPING[config.__class__](config)
model.to(self.args.device)
model.eval()
model.to(self.args.device)
# encoder-decoder has vocab size saved differently
vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
@ -149,11 +180,22 @@ class PyTorchBenchmark(Benchmark):
vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device
)
if self.args.torchscript:
with torch.no_grad():
if config.is_encoder_decoder:
raise NotImplementedError("Torchscript is currently not supported for EncoderDecoder models")
else:
inference_model = torch.jit.trace(model, input_ids)
else:
inference_model = model
def encoder_decoder_forward():
model(input_ids, decoder_input_ids=input_ids)
with torch.no_grad():
inference_model(input_ids, decoder_input_ids=input_ids)
def encoder_forward():
model(input_ids)
with torch.no_grad():
inference_model(input_ids)
_forward = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward
@ -162,6 +204,7 @@ class PyTorchBenchmark(Benchmark):
trace = start_memory_tracing("transformers")
if self.args.n_gpu > 0:
# gpu
# clear gpu cache
torch.cuda.empty_cache()
if hasattr(torch.cuda, "max_memory_reserved"):
@ -172,7 +215,17 @@ class PyTorchBenchmark(Benchmark):
)
torch.cuda.reset_max_memory_cached()
_forward()
# run forward
_forward()
elif not self.args.no_tpu and is_torch_tpu_available():
# tpu
raise NotImplementedError(
"Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `args.no_memory=True`"
)
else:
# cpu
memory_bytes = measure_peak_memory_cpu(_forward)
memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
if self.args.trace_memory_line_by_line:
summary = stop_memory_tracing(trace)
@ -188,26 +241,30 @@ class PyTorchBenchmark(Benchmark):
"Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage"
)
memory = Memory(torch.cuda.max_memory_cached())
else:
# cpu
try:
import psutil
except (ImportError):
logger.warning(
"Psutil not installed, we won't log CPU memory usage. "
"Install psutil (pip install psutil) to use CPU memory tracing."
)
memory = "N/A"
else:
process = psutil.Process(os.getpid())
memory = Memory(process.memory_info().rss)
return memory, summary
else:
if (not self.args.no_tpu and is_torch_tpu_available()) or self.args.torchscript:
# run additional 10 times to stabilize compilation for tpu and torchscript
logger.info("Do inference on TPU or torchscript. Running model 5 times to stabilize compilation")
timeit.repeat(
_forward, repeat=1, number=5,
)
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
runtimes = timeit.repeat(_forward, repeat=self.args.repeat, number=10,)
if not self.args.no_tpu and is_torch_tpu_available() and self.args.tpu_print_metrics:
import torch_xla.debug.metrics as met
self.print_fn(met.metrics_report())
return min(runtimes) / 10.0
except RuntimeError as e:
self.print_fn("Doesn't fit on GPU. {}".format(e))
return "N/A"
if trace_memory:
return "N/A", None
else:
return "N/A"

View File

@ -18,25 +18,16 @@ import logging
from dataclasses import dataclass, field
from typing import Tuple
from ..file_utils import cached_property, is_torch_available, torch_required
from ..file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
from .benchmark_args_utils import BenchmarkArguments
if is_torch_available():
import torch
try:
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
_has_tpu = True
except ImportError:
_has_tpu = False
@torch_required
def is_tpu_available():
return _has_tpu
logger = logging.getLogger(__name__)
@ -45,7 +36,9 @@ logger = logging.getLogger(__name__)
class PyTorchBenchmarkArguments(BenchmarkArguments):
no_cuda: bool = field(default=False, metadata={"help": "Whether to run on available cuda devices"})
torchscript: bool = field(default=False, metadata={"help": "Trace the models using torchscript"})
no_tpu: bool = field(default=False, metadata={"help": "Whether to run on available tpu devices"})
fp16: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."})
tpu_print_metrics: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."})
@cached_property
@torch_required
@ -54,7 +47,7 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
if self.no_cuda:
device = torch.device("cpu")
n_gpu = 0
elif is_tpu_available():
elif is_torch_tpu_available():
device = xm.xla_device()
n_gpu = 0
else:

View File

@ -14,12 +14,15 @@ import sys
from abc import ABC, abstractmethod
from collections import defaultdict, namedtuple
from datetime import datetime
from typing import Iterable, List, NamedTuple, Optional, Union
from multiprocessing import Pipe, Process
from multiprocessing.connection import Connection
from signal import SIGKILL
from typing import Callable, Iterable, List, NamedTuple, Optional, Union
from transformers import AutoConfig, PretrainedConfig
from transformers import __version__ as version
from ..file_utils import is_tf_available, is_torch_available
from ..file_utils import is_tf_available, is_torch_available, is_torch_tpu_available
from .benchmark_args_utils import BenchmarkArguments
@ -128,6 +131,127 @@ class MemorySummary(NamedTuple):
MemoryTrace = List[UsedMemoryState]
def measure_peak_memory_cpu(function: Callable[[], None], interval=0.5) -> int:
"""
measures peak cpu memory consumption of a given `function`
running the function for at least interval seconds
and at most 20 * interval seconds.
This function is heavily inspired by: `memory_usage`
of the package `memory_profiler`: https://github.com/pythonprofilers/memory_profiler/blob/895c4ac7a08020d66ae001e24067da6dcea42451/memory_profiler.py#L239
Args:
- `function`: (`callable`): function() -> ...
function without any arguments to measure for which to measure the peak memory
- `interval`: (`float`)
interval in second for which to measure the memory usage
Returns:
- `max_memory`: (`int`)
cosumed memory peak in Bytes
"""
try:
import psutil
except (ImportError):
logger.warning(
"Psutil not installed, we won't log CPU memory usage. "
"Install Psutil (pip install psutil) to use CPU memory tracing."
)
max_memory = "N/A"
else:
def _get_memory(process_id: int) -> int:
"""
measures current cpu memory usage of a given `process_id`
Args:
- `process_id`: (`int`)
process_id for which to measure memory
Returns
- `memory`: (`int`)
cosumed memory in Bytes
"""
process = psutil.Process(process_id)
try:
meminfo_attr = "memory_info" if hasattr(process, "memory_info") else "get_memory_info"
memory = getattr(process, meminfo_attr)()[0]
except psutil.AccessDenied:
raise ValueError("Error with Psutil.")
return memory
class MemoryMeasureProcess(Process):
"""
`MemoryMeasureProcess` inherits from `Process` and overwrites
its `run()` method. Used to measure the memory usage of a process
"""
def __init__(self, process_id: int, child_connection: Connection, interval: float):
super().__init__()
self.process_id = process_id
self.interval = interval
self.connection = child_connection
self.num_measurements = 1
self.mem_usage = _get_memory(process_id)
def run(self):
self.connection.send(0)
stop = False
while True:
self.mem_usage = max(self.mem_usage, _get_memory(self.process_id))
self.num_measurements += 1
if stop:
break
stop = self.connection.poll(self.interval)
# send results to parent pipe
self.connection.send(self.mem_usage)
self.connection.send(self.num_measurements)
while True:
# create child, parent connection
child_connection, parent_connection = Pipe()
# instantiate process
mem_process = MemoryMeasureProcess(os.getpid(), child_connection, interval)
mem_process.start()
# wait until we get memory
parent_connection.recv()
try:
# execute function
function()
# start parent connection
parent_connection.send(0)
# receive memory and num measurements
max_memory = parent_connection.recv()
num_measurements = parent_connection.recv()
except Exception:
# kill process in a clean way
parent = psutil.Process(os.getpid())
for child in parent.children(recursive=True):
os.kill(child.pid, SIGKILL)
mem_process.join(0)
raise RuntimeError("Process killed. Error in Process")
# run process at least 20 * interval or until it finishes
mem_process.join(20 * interval)
if (num_measurements > 4) or (interval < 1e-6):
break
# reduce interval
interval /= 10
return max_memory
def start_memory_tracing(
modules_to_trace: Optional[Union[str, Iterable[str]]] = None,
modules_not_to_trace: Optional[Union[str, Iterable[str]]] = None,
@ -424,6 +548,10 @@ class Benchmark(ABC):
def is_gpu(self):
return self.args.n_gpu > 0
@property
def is_tpu(self):
return is_torch_tpu_available() and not self.args.no_tpu
@property
@abstractmethod
def framework_version(self):
@ -486,6 +614,10 @@ class Benchmark(ABC):
self.print_fn("======= INFERENCE - SPEED - RESULT =======")
self.print_results(inference_result_time)
self.save_to_csv(inference_result_time, self.args.inference_time_csv_file)
if self.is_tpu:
self.print_fn(
"TPU was used for inference. Note that the time after compilation stabilized (after ~10 inferences model.forward(..) calls) was measured."
)
if not self.args.no_memory:
self.print_fn("======= INFERENCE - MEMORY - RESULT =======")
@ -501,6 +633,10 @@ class Benchmark(ABC):
self.print_fn("======= TRAIN - SPEED - RESULT =======")
self.print_results(train_result_time)
self.save_to_csv(train_result_time, self.args.train_time_csv_file)
if self.is_tpu:
self.print_fn(
"TPU was used for training. Note that the time after compilation stabilized (after ~10 train loss=model.forward(...) + loss.backward() calls) was measured."
)
if not self.args.no_memory:
self.print_fn("======= TRAIN - MEMORY - RESULT =======")
@ -538,6 +674,8 @@ class Benchmark(ABC):
info = {}
info["transformers_version"] = version
info["framework"] = self.framework
if self.framework == "PyTorch":
info["use_torchscript"] = self.args.torchscript
info["framework_version"] = self.framework_version
info["python_version"] = platform.python_version()
info["system"] = platform.system()
@ -590,6 +728,10 @@ class Benchmark(ABC):
info["gpu_performance_state"] = py3nvml.nvmlDeviceGetPerformanceState(handle)
py3nvml.nvmlShutdown()
info["use_tpu"] = self.is_tpu
# TODO(PVP): See if we can add more information about TPU
# see: https://github.com/pytorch/xla/issues/2180
self._environment_info = info
return self._environment_info

View File

@ -68,6 +68,21 @@ except ImportError:
torch_cache_home = os.path.expanduser(
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
)
try:
import torch_xla.core.xla_model as xm
tpu_device = xm.xla_device()
if _torch_available:
_torch_tpu_available = True # pylint: disable=
else:
_torch_tpu_available = False
except ImportError:
_torch_tpu_available = False
default_cache_path = os.path.join(torch_cache_home, "transformers")
@ -98,6 +113,10 @@ def is_tf_available():
return _tf_available
def is_torch_tpu_available():
return _torch_tpu_available
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")

View File

@ -23,7 +23,7 @@ from .data.data_collator import DataCollator, DefaultDataCollator
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput
from .training_args import TrainingArguments, is_tpu_available
from .training_args import TrainingArguments, is_torch_tpu_available
try:
@ -38,7 +38,7 @@ def is_apex_available():
return _has_apex
if is_tpu_available():
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
@ -218,7 +218,7 @@ class Trainer:
# Create output directory if needed
if self.is_world_master():
os.makedirs(self.args.output_dir, exist_ok=True)
if is_tpu_available():
if is_torch_tpu_available():
# Set an xla_device flag on the model's config.
# We'll find a more elegant and not need to do this in the future.
self.model.config.xla_device = True
@ -226,7 +226,7 @@ class Trainer:
def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
if is_tpu_available():
if is_torch_tpu_available():
train_sampler = get_tpu_sampler(self.train_dataset)
else:
train_sampler = (
@ -251,7 +251,7 @@ class Trainer:
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
if is_tpu_available():
if is_torch_tpu_available():
sampler = SequentialDistributedSampler(
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
)
@ -272,7 +272,7 @@ class Trainer:
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
# We use the same batch_size as for eval.
if is_tpu_available():
if is_torch_tpu_available():
sampler = SequentialDistributedSampler(
test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
)
@ -407,7 +407,7 @@ class Trainer:
self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
# Train!
if is_tpu_available():
if is_torch_tpu_available():
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
else:
total_train_batch_size = (
@ -455,7 +455,7 @@ class Trainer:
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
if is_tpu_available():
if is_torch_tpu_available():
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
self.args.device
)
@ -482,7 +482,7 @@ class Trainer:
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
if is_tpu_available():
if is_torch_tpu_available():
xm.optimizer_step(optimizer)
else:
optimizer.step()
@ -525,7 +525,7 @@ class Trainer:
if self.is_world_master():
self._rotate_checkpoints()
if is_tpu_available():
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
@ -588,7 +588,7 @@ class Trainer:
return loss.item()
def is_local_master(self) -> bool:
if is_tpu_available():
if is_torch_tpu_available():
return xm.is_master_ordinal(local=True)
else:
return self.args.local_rank in [-1, 0]
@ -598,7 +598,7 @@ class Trainer:
This will be True only in one process, even in distributed mode,
even when training on multiple machines.
"""
if is_tpu_available():
if is_torch_tpu_available():
return xm.is_master_ordinal(local=False)
else:
return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
@ -611,7 +611,7 @@ class Trainer:
Will only save from the world_master process (unless in TPUs).
"""
if is_tpu_available():
if is_torch_tpu_available():
self._save_tpu(output_dir)
elif self.is_world_master():
self._save(output_dir)
@ -746,7 +746,7 @@ class Trainer:
label_ids: torch.Tensor = None
model.eval()
if is_tpu_available():
if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
for inputs in tqdm(dataloader, desc=description):
@ -780,7 +780,7 @@ class Trainer:
preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
if label_ids is not None:
label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
elif is_tpu_available():
elif is_torch_tpu_available():
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
if preds is not None:
preds = xm.mesh_reduce("eval_preds", preds, torch.cat)

View File

@ -5,25 +5,15 @@ import os
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
from .file_utils import cached_property, is_torch_available, torch_required
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
if is_torch_available():
import torch
try:
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
_has_tpu = True
except ImportError:
_has_tpu = False
@torch_required
def is_tpu_available():
return _has_tpu
logger = logging.getLogger(__name__)
@ -176,7 +166,7 @@ class TrainingArguments:
if self.no_cuda:
device = torch.device("cpu")
n_gpu = 0
elif is_tpu_available():
elif is_torch_tpu_available():
device = xm.xla_device()
n_gpu = 0
elif self.local_rank == -1:

View File

@ -33,6 +33,21 @@ class BenchmarkTest(unittest.TestCase):
self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result)
def test_inference_torchscript(self):
MODEL_ID = "sshleifer/tiny-gpt2"
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID],
training=False,
no_inference=False,
torchscript=True,
sequence_lengths=[8],
batch_sizes=[1],
)
benchmark = PyTorchBenchmark(benchmark_args)
results = benchmark.run()
self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result)
def test_train_no_configs(self):
MODEL_ID = "sshleifer/tiny-gpt2"
benchmark_args = PyTorchBenchmarkArguments(
@ -76,6 +91,22 @@ class BenchmarkTest(unittest.TestCase):
self.check_results_dict_not_empty(results.time_train_result)
self.check_results_dict_not_empty(results.memory_train_result)
def test_train_with_configs_torchscript(self):
MODEL_ID = "sshleifer/tiny-gpt2"
config = AutoConfig.from_pretrained(MODEL_ID)
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID],
training=True,
no_inference=True,
torchscript=True,
sequence_lengths=[8],
batch_sizes=[1],
)
benchmark = PyTorchBenchmark(benchmark_args, configs=[config])
results = benchmark.run()
self.check_results_dict_not_empty(results.time_train_result)
self.check_results_dict_not_empty(results.memory_train_result)
def test_train_encoder_decoder_with_configs(self):
MODEL_ID = "sshleifer/tinier_bart"
config = AutoConfig.from_pretrained(MODEL_ID)