Skip to content

Commit 1d19cea

Browse files
Maisem Alineild
authored andcommitted
net/http/httptest: fix race in Server.Close
When run with race detector the test fails without the fix. Fixes #51799 Change-Id: I273adb6d3a2b1e0d606b9c27ab4c6a9aa4aa8064 GitHub-Last-Rev: a5ddd14 GitHub-Pull-Request: #51805 Reviewed-on: https://go-review.googlesource.com/c/go/+/393974 Reviewed-by: Damien Neil <[email protected]> Reviewed-by: Brad Fitzpatrick <[email protected]> Trust: Brad Fitzpatrick <[email protected]> Run-TryBot: Brad Fitzpatrick <[email protected]>
1 parent 3fd8b86 commit 1d19cea

2 files changed

Lines changed: 65 additions & 18 deletions

File tree

src/net/http/httptest/server.go

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -317,21 +317,17 @@ func (s *Server) wrap() {
317317
s.mu.Lock()
318318
defer s.mu.Unlock()
319319

320-
// Keep Close from returning until the user's ConnState hook
321-
// (if any) finishes. Without this, the call to forgetConn
322-
// below might send the count to 0 before we run the hook.
323-
s.wg.Add(1)
324-
defer s.wg.Done()
325-
326320
switch cs {
327321
case http.StateNew:
328-
s.wg.Add(1)
329322
if _, exists := s.conns[c]; exists {
330323
panic("invalid state transition")
331324
}
332325
if s.conns == nil {
333326
s.conns = make(map[net.Conn]http.ConnState)
334327
}
328+
// Add c to the set of tracked conns and increment it to the
329+
// waitgroup.
330+
s.wg.Add(1)
335331
s.conns[c] = cs
336332
if s.closed {
337333
// Probably just a socket-late-binding dial from
@@ -358,7 +354,14 @@ func (s *Server) wrap() {
358354
s.closeConn(c)
359355
}
360356
case http.StateHijacked, http.StateClosed:
361-
s.forgetConn(c)
357+
// Remove c from the set of tracked conns and decrement it from the
358+
// waitgroup, unless it was previously removed.
359+
if _, ok := s.conns[c]; ok {
360+
delete(s.conns, c)
361+
// Keep Close from returning until the user's ConnState hook
362+
// (if any) finishes.
363+
defer s.wg.Done()
364+
}
362365
}
363366
if oldHook != nil {
364367
oldHook(c, cs)
@@ -378,13 +381,3 @@ func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) {
378381
done <- struct{}{}
379382
}
380383
}
381-
382-
// forgetConn removes c from the set of tracked conns and decrements it from the
383-
// waitgroup, unless it was previously removed.
384-
// s.mu must be held.
385-
func (s *Server) forgetConn(c net.Conn) {
386-
if _, ok := s.conns[c]; ok {
387-
delete(s.conns, c)
388-
s.wg.Done()
389-
}
390-
}

src/net/http/httptest/server_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"io"
1010
"net"
1111
"net/http"
12+
"sync"
1213
"testing"
1314
)
1415

@@ -203,6 +204,59 @@ func TestServerZeroValueClose(t *testing.T) {
203204
ts.Close() // tests that it doesn't panic
204205
}
205206

207+
// Issue 51799: test hijacking a connection and then closing it
208+
// concurrently with closing the server.
209+
func TestCloseHijackedConnection(t *testing.T) {
210+
hijacked := make(chan net.Conn)
211+
ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
212+
defer close(hijacked)
213+
hj, ok := w.(http.Hijacker)
214+
if !ok {
215+
t.Fatal("failed to hijack")
216+
}
217+
c, _, err := hj.Hijack()
218+
if err != nil {
219+
t.Fatal(err)
220+
}
221+
hijacked <- c
222+
}))
223+
224+
var wg sync.WaitGroup
225+
wg.Add(1)
226+
go func() {
227+
defer wg.Done()
228+
req, err := http.NewRequest("GET", ts.URL, nil)
229+
if err != nil {
230+
t.Log(err)
231+
}
232+
// Use a client not associated with the Server.
233+
var c http.Client
234+
resp, err := c.Do(req)
235+
if err != nil {
236+
t.Log(err)
237+
return
238+
}
239+
resp.Body.Close()
240+
}()
241+
242+
wg.Add(1)
243+
conn := <-hijacked
244+
go func(conn net.Conn) {
245+
defer wg.Done()
246+
// Close the connection and then inform the Server that
247+
// we closed it.
248+
conn.Close()
249+
ts.Config.ConnState(conn, http.StateClosed)
250+
}(conn)
251+
252+
wg.Add(1)
253+
go func() {
254+
defer wg.Done()
255+
ts.Close()
256+
}()
257+
wg.Wait()
258+
}
259+
206260
func TestTLSServerWithHTTP2(t *testing.T) {
207261
modes := []struct {
208262
name string

0 commit comments

Comments
 (0)