diff --git a/pkg/vpn/bonafide/bonafide.go b/pkg/vpn/bonafide/bonafide.go index bca88c275da2b7a48600bcda83e31855cdbb8e8b..8e66cdeb84b4c1d303ae4af18ad92d829d3a695a 100644 --- a/pkg/vpn/bonafide/bonafide.go +++ b/pkg/vpn/bonafide/bonafide.go @@ -288,16 +288,15 @@ func (b *Bonafide) GetBestGateways(transport string) ([]Gateway, error) { return gws, err } -// GetAllGateways only filters gateways by transport. +// FetchGateways only filters gateways by transport. // if "any" is provided it will return all gateways for all transports -func (b *Bonafide) GetAllGateways(transport string) ([]Gateway, error) { +func (b *Bonafide) FetchGateways(transport string) error { err := b.maybeInitializeEIP() // XXX needs to wait for bonafide too if err != nil { - return nil, err + return err } - gws, err := b.gateways.getAll(transport, b.tzOffsetHours) - return gws, err + return b.gateways.getAll(transport, b.tzOffsetHours) } func (b *Bonafide) GetLocationQualityMap(transport string) map[string]float64 { diff --git a/pkg/vpn/bonafide/gateways.go b/pkg/vpn/bonafide/gateways.go index 88c7581e40ed51bbb1857465d036029e977dc066..7bb1da59f341f6997e030bf77a6063b97bfbfa0b 100644 --- a/pkg/vpn/bonafide/gateways.go +++ b/pkg/vpn/bonafide/gateways.go @@ -317,17 +317,19 @@ func (p *gatewayPool) getBestLocation(transport string, tz int) string { } -func (p *gatewayPool) getAll(transport string, tz int) ([]Gateway, error) { +func (p *gatewayPool) getAll(transport string, tz int) error { if (&gatewayPool{} == p) { log.Warn().Msg("getAll tried to access uninitialized struct") - return []Gateway{}, nil + return nil } log.Debug().Msg("seems to be initialized...") if len(p.recommended) == 0 { - return p.getGatewaysFromMenshen(transport, 999) + _, err := p.getGatewaysFromMenshen(transport, 999) + return err } - return p.getGatewaysByTimezone(transport, tz, 999) + _, err := p.getGatewaysByTimezone(transport, tz, 999) + return err } /* picks at most max gateways, filtering by transport, from the ordered list menshen returned */ diff --git a/pkg/vpn/interface.go b/pkg/vpn/interface.go index 91575e6fcbd64f5b4c9a4114abb2093fbf7901dc..0e99409690a5063e04c763d599538d0467772344 100644 --- a/pkg/vpn/interface.go +++ b/pkg/vpn/interface.go @@ -19,6 +19,6 @@ type apiInterface interface { GetOpenvpnArgs() ([]string, error) GetGatewayByIP(ip string) (bonafide.Gateway, error) GetBestGateways(transport string) ([]bonafide.Gateway, error) - GetAllGateways(transport string) ([]bonafide.Gateway, error) + FetchGateways(transport string) error GetSnowflakeCh() chan *snowflake.StatusEvent } diff --git a/pkg/vpn/menshen/gateway.go b/pkg/vpn/menshen/gateway.go index d4eb90f16baa0e278bd1d7787837e197d708dc53..2fd3031c91c88faef28f4be717b2dad499f44d32 100644 --- a/pkg/vpn/menshen/gateway.go +++ b/pkg/vpn/menshen/gateway.go @@ -28,22 +28,6 @@ func (m *Menshen) GetGatewayByIP(ip string) (bonafide.Gateway, error) { return bonafide.Gateway{}, fmt.Errorf("Could not find a gateway with ip %s", ip) } -// Fetch gateways and return a list filtered transport. The parameter transport can -// have the value "any". Then it wil lreturn all gateways for all transports. -// GetAllGateways get's called once during startup -// TODO: maybe merge GetAllGateways and fetchGateways -func (m *Menshen) GetAllGateways(transport string) ([]bonafide.Gateway, error) { - // TODO: implement obfsv4 support (transport can have the value "any") - if transport == "obfs4" { - return []bonafide.Gateway{}, errors.New("obfs4 is not supported for v5 right now") - } - err := m.fetchGateways(transport) - if err != nil { - return []bonafide.Gateway{}, err - } - return NewBonafideGatewayArray(m.Gateways), nil -} - // Returns a list of gateways that we will connect to. First checks if automatic gateway // selection should be used. func (m *Menshen) GetBestGateways(transport string) ([]bonafide.Gateway, error) { @@ -112,9 +96,14 @@ func getGatewayNames(gateways []*models.ModelsGateway) []string { // Asks menshen for gateways. The gateways are stored in m.Gateways // Currently, there is not CountryCode filtering // The vars m.gwLocations and m.gwsByLocation are updated -func (m *Menshen) fetchGateways(transport string) error { +func (m *Menshen) FetchGateways(transport string) error { log.Trace().Msg("Fetching gateways from menshen") + // TODO: implement obfsv4 support (transport can have the value "any") + if transport == "obfs4" { + errors.New("obfs4 is not supported for v5 right now") + } + // reset if called multiple times m.gwLocations = []string{} m.gwsByLocation = make(map[string][]*models.ModelsGateway) diff --git a/pkg/vpn/menshen/integration_test.go b/pkg/vpn/menshen/integration_test.go index 9b910dbfbfa229fc43d5837f81369a597edf45c8..65c8172ed426b3ad89725eddc20a77b42e2b6f6f 100644 --- a/pkg/vpn/menshen/integration_test.go +++ b/pkg/vpn/menshen/integration_test.go @@ -12,6 +12,8 @@ import ( "github.com/stretchr/testify/require" ) +// integration tests need to know where to find menshen: set API_URL="http://localhost:8443" via env + func init() { log.Logger = zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout}).With().Timestamp().Logger() } @@ -26,12 +28,16 @@ func getMenshenInstance(t *testing.T) *Menshen { return m } -func TestGetAllGateways(t *testing.T) { - // needs API_URL="http://localhost:8443" via env +func TestFetchGateways(t *testing.T) { m := getMenshenInstance(t) - gateways, err := m.GetAllGateways("openvpn") - require.NoError(t, err, "GetAllGateways returned an error") + err := m.FetchGateways("openvpn") + require.NoError(t, err, "FetchGateways returned an error") + m.SetAutomaticGateway() + + gateways, err := m.GetBestGateways("openvpn") + require.NoError(t, err, "GetBestGateways returned an error") assert.Greater(t, len(gateways), 0, "There should multiple gateways fetched") + log.Info(). Int("gateways", len(gateways)). Msg("Got gateways") @@ -75,8 +81,8 @@ func TestLatency(t *testing.T) { func TestLocationQualityMap(t *testing.T) { m := getMenshenInstance(t) - _, err := m.GetAllGateways("openvpn") - require.NoError(t, err, "GetAllGateways returned an error") + err := m.FetchGateways("openvpn") + require.NoError(t, err, "FetchGateways returned an error") locationQualtyMap := m.GetLocationQualityMap("openvpn") for _, quality := range locationQualtyMap { @@ -87,8 +93,8 @@ func TestLocationQualityMap(t *testing.T) { func TestLocationLabels(t *testing.T) { m := getMenshenInstance(t) - _, err := m.GetAllGateways("openvpn") - require.NoError(t, err, "GetAllGateways returned an error") + err := m.FetchGateways("openvpn") + require.NoError(t, err, "FetchGateways returned an error") labelMap := m.GetLocationLabels("transport") for _, city := range labelMap { @@ -105,8 +111,8 @@ func TestLocationLabels(t *testing.T) { func TestGetBestLocation(t *testing.T) { m := getMenshenInstance(t) - _, err := m.GetAllGateways("openvpn") - require.NoError(t, err, "GetAllGateways returned an error") + err := m.FetchGateways("openvpn") + require.NoError(t, err, "FetchGateways returned an error") locationQualtyMap := m.GetLocationQualityMap("openvpn") @@ -128,8 +134,8 @@ func TestGetBestGatewaysShuffled(t *testing.T) { transport := "openvpn" m := getMenshenInstance(t) - _, err := m.GetAllGateways(transport) - assert.NoError(t, err, "GetAllGateways returned an error") + err := m.FetchGateways(transport) + assert.NoError(t, err, "FetchGateways returned an error") location, err := m.GetBestLocation(transport) assert.NoError(t, err, "m.GetBestLocation returned an error") diff --git a/pkg/vpn/openvpn.go b/pkg/vpn/openvpn.go index 7d64e20b26136848e20ebf6826b39d7917e0007c..95f510fe38dcaa1afd43c091fe4f6fedf303d732 100644 --- a/pkg/vpn/openvpn.go +++ b/pkg/vpn/openvpn.go @@ -411,7 +411,7 @@ func (b *Bitmask) getCert() error { // Explicit call to GetGateways, to be able to fetch them all before starting the vpn func (b *Bitmask) fetchGateways() { log.Info().Msg("Fetching gateways...") - _, err := b.api.GetAllGateways(b.transport) + err := b.api.FetchGateways(b.transport) if err != nil { log.Warn(). Err(err).