Skip to content
Snippets Groups Projects
Commit 2b78ce86 authored by cyberta's avatar cyberta
Browse files

reuse hopping mechanism to reconnect on write error

parent f24462b2
No related branches found
No related tags found
1 merge request!61Reconnect on error
Pipeline #234704 failed
...@@ -74,7 +74,6 @@ type Client struct { ...@@ -74,7 +74,6 @@ type Client struct {
newObfs4Conn chan net.Conn newObfs4Conn chan net.Conn
obfs4Conns []net.Conn obfs4Conns []net.Conn
obfs4Endpoints []*Obfs4Config obfs4Endpoints []*Obfs4Config
obfs4Dialer *obfsvpn.Dialer
obfs4Failures map[string]int32 obfs4Failures map[string]int32
EventLogger EventLogger EventLogger EventLogger
state clientState state clientState
...@@ -177,22 +176,11 @@ func (c *Client) Start() (bool, error) { ...@@ -177,22 +176,11 @@ func (c *Client) Start() (bool, error) {
c.updateState(starting) c.updateState(starting)
obfs4Endpoint := c.obfs4Endpoints[0] obfs4Endpoint := c.obfs4Endpoints[0]
obfs4Conn, err := c.createObfs4Connection(obfs4Endpoint)
c.obfs4Dialer, err = obfsvpn.NewDialerFromCert(obfs4Endpoint.Cert)
if err != nil {
return false, fmt.Errorf("could not dial obfs4 remote: %w", err)
}
if c.kcpConfig.Enabled {
c.obfs4Dialer.DialFunc = obfsvpn.GetKCPDialer(c.kcpConfig, c.log)
}
obfs4Conn, err := c.obfs4Dialer.Dial("tcp", obfs4Endpoint.Remote)
if err != nil { if err != nil {
c.error("Could not dial obfs4 remote: %v", err) c.error("Could not dial obfs4 remote: %v", err)
return false, fmt.Errorf("could not dial remote: %w", err) return false, fmt.Errorf("could not dial remote: %w", err)
} }
c.obfs4Conns = []net.Conn{obfs4Conn} c.obfs4Conns = []net.Conn{obfs4Conn}
c.updateState(running) c.updateState(running)
...@@ -222,6 +210,26 @@ func (c *Client) Start() (bool, error) { ...@@ -222,6 +210,26 @@ func (c *Client) Start() (bool, error) {
return true, nil return true, nil
} }
func (c *Client) createObfs4Connection(obfs4Endpoint *Obfs4Config) (net.Conn, error) {
var err error
obfs4Dialer, err := obfsvpn.NewDialerFromCert(obfs4Endpoint.Cert)
if err != nil {
c.error("Could not dial obfs4 remote: %v", err)
return nil, err
}
if c.kcpConfig.Enabled {
obfs4Dialer.DialFunc = obfsvpn.GetKCPDialer(c.kcpConfig, c.log)
}
ctx, cancel := context.WithTimeout(context.Background(), dialGiveUpTime)
defer cancel()
c.log("Dialing remote: %v", obfs4Endpoint.Remote)
return obfs4Dialer.DialContext(ctx, "tcp", obfs4Endpoint.Remote)
}
// updateState sets a new client state, logs it and sends an event to the clients // updateState sets a new client state, logs it and sends an event to the clients
// EventLogger in case it is available. Always set the state with this function in // EventLogger in case it is available. Always set the state with this function in
// order to ensure integrating clients receive an update state event via FFI. // order to ensure integrating clients receive an update state event via FFI.
...@@ -244,6 +252,40 @@ func (c *Client) pickRandomEndpoint() *Obfs4Config { ...@@ -244,6 +252,40 @@ func (c *Client) pickRandomEndpoint() *Obfs4Config {
return endpoint return endpoint
} }
func (c *Client) reconnect(obfs4Endpoint *Obfs4Config, sleepSeconds int) {
newObfs4Conn, err := c.createObfs4Connection(obfs4Endpoint)
if err != nil {
newRemote := obfs4Endpoint.Remote
_, ok := c.obfs4Failures[newRemote]
if ok {
c.obfs4Failures[newRemote] += 1
} else {
c.obfs4Failures[newRemote] = 1
}
c.error("Could not dial obfs4 remote: %v (failures: %d)", err, c.obfs4Failures[newRemote])
}
if newObfs4Conn == nil {
c.error("Did not get obfs4: %v ", err)
} else {
c.outLock.Lock()
c.obfs4Conns = append([]net.Conn{newObfs4Conn}, c.obfs4Conns...)
c.outLock.Unlock()
c.newObfs4Conn <- newObfs4Conn
c.log("Dialed new remote")
// If we wait sleepSeconds here to clean up the previous connection, we can guarantee that the
// connection list will not grow unbounded
go func() {
time.Sleep(time.Duration(sleepSeconds) * time.Second)
c.cleanupOldConn()
}()
}
}
func (c *Client) hop() { func (c *Client) hop() {
for { for {
select { select {
...@@ -285,51 +327,8 @@ func (c *Client) hop() { ...@@ -285,51 +327,8 @@ func (c *Client) hop() {
} }
c.log("HOPPING to %+v", newRemote) c.log("HOPPING to %+v", newRemote)
c.reconnect(obfs4Endpoint, sleepSeconds)
obfs4Dialer, err := obfsvpn.NewDialerFromCert(obfs4Endpoint.Cert)
if err != nil {
c.error("Could not dial obfs4 remote: %v", err)
return
}
if c.kcpConfig.Enabled {
c.obfs4Dialer.DialFunc = obfsvpn.GetKCPDialer(c.kcpConfig, c.log)
}
ctx, cancel := context.WithTimeout(context.Background(), dialGiveUpTime)
defer cancel()
c.log("Dialing new remote: %v", newRemote)
newObfs4Conn, err := obfs4Dialer.DialContext(ctx, "tcp", newRemote)
if err != nil {
_, ok := c.obfs4Failures[newRemote]
if ok {
c.obfs4Failures[newRemote] += 1
} else {
c.obfs4Failures[newRemote] = 1
}
c.error("Could not dial obfs4 remote: %v (failures: %d)", err, c.obfs4Failures[newRemote])
}
if newObfs4Conn == nil {
c.error("Did not get obfs4: %v ", err)
} else {
c.outLock.Lock()
c.obfs4Conns = append([]net.Conn{newObfs4Conn}, c.obfs4Conns...)
c.outLock.Unlock()
c.newObfs4Conn <- newObfs4Conn
c.log("Dialed new remote")
// If we wait sleepSeconds here to clean up the previous connection, we can guarantee that the
// connection list will not grow unbounded
go func() {
time.Sleep(time.Duration(sleepSeconds) * time.Second)
c.cleanupOldConn()
}()
}
} }
} }
...@@ -386,6 +385,8 @@ func (c *Client) readUDPWriteTCP() { ...@@ -386,6 +385,8 @@ func (c *Client) readUDPWriteTCP() {
_, err = conn.Write(tcpBuffer) _, err = conn.Write(tcpBuffer)
if err != nil { if err != nil {
c.error("Write err from %v to %v: %v", conn.LocalAddr(), conn.RemoteAddr(), err) c.error("Write err from %v to %v: %v", conn.LocalAddr(), conn.RemoteAddr(), err)
time.Sleep(time.Duration(3) * time.Second)
c.reconnect(c.obfs4Endpoints[0], 20)
return return
} }
}() }()
...@@ -446,6 +447,8 @@ func (c *Client) readTCPWriteUDP() { ...@@ -446,6 +447,8 @@ func (c *Client) readTCPWriteUDP() {
c.openvpnAddrLock.RUnlock() c.openvpnAddrLock.RUnlock()
if err != nil { if err != nil {
c.error("Write err from %v to %v: %v", c.openvpnConn.LocalAddr(), c.openvpnConn.RemoteAddr(), err) c.error("Write err from %v to %v: %v", c.openvpnConn.LocalAddr(), c.openvpnConn.RemoteAddr(), err)
time.Sleep(time.Duration(3) * time.Second)
c.reconnect(c.obfs4Endpoints[0], 20)
return return
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment