diff --git a/pkg/api/endpoints.go b/pkg/api/endpoints.go index c568c589fe4a77a2838b91413b5ddd7296a79fbf..e0cd9a9b270423efe2fb1c47c87743aba9ce566b 100644 --- a/pkg/api/endpoints.go +++ b/pkg/api/endpoints.go @@ -44,9 +44,9 @@ func sortEndpoints[T m.Bridge | m.Gateway](cc string, endpoints []*T, locations return endpoints, nil } - // without a country code we can only create a randomized lizst of endpoints + // without a country code we can only create a randomized list of endpoints if cc == "" { - return createWeightedRandomList(endpoints) + return wr.CreateWeightedRandomList(endpoints) } locationWeights, err := locationWeights(cc, locations, lm) @@ -57,23 +57,13 @@ func sortEndpoints[T m.Bridge | m.Gateway](cc string, endpoints []*T, locations tree := createLocationTree(endpoints, locationWeights) result := []*T{} - // iterate over _lists_ of gateways, grouped by location and sorted by distance + // iterate over _lists_ of endpoints, grouped by location and sorted by distance for _, list := range tree.Values() { - // prepare weighted randomized list for endpoints grouped by location - chooser, err := wr.NewChooser( - list..., - ) + // add weighted randomized endpoints per location to result list + result, err = wr.AppendWeightedRandomChoices(list, result) if err != nil { return nil, err } - for range list { - // append random element from current list to result - element, err := chooser.PickRandom() - if err != nil { - return nil, fmt.Errorf("failed to sort endpoint list: %v", err) - } - result = append(result, *element) - } } return result, nil @@ -97,29 +87,6 @@ func createLocationTree[T m.Bridge | m.Gateway](endpoints []*T, locationWeights return tree } -func createWeightedRandomList[T m.Bridge | m.Gateway](endpoints []*T) ([]*T, error) { - list := []wr.Choice[T]{} - for _, endpoint := range endpoints { - if endpoint == nil { - continue - } - list = append(list, wr.Choice[T]{Item: *endpoint, Weight: 1}) - } - choser, err := wr.NewChooser(list...) - if err != nil { - return nil, err - } - result := []*T{} - for range endpoints { - choice, err := choser.PickRandom() - if err != nil { - return nil, fmt.Errorf("failed to create weighted random endpoint list: %v", err) - } - result = append(result, choice) - } - return result, nil -} - func getLocation(endpoint interface{}) string { switch v := any(endpoint).(type) { case *m.Gateway: @@ -159,7 +126,7 @@ func locationWeights(cc string, locations locationMap, lm *latency.Metric) (map[ return locationWeights, nil } -// normalizeDistance normalizes the haversine distance between to points to a value between 0 and 1 +// normalizeDistance normalizes the haversine distance between two points to a value between 0 and 1 func normalizeDistance(distance float64) float64 { return (distance * 2) / geolocate.EARTH_CIRCUMFERENCE } @@ -169,7 +136,7 @@ func normalizeLatency(latency, maxLatency float64) float64 { return latency / maxLatency } -// combine adds a new value to an base value and normalizes it to a value between 0 and 1 +// combine adds a new value to a base value and normalizes it to a value between 0 and 1 func combine(base, newVal, newValWeight float64) float64 { if newVal < 0 || newVal > 1 { // ignore newVal if it's not normalized @@ -180,7 +147,7 @@ func combine(base, newVal, newValWeight float64) float64 { // sanitizeCountryCode will check if the passed ISO-2 country code is known to us. // For empty strings, they stay the same. -// For non-valid country codes, it will return an error. Otherwise, will return the original string. +// For non-valid country codes, it returns an error. Otherwise, it returns the original string. func sanitizeCountryCode(cc string) (string, error) { if cc == "" { return "", nil diff --git a/pkg/api/endpoints_test.go b/pkg/api/endpoints_test.go index 5397235f8805fc10894c629b9bc3740cde48338f..8b9b07cd18d1e70282c191f3c1c6c7dbde2494be 100644 --- a/pkg/api/endpoints_test.go +++ b/pkg/api/endpoints_test.go @@ -10,6 +10,7 @@ import ( "0xacab.org/leap/menshen/pkg/latency" m "0xacab.org/leap/menshen/pkg/models" + wr "0xacab.org/leap/menshen/pkg/weightedrand" "github.com/labstack/echo/v4" "github.com/magiconair/properties/assert" "github.com/stretchr/testify/require" @@ -380,7 +381,7 @@ func TestCreateWeightedRandomList(t *testing.T) { bridge1First := false bridge2First := false for range 100 { - randomlist, err := createWeightedRandomList(list) + randomlist, err := wr.CreateWeightedRandomList(list) assert.Equal(t, nil, err) bridge1Found := false bridge2Found := false diff --git a/pkg/weightedrand/weightedrand.go b/pkg/weightedrand/weightedrand.go index a452ca25732484c661487cf7045e0f7a1dddb6c5..37f9b0b7ad23046cc426a32dff07517442740ddb 100644 --- a/pkg/weightedrand/weightedrand.go +++ b/pkg/weightedrand/weightedrand.go @@ -23,11 +23,6 @@ type Choice[T any] struct { Weight int } -// NewChoice creates a new Choice with specified item and weight. -func NewChoice[T any](item *T, weight int) Choice[*T] { - return Choice[*T]{Item: item, Weight: weight} -} - // A Chooser caches many possible Choices in a structure designed to improve // performance on repeated calls for weighted random selection. type Chooser[T any] struct { @@ -80,3 +75,35 @@ func (c *Chooser[T]) PickRandom() (*T, error) { return nil, fmt.Errorf("failed to pick an item") } + +// CreateWeightedRandomList creates weighted randomized slice containing all elements of the source slice +func CreateWeightedRandomList[T any](source []*T) ([]*T, error) { + list := []Choice[T]{} + for _, element := range source { + if element == nil { + continue + } + list = append(list, Choice[T]{Item: *element, Weight: 1}) + } + result := []*T{} + return AppendWeightedRandomChoices(list, result) +} + +// AppendWeightedRandomChoices appends weighted randomized elements to a slice. +// The order of previously added elements in the result list is retained +func AppendWeightedRandomChoices[T any](choices []Choice[T], result []*T) ([]*T, error) { + chooser, err := NewChooser( + choices..., + ) + if err != nil { + return nil, err + } + for range choices { + element, err := chooser.PickRandom() + if err != nil { + return nil, fmt.Errorf("failed to create weighted random list: %v", err) + } + result = append(result, element) + } + return result, nil +} diff --git a/pkg/weightedrand/weightedrand_test.go b/pkg/weightedrand/weightedrand_test.go index f28e291600dc44e56b51b86bfd44a816b3fef455..06bb5d8997e5650777e6aaff3b0458261a0e8d09 100644 --- a/pkg/weightedrand/weightedrand_test.go +++ b/pkg/weightedrand/weightedrand_test.go @@ -8,6 +8,7 @@ import ( "github.com/jmcvetta/randutil" mrothwr "github.com/mroth/weightedrand/v2" + "github.com/tj/assert" ) const BMMinChoices = 10 @@ -55,8 +56,6 @@ func BenchmarkMultiple(b *testing.B) { } } -// THE SINGLE USAGE CASE IS AN ANTI-PATTERN FOR THE INTENDED USAGE OF THIS -// LIBRARY. Provide some optional benchmarks for that to illustrate the point. func BenchmarkSingle(b *testing.B) { if testing.Short() { b.Skip() @@ -129,3 +128,72 @@ func convertChoicesToMenshenWeightedRand[T rune](tb testing.TB, cs []mrothwr.Cho func fmt1eN(n int) string { return fmt.Sprintf("1e%d", int(math.Log10(float64(n)))) } + +func TestCreateWeightedRandomList(t *testing.T) { + element1 := "element1" + element2 := "element2" + element3 := "element3" + list := []*string{&element1, &element2, &element3} + expectedElements := []*string{&element1, &element2, &element3} + + element1First := false + element2First := false + element3First := false + for range 200 { + randomlist, err := CreateWeightedRandomList(list) + assert.ElementsMatch(t, expectedElements, randomlist) + assert.Equal(t, nil, err) + for i, element := range randomlist { + if *element == element1 { + if i == 0 { + element1First = true + } + } else if *element == element2 { + if i == 0 { + element2First = true + } + } else if *element == element3 { + if i == 0 { + element3First = true + } + } + } + if element1First && element2First && element3First { + break + } + } + assert.Equal(t, element1First, true, "element 1 at least once at first position in randomized list") + assert.Equal(t, element2First, true, "element 2 at least once at first position in randomized list") + assert.Equal(t, element3First, true, "element 3 at least once at first position in randomized list") +} + +func TestAppendWeightedRandomChoices(t *testing.T) { + element1 := "element1" + element2 := "element2" + element3 := "element3" + list := []*string{&element1} + + expectedElements := []*string{&element1, &element2, &element3} + element2Second := false + element3Second := false + for range 100 { + choices := []Choice[string]{{Item: element2, Weight: 1}, {Item: element3, Weight: 1}} + resultList, err := AppendWeightedRandomChoices(choices, list) + assert.Equal(t, nil, err) + assert.ElementsMatch(t, expectedElements, resultList, "all source elements appear in result list") + assert.Equal(t, element1, *resultList[0], "element 1 always first, randomized elements appended") + for i, element := range resultList { + if *element == element2 { + if i == 1 { + element2Second = true + } + } else if *element == element3 { + if i == 1 { + element3Second = true + } + } + } + } + assert.Equal(t, element2Second, true, "element 2 at least once at second position in randomized list") + assert.Equal(t, element3Second, true, "element 3 at least once at second position in randomized list") +}