// Copyright (c) 2026 Pieter Wuille
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to furnish to do so, subject to the following
// conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

/** During batch processing, (span, value) is represented as
 *
 *     span  = (span_num / denom) * 2^span_shift
 *     value =  value_num / denom
 *
 * with span_num, value_num, denom kept modulo 2^64 via native uint64_t
 * wraparound. Divisions are deferred to a single modular-inverse computation
 * at the end. Multiplication, addition, subtraction, and left shifts all
 * commute with reduction mod 2^64, so intermediate overflow is harmless
 * as long as the final result fits in 64 bits (and extract_bit handles the
 * residual case where span itself does not).
 */

#include "extractor.h"

#include <array>
#include <bit>
#include <utility>

namespace {

/** Return the modular inverse of an odd x, modulo 2^64.
 *
 * Complexity: 5 mul. Constant-time.
 */
constexpr uint64_t mod_inverse(uint64_t odd) noexcept
{
    // 4-bit-correct seed via a branch-free table trick: for odd x, the
    // inverse modulo 16 equals x ^ ((x + 1) & 4) << 1 restricted to the low
    // nibble.
    uint64_t inv = odd ^ (((odd + 1) & 4) << 1);
    // Newton-Hensel: each step doubles the number of correct low bits, so
    // four iterations lift the 4-bit-correct seed to a 64-bit-correct inverse.
    inv *= 2 - inv * odd;
    inv *= 2 - inv * odd;
    inv *= 2 - inv * odd;
    inv *= 2 - inv * odd;
    return inv;
}

/** Given x, return (odd, shift) such that x == odd << shift.
 *
 * For x == 0 this returns (0, 64) Downstream arithmetic multiplies the odd
 * factor against another value, so an input of 0 produces a 0 contribution
 * that vanishes regardless of the attached shift.
 *
 * Complexity: 1 ctz. Constant-time if the architecture has a constant-time ctz.
 */
constexpr std::pair<uint64_t, unsigned> split_odd(uint64_t x) noexcept
{
    unsigned shift = std::countr_zero(x);
    return {x >> shift, shift};
}

/** Compute (num, denom) such that
 *  num * mod_inverse(denom) == (n choose k) modulo 2^64.
 *
 * Complexity: 2k mul, 2k ctz.
 *
 * Leaks both n and k through timing.
 */
std::pair<uint64_t, uint64_t> binomial_fraction(unsigned n, unsigned k) noexcept
{
    if (k > n) return {0, 1};
    uint64_t num{1}, denom{1};
    unsigned shift{0};
    for (unsigned i = 0; i < k; ++i) {
        // These split_odd calls could be replaced with lookups in an O(k)-sized
        // factorial table.
        auto [mul_odd, mul_shift] = split_odd(n - i);
        auto [div_odd, div_shift] = split_odd(i + 1);
        num *= mul_odd;
        denom *= div_odd;
        shift += mul_shift - div_shift;
    }
    return {num << shift, denom};
}

/** Compute (n choose k) mod 2^64 using the precomputed factorial table.
 *
 * Complexity: 2 mul (plus three table lookups). Requires n < 4096.
 *
 * Leaks information about n and k through timing.
 */
uint64_t binomial_table(unsigned n, unsigned k) noexcept
{
    /** Precomputed factorial table: FAC_TABLE[i] = {o, mod_inverse(o), s},
     *  where (o, s) = split_odd(i!). */
    static auto constexpr FAC_TABLE = []{
        struct FacEntry { uint64_t odd, inv; unsigned shift; };
        std::array<FacEntry, 4096> tbl{};
        uint64_t fact_odd = 1;
        unsigned fact_shift = 0;
        for (unsigned i = 0; i < tbl.size(); ++i) {
            if (i > 0) {
                auto [i_odd, i_shift] = split_odd(i);
                fact_odd *= i_odd;
                fact_shift += i_shift;
            }
            tbl[i] = {fact_odd, mod_inverse(fact_odd), fact_shift};
        }
        return tbl;
    }();

    if (k > n) return 0;
    auto& ne = FAC_TABLE[n];
    auto& ke = FAC_TABLE[k];
    auto& me = FAC_TABLE[n - k];
    return (ne.odd * ke.inv * me.inv) << (ne.shift - ke.shift - me.shift);
}

} // namespace

/** Encode a batch of binary events into a UniformSpan.
 *
 * The returned span equals (n choose k) mod 2^64, where n = events.size() and
 * k is the number of True entries. The returned value is the reverse-
 * lexicographic position of events within the (n choose k) sequences sharing
 * its frequency.
 *
 * Constant-time in the event values (no data-dependent branches).
 *
 * Complexity: 3N mul, 2N ctz, plus one mod_inverse (N = events.size()).
 */
UniformSpan process_batch(std::span<const uint8_t> events) noexcept
{
    uint64_t span_num{1}, value_num{0}, denom{1};
    unsigned span_shift{0};
    /** Number of True events seen so far. */
    uint64_t k{0};
    /** Total events seen so far. */
    uint64_t n{0};

    for (uint8_t event : events) {
        // Full-width mask: all-ones if event is True, all-zeros otherwise.
        uint64_t mask = ~uint64_t{event} + 1;
        k += event;
        ++n;
        // Count of events in this batch that match the current one's kind:
        // k if True, (n - k) if False. Selected branch-free via mask.
        uint64_t num_equal = (mask & k) + (~mask & (n - k));
        auto [n_odd, n_shift] = split_odd(n);
        auto [num_equal_odd, num_equal_shift] = split_odd(num_equal);

        // new_span = span * n / num_equal, kept in (num/denom) << shift form.
        uint64_t new_span_num = span_num * n_odd;
        unsigned new_span_shift = span_shift + n_shift - num_equal_shift;
        uint64_t new_denom = denom * num_equal_odd;

        // For a True event, add (new_span - span), the number of arrangements
        // sharing this frequency that start with a False, so True-prefixed
        // arrangements index above them. For a False event, only the rescale
        // by num_equal_odd applies. Expressed branch-free via mask.
        value_num = num_equal_odd * (value_num - (mask & (span_num << span_shift)))
                    + (mask & (new_span_num << new_span_shift));

        span_num = new_span_num;
        span_shift = new_span_shift;
        denom = new_denom;
    }

    // Fold the deferred denominator in with a single modular inverse.
    uint64_t denom_inv = mod_inverse(denom);
    return {(span_num * denom_inv) << span_shift, value_num * denom_inv};
}

/** Encode a batch of events drawn from [0, M) into a UniformSpan.
 *
 * Returned span equals the multinomial coefficient
 *     n! / (f_0! * ... * f_{M-1}!)
 * modulo 2^64, where f_i is the number of occurrences of event i.
 *
 * Constant-time in the event values.
 *
 * Complexity: (M+3)N mul, 3N ctz, plus one mod_inverse.
 */
UniformSpan process_batch_multi(std::span<const uint8_t> events, unsigned m) noexcept
{
    uint64_t span_num{1}, value_num{0}, denom{1};
    unsigned span_shift{0};
    /** freq[i] tracks how many events equal to i have been seen so far. */
    uint64_t freq[256] = {};
    /** Total events seen so far. */
    uint64_t n{0};

    for (uint8_t event : events) {
        ++n;
        // Scan all m slots branch-free. Update freq[event], and compute
        //   num_below = sum(freq[i] for i < event)
        //   num_equal = freq[event]
        // by masking each slot against the event comparison.
        uint64_t num_below{0}, num_equal{0};
        for (unsigned i = 0; i < m; ++i) {
            uint64_t mask_below = ~uint64_t{i < event} + 1;
            uint64_t mask_equal = ~uint64_t{i == event} + 1;
            freq[i] += (i == event);
            num_below += mask_below & freq[i];
            num_equal += mask_equal & freq[i];
        }
        auto [n_odd, n_shift] = split_odd(n);
        auto [num_equal_odd, num_equal_shift] = split_odd(num_equal);
        // num_below may be 0 (when event is the smallest value seen so far);
        // split_odd(0) returns (0, 64), so the product below is 0 and the
        // shift merely shifts that zero, a harmless no-op contribution.
        auto [num_below_odd, num_below_shift] = split_odd(num_below);

        // new_span = span * n / num_equal.
        uint64_t new_span_num = span_num * n_odd;
        unsigned new_span_shift = span_shift + n_shift - num_equal_shift;
        uint64_t new_denom = denom * num_equal_odd;

        // Arrangements starting with an event strictly below the current one
        // form a (num_below / num_equal) fraction of the old span; add that
        // many positions so the current event's encoding follows.
        value_num = num_equal_odd * value_num
                    + ((num_below_odd * span_num)
                       << (num_below_shift + span_shift - num_equal_shift));

        span_num = new_span_num;
        span_shift = new_span_shift;
        denom = new_denom;
    }

    // Fold the deferred denominator in with a single modular inverse.
    uint64_t denom_inv = mod_inverse(denom);
    return {(span_num * denom_inv) << span_shift, value_num * denom_inv};
}

/** Encode a batch of binary events given as the sorted 0-indexed positions of
 * True events and the total event count `n`.
 *
 * Efficient when the number of True events K is small; complexity is
 * K(K+3)+13 mul and K(K+3) ctz. Produces the same (span, value) as
 * process_batch would on the corresponding bit sequence.
 *
 * Leaks information about the returned value through timing.
 */
UniformSpan process_batch_low_k(std::span<const unsigned> positions, unsigned n) noexcept
{
    uint64_t value_num{0}, value_denom{1};
    unsigned num_1{0};
    // Accumulate the value in fraction form: at each position we add
    // (position choose num_1), whose denominator is tracked alongside the value.
    for (unsigned position : positions) {
        ++num_1;
        auto [add_num, add_denom] = binomial_fraction(position, num_1);
        value_num = value_num * add_denom + add_num * value_denom;
        value_denom *= add_denom;
    }
    // Span is (n choose num_1), in the same form.
    auto [span_num, span_denom] = binomial_fraction(n, num_1);
    // Batch inversion: one mod_inverse covers both denominators at once.
    uint64_t inv = mod_inverse(value_denom * span_denom);
    return {span_num * inv * value_denom, value_num * inv * span_denom};
}

/** Encode a batch of binary events using the precomputed factorial table.
 * Operates on 0-indexed positions of True events; cost is 2K+2 mul plus
 * table lookups. Requires n < 4096. Produces the same (span, value)
 * as process_batch would on the corresponding bit sequence.
 *
 * Leaks information about the returned value through timing.
 */
UniformSpan process_batch_table(std::span<const unsigned> positions, unsigned n) noexcept
{
    uint64_t value{0};
    unsigned num_1{0};
    for (unsigned position : positions) {
        ++num_1;
        value += binomial_table(position, num_1);
    }
    uint64_t span = binomial_table(n, num_1);
    return {span, value};
}

/** Combine entropy from two uniform spans into one.
 *
 * Interprets lo as a digit below hi: the merged value is
 * hi.value * lo.span + lo.value, drawn uniformly from [0, hi.span * lo.span).
 */
UniformSpan merge(const UniformSpan& hi, const UniformSpan& lo) noexcept
{
    return {hi.span * lo.span, hi.value * lo.span + lo.value};
}

/** Extract one uniform bit from state, mutating it.
 *
 * Returns the extracted bit, or std::nullopt if no bit could be produced
 * (in which case state is reset to {1, 0}).
 *
 * Timing side-channels: unlike process_batch and process_batch_multi, this
 * function is not constant-time. However, its runtime does not depend on the
 * value of the extracted bit. The only data-dependent branches turn on
 * information that is not secret in this setting:
 *   - the overflow fold (value >= span) is taken based on whether the
 *     prior computation overflowed, which the caller can already observe;
 *   - the reset path (value == maxval on an odd-sized span) is taken based
 *     on whether an output bit was produced at all.
 * The extracted bit itself (value & 1) never gates a branch.
 */
std::optional<bool> extract_bit(UniformSpan& state) noexcept
{
    uint64_t span = state.span;
    uint64_t value = state.value;

    // Overflow fold: if value >= span the observation lives in the q/S slice
    // of the true distribution, which is itself uniform over a span of size
    // 2^64 - state.span = (-state.span) in uint64_t. This also handles
    // span == 0 (true span is 2^64): value >= 0 is always true, so we take
    // this branch, and -0 == 0 leaves span at 0, giving maxval == MAX_VALUE.
    if (value >= span) {
        value -= span;
        span = -span;
    }
    uint64_t maxval = span - 1;

    // If the span has odd size (maxval even), one value cannot be paired with
    // a sibling of opposite parity; drop it. If we landed on that value, we
    // cannot emit a bit: reset the state and signal failure.
    if ((maxval & 1) == 0) {
        if (value == maxval) [[unlikely]] {
            state = {1, 0};
            return std::nullopt;
        }
        --maxval;
    }

    // Emit the bottom bit of value; shift span and value down together.
    bool bit = value & 1;
    state.span = (maxval >> 1) + 1;
    state.value = value >> 1;
    return bit;
}
