{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# JAX Examples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic JAX Benchmarking\n", "\n", "ZeroPyBench automatically detects JAX arrays in your code and wraps the benchmarked expression in a JIT-compiled function.\n", "\n", ":::{note}\n", "For JAX benchmarking, ZeroPyBench provides additional measurements, see below.\n", ":::" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import jax.random as jr\n", "\n", "from zeropybench import Benchmark\n", "\n", "bench = Benchmark(repeat=20)\n", "x = jnp.ones(1000)\n", "y = jnp.ones(1000)\n", "\n", "with bench():\n", " x + y" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Verbose Mode\n", "\n", "Use `verbose=True` to see the setup code (JIT-compiled function) and the benchmarked code:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bench = Benchmark(verbose=True)\n", "\n", "with bench():\n", " x + y" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For functions returning a PyTree (e.g., a tuple of Arrays), the slightly slower `jax.block_until_ready` is used instead." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from dataclasses import dataclass\n", "\n", "\n", "@jax.tree_util.register_dataclass\n", "@dataclass\n", "class Vector:\n", " x: jax.Array\n", " y: jax.Array\n", "\n", " @classmethod\n", " def normal(cls, key, shape):\n", " key_x, key_y = jr.split(key)\n", " return Vector(jr.normal(key_x, shape), jr.normal(key_y, shape))\n", "\n", " def __add__(self, other):\n", " if not isinstance(other, Vector):\n", " return NotImplemented\n", " return Vector(self.x + other.x, self.y + other.y)\n", "\n", "\n", "key1, key2 = jr.split(jr.key(42))\n", "shape = (100,)\n", "v1 = Vector.normal(key1, shape)\n", "v2 = Vector.normal(key2, shape)\n", "\n", "bench = Benchmark(verbose=True)\n", "with bench():\n", " v1 + v2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multiple Bare Expressions\n", "\n", "When the benchmarked code contains several bare expressions (without assignment), they are each captured into temporary variables so that `block_until_ready` can synchronize all computations. Calls to functions known to return `None` (such as `print` or functions annotated with `-> None`) are left as-is:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bench = Benchmark(verbose=True)\n", "\n", "with bench():\n", " print(x)\n", " x + y\n", " x * y" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## JAX-Specific Report Fields\n", "\n", "When JAX code is detected, the benchmark report includes additional fields\n", "- ``first_execution_time``: The execution time of the code inside the context manager, which usually correspond to the non-jitted version of the code that is being benchmarked.\n", "- ``compilation_time``: the lowering and compilation time.\n", "- ``generated_code_size``: the total size of the generated machine code in bytes, including embedded constants.\n", "- ``temp_size``: the size of the preallocated temporary buffer arena in bytes. This accounts for intermediate buffers needed during execution, excluding input arguments, outputs, and constants.\n", "\n", ":::{warning}\n", "The ``generated_code_size`` and ``temp_size`` values may not be reliable on the CPU backend.\n", ":::" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(bench)\n", "run = bench[0]\n", "print('Execution times [μs]:', sorted(_ * 1e6 for _ in run['execution_times']))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualizing the HLO\n", "\n", "The `hlo` field contains the StableHLO representation of the compiled function. You can visualize it using the library `visu-hlo`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from visu_hlo import show\n", "\n", "show(run['hlo'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Comparing Broadcasting Strategies\n", "\n", "Benchmark different array operations to compare their performance:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bench = Benchmark()\n", "\n", "for N in [100, 1000, 10000]:\n", " x = jnp.ones(N)\n", " y = jnp.ones(1000)\n", "\n", " with bench(method='broadcast right', N=N):\n", " x[:, None] + y[None, :]\n", "\n", " with bench(method='broadcast left', N=N):\n", " x[None, :] + y[:, None]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(bench)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Benchmarking JIT-compiled Functions\n", "\n", "ZeroPyBench handles both regular and JIT-compiled functions:\n", "\n", ":::{note}\n", "When benchmarking an already JIT-compiled function, ZeroPyBench reuses it directly without re-wrapping, preserving any `static_argnums` or other JIT options you specified.\n", ":::" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import jax\n", "\n", "\n", "@jax.jit\n", "def matmul(a, b):\n", " return a @ b\n", "\n", "\n", "bench = Benchmark(verbose=True)\n", "\n", "for N in [256, 512]:\n", " a = jnp.ones((N, N))\n", " b = jnp.ones((N, N))\n", "\n", " with bench(operation='matmul', N=N):\n", " matmul(a, b)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(bench)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualizing the Matrix Multiplication HLO\n", "\n", "Display the HLO computational graph for the second run (N=128)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show(bench[1]['hlo'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting Results" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bench.plot()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 4 }