mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Benchmark] add tpu and torchscipt for benchmark (#4850)
* add tpu and torchscipt for benchmark * fix name in tests * "fix email" * make style * better log message for tpu * add more print and info for tpu * allow possibility to print tpu metrics * correct cpu usage * fix test for non-install * remove bugus file * include psutil in testing * run a couple of times before tracing in torchscript * do not allow tpu memory tracing for now * make style * add torchscript to env * better name for torch tpu Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
parent
f0340b3031
commit
2cfb947f59
2
setup.py
2
setup.py
@ -84,7 +84,7 @@ extras["torch"] = ["torch"]
|
||||
extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
|
||||
extras["all"] = extras["serving"] + ["tensorflow", "torch"]
|
||||
|
||||
extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator"]
|
||||
extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator", "psutil"]
|
||||
extras["docs"] = ["recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme"]
|
||||
extras["quality"] = [
|
||||
"black",
|
||||
|
@ -78,6 +78,7 @@ from .file_utils import (
|
||||
cached_path,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_torch_tpu_available,
|
||||
)
|
||||
from .hf_argparser import HfArgumentParser
|
||||
|
||||
|
@ -19,12 +19,17 @@
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
import timeit
|
||||
|
||||
from transformers import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, PretrainedConfig, is_torch_available
|
||||
from transformers import (
|
||||
MODEL_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
PretrainedConfig,
|
||||
is_torch_available,
|
||||
is_torch_tpu_available,
|
||||
)
|
||||
|
||||
from .benchmark_utils import Benchmark, Memory, start_memory_tracing, stop_memory_tracing
|
||||
from .benchmark_utils import Benchmark, Memory, measure_peak_memory_cpu, start_memory_tracing, stop_memory_tracing
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -48,6 +53,10 @@ class PyTorchBenchmark(Benchmark):
|
||||
def train(self, model_name, batch_size, sequence_length, trace_memory=False):
|
||||
try:
|
||||
config = self.config_dict[model_name]
|
||||
|
||||
if self.args.torchscript:
|
||||
config.torchscript = True
|
||||
|
||||
model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
|
||||
model.to(self.args.device)
|
||||
model.train()
|
||||
@ -58,15 +67,20 @@ class PyTorchBenchmark(Benchmark):
|
||||
vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device
|
||||
)
|
||||
|
||||
if self.args.torchscript:
|
||||
raise NotImplementedError("Training for torchscript is currently not implemented")
|
||||
else:
|
||||
train_model = model
|
||||
|
||||
def compute_loss_and_backprob_encoder():
|
||||
loss = model(input_ids, labels=input_ids)[0]
|
||||
loss = train_model(input_ids, labels=input_ids)[0]
|
||||
loss.backward()
|
||||
model.zero_grad()
|
||||
train_model.zero_grad()
|
||||
|
||||
def compute_loss_and_backprob_encoder_decoder():
|
||||
loss = model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0]
|
||||
loss = train_model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0]
|
||||
loss.backward()
|
||||
model.zero_grad()
|
||||
train_model.zero_grad()
|
||||
|
||||
_train = (
|
||||
compute_loss_and_backprob_encoder_decoder
|
||||
@ -79,6 +93,7 @@ class PyTorchBenchmark(Benchmark):
|
||||
trace = start_memory_tracing("transformers")
|
||||
|
||||
if self.args.n_gpu > 0:
|
||||
# gpu
|
||||
# clear gpu cache
|
||||
torch.cuda.empty_cache()
|
||||
if hasattr(torch.cuda, "max_memory_reserved"):
|
||||
@ -89,8 +104,17 @@ class PyTorchBenchmark(Benchmark):
|
||||
)
|
||||
torch.cuda.reset_max_memory_cached()
|
||||
|
||||
# calculate loss and do backpropagation
|
||||
_train()
|
||||
# calculate loss and do backpropagation
|
||||
_train()
|
||||
elif not self.args.no_tpu and is_torch_tpu_available():
|
||||
# tpu
|
||||
raise NotImplementedError(
|
||||
"Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `args.no_memory=True`"
|
||||
)
|
||||
else:
|
||||
# cpu
|
||||
memory_bytes = measure_peak_memory_cpu(_train)
|
||||
memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
|
||||
|
||||
if self.args.trace_memory_line_by_line:
|
||||
summary = stop_memory_tracing(trace)
|
||||
@ -107,40 +131,47 @@ class PyTorchBenchmark(Benchmark):
|
||||
)
|
||||
memory = Memory(torch.cuda.max_memory_cached())
|
||||
memory = Memory(torch.cuda.max_memory_reserved())
|
||||
else:
|
||||
# cpu
|
||||
try:
|
||||
import psutil
|
||||
except (ImportError):
|
||||
logger.warning(
|
||||
"Psutil not installed, we won't log CPU memory usage. "
|
||||
"Install psutil (pip install psutil) to use CPU memory tracing."
|
||||
)
|
||||
memory = "N/A"
|
||||
else:
|
||||
process = psutil.Process(os.getpid())
|
||||
memory = Memory(process.memory_info().rss)
|
||||
|
||||
return memory, summary
|
||||
else:
|
||||
if (not self.args.no_tpu and is_torch_tpu_available()) or self.args.torchscript:
|
||||
# run additional 10 times to stabilize compilation for tpu and torchscript
|
||||
logger.info("Do inference on TPU or torchscript. Running model 5 times to stabilize compilation")
|
||||
timeit.repeat(
|
||||
_train, repeat=1, number=5,
|
||||
)
|
||||
|
||||
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
|
||||
runtimes = timeit.repeat(_train, repeat=self.args.repeat, number=10,)
|
||||
|
||||
if not self.args.no_tpu and is_torch_tpu_available() and self.args.tpu_print_metrics:
|
||||
import torch_xla.debug.metrics as met
|
||||
|
||||
self.print_fn(met.metrics_report())
|
||||
|
||||
return min(runtimes) / 10.0
|
||||
except RuntimeError as e:
|
||||
self.print_fn("Doesn't fit on GPU. {}".format(e))
|
||||
return "N/A"
|
||||
if trace_memory:
|
||||
return "N/A", None
|
||||
else:
|
||||
return "N/A"
|
||||
|
||||
def inference(self, model_name, batch_size, sequence_length, trace_memory=False):
|
||||
try:
|
||||
config = self.config_dict[model_name]
|
||||
model = None
|
||||
|
||||
if self.args.torchscript:
|
||||
config.torchscript = True
|
||||
|
||||
if self.args.with_lm_head:
|
||||
model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
|
||||
else:
|
||||
model = MODEL_MAPPING[config.__class__](config)
|
||||
|
||||
model.to(self.args.device)
|
||||
model.eval()
|
||||
model.to(self.args.device)
|
||||
|
||||
# encoder-decoder has vocab size saved differently
|
||||
vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
|
||||
@ -149,11 +180,22 @@ class PyTorchBenchmark(Benchmark):
|
||||
vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device
|
||||
)
|
||||
|
||||
if self.args.torchscript:
|
||||
with torch.no_grad():
|
||||
if config.is_encoder_decoder:
|
||||
raise NotImplementedError("Torchscript is currently not supported for EncoderDecoder models")
|
||||
else:
|
||||
inference_model = torch.jit.trace(model, input_ids)
|
||||
else:
|
||||
inference_model = model
|
||||
|
||||
def encoder_decoder_forward():
|
||||
model(input_ids, decoder_input_ids=input_ids)
|
||||
with torch.no_grad():
|
||||
inference_model(input_ids, decoder_input_ids=input_ids)
|
||||
|
||||
def encoder_forward():
|
||||
model(input_ids)
|
||||
with torch.no_grad():
|
||||
inference_model(input_ids)
|
||||
|
||||
_forward = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward
|
||||
|
||||
@ -162,6 +204,7 @@ class PyTorchBenchmark(Benchmark):
|
||||
trace = start_memory_tracing("transformers")
|
||||
|
||||
if self.args.n_gpu > 0:
|
||||
# gpu
|
||||
# clear gpu cache
|
||||
torch.cuda.empty_cache()
|
||||
if hasattr(torch.cuda, "max_memory_reserved"):
|
||||
@ -172,7 +215,17 @@ class PyTorchBenchmark(Benchmark):
|
||||
)
|
||||
torch.cuda.reset_max_memory_cached()
|
||||
|
||||
_forward()
|
||||
# run forward
|
||||
_forward()
|
||||
elif not self.args.no_tpu and is_torch_tpu_available():
|
||||
# tpu
|
||||
raise NotImplementedError(
|
||||
"Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `args.no_memory=True`"
|
||||
)
|
||||
else:
|
||||
# cpu
|
||||
memory_bytes = measure_peak_memory_cpu(_forward)
|
||||
memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
|
||||
|
||||
if self.args.trace_memory_line_by_line:
|
||||
summary = stop_memory_tracing(trace)
|
||||
@ -188,26 +241,30 @@ class PyTorchBenchmark(Benchmark):
|
||||
"Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage"
|
||||
)
|
||||
memory = Memory(torch.cuda.max_memory_cached())
|
||||
else:
|
||||
# cpu
|
||||
try:
|
||||
import psutil
|
||||
except (ImportError):
|
||||
logger.warning(
|
||||
"Psutil not installed, we won't log CPU memory usage. "
|
||||
"Install psutil (pip install psutil) to use CPU memory tracing."
|
||||
)
|
||||
memory = "N/A"
|
||||
else:
|
||||
process = psutil.Process(os.getpid())
|
||||
memory = Memory(process.memory_info().rss)
|
||||
|
||||
return memory, summary
|
||||
else:
|
||||
|
||||
if (not self.args.no_tpu and is_torch_tpu_available()) or self.args.torchscript:
|
||||
# run additional 10 times to stabilize compilation for tpu and torchscript
|
||||
logger.info("Do inference on TPU or torchscript. Running model 5 times to stabilize compilation")
|
||||
timeit.repeat(
|
||||
_forward, repeat=1, number=5,
|
||||
)
|
||||
|
||||
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
|
||||
runtimes = timeit.repeat(_forward, repeat=self.args.repeat, number=10,)
|
||||
|
||||
if not self.args.no_tpu and is_torch_tpu_available() and self.args.tpu_print_metrics:
|
||||
import torch_xla.debug.metrics as met
|
||||
|
||||
self.print_fn(met.metrics_report())
|
||||
|
||||
return min(runtimes) / 10.0
|
||||
|
||||
except RuntimeError as e:
|
||||
self.print_fn("Doesn't fit on GPU. {}".format(e))
|
||||
return "N/A"
|
||||
if trace_memory:
|
||||
return "N/A", None
|
||||
else:
|
||||
return "N/A"
|
||||
|
@ -18,25 +18,16 @@ import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Tuple
|
||||
|
||||
from ..file_utils import cached_property, is_torch_available, torch_required
|
||||
from ..file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
|
||||
from .benchmark_args_utils import BenchmarkArguments
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
try:
|
||||
if is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
_has_tpu = True
|
||||
except ImportError:
|
||||
_has_tpu = False
|
||||
|
||||
|
||||
@torch_required
|
||||
def is_tpu_available():
|
||||
return _has_tpu
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -45,7 +36,9 @@ logger = logging.getLogger(__name__)
|
||||
class PyTorchBenchmarkArguments(BenchmarkArguments):
|
||||
no_cuda: bool = field(default=False, metadata={"help": "Whether to run on available cuda devices"})
|
||||
torchscript: bool = field(default=False, metadata={"help": "Trace the models using torchscript"})
|
||||
no_tpu: bool = field(default=False, metadata={"help": "Whether to run on available tpu devices"})
|
||||
fp16: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."})
|
||||
tpu_print_metrics: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."})
|
||||
|
||||
@cached_property
|
||||
@torch_required
|
||||
@ -54,7 +47,7 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
|
||||
if self.no_cuda:
|
||||
device = torch.device("cpu")
|
||||
n_gpu = 0
|
||||
elif is_tpu_available():
|
||||
elif is_torch_tpu_available():
|
||||
device = xm.xla_device()
|
||||
n_gpu = 0
|
||||
else:
|
||||
|
@ -14,12 +14,15 @@ import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict, namedtuple
|
||||
from datetime import datetime
|
||||
from typing import Iterable, List, NamedTuple, Optional, Union
|
||||
from multiprocessing import Pipe, Process
|
||||
from multiprocessing.connection import Connection
|
||||
from signal import SIGKILL
|
||||
from typing import Callable, Iterable, List, NamedTuple, Optional, Union
|
||||
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
from transformers import __version__ as version
|
||||
|
||||
from ..file_utils import is_tf_available, is_torch_available
|
||||
from ..file_utils import is_tf_available, is_torch_available, is_torch_tpu_available
|
||||
from .benchmark_args_utils import BenchmarkArguments
|
||||
|
||||
|
||||
@ -128,6 +131,127 @@ class MemorySummary(NamedTuple):
|
||||
MemoryTrace = List[UsedMemoryState]
|
||||
|
||||
|
||||
def measure_peak_memory_cpu(function: Callable[[], None], interval=0.5) -> int:
|
||||
"""
|
||||
measures peak cpu memory consumption of a given `function`
|
||||
running the function for at least interval seconds
|
||||
and at most 20 * interval seconds.
|
||||
This function is heavily inspired by: `memory_usage`
|
||||
of the package `memory_profiler`: https://github.com/pythonprofilers/memory_profiler/blob/895c4ac7a08020d66ae001e24067da6dcea42451/memory_profiler.py#L239
|
||||
|
||||
Args:
|
||||
- `function`: (`callable`): function() -> ...
|
||||
function without any arguments to measure for which to measure the peak memory
|
||||
|
||||
- `interval`: (`float`)
|
||||
interval in second for which to measure the memory usage
|
||||
|
||||
Returns:
|
||||
- `max_memory`: (`int`)
|
||||
cosumed memory peak in Bytes
|
||||
"""
|
||||
try:
|
||||
import psutil
|
||||
except (ImportError):
|
||||
logger.warning(
|
||||
"Psutil not installed, we won't log CPU memory usage. "
|
||||
"Install Psutil (pip install psutil) to use CPU memory tracing."
|
||||
)
|
||||
max_memory = "N/A"
|
||||
else:
|
||||
|
||||
def _get_memory(process_id: int) -> int:
|
||||
"""
|
||||
measures current cpu memory usage of a given `process_id`
|
||||
|
||||
Args:
|
||||
- `process_id`: (`int`)
|
||||
process_id for which to measure memory
|
||||
|
||||
Returns
|
||||
- `memory`: (`int`)
|
||||
cosumed memory in Bytes
|
||||
"""
|
||||
process = psutil.Process(process_id)
|
||||
try:
|
||||
meminfo_attr = "memory_info" if hasattr(process, "memory_info") else "get_memory_info"
|
||||
memory = getattr(process, meminfo_attr)()[0]
|
||||
except psutil.AccessDenied:
|
||||
raise ValueError("Error with Psutil.")
|
||||
return memory
|
||||
|
||||
class MemoryMeasureProcess(Process):
|
||||
|
||||
"""
|
||||
`MemoryMeasureProcess` inherits from `Process` and overwrites
|
||||
its `run()` method. Used to measure the memory usage of a process
|
||||
"""
|
||||
|
||||
def __init__(self, process_id: int, child_connection: Connection, interval: float):
|
||||
super().__init__()
|
||||
self.process_id = process_id
|
||||
self.interval = interval
|
||||
self.connection = child_connection
|
||||
self.num_measurements = 1
|
||||
self.mem_usage = _get_memory(process_id)
|
||||
|
||||
def run(self):
|
||||
self.connection.send(0)
|
||||
stop = False
|
||||
while True:
|
||||
self.mem_usage = max(self.mem_usage, _get_memory(self.process_id))
|
||||
self.num_measurements += 1
|
||||
|
||||
if stop:
|
||||
break
|
||||
|
||||
stop = self.connection.poll(self.interval)
|
||||
|
||||
# send results to parent pipe
|
||||
self.connection.send(self.mem_usage)
|
||||
self.connection.send(self.num_measurements)
|
||||
|
||||
while True:
|
||||
# create child, parent connection
|
||||
child_connection, parent_connection = Pipe()
|
||||
|
||||
# instantiate process
|
||||
mem_process = MemoryMeasureProcess(os.getpid(), child_connection, interval)
|
||||
mem_process.start()
|
||||
|
||||
# wait until we get memory
|
||||
parent_connection.recv()
|
||||
|
||||
try:
|
||||
# execute function
|
||||
function()
|
||||
|
||||
# start parent connection
|
||||
parent_connection.send(0)
|
||||
|
||||
# receive memory and num measurements
|
||||
max_memory = parent_connection.recv()
|
||||
num_measurements = parent_connection.recv()
|
||||
except Exception:
|
||||
# kill process in a clean way
|
||||
parent = psutil.Process(os.getpid())
|
||||
for child in parent.children(recursive=True):
|
||||
os.kill(child.pid, SIGKILL)
|
||||
mem_process.join(0)
|
||||
raise RuntimeError("Process killed. Error in Process")
|
||||
|
||||
# run process at least 20 * interval or until it finishes
|
||||
mem_process.join(20 * interval)
|
||||
|
||||
if (num_measurements > 4) or (interval < 1e-6):
|
||||
break
|
||||
|
||||
# reduce interval
|
||||
interval /= 10
|
||||
|
||||
return max_memory
|
||||
|
||||
|
||||
def start_memory_tracing(
|
||||
modules_to_trace: Optional[Union[str, Iterable[str]]] = None,
|
||||
modules_not_to_trace: Optional[Union[str, Iterable[str]]] = None,
|
||||
@ -424,6 +548,10 @@ class Benchmark(ABC):
|
||||
def is_gpu(self):
|
||||
return self.args.n_gpu > 0
|
||||
|
||||
@property
|
||||
def is_tpu(self):
|
||||
return is_torch_tpu_available() and not self.args.no_tpu
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def framework_version(self):
|
||||
@ -486,6 +614,10 @@ class Benchmark(ABC):
|
||||
self.print_fn("======= INFERENCE - SPEED - RESULT =======")
|
||||
self.print_results(inference_result_time)
|
||||
self.save_to_csv(inference_result_time, self.args.inference_time_csv_file)
|
||||
if self.is_tpu:
|
||||
self.print_fn(
|
||||
"TPU was used for inference. Note that the time after compilation stabilized (after ~10 inferences model.forward(..) calls) was measured."
|
||||
)
|
||||
|
||||
if not self.args.no_memory:
|
||||
self.print_fn("======= INFERENCE - MEMORY - RESULT =======")
|
||||
@ -501,6 +633,10 @@ class Benchmark(ABC):
|
||||
self.print_fn("======= TRAIN - SPEED - RESULT =======")
|
||||
self.print_results(train_result_time)
|
||||
self.save_to_csv(train_result_time, self.args.train_time_csv_file)
|
||||
if self.is_tpu:
|
||||
self.print_fn(
|
||||
"TPU was used for training. Note that the time after compilation stabilized (after ~10 train loss=model.forward(...) + loss.backward() calls) was measured."
|
||||
)
|
||||
|
||||
if not self.args.no_memory:
|
||||
self.print_fn("======= TRAIN - MEMORY - RESULT =======")
|
||||
@ -538,6 +674,8 @@ class Benchmark(ABC):
|
||||
info = {}
|
||||
info["transformers_version"] = version
|
||||
info["framework"] = self.framework
|
||||
if self.framework == "PyTorch":
|
||||
info["use_torchscript"] = self.args.torchscript
|
||||
info["framework_version"] = self.framework_version
|
||||
info["python_version"] = platform.python_version()
|
||||
info["system"] = platform.system()
|
||||
@ -590,6 +728,10 @@ class Benchmark(ABC):
|
||||
info["gpu_performance_state"] = py3nvml.nvmlDeviceGetPerformanceState(handle)
|
||||
py3nvml.nvmlShutdown()
|
||||
|
||||
info["use_tpu"] = self.is_tpu
|
||||
# TODO(PVP): See if we can add more information about TPU
|
||||
# see: https://github.com/pytorch/xla/issues/2180
|
||||
|
||||
self._environment_info = info
|
||||
return self._environment_info
|
||||
|
||||
|
@ -68,6 +68,21 @@ except ImportError:
|
||||
torch_cache_home = os.path.expanduser(
|
||||
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
tpu_device = xm.xla_device()
|
||||
|
||||
if _torch_available:
|
||||
_torch_tpu_available = True # pylint: disable=
|
||||
else:
|
||||
_torch_tpu_available = False
|
||||
except ImportError:
|
||||
_torch_tpu_available = False
|
||||
|
||||
|
||||
default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||
|
||||
|
||||
@ -98,6 +113,10 @@ def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
|
||||
def is_torch_tpu_available():
|
||||
return _torch_tpu_available
|
||||
|
||||
|
||||
def add_start_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
|
||||
|
@ -23,7 +23,7 @@ from .data.data_collator import DataCollator, DefaultDataCollator
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .optimization import AdamW, get_linear_schedule_with_warmup
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput
|
||||
from .training_args import TrainingArguments, is_tpu_available
|
||||
from .training_args import TrainingArguments, is_torch_tpu_available
|
||||
|
||||
|
||||
try:
|
||||
@ -38,7 +38,7 @@ def is_apex_available():
|
||||
return _has_apex
|
||||
|
||||
|
||||
if is_tpu_available():
|
||||
if is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.metrics as met
|
||||
import torch_xla.distributed.parallel_loader as pl
|
||||
@ -218,7 +218,7 @@ class Trainer:
|
||||
# Create output directory if needed
|
||||
if self.is_world_master():
|
||||
os.makedirs(self.args.output_dir, exist_ok=True)
|
||||
if is_tpu_available():
|
||||
if is_torch_tpu_available():
|
||||
# Set an xla_device flag on the model's config.
|
||||
# We'll find a more elegant and not need to do this in the future.
|
||||
self.model.config.xla_device = True
|
||||
@ -226,7 +226,7 @@ class Trainer:
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.train_dataset is None:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
if is_tpu_available():
|
||||
if is_torch_tpu_available():
|
||||
train_sampler = get_tpu_sampler(self.train_dataset)
|
||||
else:
|
||||
train_sampler = (
|
||||
@ -251,7 +251,7 @@ class Trainer:
|
||||
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
|
||||
if is_tpu_available():
|
||||
if is_torch_tpu_available():
|
||||
sampler = SequentialDistributedSampler(
|
||||
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
|
||||
)
|
||||
@ -272,7 +272,7 @@ class Trainer:
|
||||
|
||||
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
||||
# We use the same batch_size as for eval.
|
||||
if is_tpu_available():
|
||||
if is_torch_tpu_available():
|
||||
sampler = SequentialDistributedSampler(
|
||||
test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
|
||||
)
|
||||
@ -407,7 +407,7 @@ class Trainer:
|
||||
self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
|
||||
|
||||
# Train!
|
||||
if is_tpu_available():
|
||||
if is_torch_tpu_available():
|
||||
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
|
||||
else:
|
||||
total_train_batch_size = (
|
||||
@ -455,7 +455,7 @@ class Trainer:
|
||||
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
||||
train_dataloader.sampler.set_epoch(epoch)
|
||||
|
||||
if is_tpu_available():
|
||||
if is_torch_tpu_available():
|
||||
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
|
||||
self.args.device
|
||||
)
|
||||
@ -482,7 +482,7 @@ class Trainer:
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
||||
|
||||
if is_tpu_available():
|
||||
if is_torch_tpu_available():
|
||||
xm.optimizer_step(optimizer)
|
||||
else:
|
||||
optimizer.step()
|
||||
@ -525,7 +525,7 @@ class Trainer:
|
||||
if self.is_world_master():
|
||||
self._rotate_checkpoints()
|
||||
|
||||
if is_tpu_available():
|
||||
if is_torch_tpu_available():
|
||||
xm.rendezvous("saving_optimizer_states")
|
||||
xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
@ -588,7 +588,7 @@ class Trainer:
|
||||
return loss.item()
|
||||
|
||||
def is_local_master(self) -> bool:
|
||||
if is_tpu_available():
|
||||
if is_torch_tpu_available():
|
||||
return xm.is_master_ordinal(local=True)
|
||||
else:
|
||||
return self.args.local_rank in [-1, 0]
|
||||
@ -598,7 +598,7 @@ class Trainer:
|
||||
This will be True only in one process, even in distributed mode,
|
||||
even when training on multiple machines.
|
||||
"""
|
||||
if is_tpu_available():
|
||||
if is_torch_tpu_available():
|
||||
return xm.is_master_ordinal(local=False)
|
||||
else:
|
||||
return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
|
||||
@ -611,7 +611,7 @@ class Trainer:
|
||||
Will only save from the world_master process (unless in TPUs).
|
||||
"""
|
||||
|
||||
if is_tpu_available():
|
||||
if is_torch_tpu_available():
|
||||
self._save_tpu(output_dir)
|
||||
elif self.is_world_master():
|
||||
self._save(output_dir)
|
||||
@ -746,7 +746,7 @@ class Trainer:
|
||||
label_ids: torch.Tensor = None
|
||||
model.eval()
|
||||
|
||||
if is_tpu_available():
|
||||
if is_torch_tpu_available():
|
||||
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
|
||||
|
||||
for inputs in tqdm(dataloader, desc=description):
|
||||
@ -780,7 +780,7 @@ class Trainer:
|
||||
preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
|
||||
if label_ids is not None:
|
||||
label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
|
||||
elif is_tpu_available():
|
||||
elif is_torch_tpu_available():
|
||||
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
|
||||
if preds is not None:
|
||||
preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
|
||||
|
@ -5,25 +5,15 @@ import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from .file_utils import cached_property, is_torch_available, torch_required
|
||||
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
try:
|
||||
if is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
_has_tpu = True
|
||||
except ImportError:
|
||||
_has_tpu = False
|
||||
|
||||
|
||||
@torch_required
|
||||
def is_tpu_available():
|
||||
return _has_tpu
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -176,7 +166,7 @@ class TrainingArguments:
|
||||
if self.no_cuda:
|
||||
device = torch.device("cpu")
|
||||
n_gpu = 0
|
||||
elif is_tpu_available():
|
||||
elif is_torch_tpu_available():
|
||||
device = xm.xla_device()
|
||||
n_gpu = 0
|
||||
elif self.local_rank == -1:
|
||||
|
@ -33,6 +33,21 @@ class BenchmarkTest(unittest.TestCase):
|
||||
self.check_results_dict_not_empty(results.time_inference_result)
|
||||
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||
|
||||
def test_inference_torchscript(self):
|
||||
MODEL_ID = "sshleifer/tiny-gpt2"
|
||||
benchmark_args = PyTorchBenchmarkArguments(
|
||||
models=[MODEL_ID],
|
||||
training=False,
|
||||
no_inference=False,
|
||||
torchscript=True,
|
||||
sequence_lengths=[8],
|
||||
batch_sizes=[1],
|
||||
)
|
||||
benchmark = PyTorchBenchmark(benchmark_args)
|
||||
results = benchmark.run()
|
||||
self.check_results_dict_not_empty(results.time_inference_result)
|
||||
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||
|
||||
def test_train_no_configs(self):
|
||||
MODEL_ID = "sshleifer/tiny-gpt2"
|
||||
benchmark_args = PyTorchBenchmarkArguments(
|
||||
@ -76,6 +91,22 @@ class BenchmarkTest(unittest.TestCase):
|
||||
self.check_results_dict_not_empty(results.time_train_result)
|
||||
self.check_results_dict_not_empty(results.memory_train_result)
|
||||
|
||||
def test_train_with_configs_torchscript(self):
|
||||
MODEL_ID = "sshleifer/tiny-gpt2"
|
||||
config = AutoConfig.from_pretrained(MODEL_ID)
|
||||
benchmark_args = PyTorchBenchmarkArguments(
|
||||
models=[MODEL_ID],
|
||||
training=True,
|
||||
no_inference=True,
|
||||
torchscript=True,
|
||||
sequence_lengths=[8],
|
||||
batch_sizes=[1],
|
||||
)
|
||||
benchmark = PyTorchBenchmark(benchmark_args, configs=[config])
|
||||
results = benchmark.run()
|
||||
self.check_results_dict_not_empty(results.time_train_result)
|
||||
self.check_results_dict_not_empty(results.memory_train_result)
|
||||
|
||||
def test_train_encoder_decoder_with_configs(self):
|
||||
MODEL_ID = "sshleifer/tinier_bart"
|
||||
config = AutoConfig.from_pretrained(MODEL_ID)
|
||||
|
Loading…
Reference in New Issue
Block a user