Skip to content
Snippets Groups Projects
Commit 5abad157 authored by Yawning Angel's avatar Yawning Angel
Browse files

Use Vose's Alias Method to sample the weighted distribution.

The weight generation code also was cleaned up (and now can support
generating distributions that look like what ScrambleSuit does as
a compile time change).

Per: http://www.keithschwarz.com/darts-dice-coins/
parent 6245391c
No related branches found
No related tags found
No related merge requests found
...@@ -68,9 +68,9 @@ func (r csRandSource) Seed(seed int64) { ...@@ -68,9 +68,9 @@ func (r csRandSource) Seed(seed int64) {
// No-op. // No-op.
} }
// Int63n returns, as a int64, a pseudo random number in [0, n). // Intn returns, as a int, a pseudo random number in [0, n).
func Int63n(n int64) int64 { func Intn(n int) int {
return CsRand.Int63n(n) return CsRand.Intn(n)
} }
// Float64 returns, as a float64, a pesudo random number in [0.0,1.0). // Float64 returns, as a float64, a pesudo random number in [0.0,1.0).
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
package obfs4 package obfs4
import ( import (
"container/list"
"fmt" "fmt"
"math/rand" "math/rand"
...@@ -36,27 +37,25 @@ import ( ...@@ -36,27 +37,25 @@ import (
) )
const ( const (
minBuckets = 1 minValues = 1
maxBuckets = 100 maxValues = 100
) )
// wDist is a weighted distribution. // wDist is a weighted distribution.
type wDist struct { type wDist struct {
minValue int minValue int
maxValue int maxValue int
values []int values []int
buckets []int64 weights []float64
totalWeight int64
rng *rand.Rand alias []int
prob []float64
} }
// newWDist creates a weighted distribution of values ranging from min to max // newWDist creates a weighted distribution of values ranging from min to max
// based on a HashDrbg initialized with seed. // based on a HashDrbg initialized with seed.
func newWDist(seed *drbg.Seed, min, max int) (w *wDist) { func newWDist(seed *drbg.Seed, min, max int) (w *wDist) {
w = new(wDist) w = &wDist{minValue: min, maxValue: max}
w.minValue = min
w.maxValue = max
if max <= min { if max <= min {
panic(fmt.Sprintf("wDist.Reset(): min >= max (%d, %d)", min, max)) panic(fmt.Sprintf("wDist.Reset(): min >= max (%d, %d)", min, max))
...@@ -67,46 +66,147 @@ func newWDist(seed *drbg.Seed, min, max int) (w *wDist) { ...@@ -67,46 +66,147 @@ func newWDist(seed *drbg.Seed, min, max int) (w *wDist) {
return return
} }
// sample generates a random value according to the distribution. // genValues creates a slice containing a random number of random values
func (w *wDist) sample() int { // that when scaled by adding minValue will fall into [min, max].
retIdx := 0 func (w *wDist) genValues(rng *rand.Rand) {
var totalWeight int64 nValues := (w.maxValue + 1) - w.minValue
weight := csrand.Int63n(w.totalWeight) values := rng.Perm(nValues)
for i, bucketWeight := range w.buckets { if nValues < minValues {
totalWeight += bucketWeight nValues = minValues
if weight <= totalWeight { }
retIdx = i if nValues > maxValues {
break nValues = maxValues
}
nValues = rng.Intn(nValues) + 1
w.values = values[:nValues]
}
// genBiasedWeights generates a non-uniform weight list, similar to the
// ScrambleSuit prob_dist module.
func (w *wDist) genBiasedWeights(rng *rand.Rand) {
w.weights = make([]float64, len(w.values))
culmProb := 0.0
for i := range w.values {
p := (1.0 - culmProb) * rng.Float64()
w.weights[i] = p
culmProb += p
}
}
// genUniformWeights generates a uniform weight list.
func (w *wDist) genUniformWeights(rng *rand.Rand) {
w.weights = make([]float64, len(w.values))
for i := range w.weights {
w.weights[i] = rng.Float64()
}
}
// genTables calculates the alias and prob tables used for Vose's Alias method.
// Algorithm taken from http://www.keithschwarz.com/darts-dice-coins/
func (w *wDist) genTables() {
n := len(w.weights)
var sum float64
for _, weight := range w.weights {
sum += weight
}
// Create arrays $Alias$ and $Prob$, each of size $n$.
alias := make([]int, n)
prob := make([]float64, n)
// Create two worklists, $Small$ and $Large$.
small := list.New()
large := list.New()
scaled := make([]float64, n)
for i, weight := range w.weights {
// Multiply each probability by $n$.
p_i := weight * float64(n) / sum
scaled[i] = p_i
// For each scaled probability $p_i$:
if scaled[i] < 1.0 {
// If $p_i < 1$, add $i$ to $Small$.
small.PushBack(i)
} else {
// Otherwise ($p_i \ge 1$), add $i$ to $Large$.
large.PushBack(i)
} }
} }
return w.minValue + w.values[retIdx] // While $Small$ and $Large$ are not empty: ($Large$ might be emptied first)
for small.Len() > 0 && large.Len() > 0 {
// Remove the first element from $Small$; call it $l$.
l := small.Remove(small.Front()).(int)
// Remove the first element from $Large$; call it $g$.
g := large.Remove(large.Front()).(int)
// Set $Prob[l] = p_l$.
prob[l] = scaled[l]
// Set $Alias[l] = g$.
alias[l] = g
// Set $p_g := (p_g + p_l) - 1$. (This is a more numerically stable option.)
scaled[g] = (scaled[g] + scaled[l]) - 1.0
if scaled[g] < 1.0 {
// If $p_g < 1$, add $g$ to $Small$.
small.PushBack(g)
} else {
// Otherwise ($p_g \ge 1$), add $g$ to $Large$.
large.PushBack(g)
}
}
// While $Large$ is not empty:
for large.Len() > 0 {
// Remove the first element from $Large$; call it $g$.
g := large.Remove(large.Front()).(int)
// Set $Prob[g] = 1$.
prob[g] = 1.0
}
// While $Small$ is not empty: This is only possible due to numerical instability.
for small.Len() > 0 {
// Remove the first element from $Small$; call it $l$.
l := small.Remove(small.Front()).(int)
// Set $Prob[l] = 1$.
prob[l] = 1.0
}
w.prob = prob
w.alias = alias
} }
// reset generates a new distribution with the same min/max based on a new seed. // reset generates a new distribution with the same min/max based on a new seed.
func (w *wDist) reset(seed *drbg.Seed) { func (w *wDist) reset(seed *drbg.Seed) {
// Initialize the deterministic random number generator. // Initialize the deterministic random number generator.
drbg := drbg.NewHashDrbg(seed) drbg := drbg.NewHashDrbg(seed)
w.rng = rand.New(drbg) rng := rand.New(drbg)
nBuckets := (w.maxValue + 1) - w.minValue w.genValues(rng)
w.values = w.rng.Perm(nBuckets) //w.genBiasedWeights(rng)
if nBuckets < minBuckets { w.genUniformWeights(rng)
nBuckets = minBuckets w.genTables()
} }
if nBuckets > maxBuckets {
nBuckets = maxBuckets // sample generates a random value according to the distribution.
} func (w *wDist) sample() int {
nBuckets = w.rng.Intn(nBuckets) + 1 var idx int
w.totalWeight = 0 // Generate a fair die roll from an $n$-sided die; call the side $i$.
w.buckets = make([]int64, nBuckets) i := csrand.Intn(len(w.values))
for i, _ := range w.buckets { // Flip a biased coin that comes up heads with probability $Prob[i]$.
prob := w.rng.Int63n(1000) if csrand.Float64() <= w.prob[i] {
w.buckets[i] = prob // If the coin comes up "heads," return $i$.
w.totalWeight += prob idx = i
} else {
// Otherwise, return $Alias[i]$.
idx = w.alias[i]
} }
w.buckets[len(w.buckets)-1] = w.totalWeight
return w.minValue + w.values[idx]
} }
/* vim :set ts=4 sw=4 sts=4 noet : */ /* vim :set ts=4 sw=4 sts=4 noet : */
/*
* Copyright (c) 2014, Yawning Angel <yawning at torproject dot org>
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
package obfs4
import (
"fmt"
"testing"
"github.com/yawning/obfs4/drbg"
)
const debug = false
func TestWeightedDist(t *testing.T) {
seed, err := drbg.NewSeed()
if err != nil {
t.Fatal("failed to generate a DRBG seed:", err)
}
const nrTrials = 1000000
hist := make([]int, 1000)
w := newWDist(seed, 0, 999)
if debug {
// Dump a string representation of the probability table.
fmt.Println("Table:")
var sum float64
for _, weight := range w.weights {
sum += weight
}
for i, weight := range w.weights {
p := weight / sum
if p > 0.000001 { // Filter out tiny values.
fmt.Printf(" [%d]: %f\n", w.minValue+w.values[i], p)
}
}
fmt.Println()
}
for i := 0; i < nrTrials; i++ {
value := w.sample()
hist[value]++
}
if debug {
fmt.Println("Generated:")
for value, count := range hist {
if count != 0 {
p := float64(count) / float64(nrTrials)
fmt.Printf(" [%d]: %f (%d)\n", value, p, count)
}
}
}
}
/* vim :set ts=4 sw=4 sts=4 noet : */
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment