diff --git a/pkg/api/bridge.go b/pkg/api/bridge.go index 4f89429a2e4bf63a6ea30ee02f050dc743f01355..aedeeaa67fb5ac3cead4dae88d612881f64b705b 100644 --- a/pkg/api/bridge.go +++ b/pkg/api/bridge.go @@ -2,6 +2,7 @@ package api import ( "net/http" + "slices" "strconv" "time" @@ -26,17 +27,48 @@ import ( // @Router /api/5/bridges [get] // @Security BucketTokenAuth func (r *registry) ListAllBridges(c echo.Context) error { - bridges := []*m.Bridge{} - for k := range r.bridges { - bridges = append(bridges, r.bridges[k]...) + // Step 1: Validate parameters + cc, err := sanitizeCountryCode(c.QueryParam(paramCountryCode)) + if err != nil { + return c.JSON(http.StatusBadRequest, err) } - filters := bridgeFiltersFromParams(c, []string{"type"}) + err = sanitizeLocations(c.QueryParam("loc"), r.locations) + if err != nil { + return c.JSON(http.StatusBadRequest, err) + } + + err = sanitizeTransport(c, []string{"udp", "tcp", "kcp", "quic"}) + if err != nil { + return c.JSON(http.StatusBadRequest, err) + } + + err = sanitizePort(c.QueryParam("port")) + if err != nil { + return c.JSON(http.StatusBadRequest, err) + } + + err = sanitizeType(c, []string{"obfs4", "obfs4-hop"}) + if err != nil { + return c.JSON(http.StatusBadRequest, err) + } + + // Step 2: We start with all bridges + bridges := r.AllBridges() + + // Step 3: apply gateway filters + filters := bridgeFiltersFromParams(c, []string{"type", "loc", "port", "tr"}) filters = maybeAddBridgeBucketFilter(c, filters) filters = maybeAddLastSeenBridgeCutoffFilter(r, filters) - filtered := filter[*m.Bridge](alltrue(filters), bridges) + availableBridges := filter[*m.Bridge](alltrue(filters), bridges) + + // Step 4: sort gateways + result, err := sortEndpoints(cc, availableBridges, r.locations, r.lm) - return c.JSON(http.StatusOK, filtered) + if err != nil { + return c.JSON(http.StatusBadRequest, nil) + } + return c.JSON(http.StatusOK, result) } // I'd love for this to be generic across "endpoints" (ie bridges AND gateways) @@ -51,15 +83,7 @@ func maybeAddBridgeBucketFilter(c echo.Context, filters []func(*m.Bridge) bool) } return append(filters, func(b *m.Bridge) bool { - if b.Bucket == "" { - return true - } - for _, bucket := range buckets { - if b.Bucket == bucket { - return true - } - } - return false + return slices.Contains(buckets, b.Bucket) }) } @@ -110,34 +134,12 @@ func maybeAddBridgeFilter(c echo.Context, param string, filters []func(*m.Bridge filter = func(b *m.Bridge) bool { return b.Location == q } + case "tr": + filter = func(b *m.Bridge) bool { + return b.Transport == q + } default: return filters } return append(filters, filter) } - -// BridgePicker godoc -// @Summary Get Bridges -// @Description fetch bridges by location -// @Tags Provisioning -// @Accept json -// @Produce json -// @Param location path string true "Location ID" -// @Success 200 {object} []models.Bridge -// @Failure 400 {object} error -// @Failure 404 {object} error -// @Failure 500 {object} error -// @Router /api/5/bridge/{location} [get] -// @Security BucketTokenAuth -func (r *registry) BridgePicker(c echo.Context) error { - location := c.Param("location") - // TODO return error if location not known - bridges := r.bridges[location] - - filters := make([]func(*m.Bridge) bool, 0) - filters = maybeAddBridgeBucketFilter(c, filters) - filters = maybeAddLastSeenBridgeCutoffFilter(r, filters) - filtered := filter[*m.Bridge](alltrue(filters), bridges) - - return c.JSON(http.StatusOK, filtered) -} diff --git a/pkg/api/bridge_test.go b/pkg/api/bridge_test.go index c98e50bb8a60d3a76d934ca9da1dbe4b5086dea9..0d6d6fc9e5a89fb0f2dbee4aae3bc93e638542f2 100644 --- a/pkg/api/bridge_test.go +++ b/pkg/api/bridge_test.go @@ -7,7 +7,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "net/url" "os" "strings" "testing" @@ -25,7 +24,23 @@ func TestBridgeFilters(t *testing.T) { buckets string } - location1 := "New York" + locationNY := "New York" + locationMon := "Montreal" + locationStruct := m.Location{ + CountryCode: "US", + Label: locationNY, + DisplayName: locationNY, + Lat: "40.71", + Lon: "-74.00", + } + + locationStructCA := m.Location{ + CountryCode: "CA", + Label: locationMon, + DisplayName: locationMon, + Lat: "45.52", + Lon: "-73.65", + } bridge1 := &m.Bridge{ Host: "bridge1", @@ -33,13 +48,14 @@ func TestBridgeFilters(t *testing.T) { Transport: "TCP", Type: "obfs4", Bucket: "bucket1", + Location: locationNY, } - bridge2 := &m.Bridge{ Host: "bridge2", Port: 443, Transport: "TCP", Type: "obfs4", + Location: locationNY, } bridge3 := &m.Bridge{ @@ -49,85 +65,161 @@ func TestBridgeFilters(t *testing.T) { Type: "obfs4", Bucket: "bucket2", LastSeenMillis: time.Now().UnixMilli(), + Location: locationNY, + } + + bridge4 := &m.Bridge{ + Host: "bridge4", + Port: 443, + Transport: "TCP", + Type: "obfs4-hop", + Bucket: "bucket2", + LastSeenMillis: time.Now().UnixMilli(), + Location: locationNY, + } + + bridge5 := &m.Bridge{ + Host: "bridge5", + Port: 443, + Transport: "quic", + Type: "obfs4-hop", + Bucket: "bucket2", + LastSeenMillis: time.Now().UnixMilli(), + Location: locationNY, + } + + bridge6 := &m.Bridge{ + Host: "bridge6", + Port: 4430, + Transport: "kcp", + Type: "obfs4-hop", + Bucket: "bucket2", + LastSeenMillis: time.Now().UnixMilli(), + Location: locationMon, } testTable := []struct { name string mockRegistry *registry + parameter string expected func(*registry) string authToken string dbAuthTokens []authTokenDbEntry }{ - {"no auth token only private bridges", + {"no auth token only private bridges return empty", ®istry{ - bridges: bridgeMap{location1: []*m.Bridge{bridge1}}, + bridges: bridgeMap{locationNY: []*m.Bridge{bridge1, bridge3, bridge4, bridge5}, locationMon: []*m.Bridge{bridge6}}, + locations: locationMap{locationNY: &locationStruct}, }, + "?type=obfs4", func(r *registry) string { return "[]\n" }, "", []authTokenDbEntry{{"key1", "bucket1"}}, }, - {"auth token one private bridge", + {"auth token return one private bridge", ®istry{ - bridges: bridgeMap{location1: []*m.Bridge{bridge1}}, + bridges: bridgeMap{locationNY: []*m.Bridge{bridge1, bridge2, bridge3, bridge4, bridge5}, locationMon: []*m.Bridge{bridge6}}, + locations: locationMap{locationNY: &locationStruct}, }, + "?type=obfs4", func(r *registry) string { - bytes, err := json.Marshal(r.bridges[location1]) + bytes, err := json.Marshal([]*m.Bridge{bridge1}) assert.NoError(t, err) return string(bytes) }, "key1", []authTokenDbEntry{{"key1", "bucket1"}}, }, - {"auth token one private one public bridge", + {"auth token private and public bridges return private", ®istry{ - bridges: bridgeMap{location1: []*m.Bridge{ - bridge1, bridge2, - }, - }, + bridges: bridgeMap{locationNY: []*m.Bridge{bridge1, bridge2, bridge3, bridge4, bridge5}, locationMon: []*m.Bridge{bridge6}}, + locations: locationMap{locationNY: &locationStruct, locationMon: &locationStructCA}, }, + "?type=obfs4", func(r *registry) string { - bytes, err := json.Marshal(r.bridges[location1]) + bytes, err := json.Marshal([]*m.Bridge{bridge1}) assert.NoError(t, err) return string(bytes) }, "key1", []authTokenDbEntry{{"key1", "bucket1"}}, }, - {"auth token with multiple buckets two private bridges", + {"auth token with lastSeenCutoffMillis enabled ignore offline bridge", ®istry{ - bridges: bridgeMap{location1: []*m.Bridge{ + bridges: bridgeMap{locationNY: []*m.Bridge{ bridge1, bridge3, - }, - }, + }}, + locations: locationMap{locationNY: &locationStruct, locationMon: &locationStructCA}, + // Cut off is 5 seconds + lastSeenCutoffMillis: 5000, }, + "?type=obfs4", func(r *registry) string { - bytes, err := json.Marshal(r.bridges[location1]) + bytes, err := json.Marshal([]*m.Bridge{bridge3}) assert.NoError(t, err) return string(bytes) }, "key1", []authTokenDbEntry{{"key1", "bucket1,bucket2"}}, }, - {"auth token with lastSeenCutoffMillis enabled", + {"transport quic return 1 bridge", ®istry{ - bridges: bridgeMap{location1: []*m.Bridge{ - bridge1, - bridge3, - }, - }, - // Cut off is 5 seconds - lastSeenCutoffMillis: 5000, + bridges: bridgeMap{locationNY: []*m.Bridge{bridge1, bridge3, bridge4, bridge5}, locationMon: []*m.Bridge{bridge6}}, + locations: locationMap{locationNY: &locationStruct, locationMon: &locationStructCA}, }, + "?tr=quic", func(r *registry) string { - bytes, err := json.Marshal([]*m.Bridge{bridge3}) + bytes, err := json.Marshal([]*m.Bridge{bridge5}) assert.NoError(t, err) return string(bytes) }, "key1", - []authTokenDbEntry{{"key1", "bucket1,bucket2"}}, + []authTokenDbEntry{{"key1", "bucket2"}}, + }, + {"transport kcp return 1 bridge", + ®istry{ + bridges: bridgeMap{locationNY: []*m.Bridge{bridge1, bridge3, bridge4, bridge5}, locationMon: []*m.Bridge{bridge6}}, + locations: locationMap{locationNY: &locationStruct, locationMon: &locationStructCA}, + }, + "?tr=kcp", + func(r *registry) string { + bytes, err := json.Marshal([]*m.Bridge{bridge6}) + assert.NoError(t, err) + return string(bytes) + }, + "key1", + []authTokenDbEntry{{"key1", "bucket2"}}, + }, + {"location return 1 bridge", + ®istry{ + bridges: bridgeMap{locationNY: []*m.Bridge{bridge1, bridge3, bridge4, bridge5}, locationMon: []*m.Bridge{bridge6}}, + locations: locationMap{locationNY: &locationStruct, locationMon: &locationStructCA}, + }, + "?loc=Montreal", + func(r *registry) string { + bytes, err := json.Marshal([]*m.Bridge{bridge6}) + assert.NoError(t, err) + return string(bytes) + }, + "key1", + []authTokenDbEntry{{"key1", "bucket2"}}, + }, + {"port return 1 bridge", + ®istry{ + bridges: bridgeMap{locationNY: []*m.Bridge{bridge1, bridge3, bridge4, bridge5}, locationMon: []*m.Bridge{bridge6}}, + locations: locationMap{locationNY: &locationStruct, locationMon: &locationStructCA}, + }, + "?port=4430", + func(r *registry) string { + bytes, err := json.Marshal([]*m.Bridge{bridge6}) + assert.NoError(t, err) + return string(bytes) + }, + "key1", + []authTokenDbEntry{{"key1", "bucket2"}}, }, } @@ -161,10 +253,9 @@ func TestBridgeFilters(t *testing.T) { e.Use(storageMiddleware(db)) e.Use(authTokenMiddleware) e.GET("/api/5/bridges", tc.mockRegistry.ListAllBridges) - e.GET("/api/5/bridge/:location", tc.mockRegistry.BridgePicker) // First test ListAllBridges - req := httptest.NewRequest(http.MethodGet, "/api/5/bridges?type=obfs4", nil) + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/5/bridges%s", tc.parameter), nil) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) if tc.authToken != "" { req.Header.Set("x-menshen-auth-token", tc.authToken) @@ -174,18 +265,109 @@ func TestBridgeFilters(t *testing.T) { e.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, strings.TrimSpace(expectedResponse), strings.TrimSpace(rec.Body.String())) + }) + } +} + +// if filtereing results in multiple Bridges being returned, we cannot do string comparisons in the +// test for verifying the result. This test has been moved (quick and dirty) from TestBridgeFilters. +func TestFiltersResultInRandomOrder(t *testing.T) { + type authTokenDbEntry struct { + key string + buckets string + } + + location1 := "New York" + locationStruct := m.Location{ + CountryCode: "US", + Label: location1, + DisplayName: location1, + Lat: "40.71", + Lon: "-74.00", + } - // The output from BridgePicker should be the same since we're just filtering by buckets - req = httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/5/bridge/%v", url.PathEscape(location1)), nil) + bridge1 := &m.Bridge{ + Host: "bridge1", + Port: 443, + Transport: "TCP", + Type: "obfs4", + Bucket: "bucket1", + Location: location1, + } + + bridge3 := &m.Bridge{ + Host: "bridge3", + Port: 443, + Transport: "TCP", + Type: "obfs4", + Bucket: "bucket2", + LastSeenMillis: time.Now().UnixMilli(), + Location: location1, + } + + testTable := []struct { + name string + mockRegistry *registry + expected []*m.Bridge + authToken string + dbAuthTokens []authTokenDbEntry + }{ + {"auth token with multiple buckets two private bridges", + ®istry{ + bridges: bridgeMap{location1: []*m.Bridge{bridge1, bridge3}}, + locations: locationMap{location1: &locationStruct}, + }, + []*m.Bridge{bridge1, bridge3}, + "key1", + []authTokenDbEntry{{"key1", "bucket1,bucket2"}}, + }, + } + + for _, tc := range testTable { + t.Run(tc.name, func(t *testing.T) { + + dir, err := os.MkdirTemp("", "") + assert.NoError(t, err) + defer os.RemoveAll(dir) + + db, err := storage.OpenDatabase(dir + "/db1.sql") + assert.NoError(t, err) + defer db.Close() + + stmt, err := db.Prepare("INSERT INTO tokens(key, buckets) VALUES(?, ?)") + assert.NoError(t, err) + + for _, token := range tc.dbAuthTokens { + h := sha256.New() + h.Write([]byte(token.key)) + authTokenHashBytes := h.Sum(nil) + + authTokenHashString := base64.StdEncoding.EncodeToString(authTokenHashBytes) + _, err = stmt.Exec(authTokenHashString, token.buckets) + assert.NoError(t, err) + } + + e := echo.New() + e.Use(storageMiddleware(db)) + e.Use(authTokenMiddleware) + e.GET("/api/5/bridges", tc.mockRegistry.ListAllBridges) + + req := httptest.NewRequest(http.MethodGet, "/api/5/bridges?type=obfs4", nil) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) if tc.authToken != "" { req.Header.Set("x-menshen-auth-token", tc.authToken) } - rec = httptest.NewRecorder() + rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, strings.TrimSpace(expectedResponse), strings.TrimSpace(rec.Body.String())) + + // Unmarshal the JSON data into the struct + var bridges []*m.Bridge + err = json.Unmarshal(rec.Body.Bytes(), &bridges) + assert.Equal(t, err, nil) + assert.ElementsMatch(t, tc.expected, bridges) }) } + } diff --git a/pkg/api/endpoints.go b/pkg/api/endpoints.go index 057a34c7b7c9969257e9c62f291cf78b5f39dadb..619d0ac4358aa9e19115bb2ba6889bd8cc491a04 100644 --- a/pkg/api/endpoints.go +++ b/pkg/api/endpoints.go @@ -3,6 +3,7 @@ package api import ( "errors" + "fmt" "slices" "strconv" "strings" @@ -196,13 +197,24 @@ func sanitizeLocations(location string, locations locationMap) error { return errors.New("unknown location") } -func sanitizeTransport(transport string, allowedTransports []string) error { - if transport == "" { +func sanitizeTransport(c echo.Context, allowedTransports []string) error { + return sanitizeToLower(c, "tr", allowedTransports) +} + +func sanitizeType(c echo.Context, allowedTypes []string) error { + return sanitizeToLower(c, "type", allowedTypes) +} + +func sanitizeToLower(c echo.Context, paramKey string, allowedList []string) error { + queryParam := c.QueryParam(paramKey) + if queryParam == "" { return nil } - if !slices.Contains(allowedTransports, transport) { - return errors.New("unknown transport") + queryParam = strings.ToLower(queryParam) + if !slices.Contains(allowedList, queryParam) { + return fmt.Errorf("unknown value %s for query param %s", c.QueryParam(paramKey), paramKey) } + c.QueryParams().Set(paramKey, queryParam) return nil } diff --git a/pkg/api/gateway.go b/pkg/api/gateway.go index 1b675680275b6c14d3338fc80395e88f66f44923..b62f8db6aa7383456d4d5a81a382633d71d9e1f6 100644 --- a/pkg/api/gateway.go +++ b/pkg/api/gateway.go @@ -45,7 +45,7 @@ func (r *registry) ListAllGateways(c echo.Context) error { return c.JSON(http.StatusBadRequest, err) } - err = sanitizeTransport(c.QueryParam("tr"), []string{"udp", "tcp"}) + err = sanitizeTransport(c, []string{"udp", "tcp"}) if err != nil { return c.JSON(http.StatusBadRequest, err) } diff --git a/pkg/api/registry.go b/pkg/api/registry.go index 932435be8f95d2b60ea1da106051e5a5946753b2..f6adc2ebab25b465a65e4702beade7ef35ca5519 100644 --- a/pkg/api/registry.go +++ b/pkg/api/registry.go @@ -256,8 +256,12 @@ func newRegistry(cfg *Config) (*registry, error) { return r, nil } -func (r *registry) Stop() { - r.lb.Stop() +func (r *registry) AllBridges() []*m.Bridge { + bridges := []*m.Bridge{} + for k := range r.bridges { + bridges = append(bridges, r.bridges[k]...) + } + return bridges } func (r *registry) AllGateways() []*m.Gateway {