// Copyright (c) 2026 Pieter Wuille
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to furnish to do so, subject to the following
// conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

// ---------------------------------------------------------------------------
// CLI for the binomial randomness extractor. Subcommands:
//
//   exhaustive_test [M N]              correctness test
//   sim_bench BATCH P                  throughput benchmark
//   model_rate BATCH BITS CARRY        modelled extraction rate
//
// Running with no arguments is equivalent to `exhaustive_test` with defaults.
// ---------------------------------------------------------------------------

#include "extractor.h"

#include <algorithm>
#include <array>
#include <bit>
#include <chrono>
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <map>
#include <numbers>
#include <random>
#include <span>
#include <string_view>
#include <unordered_map>
#include <utility>
#include <vector>

namespace {

// ---------------------------------------------------------------------------
// exhaustive_test
// ---------------------------------------------------------------------------

constexpr uint64_t MAX_SEQUENCES = 10'000'000'000ULL;

uint64_t multinomial(const std::vector<int>& freqs)
{
    int total = 0;
    for (int f : freqs) total += f;
    uint64_t num = 1;
    for (int i = 1; i <= total; ++i) num *= i;
    uint64_t den = 1;
    for (int f : freqs) for (int i = 1; i <= f; ++i) den *= i;
    return num / den;
}

bool run_config(unsigned m, unsigned n)
{
    uint64_t total_seqs = 1;
    for (unsigned i = 0; i < n; ++i) {
        if (total_seqs > MAX_SEQUENCES / m) {
            std::cerr << "  (n=" << n << ", M=" << m << "): "
                      << m << "^" << n << " exceeds " << MAX_SEQUENCES
                      << " sequence limit\n";
            return false;
        }
        total_seqs *= m;
    }
    std::map<std::vector<int>, std::vector<uint64_t>> by_freq;
    for (uint64_t seq = 0; seq < total_seqs; ++seq) {
        // Decode seq as n digits in base m.
        std::vector<uint8_t> events(n);
        std::vector<int> freqs(m, 0);
        uint64_t s = seq;
        for (unsigned i = 0; i < n; ++i) {
            unsigned e = s % m;
            s /= m;
            events[i] = e;
            ++freqs[e];
        }
        UniformSpan r;
        if (m == 2) {
            r = process_batch(events);
            // Exercise the alternative binary coders and verify agreement.
            std::vector<unsigned> positions;
            for (unsigned i = 0; i < n; ++i) {
                if (events[i]) positions.push_back(i);
            }
            UniformSpan r_low_k = process_batch_low_k(positions, n);
            UniformSpan r_table = process_batch_table(positions, n);
            if (r.span != r_low_k.span || r.value != r_low_k.value) {
                std::cout << "FAIL low_k disagreement (n=" << n << ")\n";
                return false;
            }
            if (r.span != r_table.span || r.value != r_table.value) {
                std::cout << "FAIL table disagreement (n=" << n << ")\n";
                return false;
            }
        } else {
            r = process_batch_multi(events, m);
        }
        uint64_t expected = multinomial(freqs);
        if (r.span != expected) {
            std::cout << "FAIL span (n=" << n << ", M=" << m << ") seq=" << seq
                      << " got=" << r.span << " expected=" << expected << "\n";
            return false;
        }
        if (r.value >= r.span) {
            std::cout << "FAIL value out of range (n=" << n << ", M=" << m << ")\n";
            return false;
        }
        by_freq[freqs].push_back(r.value);
    }
    for (auto& [freqs, values] : by_freq) {
        std::sort(values.begin(), values.end());
        for (size_t i = 0; i < values.size(); ++i) {
            if (values[i] != i) {
                std::cout << "FAIL permutation (n=" << n << ", M=" << m << ")\n";
                return false;
            }
        }
    }
    std::cout << "  (n=" << n << ", M=" << m << "): " << total_seqs << " sequences OK\n";
    return true;
}

constexpr std::pair<unsigned, unsigned> DEFAULT_CONFIGS[] = {
    {2, 1}, {2, 5}, {2, 10}, {2, 13},
    {3, 4}, {3, 8},
    {4, 3}, {4, 6},
    {5, 5},
    {6, 4},
    {10, 3}, {10, 4},
};

// ---------------------------------------------------------------------------
// sim_bench
// ---------------------------------------------------------------------------

// Xoshiro256++ PRNG.
struct Xoshiro256pp {
    uint64_t s[4];

    Xoshiro256pp(std::random_device& rd)
    {
        for (auto& v : s) v = (uint64_t(rd()) << 32) | rd();
    }

    uint64_t operator()() noexcept
    {
        uint64_t result = std::rotl(s[0] + s[3], 23) + s[0];
        uint64_t t = s[1] << 17;
        s[2] ^= s[0]; s[3] ^= s[1]; s[1] ^= s[2]; s[0] ^= s[3];
        s[2] ^= t;
        s[3] = std::rotl(s[3], 45);
        return result;
    }
};

void cmd_sim_bench(unsigned batch, double p) noexcept
{
    constexpr size_t POOL_SIZE = 128 * 1024 * 1024;
    constexpr unsigned EVENTS_PER_ROUND = 4 * 1024 * 1024;

    std::random_device rd;
    Xoshiro256pp rng(rd);

    // Generate biased boolean pool. Use a 64-bit threshold comparison:
    // each RNG output < threshold means true.
    std::vector<uint8_t> pool(POOL_SIZE);
    uint64_t threshold = static_cast<uint64_t>(p * 18446744073709551616.0);
    for (size_t i = 0; i < POOL_SIZE; ++i) {
        pool[i] = rng() < threshold;
    }

    unsigned batches_per_round = EVENTS_PER_ROUND / batch;
    if (batches_per_round == 0) batches_per_round = 1;
    double entropy = -(p * std::log2(p) + (1.0 - p) * std::log2(1.0 - p));
    size_t symbols_per_round = size_t(batches_per_round) * batch;

    uint64_t total_symbols = 0;
    uint64_t total_bits = 0;
    auto t_start = std::chrono::steady_clock::now();
    auto t_last = t_start;

    std::fprintf(stderr, "sim_bench: batch=%u, p=%.6g\n", batch, p);

    UniformSpan accum{1, 0};
    for (;;) {
        // Pick a random starting position that fits the round.
        size_t max_start = POOL_SIZE - symbols_per_round;
        size_t start = rng() % (max_start + 1);

        for (unsigned b = 0; b < batches_per_round; ++b) {
            std::span<const uint8_t> events(pool.data() + start + size_t(b) * batch, batch);
            UniformSpan batch_result = process_batch(events);
            // Reduce batch_result to 32 bits.
            while (batch_result.value >= batch_result.span || (batch_result.span >> 32) != 0) {
                auto bit = extract_bit(batch_result);
                if (bit.has_value()) ++total_bits;
            }
            accum = merge(accum, batch_result);
            // Reduce accumulator to 32 bits.
            while (accum.value >= accum.span || (accum.span >> 32) != 0) {
                auto bit = extract_bit(accum);
                if (bit.has_value()) ++total_bits;
            }
        }
        total_symbols += symbols_per_round;

        auto now = std::chrono::steady_clock::now();
        double elapsed = std::chrono::duration<double>(now - t_last).count();
        if (elapsed >= 1.0) {
            double total_elapsed = std::chrono::duration<double>(now - t_start).count();
            double sym_per_sec = total_symbols / total_elapsed;
            double bits_per_sec = total_bits / total_elapsed;
            double bits_per_sym = double(total_bits) / double(total_symbols);
            double redundancy = (1.0 - bits_per_sym / entropy) * 100.0;
            std::fprintf(stderr, "  %.1fs: %.2f Msym/s in, %.2f Mbit/s out, %.6f bit/sym extracted, %.2f%% redundancy\n",
                         total_elapsed, sym_per_sec / 1e6, bits_per_sec / 1e6, bits_per_sym, redundancy);
            t_last = now;
        }
    }
}

// ---------------------------------------------------------------------------
// model_rate: compute expected bits/toss for given extractor parameters.
// Uses 128-bit modular arithmetic to model extraction rate across a range
// of probabilities. Reads p values from stdin, writes bits-per-toss to stdout.
// ---------------------------------------------------------------------------

using uint128 = unsigned __int128;

/** Count trailing zeroes for uint128. */
constexpr unsigned ctz128(uint128 x) noexcept
{
    if (x == 0) return 128;
    uint64_t lo = x;
    if (lo != 0) return std::countr_zero(lo);
    return 64 + std::countr_zero(uint64_t(x >> 64));
}

/** Modular inverse of odd, modulo 2^128. */
constexpr uint128 mod_inverse128(uint128 odd) noexcept
{
    // 4-bit-correct seed.
    uint64_t o = odd;
    uint64_t inv = o ^ (((o + 1) & 4) << 1);
    // Four Newton-Hensel iterations in 64-bit.
    inv *= 2 - inv * o;
    inv *= 2 - inv * o;
    inv *= 2 - inv * o;
    inv *= 2 - inv * o;
    // Final Newton-Hensel iteration for full 128-bit precision.
    uint128 inv128 = inv;
    inv128 *= 2 - inv128 * odd;
    return inv128;
}

/** Return {odd, shift} such that x == (odd << shift), with smallest possible odd. */
constexpr std::pair<uint128, unsigned> split_odd128(uint128 x) noexcept
{
    if (x == 0) return {0, 128};
    unsigned shift = ctz128(x);
    return {x >> shift, shift};
}

/** Compute (n choose k) mod 2^128 using a precomputed factorial table. */
uint128 binomial_coef128(unsigned n, unsigned k) noexcept
{
    // Precomputed table with factorials in (odd, shift) form + modular inverses.
    struct FacEntry { uint128 odd, inv; unsigned shift; };
    static auto constexpr FAC_TABLE = []{
        std::array<FacEntry, 4096> tbl{};
        uint128 fact_odd = 1;
        unsigned fact_shift = 0;
        for (unsigned i = 0; i < tbl.size(); i++) {
            if (i > 0) {
                auto [i_odd, i_shift] = split_odd128(uint128{i});
                fact_odd *= i_odd;
                fact_shift += i_shift;
            }
            tbl[i] = {fact_odd, mod_inverse128(fact_odd), fact_shift};
        }
        return tbl;
    }();

    if (k > n) return 0;
    auto& ne = FAC_TABLE[n];
    auto& ke = FAC_TABLE[k];
    auto& me = FAC_TABLE[n - k];
    return (ne.odd * ke.inv * me.inv) << (ne.shift - ke.shift - me.shift);
}

/** Precomputed table with log2(n!). */
const struct Log2FactorialTable {
    std::array<long double, 4096> val;
    Log2FactorialTable() {
        for (unsigned i = 0; i < val.size(); i++) {
            val[i] = std::lgamma(i + 1.0L) * std::numbers::log2e_v<long double>;
        }
    }
    long double operator[](unsigned i) const { return val[i]; }
} LOG2_FACTORIAL;

/** Compute log2(n choose k). */
long double log2_binomial_coef(unsigned n, unsigned k) noexcept
{
    return LOG2_FACTORIAL[n] - (LOG2_FACTORIAL[k] + LOG2_FACTORIAL[n - k]);
}

/** Equivalent of std::bit_width for 128-bit integers. */
constexpr unsigned bit_width128(uint128 x) noexcept
{
    if (x == 0) return 0;
    uint64_t hi = x >> 64;
    if (hi != 0) return 64 + std::bit_width(hi);
    return std::bit_width(uint64_t(x));
}

/** Precomputed uniform range for randomness extraction. */
struct Range {
    /** Size of the range. */
    uint128 size;
    /** How many times each value within the range is reached. */
    long double count;
};
/** Up to two ranges per (n, k) combination. The second is for overflow handling. */
using RangeInfo = std::array<Range, 2>;

/** Precompute ranges for batches of n binary observations, without overflow. */
std::vector<RangeInfo> precompute(unsigned n)
{
    std::vector<RangeInfo> ranges(n + 1);
    for (unsigned k = 0; k <= n; k++)
        ranges[k] = {Range{binomial_coef128(n, k), 1.0L}, Range{0, 0.0L}};
    return ranges;
}

/** Precompute ranges with integer arithmetic modulo 2^bits, with overflow handling. */
std::vector<RangeInfo> precompute(unsigned n, unsigned bits)
{
    std::vector<RangeInfo> ranges(n + 1);
    for (unsigned k = 0; k <= n; k++) {
        uint128 cmb = binomial_coef128(n, k);
        long double log2_ov = log2_binomial_coef(n, k) - (long double)(bits);
        uint128 mask = (bits < 128) ? ((uint128{1} << bits) - 1) : ~uint128{0};
        uint128 r = cmb & mask;
        if (log2_ov > 0.0L) {
            // Overflow occurred: residual range has count overflows+1, overflow range has count overflows.
            long double overflows = std::floor(std::exp2(log2_ov));
            uint128 overflow_r = (mask - r) + 1;
            ranges[k] = {Range{r, overflows + 1.0L}, Range{overflow_r, overflows}};
        } else {
            // No overflow.
            ranges[k] = {Range{r, 1.0L}, Range{0, 0.0L}};
        }
    }
    return ranges;
}

/** Determine bits extracted per toss, when reducing fully after every batch (no carry). */
long double model_no_carry(long double p, const std::vector<RangeInfo>& ranges) noexcept
{
    unsigned batch = ranges.size() - 1;
    long double hi = std::max(p, 1.0L - p);
    long double lo = std::min(p, 1.0L - p);
    long double ratio = lo / hi;
    long double prob = std::pow(hi, batch);
    long double ret = 0.0L;
    for (unsigned k = 0; k <= batch; k++) {
        if (prob != 0.0L) {
            for (auto [size, count] : ranges[k]) {
                // Decompose size into powers of two; each power-of-two block
                // of size 2^b contributes b bits per arrangement.
                auto c = size;
                while (c > 1) {
                    unsigned b = bit_width128(c) - 1;
                    ret += count * b * std::ldexp(prob, static_cast<int>(b));
                    c -= uint128{1} << b;
                }
            }
        }
        prob *= ratio;
    }
    return ret / batch;
}

/** Determine bits extracted per toss, assuming infinite carry precision. */
long double model_full_carry(long double p, const std::vector<RangeInfo>& ranges) noexcept
{
    unsigned batch = ranges.size() - 1;
    long double hi = std::max(p, 1.0L - p);
    long double lo = std::min(p, 1.0L - p);
    long double ratio = lo / hi;
    long double prob = std::pow(hi, batch);
    long double ret = 0.0L;
    for (unsigned k = 0; k <= batch; k++) {
        if (prob != 0.0L) {
            for (auto [size, count] : ranges[k]) {
                if (size > 0 && count > 0.0L) {
                    // With infinite carry, each range of size S contributes
                    // exactly log2(S) fractional bits.
                    long double size_d = size;
                    ret += std::log2(size_d) * size_d * (count * prob);
                }
            }
        }
        prob *= ratio;
    }
    return ret / batch;
}

/** Determine bits extracted per toss, when reducing to (span <= 2^carry_bits) after
 *  every batch. Iterates the Markov chain over carry states to find the steady-state
 *  distribution before measuring output. */
long double model_bounded_carry(long double p, const std::vector<RangeInfo>& ranges, unsigned carry_bits) noexcept
{
    unsigned batch = ranges.size() - 1;
    auto max_state = uint128{1} << carry_bits;
    long double hi = std::max(p, 1.0L - p);
    long double lo = std::min(p, 1.0L - p);
    long double ratio = lo / hi;

    // Initialize state distribution over spans to be uniform over
    // (2^(carry_bits-1), 2^carry_bits].
    std::vector<long double> state_prob(size_t(max_state) + 1, 0.0L);
    auto half_state = uint128{1} << (carry_bits - 1);
    long double init_prob = 1.0L / (long double)(max_state - half_state);
    for (size_t s = size_t(half_state) + 1; s <= size_t(max_state); s++)
        state_prob[s] = init_prob;

    long double total_bits = 0.0L;
    long double prev_bits = -1.0L;

    // Run possibly many iterations to find steady-state distribution over spans.
    for (unsigned iter = 0; iter < 16384; iter++) {
        std::vector<long double> new_state_prob(size_t(max_state) + 1, 0.0L);
        total_bits = 0.0L;

        long double prob = std::pow(hi, batch);
        for (unsigned k = 0; k <= batch; k++) {
            if (prob == 0.0L) break;
            // Iterate over all possible previous span sizes.
            for (size_t s = 1; s <= size_t(max_state); s++) {
                if (state_prob[s] == 0.0L) continue;
                // Probability per arrangement (size cancels: size/prod = size/(s*size) = 1/s).
                long double sp = state_prob[s] / (long double)s;
                // Iterate over all outcomes of a batch.
                for (auto [size, count] : ranges[k]) {
                    if (size == 0) continue;
                    // Reduce the product span to be <= 2^carry_bits, remembering the
                    // number of bits produced along the way. weight tracks probability
                    // per arrangement.
                    long double weight = (count * prob) * sp;
                    uint128 prod = uint128{s} * size;
                    unsigned bits = 0;
                    while (prod > max_state) {
                        // Deal with odd span: one arrangement resets to state 1.
                        if (prod & 1) {
                            new_state_prob[1] += weight;
                            total_bits += bits * weight;
                            --prod;
                        }
                        // Produce 1 bit by reducing span by a factor 2.
                        prod >>= 1;
                        weight += weight;
                        ++bits;
                    }
                    long double remain = weight * (long double)prod;
                    new_state_prob[size_t(prod)] += remain;
                    total_bits += bits * remain;
                }
            }
            prob *= ratio;
        }

        // Update and normalize state distribution.
        state_prob = std::move(new_state_prob);
        long double sum = 0.0L;
        for (size_t s = 1; s <= size_t(max_state); s++) sum += state_prob[s];
        long double inv_sum = 1.0L / sum;
        for (size_t s = 1; s <= size_t(max_state); s++) state_prob[s] *= inv_sum;
        // Check convergence.
        if (prev_bits >= 0.0L && std::abs(total_bits - prev_bits) <= 1e-15L * total_bits)
            break;
        prev_bits = total_bits;
    }

    return total_bits / batch;
}

// ---------------------------------------------------------------------------
// Tune tables: adaptive batch selection as a function of -log2(p).
// ---------------------------------------------------------------------------

struct TuneEntry { long double log2p; unsigned batch; };

constexpr TuneEntry TUNE_8[] = {
    {1.0, 10}, {2.3152182988734133, 12}, {3.384661710849709, 15},
    {3.653265566597382, 16}, {3.6608995269358804, 18}, {3.84505604130457, 19},
    {3.8861498377256023, 22}, {4.595960283546672, 23}, {5.464898809524952, 32},
    {5.942684363699531, 40}, {6.149483783654269, 51}, {6.6660591003647784, 56},
    {6.877719522574525, 72}, {7.463858291293954, 82}, {7.521861095367855, 91},
    {7.804817514888896, 109}, {8.34085718573971, 140}, {8.730730752642367, 154},
    {9.014854872720187, 183}, {9.094788162949737, 214}, {9.33282845124401, 229},
    {9.584490992966657, 241}, {10.859222565631118, 256}
};

constexpr TuneEntry TUNE_16[] = {
    {1.0, 18}, {1.9292649225190597, 19}, {2.2185511575504577, 20},
    {2.23317795648826, 21}, {2.715465118412977, 24}, {3.004089007232924, 26},
    {3.4270236187817593, 29}, {3.536150996608597, 30}, {3.548812494298568, 32},
    {3.719958085525595, 35}, {4.144548033443606, 36}, {4.190533610731624, 37},
    {4.384873336496124, 43}, {4.537483483608158, 48}, {4.894194696239511, 57},
    {5.15186370631363, 61}, {5.233812221156851, 66}, {5.405465968899529, 70},
    {5.814595645074128, 74}, {6.080427695771663, 93}, {6.473557542071636, 106},
    {6.579076865224145, 117}, {6.818160800028085, 134}, {7.169517714712086, 141},
    {7.220574592569081, 164}, {7.43737000420301, 193}, {7.8690443933434295, 218},
    {7.935603559127777, 226}, {7.976376514836684, 236}, {8.207061486379235, 278},
    {8.385241504376094, 283}, {8.43405524476637, 314}, {8.6777476804653, 352},
    {9.458878055429896, 362}, {9.792835682024329, 627}, {10.438425102850397, 724},
    {10.716304837689812, 810}, {10.876981872221705, 958}, {11.342934582267066, 1087},
    {11.392064763759265, 1145}, {11.402508317361985, 1306}, {11.798931839105926, 1449}
};

constexpr TuneEntry TUNE_32[] = {
    {1.0, 34}, {1.3053219078882456, 35}, {1.821390573941862, 36},
    {1.987568095642787, 37}, {2.1272987193897084, 39}, {2.3700173287735993, 41},
    {2.4791970247631077, 42}, {2.640101451563181, 44}, {2.723395605312736, 46},
    {2.9652207685177236, 49}, {2.984539325363781, 51}, {3.232190163881575, 52},
    {3.3427201270855935, 57}, {3.477765888694056, 61}, {3.722316439549277, 62},
    {3.7455763777858095, 63}, {3.8328682813195982, 69}, {4.018244695205404, 75},
    {4.155989345005922, 79}, {4.301991645374505, 80}, {4.314868129950261, 83},
    {4.52889004247111, 90}, {4.547787430817317, 91}, {4.602324004955512, 92},
    {4.627344947334828, 97}, {4.693886476835967, 101}, {4.856143458098374, 104},
    {4.911147427078689, 113}, {5.098797401611607, 119}, {5.251868879918149, 123},
    {5.387804682348882, 136}, {5.444081442345237, 138}, {5.602640029290084, 147},
    {5.717216076853364, 154}, {5.73578293867853, 160}, {5.774053996596445, 165},
    {5.9283100974406855, 173}, {5.931823387269357, 177}, {6.012051858559962, 185},
    {6.057536851794875, 192}, {6.13713037771212, 203}, {6.2818377792987565, 206},
    {6.295474539019582, 209}, {6.315667165823255, 214}, {6.4484378487969645, 222},
    {6.6352189070271965, 252}, {6.842393672505558, 276}, {7.01466312711409, 292},
    {7.090593723169471, 317}, {7.240126075164868, 335}, {7.280732374100063, 343},
    {7.393437975979448, 375}, {7.546494226817195, 385}, {7.581186420249448, 414},
    {7.732045112414076, 442}, {7.763595961723617, 455}, {7.873806052218208, 485},
    {7.996529815027815, 509}, {8.090518759201501, 537}, {8.11669279010144, 549},
    {8.197162848172644, 559}, {8.509031928150128, 568}, {8.511211393347118, 671},
    {8.682806008992417, 675}, {8.79956818525292, 747}, {8.934136251044503, 803},
    {9.049022026088133, 849}, {9.172078292843807, 888}, {9.185632604115451, 923},
    {9.306047043258552, 983}, {9.329888005645472, 1009}, {9.459493312482316, 1056},
    {9.575405817154998, 1152}, {9.686208736854969, 1200}, {9.734208960410028, 1242},
    {9.77822099731279, 1293}, {9.882550012038184, 1349}
};

constexpr TuneEntry TUNE_64[] = {
    {1.0, 67}, {1.5423534776474277, 68}, {1.6808828080591875, 69},
    {1.7528506537683572, 70}, {1.8651028232027407, 71}, {1.9467816232708224, 72},
    {2.0094888301918123, 74}, {2.138046740077643, 76}, {2.2544517914756894, 78},
    {2.3600756739027124, 81}, {2.439193723412559, 82}, {2.5291164288743286, 84},
    {2.572113792050173, 86}, {2.692150882404686, 89}, {2.7394384271342678, 91},
    {2.857296595653468, 94}, {2.9151254523785513, 96}, {2.9443256442743344, 97},
    {3.0459737437810004, 101}, {3.095243574183795, 103}, {3.1949154339913064, 105},
    {3.258944004549513, 109}, {3.329309589788051, 112}, {3.367174513846308, 114},
    {3.4607959177839827, 116}, {3.501335291518983, 121}, {3.6128565935353345, 124},
    {3.615788049640245, 126}, {3.688334668707313, 128}, {3.715691420786384, 129},
    {3.7419428657117546, 130}, {3.8038910772469947, 136}, {3.866517376849585, 140},
    {3.95347940276595, 145}, {3.96841427404187, 146}, {4.014841563681518, 149},
    {4.108962523239413, 150}, {4.135679932075356, 157}, {4.228371505613817, 162},
    {4.2834638209684215, 165}, {4.305778543189465, 170}, {4.357690355508934, 176},
    {4.480855380777648, 178}, {4.546478513263544, 187}, {4.605890664130133, 193},
    {4.671979922148429, 197}, {4.6935190690197786, 200}, {4.744482183233984, 206},
    {4.776447747956903, 208}, {4.787954684658638, 213}, {4.860386153284255, 217},
    {4.940174655544418, 218}, {4.980135650825773, 231}, {5.063378884374847, 239},
    {5.13702137157296, 245}, {5.166562951823467, 253}, {5.228653976749275, 259},
    {5.274598961216851, 264}, {5.324413974451085, 266}, {5.337470968708146, 278},
    {5.444900398384398, 282}, {5.514303706992132, 300}, {5.607135203891892, 311},
    {5.659083068053459, 319}, {5.714600334283497, 331}, {5.808853644604225, 343},
    {5.877254667153872, 354}, {5.8996348691013445, 357}, {5.901854954291529, 363},
    {5.92216888394571, 367}, {5.965597936317451, 380}, {6.11736077793662, 387},
    {6.160268400338973, 412}, {6.260475639682002, 431}, {6.337379799148384, 454},
    {6.432731649704969, 462}, {6.4466048164364835, 481}, {6.496460829109529, 486},
    {6.58040875519743, 509}, {6.606687582058371, 530}, {6.693361804465782, 547},
    {6.789506646818857, 570}, {6.901296387874755, 575}, {6.9260529093700605, 577},
    {6.94315277294724, 617}, {7.025907331379444, 623}, {7.067678625307426, 652},
    {7.132317358413498, 673}, {7.184311456409102, 703}, {7.260557207938804, 726},
    {7.335621306941565, 760}, {7.424285573923436, 799}, {7.4494022268263755, 804},
    {7.515095938311953, 816}, {7.523870745084778, 831}, {7.558224704381075, 877},
    {7.677907591228092, 899}, {7.683784189043544, 906}, {7.6890956517739975, 933},
    {7.7951445595498585, 955}, {7.879901010304549, 967}, {7.92225407146546, 1049},
    {8.04707347562178, 1107}, {8.101476637703492, 1109}, {8.169804857486305, 1149},
    {8.19219675123579, 1182}, {8.223414021453852, 1209}, {8.264955387810296, 1253},
    {8.356919374486434, 1304}, {8.432878724226821, 1332}, {8.437343920235701, 1344},
    {8.478544802068157, 1387}, {8.517357734345595, 1405}, {8.545034707365541, 1445},
    {8.551922746948694, 1458}, {8.61280101469098, 1532}, {8.65098197253091, 1550},
    {8.708577391164711, 1571}, {8.758351976178243, 1621}, {8.784740325158289, 1662},
    {8.79689507173195, 1690}, {8.87434823946105, 1753}, {8.921658264987187, 1794},
    {8.95316233091659, 1836}, {9.003896288916454, 1900}, {9.200198453330792, 2098},
    {9.27579801907579, 2111}, {9.34464474695715, 2230}, {9.398661831367296, 2331},
    {9.49357800109032, 2332}, {9.521442551767045, 2408}, {9.526484787160781, 2526},
    {9.675403225736204, 2619}, {9.686321599549105, 2658}, {9.691109181496035, 2728},
    {9.766445426470977, 2788}, {9.76975765531401, 2842}, {9.82135718525645, 2913},
    {9.830619398428873, 2934}, {9.894500372881463, 3063}, {9.927615586643217, 3094}
};

/** Look up the optimal batch size for a given p and bit width from the tune tables. */
unsigned find_tuned_batch(long double p, unsigned bits) noexcept
{
    long double log2p = -std::log2(p);
    std::span<const TuneEntry> table;
    switch (bits) {
    case  8: table = TUNE_8;  break;
    case 16: table = TUNE_16; break;
    case 32: table = TUNE_32; break;
    case 64: table = TUNE_64; break;
    default: std::fprintf(stderr, "unsupported bits=%u\n", bits); return 0;
    }
    unsigned batch = table.front().batch;
    for (auto [thresh, b] : table) {
        if (log2p >= thresh)
            batch = b;
        else
            break;
    }
    return batch;
}

/** Dispatch to the appropriate evaluator based on carry width. */
long double evaluate(long double p, const std::vector<RangeInfo>& ranges,
                     unsigned carry_bits, bool carry_inf) noexcept
{
    if (carry_inf) return model_full_carry(p, ranges);
    if (carry_bits == 0) return model_no_carry(p, ranges);
    return model_bounded_carry(p, ranges, carry_bits);
}

void cmd_model_rate(const char* batch_arg, const char* bits_arg, const char* carry_arg)
{
    std::string_view batch_str = batch_arg;
    std::string_view bits_str = bits_arg;
    std::string_view carry_str = carry_arg;

    bool batch_tuned = (batch_str == "tuned");
    unsigned batch = batch_tuned ? 0 : std::atoi(batch_arg);
    bool bits_inf = (bits_str == "inf");
    unsigned bits = bits_inf ? 0 : std::atoi(bits_arg);
    bool carry_inf = (carry_str == "inf");
    unsigned carry_bits = carry_inf ? 0 : std::atoi(carry_arg);

    if (batch_tuned) {
        std::vector<long double> ps;
        long double p;
        while (std::scanf("%Lf", &p) == 1)
            ps.push_back(p);

        std::unordered_map<unsigned, std::vector<RangeInfo>> cache;
        std::vector<unsigned> batch_per_p(ps.size());

        for (size_t i = 0; i < ps.size(); i++) {
            unsigned b = find_tuned_batch(ps[i], bits);
            batch_per_p[i] = b;
            if (!cache.contains(b))
                cache.emplace(b, bits_inf ? precompute(b) : precompute(b, bits));
        }

        for (size_t i = 0; i < ps.size(); i++) {
            if (ps[i] <= 0.0L || ps[i] >= 1.0L) { std::printf("0\n"); continue; }
            std::printf("%.21Lg\n", evaluate(ps[i], cache.at(batch_per_p[i]), carry_bits, carry_inf));
        }
    } else {
        auto ranges = bits_inf ? precompute(batch) : precompute(batch, bits);
        long double p;
        while (std::scanf("%Lf", &p) == 1) {
            if (p <= 0.0L || p >= 1.0L) { std::printf("0\n"); continue; }
            std::printf("%.21Lg\n", evaluate(p, ranges, carry_bits, carry_inf));
        }
    }
}

}  // namespace

int main(int argc, char** argv)
{
    if (argc >= 2 && std::strcmp(argv[1], "exhaustive_test") == 0) {
        if (argc == 4) {
            unsigned m = std::atoi(argv[2]);
            unsigned n = std::atoi(argv[3]);
            if (m < 2 || m > 256) {
                std::cerr << "m must be between 2 and 256\n";
                return 1;
            }
            return run_config(m, n) ? 0 : 1;
        }
        if (argc == 2) {
            bool ok = true;
            for (auto [m, n] : DEFAULT_CONFIGS) {
                ok = run_config(m, n) && ok;
            }
            return ok ? 0 : 1;
        }
        std::fprintf(stderr, "Usage: exhaustive_test [M N]\n");
        return 1;
    }

    if (argc == 5 && std::strcmp(argv[1], "model_rate") == 0) {
        cmd_model_rate(argv[2], argv[3], argv[4]);
        return 0;
    }

    if (argc == 4 && std::strcmp(argv[1], "sim_bench") == 0) {
        unsigned batch = std::atoi(argv[2]);
        double p = std::atof(argv[3]);
        if (batch < 2 || batch > 4096) {
            std::cerr << "batch must be between 2 and 4096\n";
            return 1;
        }
        if (p <= 0.0 || p >= 1.0) {
            std::cerr << "p must be between 0.0 and 1.0 (exclusive)\n";
            return 1;
        }
        cmd_sim_bench(batch, p);
        return 0;
    }

    if (argc == 1) {
        bool ok = true;
        for (auto [m, n] : DEFAULT_CONFIGS) {
            ok = run_config(m, n) && ok;
        }
        return ok ? 0 : 1;
    }

    std::fprintf(stderr,
        "Usage: %s <command> [args...]\n"
        "\n"
        "Commands:\n"
        "  exhaustive_test [M N]\n"
        "      Exhaustively verify the extractor for M-valued events in batches of N.\n"
        "      Checks that spans match multinomial coefficients and that values form\n"
        "      permutations within each frequency class. Without arguments, runs a\n"
        "      built-in set of (M, N) pairs. M must be 2..256; M^N <= 10^10.\n"
        "\n"
        "  sim_bench BATCH P\n"
        "      Benchmark extraction throughput. Generates 128 MiB of biased random\n"
        "      booleans (each true with probability P), then repeatedly feeds batches\n"
        "      of BATCH symbols through process_batch + extract_bit. Prints per-second\n"
        "      stats (throughput, bits/symbol, redundancy) to stderr. Runs forever.\n"
        "      BATCH: 2..4096.  P: (0, 1) exclusive.\n"
        "\n"
        "  model_rate BATCH BITS CARRY\n"
        "      Compute the expected bits/toss for the extractor, reading p values from\n"
        "      stdin (one per line) and writing results to stdout.\n"
        "      BATCH: batch size (2..4096), or \"tuned\" for adaptive selection.\n"
        "      BITS:  integer width for overflow handling (2..128), or \"inf\".\n"
        "      CARRY: carry register width (0..128), or \"inf\".\n"
        "             0 = no carry (Elias), inf = full carry (log2-based).\n"
        "\n"
        "Running with no arguments is equivalent to `exhaustive_test`.\n", argv[0]);
    return 1;
}
