mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
CPU/GPU memory benchmarking utilities - Remove support for python 3.5 (now only 3.6+) (#3186)
* memory benchmark rss * have both forward pass and line-by-line mem tracing * cleaned up tracing * refactored and cleaning up API * no f-strings yet... * add GPU mem logging * fix GPU memory monitoring * style and quality * clean up and doc * update with comments * Switching to python 3.6+ * fix quality
This commit is contained in:
parent
bd3feddf67
commit
2187c49f5c
@ -3,7 +3,7 @@ jobs:
|
||||
run_tests_torch_and_tf:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
- image: circleci/python:3.5
|
||||
- image: circleci/python:3.6
|
||||
environment:
|
||||
OMP_NUM_THREADS: 1
|
||||
resource_class: xlarge
|
||||
@ -46,7 +46,7 @@ jobs:
|
||||
run_tests_custom_tokenizers:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
- image: circleci/python:3.5
|
||||
- image: circleci/python:3.6
|
||||
environment:
|
||||
RUN_CUSTOM_TOKENIZERS: yes
|
||||
steps:
|
||||
@ -56,7 +56,7 @@ jobs:
|
||||
run_examples_torch:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
- image: circleci/python:3.5
|
||||
- image: circleci/python:3.6
|
||||
environment:
|
||||
OMP_NUM_THREADS: 1
|
||||
resource_class: xlarge
|
||||
@ -69,7 +69,7 @@ jobs:
|
||||
deploy_doc:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
- image: circleci/python:3.5
|
||||
- image: circleci/python:3.6
|
||||
steps:
|
||||
- add_ssh_keys:
|
||||
fingerprints:
|
||||
@ -94,7 +94,7 @@ jobs:
|
||||
check_repository_consistency:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
- image: circleci/python:3.5
|
||||
- image: circleci/python:3.6
|
||||
resource_class: small
|
||||
parallelism: 1
|
||||
steps:
|
||||
|
@ -66,7 +66,7 @@ Choose the right framework for every part of a model's lifetime
|
||||
|
||||
## Installation
|
||||
|
||||
This repo is tested on Python 3.5+, PyTorch 1.0.0+ and TensorFlow 2.0.0-rc1
|
||||
This repo is tested on Python 3.6+, PyTorch 1.0.0+ and TensorFlow 2.0.0-rc1
|
||||
|
||||
You should install 🤗 Transformers in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Installation
|
||||
|
||||
Transformers is tested on Python 3.5+ and PyTorch 1.1.0
|
||||
Transformers is tested on Python 3.6+ and PyTorch 1.1.0
|
||||
|
||||
## With pip
|
||||
|
||||
|
@ -24,7 +24,15 @@ import timeit
|
||||
from time import time
|
||||
from typing import List
|
||||
|
||||
from transformers import AutoConfig, AutoTokenizer, is_tf_available, is_torch_available
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
MemorySummary,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
start_memory_tracing,
|
||||
stop_memory_tracing,
|
||||
)
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
@ -250,15 +258,21 @@ as they entered."""
|
||||
|
||||
def create_setup_and_compute(
|
||||
model_names: List[str],
|
||||
batch_sizes: List[int],
|
||||
slice_sizes: List[int],
|
||||
gpu: bool = True,
|
||||
tensorflow: bool = False,
|
||||
average_over: int = 3,
|
||||
no_speed: bool = False,
|
||||
no_memory: bool = False,
|
||||
verbose: bool = False,
|
||||
torchscript: bool = False,
|
||||
xla: bool = False,
|
||||
amp: bool = False,
|
||||
fp16: bool = False,
|
||||
save_to_csv: bool = False,
|
||||
csv_filename: str = f"results_{round(time())}.csv",
|
||||
csv_memory_filename: str = f"memory_{round(time())}.csv",
|
||||
):
|
||||
if xla:
|
||||
tf.config.optimizer.set_jit(True)
|
||||
@ -267,11 +281,25 @@ def create_setup_and_compute(
|
||||
|
||||
if tensorflow:
|
||||
dictionary = {model_name: {} for model_name in model_names}
|
||||
results = _compute_tensorflow(model_names, dictionary, average_over, amp)
|
||||
results = _compute_tensorflow(
|
||||
model_names, batch_sizes, slice_sizes, dictionary, average_over, amp, no_speed, no_memory, verbose
|
||||
)
|
||||
else:
|
||||
device = "cuda" if (gpu and torch.cuda.is_available()) else "cpu"
|
||||
dictionary = {model_name: {} for model_name in model_names}
|
||||
results = _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16)
|
||||
results = _compute_pytorch(
|
||||
model_names,
|
||||
batch_sizes,
|
||||
slice_sizes,
|
||||
dictionary,
|
||||
average_over,
|
||||
device,
|
||||
torchscript,
|
||||
fp16,
|
||||
no_speed,
|
||||
no_memory,
|
||||
verbose,
|
||||
)
|
||||
|
||||
print("=========== RESULTS ===========")
|
||||
for model_name in model_names:
|
||||
@ -280,13 +308,19 @@ def create_setup_and_compute(
|
||||
print("\t\t" + f"===== BATCH SIZE: {batch_size} =====")
|
||||
for slice_size in results[model_name]["ss"]:
|
||||
result = results[model_name]["results"][batch_size][slice_size]
|
||||
memory = results[model_name]["memory"][batch_size][slice_size]
|
||||
if isinstance(result, str):
|
||||
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{result}")
|
||||
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{result} " f"{memory}")
|
||||
else:
|
||||
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{(round(1000 * result) / 1000)}" f"s")
|
||||
print(
|
||||
f"\t\t{model_name}/{batch_size}/{slice_size}: "
|
||||
f"{(round(1000 * result) / 1000)}"
|
||||
f"s "
|
||||
f"{memory}"
|
||||
)
|
||||
|
||||
if save_to_csv:
|
||||
with open(csv_filename, mode="w") as csv_file:
|
||||
with open(csv_filename, mode="w") as csv_file, open(csv_memory_filename, mode="w") as csv_memory_file:
|
||||
fieldnames = [
|
||||
"model",
|
||||
"1x8",
|
||||
@ -317,6 +351,8 @@ def create_setup_and_compute(
|
||||
|
||||
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
memory_writer = csv.DictWriter(csv_memory_file, fieldnames=fieldnames)
|
||||
memory_writer.writeheader()
|
||||
|
||||
for model_name in model_names:
|
||||
model_results = {
|
||||
@ -326,8 +362,52 @@ def create_setup_and_compute(
|
||||
}
|
||||
writer.writerow({"model": model_name, **model_results})
|
||||
|
||||
model_memory_results = {
|
||||
f"{bs}x{ss}": results[model_name]["memory"][bs][ss]
|
||||
for bs in results[model_name]["memory"]
|
||||
for ss in results[model_name]["memory"][bs]
|
||||
}
|
||||
memory_writer.writerow({"model": model_name, **model_memory_results})
|
||||
|
||||
def _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16):
|
||||
|
||||
def print_summary_statistics(summary: MemorySummary):
|
||||
print(
|
||||
"\nLines by line memory consumption:\n"
|
||||
+ "\n".join(
|
||||
f"{state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
|
||||
for state in summary.sequential
|
||||
)
|
||||
)
|
||||
print(
|
||||
"\nLines with top memory consumption:\n"
|
||||
+ "\n".join(
|
||||
f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
|
||||
for state in summary.cumulative[:6]
|
||||
)
|
||||
)
|
||||
print(
|
||||
"\nLines with lowest memory consumption:\n"
|
||||
+ "\n".join(
|
||||
f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
|
||||
for state in summary.cumulative[-6:]
|
||||
)
|
||||
)
|
||||
print(f"\nTotal memory increase: {summary.total}")
|
||||
|
||||
|
||||
def _compute_pytorch(
|
||||
model_names,
|
||||
batch_sizes,
|
||||
slice_sizes,
|
||||
dictionary,
|
||||
average_over,
|
||||
device,
|
||||
torchscript,
|
||||
fp16,
|
||||
no_speed,
|
||||
no_memory,
|
||||
verbose,
|
||||
):
|
||||
for c, model_name in enumerate(model_names):
|
||||
print(f"{c + 1} / {len(model_names)}")
|
||||
config = AutoConfig.from_pretrained(model_name, torchscript=torchscript)
|
||||
@ -337,17 +417,17 @@ def _compute_pytorch(model_names, dictionary, average_over, device, torchscript,
|
||||
tokenized_sequence = tokenizer.encode(input_text, add_special_tokens=False)
|
||||
|
||||
max_input_size = tokenizer.max_model_input_sizes[model_name]
|
||||
batch_sizes = [1, 2, 4, 8]
|
||||
slice_sizes = [8, 64, 128, 256, 512, 1024]
|
||||
|
||||
dictionary[model_name] = {"bs": batch_sizes, "ss": slice_sizes, "results": {}}
|
||||
dictionary[model_name] = {"bs": batch_sizes, "ss": slice_sizes, "results": {}, "memory": {}}
|
||||
dictionary[model_name]["results"] = {i: {} for i in batch_sizes}
|
||||
dictionary[model_name]["memory"] = {i: {} for i in batch_sizes}
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
if fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
for slice_size in slice_sizes:
|
||||
if max_input_size is not None and slice_size > max_input_size:
|
||||
dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
|
||||
@ -362,18 +442,40 @@ def _compute_pytorch(model_names, dictionary, average_over, device, torchscript,
|
||||
inference = model
|
||||
inference(sequence)
|
||||
|
||||
print("Going through model with sequence of shape", sequence.shape)
|
||||
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
|
||||
average_time = sum(runtimes) / float(len(runtimes)) / 3.0
|
||||
dictionary[model_name]["results"][batch_size][slice_size] = average_time
|
||||
if not no_memory:
|
||||
# model.add_memory_hooks() # Forward method tracing (only for PyTorch models)
|
||||
|
||||
# Line by line memory tracing (all code in the module `transformers`) works for all models/arbitrary code
|
||||
trace = start_memory_tracing("transformers")
|
||||
inference(sequence)
|
||||
summary = stop_memory_tracing(trace)
|
||||
|
||||
if verbose:
|
||||
print_summary_statistics(summary)
|
||||
|
||||
dictionary[model_name]["memory"][batch_size][slice_size] = str(summary.total)
|
||||
else:
|
||||
dictionary[model_name]["memory"][batch_size][slice_size] = "N/A"
|
||||
|
||||
if not no_speed:
|
||||
print("Going through model with sequence of shape", sequence.shape)
|
||||
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
|
||||
average_time = sum(runtimes) / float(len(runtimes)) / 3.0
|
||||
dictionary[model_name]["results"][batch_size][slice_size] = average_time
|
||||
else:
|
||||
dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
|
||||
|
||||
except RuntimeError as e:
|
||||
print("Doesn't fit on GPU.", e)
|
||||
torch.cuda.empty_cache()
|
||||
dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
|
||||
dictionary[model_name]["memory"][batch_size][slice_size] = "N/A"
|
||||
return dictionary
|
||||
|
||||
|
||||
def _compute_tensorflow(model_names, dictionary, average_over, amp):
|
||||
def _compute_tensorflow(
|
||||
model_names, batch_sizes, slice_sizes, dictionary, average_over, amp, no_speed, no_memory, verbose
|
||||
):
|
||||
for c, model_name in enumerate(model_names):
|
||||
print(f"{c + 1} / {len(model_names)}")
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
@ -383,11 +485,10 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
|
||||
tokenized_sequence = tokenizer.encode(input_text, add_special_tokens=False)
|
||||
|
||||
max_input_size = tokenizer.max_model_input_sizes[model_name]
|
||||
batch_sizes = [1, 2, 4, 8]
|
||||
slice_sizes = [8, 64, 128, 256, 512, 1024]
|
||||
|
||||
dictionary[model_name] = {"bs": batch_sizes, "ss": slice_sizes, "results": {}}
|
||||
dictionary[model_name] = {"bs": batch_sizes, "ss": slice_sizes, "results": {}, "memory": {}}
|
||||
dictionary[model_name]["results"] = {i: {} for i in batch_sizes}
|
||||
dictionary[model_name]["memory"] = {i: {} for i in batch_sizes}
|
||||
|
||||
print("Using model", model)
|
||||
|
||||
@ -409,13 +510,31 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
|
||||
# To make sure that the model is traced + that the tensors are on the appropriate device
|
||||
inference(sequence)
|
||||
|
||||
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
|
||||
average_time = sum(runtimes) / float(len(runtimes)) / 3.0
|
||||
dictionary[model_name]["results"][batch_size][slice_size] = average_time
|
||||
if not no_memory:
|
||||
# Line by line memory tracing (all code in the module `transformers`) works for all models/arbitrary code
|
||||
trace = start_memory_tracing("transformers")
|
||||
inference(sequence)
|
||||
summary = stop_memory_tracing(trace)
|
||||
|
||||
if verbose:
|
||||
print_summary_statistics(summary)
|
||||
|
||||
dictionary[model_name]["memory"][batch_size][slice_size] = str(summary.total)
|
||||
else:
|
||||
dictionary[model_name]["memory"][batch_size][slice_size] = "N/A"
|
||||
|
||||
if not no_speed:
|
||||
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
|
||||
average_time = sum(runtimes) / float(len(runtimes)) / 3.0
|
||||
dictionary[model_name]["results"][batch_size][slice_size] = average_time
|
||||
else:
|
||||
dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
|
||||
|
||||
except tf.errors.ResourceExhaustedError as e:
|
||||
print("Doesn't fit on GPU.", e)
|
||||
torch.cuda.empty_cache()
|
||||
dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
|
||||
dictionary[model_name]["memory"][batch_size][slice_size] = "N/A"
|
||||
return dictionary
|
||||
|
||||
|
||||
@ -433,6 +552,9 @@ def main():
|
||||
"of all available model "
|
||||
"architectures.",
|
||||
)
|
||||
parser.add_argument("--verbose", required=False, action="store_true", help="Verbose memory tracing")
|
||||
parser.add_argument("--no_speed", required=False, action="store_true", help="Don't perform speed measurments")
|
||||
parser.add_argument("--no_memory", required=False, action="store_true", help="Don't perform memory measurments")
|
||||
parser.add_argument(
|
||||
"--torch", required=False, action="store_true", help="Benchmark the Pytorch version of the " "models"
|
||||
)
|
||||
@ -477,6 +599,8 @@ def main():
|
||||
parser.add_argument(
|
||||
"--average_over", required=False, default=30, type=int, help="Times an experiment will be run."
|
||||
)
|
||||
parser.add_argument("--batch_sizes", nargs="+", type=int, default=[1, 2, 4, 8])
|
||||
parser.add_argument("--slice_sizes", nargs="+", type=int, default=[8, 64, 128, 256, 512, 1024])
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.models == "all":
|
||||
@ -501,6 +625,8 @@ def main():
|
||||
if is_torch_available():
|
||||
create_setup_and_compute(
|
||||
model_names=args.models,
|
||||
batch_sizes=args.batch_sizes,
|
||||
slice_sizes=args.slice_sizes,
|
||||
tensorflow=False,
|
||||
gpu=args.torch_cuda,
|
||||
torchscript=args.torchscript,
|
||||
@ -508,6 +634,9 @@ def main():
|
||||
save_to_csv=args.save_to_csv,
|
||||
csv_filename=args.csv_filename,
|
||||
average_over=args.average_over,
|
||||
no_speed=args.no_speed,
|
||||
no_memory=args.no_memory,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
else:
|
||||
raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.")
|
||||
@ -516,12 +645,17 @@ def main():
|
||||
if is_tf_available():
|
||||
create_setup_and_compute(
|
||||
model_names=args.models,
|
||||
batch_sizes=args.batch_sizes,
|
||||
slice_sizes=args.slice_sizes,
|
||||
tensorflow=True,
|
||||
xla=args.xla,
|
||||
amp=args.amp,
|
||||
save_to_csv=args.save_to_csv,
|
||||
csv_filename=args.csv_filename,
|
||||
average_over=args.average_over,
|
||||
no_speed=args.no_speed,
|
||||
no_memory=args.no_memory,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
else:
|
||||
raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.")
|
||||
|
@ -2,3 +2,4 @@ tensorboardX
|
||||
tensorboard
|
||||
scikit-learn
|
||||
seqeval
|
||||
psutil
|
||||
|
3
setup.py
3
setup.py
@ -110,7 +110,7 @@ setup(
|
||||
],
|
||||
extras_require=extras,
|
||||
scripts=["transformers-cli"],
|
||||
python_requires=">=3.5.0",
|
||||
python_requires=">=3.6.0",
|
||||
classifiers=[
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
"Intended Audience :: Developers",
|
||||
@ -119,7 +119,6 @@ setup(
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.5",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
|
@ -19,6 +19,18 @@ else:
|
||||
|
||||
import logging
|
||||
|
||||
# Benchmarking
|
||||
from .benchmark_utils import (
|
||||
Frame,
|
||||
Memory,
|
||||
MemoryState,
|
||||
MemorySummary,
|
||||
MemoryTrace,
|
||||
UsedMemoryState,
|
||||
bytes_to_human_readable,
|
||||
start_memory_tracing,
|
||||
stop_memory_tracing,
|
||||
)
|
||||
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
|
||||
from .configuration_bart import BartConfig
|
||||
|
341
src/transformers/benchmark_utils.py
Normal file
341
src/transformers/benchmark_utils.py
Normal file
@ -0,0 +1,341 @@
|
||||
"""
|
||||
Utilities for working with the local dataset cache.
|
||||
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
||||
Copyright by the AllenNLP authors.
|
||||
"""
|
||||
|
||||
import linecache
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from typing import Iterable, List, NamedTuple, Optional, Union
|
||||
|
||||
from .file_utils import is_tf_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from torch.cuda import empty_cache as torch_empty_cache
|
||||
if is_tf_available():
|
||||
from tensorflow.python.eager import context as tf_context
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_is_memory_tracing_enabled = False
|
||||
|
||||
|
||||
def is_memory_tracing_enabled():
|
||||
global _is_memory_tracing_enabled
|
||||
return _is_memory_tracing_enabled
|
||||
|
||||
|
||||
class Frame(NamedTuple):
|
||||
""" `Frame` is a NamedTuple used to gather the current frame state.
|
||||
`Frame` has the following fields:
|
||||
- 'filename' (string): Name of the file currently executed
|
||||
- 'module' (string): Name of the module currently executed
|
||||
- 'line_number' (int): Number of the line currently executed
|
||||
- 'event' (string): Event that triggered the tracing (default will be "line")
|
||||
- 'line_text' (string): Text of the line in the python script
|
||||
"""
|
||||
|
||||
filename: str
|
||||
module: str
|
||||
line_number: int
|
||||
event: str
|
||||
line_text: str
|
||||
|
||||
|
||||
class UsedMemoryState(NamedTuple):
|
||||
""" `UsedMemoryState` are named tuples with the following fields:
|
||||
- 'frame': a `Frame` namedtuple (see below) storing information on the current tracing frame (current file, location in current file)
|
||||
- 'cpu_memory': CPU RSS memory state *before* executing the line
|
||||
- 'gpu_memory': GPU used memory *before* executing the line (sum for all GPUs or for only `gpus_to_trace` if provided)
|
||||
"""
|
||||
|
||||
frame: Frame
|
||||
cpu_memory: int
|
||||
gpu_memory: int
|
||||
|
||||
|
||||
class Memory(NamedTuple):
|
||||
""" `Memory` NamedTuple have a single field `bytes` and
|
||||
you can get a human readable string of the number of bytes by calling `__repr__`
|
||||
- `byte` (integer): number of bytes,
|
||||
"""
|
||||
|
||||
bytes: int
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return bytes_to_human_readable(self.bytes)
|
||||
|
||||
|
||||
class MemoryState(NamedTuple):
|
||||
""" `MemoryState` are namedtuples listing frame + CPU/GPU memory with the following fields:
|
||||
- `frame` (`Frame`): the current frame (see above)
|
||||
- `cpu`: CPU memory consumed at during the current frame as a `Memory` named tuple
|
||||
- `gpu`: GPU memory consumed at during the current frame as a `Memory` named tuple
|
||||
- `cpu_gpu`: CPU + GPU memory consumed at during the current frame as a `Memory` named tuple
|
||||
"""
|
||||
|
||||
frame: Frame
|
||||
cpu: Memory
|
||||
gpu: Memory
|
||||
cpu_gpu: Memory
|
||||
|
||||
|
||||
class MemorySummary(NamedTuple):
|
||||
""" `MemorySummary` namedtuple otherwise with the fields:
|
||||
- `sequential`: a list of `MemoryState` namedtuple (see below) computed from the provided `memory_trace`
|
||||
by substracting the memory after executing each line from the memory before executing said line.
|
||||
- `cumulative`: a list of `MemoryState` namedtuple (see below) with cumulative increase in memory for each line
|
||||
obtained by summing repeted memory increase for a line if it's executed several times.
|
||||
The list is sorted from the frame with the largest memory consumption to the frame with the smallest (can be negative if memory is released)
|
||||
- `total`: total memory increase during the full tracing as a `Memory` named tuple (see below).
|
||||
Line with memory release (negative consumption) are ignored if `ignore_released_memory` is `True` (default).
|
||||
"""
|
||||
|
||||
sequential: List[MemoryState]
|
||||
cumulative: List[MemoryState]
|
||||
total: Memory
|
||||
|
||||
|
||||
MemoryTrace = List[UsedMemoryState]
|
||||
|
||||
|
||||
def start_memory_tracing(
|
||||
modules_to_trace: Optional[Union[str, Iterable[str]]] = None,
|
||||
modules_not_to_trace: Optional[Union[str, Iterable[str]]] = None,
|
||||
events_to_trace: str = "line",
|
||||
gpus_to_trace: Optional[List[int]] = None,
|
||||
) -> MemoryTrace:
|
||||
""" Setup line-by-line tracing to record rss mem (RAM) at each line of a module or sub-module.
|
||||
See `../../examples/benchmarks.py for a usage example.
|
||||
Current memory consumption is returned using psutil and in particular is the RSS memory
|
||||
"Resident Set Size” (the non-swapped physical memory the process is using).
|
||||
See https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info
|
||||
|
||||
Args:
|
||||
- `modules_to_trace`: (None, string, list/tuple of string)
|
||||
if None, all events are recorded
|
||||
if string or list of strings: only events from the listed module/sub-module will be recorded (e.g. 'fairseq' or 'transformers.modeling_gpt2')
|
||||
- `modules_not_to_trace`: (None, string, list/tuple of string)
|
||||
if None, no module is avoided
|
||||
if string or list of strings: events from the listed module/sub-module will not be recorded (e.g. 'torch')
|
||||
- `events_to_trace`: string or list of string of events to be recorded (see official python doc for `sys.settrace` for the list of events)
|
||||
default to line
|
||||
- `gpus_to_trace`: (optional list, default None) list of GPUs to trace. Default to tracing all GPUs
|
||||
|
||||
Return:
|
||||
- `memory_trace` is a list of `UsedMemoryState` for each event (default each line of the traced script).
|
||||
- `UsedMemoryState` are named tuples with the following fields:
|
||||
- 'frame': a `Frame` namedtuple (see below) storing information on the current tracing frame (current file, location in current file)
|
||||
- 'cpu_memory': CPU RSS memory state *before* executing the line
|
||||
- 'gpu_memory': GPU used memory *before* executing the line (sum for all GPUs or for only `gpus_to_trace` if provided)
|
||||
|
||||
`Frame` is a namedtuple used by `UsedMemoryState` to list the current frame state.
|
||||
`Frame` has the following fields:
|
||||
- 'filename' (string): Name of the file currently executed
|
||||
- 'module' (string): Name of the module currently executed
|
||||
- 'line_number' (int): Number of the line currently executed
|
||||
- 'event' (string): Event that triggered the tracing (default will be "line")
|
||||
- 'line_text' (string): Text of the line in the python script
|
||||
|
||||
"""
|
||||
try:
|
||||
import psutil
|
||||
except (ImportError):
|
||||
logger.warning(
|
||||
"Psutil not installed, we won't log CPU memory usage. "
|
||||
"Install psutil (pip install psutil) to use CPU memory tracing."
|
||||
)
|
||||
process = None
|
||||
else:
|
||||
process = psutil.Process(os.getpid())
|
||||
|
||||
try:
|
||||
from py3nvml import py3nvml
|
||||
|
||||
py3nvml.nvmlInit()
|
||||
devices = list(range(py3nvml.nvmlDeviceGetCount())) if gpus_to_trace is None else gpus_to_trace
|
||||
py3nvml.nvmlShutdown()
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"py3nvml not installed, we won't log GPU memory usage. "
|
||||
"Install py3nvml (pip install py3nvml) to use GPU memory tracing."
|
||||
)
|
||||
log_gpu = False
|
||||
except (OSError, py3nvml.NVMLError):
|
||||
logger.warning("Error while initializing comunication with GPU. " "We won't perform GPU memory tracing.")
|
||||
log_gpu = False
|
||||
else:
|
||||
log_gpu = is_torch_available() or is_tf_available()
|
||||
|
||||
memory_trace = []
|
||||
|
||||
def traceit(frame, event, args):
|
||||
""" Tracing method executed before running each line in a module or sub-module
|
||||
Record memory allocated in a list with debugging information
|
||||
"""
|
||||
global _is_memory_tracing_enabled
|
||||
|
||||
if not _is_memory_tracing_enabled:
|
||||
return traceit
|
||||
|
||||
# Filter events
|
||||
if events_to_trace is not None:
|
||||
if isinstance(events_to_trace, str) and event != events_to_trace:
|
||||
return traceit
|
||||
elif isinstance(events_to_trace, (list, tuple)) and event not in events_to_trace:
|
||||
return traceit
|
||||
|
||||
# Filter modules
|
||||
name = frame.f_globals["__name__"]
|
||||
if not isinstance(name, str):
|
||||
return traceit
|
||||
else:
|
||||
# Filter whitelist of modules to trace
|
||||
if modules_to_trace is not None:
|
||||
if isinstance(modules_to_trace, str) and modules_to_trace not in name:
|
||||
return traceit
|
||||
elif isinstance(modules_to_trace, (list, tuple)) and all(m not in name for m in modules_to_trace):
|
||||
return traceit
|
||||
|
||||
# Filter blacklist of modules not to trace
|
||||
if modules_not_to_trace is not None:
|
||||
if isinstance(modules_not_to_trace, str) and modules_not_to_trace in name:
|
||||
return traceit
|
||||
elif isinstance(modules_not_to_trace, (list, tuple)) and any(m in name for m in modules_not_to_trace):
|
||||
return traceit
|
||||
|
||||
# Record current tracing state (file, location in file...)
|
||||
lineno = frame.f_lineno
|
||||
filename = frame.f_globals["__file__"]
|
||||
if filename.endswith(".pyc") or filename.endswith(".pyo"):
|
||||
filename = filename[:-1]
|
||||
line = linecache.getline(filename, lineno).rstrip()
|
||||
traced_state = Frame(filename, name, lineno, event, line)
|
||||
|
||||
# Record current memory state (rss memory) and compute difference with previous memory state
|
||||
cpu_mem = 0
|
||||
if process is not None:
|
||||
mem = process.memory_info()
|
||||
cpu_mem = mem.rss
|
||||
|
||||
gpu_mem = 0
|
||||
if log_gpu:
|
||||
# Clear GPU caches
|
||||
if is_torch_available():
|
||||
torch_empty_cache()
|
||||
if is_tf_available():
|
||||
tf_context.context()._clear_caches() # See https://github.com/tensorflow/tensorflow/issues/20218#issuecomment-416771802
|
||||
|
||||
# Sum used memory for all GPUs
|
||||
py3nvml.nvmlInit()
|
||||
for i in devices:
|
||||
handle = py3nvml.nvmlDeviceGetHandleByIndex(i)
|
||||
meminfo = py3nvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
gpu_mem += meminfo.used
|
||||
py3nvml.nvmlShutdown()
|
||||
|
||||
mem_state = UsedMemoryState(traced_state, cpu_mem, gpu_mem)
|
||||
memory_trace.append(mem_state)
|
||||
|
||||
return traceit
|
||||
|
||||
sys.settrace(traceit)
|
||||
|
||||
global _is_memory_tracing_enabled
|
||||
_is_memory_tracing_enabled = True
|
||||
|
||||
return memory_trace
|
||||
|
||||
|
||||
def stop_memory_tracing(
|
||||
memory_trace: Optional[MemoryTrace] = None, ignore_released_memory: bool = True
|
||||
) -> Optional[MemorySummary]:
|
||||
""" Stop memory tracing cleanly and return a summary of the memory trace if a trace is given.
|
||||
|
||||
Args:
|
||||
- `memory_trace` (optional output of start_memory_tracing, default: None): memory trace to convert in summary
|
||||
- `ignore_released_memory` (boolean, default: None): if True we only sum memory increase to compute total memory
|
||||
|
||||
Return:
|
||||
- None if `memory_trace` is None
|
||||
- `MemorySummary` namedtuple otherwise with the fields:
|
||||
- `sequential`: a list of `MemoryState` namedtuple (see below) computed from the provided `memory_trace`
|
||||
by substracting the memory after executing each line from the memory before executing said line.
|
||||
- `cumulative`: a list of `MemoryState` namedtuple (see below) with cumulative increase in memory for each line
|
||||
obtained by summing repeted memory increase for a line if it's executed several times.
|
||||
The list is sorted from the frame with the largest memory consumption to the frame with the smallest (can be negative if memory is released)
|
||||
- `total`: total memory increase during the full tracing as a `Memory` named tuple (see below).
|
||||
Line with memory release (negative consumption) are ignored if `ignore_released_memory` is `True` (default).
|
||||
|
||||
`Memory` named tuple have fields
|
||||
- `byte` (integer): number of bytes,
|
||||
- `string` (string): same as human readable string (ex: "3.5MB")
|
||||
|
||||
`Frame` are namedtuple used to list the current frame state and have the following fields:
|
||||
- 'filename' (string): Name of the file currently executed
|
||||
- 'module' (string): Name of the module currently executed
|
||||
- 'line_number' (int): Number of the line currently executed
|
||||
- 'event' (string): Event that triggered the tracing (default will be "line")
|
||||
- 'line_text' (string): Text of the line in the python script
|
||||
|
||||
`MemoryState` are namedtuples listing frame + CPU/GPU memory with the following fields:
|
||||
- `frame` (`Frame`): the current frame (see above)
|
||||
- `cpu`: CPU memory consumed at during the current frame as a `Memory` named tuple
|
||||
- `gpu`: GPU memory consumed at during the current frame as a `Memory` named tuple
|
||||
- `cpu_gpu`: CPU + GPU memory consumed at during the current frame as a `Memory` named tuple
|
||||
"""
|
||||
global _is_memory_tracing_enabled
|
||||
_is_memory_tracing_enabled = False
|
||||
|
||||
if memory_trace is not None and len(memory_trace) > 1:
|
||||
memory_diff_trace = []
|
||||
cumulative_memory_dict = defaultdict(lambda: [0, 0, 0])
|
||||
for (frame, cpu_mem, gpu_mem), (next_frame, next_cpu_mem, next_gpu_mem) in zip(
|
||||
memory_trace[:-1], memory_trace[1:]
|
||||
):
|
||||
cpu_mem_inc = next_cpu_mem - cpu_mem
|
||||
gpu_mem_inc = next_gpu_mem - gpu_mem
|
||||
cpu_gpu_mem_inc = cpu_mem_inc + gpu_mem_inc
|
||||
memory_diff_trace.append(
|
||||
MemoryState(
|
||||
frame=frame, cpu=Memory(cpu_mem_inc), gpu=Memory(gpu_mem_inc), cpu_gpu=Memory(cpu_gpu_mem_inc),
|
||||
)
|
||||
)
|
||||
cumulative_memory_dict[frame][0] += cpu_mem_inc
|
||||
cumulative_memory_dict[frame][1] += gpu_mem_inc
|
||||
cumulative_memory_dict[frame][2] += cpu_gpu_mem_inc
|
||||
|
||||
cumulative_memory = sorted(
|
||||
list(cumulative_memory_dict.items()), key=lambda x: x[1][2], reverse=True
|
||||
) # order by the total CPU + GPU memory increase
|
||||
cumulative_memory = list(
|
||||
MemoryState(
|
||||
frame=frame, cpu=Memory(cpu_mem_inc), gpu=Memory(gpu_mem_inc), cpu_gpu=Memory(cpu_gpu_mem_inc),
|
||||
)
|
||||
for frame, (cpu_mem_inc, gpu_mem_inc, cpu_gpu_mem_inc) in cumulative_memory
|
||||
)
|
||||
|
||||
if ignore_released_memory:
|
||||
total_memory = sum(max(0, step_trace.cpu_gpu.bytes) for step_trace in memory_diff_trace)
|
||||
else:
|
||||
total_memory = sum(step_trace.cpu_gpu.bytes for step_trace in memory_diff_trace)
|
||||
total_memory = Memory(total_memory)
|
||||
return MemorySummary(sequential=memory_diff_trace, cumulative=cumulative_memory, total=total_memory)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def bytes_to_human_readable(memory_amount: int) -> str:
|
||||
""" Utility to convert a number of bytes (int) in a human readable string (with units)
|
||||
"""
|
||||
for unit in ["B", "KB", "MB", "GB"]:
|
||||
if memory_amount > -1024.0 and memory_amount < 1024.0:
|
||||
return "{:.3f}{}".format(memory_amount, unit)
|
||||
memory_amount /= 1024.0
|
||||
return "{:.3f}TB".format(memory_amount)
|
@ -59,6 +59,8 @@ class GPT2Config(PretrainedConfig):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
n_head (:obj:`int`, optional, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
activation_function (:obj:`str`, optional, defaults to 'gelu'):
|
||||
Activation function selected in the list ["relu", "swish", "gelu", "tanh", "gelu_new"].
|
||||
resid_pdrop (:obj:`float`, optional, defaults to 0.1):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
embd_pdrop (:obj:`int`, optional, defaults to 0.1):
|
||||
@ -125,6 +127,7 @@ class GPT2Config(PretrainedConfig):
|
||||
n_embd=768,
|
||||
n_layer=12,
|
||||
n_head=12,
|
||||
activation_function="gelu_new",
|
||||
resid_pdrop=0.1,
|
||||
embd_pdrop=0.1,
|
||||
attn_pdrop=0.1,
|
||||
@ -147,6 +150,7 @@ class GPT2Config(PretrainedConfig):
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.activation_function = activation_function
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.embd_pdrop = embd_pdrop
|
||||
self.attn_pdrop = attn_pdrop
|
||||
|
@ -24,7 +24,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .activations import gelu_new
|
||||
from .activations import ACT2FN
|
||||
from .configuration_gpt2 import GPT2Config
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer
|
||||
@ -203,7 +203,7 @@ class MLP(nn.Module):
|
||||
nx = config.n_embd
|
||||
self.c_fc = Conv1D(n_state, nx)
|
||||
self.c_proj = Conv1D(nx, n_state)
|
||||
self.act = gelu_new
|
||||
self.act = ACT2FN[config.activation_function]
|
||||
self.dropout = nn.Dropout(config.resid_pdrop)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -39,6 +39,7 @@ from .file_utils import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
from torch.nn import Identity
|
||||
except ImportError:
|
||||
@ -66,6 +67,47 @@ class ModuleUtilsMixin:
|
||||
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
|
||||
return sum(p.numel() for p in params)
|
||||
|
||||
@staticmethod
|
||||
def _hook_rss_memory_pre_forward(module, *args, **kwargs):
|
||||
try:
|
||||
import psutil
|
||||
except (ImportError):
|
||||
raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
mem = process.memory_info()
|
||||
module.mem_rss_pre_forward = mem.rss
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _hook_rss_memory_post_forward(module, *args, **kwargs):
|
||||
try:
|
||||
import psutil
|
||||
except (ImportError):
|
||||
raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
mem = process.memory_info()
|
||||
module.mem_rss_post_forward = mem.rss
|
||||
mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
|
||||
module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
|
||||
return None
|
||||
|
||||
def add_memory_hooks(self):
|
||||
""" Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
|
||||
Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()`
|
||||
"""
|
||||
for module in self.modules():
|
||||
module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
|
||||
module.register_forward_hook(self._hook_rss_memory_post_forward)
|
||||
self.reset_memory_hooks_state()
|
||||
|
||||
def reset_memory_hooks_state(self):
|
||||
for module in self.modules():
|
||||
module.mem_rss_diff = 0
|
||||
module.mem_rss_post_forward = 0
|
||||
module.mem_rss_pre_forward = 0
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
r""" Base class for all models.
|
||||
|
Loading…
Reference in New Issue
Block a user