From a3049438421e467270d3d3c4ce6e1923c30e787d Mon Sep 17 00:00:00 2001 From: Pea Nut <peanut2@systemli.org> Date: Fri, 5 Jul 2024 09:04:38 +0200 Subject: [PATCH] Rename GetAllGateways to FetchAllGateways and change return values of FetchAllGateways After removing `Bitmask.ReloadFirewall()` there are no calls to `b.api.GetAllGateways()` where the calling function actually need the gateways as return values. So I renamed the function to FetchAllGateways. The function now only returns an error (instead of `([]bonafide.Gateway, error)`). This simplifies the code of v5 a bit, as we get rid of yet another function/abstraction. In v3 `gatewayPool.getAll()` now also returns just an error. --- pkg/vpn/bonafide/bonafide.go | 9 ++++----- pkg/vpn/bonafide/gateways.go | 10 ++++++---- pkg/vpn/interface.go | 2 +- pkg/vpn/menshen/gateway.go | 23 ++++++---------------- pkg/vpn/menshen/integration_test.go | 30 +++++++++++++++++------------ pkg/vpn/openvpn.go | 2 +- 6 files changed, 36 insertions(+), 40 deletions(-) diff --git a/pkg/vpn/bonafide/bonafide.go b/pkg/vpn/bonafide/bonafide.go index bca88c27..f2b0afbc 100644 --- a/pkg/vpn/bonafide/bonafide.go +++ b/pkg/vpn/bonafide/bonafide.go @@ -288,16 +288,15 @@ func (b *Bonafide) GetBestGateways(transport string) ([]Gateway, error) { return gws, err } -// GetAllGateways only filters gateways by transport. +// FetchGateways only filters gateways by transport. // if "any" is provided it will return all gateways for all transports -func (b *Bonafide) GetAllGateways(transport string) ([]Gateway, error) { +func (b *Bonafide) FetchAllGateways(transport string) error { err := b.maybeInitializeEIP() // XXX needs to wait for bonafide too if err != nil { - return nil, err + return err } - gws, err := b.gateways.getAll(transport, b.tzOffsetHours) - return gws, err + return b.gateways.getAll(transport, b.tzOffsetHours) } func (b *Bonafide) GetLocationQualityMap(transport string) map[string]float64 { diff --git a/pkg/vpn/bonafide/gateways.go b/pkg/vpn/bonafide/gateways.go index bbf5e084..86f82391 100644 --- a/pkg/vpn/bonafide/gateways.go +++ b/pkg/vpn/bonafide/gateways.go @@ -317,17 +317,19 @@ func (p *gatewayPool) getBestLocation(transport string, tz int) string { } -func (p *gatewayPool) getAll(transport string, tz int) ([]Gateway, error) { +func (p *gatewayPool) getAll(transport string, tz int) error { if (&gatewayPool{} == p) { log.Warn().Msg("getAll tried to access uninitialized struct") - return []Gateway{}, nil + return nil } log.Debug().Msg("seems to be initialized...") if len(p.recommended) == 0 { - return p.getGatewaysFromMenshen(transport, 999) + _, err := p.getGatewaysFromMenshen(transport, 999) + return err } - return p.getGatewaysByTimezone(transport, tz, 999) + _, err := p.getGatewaysByTimezone(transport, tz, 999) + return err } /* picks at most max gateways, filtering by transport, from the ordered list menshen returned */ diff --git a/pkg/vpn/interface.go b/pkg/vpn/interface.go index 91575e6f..e4408b60 100644 --- a/pkg/vpn/interface.go +++ b/pkg/vpn/interface.go @@ -19,6 +19,6 @@ type apiInterface interface { GetOpenvpnArgs() ([]string, error) GetGatewayByIP(ip string) (bonafide.Gateway, error) GetBestGateways(transport string) ([]bonafide.Gateway, error) - GetAllGateways(transport string) ([]bonafide.Gateway, error) + FetchAllGateways(transport string) error GetSnowflakeCh() chan *snowflake.StatusEvent } diff --git a/pkg/vpn/menshen/gateway.go b/pkg/vpn/menshen/gateway.go index 99ff8a77..0f81228f 100644 --- a/pkg/vpn/menshen/gateway.go +++ b/pkg/vpn/menshen/gateway.go @@ -28,22 +28,6 @@ func (m *Menshen) GetGatewayByIP(ip string) (bonafide.Gateway, error) { return bonafide.Gateway{}, fmt.Errorf("Could not find a gateway with ip %s", ip) } -// Fetch gateways and return a list filtered transport. The parameter transport can -// have the value "any". Then it wil lreturn all gateways for all transports. -// GetAllGateways get's called once during startup -// TODO: maybe merge GetAllGateways and fetchGateways -func (m *Menshen) GetAllGateways(transport string) ([]bonafide.Gateway, error) { - // TODO: implement obfsv4 support (transport can have the value "any") - if transport == "obfs4" { - return []bonafide.Gateway{}, errors.New("obfs4 is not supported for v5 right now") - } - err := m.fetchGateways(transport) - if err != nil { - return []bonafide.Gateway{}, err - } - return NewBonafideGatewayArray(m.Gateways), nil -} - // Returns a list of gateways that we will connect to. First checks if automatic gateway // selection should be used. func (m *Menshen) GetBestGateways(transport string) ([]bonafide.Gateway, error) { @@ -112,9 +96,14 @@ func getGatewayNames(gateways []*models.ModelsGateway) []string { // Asks menshen for gateways. The gateways are stored in m.Gateways // Currently, there is not CountryCode filtering // The vars m.gwLocations and m.gwsByLocation are updated -func (m *Menshen) fetchGateways(transport string) error { +func (m *Menshen) FetchAllGateways(transport string) error { log.Trace().Msg("Fetching gateways from menshen") + // TODO: implement obfsv4 support (transport can have the value "any") + if transport == "obfs4" { + errors.New("obfs4 is not supported for v5 right now") + } + // reset if called multiple times m.gwLocations = []string{} m.gwsByLocation = make(map[string][]*models.ModelsGateway) diff --git a/pkg/vpn/menshen/integration_test.go b/pkg/vpn/menshen/integration_test.go index 9b910dbf..50a916f0 100644 --- a/pkg/vpn/menshen/integration_test.go +++ b/pkg/vpn/menshen/integration_test.go @@ -12,6 +12,8 @@ import ( "github.com/stretchr/testify/require" ) +// integration tests need to know where to find menshen: set API_URL="http://localhost:8443" via env + func init() { log.Logger = zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout}).With().Timestamp().Logger() } @@ -26,12 +28,16 @@ func getMenshenInstance(t *testing.T) *Menshen { return m } -func TestGetAllGateways(t *testing.T) { - // needs API_URL="http://localhost:8443" via env +func TestFetchAllGateways(t *testing.T) { m := getMenshenInstance(t) - gateways, err := m.GetAllGateways("openvpn") - require.NoError(t, err, "GetAllGateways returned an error") + err := m.FetchAllGateways("openvpn") + require.NoError(t, err, "FetchAllGateways returned an error") + m.SetAutomaticGateway() + + gateways, err := m.GetBestGateways("openvpn") + require.NoError(t, err, "GetBestGateways returned an error") assert.Greater(t, len(gateways), 0, "There should multiple gateways fetched") + log.Info(). Int("gateways", len(gateways)). Msg("Got gateways") @@ -75,8 +81,8 @@ func TestLatency(t *testing.T) { func TestLocationQualityMap(t *testing.T) { m := getMenshenInstance(t) - _, err := m.GetAllGateways("openvpn") - require.NoError(t, err, "GetAllGateways returned an error") + err := m.FetchAllGateways("openvpn") + require.NoError(t, err, "FetchAllGateways returned an error") locationQualtyMap := m.GetLocationQualityMap("openvpn") for _, quality := range locationQualtyMap { @@ -87,8 +93,8 @@ func TestLocationQualityMap(t *testing.T) { func TestLocationLabels(t *testing.T) { m := getMenshenInstance(t) - _, err := m.GetAllGateways("openvpn") - require.NoError(t, err, "GetAllGateways returned an error") + err := m.FetchAllGateways("openvpn") + require.NoError(t, err, "FetchAllGateways returned an error") labelMap := m.GetLocationLabels("transport") for _, city := range labelMap { @@ -105,8 +111,8 @@ func TestLocationLabels(t *testing.T) { func TestGetBestLocation(t *testing.T) { m := getMenshenInstance(t) - _, err := m.GetAllGateways("openvpn") - require.NoError(t, err, "GetAllGateways returned an error") + err := m.FetchAllGateways("openvpn") + require.NoError(t, err, "FetchAllGateways returned an error") locationQualtyMap := m.GetLocationQualityMap("openvpn") @@ -128,8 +134,8 @@ func TestGetBestGatewaysShuffled(t *testing.T) { transport := "openvpn" m := getMenshenInstance(t) - _, err := m.GetAllGateways(transport) - assert.NoError(t, err, "GetAllGateways returned an error") + err := m.FetchAllGateways(transport) + assert.NoError(t, err, "FetchAllGateways returned an error") location, err := m.GetBestLocation(transport) assert.NoError(t, err, "m.GetBestLocation returned an error") diff --git a/pkg/vpn/openvpn.go b/pkg/vpn/openvpn.go index 7d64e20b..1ece141f 100644 --- a/pkg/vpn/openvpn.go +++ b/pkg/vpn/openvpn.go @@ -411,7 +411,7 @@ func (b *Bitmask) getCert() error { // Explicit call to GetGateways, to be able to fetch them all before starting the vpn func (b *Bitmask) fetchGateways() { log.Info().Msg("Fetching gateways...") - _, err := b.api.GetAllGateways(b.transport) + err := b.api.FetchAllGateways(b.transport) if err != nil { log.Warn(). Err(err). -- GitLab