Source code for zeropybench._benchmark

import contextlib
import gc
import inspect
import linecache
import sys
import textwrap
import time
import timeit
from collections.abc import Callable, Iterator, Sequence
from copy import deepcopy
from pathlib import Path
from typing import Any, ClassVar

import polars as pl

from ._io import BenchmarkReader, BenchmarkWriter
from ._jax import CodeASTParser
from ._plot import BenchmarkPlotter
from ._units import get_optimal_time_units, to_units

__all__ = ['Benchmark', 'read_benchmark']

ValidBenchmarkType = bool | int | float | str | list[float] | None


[docs] class Benchmark: """A class for multidimensional benchmarking of code snippets. Example:: import jax.numpy as jnp bench = Benchmark(repeat=10, verbose=True) for N in [10_000, 100_000, 1_000_000]: x = jnp.ones(N) y = jnp.ones(1000) with bench(method='broadcast right', N=N): x[:, None] + y[None, :] with bench(method='broadcast left', N=N): x[None, :] + y[:, None] print(bench) Output:: ┌───┬─────────────────┬───────────┬────────────────────────────┬──────────┬───────────────────────────┬───────────────────────┐ │ ┆ method ┆ N ┆ median_execution_time (ms) ┆ ± (%) ┆ first_execution_time (ms) ┆ compilation_time (ms) │ ╞═══╪═════════════════╪═══════════╪════════════════════════════╪══════════╪═══════════════════════════╪═══════════════════════╡ │ 0 ┆ broadcast right ┆ 10_000 ┆ 1.621035 ┆ 7.45449 ┆ 114.662899 ┆ 82.715331 │ │ 1 ┆ broadcast left ┆ 10_000 ┆ 1.637076 ┆ 1.211578 ┆ 76.16242 ┆ 59.710502 │ │ 2 ┆ broadcast right ┆ 100_000 ┆ 10.030257 ┆ 1.190149 ┆ 44.89509 ┆ 35.997909 │ │ 3 ┆ broadcast left ┆ 100_000 ┆ 10.297392 ┆ 1.467783 ┆ 133.313475 ┆ 54.304368 │ │ 4 ┆ broadcast right ┆ 1_000_000 ┆ 94.146169 ┆ 0.159088 ┆ 137.82041 ┆ 44.644484 │ │ 5 ┆ broadcast left ┆ 1_000_000 ┆ 196.27792 ┆ 0.093503 ┆ 112.162553 ┆ 80.459331 │ └───┴─────────────────┴───────────┴────────────────────────────┴──────────┴───────────────────────────┴───────────────────────┘ Then export or plot:: bench.write_csv('bench.csv') bench.plot() Attributes: repeat (int): The number of times the estimation of the elapsed time will be performed. Each repeat will usually execute the benchmarked code many times. min_duration_per_repeat (float): The minimum duration of one repeat, in seconds. The function will be executed as many times as necessary so that the total execution time is greater than this value. """ DEFAULT_REPEAT: ClassVar[int] = 7 """The default number of measurement repetitions.""" DEFAULT_MIN_DURATION_PER_REPEAT: ClassVar[float] = 0.2 """The default minimum duration per repeat in seconds.""" _report: list[dict[str, ValidBenchmarkType]] """Storage for the individual measurements.""" _cache: dict[str, str] """Cache of the content of the file names used to extract the with statement context. We don't use the linecache module (except for <...> files) since it's preferable to reset the cache each time a Benchmark class is instantiated (otherwise modifications of the benchmark may not be reflected)."""
[docs] def __init__( self, *, repeat: int = DEFAULT_REPEAT, min_duration_per_repeat: float = DEFAULT_MIN_DURATION_PER_REPEAT, verbose: bool = False, ) -> None: """Returns a Benchmark instance. Args: repeat: The number of times the estimation of the elapsed time will be performed. Each repeat will usually execute the benchmarked code many times. min_duration_per_repeat: The minimum duration of one repeat, in seconds. The function will be executed as many times as necessary so that the total execution time is greater than this value. The execution time for this repeat is the mean value of the execution times. verbose: If True, print the setup and benchmarked code to stderr. For JAX benchmarks, this shows the JIT-compiled function definition and the actual timed call (e.g., ``__bench_func(x, y).block_until_ready()``). """ self.repeat = repeat self.min_duration_per_repeat = min_duration_per_repeat self.verbose = verbose self._report = [] self._cache = {}
def __repr__(self) -> str: with pl.Config( thousands_separator='_', tbl_cols=-1, tbl_rows=-1, tbl_hide_column_data_types=True, tbl_hide_dataframe_shape=True, ): return str(self._to_display_dataframe()) @contextlib.contextmanager def __call__(self, **keywords: ValidBenchmarkType) -> Iterator[None]: # clear the JAX compilation caches jax = sys.modules.get('jax') if jax is not None: jax.clear_caches() start_time = time.perf_counter() yield first_time = time.perf_counter() - start_time code, f_locals, f_globals = self._get_execution_context() globals = f_globals | f_locals parser = CodeASTParser.from_code(code, globals) extra_info: dict[str, ValidBenchmarkType] if parser.is_jax_context(): setup, param_names, globals = parser.transform_jax_code() is_single_array, jax_metadata = self._compile_jax(param_names, globals) code = f'__bench_func({", ".join(param_names)})' if is_single_array: code += '.block_until_ready()' else: code = f'jax.block_until_ready({code})' extra_info = { 'first_execution_time': first_time, **jax_metadata, } if self.verbose: print(f'Setup code:\n{textwrap.indent(setup, " ")}', file=sys.stderr) else: extra_info = {} if self.verbose: print(f'Benchmarked code:\n{textwrap.indent(code, " ")}', file=sys.stderr) if parser.is_jax_context(): # In JAX context, first_time is not representative because of XLA optimizations. # Run the jitted code once to get an accurate estimate for autorange. timer = timeit.Timer(code, globals=globals) estimated_time = timer.timeit(1) else: estimated_time = first_time execution_times, number = self._run_many_times(code, estimated_time, globals) median, rel_stdev = self._get_statistics(execution_times) units = get_optimal_time_units([median]) median_display = to_units(median, units) if keywords: message = ', '.join(f'{k}={v}' for k, v in keywords.items()) + ': ' else: message = '' print( f'{message}{median_display:.3f} {units} ± {rel_stdev:.2f}% ' f'(median of {self.repeat} runs, {number} loops each)', file=sys.stderr, ) record: dict[str, ValidBenchmarkType] = { **keywords, 'median_execution_time': median, 'execution_times': execution_times, **extra_info, } self._report.append(record) def _get_execution_context(self) -> tuple[str, dict[str, Any], dict[str, Any]]: """Return the code as string, and the locals and globals as dicts.""" cf = inspect.currentframe() assert cf is not None cf = cf.f_back assert cf is not None cf = cf.f_back assert cf is not None cf = cf.f_back assert cf is not None filename = cf.f_code.co_filename code = self._get_code(filename, cf.f_lineno) return code, cf.f_locals, cf.f_globals def _get_code(self, filename: str, line_number: int) -> str: """Return the content inside the with statement context as text.""" lines = self._get_lines(filename) context = [] line_with = lines[line_number - 1] indent_with = len(line_with) - len(line_with.lstrip()) for line in lines[line_number:]: stripped_line = line.lstrip() indent = len(line) - len(stripped_line) if stripped_line and indent <= indent_with: break context.append(line) code = textwrap.dedent('\n'.join(context)).strip() return code def _get_lines(self, filename: str) -> list[str]: text = self._cache.get(filename) if text is None: # Use linecache for special files (<...>, ipykernel temp files, etc.) # linecache handles both regular files and IPython/Jupyter execution text = ''.join(linecache.getlines(filename)) if not text and filename == '<string>': # Handle python -c: get code from /proc/self/cmdline on Linux text = self._get_code_from_cmdline() self._cache[filename] = text return text.splitlines() @staticmethod def _get_code_from_cmdline() -> str: """Get Python code passed via 'python -c' from /proc/self/cmdline.""" cmdline_path = Path('/proc/self/cmdline') if not cmdline_path.exists(): raise RuntimeError( "Cannot read source code: '/proc/self/cmdline' does not exist. " "Benchmarking code passed via 'python -c' is only supported on Linux." ) cmdline = cmdline_path.read_bytes().decode().split('\x00') if '-c' not in cmdline: raise RuntimeError( "Cannot read source code: '-c' flag not found in command line. " "This is unexpected when the source file is '<string>'." ) idx = cmdline.index('-c') return cmdline[idx + 1] def _compile_jax( self, param_names: list[str], globals: dict[str, Any] ) -> tuple[bool, dict[str, ValidBenchmarkType]]: """Compile the JAX function and return output type info and metadata. Returns: A tuple (is_single_array, metadata) where is_single_array indicates whether the output is a single JAX array (vs tuple/pytree), and metadata is a dict with keys: compilation_time, generated_code_size, temp_size, hlo. """ bench_func = globals['__bench_func'] arg_values = [globals[name] for name in param_names] try: start_time = time.perf_counter() lowered = bench_func.lower(*arg_values) compiled = lowered.compile() compilation_time = time.perf_counter() - start_time except Exception as exc: print( f'Warning: the lowering or compilation of the JAX jitted function failed: {exc}', file=sys.stderr, ) metadata: dict[str, ValidBenchmarkType] = { 'compilation_time': None, 'generated_code_size': None, 'temp_size': None, 'hlo': None, } is_single_array = False else: memory = compiled.memory_analysis() metadata = { 'compilation_time': compilation_time, 'generated_code_size': memory.generated_code_size_in_bytes, 'temp_size': memory.temp_size_in_bytes, 'hlo': compiled.as_text(), } is_single_array = lowered.out_tree.num_leaves == 1 return is_single_array, metadata def _run_many_times( self, func: Callable[[], object] | str, first_time: float, globals: dict[str, Any] | None ) -> tuple[list[float], int]: """Returns execution times in seconds. Args: func: the function or code snippet to be executed. first_time: The execution time in seconds of the code that was run in the context manager. globals: The combined locals and globals of the code. """ number, time_taken = self._autorange(func, first_time, globals) timer = timeit.Timer(func, globals=globals) runs = [time_taken / number] for _ in range(self.repeat - 1): gc.collect() runs.append(timer.timeit(number) / number) return runs, number def _autorange( self, func: Callable[[], object] | str, first_time: float, globals: dict[str, Any] | None ) -> tuple[int, float]: """Returns the number of loops so that total time is greater than min_duration_per_repeat. Calls the timeit method with increasing numbers from the sequence 1, 2, 5, 10, 20, 50, ... until the time taken is at least min_duration_per_repeat Returns (number, time_taken_in_seconds). Adapted from the timeit module. """ if first_time >= self.min_duration_per_repeat: return 1, first_time timer = timeit.Timer(func, globals=globals) i = 1 while True: for j in 1, 2, 5: if (i, j) == (1, 1): continue number = i * j time_taken = timer.timeit(number) if time_taken >= self.min_duration_per_repeat: return number, time_taken i *= 10 @staticmethod def _get_statistics(execution_times: list[float]) -> tuple[float, float]: """Return the median and the relative MAD scaled to estimate standard deviation""" df = pl.DataFrame({'values': [execution_times]}) df = df.select( median=pl.col('values').list.median(), mad=Benchmark._get_mad(pl.col('values')) ) median = df['median'].item() rel_stdev = 1.4826 * df['mad'].item() / median * 100 return median, rel_stdev @staticmethod def _get_mad(column: pl.Expr) -> pl.Expr: """Return the Median Absolute Deviation.""" expr_element = abs(pl.element() - pl.element().median()).median() return column.list.eval(expr_element).list.first()
[docs] def to_dataframe(self) -> pl.DataFrame: """Returns the benchmark as a Polars dataframe with times in seconds.""" if not self._report: schema = {'median_execution_time': pl.Float64(), 'execution_times': pl.List(pl.Float64)} return pl.DataFrame({'median_execution_time': [], 'execution_times': []}, schema=schema) if 'hlo' in self._report[0]: schema = { 'compilation_time': pl.Float64(), 'generated_code_size': pl.Int64(), 'temp_size': pl.Int64(), 'hlo': pl.String(), } else: schema = {} return pl.DataFrame(self._report, schema_overrides=schema)
def _to_display_dataframe(self) -> pl.DataFrame: """Returns the benchmark as a Polars dataframe with times in display units.""" df = self.to_dataframe() excluded_columns = [ 'median_execution_time', 'execution_times', 'first_execution_time', 'compilation_time', 'mad', 'hlo', 'temp_size', 'generated_code_size', ] time_columns = [ col for col in ('first_execution_time', 'compilation_time') if col in df.columns ] byte_columns = [col for col in ('generated_code_size', 'temp_size') if col in df.columns] if not self._report: return df.select( 'median_execution_time', pl.lit(None, pl.Float64).alias('± (%)'), 'execution_times', ) units = get_optimal_time_units(df['median_execution_time']) suffix = f' ({units})' df = df.with_columns( mad=self._get_mad(pl.col('execution_times')), ) df = df.select( pl.exclude(excluded_columns), to_units(pl.col('median_execution_time').name.suffix(suffix), units), (1.4826 * pl.col('mad') / pl.col('median_execution_time') * 100).alias('± (%)'), to_units(pl.col(time_columns).name.suffix(suffix), units), pl.col(byte_columns).name.suffix(' (B)'), ).with_row_index('') return df def __len__(self) -> int: """Returns the number of runs in the benchmark.""" return len(self._report) def __bool__(self) -> bool: """Returns True if the benchmark is not empty.""" return len(self) > 0 def __getitem__(self, item: int) -> dict[str, ValidBenchmarkType]: """Returns the benchmark run with the given index (chronologically).""" try: return self._report[item] except IndexError: pass message = f'contains only {len(self)} runs' if self else 'is empty' raise IndexError(f'Index {item} is out of range. The benchmark report {message}.')
[docs] def to_dicts(self) -> list[dict[str, Any]]: """Returns the benchmark as a list of dicts.""" return deepcopy(self._report)
[docs] def write_csv(self, path: Path | str) -> None: """Writes the benchmark report as CSV. The file includes a header with metadata comments: - ``# repeat = <value>`` - ``# min_duration_per_repeat = <value>`` Args: path: The path of the CSV file. """ self._create_writer().write_csv(path)
[docs] def write_parquet(self, path: Path | str) -> None: """Writes the benchmark report as Parquet. The file includes metadata: - ``repeat``: The number of measurement repetitions - ``min_duration_per_repeat``: The minimum duration per repeat in seconds Args: path: The path of the Parquet file. """ self._create_writer().write_parquet(path)
[docs] def write_markdown(self, path: Path | str) -> None: """Writes the benchmark report as MarkDown table. Args: path: The path of the MarkDown file. """ self._create_writer().write_markdown(path)
def _create_writer(self) -> BenchmarkWriter: """Create a BenchmarkWriter instance.""" return BenchmarkWriter(self.to_dataframe(), self.repeat, self.min_duration_per_repeat)
[docs] def plot( self, *, x: str | pl.Expr | None = None, y: str | pl.Expr | None = None, by: str | Sequence[str] | None = None, reference: str | None = None, **subplots_keywords: Any, ) -> None: """Plots the benchmark report in a figure. Args: x: The x-axis of the plots, as a benchmark report column name or expression. y: The y-axis of the plots, as a benchmark report column name or expression. by: Key to divide into several subplots. reference: Legend label of the reference method for speedup comparison. When specified, a second column of subplots shows the speedup (reference_time / method_time) for each method. Values > 1 mean faster than the reference. """ plotter = self._create_plotter(x=x, y=y, by=by, reference=reference) plotter.show(**subplots_keywords)
[docs] def write_plot( self, path: Path | str, *, x: str | pl.Expr | None = None, y: str | pl.Expr | None = None, by: str | Sequence[str] | None = None, reference: str | None = None, **subplots_keywords: Any, ) -> None: """Saves the benchmark plot in a file. Args: path: The path of the plot file. x: The x-axis of the plots, as a benchmark report column name or expression. y: The y-axis of the plots, as a benchmark report column name or expression. by: Key to divide into several subplots. reference: Legend label of the reference method for speedup comparison. When specified, a second column of subplots shows the speedup (reference_time / method_time) for each method. Values > 1 mean faster than the reference. """ plotter = self._create_plotter(x=x, y=y, by=by, reference=reference) plotter.save(path, **subplots_keywords)
def _create_plotter( self, *, x: str | pl.Expr | None = None, y: str | pl.Expr | None = None, by: str | Sequence[str] | None = None, reference: str | None = None, ) -> BenchmarkPlotter: """Create a BenchmarkPlotter instance.""" return BenchmarkPlotter(self.to_dataframe(), x=x, y=y, by=by, reference=reference)
[docs] def read_benchmark(path: Path | str, *, verbose: bool = False) -> Benchmark: """Reads a benchmark from a CSV or Parquet file. The function automatically detects the file format based on the extension and reads the metadata (repeat, min_duration_per_repeat) stored in the file. Args: path: The path to the CSV or Parquet file. verbose: If True, set logging level to INFO. If False, set logging level to WARNING. Returns: A Benchmark instance with the data and metadata from the file. Raises: ValueError: If the file extension is not .csv or .parquet. """ reader = BenchmarkReader(Benchmark.DEFAULT_REPEAT, Benchmark.DEFAULT_MIN_DURATION_PER_REPEAT) df, repeat, min_duration_per_repeat = reader.read(path) bench = Benchmark( repeat=repeat, min_duration_per_repeat=min_duration_per_repeat, verbose=verbose ) bench._report = df.to_dicts() return bench