JAX Examples

Basic JAX Benchmarking

ZeroPyBench automatically detects JAX arrays in your code and wraps the benchmarked expression in a JIT-compiled function.

Note

For JAX benchmarking, ZeroPyBench provides additional measurements, see below.

import jax
import jax.numpy as jnp
import jax.random as jr

from zeropybench import Benchmark

bench = Benchmark(repeat=20)
x = jnp.ones(1000)
y = jnp.ones(1000)

with bench():
    x + y
5.880 µs ± 0.81% (median of 20 runs, 50000 loops each)

Verbose Mode

Use verbose=True to see the setup code (JIT-compiled function) and the benchmarked code:

bench = Benchmark(verbose=True)

with bench():
    x + y
Setup code:
    @jax.jit
    def __bench_func(x, y):
        return x + y
Benchmarked code:
    __bench_func(x, y).block_until_ready()
5.932 µs ± 0.32% (median of 7 runs, 50000 loops each)

For functions returning a PyTree (e.g., a tuple of Arrays), the slightly slower jax.block_until_ready is used instead.

from dataclasses import dataclass


@jax.tree_util.register_dataclass
@dataclass
class Vector:
    x: jax.Array
    y: jax.Array

    @classmethod
    def normal(cls, key, shape):
        key_x, key_y = jr.split(key)
        return Vector(jr.normal(key_x, shape), jr.normal(key_y, shape))

    def __add__(self, other):
        if not isinstance(other, Vector):
            return NotImplemented
        return Vector(self.x + other.x, self.y + other.y)


key1, key2 = jr.split(jr.key(42))
shape = (100,)
v1 = Vector.normal(key1, shape)
v2 = Vector.normal(key2, shape)

bench = Benchmark(verbose=True)
with bench():
    v1 + v2
Setup code:
    @jax.jit
    def __bench_func(v1, v2):
        return v1 + v2
Benchmarked code:
    jax.block_until_ready(__bench_func(v1, v2))
9.488 µs ± 0.42% (median of 7 runs, 50000 loops each)

Multiple Bare Expressions

When the benchmarked code contains several bare expressions (without assignment), they are each captured into temporary variables so that block_until_ready can synchronize all computations. Calls to functions known to return None (such as print or functions annotated with -> None) are left as-is:

bench = Benchmark(verbose=True)

with bench():
    print(x)
    x + y
    x * y
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
JitTracer(float32[1000])
Setup code:
    @jax.jit
    def __bench_func(x, y):
        print(x)
        __expr1 = x + y
        __expr2 = x * y
        return (__expr1, __expr2)
Benchmarked code:
    jax.block_until_ready(__bench_func(x, y))
8.547 µs ± 1.17% (median of 7 runs, 50000 loops each)

JAX-Specific Report Fields

When JAX code is detected, the benchmark report includes additional fields

  • first_execution_time: The execution time of the code inside the context manager, which usually correspond to the non-jitted version of the code that is being benchmarked.

  • compilation_time: the lowering and compilation time.

  • generated_code_size: the total size of the generated machine code in bytes, including embedded constants.

  • temp_size: the size of the preallocated temporary buffer arena in bytes. This accounts for intermediate buffers needed during execution, excluding input arguments, outputs, and constants.

Warning

The generated_code_size and temp_size values may not be reliable on the CPU backend.

print(bench)
run = bench[0]
print('Execution times [μs]:', sorted(_ * 1e6 for _ in run['execution_times']))
┌───┬────────────────┬──────────┬────────────────┬────────────────┬────────────────┬───────────────┐
│   ┆ median_executi ┆ ± (%)    ┆ first_executio ┆ compilation_ti ┆ generated_code ┆ temp_size (B) │
│   ┆ on_time (µs)   ┆          ┆ n_time (µs)    ┆ me (µs)        ┆ _size (B)      ┆               │
╞═══╪════════════════╪══════════╪════════════════╪════════════════╪════════════════╪═══════════════╡
│ 0 ┆ 8.547132       ┆ 1.173406 ┆ 57_029.436     ┆ 29_734.152999  ┆ 0              ┆ 0             │
└───┴────────────────┴──────────┴────────────────┴────────────────┴────────────────┴───────────────┘
Execution times [μs]: [8.479386939980031, 8.479485339994426, 8.500881360014318, 8.547131759987678, 8.550357340027404, 8.624570760002825, 8.70624356000917]

Visualizing the HLO

The hlo field contains the StableHLO representation of the compiled function. You can visualize it using the library visu-hlo:

from visu_hlo import show

show(run['hlo'])
../_images/9cb361086bc591871cd8be1c24c454c64a3fce59441a8c3e319ccc7dcde238e7.svg

Comparing Broadcasting Strategies

Benchmark different array operations to compare their performance:

bench = Benchmark()

for N in [100, 1000, 10000]:
    x = jnp.ones(N)
    y = jnp.ones(1000)

    with bench(method='broadcast right', N=N):
        x[:, None] + y[None, :]

    with bench(method='broadcast left', N=N):
        x[None, :] + y[:, None]
method=broadcast right, N=100: 13.746 µs ± 2.43% (median of 7 runs, 20000 loops each)
method=broadcast left, N=100: 14.150 µs ± 2.89% (median of 7 runs, 20000 loops each)
method=broadcast right, N=1000: 50.810 µs ± 1.57% (median of 7 runs, 5000 loops each)
method=broadcast left, N=1000: 50.451 µs ± 1.94% (median of 7 runs, 5000 loops each)
method=broadcast right, N=10000: 12.588 ms ± 1.70% (median of 7 runs, 20 loops each)
method=broadcast left, N=10000: 12.741 ms ± 2.50% (median of 7 runs, 20 loops each)
print(bench)
┌───┬────────────┬────────┬────────────┬──────────┬────────────┬───────────┬───────────┬───────────┐
│   ┆ method     ┆ N      ┆ median_exe ┆ ± (%)    ┆ first_exec ┆ compilati ┆ generated ┆ temp_size │
│   ┆            ┆        ┆ cution_tim ┆          ┆ ution_time ┆ on_time   ┆ _code_siz ┆ (B)       │
│   ┆            ┆        ┆ e (ms)     ┆          ┆ (ms)       ┆ (ms)      ┆ e (B)     ┆           │
╞═══╪════════════╪════════╪════════════╪══════════╪════════════╪═══════════╪═══════════╪═══════════╡
│ 0 ┆ broadcast  ┆ 100    ┆ 0.013746   ┆ 2.428567 ┆ 34.956574  ┆ 20.968133 ┆ 0         ┆ 0         │
│   ┆ right      ┆        ┆            ┆          ┆            ┆           ┆           ┆           │
│ 1 ┆ broadcast  ┆ 100    ┆ 0.01415    ┆ 2.893307 ┆ 34.258964  ┆ 19.529872 ┆ 0         ┆ 0         │
│   ┆ left       ┆        ┆            ┆          ┆            ┆           ┆           ┆           │
│ 2 ┆ broadcast  ┆ 1_000  ┆ 0.05081    ┆ 1.568993 ┆ 36.311284  ┆ 22.613573 ┆ 0         ┆ 0         │
│   ┆ right      ┆        ┆            ┆          ┆            ┆           ┆           ┆           │
│ 3 ┆ broadcast  ┆ 1_000  ┆ 0.050451   ┆ 1.939443 ┆ 37.465634  ┆ 22.188032 ┆ 0         ┆ 0         │
│   ┆ left       ┆        ┆            ┆          ┆            ┆           ┆           ┆           │
│ 4 ┆ broadcast  ┆ 10_000 ┆ 12.587709  ┆ 1.702945 ┆ 37.154414  ┆ 31.106143 ┆ 0         ┆ 0         │
│   ┆ right      ┆        ┆            ┆          ┆            ┆           ┆           ┆           │
│ 5 ┆ broadcast  ┆ 10_000 ┆ 12.740685  ┆ 2.498716 ┆ 39.017844  ┆ 30.673434 ┆ 0         ┆ 0         │
│   ┆ left       ┆        ┆            ┆          ┆            ┆           ┆           ┆           │
└───┴────────────┴────────┴────────────┴──────────┴────────────┴───────────┴───────────┴───────────┘

Benchmarking JIT-compiled Functions

ZeroPyBench handles both regular and JIT-compiled functions:

Note

When benchmarking an already JIT-compiled function, ZeroPyBench reuses it directly without re-wrapping, preserving any static_argnums or other JIT options you specified.

import jax


@jax.jit
def matmul(a, b):
    return a @ b


bench = Benchmark(verbose=True)

for N in [256, 512]:
    a = jnp.ones((N, N))
    b = jnp.ones((N, N))

    with bench(operation='matmul', N=N):
        matmul(a, b)
Setup code:
    __bench_func = matmul
Benchmarked code:
    __bench_func(a, b).block_until_ready()
operation=matmul, N=256: 202.963 µs ± 1.32% (median of 7 runs, 1000 loops each)
Setup code:
    __bench_func = matmul
Benchmarked code:
    __bench_func(a, b).block_until_ready()
operation=matmul, N=512: 1.234 ms ± 0.45% (median of 7 runs, 200 loops each)
print(bench)
┌───┬───────────┬─────┬─────────────┬──────────┬────────────┬────────────┬────────────┬────────────┐
│   ┆ operation ┆ N   ┆ median_exec ┆ ± (%)    ┆ first_exec ┆ compilatio ┆ generated_ ┆ temp_size  │
│   ┆           ┆     ┆ ution_time  ┆          ┆ ution_time ┆ n_time     ┆ code_size  ┆ (B)        │
│   ┆           ┆     ┆ (ms)        ┆          ┆ (ms)       ┆ (ms)       ┆ (B)        ┆            │
╞═══╪═══════════╪═════╪═════════════╪══════════╪════════════╪════════════╪════════════╪════════════╡
│ 0 ┆ matmul    ┆ 256 ┆ 0.202963    ┆ 1.315189 ┆ 8.609431   ┆ 0.11784    ┆ 0          ┆ 0          │
│ 1 ┆ matmul    ┆ 512 ┆ 1.233717    ┆ 0.45166  ┆ 8.463781   ┆ 0.11857    ┆ 0          ┆ 0          │
└───┴───────────┴─────┴─────────────┴──────────┴────────────┴────────────┴────────────┴────────────┘

Visualizing the Matrix Multiplication HLO

Display the HLO computational graph for the second run (N=128)

show(bench[1]['hlo'])
../_images/467a44d327a2ffbd6a876184aee5b156889928d951374fe800869202cc6cd861.svg

Plotting Results

bench.plot()
../_images/36c88fdfb4008ad669d7b9b0691f4e3c8ee591b92038ba0bc0fb095b48b34f5c.png