diff --git a/obfs4.go b/obfs4.go
index 562015a2312d19d21719974c41823ed744b1a4af..a92c0945d4589250cfa59d01b50c2ff0cf5a4e08 100644
--- a/obfs4.go
+++ b/obfs4.go
@@ -32,6 +32,7 @@ import (
 	"bytes"
 	"fmt"
 	"io"
+	"math/rand"
 	"net"
 	"syscall"
 	"time"
@@ -76,8 +77,6 @@ type Obfs4Conn struct {
 	// Server side state.
 	listener *Obfs4Listener
 	startTime time.Time
-	closeDelayBytes int
-	closeDelay int
 }
 
 func (c *Obfs4Conn) padBurst(burst *bytes.Buffer) (err error) {
@@ -117,7 +116,7 @@ func (c *Obfs4Conn) closeAfterDelay() {
 	// I-it's not like I w-wanna handshake with you or anything.  B-b-baka!
 	defer c.conn.Close()
 
-	delay := time.Duration(c.closeDelay) * time.Second
+	delay := time.Duration(c.listener.closeDelay) * time.Second + connectionTimeout
 	deadline := c.startTime.Add(delay)
 	if time.Now().After(deadline) {
 		return
@@ -132,7 +131,7 @@ func (c *Obfs4Conn) closeAfterDelay() {
 	// interval passes or a certain size has been reached.
 	discarded := 0
 	var buf [framing.MaximumSegmentLength]byte
-	for discarded < int(c.closeDelayBytes) {
+	for discarded < int(c.listener.closeDelayBytes) {
 		n, err := c.conn.Read(buf[:])
 		if err != nil {
 			return
@@ -325,10 +324,10 @@ func (c *Obfs4Conn) ServerHandshake() error {
 
 	// Complete the handshake.
 	err := c.serverHandshake(c.listener.nodeID, c.listener.keyPair)
-	c.listener = nil
 	if err != nil {
 		c.closeAfterDelay()
 	}
+	c.listener = nil
 
 	return err
 }
@@ -524,7 +523,11 @@ type Obfs4Listener struct {
 
 	keyPair *ntor.Keypair
 	nodeID  *ntor.NodeID
+
 	seed    *DrbgSeed
+
+	closeDelayBytes int
+	closeDelay int
 }
 
 func (l *Obfs4Listener) Accept() (net.Conn, error) {
@@ -545,8 +548,6 @@ func (l *Obfs4Listener) Accept() (net.Conn, error) {
 		return nil, err
 	}
 	cObfs.startTime = time.Now()
-	cObfs.closeDelayBytes = cObfs.lenProbDist.rng.Intn(maxCloseDelayBytes)
-	cObfs.closeDelay = cObfs.lenProbDist.rng.Intn(maxCloseDelay)
 
 	return cObfs, nil
 }
@@ -585,6 +586,10 @@ func Listen(network, laddr, nodeID, privateKey, seed string) (net.Listener, erro
 		return nil, err
 	}
 
+	rng := rand.New(newHashDrbg(l.seed))
+	l.closeDelayBytes = rng.Intn(maxCloseDelayBytes)
+	l.closeDelay = rng.Intn(maxCloseDelay)
+
 	// Start up the listener.
 	l.listener, err = net.Listen(network, laddr)
 	if err != nil {