From 5abad1571c7d0869e29d55ca01df83fef8cd4606 Mon Sep 17 00:00:00 2001 From: Yawning Angel <yawning@schwanenlied.me> Date: Thu, 19 Jun 2014 06:29:12 +0000 Subject: [PATCH] Use Vose's Alias Method to sample the weighted distribution. The weight generation code also was cleaned up (and now can support generating distributions that look like what ScrambleSuit does as a compile time change). Per: http://www.keithschwarz.com/darts-dice-coins/ --- csrand/csrand.go | 6 +- weighted_dist.go | 180 ++++++++++++++++++++++++++++++++---------- weighted_dist_test.go | 82 +++++++++++++++++++ 3 files changed, 225 insertions(+), 43 deletions(-) create mode 100644 weighted_dist_test.go diff --git a/csrand/csrand.go b/csrand/csrand.go index a3299aa..b059ed0 100644 --- a/csrand/csrand.go +++ b/csrand/csrand.go @@ -68,9 +68,9 @@ func (r csRandSource) Seed(seed int64) { // No-op. } -// Int63n returns, as a int64, a pseudo random number in [0, n). -func Int63n(n int64) int64 { - return CsRand.Int63n(n) +// Intn returns, as a int, a pseudo random number in [0, n). +func Intn(n int) int { + return CsRand.Intn(n) } // Float64 returns, as a float64, a pesudo random number in [0.0,1.0). diff --git a/weighted_dist.go b/weighted_dist.go index 02fb26d..7c47cb8 100644 --- a/weighted_dist.go +++ b/weighted_dist.go @@ -28,6 +28,7 @@ package obfs4 import ( + "container/list" "fmt" "math/rand" @@ -36,27 +37,25 @@ import ( ) const ( - minBuckets = 1 - maxBuckets = 100 + minValues = 1 + maxValues = 100 ) // wDist is a weighted distribution. type wDist struct { - minValue int - maxValue int - values []int - buckets []int64 - totalWeight int64 + minValue int + maxValue int + values []int + weights []float64 - rng *rand.Rand + alias []int + prob []float64 } // newWDist creates a weighted distribution of values ranging from min to max // based on a HashDrbg initialized with seed. func newWDist(seed *drbg.Seed, min, max int) (w *wDist) { - w = new(wDist) - w.minValue = min - w.maxValue = max + w = &wDist{minValue: min, maxValue: max} if max <= min { panic(fmt.Sprintf("wDist.Reset(): min >= max (%d, %d)", min, max)) @@ -67,46 +66,147 @@ func newWDist(seed *drbg.Seed, min, max int) (w *wDist) { return } -// sample generates a random value according to the distribution. -func (w *wDist) sample() int { - retIdx := 0 - var totalWeight int64 - weight := csrand.Int63n(w.totalWeight) - for i, bucketWeight := range w.buckets { - totalWeight += bucketWeight - if weight <= totalWeight { - retIdx = i - break +// genValues creates a slice containing a random number of random values +// that when scaled by adding minValue will fall into [min, max]. +func (w *wDist) genValues(rng *rand.Rand) { + nValues := (w.maxValue + 1) - w.minValue + values := rng.Perm(nValues) + if nValues < minValues { + nValues = minValues + } + if nValues > maxValues { + nValues = maxValues + } + nValues = rng.Intn(nValues) + 1 + w.values = values[:nValues] +} + +// genBiasedWeights generates a non-uniform weight list, similar to the +// ScrambleSuit prob_dist module. +func (w *wDist) genBiasedWeights(rng *rand.Rand) { + w.weights = make([]float64, len(w.values)) + + culmProb := 0.0 + for i := range w.values { + p := (1.0 - culmProb) * rng.Float64() + w.weights[i] = p + culmProb += p + } +} + +// genUniformWeights generates a uniform weight list. +func (w *wDist) genUniformWeights(rng *rand.Rand) { + w.weights = make([]float64, len(w.values)) + for i := range w.weights { + w.weights[i] = rng.Float64() + } +} + +// genTables calculates the alias and prob tables used for Vose's Alias method. +// Algorithm taken from http://www.keithschwarz.com/darts-dice-coins/ +func (w *wDist) genTables() { + n := len(w.weights) + var sum float64 + for _, weight := range w.weights { + sum += weight + } + + // Create arrays $Alias$ and $Prob$, each of size $n$. + alias := make([]int, n) + prob := make([]float64, n) + + // Create two worklists, $Small$ and $Large$. + small := list.New() + large := list.New() + + scaled := make([]float64, n) + for i, weight := range w.weights { + // Multiply each probability by $n$. + p_i := weight * float64(n) / sum + scaled[i] = p_i + + // For each scaled probability $p_i$: + if scaled[i] < 1.0 { + // If $p_i < 1$, add $i$ to $Small$. + small.PushBack(i) + } else { + // Otherwise ($p_i \ge 1$), add $i$ to $Large$. + large.PushBack(i) } } - return w.minValue + w.values[retIdx] + // While $Small$ and $Large$ are not empty: ($Large$ might be emptied first) + for small.Len() > 0 && large.Len() > 0 { + // Remove the first element from $Small$; call it $l$. + l := small.Remove(small.Front()).(int) + // Remove the first element from $Large$; call it $g$. + g := large.Remove(large.Front()).(int) + + // Set $Prob[l] = p_l$. + prob[l] = scaled[l] + // Set $Alias[l] = g$. + alias[l] = g + + // Set $p_g := (p_g + p_l) - 1$. (This is a more numerically stable option.) + scaled[g] = (scaled[g] + scaled[l]) - 1.0 + + if scaled[g] < 1.0 { + // If $p_g < 1$, add $g$ to $Small$. + small.PushBack(g) + } else { + // Otherwise ($p_g \ge 1$), add $g$ to $Large$. + large.PushBack(g) + } + } + + // While $Large$ is not empty: + for large.Len() > 0 { + // Remove the first element from $Large$; call it $g$. + g := large.Remove(large.Front()).(int) + // Set $Prob[g] = 1$. + prob[g] = 1.0 + } + + // While $Small$ is not empty: This is only possible due to numerical instability. + for small.Len() > 0 { + // Remove the first element from $Small$; call it $l$. + l := small.Remove(small.Front()).(int) + // Set $Prob[l] = 1$. + prob[l] = 1.0 + } + + w.prob = prob + w.alias = alias } // reset generates a new distribution with the same min/max based on a new seed. func (w *wDist) reset(seed *drbg.Seed) { // Initialize the deterministic random number generator. drbg := drbg.NewHashDrbg(seed) - w.rng = rand.New(drbg) + rng := rand.New(drbg) - nBuckets := (w.maxValue + 1) - w.minValue - w.values = w.rng.Perm(nBuckets) - if nBuckets < minBuckets { - nBuckets = minBuckets - } - if nBuckets > maxBuckets { - nBuckets = maxBuckets - } - nBuckets = w.rng.Intn(nBuckets) + 1 - - w.totalWeight = 0 - w.buckets = make([]int64, nBuckets) - for i, _ := range w.buckets { - prob := w.rng.Int63n(1000) - w.buckets[i] = prob - w.totalWeight += prob + w.genValues(rng) + //w.genBiasedWeights(rng) + w.genUniformWeights(rng) + w.genTables() +} + +// sample generates a random value according to the distribution. +func (w *wDist) sample() int { + var idx int + + // Generate a fair die roll from an $n$-sided die; call the side $i$. + i := csrand.Intn(len(w.values)) + // Flip a biased coin that comes up heads with probability $Prob[i]$. + if csrand.Float64() <= w.prob[i] { + // If the coin comes up "heads," return $i$. + idx = i + } else { + // Otherwise, return $Alias[i]$. + idx = w.alias[i] } - w.buckets[len(w.buckets)-1] = w.totalWeight + + return w.minValue + w.values[idx] } /* vim :set ts=4 sw=4 sts=4 noet : */ diff --git a/weighted_dist_test.go b/weighted_dist_test.go new file mode 100644 index 0000000..14fecec --- /dev/null +++ b/weighted_dist_test.go @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2014, Yawning Angel <yawning at torproject dot org> + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +package obfs4 + +import ( + "fmt" + "testing" + + "github.com/yawning/obfs4/drbg" +) + +const debug = false + +func TestWeightedDist(t *testing.T) { + seed, err := drbg.NewSeed() + if err != nil { + t.Fatal("failed to generate a DRBG seed:", err) + } + + const nrTrials = 1000000 + + hist := make([]int, 1000) + + w := newWDist(seed, 0, 999) + if debug { + // Dump a string representation of the probability table. + fmt.Println("Table:") + var sum float64 + for _, weight := range w.weights { + sum += weight + } + for i, weight := range w.weights { + p := weight / sum + if p > 0.000001 { // Filter out tiny values. + fmt.Printf(" [%d]: %f\n", w.minValue+w.values[i], p) + } + } + fmt.Println() + } + + for i := 0; i < nrTrials; i++ { + value := w.sample() + hist[value]++ + } + + if debug { + fmt.Println("Generated:") + for value, count := range hist { + if count != 0 { + p := float64(count) / float64(nrTrials) + fmt.Printf(" [%d]: %f (%d)\n", value, p, count) + } + } + } +} + +/* vim :set ts=4 sw=4 sts=4 noet : */ -- GitLab