From 5abad1571c7d0869e29d55ca01df83fef8cd4606 Mon Sep 17 00:00:00 2001
From: Yawning Angel <yawning@schwanenlied.me>
Date: Thu, 19 Jun 2014 06:29:12 +0000
Subject: [PATCH] Use Vose's Alias Method to sample the weighted distribution.

The weight generation code also was cleaned up (and now can support
generating distributions that look like what ScrambleSuit does as
a compile time change).

Per: http://www.keithschwarz.com/darts-dice-coins/
---
 csrand/csrand.go      |   6 +-
 weighted_dist.go      | 180 ++++++++++++++++++++++++++++++++----------
 weighted_dist_test.go |  82 +++++++++++++++++++
 3 files changed, 225 insertions(+), 43 deletions(-)
 create mode 100644 weighted_dist_test.go

diff --git a/csrand/csrand.go b/csrand/csrand.go
index a3299aa..b059ed0 100644
--- a/csrand/csrand.go
+++ b/csrand/csrand.go
@@ -68,9 +68,9 @@ func (r csRandSource) Seed(seed int64) {
 	// No-op.
 }
 
-// Int63n returns, as a int64, a pseudo random number in [0, n).
-func Int63n(n int64) int64 {
-	return CsRand.Int63n(n)
+// Intn returns, as a int, a pseudo random number in [0, n).
+func Intn(n int) int {
+	return CsRand.Intn(n)
 }
 
 // Float64 returns, as a float64, a pesudo random number in [0.0,1.0).
diff --git a/weighted_dist.go b/weighted_dist.go
index 02fb26d..7c47cb8 100644
--- a/weighted_dist.go
+++ b/weighted_dist.go
@@ -28,6 +28,7 @@
 package obfs4
 
 import (
+	"container/list"
 	"fmt"
 	"math/rand"
 
@@ -36,27 +37,25 @@ import (
 )
 
 const (
-	minBuckets = 1
-	maxBuckets = 100
+	minValues = 1
+	maxValues = 100
 )
 
 // wDist is a weighted distribution.
 type wDist struct {
-	minValue    int
-	maxValue    int
-	values      []int
-	buckets     []int64
-	totalWeight int64
+	minValue int
+	maxValue int
+	values   []int
+	weights  []float64
 
-	rng *rand.Rand
+	alias []int
+	prob  []float64
 }
 
 // newWDist creates a weighted distribution of values ranging from min to max
 // based on a HashDrbg initialized with seed.
 func newWDist(seed *drbg.Seed, min, max int) (w *wDist) {
-	w = new(wDist)
-	w.minValue = min
-	w.maxValue = max
+	w = &wDist{minValue: min, maxValue: max}
 
 	if max <= min {
 		panic(fmt.Sprintf("wDist.Reset(): min >= max (%d, %d)", min, max))
@@ -67,46 +66,147 @@ func newWDist(seed *drbg.Seed, min, max int) (w *wDist) {
 	return
 }
 
-// sample generates a random value according to the distribution.
-func (w *wDist) sample() int {
-	retIdx := 0
-	var totalWeight int64
-	weight := csrand.Int63n(w.totalWeight)
-	for i, bucketWeight := range w.buckets {
-		totalWeight += bucketWeight
-		if weight <= totalWeight {
-			retIdx = i
-			break
+// genValues creates a slice containing a random number of random values
+// that when scaled by adding minValue will fall into [min, max].
+func (w *wDist) genValues(rng *rand.Rand) {
+	nValues := (w.maxValue + 1) - w.minValue
+	values := rng.Perm(nValues)
+	if nValues < minValues {
+		nValues = minValues
+	}
+	if nValues > maxValues {
+		nValues = maxValues
+	}
+	nValues = rng.Intn(nValues) + 1
+	w.values = values[:nValues]
+}
+
+// genBiasedWeights generates a non-uniform weight list, similar to the
+// ScrambleSuit prob_dist module.
+func (w *wDist) genBiasedWeights(rng *rand.Rand) {
+	w.weights = make([]float64, len(w.values))
+
+	culmProb := 0.0
+	for i := range w.values {
+		p := (1.0 - culmProb) * rng.Float64()
+		w.weights[i] = p
+		culmProb += p
+	}
+}
+
+// genUniformWeights generates a uniform weight list.
+func (w *wDist) genUniformWeights(rng *rand.Rand) {
+	w.weights = make([]float64, len(w.values))
+	for i := range w.weights {
+		w.weights[i] = rng.Float64()
+	}
+}
+
+// genTables calculates the alias and prob tables used for Vose's Alias method.
+// Algorithm taken from http://www.keithschwarz.com/darts-dice-coins/
+func (w *wDist) genTables() {
+	n := len(w.weights)
+	var sum float64
+	for _, weight := range w.weights {
+		sum += weight
+	}
+
+	// Create arrays $Alias$ and $Prob$, each of size $n$.
+	alias := make([]int, n)
+	prob := make([]float64, n)
+
+	// Create two worklists, $Small$ and $Large$.
+	small := list.New()
+	large := list.New()
+
+	scaled := make([]float64, n)
+	for i, weight := range w.weights {
+		// Multiply each probability by $n$.
+		p_i := weight * float64(n) / sum
+		scaled[i] = p_i
+
+		// For each scaled probability $p_i$:
+		if scaled[i] < 1.0 {
+			// If $p_i < 1$, add $i$ to $Small$.
+			small.PushBack(i)
+		} else {
+			// Otherwise ($p_i \ge 1$), add $i$ to $Large$.
+			large.PushBack(i)
 		}
 	}
 
-	return w.minValue + w.values[retIdx]
+	// While $Small$ and $Large$ are not empty: ($Large$ might be emptied first)
+	for small.Len() > 0 && large.Len() > 0 {
+		// Remove the first element from $Small$; call it $l$.
+		l := small.Remove(small.Front()).(int)
+		// Remove the first element from $Large$; call it $g$.
+		g := large.Remove(large.Front()).(int)
+
+		// Set $Prob[l] = p_l$.
+		prob[l] = scaled[l]
+		// Set $Alias[l] = g$.
+		alias[l] = g
+
+		// Set $p_g := (p_g + p_l) - 1$. (This is a more numerically stable option.)
+		scaled[g] = (scaled[g] + scaled[l]) - 1.0
+
+		if scaled[g] < 1.0 {
+			// If $p_g < 1$, add $g$ to $Small$.
+			small.PushBack(g)
+		} else {
+			// Otherwise ($p_g \ge 1$), add $g$ to $Large$.
+			large.PushBack(g)
+		}
+	}
+
+	// While $Large$ is not empty:
+	for large.Len() > 0 {
+		// Remove the first element from $Large$; call it $g$.
+		g := large.Remove(large.Front()).(int)
+		// Set $Prob[g] = 1$.
+		prob[g] = 1.0
+	}
+
+	// While $Small$ is not empty: This is only possible due to numerical instability.
+	for small.Len() > 0 {
+		// Remove the first element from $Small$; call it $l$.
+		l := small.Remove(small.Front()).(int)
+		// Set $Prob[l] = 1$.
+		prob[l] = 1.0
+	}
+
+	w.prob = prob
+	w.alias = alias
 }
 
 // reset generates a new distribution with the same min/max based on a new seed.
 func (w *wDist) reset(seed *drbg.Seed) {
 	// Initialize the deterministic random number generator.
 	drbg := drbg.NewHashDrbg(seed)
-	w.rng = rand.New(drbg)
+	rng := rand.New(drbg)
 
-	nBuckets := (w.maxValue + 1) - w.minValue
-	w.values = w.rng.Perm(nBuckets)
-	if nBuckets < minBuckets {
-		nBuckets = minBuckets
-	}
-	if nBuckets > maxBuckets {
-		nBuckets = maxBuckets
-	}
-	nBuckets = w.rng.Intn(nBuckets) + 1
-
-	w.totalWeight = 0
-	w.buckets = make([]int64, nBuckets)
-	for i, _ := range w.buckets {
-		prob := w.rng.Int63n(1000)
-		w.buckets[i] = prob
-		w.totalWeight += prob
+	w.genValues(rng)
+	//w.genBiasedWeights(rng)
+	w.genUniformWeights(rng)
+	w.genTables()
+}
+
+// sample generates a random value according to the distribution.
+func (w *wDist) sample() int {
+	var idx int
+
+	// Generate a fair die roll from an $n$-sided die; call the side $i$.
+	i := csrand.Intn(len(w.values))
+	// Flip a biased coin that comes up heads with probability $Prob[i]$.
+	if csrand.Float64() <= w.prob[i] {
+		// If the coin comes up "heads," return $i$.
+		idx = i
+	} else {
+		// Otherwise, return $Alias[i]$.
+		idx = w.alias[i]
 	}
-	w.buckets[len(w.buckets)-1] = w.totalWeight
+
+	return w.minValue + w.values[idx]
 }
 
 /* vim :set ts=4 sw=4 sts=4 noet : */
diff --git a/weighted_dist_test.go b/weighted_dist_test.go
new file mode 100644
index 0000000..14fecec
--- /dev/null
+++ b/weighted_dist_test.go
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) 2014, Yawning Angel <yawning at torproject dot org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ *  * Redistributions of source code must retain the above copyright notice,
+ *    this list of conditions and the following disclaimer.
+ *
+ *  * Redistributions in binary form must reproduce the above copyright notice,
+ *    this list of conditions and the following disclaimer in the documentation
+ *    and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+ * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ */
+
+package obfs4
+
+import (
+	"fmt"
+	"testing"
+
+	"github.com/yawning/obfs4/drbg"
+)
+
+const debug = false
+
+func TestWeightedDist(t *testing.T) {
+	seed, err := drbg.NewSeed()
+	if err != nil {
+		t.Fatal("failed to generate a DRBG seed:", err)
+	}
+
+	const nrTrials = 1000000
+
+	hist := make([]int, 1000)
+
+	w := newWDist(seed, 0, 999)
+	if debug {
+		// Dump a string representation of the probability table.
+		fmt.Println("Table:")
+		var sum float64
+		for _, weight := range w.weights {
+			sum += weight
+		}
+		for i, weight := range w.weights {
+			p := weight / sum
+			if p > 0.000001 { // Filter out tiny values.
+				fmt.Printf(" [%d]: %f\n", w.minValue+w.values[i], p)
+			}
+		}
+		fmt.Println()
+	}
+
+	for i := 0; i < nrTrials; i++ {
+		value := w.sample()
+		hist[value]++
+	}
+
+	if debug {
+		fmt.Println("Generated:")
+		for value, count := range hist {
+			if count != 0 {
+				p := float64(count) / float64(nrTrials)
+				fmt.Printf(" [%d]: %f (%d)\n", value, p, count)
+			}
+		}
+	}
+}
+
+/* vim :set ts=4 sw=4 sts=4 noet : */
-- 
GitLab