CPU/GPU memory benchmarking utilities - Remove support for python 3.5 (now only 3.6+) (#3186)

* memory benchmark rss

* have both forward pass and line-by-line mem tracing

* cleaned up tracing

* refactored and cleaning up API

* no f-strings yet...

* add GPU mem logging

* fix GPU memory monitoring

* style and quality

* clean up and doc

* update with comments

* Switching to python 3.6+

* fix quality
This commit is contained in:
Thomas Wolf 2020-03-17 15:17:11 +01:00 committed by GitHub
parent bd3feddf67
commit 2187c49f5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 565 additions and 32 deletions

View File

@ -3,7 +3,7 @@ jobs:
run_tests_torch_and_tf:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
- image: circleci/python:3.6
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
@ -46,7 +46,7 @@ jobs:
run_tests_custom_tokenizers:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
- image: circleci/python:3.6
environment:
RUN_CUSTOM_TOKENIZERS: yes
steps:
@ -56,7 +56,7 @@ jobs:
run_examples_torch:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
- image: circleci/python:3.6
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
@ -69,7 +69,7 @@ jobs:
deploy_doc:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
- image: circleci/python:3.6
steps:
- add_ssh_keys:
fingerprints:
@ -94,7 +94,7 @@ jobs:
check_repository_consistency:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
- image: circleci/python:3.6
resource_class: small
parallelism: 1
steps:

View File

@ -66,7 +66,7 @@ Choose the right framework for every part of a model's lifetime
## Installation
This repo is tested on Python 3.5+, PyTorch 1.0.0+ and TensorFlow 2.0.0-rc1
This repo is tested on Python 3.6+, PyTorch 1.0.0+ and TensorFlow 2.0.0-rc1
You should install 🤗 Transformers in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).

View File

@ -1,6 +1,6 @@
# Installation
Transformers is tested on Python 3.5+ and PyTorch 1.1.0
Transformers is tested on Python 3.6+ and PyTorch 1.1.0
## With pip

View File

@ -24,7 +24,15 @@ import timeit
from time import time
from typing import List
from transformers import AutoConfig, AutoTokenizer, is_tf_available, is_torch_available
from transformers import (
AutoConfig,
AutoTokenizer,
MemorySummary,
is_tf_available,
is_torch_available,
start_memory_tracing,
stop_memory_tracing,
)
if is_tf_available():
@ -250,15 +258,21 @@ as they entered."""
def create_setup_and_compute(
model_names: List[str],
batch_sizes: List[int],
slice_sizes: List[int],
gpu: bool = True,
tensorflow: bool = False,
average_over: int = 3,
no_speed: bool = False,
no_memory: bool = False,
verbose: bool = False,
torchscript: bool = False,
xla: bool = False,
amp: bool = False,
fp16: bool = False,
save_to_csv: bool = False,
csv_filename: str = f"results_{round(time())}.csv",
csv_memory_filename: str = f"memory_{round(time())}.csv",
):
if xla:
tf.config.optimizer.set_jit(True)
@ -267,11 +281,25 @@ def create_setup_and_compute(
if tensorflow:
dictionary = {model_name: {} for model_name in model_names}
results = _compute_tensorflow(model_names, dictionary, average_over, amp)
results = _compute_tensorflow(
model_names, batch_sizes, slice_sizes, dictionary, average_over, amp, no_speed, no_memory, verbose
)
else:
device = "cuda" if (gpu and torch.cuda.is_available()) else "cpu"
dictionary = {model_name: {} for model_name in model_names}
results = _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16)
results = _compute_pytorch(
model_names,
batch_sizes,
slice_sizes,
dictionary,
average_over,
device,
torchscript,
fp16,
no_speed,
no_memory,
verbose,
)
print("=========== RESULTS ===========")
for model_name in model_names:
@ -280,13 +308,19 @@ def create_setup_and_compute(
print("\t\t" + f"===== BATCH SIZE: {batch_size} =====")
for slice_size in results[model_name]["ss"]:
result = results[model_name]["results"][batch_size][slice_size]
memory = results[model_name]["memory"][batch_size][slice_size]
if isinstance(result, str):
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{result}")
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{result} " f"{memory}")
else:
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{(round(1000 * result) / 1000)}" f"s")
print(
f"\t\t{model_name}/{batch_size}/{slice_size}: "
f"{(round(1000 * result) / 1000)}"
f"s "
f"{memory}"
)
if save_to_csv:
with open(csv_filename, mode="w") as csv_file:
with open(csv_filename, mode="w") as csv_file, open(csv_memory_filename, mode="w") as csv_memory_file:
fieldnames = [
"model",
"1x8",
@ -317,6 +351,8 @@ def create_setup_and_compute(
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
writer.writeheader()
memory_writer = csv.DictWriter(csv_memory_file, fieldnames=fieldnames)
memory_writer.writeheader()
for model_name in model_names:
model_results = {
@ -326,8 +362,52 @@ def create_setup_and_compute(
}
writer.writerow({"model": model_name, **model_results})
model_memory_results = {
f"{bs}x{ss}": results[model_name]["memory"][bs][ss]
for bs in results[model_name]["memory"]
for ss in results[model_name]["memory"][bs]
}
memory_writer.writerow({"model": model_name, **model_memory_results})
def _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16):
def print_summary_statistics(summary: MemorySummary):
print(
"\nLines by line memory consumption:\n"
+ "\n".join(
f"{state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
for state in summary.sequential
)
)
print(
"\nLines with top memory consumption:\n"
+ "\n".join(
f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
for state in summary.cumulative[:6]
)
)
print(
"\nLines with lowest memory consumption:\n"
+ "\n".join(
f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
for state in summary.cumulative[-6:]
)
)
print(f"\nTotal memory increase: {summary.total}")
def _compute_pytorch(
model_names,
batch_sizes,
slice_sizes,
dictionary,
average_over,
device,
torchscript,
fp16,
no_speed,
no_memory,
verbose,
):
for c, model_name in enumerate(model_names):
print(f"{c + 1} / {len(model_names)}")
config = AutoConfig.from_pretrained(model_name, torchscript=torchscript)
@ -337,17 +417,17 @@ def _compute_pytorch(model_names, dictionary, average_over, device, torchscript,
tokenized_sequence = tokenizer.encode(input_text, add_special_tokens=False)
max_input_size = tokenizer.max_model_input_sizes[model_name]
batch_sizes = [1, 2, 4, 8]
slice_sizes = [8, 64, 128, 256, 512, 1024]
dictionary[model_name] = {"bs": batch_sizes, "ss": slice_sizes, "results": {}}
dictionary[model_name] = {"bs": batch_sizes, "ss": slice_sizes, "results": {}, "memory": {}}
dictionary[model_name]["results"] = {i: {} for i in batch_sizes}
dictionary[model_name]["memory"] = {i: {} for i in batch_sizes}
for batch_size in batch_sizes:
if fp16:
model.half()
model.to(device)
model.eval()
for slice_size in slice_sizes:
if max_input_size is not None and slice_size > max_input_size:
dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
@ -362,18 +442,40 @@ def _compute_pytorch(model_names, dictionary, average_over, device, torchscript,
inference = model
inference(sequence)
print("Going through model with sequence of shape", sequence.shape)
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
average_time = sum(runtimes) / float(len(runtimes)) / 3.0
dictionary[model_name]["results"][batch_size][slice_size] = average_time
if not no_memory:
# model.add_memory_hooks() # Forward method tracing (only for PyTorch models)
# Line by line memory tracing (all code in the module `transformers`) works for all models/arbitrary code
trace = start_memory_tracing("transformers")
inference(sequence)
summary = stop_memory_tracing(trace)
if verbose:
print_summary_statistics(summary)
dictionary[model_name]["memory"][batch_size][slice_size] = str(summary.total)
else:
dictionary[model_name]["memory"][batch_size][slice_size] = "N/A"
if not no_speed:
print("Going through model with sequence of shape", sequence.shape)
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
average_time = sum(runtimes) / float(len(runtimes)) / 3.0
dictionary[model_name]["results"][batch_size][slice_size] = average_time
else:
dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
except RuntimeError as e:
print("Doesn't fit on GPU.", e)
torch.cuda.empty_cache()
dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
dictionary[model_name]["memory"][batch_size][slice_size] = "N/A"
return dictionary
def _compute_tensorflow(model_names, dictionary, average_over, amp):
def _compute_tensorflow(
model_names, batch_sizes, slice_sizes, dictionary, average_over, amp, no_speed, no_memory, verbose
):
for c, model_name in enumerate(model_names):
print(f"{c + 1} / {len(model_names)}")
config = AutoConfig.from_pretrained(model_name)
@ -383,11 +485,10 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
tokenized_sequence = tokenizer.encode(input_text, add_special_tokens=False)
max_input_size = tokenizer.max_model_input_sizes[model_name]
batch_sizes = [1, 2, 4, 8]
slice_sizes = [8, 64, 128, 256, 512, 1024]
dictionary[model_name] = {"bs": batch_sizes, "ss": slice_sizes, "results": {}}
dictionary[model_name] = {"bs": batch_sizes, "ss": slice_sizes, "results": {}, "memory": {}}
dictionary[model_name]["results"] = {i: {} for i in batch_sizes}
dictionary[model_name]["memory"] = {i: {} for i in batch_sizes}
print("Using model", model)
@ -409,13 +510,31 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
# To make sure that the model is traced + that the tensors are on the appropriate device
inference(sequence)
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
average_time = sum(runtimes) / float(len(runtimes)) / 3.0
dictionary[model_name]["results"][batch_size][slice_size] = average_time
if not no_memory:
# Line by line memory tracing (all code in the module `transformers`) works for all models/arbitrary code
trace = start_memory_tracing("transformers")
inference(sequence)
summary = stop_memory_tracing(trace)
if verbose:
print_summary_statistics(summary)
dictionary[model_name]["memory"][batch_size][slice_size] = str(summary.total)
else:
dictionary[model_name]["memory"][batch_size][slice_size] = "N/A"
if not no_speed:
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
average_time = sum(runtimes) / float(len(runtimes)) / 3.0
dictionary[model_name]["results"][batch_size][slice_size] = average_time
else:
dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
except tf.errors.ResourceExhaustedError as e:
print("Doesn't fit on GPU.", e)
torch.cuda.empty_cache()
dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
dictionary[model_name]["memory"][batch_size][slice_size] = "N/A"
return dictionary
@ -433,6 +552,9 @@ def main():
"of all available model "
"architectures.",
)
parser.add_argument("--verbose", required=False, action="store_true", help="Verbose memory tracing")
parser.add_argument("--no_speed", required=False, action="store_true", help="Don't perform speed measurments")
parser.add_argument("--no_memory", required=False, action="store_true", help="Don't perform memory measurments")
parser.add_argument(
"--torch", required=False, action="store_true", help="Benchmark the Pytorch version of the " "models"
)
@ -477,6 +599,8 @@ def main():
parser.add_argument(
"--average_over", required=False, default=30, type=int, help="Times an experiment will be run."
)
parser.add_argument("--batch_sizes", nargs="+", type=int, default=[1, 2, 4, 8])
parser.add_argument("--slice_sizes", nargs="+", type=int, default=[8, 64, 128, 256, 512, 1024])
args = parser.parse_args()
if args.models == "all":
@ -501,6 +625,8 @@ def main():
if is_torch_available():
create_setup_and_compute(
model_names=args.models,
batch_sizes=args.batch_sizes,
slice_sizes=args.slice_sizes,
tensorflow=False,
gpu=args.torch_cuda,
torchscript=args.torchscript,
@ -508,6 +634,9 @@ def main():
save_to_csv=args.save_to_csv,
csv_filename=args.csv_filename,
average_over=args.average_over,
no_speed=args.no_speed,
no_memory=args.no_memory,
verbose=args.verbose,
)
else:
raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.")
@ -516,12 +645,17 @@ def main():
if is_tf_available():
create_setup_and_compute(
model_names=args.models,
batch_sizes=args.batch_sizes,
slice_sizes=args.slice_sizes,
tensorflow=True,
xla=args.xla,
amp=args.amp,
save_to_csv=args.save_to_csv,
csv_filename=args.csv_filename,
average_over=args.average_over,
no_speed=args.no_speed,
no_memory=args.no_memory,
verbose=args.verbose,
)
else:
raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.")

View File

@ -2,3 +2,4 @@ tensorboardX
tensorboard
scikit-learn
seqeval
psutil

View File

@ -110,7 +110,7 @@ setup(
],
extras_require=extras,
scripts=["transformers-cli"],
python_requires=">=3.5.0",
python_requires=">=3.6.0",
classifiers=[
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
@ -119,7 +119,6 @@ setup(
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Topic :: Scientific/Engineering :: Artificial Intelligence",

View File

@ -19,6 +19,18 @@ else:
import logging
# Benchmarking
from .benchmark_utils import (
Frame,
Memory,
MemoryState,
MemorySummary,
MemoryTrace,
UsedMemoryState,
bytes_to_human_readable,
start_memory_tracing,
stop_memory_tracing,
)
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
from .configuration_bart import BartConfig

View File

@ -0,0 +1,341 @@
"""
Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""
import linecache
import logging
import os
import sys
from collections import defaultdict
from typing import Iterable, List, NamedTuple, Optional, Union
from .file_utils import is_tf_available, is_torch_available
if is_torch_available():
from torch.cuda import empty_cache as torch_empty_cache
if is_tf_available():
from tensorflow.python.eager import context as tf_context
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
_is_memory_tracing_enabled = False
def is_memory_tracing_enabled():
global _is_memory_tracing_enabled
return _is_memory_tracing_enabled
class Frame(NamedTuple):
""" `Frame` is a NamedTuple used to gather the current frame state.
`Frame` has the following fields:
- 'filename' (string): Name of the file currently executed
- 'module' (string): Name of the module currently executed
- 'line_number' (int): Number of the line currently executed
- 'event' (string): Event that triggered the tracing (default will be "line")
- 'line_text' (string): Text of the line in the python script
"""
filename: str
module: str
line_number: int
event: str
line_text: str
class UsedMemoryState(NamedTuple):
""" `UsedMemoryState` are named tuples with the following fields:
- 'frame': a `Frame` namedtuple (see below) storing information on the current tracing frame (current file, location in current file)
- 'cpu_memory': CPU RSS memory state *before* executing the line
- 'gpu_memory': GPU used memory *before* executing the line (sum for all GPUs or for only `gpus_to_trace` if provided)
"""
frame: Frame
cpu_memory: int
gpu_memory: int
class Memory(NamedTuple):
""" `Memory` NamedTuple have a single field `bytes` and
you can get a human readable string of the number of bytes by calling `__repr__`
- `byte` (integer): number of bytes,
"""
bytes: int
def __repr__(self) -> str:
return bytes_to_human_readable(self.bytes)
class MemoryState(NamedTuple):
""" `MemoryState` are namedtuples listing frame + CPU/GPU memory with the following fields:
- `frame` (`Frame`): the current frame (see above)
- `cpu`: CPU memory consumed at during the current frame as a `Memory` named tuple
- `gpu`: GPU memory consumed at during the current frame as a `Memory` named tuple
- `cpu_gpu`: CPU + GPU memory consumed at during the current frame as a `Memory` named tuple
"""
frame: Frame
cpu: Memory
gpu: Memory
cpu_gpu: Memory
class MemorySummary(NamedTuple):
""" `MemorySummary` namedtuple otherwise with the fields:
- `sequential`: a list of `MemoryState` namedtuple (see below) computed from the provided `memory_trace`
by substracting the memory after executing each line from the memory before executing said line.
- `cumulative`: a list of `MemoryState` namedtuple (see below) with cumulative increase in memory for each line
obtained by summing repeted memory increase for a line if it's executed several times.
The list is sorted from the frame with the largest memory consumption to the frame with the smallest (can be negative if memory is released)
- `total`: total memory increase during the full tracing as a `Memory` named tuple (see below).
Line with memory release (negative consumption) are ignored if `ignore_released_memory` is `True` (default).
"""
sequential: List[MemoryState]
cumulative: List[MemoryState]
total: Memory
MemoryTrace = List[UsedMemoryState]
def start_memory_tracing(
modules_to_trace: Optional[Union[str, Iterable[str]]] = None,
modules_not_to_trace: Optional[Union[str, Iterable[str]]] = None,
events_to_trace: str = "line",
gpus_to_trace: Optional[List[int]] = None,
) -> MemoryTrace:
""" Setup line-by-line tracing to record rss mem (RAM) at each line of a module or sub-module.
See `../../examples/benchmarks.py for a usage example.
Current memory consumption is returned using psutil and in particular is the RSS memory
"Resident Set Size” (the non-swapped physical memory the process is using).
See https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info
Args:
- `modules_to_trace`: (None, string, list/tuple of string)
if None, all events are recorded
if string or list of strings: only events from the listed module/sub-module will be recorded (e.g. 'fairseq' or 'transformers.modeling_gpt2')
- `modules_not_to_trace`: (None, string, list/tuple of string)
if None, no module is avoided
if string or list of strings: events from the listed module/sub-module will not be recorded (e.g. 'torch')
- `events_to_trace`: string or list of string of events to be recorded (see official python doc for `sys.settrace` for the list of events)
default to line
- `gpus_to_trace`: (optional list, default None) list of GPUs to trace. Default to tracing all GPUs
Return:
- `memory_trace` is a list of `UsedMemoryState` for each event (default each line of the traced script).
- `UsedMemoryState` are named tuples with the following fields:
- 'frame': a `Frame` namedtuple (see below) storing information on the current tracing frame (current file, location in current file)
- 'cpu_memory': CPU RSS memory state *before* executing the line
- 'gpu_memory': GPU used memory *before* executing the line (sum for all GPUs or for only `gpus_to_trace` if provided)
`Frame` is a namedtuple used by `UsedMemoryState` to list the current frame state.
`Frame` has the following fields:
- 'filename' (string): Name of the file currently executed
- 'module' (string): Name of the module currently executed
- 'line_number' (int): Number of the line currently executed
- 'event' (string): Event that triggered the tracing (default will be "line")
- 'line_text' (string): Text of the line in the python script
"""
try:
import psutil
except (ImportError):
logger.warning(
"Psutil not installed, we won't log CPU memory usage. "
"Install psutil (pip install psutil) to use CPU memory tracing."
)
process = None
else:
process = psutil.Process(os.getpid())
try:
from py3nvml import py3nvml
py3nvml.nvmlInit()
devices = list(range(py3nvml.nvmlDeviceGetCount())) if gpus_to_trace is None else gpus_to_trace
py3nvml.nvmlShutdown()
except ImportError:
logger.warning(
"py3nvml not installed, we won't log GPU memory usage. "
"Install py3nvml (pip install py3nvml) to use GPU memory tracing."
)
log_gpu = False
except (OSError, py3nvml.NVMLError):
logger.warning("Error while initializing comunication with GPU. " "We won't perform GPU memory tracing.")
log_gpu = False
else:
log_gpu = is_torch_available() or is_tf_available()
memory_trace = []
def traceit(frame, event, args):
""" Tracing method executed before running each line in a module or sub-module
Record memory allocated in a list with debugging information
"""
global _is_memory_tracing_enabled
if not _is_memory_tracing_enabled:
return traceit
# Filter events
if events_to_trace is not None:
if isinstance(events_to_trace, str) and event != events_to_trace:
return traceit
elif isinstance(events_to_trace, (list, tuple)) and event not in events_to_trace:
return traceit
# Filter modules
name = frame.f_globals["__name__"]
if not isinstance(name, str):
return traceit
else:
# Filter whitelist of modules to trace
if modules_to_trace is not None:
if isinstance(modules_to_trace, str) and modules_to_trace not in name:
return traceit
elif isinstance(modules_to_trace, (list, tuple)) and all(m not in name for m in modules_to_trace):
return traceit
# Filter blacklist of modules not to trace
if modules_not_to_trace is not None:
if isinstance(modules_not_to_trace, str) and modules_not_to_trace in name:
return traceit
elif isinstance(modules_not_to_trace, (list, tuple)) and any(m in name for m in modules_not_to_trace):
return traceit
# Record current tracing state (file, location in file...)
lineno = frame.f_lineno
filename = frame.f_globals["__file__"]
if filename.endswith(".pyc") or filename.endswith(".pyo"):
filename = filename[:-1]
line = linecache.getline(filename, lineno).rstrip()
traced_state = Frame(filename, name, lineno, event, line)
# Record current memory state (rss memory) and compute difference with previous memory state
cpu_mem = 0
if process is not None:
mem = process.memory_info()
cpu_mem = mem.rss
gpu_mem = 0
if log_gpu:
# Clear GPU caches
if is_torch_available():
torch_empty_cache()
if is_tf_available():
tf_context.context()._clear_caches() # See https://github.com/tensorflow/tensorflow/issues/20218#issuecomment-416771802
# Sum used memory for all GPUs
py3nvml.nvmlInit()
for i in devices:
handle = py3nvml.nvmlDeviceGetHandleByIndex(i)
meminfo = py3nvml.nvmlDeviceGetMemoryInfo(handle)
gpu_mem += meminfo.used
py3nvml.nvmlShutdown()
mem_state = UsedMemoryState(traced_state, cpu_mem, gpu_mem)
memory_trace.append(mem_state)
return traceit
sys.settrace(traceit)
global _is_memory_tracing_enabled
_is_memory_tracing_enabled = True
return memory_trace
def stop_memory_tracing(
memory_trace: Optional[MemoryTrace] = None, ignore_released_memory: bool = True
) -> Optional[MemorySummary]:
""" Stop memory tracing cleanly and return a summary of the memory trace if a trace is given.
Args:
- `memory_trace` (optional output of start_memory_tracing, default: None): memory trace to convert in summary
- `ignore_released_memory` (boolean, default: None): if True we only sum memory increase to compute total memory
Return:
- None if `memory_trace` is None
- `MemorySummary` namedtuple otherwise with the fields:
- `sequential`: a list of `MemoryState` namedtuple (see below) computed from the provided `memory_trace`
by substracting the memory after executing each line from the memory before executing said line.
- `cumulative`: a list of `MemoryState` namedtuple (see below) with cumulative increase in memory for each line
obtained by summing repeted memory increase for a line if it's executed several times.
The list is sorted from the frame with the largest memory consumption to the frame with the smallest (can be negative if memory is released)
- `total`: total memory increase during the full tracing as a `Memory` named tuple (see below).
Line with memory release (negative consumption) are ignored if `ignore_released_memory` is `True` (default).
`Memory` named tuple have fields
- `byte` (integer): number of bytes,
- `string` (string): same as human readable string (ex: "3.5MB")
`Frame` are namedtuple used to list the current frame state and have the following fields:
- 'filename' (string): Name of the file currently executed
- 'module' (string): Name of the module currently executed
- 'line_number' (int): Number of the line currently executed
- 'event' (string): Event that triggered the tracing (default will be "line")
- 'line_text' (string): Text of the line in the python script
`MemoryState` are namedtuples listing frame + CPU/GPU memory with the following fields:
- `frame` (`Frame`): the current frame (see above)
- `cpu`: CPU memory consumed at during the current frame as a `Memory` named tuple
- `gpu`: GPU memory consumed at during the current frame as a `Memory` named tuple
- `cpu_gpu`: CPU + GPU memory consumed at during the current frame as a `Memory` named tuple
"""
global _is_memory_tracing_enabled
_is_memory_tracing_enabled = False
if memory_trace is not None and len(memory_trace) > 1:
memory_diff_trace = []
cumulative_memory_dict = defaultdict(lambda: [0, 0, 0])
for (frame, cpu_mem, gpu_mem), (next_frame, next_cpu_mem, next_gpu_mem) in zip(
memory_trace[:-1], memory_trace[1:]
):
cpu_mem_inc = next_cpu_mem - cpu_mem
gpu_mem_inc = next_gpu_mem - gpu_mem
cpu_gpu_mem_inc = cpu_mem_inc + gpu_mem_inc
memory_diff_trace.append(
MemoryState(
frame=frame, cpu=Memory(cpu_mem_inc), gpu=Memory(gpu_mem_inc), cpu_gpu=Memory(cpu_gpu_mem_inc),
)
)
cumulative_memory_dict[frame][0] += cpu_mem_inc
cumulative_memory_dict[frame][1] += gpu_mem_inc
cumulative_memory_dict[frame][2] += cpu_gpu_mem_inc
cumulative_memory = sorted(
list(cumulative_memory_dict.items()), key=lambda x: x[1][2], reverse=True
) # order by the total CPU + GPU memory increase
cumulative_memory = list(
MemoryState(
frame=frame, cpu=Memory(cpu_mem_inc), gpu=Memory(gpu_mem_inc), cpu_gpu=Memory(cpu_gpu_mem_inc),
)
for frame, (cpu_mem_inc, gpu_mem_inc, cpu_gpu_mem_inc) in cumulative_memory
)
if ignore_released_memory:
total_memory = sum(max(0, step_trace.cpu_gpu.bytes) for step_trace in memory_diff_trace)
else:
total_memory = sum(step_trace.cpu_gpu.bytes for step_trace in memory_diff_trace)
total_memory = Memory(total_memory)
return MemorySummary(sequential=memory_diff_trace, cumulative=cumulative_memory, total=total_memory)
return None
def bytes_to_human_readable(memory_amount: int) -> str:
""" Utility to convert a number of bytes (int) in a human readable string (with units)
"""
for unit in ["B", "KB", "MB", "GB"]:
if memory_amount > -1024.0 and memory_amount < 1024.0:
return "{:.3f}{}".format(memory_amount, unit)
memory_amount /= 1024.0
return "{:.3f}TB".format(memory_amount)

View File

@ -59,6 +59,8 @@ class GPT2Config(PretrainedConfig):
Number of hidden layers in the Transformer encoder.
n_head (:obj:`int`, optional, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
activation_function (:obj:`str`, optional, defaults to 'gelu'):
Activation function selected in the list ["relu", "swish", "gelu", "tanh", "gelu_new"].
resid_pdrop (:obj:`float`, optional, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
embd_pdrop (:obj:`int`, optional, defaults to 0.1):
@ -125,6 +127,7 @@ class GPT2Config(PretrainedConfig):
n_embd=768,
n_layer=12,
n_head=12,
activation_function="gelu_new",
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
@ -147,6 +150,7 @@ class GPT2Config(PretrainedConfig):
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.activation_function = activation_function
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop

View File

@ -24,7 +24,7 @@ import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from .activations import gelu_new
from .activations import ACT2FN
from .configuration_gpt2 import GPT2Config
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer
@ -203,7 +203,7 @@ class MLP(nn.Module):
nx = config.n_embd
self.c_fc = Conv1D(n_state, nx)
self.c_proj = Conv1D(nx, n_state)
self.act = gelu_new
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, x):

View File

@ -39,6 +39,7 @@ from .file_utils import (
logger = logging.getLogger(__name__)
try:
from torch.nn import Identity
except ImportError:
@ -66,6 +67,47 @@ class ModuleUtilsMixin:
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)
@staticmethod
def _hook_rss_memory_pre_forward(module, *args, **kwargs):
try:
import psutil
except (ImportError):
raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
process = psutil.Process(os.getpid())
mem = process.memory_info()
module.mem_rss_pre_forward = mem.rss
return None
@staticmethod
def _hook_rss_memory_post_forward(module, *args, **kwargs):
try:
import psutil
except (ImportError):
raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
process = psutil.Process(os.getpid())
mem = process.memory_info()
module.mem_rss_post_forward = mem.rss
mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
return None
def add_memory_hooks(self):
""" Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()`
"""
for module in self.modules():
module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
module.register_forward_hook(self._hook_rss_memory_post_forward)
self.reset_memory_hooks_state()
def reset_memory_hooks_state(self):
for module in self.modules():
module.mem_rss_diff = 0
module.mem_rss_post_forward = 0
module.mem_rss_pre_forward = 0
class PreTrainedModel(nn.Module, ModuleUtilsMixin):
r""" Base class for all models.