diff --git a/client/client.go b/client/client.go index 1a3ef1a5cbfb7f980bd492d044415e0814a2157d..94b7d03a75b2b9f6b4a6b97f68c5527800f14413 100644 --- a/client/client.go +++ b/client/client.go @@ -46,6 +46,18 @@ type Obfs4Config struct { Cert string } +type Obfs4Conn struct { + net.Conn + config Obfs4Config +} + +func NewObfs4Conn(conn net.Conn, config Obfs4Config) *Obfs4Conn { + return &Obfs4Conn{ + Conn: conn, + config: config, + } +} + func (oc *Obfs4Config) String() string { return oc.Remote } @@ -72,8 +84,8 @@ type HoppingConfig struct { type Client struct { kcpConfig obfsvpn.KCPConfig ProxyAddr string - newObfs4Conn chan net.Conn - obfs4Conns []net.Conn + newObfs4Conn chan Obfs4Conn + obfs4Conns []Obfs4Conn obfs4Endpoints []*Obfs4Config obfs4Failures map[string]int32 EventLogger EventLogger @@ -100,7 +112,7 @@ func NewClient(ctx context.Context, stop context.CancelFunc, config Config) *Cli kcpConfig: config.KCPConfig, obfs4Failures: map[string]int32{}, minHopSeconds: config.HoppingConfig.MinHopSeconds, - newObfs4Conn: make(chan net.Conn), + newObfs4Conn: make(chan Obfs4Conn), obfs4Endpoints: obfs4Endpoints, stop: stop, state: stopped, @@ -182,7 +194,7 @@ func (c *Client) Start() (bool, error) { c.error("Could not dial obfs4 remote: %v", err) return false, fmt.Errorf("could not dial remote: %w", err) } - c.obfs4Conns = []net.Conn{obfs4Conn} + c.obfs4Conns = []Obfs4Conn{*obfs4Conn} c.updateState(running) @@ -206,7 +218,7 @@ func (c *Client) Start() (bool, error) { return true, nil } -func (c *Client) createObfs4Connection(obfs4Endpoint *Obfs4Config) (net.Conn, error) { +func (c *Client) createObfs4Connection(obfs4Endpoint *Obfs4Config) (*Obfs4Conn, error) { var err error obfs4Dialer, err := obfsvpn.NewDialerFromCert(obfs4Endpoint.Cert) @@ -223,7 +235,11 @@ func (c *Client) createObfs4Connection(obfs4Endpoint *Obfs4Config) (net.Conn, er defer cancel() c.log("Dialing remote: %v", obfs4Endpoint.Remote) - return obfs4Dialer.DialContext(ctx, "tcp", obfs4Endpoint.Remote) + conn, err := obfs4Dialer.DialContext(ctx, "tcp", obfs4Endpoint.Remote) + if err != nil { + return nil, err + } + return NewObfs4Conn(conn, *obfs4Endpoint), nil } // updateState sets a new client state, logs it and sends an event to the clients @@ -266,10 +282,10 @@ func (c *Client) connectObfs4(obfs4Endpoint *Obfs4Config, sleepSeconds int) { c.error("Did not get obfs4: %v ", err) } else { c.outLock.Lock() - c.obfs4Conns = append([]net.Conn{newObfs4Conn}, c.obfs4Conns...) + c.obfs4Conns = append([]Obfs4Conn{*newObfs4Conn}, c.obfs4Conns...) c.outLock.Unlock() - c.newObfs4Conn <- newObfs4Conn + c.newObfs4Conn <- *newObfs4Conn c.log("Dialed new remote") // If we wait sleepSeconds here to clean up the previous connection, we can guarantee that the @@ -335,13 +351,11 @@ func (c *Client) cleanupOldConn() { if len(c.obfs4Conns) > 1 { c.log("Connections: %v", len(c.obfs4Conns)) connToClose := c.obfs4Conns[len(c.obfs4Conns)-1] - if connToClose != nil { - c.log("Cleaning up old connection to %v", connToClose.RemoteAddr()) + c.log("Cleaning up old connection to %v", connToClose.RemoteAddr()) - err := connToClose.Close() - if err != nil { - c.log("Error closing obfs4 connection to %v: %v", connToClose.RemoteAddr(), err) - } + err := connToClose.Close() + if err != nil { + c.log("Error closing obfs4 connection to %v: %v", connToClose.RemoteAddr(), err) } // Remove the connection from our tracking list @@ -378,9 +392,10 @@ func (c *Client) readUDPWriteTCP() { } _, err = conn.Write(tcpBuffer) if err != nil { - c.error("Write err from %v to %v: %v", conn.LocalAddr(), conn.RemoteAddr(), err) + c.error("readUDPWriteTCP: Write err from %v to %v: %v", conn.LocalAddr(), conn.RemoteAddr(), err) time.Sleep(reconnectTime) - c.connectObfs4(c.obfs4Endpoints[0], 20) + config := c.obfs4Conns[0].config + c.connectObfs4(&config, 20) } } } @@ -436,7 +451,7 @@ func (c *Client) readTCPWriteUDP() { _, err := c.openvpnConn.WriteToUDP(tcpBytes, c.openvpnAddr) c.openvpnAddrLock.RUnlock() if err != nil { - c.error("Write err from %v to %v: %v", c.openvpnConn.LocalAddr(), c.openvpnConn.RemoteAddr(), err) + c.error("readTCPWriteUDP: Write err from %v to %v: %v", c.openvpnConn.LocalAddr(), c.openvpnConn.RemoteAddr(), err) c.openvpnAddrLock.Lock() c.openvpnConn.Close() c.openvpnConn, err = c.createOpenvpnConnection()