diff --git a/pkg/vpn/bonafide/bonafide.go b/pkg/vpn/bonafide/bonafide.go index bca88c275da2b7a48600bcda83e31855cdbb8e8b..f2b0afbc9c5e0104e722bff3027b8d96cf47d45c 100644 --- a/pkg/vpn/bonafide/bonafide.go +++ b/pkg/vpn/bonafide/bonafide.go @@ -288,16 +288,15 @@ func (b *Bonafide) GetBestGateways(transport string) ([]Gateway, error) { return gws, err } -// GetAllGateways only filters gateways by transport. +// FetchGateways only filters gateways by transport. // if "any" is provided it will return all gateways for all transports -func (b *Bonafide) GetAllGateways(transport string) ([]Gateway, error) { +func (b *Bonafide) FetchAllGateways(transport string) error { err := b.maybeInitializeEIP() // XXX needs to wait for bonafide too if err != nil { - return nil, err + return err } - gws, err := b.gateways.getAll(transport, b.tzOffsetHours) - return gws, err + return b.gateways.getAll(transport, b.tzOffsetHours) } func (b *Bonafide) GetLocationQualityMap(transport string) map[string]float64 { diff --git a/pkg/vpn/bonafide/gateways.go b/pkg/vpn/bonafide/gateways.go index bbf5e084d04048862b367bef36d28db7ae435c8c..86f82391b9a081e1bca166c30738dc009e090dcd 100644 --- a/pkg/vpn/bonafide/gateways.go +++ b/pkg/vpn/bonafide/gateways.go @@ -317,17 +317,19 @@ func (p *gatewayPool) getBestLocation(transport string, tz int) string { } -func (p *gatewayPool) getAll(transport string, tz int) ([]Gateway, error) { +func (p *gatewayPool) getAll(transport string, tz int) error { if (&gatewayPool{} == p) { log.Warn().Msg("getAll tried to access uninitialized struct") - return []Gateway{}, nil + return nil } log.Debug().Msg("seems to be initialized...") if len(p.recommended) == 0 { - return p.getGatewaysFromMenshen(transport, 999) + _, err := p.getGatewaysFromMenshen(transport, 999) + return err } - return p.getGatewaysByTimezone(transport, tz, 999) + _, err := p.getGatewaysByTimezone(transport, tz, 999) + return err } /* picks at most max gateways, filtering by transport, from the ordered list menshen returned */ diff --git a/pkg/vpn/interface.go b/pkg/vpn/interface.go index 91575e6fcbd64f5b4c9a4114abb2093fbf7901dc..e4408b601471b02cee0fa339137e167787ad57ae 100644 --- a/pkg/vpn/interface.go +++ b/pkg/vpn/interface.go @@ -19,6 +19,6 @@ type apiInterface interface { GetOpenvpnArgs() ([]string, error) GetGatewayByIP(ip string) (bonafide.Gateway, error) GetBestGateways(transport string) ([]bonafide.Gateway, error) - GetAllGateways(transport string) ([]bonafide.Gateway, error) + FetchAllGateways(transport string) error GetSnowflakeCh() chan *snowflake.StatusEvent } diff --git a/pkg/vpn/menshen/gateway.go b/pkg/vpn/menshen/gateway.go index 99ff8a77df927c5daf9de9980f08dc718f6ffb71..0f81228f968f5055e9296b273fe2c673cf529fe8 100644 --- a/pkg/vpn/menshen/gateway.go +++ b/pkg/vpn/menshen/gateway.go @@ -28,22 +28,6 @@ func (m *Menshen) GetGatewayByIP(ip string) (bonafide.Gateway, error) { return bonafide.Gateway{}, fmt.Errorf("Could not find a gateway with ip %s", ip) } -// Fetch gateways and return a list filtered transport. The parameter transport can -// have the value "any". Then it wil lreturn all gateways for all transports. -// GetAllGateways get's called once during startup -// TODO: maybe merge GetAllGateways and fetchGateways -func (m *Menshen) GetAllGateways(transport string) ([]bonafide.Gateway, error) { - // TODO: implement obfsv4 support (transport can have the value "any") - if transport == "obfs4" { - return []bonafide.Gateway{}, errors.New("obfs4 is not supported for v5 right now") - } - err := m.fetchGateways(transport) - if err != nil { - return []bonafide.Gateway{}, err - } - return NewBonafideGatewayArray(m.Gateways), nil -} - // Returns a list of gateways that we will connect to. First checks if automatic gateway // selection should be used. func (m *Menshen) GetBestGateways(transport string) ([]bonafide.Gateway, error) { @@ -112,9 +96,14 @@ func getGatewayNames(gateways []*models.ModelsGateway) []string { // Asks menshen for gateways. The gateways are stored in m.Gateways // Currently, there is not CountryCode filtering // The vars m.gwLocations and m.gwsByLocation are updated -func (m *Menshen) fetchGateways(transport string) error { +func (m *Menshen) FetchAllGateways(transport string) error { log.Trace().Msg("Fetching gateways from menshen") + // TODO: implement obfsv4 support (transport can have the value "any") + if transport == "obfs4" { + errors.New("obfs4 is not supported for v5 right now") + } + // reset if called multiple times m.gwLocations = []string{} m.gwsByLocation = make(map[string][]*models.ModelsGateway) diff --git a/pkg/vpn/menshen/integration_test.go b/pkg/vpn/menshen/integration_test.go index 9b910dbfbfa229fc43d5837f81369a597edf45c8..50a916f0d10364038276e56560d08afd3c7d7842 100644 --- a/pkg/vpn/menshen/integration_test.go +++ b/pkg/vpn/menshen/integration_test.go @@ -12,6 +12,8 @@ import ( "github.com/stretchr/testify/require" ) +// integration tests need to know where to find menshen: set API_URL="http://localhost:8443" via env + func init() { log.Logger = zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout}).With().Timestamp().Logger() } @@ -26,12 +28,16 @@ func getMenshenInstance(t *testing.T) *Menshen { return m } -func TestGetAllGateways(t *testing.T) { - // needs API_URL="http://localhost:8443" via env +func TestFetchAllGateways(t *testing.T) { m := getMenshenInstance(t) - gateways, err := m.GetAllGateways("openvpn") - require.NoError(t, err, "GetAllGateways returned an error") + err := m.FetchAllGateways("openvpn") + require.NoError(t, err, "FetchAllGateways returned an error") + m.SetAutomaticGateway() + + gateways, err := m.GetBestGateways("openvpn") + require.NoError(t, err, "GetBestGateways returned an error") assert.Greater(t, len(gateways), 0, "There should multiple gateways fetched") + log.Info(). Int("gateways", len(gateways)). Msg("Got gateways") @@ -75,8 +81,8 @@ func TestLatency(t *testing.T) { func TestLocationQualityMap(t *testing.T) { m := getMenshenInstance(t) - _, err := m.GetAllGateways("openvpn") - require.NoError(t, err, "GetAllGateways returned an error") + err := m.FetchAllGateways("openvpn") + require.NoError(t, err, "FetchAllGateways returned an error") locationQualtyMap := m.GetLocationQualityMap("openvpn") for _, quality := range locationQualtyMap { @@ -87,8 +93,8 @@ func TestLocationQualityMap(t *testing.T) { func TestLocationLabels(t *testing.T) { m := getMenshenInstance(t) - _, err := m.GetAllGateways("openvpn") - require.NoError(t, err, "GetAllGateways returned an error") + err := m.FetchAllGateways("openvpn") + require.NoError(t, err, "FetchAllGateways returned an error") labelMap := m.GetLocationLabels("transport") for _, city := range labelMap { @@ -105,8 +111,8 @@ func TestLocationLabels(t *testing.T) { func TestGetBestLocation(t *testing.T) { m := getMenshenInstance(t) - _, err := m.GetAllGateways("openvpn") - require.NoError(t, err, "GetAllGateways returned an error") + err := m.FetchAllGateways("openvpn") + require.NoError(t, err, "FetchAllGateways returned an error") locationQualtyMap := m.GetLocationQualityMap("openvpn") @@ -128,8 +134,8 @@ func TestGetBestGatewaysShuffled(t *testing.T) { transport := "openvpn" m := getMenshenInstance(t) - _, err := m.GetAllGateways(transport) - assert.NoError(t, err, "GetAllGateways returned an error") + err := m.FetchAllGateways(transport) + assert.NoError(t, err, "FetchAllGateways returned an error") location, err := m.GetBestLocation(transport) assert.NoError(t, err, "m.GetBestLocation returned an error") diff --git a/pkg/vpn/openvpn.go b/pkg/vpn/openvpn.go index 7d64e20b26136848e20ebf6826b39d7917e0007c..1ece141f63693a61da65a65f7ad67bf8616608e0 100644 --- a/pkg/vpn/openvpn.go +++ b/pkg/vpn/openvpn.go @@ -411,7 +411,7 @@ func (b *Bitmask) getCert() error { // Explicit call to GetGateways, to be able to fetch them all before starting the vpn func (b *Bitmask) fetchGateways() { log.Info().Msg("Fetching gateways...") - _, err := b.api.GetAllGateways(b.transport) + err := b.api.FetchAllGateways(b.transport) if err != nil { log.Warn(). Err(err).