From 2001f0b698183b998dbf8e52f5d40a0d82aeef09 Mon Sep 17 00:00:00 2001
From: Yawning Angel <yawning@schwanenlied.me>
Date: Sun, 1 Jun 2014 04:51:33 +0000
Subject: [PATCH] Generate client keypairs before connecting, instead of after.

Part of issue #9.
---
 handshake_ntor.go      |  9 ++-------
 handshake_ntor_test.go | 24 ++++++++++++++++++------
 obfs4.go               | 14 +++++++++++++-
 3 files changed, 33 insertions(+), 14 deletions(-)

diff --git a/handshake_ntor.go b/handshake_ntor.go
index fc107c2..92f00dc 100644
--- a/handshake_ntor.go
+++ b/handshake_ntor.go
@@ -121,14 +121,9 @@ type clientHandshake struct {
 	serverMark           []byte
 }
 
-func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey) (*clientHandshake, error) {
-	var err error
-
+func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey, sessionKey *ntor.Keypair) (*clientHandshake, error) {
 	hs := new(clientHandshake)
-	hs.keypair, err = ntor.NewKeypair(true)
-	if err != nil {
-		return nil, err
-	}
+	hs.keypair = sessionKey
 	hs.nodeID = nodeID
 	hs.serverIdentity = serverIdentity
 	hs.padLen = csrand.IntRange(clientMinPadLength, clientMaxPadLength)
diff --git a/handshake_ntor_test.go b/handshake_ntor_test.go
index b3e0a4d..69fb442 100644
--- a/handshake_ntor_test.go
+++ b/handshake_ntor_test.go
@@ -43,9 +43,13 @@ func TestHandshakeNtor(t *testing.T) {
 	// Test client handshake padding.
 	for l := clientMinPadLength; l <= clientMaxPadLength; l++ {
 		// Generate the client state and override the pad length.
-		clientHs, err := newClientHandshake(nodeID, idKeypair.Public())
+		clientKeypair, err := ntor.NewKeypair(true)
 		if err != nil {
-			t.Fatalf("[%d:0] newClientHandshake failed:", l, err)
+			t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err)
+		}
+		clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair)
+		if err != nil {
+			t.Fatalf("[%d:0] newClientHandshake failed: %s", l, err)
 		}
 		clientHs.padLen = l
 
@@ -99,9 +103,13 @@ func TestHandshakeNtor(t *testing.T) {
 	// Test server handshake padding.
 	for l := serverMinPadLength; l <= serverMaxPadLength+inlineSeedFrameLength; l++ {
 		// Generate the client state and override the pad length.
-		clientHs, err := newClientHandshake(nodeID, idKeypair.Public())
+		clientKeypair, err := ntor.NewKeypair(true)
+		if err != nil {
+			t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err)
+		}
+		clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair)
 		if err != nil {
-			t.Fatalf("[%d:0] newClientHandshake failed:", l, err)
+			t.Fatalf("[%d:0] newClientHandshake failed: %s", l, err)
 		}
 		clientHs.padLen = clientMinPadLength
 
@@ -146,9 +154,13 @@ func TestHandshakeNtor(t *testing.T) {
 	}
 
 	// Test oversized client padding.
-	clientHs, err := newClientHandshake(nodeID, idKeypair.Public())
+	clientKeypair, err := ntor.NewKeypair(true)
+	if err != nil {
+		t.Fatalf("ntor.NewKeypair failed: %s", err)
+	}
+	clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair)
 	if err != nil {
-		t.Fatalf("newClientHandshake failed:", err)
+		t.Fatalf("newClientHandshake failed: %s", err)
 	}
 
 	clientHs.padLen = clientMaxPadLength + 1
diff --git a/obfs4.go b/obfs4.go
index c780e0c..cc5e3b9 100644
--- a/obfs4.go
+++ b/obfs4.go
@@ -69,6 +69,8 @@ const (
 type Obfs4Conn struct {
 	conn net.Conn
 
+	sessionKey *ntor.Keypair
+
 	lenProbDist *wDist
 	iatProbDist *wDist
 
@@ -157,6 +159,8 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK
 	}
 
 	defer func() {
+		// The session key is not needed past returning from this routine.
+		c.sessionKey = nil
 		if err != nil {
 			c.setBroken()
 		}
@@ -165,7 +169,7 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK
 	// Generate/send the client handshake.
 	var hs *clientHandshake
 	var blob []byte
-	hs, err = newClientHandshake(nodeID, publicKey)
+	hs, err = newClientHandshake(nodeID, publicKey, c.sessionKey)
 	if err != nil {
 		return
 	}
@@ -576,6 +580,14 @@ func DialObfs4DialFn(dialFn DialFn, network, address, nodeID, publicKey string,
 		}
 		c.iatProbDist = newWDist(iatSeed, 0, maxIatDelay)
 	}
+
+	// Generate the session keypair *before* connecting to the remote peer.
+	c.sessionKey, err = ntor.NewKeypair(true)
+	if err != nil {
+		return nil, err
+	}
+
+	// Connect to the remote peer.
 	c.conn, err = dialFn(network, address)
 	if err != nil {
 		return nil, err
-- 
GitLab