import heapq
import os
import subprocess
import sys
from mpmath import mp
import random

mp.prec = 256

def make_peres_tree(n):
    """Construct a Peres extractor tree with n node."""
    # Each node: [priority, v_count, path, children, order]
    # - Priority is distance from root, where every U counts as 1, and every V
    #   counts as two; it corresponds to the -log2(h) of the entropy the node
    #   received if p=0.5
    # - v_count (number of V steps) is used as tiebreaker, effectively
    #   optimizing for p=0.5+epsilon.
    order_counter = [1]
    ret = [1, 0, "", [], order_counter[0]]
    to_improve = [ret]
    for _ in range(1, n):
        improve_now = heapq.heappop(to_improve)
        order_counter[0] += 1
        if len(improve_now[3]) == 0:
            # Add U child
            add = [improve_now[0] + 1, improve_now[1], improve_now[2] + "U", [], order_counter[0]]
            improve_now[3].append(add)
            improve_now[0] += 1
            heapq.heappush(to_improve, improve_now)
            heapq.heappush(to_improve, add)
        else:
            # Add V child
            add = [improve_now[0] + 1, improve_now[1] + 1, improve_now[2] + "V", [], order_counter[0]]
            improve_now[3].append(add)
            heapq.heappush(to_improve, add)
    def recurse(node):
        return (node[2], [recurse(child) for child in node[3]], node[4])
    return recurse(ret)

def eval_peres_tree(tree, p):
    """Evaluate bitrate extracted from Peres tree."""
    q = 1 - p
    res = p * q
    p2 = p * p
    q2 = q * q
    eq = p2 + q2
    if len(tree[1]) >= 1:
        res += eval_peres_tree(tree[1][0], eq) / 2
    if len(tree[1]) >= 2:
        res += eq * eval_peres_tree(tree[1][1], p2 / eq) / 2
    return res

def _get_extractor():
    """Return path to extractor binary, or exit with an error."""
    script_dir = os.path.dirname(os.path.abspath(__file__))
    binary = os.path.join(script_dir, "main")
    if not os.path.exists(binary):
        sys.exit(f"error: {binary} not found.\n"
                 f"Compile with: g++ -O3 -flto -march=native -std=c++20 -o main extractor.cpp main.cpp")
    return binary

def eval_sim(batch, bitwidth, carry, p_values):
    """Evaluate expected bits/toss via the extractor binary."""
    binary = _get_extractor()
    input_data = "\n".join(f"{p:.17g}" for p in p_values) + "\n"
    result = subprocess.run(
        [binary, "model_rate", str(batch), str(bitwidth), str(carry)],
        input=input_data, capture_output=True, text=True, check=True
    )
    return [float(x) for x in result.stdout.strip().split("\n")]

TREE1 = make_peres_tree(1)
TREE4 = make_peres_tree(4)
TREE12 = make_peres_tree(12)
TREE16 = make_peres_tree(16)
TREE64 = make_peres_tree(64)
TREE80 = make_peres_tree(80)
TREE256 = make_peres_tree(256)
TREE600 = make_peres_tree(600)

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from collections import deque

def label_y_at_p_half(ax):
    """Add colored y-axis labels where each line crosses p=1/2."""
    for line in ax.get_lines():
        ydata = line.get_ydata()
        if len(ydata) == 0:
            continue
        y_val = ydata[-1]
        ax.annotate(f'{y_val:.1f}',
                    xy=(0, y_val), xycoords=('axes fraction', 'data'),
                    xytext=(-4, 0), textcoords='offset points',
                    va='center', ha='right',
                    color=line.get_color(), fontsize=8, clip_on=False)


def draw_tree_diagram(tree, filename="peres_tree14.png"):
    """Draw a diagram of the Peres extractor tree with insertion-order labels."""

    # Collect nodes with their insertion-order labels
    node_children = {}  # order -> [child_orders]
    root_order = tree[2]

    def collect(node):
        path, children, order = node
        node_children[order] = [c[2] for c in children]
        for c in children:
            collect(c)
    collect(tree)

    # Lay out positions using recursive width-based approach
    def subtree_width(lbl):
        ch = node_children[lbl]
        if not ch:
            return 1
        return sum(subtree_width(c) for c in ch)

    def layout(lbl, x, y, positions):
        positions[lbl] = (x, y)
        ch = node_children[lbl]
        if not ch:
            return
        total_w = subtree_width(lbl)
        cx = x - total_w / 2
        for c in ch:
            w = subtree_width(c)
            layout(c, cx + w / 2, y - 1, positions)
            cx += w

    positions = {}
    layout(root_order, 0, 0, positions)

    # Draw
    fig, ax = plt.subplots(1, 1, figsize=(8, 4.5))
    ax.set_aspect("equal")
    ax.set_axis_off()
    n_nodes = len(positions)
    ax.set_title(f"{n_nodes}-node Peres extractor", fontsize=13, fontweight="bold", pad=4)

    box_w, box_h = 0.6, 0.5
    node_color = "#4a7ab5"
    edge_color = "#888888"
    for lbl, (x, y) in positions.items():
        rect = patches.FancyBboxPatch(
            (x - box_w / 2, y - box_h / 2), box_w, box_h,
            boxstyle="round,pad=0.05", facecolor="#dce6f1",
            edgecolor=node_color, linewidth=1.0
        )
        ax.add_patch(rect)
        ax.text(x, y, str(lbl), ha="center", va="center", fontsize=11,
                fontweight="bold", color=node_color)

        ch = node_children[lbl]
        for idx, c in enumerate(ch):
            cx, cy = positions[c]
            # Line from bottom of parent to top of child
            pad = 0.05  # matches boxstyle round,pad=0.05
            ax.annotate(
                "", xy=(cx, cy + box_h / 2 + pad), xytext=(x, y - box_h / 2 - pad),
                arrowprops=dict(arrowstyle="-",
                                color=edge_color, lw=0.9,
                                shrinkA=0, shrinkB=0)
            )
            # Label on the arrow
            mid_x = (x + cx) / 2
            mid_y = (y - box_h / 2 + cy + box_h / 2) / 2
            prefix = "U" if idx == 0 else "V"
            offset_x = -0.25 if idx == 0 else 0.25
            ax.text(mid_x + offset_x, mid_y, f"$\\mathbf{{{prefix}_{{{lbl}}}}}$",
                    ha="center", va="center", fontsize=11, color="#333333")

    # Draw "Toss" input line into root node
    rx, ry = positions[root_order]
    toss_top = ry + box_h / 2 + 0.05 + 0.5
    ax.annotate(
        "", xy=(rx, ry + box_h / 2 + 0.05), xytext=(rx, toss_top),
        arrowprops=dict(arrowstyle="-", color=edge_color, lw=0.9,
                        shrinkA=0, shrinkB=0)
    )
    ax.text(rx + 0.25, (ry + box_h / 2 + 0.05 + toss_top) / 2,
            "$\\mathbf{Toss}$", ha="left", va="center", fontsize=11, color="#333333")

    # Fit axes
    all_x = [p[0] for p in positions.values()]
    all_y = [p[1] for p in positions.values()]
    ax.set_xlim(min(all_x) - 0.6, max(all_x) + 0.6)
    ax.set_ylim(min(all_y) - 0.6, max(all_y) + 0.8)

    fig.savefig(filename, dpi=180, bbox_inches="tight", pad_inches=0.1)
    plt.close(fig)
    print(f"  Saved {filename}")


def draw_von_neumann_bits(filename):
    """Plot Shannon entropy and Von Neumann bits/toss vs p (linear scale)."""

    ps = np.linspace(0.001, 0.999, 500)

    shannon = [-(p * np.log2(p) + (1 - p) * np.log2(1 - p)) for p in ps]
    vn_bits = [p * (1 - p) for p in ps]

    fig, ax = plt.subplots(1, 1, figsize=(9, 5))
    ax.plot(ps, shannon, label="Entropy", linewidth=1.5)
    ax.plot(ps, vn_bits, label="Extracted (Von Neumann)", linewidth=1.5)
    ax.set_xlabel("Probability $p$", fontsize=12)
    ax.set_ylabel("Bits per toss", fontsize=12)
    ax.set_title("Shannon entropy vs. Von Neumann extraction rate", fontsize=14, fontweight="bold")
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1.05)
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)

    fig.tight_layout()
    fig.savefig(filename, dpi=180, bbox_inches="tight")
    plt.close(fig)
    print(f"  Saved {filename}")


def draw_von_neumann_redundancy(filename):
    """Plot Von Neumann redundancy vs p (linear scale)."""

    ps = np.linspace(0.001, 0.999, 500)

    redundancy = []
    for p in ps:
        h = -(p * np.log2(p) + (1 - p) * np.log2(1 - p))
        vn = p * (1 - p)
        redundancy.append((1 - vn / h) * 100 if h > 0 else 100.0)

    fig, ax = plt.subplots(1, 1, figsize=(9, 5))
    ax.plot(ps, redundancy, linewidth=1.5, color="C1")
    ax.set_xlabel("Probability $p$", fontsize=12)
    ax.set_ylabel("Redundancy (%)", fontsize=12)
    ax.set_title("Von Neumann extractor redundancy", fontsize=14, fontweight="bold")
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 100)
    ax.grid(True, alpha=0.3)

    fig.tight_layout()
    fig.savefig(filename, dpi=180, bbox_inches="tight")
    plt.close(fig)
    print(f"  Saved {filename}")


def draw_peres_redundancy(filename):
    """Plot redundancy of Von Neumann and Peres extractors."""

    # Sample p values on a log2 scale from 1/512 to 1/2
    log2_ps = np.linspace(-9, -1, 500)
    ps = 2.0 ** log2_ps

    trees = [
        (TREE1, "1 node (VN)"),
        (TREE4, "4 nodes"),
        (TREE16, "16 nodes"),
        (TREE64, "64 nodes"),
        (TREE256, "256 nodes"),
    ]

    fig, ax = plt.subplots(1, 1, figsize=(9, 5))

    for tree, label in trees:
        redundancy = []
        for p_float in ps:
            p = mp.mpf(p_float)
            h = -(p * mp.log(p, 2) + (1 - p) * mp.log(1 - p, 2))
            if h > 0:
                red = (1 - float(eval_peres_tree(tree, p) / h)) * 100
            else:
                red = 100.0
            redundancy.append(red)
        ax.plot(ps, redundancy, label=label, linewidth=1.5)

    ax.set_xscale("log", base=2)
    ax.set_xlabel("Probability $p$", fontsize=12)
    ax.set_ylabel("Redundancy (%)", fontsize=12)
    ax.set_title("Redundancy of Peres extractors", fontsize=14, fontweight="bold")
    ax.set_xlim(2**-1, 2**-9)
    ax.set_ylim(0, 100)
    ax.legend(fontsize=10, loc="lower right")
    ax.grid(True, alpha=0.3)
    label_y_at_p_half(ax)

    fig.tight_layout()
    fig.savefig(filename, dpi=180, bbox_inches="tight")
    plt.close(fig)
    print(f"  Saved {filename}")


def draw_elias_redundancy(filename):
    """Plot redundancy of Elias extractors for various batch sizes."""

    # Sample p values on a log2 scale from 1/512 to 1/2
    log2_ps = np.linspace(-9, -1, 500)
    ps = 2.0 ** log2_ps
    h_values = -(ps * np.log2(ps) + (1 - ps) * np.log2(1 - ps))

    batches = [
        (2, "batch=2 (VN)"),
        (4, "batch=4"),
        (16, "batch=16"),
        (64, "batch=64"),
        (127, "batch=127"),
        (128, "batch=128"),
    ]

    fig, ax = plt.subplots(1, 1, figsize=(9, 5))

    for batch, label in batches:
        bpt = eval_sim(batch, "inf", 0, ps)
        redundancy = [(1 - b / h) * 100 if h > 0 else 100.0 for b, h in zip(bpt, h_values)]
        ls = ":" if batch == 127 else "-"
        ax.plot(ps, redundancy, label=label, linewidth=1.5, linestyle=ls)

    ax.set_xscale("log", base=2)
    ax.set_xlabel("Probability $p$", fontsize=12)
    ax.set_ylabel("Redundancy (%)", fontsize=12)
    ax.set_title("Redundancy of Elias extractors", fontsize=14, fontweight="bold")
    ax.set_xlim(2**-1, 2**-9)
    ax.set_ylim(0, 100)
    ax.legend(fontsize=10, loc="lower right")
    ax.grid(True, alpha=0.3)
    label_y_at_p_half(ax)

    fig.tight_layout()
    fig.savefig(filename, dpi=180, bbox_inches="tight")
    plt.close(fig)
    print(f"  Saved {filename}")


def draw_comparison_redundancy(filename):
    """Compare redundancy of various extractors."""

    log2_ps = np.linspace(-9, -1, 500)
    ps = 2.0 ** log2_ps

    fig, ax = plt.subplots(1, 1, figsize=(9, 5))

    # Peres — still computed with mpmath in Python
    redundancy = []
    for p_float in ps:
        p = mp.mpf(p_float)
        h = -(p * mp.log(p, 2) + (1 - p) * mp.log(1 - p, 2))
        if h > 0:
            red = (1 - float(eval_peres_tree(TREE80, p) / h)) * 100
        else:
            red = 100.0
        redundancy.append(red)
    ax.plot(ps, redundancy, label="Peres (nodes=80)", linewidth=1.5, linestyle="-", color="C0")

    # C++ computed series
    h_values = -(ps * np.log2(ps) + (1 - ps) * np.log2(1 - ps))
    for label, ls, color, bpt in [
        ("Elias (batch=62)",     "-", "C1", eval_sim(62, "inf", 0, ps)),
        ("+ carry (batch=62)",   "-", "C2", eval_sim(62, 64, "inf", ps)),
        ("+ modinv (batch=67)",  "-", "C3", eval_sim(67, 64, "inf", ps)),
        ("+ adaptive (batch=*)", "-", "C4", eval_sim("tuned", 64, "inf", ps)),
    ]:
        redundancy = [(1 - b / h) * 100 if h > 0 else 100.0 for b, h in zip(bpt, h_values)]
        ax.plot(ps, redundancy, label=label, linewidth=1.5, linestyle=ls, color=color)

    ax.set_xscale("log", base=2)
    ax.set_xlabel("Probability $p$", fontsize=12)
    ax.set_ylabel("Redundancy (%)", fontsize=12)
    ax.set_title("Redundancy comparison", fontsize=14, fontweight="bold")
    ax.set_xlim(2**-1, 2**-9)
    ax.set_ylim(bottom=0)
    ax.legend(fontsize=10, loc="upper left")
    ax.grid(True, alpha=0.3)
    label_y_at_p_half(ax)

    fig.tight_layout()
    fig.savefig(filename, dpi=180, bbox_inches="tight")
    plt.close(fig)
    print(f"  Saved {filename}")


def draw_overflow_redundancy(filename):
    """Plot redundancy of our extractor at bits=64, for various batch sizes."""

    log2_ps = np.linspace(-9, -1, 500)
    ps = 2.0 ** log2_ps
    h_values = -(ps * np.log2(ps) + (1 - ps) * np.log2(1 - ps))

    batches = [67, 80, 128, 256, 512]

    fig, ax = plt.subplots(1, 1, figsize=(9, 5))

    for batch in batches:
        bpt = eval_sim(batch, 64, "inf", ps)
        redundancy = [(1 - b / h) * 100 if h > 0 else 100.0 for b, h in zip(bpt, h_values)]
        ax.plot(ps, redundancy, label=f"batch={batch}", linewidth=1.5)

    ax.set_xscale("log", base=2)
    ax.set_xlabel("Probability $p$", fontsize=12)
    ax.set_ylabel("Redundancy (%)", fontsize=12)
    ax.set_title("Redundancy due to overflow at bits=64", fontsize=14, fontweight="bold")
    ax.set_xlim(2**-1, 2**-9)
    ax.set_ylim(bottom=0)
    ax.legend(fontsize=10, loc="upper right")
    ax.grid(True, alpha=0.3)
    label_y_at_p_half(ax)

    fig.tight_layout()
    fig.savefig(filename, dpi=180, bbox_inches="tight")
    plt.close(fig)
    print(f"  Saved {filename}")


def draw_adaptive_batch_redundancy(filename):
    """Plot redundancy for fixed batch sizes alongside an adaptive curve."""

    log2_ps = np.linspace(-9, -1, 500)
    ps = 2.0 ** log2_ps
    h_values = -(ps * np.log2(ps) + (1 - ps) * np.log2(1 - ps))

    series = [
        ("batch=64",  ":", eval_sim(64, 64, "inf", ps)),
        ("batch=128", ":", eval_sim(128, 64, "inf", ps)),
        ("batch=256", ":", eval_sim(256, 64, "inf", ps)),
        ("batch=512", ":", eval_sim(512, 64, "inf", ps)),
        ("adaptive",  "-", eval_sim("tuned", 64, "inf", ps)),
    ]

    fig, ax = plt.subplots(1, 1, figsize=(9, 5))

    for label, ls, bpt in series:
        redundancy = [(1 - b / h) * 100 if h > 0 else 100.0 for b, h in zip(bpt, h_values)]
        ax.plot(ps, redundancy, label=label, linewidth=1.5, linestyle=ls)

    ax.set_xscale("log", base=2)
    ax.set_xlabel("Probability $p$", fontsize=12)
    ax.set_ylabel("Redundancy (%)", fontsize=12)
    ax.set_title("Adaptive vs. fixed batch size (bits=64)", fontsize=14, fontweight="bold")
    ax.set_xlim(2**-1, 2**-9)
    ax.set_ylim(bottom=0)
    ax.legend(fontsize=10, loc="upper right")
    ax.grid(True, alpha=0.3)
    label_y_at_p_half(ax)

    fig.tight_layout()
    fig.savefig(filename, dpi=180, bbox_inches="tight")
    plt.close(fig)
    print(f"  Saved {filename}")


def draw_carry_redundancy(filename, filename_normalized):
    """Compare redundancy across carry widths, and show normalized gap."""

    log2_ps = np.linspace(-9, -1, 500)
    ps = 2.0 ** log2_ps
    h_values = -(ps * np.log2(ps) + (1 - ps) * np.log2(1 - ps))

    series = [
        ("carry=0 (Elias)",     eval_sim(32, "inf", 0, ps)),
        ("carry=2",     eval_sim(32, "inf", 2, ps)),
        ("carry=4",     eval_sim(32, "inf", 4, ps)),
        ("carry=6",     eval_sim(32, "inf", 6, ps)),
        ("carry=8",     eval_sim(32, "inf", 8, ps)),
        ("carry=∞", eval_sim(32, "inf", "inf", ps)),
    ]

    all_redundancy = []
    for label, bpt in series:
        redundancy = np.array([(1 - b / h) * 100 if h > 0 else 100.0 for b, h in zip(bpt, h_values)])
        all_redundancy.append((label, redundancy))

    # --- Absolute redundancy ---
    fig, ax = plt.subplots(1, 1, figsize=(9, 5))
    for label, redundancy in all_redundancy:
        ax.plot(ps, redundancy, label=label, linewidth=1.5)
    ax.set_xscale("log", base=2)
    ax.set_xlabel("Probability $p$", fontsize=12)
    ax.set_ylabel("Redundancy (%)", fontsize=12)
    ax.set_title("Effect of carry width (batch=32)", fontsize=14, fontweight="bold")
    ax.set_xlim(2**-1, 2**-9)
    ax.set_ylim(bottom=0)
    ax.legend(fontsize=10, loc="upper right")
    ax.grid(True, alpha=0.3)
    label_y_at_p_half(ax)
    fig.tight_layout()
    fig.savefig(filename, dpi=180, bbox_inches="tight")
    plt.close(fig)
    print(f"  Saved {filename}")

    # --- Normalized: (red - red_inf) / (red_0 - red_inf) ---
    red_0 = all_redundancy[0][1]
    red_inf = all_redundancy[-1][1]
    gap = red_0 - red_inf

    fig, ax = plt.subplots(1, 1, figsize=(9, 5))
    for label, redundancy in all_redundancy:
        normalized = np.where(gap > 0, (redundancy - red_inf) / gap * 100, 0.0)
        ax.plot(ps, normalized, label=label, linewidth=1.5)
    ax.set_xscale("log", base=2)
    ax.set_xlabel("Probability $p$", fontsize=12)
    ax.set_ylabel("Remaining gap (%)", fontsize=12)
    ax.set_title("Normalized effect of carry width (batch=32)", fontsize=14, fontweight="bold")
    ax.set_xlim(2**-1, 2**-9)
    ax.set_ylim(0, 100)
    ax.legend(fontsize=10, loc="upper right")
    ax.grid(True, alpha=0.3)
    label_y_at_p_half(ax)
    fig.tight_layout()
    fig.savefig(filename_normalized, dpi=180, bbox_inches="tight")
    plt.close(fig)
    print(f"  Saved {filename_normalized}")


def draw_binomial_encoding(filename):
    """Draw the recursive binomial encoding for n=5, k=2, tracing V through 10100."""

    fig, ax = plt.subplots(1, 1, figsize=(11, 5.5))

    char_w = 0.13
    pad = 0.07
    col_gap = 0.10
    row_h = 0.55
    row_gap = 0.45
    cell_w = 5 * char_w + 2 * pad + col_gap

    color_0 = "#b3d4f7"
    color_1 = "#f7c4b3"
    border = "#444444"

    # Each row: (sequences, global_positions, blue_count, v_global_pos, s_label, v_label)
    # V=10100 maps to global position 8 at every level.
    rows = [
        (["00011","00101","00110","01001","01010","01100","10001","10010","10100","11000"],
         list(range(10)), 6, 8, "$S = \\binom{5}{2} = 10$", "$V = 8$"),
        (["0001","0010","0100","1000"],
         [6,7,8,9], 3, 8, "$S = \\binom{4}{1} = 4$", "$V = 2$"),
        (["001","010","100"],
         [6,7,8], 2, 8, "$S = \\binom{3}{1} = 3$", "$V = 2$"),
        (["00"],
         [8], 1, 8, "$S = \\binom{2}{0} = 1$", "$V = 0$"),
        (["0"],
         [8], 1, 8, "$S = \\binom{1}{0} = 1$", "$V = 0$"),
    ]
    events = [1, 0, 1, 0]
    n_rows = len(rows)

    bw = cell_w - col_gap  # uniform box width for all rows

    for row_idx, (seqs, gpos, blue_count, v_gpos, s_label, v_label) in enumerate(rows):
        # Flip: row 0 (biggest) at bottom, row 4 (smallest) at top.
        flipped = n_rows - 1 - row_idx
        y = -(flipped * (row_h + row_gap))

        for i, (seq, gp) in enumerate(zip(seqs, gpos)):
            x_center = gp * cell_w + cell_w / 2
            x = x_center - bw / 2

            if blue_count is not None and i < blue_count:
                color = color_0
            elif blue_count is not None:
                color = color_1
            else:
                color = "#ddd"

            is_v = (gp == v_gpos)
            lw = 2.0 if is_v else 0.8
            ec = "#000000" if is_v else border

            rect = patches.FancyBboxPatch(
                (x, y), bw, row_h,
                boxstyle="round,pad=0.02", facecolor=color,
                edgecolor=ec, linewidth=lw
            )
            ax.add_patch(rect)
            ax.text(x_center, y + row_h / 2, seq,
                    ha="center", va="center", fontsize=7, fontfamily="monospace",
                    color="#222222", fontweight="bold" if is_v else "normal")

        # Left labels
        leftmost_x = gpos[0] * cell_w
        ax.text(leftmost_x - 0.15, y + row_h / 2 + 0.12, s_label,
                ha="right", va="center", fontsize=9, color="#333333")
        ax.text(leftmost_x - 0.15, y + row_h / 2 - 0.12, v_label,
                ha="right", va="center", fontsize=9, color="#333333")

        # Event annotation: arrow points downward from this row to the next bigger row.
        if row_idx < len(events):
            ev = events[row_idx]
            v_x = v_gpos * cell_w + cell_w / 2
            next_flipped = n_rows - 1 - (row_idx + 1)
            next_y = -(next_flipped * (row_h + row_gap))
            mid_y = (y + next_y + row_h) / 2
            ax.text(v_x + 0.55, mid_y, f"event = {ev}",
                    ha="left", va="center", fontsize=8, color="#555555", style="italic")
            ax.annotate("", xy=(v_x, y - 0.02),
                        xytext=(v_x, next_y + row_h + 0.02),
                        arrowprops=dict(arrowstyle="->", color="#888888", lw=1.0))

    # Final event annotation at top
    v_x = 8 * cell_w + cell_w / 2
    top_row_y = 0
    ax.text(v_x + 0.55, top_row_y + row_h + 0.15, "event = 0",
            ha="left", va="center", fontsize=8, color="#555555", style="italic")

    y_bottom = -(n_rows - 1) * (row_h + row_gap)
    ax.set_xlim(-1.8, 10 * cell_w + 0.3)
    ax.set_ylim(y_bottom - 0.3, row_h + 1.0)
    ax.set_aspect("equal")
    ax.set_axis_off()

    fig.tight_layout()
    fig.savefig(filename, dpi=180, bbox_inches="tight", pad_inches=0.15)
    plt.close(fig)
    print(f"  Saved {filename}")


# Map from output filename(s) to a callable(outdir) that generates them.
TARGETS = {}

def _register(filenames, gen):
    """Register a generator for one or more output filenames."""
    if isinstance(filenames, str):
        filenames = [filenames]
    entry = (gen, filenames)
    for f in filenames:
        TARGETS[f] = entry

_register("binomial_encoding.png",
    lambda d: draw_binomial_encoding(os.path.join(d, "binomial_encoding.png")))
_register("peres_tree12.png",
    lambda d: draw_tree_diagram(TREE12, os.path.join(d, "peres_tree12.png")))
_register("von_neumann_bits.png",
    lambda d: draw_von_neumann_bits(os.path.join(d, "von_neumann_bits.png")))
_register("von_neumann_redundancy.png",
    lambda d: draw_von_neumann_redundancy(os.path.join(d, "von_neumann_redundancy.png")))
_register("peres_redundancy.png",
    lambda d: draw_peres_redundancy(os.path.join(d, "peres_redundancy.png")))
_register("elias_redundancy.png",
    lambda d: draw_elias_redundancy(os.path.join(d, "elias_redundancy.png")))
_register("comparison_redundancy.png",
    lambda d: draw_comparison_redundancy(os.path.join(d, "comparison_redundancy.png")))
_register("overflow_redundancy.png",
    lambda d: draw_overflow_redundancy(os.path.join(d, "overflow_redundancy.png")))
_register("adaptive_batch_redundancy.png",
    lambda d: draw_adaptive_batch_redundancy(os.path.join(d, "adaptive_batch_redundancy.png")))
_register(["carry_redundancy.png", "carry_normalized_redundancy.png"],
    lambda d: draw_carry_redundancy(os.path.join(d, "carry_redundancy.png"),
                                    os.path.join(d, "carry_normalized_redundancy.png")))

if __name__ == "__main__":
    outdir = os.path.dirname(os.path.abspath(__file__))
    requested = sys.argv[1:] if len(sys.argv) > 1 else list(TARGETS)
    seen = set()
    for name in requested:
        if name not in TARGETS:
            sys.exit(f"error: unknown target {name!r}")
        gen, all_files = TARGETS[name]
        if id(gen) in seen:
            continue
        seen.add(id(gen))
        gen(outdir)
