1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353 | // Copyright (c) 2026 Pieter Wuille
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to furnish to do so, subject to the following
// conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
/** During batch processing, (span, value) is represented as
*
* span = (span_num / denom) * 2^span_shift
* value = value_num / denom
*
* with span_num, value_num, denom kept modulo 2^64 via native uint64_t
* wraparound. Divisions are deferred to a single modular-inverse computation
* at the end. Multiplication, addition, subtraction, and left shifts all
* commute with reduction mod 2^64, so intermediate overflow is harmless
* as long as the final result fits in 64 bits (and extract_bit handles the
* residual case where span itself does not).
*/
#include "extractor.h"
#include <array>
#include <bit>
#include <utility>
namespace {
/** Return the modular inverse of an odd x, modulo 2^64.
*
* Complexity: 5 mul. Constant-time.
*/
constexpr uint64_t mod_inverse(uint64_t odd) noexcept
{
// 4-bit-correct seed via a branch-free table trick: for odd x, the
// inverse modulo 16 equals x ^ ((x + 1) & 4) << 1 restricted to the low
// nibble.
uint64_t inv = odd ^ (((odd + 1) & 4) << 1);
// Newton-Hensel: each step doubles the number of correct low bits, so
// four iterations lift the 4-bit-correct seed to a 64-bit-correct inverse.
inv *= 2 - inv * odd;
inv *= 2 - inv * odd;
inv *= 2 - inv * odd;
inv *= 2 - inv * odd;
return inv;
}
/** Given x, return (odd, shift) such that x == odd << shift.
*
* For x == 0 this returns (0, 64) Downstream arithmetic multiplies the odd
* factor against another value, so an input of 0 produces a 0 contribution
* that vanishes regardless of the attached shift.
*
* Complexity: 1 ctz. Constant-time if the architecture has a constant-time ctz.
*/
constexpr std::pair<uint64_t, unsigned> split_odd(uint64_t x) noexcept
{
unsigned shift = std::countr_zero(x);
return {x >> shift, shift};
}
/** Compute (num, denom) such that
* num * mod_inverse(denom) == (n choose k) modulo 2^64.
*
* Complexity: 2k mul, 2k ctz.
*
* Leaks both n and k through timing.
*/
std::pair<uint64_t, uint64_t> binomial_fraction(unsigned n, unsigned k) noexcept
{
if (k > n) return {0, 1};
uint64_t num{1}, denom{1};
unsigned shift{0};
for (unsigned i = 0; i < k; ++i) {
// These split_odd calls could be replaced with lookups in an O(k)-sized
// factorial table.
auto [mul_odd, mul_shift] = split_odd(n - i);
auto [div_odd, div_shift] = split_odd(i + 1);
num *= mul_odd;
denom *= div_odd;
shift += mul_shift - div_shift;
}
return {num << shift, denom};
}
/** Compute (n choose k) mod 2^64 using the precomputed factorial table.
*
* Complexity: 2 mul (plus three table lookups). Requires n < 4096.
*
* Leaks information about n and k through timing.
*/
uint64_t binomial_table(unsigned n, unsigned k) noexcept
{
/** Precomputed factorial table: FAC_TABLE[i] = {o, mod_inverse(o), s},
* where (o, s) = split_odd(i!). */
static auto constexpr FAC_TABLE = []{
struct FacEntry { uint64_t odd, inv; unsigned shift; };
std::array<FacEntry, 4096> tbl{};
uint64_t fact_odd = 1;
unsigned fact_shift = 0;
for (unsigned i = 0; i < tbl.size(); ++i) {
if (i > 0) {
auto [i_odd, i_shift] = split_odd(i);
fact_odd *= i_odd;
fact_shift += i_shift;
}
tbl[i] = {fact_odd, mod_inverse(fact_odd), fact_shift};
}
return tbl;
}();
if (k > n) return 0;
auto& ne = FAC_TABLE[n];
auto& ke = FAC_TABLE[k];
auto& me = FAC_TABLE[n - k];
return (ne.odd * ke.inv * me.inv) << (ne.shift - ke.shift - me.shift);
}
} // namespace
/** Encode a batch of binary events into a UniformSpan.
*
* The returned span equals (n choose k) mod 2^64, where n = events.size() and
* k is the number of True entries. The returned value is the reverse-
* lexicographic position of events within the (n choose k) sequences sharing
* its frequency.
*
* Constant-time in the event values (no data-dependent branches).
*
* Complexity: 3N mul, 2N ctz, plus one mod_inverse (N = events.size()).
*/
UniformSpan process_batch(std::span<const uint8_t> events) noexcept
{
uint64_t span_num{1}, value_num{0}, denom{1};
unsigned span_shift{0};
/** Number of True events seen so far. */
uint64_t k{0};
/** Total events seen so far. */
uint64_t n{0};
for (uint8_t event : events) {
// Full-width mask: all-ones if event is True, all-zeros otherwise.
uint64_t mask = ~uint64_t{event} + 1;
k += event;
++n;
// Count of events in this batch that match the current one's kind:
// k if True, (n - k) if False. Selected branch-free via mask.
uint64_t num_equal = (mask & k) + (~mask & (n - k));
auto [n_odd, n_shift] = split_odd(n);
auto [num_equal_odd, num_equal_shift] = split_odd(num_equal);
// new_span = span * n / num_equal, kept in (num/denom) << shift form.
uint64_t new_span_num = span_num * n_odd;
unsigned new_span_shift = span_shift + n_shift - num_equal_shift;
uint64_t new_denom = denom * num_equal_odd;
// For a True event, add (new_span - span), the number of arrangements
// sharing this frequency that start with a False, so True-prefixed
// arrangements index above them. For a False event, only the rescale
// by num_equal_odd applies. Expressed branch-free via mask.
value_num = num_equal_odd * (value_num - (mask & (span_num << span_shift)))
+ (mask & (new_span_num << new_span_shift));
span_num = new_span_num;
span_shift = new_span_shift;
denom = new_denom;
}
// Fold the deferred denominator in with a single modular inverse.
uint64_t denom_inv = mod_inverse(denom);
return {(span_num * denom_inv) << span_shift, value_num * denom_inv};
}
/** Encode a batch of events drawn from [0, M) into a UniformSpan.
*
* Returned span equals the multinomial coefficient
* n! / (f_0! * ... * f_{M-1}!)
* modulo 2^64, where f_i is the number of occurrences of event i.
*
* Constant-time in the event values.
*
* Complexity: (M+3)N mul, 3N ctz, plus one mod_inverse.
*/
UniformSpan process_batch_multi(std::span<const uint8_t> events, unsigned m) noexcept
{
uint64_t span_num{1}, value_num{0}, denom{1};
unsigned span_shift{0};
/** freq[i] tracks how many events equal to i have been seen so far. */
uint64_t freq[256] = {};
/** Total events seen so far. */
uint64_t n{0};
for (uint8_t event : events) {
++n;
// Scan all m slots branch-free. Update freq[event], and compute
// num_below = sum(freq[i] for i < event)
// num_equal = freq[event]
// by masking each slot against the event comparison.
uint64_t num_below{0}, num_equal{0};
for (unsigned i = 0; i < m; ++i) {
uint64_t mask_below = ~uint64_t{i < event} + 1;
uint64_t mask_equal = ~uint64_t{i == event} + 1;
freq[i] += (i == event);
num_below += mask_below & freq[i];
num_equal += mask_equal & freq[i];
}
auto [n_odd, n_shift] = split_odd(n);
auto [num_equal_odd, num_equal_shift] = split_odd(num_equal);
// num_below may be 0 (when event is the smallest value seen so far);
// split_odd(0) returns (0, 64), so the product below is 0 and the
// shift merely shifts that zero, a harmless no-op contribution.
auto [num_below_odd, num_below_shift] = split_odd(num_below);
// new_span = span * n / num_equal.
uint64_t new_span_num = span_num * n_odd;
unsigned new_span_shift = span_shift + n_shift - num_equal_shift;
uint64_t new_denom = denom * num_equal_odd;
// Arrangements starting with an event strictly below the current one
// form a (num_below / num_equal) fraction of the old span; add that
// many positions so the current event's encoding follows.
value_num = num_equal_odd * value_num
+ ((num_below_odd * span_num)
<< (num_below_shift + span_shift - num_equal_shift));
span_num = new_span_num;
span_shift = new_span_shift;
denom = new_denom;
}
// Fold the deferred denominator in with a single modular inverse.
uint64_t denom_inv = mod_inverse(denom);
return {(span_num * denom_inv) << span_shift, value_num * denom_inv};
}
/** Encode a batch of binary events given as the sorted 0-indexed positions of
* True events and the total event count `n`.
*
* Efficient when the number of True events K is small; complexity is
* K(K+3)+13 mul and K(K+3) ctz. Produces the same (span, value) as
* process_batch would on the corresponding bit sequence.
*
* Leaks information about the returned value through timing.
*/
UniformSpan process_batch_low_k(std::span<const unsigned> positions, unsigned n) noexcept
{
uint64_t value_num{0}, value_denom{1};
unsigned num_1{0};
// Accumulate the value in fraction form: at each position we add
// (position choose num_1), whose denominator is tracked alongside the value.
for (unsigned position : positions) {
++num_1;
auto [add_num, add_denom] = binomial_fraction(position, num_1);
value_num = value_num * add_denom + add_num * value_denom;
value_denom *= add_denom;
}
// Span is (n choose num_1), in the same form.
auto [span_num, span_denom] = binomial_fraction(n, num_1);
// Batch inversion: one mod_inverse covers both denominators at once.
uint64_t inv = mod_inverse(value_denom * span_denom);
return {span_num * inv * value_denom, value_num * inv * span_denom};
}
/** Encode a batch of binary events using the precomputed factorial table.
* Operates on 0-indexed positions of True events; cost is 2K+2 mul plus
* table lookups. Requires n < 4096. Produces the same (span, value)
* as process_batch would on the corresponding bit sequence.
*
* Leaks information about the returned value through timing.
*/
UniformSpan process_batch_table(std::span<const unsigned> positions, unsigned n) noexcept
{
uint64_t value{0};
unsigned num_1{0};
for (unsigned position : positions) {
++num_1;
value += binomial_table(position, num_1);
}
uint64_t span = binomial_table(n, num_1);
return {span, value};
}
/** Combine entropy from two uniform spans into one.
*
* Interprets lo as a digit below hi: the merged value is
* hi.value * lo.span + lo.value, drawn uniformly from [0, hi.span * lo.span).
*/
UniformSpan merge(const UniformSpan& hi, const UniformSpan& lo) noexcept
{
return {hi.span * lo.span, hi.value * lo.span + lo.value};
}
/** Extract one uniform bit from state, mutating it.
*
* Returns the extracted bit, or std::nullopt if no bit could be produced
* (in which case state is reset to {1, 0}).
*
* Timing side-channels: unlike process_batch and process_batch_multi, this
* function is not constant-time. However, its runtime does not depend on the
* value of the extracted bit. The only data-dependent branches turn on
* information that is not secret in this setting:
* - the overflow fold (value >= span) is taken based on whether the
* prior computation overflowed, which the caller can already observe;
* - the reset path (value == maxval on an odd-sized span) is taken based
* on whether an output bit was produced at all.
* The extracted bit itself (value & 1) never gates a branch.
*/
std::optional<bool> extract_bit(UniformSpan& state) noexcept
{
uint64_t span = state.span;
uint64_t value = state.value;
// Overflow fold: if value >= span the observation lives in the q/S slice
// of the true distribution, which is itself uniform over a span of size
// 2^64 - state.span = (-state.span) in uint64_t. This also handles
// span == 0 (true span is 2^64): value >= 0 is always true, so we take
// this branch, and -0 == 0 leaves span at 0, giving maxval == MAX_VALUE.
if (value >= span) {
value -= span;
span = -span;
}
uint64_t maxval = span - 1;
// If the span has odd size (maxval even), one value cannot be paired with
// a sibling of opposite parity; drop it. If we landed on that value, we
// cannot emit a bit: reset the state and signal failure.
if ((maxval & 1) == 0) {
if (value == maxval) [[unlikely]] {
state = {1, 0};
return std::nullopt;
}
--maxval;
}
// Emit the bottom bit of value; shift span and value down together.
bool bit = value & 1;
state.span = (maxval >> 1) + 1;
state.value = value >> 1;
return bit;
}
|