From 2187c49f5cde57306c3fd1eb67dbc68fab9c6403 Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Tue, 17 Mar 2020 15:17:11 +0100 Subject: [PATCH] CPU/GPU memory benchmarking utilities - Remove support for python 3.5 (now only 3.6+) (#3186) * memory benchmark rss * have both forward pass and line-by-line mem tracing * cleaned up tracing * refactored and cleaning up API * no f-strings yet... * add GPU mem logging * fix GPU memory monitoring * style and quality * clean up and doc * update with comments * Switching to python 3.6+ * fix quality --- .circleci/config.yml | 10 +- README.md | 2 +- docs/source/installation.md | 2 +- examples/benchmarks.py | 176 +++++++++++-- examples/requirements.txt | 1 + setup.py | 3 +- src/transformers/__init__.py | 12 + src/transformers/benchmark_utils.py | 341 +++++++++++++++++++++++++ src/transformers/configuration_gpt2.py | 4 + src/transformers/modeling_gpt2.py | 4 +- src/transformers/modeling_utils.py | 42 +++ 11 files changed, 565 insertions(+), 32 deletions(-) create mode 100644 src/transformers/benchmark_utils.py diff --git a/.circleci/config.yml b/.circleci/config.yml index ff7c021b6fc..a9a31ba131d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -3,7 +3,7 @@ jobs: run_tests_torch_and_tf: working_directory: ~/transformers docker: - - image: circleci/python:3.5 + - image: circleci/python:3.6 environment: OMP_NUM_THREADS: 1 resource_class: xlarge @@ -46,7 +46,7 @@ jobs: run_tests_custom_tokenizers: working_directory: ~/transformers docker: - - image: circleci/python:3.5 + - image: circleci/python:3.6 environment: RUN_CUSTOM_TOKENIZERS: yes steps: @@ -56,7 +56,7 @@ jobs: run_examples_torch: working_directory: ~/transformers docker: - - image: circleci/python:3.5 + - image: circleci/python:3.6 environment: OMP_NUM_THREADS: 1 resource_class: xlarge @@ -69,7 +69,7 @@ jobs: deploy_doc: working_directory: ~/transformers docker: - - image: circleci/python:3.5 + - image: circleci/python:3.6 steps: - add_ssh_keys: fingerprints: @@ -94,7 +94,7 @@ jobs: check_repository_consistency: working_directory: ~/transformers docker: - - image: circleci/python:3.5 + - image: circleci/python:3.6 resource_class: small parallelism: 1 steps: diff --git a/README.md b/README.md index 119467b3e8b..b7d2b724939 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ Choose the right framework for every part of a model's lifetime ## Installation -This repo is tested on Python 3.5+, PyTorch 1.0.0+ and TensorFlow 2.0.0-rc1 +This repo is tested on Python 3.6+, PyTorch 1.0.0+ and TensorFlow 2.0.0-rc1 You should install 🤗 Transformers in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). diff --git a/docs/source/installation.md b/docs/source/installation.md index f4b7781ea9a..02f2951759d 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -1,6 +1,6 @@ # Installation -Transformers is tested on Python 3.5+ and PyTorch 1.1.0 +Transformers is tested on Python 3.6+ and PyTorch 1.1.0 ## With pip diff --git a/examples/benchmarks.py b/examples/benchmarks.py index 07de19d4b51..bf204b48653 100644 --- a/examples/benchmarks.py +++ b/examples/benchmarks.py @@ -24,7 +24,15 @@ import timeit from time import time from typing import List -from transformers import AutoConfig, AutoTokenizer, is_tf_available, is_torch_available +from transformers import ( + AutoConfig, + AutoTokenizer, + MemorySummary, + is_tf_available, + is_torch_available, + start_memory_tracing, + stop_memory_tracing, +) if is_tf_available(): @@ -250,15 +258,21 @@ as they entered.""" def create_setup_and_compute( model_names: List[str], + batch_sizes: List[int], + slice_sizes: List[int], gpu: bool = True, tensorflow: bool = False, average_over: int = 3, + no_speed: bool = False, + no_memory: bool = False, + verbose: bool = False, torchscript: bool = False, xla: bool = False, amp: bool = False, fp16: bool = False, save_to_csv: bool = False, csv_filename: str = f"results_{round(time())}.csv", + csv_memory_filename: str = f"memory_{round(time())}.csv", ): if xla: tf.config.optimizer.set_jit(True) @@ -267,11 +281,25 @@ def create_setup_and_compute( if tensorflow: dictionary = {model_name: {} for model_name in model_names} - results = _compute_tensorflow(model_names, dictionary, average_over, amp) + results = _compute_tensorflow( + model_names, batch_sizes, slice_sizes, dictionary, average_over, amp, no_speed, no_memory, verbose + ) else: device = "cuda" if (gpu and torch.cuda.is_available()) else "cpu" dictionary = {model_name: {} for model_name in model_names} - results = _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16) + results = _compute_pytorch( + model_names, + batch_sizes, + slice_sizes, + dictionary, + average_over, + device, + torchscript, + fp16, + no_speed, + no_memory, + verbose, + ) print("=========== RESULTS ===========") for model_name in model_names: @@ -280,13 +308,19 @@ def create_setup_and_compute( print("\t\t" + f"===== BATCH SIZE: {batch_size} =====") for slice_size in results[model_name]["ss"]: result = results[model_name]["results"][batch_size][slice_size] + memory = results[model_name]["memory"][batch_size][slice_size] if isinstance(result, str): - print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{result}") + print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{result} " f"{memory}") else: - print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{(round(1000 * result) / 1000)}" f"s") + print( + f"\t\t{model_name}/{batch_size}/{slice_size}: " + f"{(round(1000 * result) / 1000)}" + f"s " + f"{memory}" + ) if save_to_csv: - with open(csv_filename, mode="w") as csv_file: + with open(csv_filename, mode="w") as csv_file, open(csv_memory_filename, mode="w") as csv_memory_file: fieldnames = [ "model", "1x8", @@ -317,6 +351,8 @@ def create_setup_and_compute( writer = csv.DictWriter(csv_file, fieldnames=fieldnames) writer.writeheader() + memory_writer = csv.DictWriter(csv_memory_file, fieldnames=fieldnames) + memory_writer.writeheader() for model_name in model_names: model_results = { @@ -326,8 +362,52 @@ def create_setup_and_compute( } writer.writerow({"model": model_name, **model_results}) + model_memory_results = { + f"{bs}x{ss}": results[model_name]["memory"][bs][ss] + for bs in results[model_name]["memory"] + for ss in results[model_name]["memory"][bs] + } + memory_writer.writerow({"model": model_name, **model_memory_results}) -def _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16): + +def print_summary_statistics(summary: MemorySummary): + print( + "\nLines by line memory consumption:\n" + + "\n".join( + f"{state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}" + for state in summary.sequential + ) + ) + print( + "\nLines with top memory consumption:\n" + + "\n".join( + f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}" + for state in summary.cumulative[:6] + ) + ) + print( + "\nLines with lowest memory consumption:\n" + + "\n".join( + f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}" + for state in summary.cumulative[-6:] + ) + ) + print(f"\nTotal memory increase: {summary.total}") + + +def _compute_pytorch( + model_names, + batch_sizes, + slice_sizes, + dictionary, + average_over, + device, + torchscript, + fp16, + no_speed, + no_memory, + verbose, +): for c, model_name in enumerate(model_names): print(f"{c + 1} / {len(model_names)}") config = AutoConfig.from_pretrained(model_name, torchscript=torchscript) @@ -337,17 +417,17 @@ def _compute_pytorch(model_names, dictionary, average_over, device, torchscript, tokenized_sequence = tokenizer.encode(input_text, add_special_tokens=False) max_input_size = tokenizer.max_model_input_sizes[model_name] - batch_sizes = [1, 2, 4, 8] - slice_sizes = [8, 64, 128, 256, 512, 1024] - dictionary[model_name] = {"bs": batch_sizes, "ss": slice_sizes, "results": {}} + dictionary[model_name] = {"bs": batch_sizes, "ss": slice_sizes, "results": {}, "memory": {}} dictionary[model_name]["results"] = {i: {} for i in batch_sizes} + dictionary[model_name]["memory"] = {i: {} for i in batch_sizes} for batch_size in batch_sizes: if fp16: model.half() model.to(device) model.eval() + for slice_size in slice_sizes: if max_input_size is not None and slice_size > max_input_size: dictionary[model_name]["results"][batch_size][slice_size] = "N/A" @@ -362,18 +442,40 @@ def _compute_pytorch(model_names, dictionary, average_over, device, torchscript, inference = model inference(sequence) - print("Going through model with sequence of shape", sequence.shape) - runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3) - average_time = sum(runtimes) / float(len(runtimes)) / 3.0 - dictionary[model_name]["results"][batch_size][slice_size] = average_time + if not no_memory: + # model.add_memory_hooks() # Forward method tracing (only for PyTorch models) + + # Line by line memory tracing (all code in the module `transformers`) works for all models/arbitrary code + trace = start_memory_tracing("transformers") + inference(sequence) + summary = stop_memory_tracing(trace) + + if verbose: + print_summary_statistics(summary) + + dictionary[model_name]["memory"][batch_size][slice_size] = str(summary.total) + else: + dictionary[model_name]["memory"][batch_size][slice_size] = "N/A" + + if not no_speed: + print("Going through model with sequence of shape", sequence.shape) + runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3) + average_time = sum(runtimes) / float(len(runtimes)) / 3.0 + dictionary[model_name]["results"][batch_size][slice_size] = average_time + else: + dictionary[model_name]["results"][batch_size][slice_size] = "N/A" + except RuntimeError as e: print("Doesn't fit on GPU.", e) torch.cuda.empty_cache() dictionary[model_name]["results"][batch_size][slice_size] = "N/A" + dictionary[model_name]["memory"][batch_size][slice_size] = "N/A" return dictionary -def _compute_tensorflow(model_names, dictionary, average_over, amp): +def _compute_tensorflow( + model_names, batch_sizes, slice_sizes, dictionary, average_over, amp, no_speed, no_memory, verbose +): for c, model_name in enumerate(model_names): print(f"{c + 1} / {len(model_names)}") config = AutoConfig.from_pretrained(model_name) @@ -383,11 +485,10 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp): tokenized_sequence = tokenizer.encode(input_text, add_special_tokens=False) max_input_size = tokenizer.max_model_input_sizes[model_name] - batch_sizes = [1, 2, 4, 8] - slice_sizes = [8, 64, 128, 256, 512, 1024] - dictionary[model_name] = {"bs": batch_sizes, "ss": slice_sizes, "results": {}} + dictionary[model_name] = {"bs": batch_sizes, "ss": slice_sizes, "results": {}, "memory": {}} dictionary[model_name]["results"] = {i: {} for i in batch_sizes} + dictionary[model_name]["memory"] = {i: {} for i in batch_sizes} print("Using model", model) @@ -409,13 +510,31 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp): # To make sure that the model is traced + that the tensors are on the appropriate device inference(sequence) - runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3) - average_time = sum(runtimes) / float(len(runtimes)) / 3.0 - dictionary[model_name]["results"][batch_size][slice_size] = average_time + if not no_memory: + # Line by line memory tracing (all code in the module `transformers`) works for all models/arbitrary code + trace = start_memory_tracing("transformers") + inference(sequence) + summary = stop_memory_tracing(trace) + + if verbose: + print_summary_statistics(summary) + + dictionary[model_name]["memory"][batch_size][slice_size] = str(summary.total) + else: + dictionary[model_name]["memory"][batch_size][slice_size] = "N/A" + + if not no_speed: + runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3) + average_time = sum(runtimes) / float(len(runtimes)) / 3.0 + dictionary[model_name]["results"][batch_size][slice_size] = average_time + else: + dictionary[model_name]["results"][batch_size][slice_size] = "N/A" + except tf.errors.ResourceExhaustedError as e: print("Doesn't fit on GPU.", e) torch.cuda.empty_cache() dictionary[model_name]["results"][batch_size][slice_size] = "N/A" + dictionary[model_name]["memory"][batch_size][slice_size] = "N/A" return dictionary @@ -433,6 +552,9 @@ def main(): "of all available model " "architectures.", ) + parser.add_argument("--verbose", required=False, action="store_true", help="Verbose memory tracing") + parser.add_argument("--no_speed", required=False, action="store_true", help="Don't perform speed measurments") + parser.add_argument("--no_memory", required=False, action="store_true", help="Don't perform memory measurments") parser.add_argument( "--torch", required=False, action="store_true", help="Benchmark the Pytorch version of the " "models" ) @@ -477,6 +599,8 @@ def main(): parser.add_argument( "--average_over", required=False, default=30, type=int, help="Times an experiment will be run." ) + parser.add_argument("--batch_sizes", nargs="+", type=int, default=[1, 2, 4, 8]) + parser.add_argument("--slice_sizes", nargs="+", type=int, default=[8, 64, 128, 256, 512, 1024]) args = parser.parse_args() if args.models == "all": @@ -501,6 +625,8 @@ def main(): if is_torch_available(): create_setup_and_compute( model_names=args.models, + batch_sizes=args.batch_sizes, + slice_sizes=args.slice_sizes, tensorflow=False, gpu=args.torch_cuda, torchscript=args.torchscript, @@ -508,6 +634,9 @@ def main(): save_to_csv=args.save_to_csv, csv_filename=args.csv_filename, average_over=args.average_over, + no_speed=args.no_speed, + no_memory=args.no_memory, + verbose=args.verbose, ) else: raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.") @@ -516,12 +645,17 @@ def main(): if is_tf_available(): create_setup_and_compute( model_names=args.models, + batch_sizes=args.batch_sizes, + slice_sizes=args.slice_sizes, tensorflow=True, xla=args.xla, amp=args.amp, save_to_csv=args.save_to_csv, csv_filename=args.csv_filename, average_over=args.average_over, + no_speed=args.no_speed, + no_memory=args.no_memory, + verbose=args.verbose, ) else: raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.") diff --git a/examples/requirements.txt b/examples/requirements.txt index 36229755e81..6a4126c9263 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -2,3 +2,4 @@ tensorboardX tensorboard scikit-learn seqeval +psutil diff --git a/setup.py b/setup.py index eb9916d84c6..fd9180d110f 100644 --- a/setup.py +++ b/setup.py @@ -110,7 +110,7 @@ setup( ], extras_require=extras, scripts=["transformers-cli"], - python_requires=">=3.5.0", + python_requires=">=3.6.0", classifiers=[ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", @@ -119,7 +119,6 @@ setup( "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Topic :: Scientific/Engineering :: Artificial Intelligence", diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3c6e0bb9e1f..1c9e1ac4c77 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -19,6 +19,18 @@ else: import logging +# Benchmarking +from .benchmark_utils import ( + Frame, + Memory, + MemoryState, + MemorySummary, + MemoryTrace, + UsedMemoryState, + bytes_to_human_readable, + start_memory_tracing, + stop_memory_tracing, +) from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig from .configuration_bart import BartConfig diff --git a/src/transformers/benchmark_utils.py b/src/transformers/benchmark_utils.py new file mode 100644 index 00000000000..9223816123c --- /dev/null +++ b/src/transformers/benchmark_utils.py @@ -0,0 +1,341 @@ +""" +Utilities for working with the local dataset cache. +This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp +Copyright by the AllenNLP authors. +""" + +import linecache +import logging +import os +import sys +from collections import defaultdict +from typing import Iterable, List, NamedTuple, Optional, Union + +from .file_utils import is_tf_available, is_torch_available + + +if is_torch_available(): + from torch.cuda import empty_cache as torch_empty_cache +if is_tf_available(): + from tensorflow.python.eager import context as tf_context + + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +_is_memory_tracing_enabled = False + + +def is_memory_tracing_enabled(): + global _is_memory_tracing_enabled + return _is_memory_tracing_enabled + + +class Frame(NamedTuple): + """ `Frame` is a NamedTuple used to gather the current frame state. + `Frame` has the following fields: + - 'filename' (string): Name of the file currently executed + - 'module' (string): Name of the module currently executed + - 'line_number' (int): Number of the line currently executed + - 'event' (string): Event that triggered the tracing (default will be "line") + - 'line_text' (string): Text of the line in the python script + """ + + filename: str + module: str + line_number: int + event: str + line_text: str + + +class UsedMemoryState(NamedTuple): + """ `UsedMemoryState` are named tuples with the following fields: + - 'frame': a `Frame` namedtuple (see below) storing information on the current tracing frame (current file, location in current file) + - 'cpu_memory': CPU RSS memory state *before* executing the line + - 'gpu_memory': GPU used memory *before* executing the line (sum for all GPUs or for only `gpus_to_trace` if provided) + """ + + frame: Frame + cpu_memory: int + gpu_memory: int + + +class Memory(NamedTuple): + """ `Memory` NamedTuple have a single field `bytes` and + you can get a human readable string of the number of bytes by calling `__repr__` + - `byte` (integer): number of bytes, + """ + + bytes: int + + def __repr__(self) -> str: + return bytes_to_human_readable(self.bytes) + + +class MemoryState(NamedTuple): + """ `MemoryState` are namedtuples listing frame + CPU/GPU memory with the following fields: + - `frame` (`Frame`): the current frame (see above) + - `cpu`: CPU memory consumed at during the current frame as a `Memory` named tuple + - `gpu`: GPU memory consumed at during the current frame as a `Memory` named tuple + - `cpu_gpu`: CPU + GPU memory consumed at during the current frame as a `Memory` named tuple + """ + + frame: Frame + cpu: Memory + gpu: Memory + cpu_gpu: Memory + + +class MemorySummary(NamedTuple): + """ `MemorySummary` namedtuple otherwise with the fields: + - `sequential`: a list of `MemoryState` namedtuple (see below) computed from the provided `memory_trace` + by substracting the memory after executing each line from the memory before executing said line. + - `cumulative`: a list of `MemoryState` namedtuple (see below) with cumulative increase in memory for each line + obtained by summing repeted memory increase for a line if it's executed several times. + The list is sorted from the frame with the largest memory consumption to the frame with the smallest (can be negative if memory is released) + - `total`: total memory increase during the full tracing as a `Memory` named tuple (see below). + Line with memory release (negative consumption) are ignored if `ignore_released_memory` is `True` (default). + """ + + sequential: List[MemoryState] + cumulative: List[MemoryState] + total: Memory + + +MemoryTrace = List[UsedMemoryState] + + +def start_memory_tracing( + modules_to_trace: Optional[Union[str, Iterable[str]]] = None, + modules_not_to_trace: Optional[Union[str, Iterable[str]]] = None, + events_to_trace: str = "line", + gpus_to_trace: Optional[List[int]] = None, +) -> MemoryTrace: + """ Setup line-by-line tracing to record rss mem (RAM) at each line of a module or sub-module. + See `../../examples/benchmarks.py for a usage example. + Current memory consumption is returned using psutil and in particular is the RSS memory + "Resident Set Size” (the non-swapped physical memory the process is using). + See https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info + + Args: + - `modules_to_trace`: (None, string, list/tuple of string) + if None, all events are recorded + if string or list of strings: only events from the listed module/sub-module will be recorded (e.g. 'fairseq' or 'transformers.modeling_gpt2') + - `modules_not_to_trace`: (None, string, list/tuple of string) + if None, no module is avoided + if string or list of strings: events from the listed module/sub-module will not be recorded (e.g. 'torch') + - `events_to_trace`: string or list of string of events to be recorded (see official python doc for `sys.settrace` for the list of events) + default to line + - `gpus_to_trace`: (optional list, default None) list of GPUs to trace. Default to tracing all GPUs + + Return: + - `memory_trace` is a list of `UsedMemoryState` for each event (default each line of the traced script). + - `UsedMemoryState` are named tuples with the following fields: + - 'frame': a `Frame` namedtuple (see below) storing information on the current tracing frame (current file, location in current file) + - 'cpu_memory': CPU RSS memory state *before* executing the line + - 'gpu_memory': GPU used memory *before* executing the line (sum for all GPUs or for only `gpus_to_trace` if provided) + + `Frame` is a namedtuple used by `UsedMemoryState` to list the current frame state. + `Frame` has the following fields: + - 'filename' (string): Name of the file currently executed + - 'module' (string): Name of the module currently executed + - 'line_number' (int): Number of the line currently executed + - 'event' (string): Event that triggered the tracing (default will be "line") + - 'line_text' (string): Text of the line in the python script + + """ + try: + import psutil + except (ImportError): + logger.warning( + "Psutil not installed, we won't log CPU memory usage. " + "Install psutil (pip install psutil) to use CPU memory tracing." + ) + process = None + else: + process = psutil.Process(os.getpid()) + + try: + from py3nvml import py3nvml + + py3nvml.nvmlInit() + devices = list(range(py3nvml.nvmlDeviceGetCount())) if gpus_to_trace is None else gpus_to_trace + py3nvml.nvmlShutdown() + except ImportError: + logger.warning( + "py3nvml not installed, we won't log GPU memory usage. " + "Install py3nvml (pip install py3nvml) to use GPU memory tracing." + ) + log_gpu = False + except (OSError, py3nvml.NVMLError): + logger.warning("Error while initializing comunication with GPU. " "We won't perform GPU memory tracing.") + log_gpu = False + else: + log_gpu = is_torch_available() or is_tf_available() + + memory_trace = [] + + def traceit(frame, event, args): + """ Tracing method executed before running each line in a module or sub-module + Record memory allocated in a list with debugging information + """ + global _is_memory_tracing_enabled + + if not _is_memory_tracing_enabled: + return traceit + + # Filter events + if events_to_trace is not None: + if isinstance(events_to_trace, str) and event != events_to_trace: + return traceit + elif isinstance(events_to_trace, (list, tuple)) and event not in events_to_trace: + return traceit + + # Filter modules + name = frame.f_globals["__name__"] + if not isinstance(name, str): + return traceit + else: + # Filter whitelist of modules to trace + if modules_to_trace is not None: + if isinstance(modules_to_trace, str) and modules_to_trace not in name: + return traceit + elif isinstance(modules_to_trace, (list, tuple)) and all(m not in name for m in modules_to_trace): + return traceit + + # Filter blacklist of modules not to trace + if modules_not_to_trace is not None: + if isinstance(modules_not_to_trace, str) and modules_not_to_trace in name: + return traceit + elif isinstance(modules_not_to_trace, (list, tuple)) and any(m in name for m in modules_not_to_trace): + return traceit + + # Record current tracing state (file, location in file...) + lineno = frame.f_lineno + filename = frame.f_globals["__file__"] + if filename.endswith(".pyc") or filename.endswith(".pyo"): + filename = filename[:-1] + line = linecache.getline(filename, lineno).rstrip() + traced_state = Frame(filename, name, lineno, event, line) + + # Record current memory state (rss memory) and compute difference with previous memory state + cpu_mem = 0 + if process is not None: + mem = process.memory_info() + cpu_mem = mem.rss + + gpu_mem = 0 + if log_gpu: + # Clear GPU caches + if is_torch_available(): + torch_empty_cache() + if is_tf_available(): + tf_context.context()._clear_caches() # See https://github.com/tensorflow/tensorflow/issues/20218#issuecomment-416771802 + + # Sum used memory for all GPUs + py3nvml.nvmlInit() + for i in devices: + handle = py3nvml.nvmlDeviceGetHandleByIndex(i) + meminfo = py3nvml.nvmlDeviceGetMemoryInfo(handle) + gpu_mem += meminfo.used + py3nvml.nvmlShutdown() + + mem_state = UsedMemoryState(traced_state, cpu_mem, gpu_mem) + memory_trace.append(mem_state) + + return traceit + + sys.settrace(traceit) + + global _is_memory_tracing_enabled + _is_memory_tracing_enabled = True + + return memory_trace + + +def stop_memory_tracing( + memory_trace: Optional[MemoryTrace] = None, ignore_released_memory: bool = True +) -> Optional[MemorySummary]: + """ Stop memory tracing cleanly and return a summary of the memory trace if a trace is given. + + Args: + - `memory_trace` (optional output of start_memory_tracing, default: None): memory trace to convert in summary + - `ignore_released_memory` (boolean, default: None): if True we only sum memory increase to compute total memory + + Return: + - None if `memory_trace` is None + - `MemorySummary` namedtuple otherwise with the fields: + - `sequential`: a list of `MemoryState` namedtuple (see below) computed from the provided `memory_trace` + by substracting the memory after executing each line from the memory before executing said line. + - `cumulative`: a list of `MemoryState` namedtuple (see below) with cumulative increase in memory for each line + obtained by summing repeted memory increase for a line if it's executed several times. + The list is sorted from the frame with the largest memory consumption to the frame with the smallest (can be negative if memory is released) + - `total`: total memory increase during the full tracing as a `Memory` named tuple (see below). + Line with memory release (negative consumption) are ignored if `ignore_released_memory` is `True` (default). + + `Memory` named tuple have fields + - `byte` (integer): number of bytes, + - `string` (string): same as human readable string (ex: "3.5MB") + + `Frame` are namedtuple used to list the current frame state and have the following fields: + - 'filename' (string): Name of the file currently executed + - 'module' (string): Name of the module currently executed + - 'line_number' (int): Number of the line currently executed + - 'event' (string): Event that triggered the tracing (default will be "line") + - 'line_text' (string): Text of the line in the python script + + `MemoryState` are namedtuples listing frame + CPU/GPU memory with the following fields: + - `frame` (`Frame`): the current frame (see above) + - `cpu`: CPU memory consumed at during the current frame as a `Memory` named tuple + - `gpu`: GPU memory consumed at during the current frame as a `Memory` named tuple + - `cpu_gpu`: CPU + GPU memory consumed at during the current frame as a `Memory` named tuple + """ + global _is_memory_tracing_enabled + _is_memory_tracing_enabled = False + + if memory_trace is not None and len(memory_trace) > 1: + memory_diff_trace = [] + cumulative_memory_dict = defaultdict(lambda: [0, 0, 0]) + for (frame, cpu_mem, gpu_mem), (next_frame, next_cpu_mem, next_gpu_mem) in zip( + memory_trace[:-1], memory_trace[1:] + ): + cpu_mem_inc = next_cpu_mem - cpu_mem + gpu_mem_inc = next_gpu_mem - gpu_mem + cpu_gpu_mem_inc = cpu_mem_inc + gpu_mem_inc + memory_diff_trace.append( + MemoryState( + frame=frame, cpu=Memory(cpu_mem_inc), gpu=Memory(gpu_mem_inc), cpu_gpu=Memory(cpu_gpu_mem_inc), + ) + ) + cumulative_memory_dict[frame][0] += cpu_mem_inc + cumulative_memory_dict[frame][1] += gpu_mem_inc + cumulative_memory_dict[frame][2] += cpu_gpu_mem_inc + + cumulative_memory = sorted( + list(cumulative_memory_dict.items()), key=lambda x: x[1][2], reverse=True + ) # order by the total CPU + GPU memory increase + cumulative_memory = list( + MemoryState( + frame=frame, cpu=Memory(cpu_mem_inc), gpu=Memory(gpu_mem_inc), cpu_gpu=Memory(cpu_gpu_mem_inc), + ) + for frame, (cpu_mem_inc, gpu_mem_inc, cpu_gpu_mem_inc) in cumulative_memory + ) + + if ignore_released_memory: + total_memory = sum(max(0, step_trace.cpu_gpu.bytes) for step_trace in memory_diff_trace) + else: + total_memory = sum(step_trace.cpu_gpu.bytes for step_trace in memory_diff_trace) + total_memory = Memory(total_memory) + return MemorySummary(sequential=memory_diff_trace, cumulative=cumulative_memory, total=total_memory) + + return None + + +def bytes_to_human_readable(memory_amount: int) -> str: + """ Utility to convert a number of bytes (int) in a human readable string (with units) + """ + for unit in ["B", "KB", "MB", "GB"]: + if memory_amount > -1024.0 and memory_amount < 1024.0: + return "{:.3f}{}".format(memory_amount, unit) + memory_amount /= 1024.0 + return "{:.3f}TB".format(memory_amount) diff --git a/src/transformers/configuration_gpt2.py b/src/transformers/configuration_gpt2.py index 4957e9fd104..1f2352a6c96 100644 --- a/src/transformers/configuration_gpt2.py +++ b/src/transformers/configuration_gpt2.py @@ -59,6 +59,8 @@ class GPT2Config(PretrainedConfig): Number of hidden layers in the Transformer encoder. n_head (:obj:`int`, optional, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. + activation_function (:obj:`str`, optional, defaults to 'gelu'): + Activation function selected in the list ["relu", "swish", "gelu", "tanh", "gelu_new"]. resid_pdrop (:obj:`float`, optional, defaults to 0.1): The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. embd_pdrop (:obj:`int`, optional, defaults to 0.1): @@ -125,6 +127,7 @@ class GPT2Config(PretrainedConfig): n_embd=768, n_layer=12, n_head=12, + activation_function="gelu_new", resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1, @@ -147,6 +150,7 @@ class GPT2Config(PretrainedConfig): self.n_embd = n_embd self.n_layer = n_layer self.n_head = n_head + self.activation_function = activation_function self.resid_pdrop = resid_pdrop self.embd_pdrop = embd_pdrop self.attn_pdrop = attn_pdrop diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index b492d7fc374..04a95eff289 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -24,7 +24,7 @@ import torch import torch.nn as nn from torch.nn import CrossEntropyLoss -from .activations import gelu_new +from .activations import ACT2FN from .configuration_gpt2 import GPT2Config from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer @@ -203,7 +203,7 @@ class MLP(nn.Module): nx = config.n_embd self.c_fc = Conv1D(n_state, nx) self.c_proj = Conv1D(nx, n_state) - self.act = gelu_new + self.act = ACT2FN[config.activation_function] self.dropout = nn.Dropout(config.resid_pdrop) def forward(self, x): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 467c329f5e4..e2c2ef1bf06 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -39,6 +39,7 @@ from .file_utils import ( logger = logging.getLogger(__name__) + try: from torch.nn import Identity except ImportError: @@ -66,6 +67,47 @@ class ModuleUtilsMixin: params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters() return sum(p.numel() for p in params) + @staticmethod + def _hook_rss_memory_pre_forward(module, *args, **kwargs): + try: + import psutil + except (ImportError): + raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.") + + process = psutil.Process(os.getpid()) + mem = process.memory_info() + module.mem_rss_pre_forward = mem.rss + return None + + @staticmethod + def _hook_rss_memory_post_forward(module, *args, **kwargs): + try: + import psutil + except (ImportError): + raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.") + + process = psutil.Process(os.getpid()) + mem = process.memory_info() + module.mem_rss_post_forward = mem.rss + mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward + module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0) + return None + + def add_memory_hooks(self): + """ Add a memory hook before and after each sub-module forward pass to record increase in memory consumption. + Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()` + """ + for module in self.modules(): + module.register_forward_pre_hook(self._hook_rss_memory_pre_forward) + module.register_forward_hook(self._hook_rss_memory_post_forward) + self.reset_memory_hooks_state() + + def reset_memory_hooks_state(self): + for module in self.modules(): + module.mem_rss_diff = 0 + module.mem_rss_post_forward = 0 + module.mem_rss_pre_forward = 0 + class PreTrainedModel(nn.Module, ModuleUtilsMixin): r""" Base class for all models.