# Copyright (c) 2026 Pieter Wuille
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to furnish to do so, subject to the following
# conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

"""Division-free, overflow-safe binomial randomness extractor.

During batch processing, (span, value) is represented as

    span  = (span_num / denom) * 2 ** span_shift
    value =  value_num / denom

with span_num, value_num, denom kept modulo 2 ** BITS. Divisions are deferred
to a single modular-inverse computation at the end. Multiplication, addition,
subtraction, and left shifts all commute with reduction mod 2 ** BITS, so
intermediate overflow is harmless as long as the final result fits in BITS bits
(and extract_bit handles the residual case where span itself does not).
"""
from dataclasses import dataclass
from typing import List, Optional, Tuple

BITS = 64
MAX_VALUE = (1 << BITS) - 1


@dataclass
class UniformSpan:
    """Value is uniformly distributed in [0, span)."""
    span: int
    value: int


def modular_inverse(odd: int) -> int:
    """Return the modular inverse of odd, modulo 2 ** BITS."""
    # Bottom 4 bits of the inverse via a table trick: for odd x, the inverse
    # mod 16 equals x ^ ((x+1) & 4) << 1 restricted to the low nibble.
    ret = (odd ^ (((odd + 1) & 4) << 1)) & MAX_VALUE
    # Each Newton-Hensel step doubles the number of correct low bits, so
    # ceil(log2(BITS/4)) iterations suffice to reach BITS bits.
    for _ in range((BITS - 1).bit_length() - 2):
        ret = (ret * (2 - ret * odd)) & MAX_VALUE
    return ret


def split_odd(x: int) -> Tuple[int, int]:
    """Return (odd, shift) with x == odd << shift.

    For x == 0 this returns (0, BITS), matching C++'s std::countr_zero on a
    BITS-wide unsigned: the odd factor is 0 so any later multiplication
    vanishes, and callers that shift by shift end up shifting zero by BITS,
    leaving a harmless zero contribution.

    Real hardware has "count trailing zeroes" / "bit scan forward" instructions
    that compute shift in a single cycle.
    """
    if x == 0:
        return 0, BITS
    shift = 0
    while ((x >> shift) & 1) == 0:
        shift += 1
    return x >> shift, shift


def process_batch(events: List[bool]) -> UniformSpan:
    """Encode a batch of binary events into a UniformSpan.

    Returned span equals C(n, k) mod 2 ** BITS, where n = len(events) and k is
    the number of True entries. The returned value is the reverse-lexicographic
    position of events within the C(n, k) sequences sharing its frequency.
    """
    # Number of True events seen so far.
    k = 0
    # Base case (empty prefix): span == 1, value == 0.
    span_num, span_shift = 1, 0
    value_num, denom = 0, 1
    # Process events one by one, in forward order. Because we are computing
    # the reverse-lex position, each event acts on the *front* of the suffix.
    for n, event in enumerate(events, start=1):
        k += event
        # Count of events in this batch that match the current one's kind
        # (True count if True, False count otherwise).
        num_equal = k if event else (n - k)
        # Split n and num_equal into an odd factor and a shift so that the
        # subsequent (/ num_equal) becomes a modular inverse + right shift.
        n_odd, n_shift = split_odd(n)
        num_equal_odd, num_equal_shift = split_odd(num_equal)
        # new_span = span * n / num_equal  (kept in (num/denom) << shift form).
        new_span_num = (span_num * n_odd) & MAX_VALUE
        new_span_shift = span_shift + n_shift - num_equal_shift
        new_denom = (denom * num_equal_odd) & MAX_VALUE
        # Update the value. When the event is True, add (new_span - span),
        # the number of arrangements sharing this frequency that start with a
        # False so that True-prefixed arrangements index above them.
        if event:
            value_num = (num_equal_odd * (value_num - (span_num << span_shift))
                         + (new_span_num << new_span_shift)) & MAX_VALUE
        else:
            value_num = (num_equal_odd * value_num) & MAX_VALUE
        span_num, span_shift, denom = new_span_num, new_span_shift, new_denom
    # Fold the deferred denominator in with one modular inverse.
    denom_inv = modular_inverse(denom)
    return UniformSpan(
        span=((span_num * denom_inv) << span_shift) & MAX_VALUE,
        value=(value_num * denom_inv) & MAX_VALUE,
    )


def process_batch_multi(events: List[int], M: int) -> UniformSpan:
    """Encode a batch of events drawn from [0, M) into a UniformSpan.

    Returned span equals the multinomial coefficient n! / (f_0! .. f_{M-1}!)
    modulo 2 ** BITS, where f_i is the number of occurrences of event i.
    """
    # freq[i] tracks how many events equal to i have been seen so far.
    freq = [0] * M
    # Base case (empty prefix): span == 1, value == 0.
    span_num, span_shift = 1, 0
    value_num, denom = 0, 1
    for n, event in enumerate(events, start=1):
        freq[event] += 1
        # Events equal to the current one, and events strictly below it, seen
        # so far (including this one for num_equal).
        num_equal = freq[event]
        num_below = sum(freq[:event])
        # Split into odd/shift form so the divisions become modular inverses.
        n_odd, n_shift = split_odd(n)
        num_equal_odd, num_equal_shift = split_odd(num_equal)
        num_below_odd, num_below_shift = split_odd(num_below)
        # new_span = span * n / num_equal.
        new_span_num = (span_num * n_odd) & MAX_VALUE
        new_span_shift = span_shift + n_shift - num_equal_shift
        new_denom = (denom * num_equal_odd) & MAX_VALUE
        # Update the value. Arrangements starting with an event *lower* than
        # the current one form a (num_below / num_equal) fraction of the old
        # span; add that many positions so the current event's index follows.
        value_num = (num_equal_odd * value_num
                     + ((num_below_odd * span_num)
                        << (num_below_shift + span_shift - num_equal_shift))
                     ) & MAX_VALUE
        span_num, span_shift, denom = new_span_num, new_span_shift, new_denom
    # Fold the deferred denominator in with one modular inverse.
    denom_inv = modular_inverse(denom)
    return UniformSpan(
        span=((span_num * denom_inv) << span_shift) & MAX_VALUE,
        value=(value_num * denom_inv) & MAX_VALUE,
    )


def binomial_fraction(n: int, k: int) -> Tuple[int, int]:
    """Return (num, denom) with num * modular_inverse(denom) == C(n, k) mod 2**BITS.

    Complexity: 2k mul, 2k ctz. Leaks both n and k through timing.
    """
    if k > n:
        return 0, 1
    num, denom = 1, 1
    shift = 0
    for i in range(k):
        # These split_odd calls could be replaced with lookups in an O(k)-sized
        # factorial table.
        mul_odd, mul_shift = split_odd(n - i)
        div_odd, div_shift = split_odd(i + 1)
        num = (num * mul_odd) & MAX_VALUE
        denom = (denom * div_odd) & MAX_VALUE
        shift += mul_shift - div_shift
    return (num << shift) & MAX_VALUE, denom


# FAC_TABLE[i] = (odd part of i!, modular inverse of that odd part, tz(i!)).
_FAC_TABLE_SIZE = 100
FAC_TABLE: List[Tuple[int, int, int]] = []


def _init_fac_table():
    fact_odd = 1
    fact_shift = 0
    for i in range(_FAC_TABLE_SIZE):
        if i > 0:
            i_odd, i_shift = split_odd(i)
            fact_odd = (fact_odd * i_odd) & MAX_VALUE
            fact_shift += i_shift
        FAC_TABLE.append((fact_odd, modular_inverse(fact_odd), fact_shift))


_init_fac_table()


def binomial_table(n: int, k: int) -> int:
    """Compute C(n, k) mod 2**BITS via the precomputed factorial table.

    Complexity: 2 mul plus the table lookup. Requires n < len(FAC_TABLE).
    Leaks information about n and k through timing.
    """
    if k > n:
        return 0
    n_odd, _, n_shift = FAC_TABLE[n]
    _, k_inv_odd, k_shift = FAC_TABLE[k]
    _, nmk_inv_odd, nmk_shift = FAC_TABLE[n - k]
    return ((n_odd * k_inv_odd * nmk_inv_odd) << (n_shift - k_shift - nmk_shift)) & MAX_VALUE


def process_batch_low_k(positions: List[int], n: int) -> UniformSpan:
    """Encode a batch of binary events given as the sorted 0-indexed positions
    of True events and the total event count n.

    Efficient when the number of True events K is small; complexity is
    ~K(K+3) mul and K(K+3) ctz. Produces the same (span, value) as
    process_batch would on the corresponding bit sequence. Leaks information
    about the returned value through timing.
    """
    value_num, value_denom = 0, 1
    num_1 = 0
    # Accumulate the value in fraction form: at each position we add
    # C(position, num_1), whose denominator must be tracked alongside.
    for position in positions:
        num_1 += 1
        add_num, add_denom = binomial_fraction(position, num_1)
        value_num = (value_num * add_denom + add_num * value_denom) & MAX_VALUE
        value_denom = (value_denom * add_denom) & MAX_VALUE
    # Span is C(n, num_1) in the same form.
    span_num, span_denom = binomial_fraction(n, num_1)
    # Batch inversion: one modular inverse covers both denominators.
    inv = modular_inverse((value_denom * span_denom) & MAX_VALUE)
    return UniformSpan(
        span=(span_num * inv * value_denom) & MAX_VALUE,
        value=(value_num * inv * span_denom) & MAX_VALUE,
    )


def process_batch_table(positions: List[int], n: int) -> UniformSpan:
    """Encode a batch of binary events using the precomputed factorial table.

    Operates on 0-indexed positions of True events; cost is 2K+2 mul plus
    table lookups. Requires n < len(FAC_TABLE). Produces the same
    (span, value) as process_batch would on the corresponding bit sequence.
    Leaks information about the returned value through timing.
    """
    value = 0
    num_1 = 0
    for position in positions:
        num_1 += 1
        value = (value + binomial_table(position, num_1)) & MAX_VALUE
    span = binomial_table(n, num_1)
    return UniformSpan(span, value)


def merge(u1: UniformSpan, u2: UniformSpan) -> UniformSpan:
    """Combine entropy from two uniform spans into one.

    Interprets u2 as a digit below u1: the merged value is u1.value * u2.span
    + u2.value, drawn uniformly from [0, u1.span * u2.span).
    """
    return UniformSpan(
        span=(u1.span * u2.span) & MAX_VALUE,
        value=(u1.value * u2.span + u2.value) & MAX_VALUE,
    )


def extract_bit(state: UniformSpan) -> Optional[bool]:
    """Extract one uniform bit from state, mutating it; None means none available."""
    span, value = state.span, state.value
    # Overflow fold: if the true full-width span is S = span + q * 2**BITS,
    # then values in [0, span) occur with probability (q+1)/S and values in
    # [span, 2**BITS) occur with probability q/S. The second region is its
    # own uniform span of size (2**BITS - span), reachable via (value - span).
    # Either branch produces a uniform value in [0, maxval + 1).
    if value < span:
        maxval = span - 1
    else:
        maxval = MAX_VALUE - span
        value = value - span
    # If the span has odd size (maxval even), one value cannot be paired with
    # a sibling of opposite parity; discard it. If we landed on that value,
    # we cannot emit a bit: reset the state to neutral and return None.
    if (maxval & 1) == 0:
        if value == maxval:
            state.span, state.value = 1, 0
            return None
        maxval -= 1
    # Emit the bottom bit of value and shift both span and value down.
    ret = bool(value & 1)
    state.span = (maxval >> 1) + 1
    state.value = value >> 1
    return ret


import itertools
import math
import unittest
from collections import defaultdict


def _multinomial(freqs):
    total = math.factorial(sum(freqs))
    for f in freqs:
        total //= math.factorial(f)
    return total


class ExhaustiveTest(unittest.TestCase):
    """For small (n, M) with M**n tractable, enumerate all event sequences and
    verify that span equals the expected (multi)nomial, 0 <= value < span, and
    that within each frequency class the values form a permutation of
    {0, ..., span-1} (i.e. every sequence produces a distinct encoding)."""

    CONFIGS = [
        (1, 2), (5, 2), (10, 2), (13, 2),
        (4, 3), (8, 3),
        (3, 4), (6, 4),
        (5, 5),
        (4, 6),
        (3, 10), (4, 10),
    ]

    def test_exhaustive(self):
        for n, M in self.CONFIGS:
            with self.subTest(n=n, M=M):
                self.assertLessEqual(M ** n, 10000)
                by_freq = defaultdict(list)
                for events in itertools.product(range(M), repeat=n):
                    freqs = [0] * M
                    for e in events:
                        freqs[e] += 1
                    expected = _multinomial(freqs)
                    if M == 2:
                        bits = [bool(e) for e in events]
                        result = process_batch(bits)
                        positions = [i for i, b in enumerate(bits) if b]
                        result_low_k = process_batch_low_k(positions, len(bits))
                        result_table = process_batch_table(positions, len(bits))
                        self.assertEqual((result.span, result.value),
                                         (result_low_k.span, result_low_k.value),
                                         f"low_k disagrees for events={events}")
                        self.assertEqual((result.span, result.value),
                                         (result_table.span, result_table.value),
                                         f"table disagrees for events={events}")
                    else:
                        result = process_batch_multi(list(events), M)
                    self.assertEqual(result.span, expected,
                                     f"span mismatch for events={events}")
                    self.assertGreaterEqual(result.value, 0,
                                            f"negative value for events={events}")
                    self.assertLess(result.value, result.span,
                                    f"value >= span for events={events}")
                    by_freq[tuple(freqs)].append(result.value)
                for freqs, values in by_freq.items():
                    self.assertEqual(sorted(values), list(range(len(values))),
                                     f"values for freqs={freqs} not a permutation")


if __name__ == "__main__":
    unittest.main()
