Skip to content

Commit 3f42de4

Browse files
committed
Refactor dialer
Unify timeout dialer and context dialer implementation. Fix ContextDialer to use a real context-aware dialer. Signed-off-by: Ethan Chen <[email protected]>
1 parent 0396089 commit 3f42de4

4 files changed

Lines changed: 259 additions & 23 deletions

File tree

pkg/dialer/dialer.go

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package dialer
1818

1919
import (
2020
"context"
21+
stderrs "errors"
2122
"net"
2223
"time"
2324

@@ -29,51 +30,77 @@ type dialResult struct {
2930
err error
3031
}
3132

32-
// ContextDialer returns a GRPC net.Conn connected to the provided address
33+
var (
34+
// ErrTimeout represents a timeout error while dialing
35+
ErrTimeout = stderrs.New("timeout")
36+
)
37+
38+
// ContextDialer returns a GRPC net.Conn connected to the provided address.
39+
// It tolerates ENOENT and keeps retrying if provided context has a deadline.
40+
// It does a "one shot" connect attempt if provided context doesn't have a deadline.
3341
func ContextDialer(ctx context.Context, address string) (net.Conn, error) {
34-
if deadline, ok := ctx.Deadline(); ok {
35-
return timeoutDialer(address, time.Until(deadline))
42+
if _, ok := ctx.Deadline(); !ok {
43+
return contextDialer(ctx, address, nil)
3644
}
37-
return timeoutDialer(address, 0)
45+
return contextDialer(ctx, address, isNoent)
3846
}
3947

4048
// Dialer returns a GRPC net.Conn connected to the provided address
4149
// Deprecated: use ContextDialer and grpc.WithContextDialer.
4250
var Dialer = timeoutDialer
4351

44-
func timeoutDialer(address string, timeout time.Duration) (net.Conn, error) {
45-
var (
46-
stopC = make(chan struct{})
47-
synC = make(chan *dialResult)
48-
)
52+
// timeoutDialer connects to the provided address with a timeout.
53+
func timeoutDialer(address string, timeout time.Duration) (conn net.Conn, err error) {
54+
if timeout == 0 {
55+
return dialer(address, timeout)
56+
}
57+
58+
timeoutCtx, _ := context.WithTimeout(context.TODO(), timeout)
59+
return contextDialer(timeoutCtx, address, isNoent)
60+
}
61+
62+
// contextDialer connects to the provided address with a context.
63+
// It may tolerate certain errors and keep retrying before context canceled.
64+
// It returns immediately on context cancellation.
65+
func contextDialer(ctx context.Context, address string, tolerateErr func(error) bool) (net.Conn, error) {
66+
stopCh := make(chan struct{})
67+
resCh := make(chan *dialResult, 1)
68+
4969
go func() {
50-
defer close(synC)
70+
defer close(resCh)
5171
for {
5272
select {
53-
case <-stopC:
73+
case <-stopCh:
5474
return
5575
default:
56-
c, err := dialer(address, timeout)
57-
if isNoent(err) {
58-
<-time.After(10 * time.Millisecond)
76+
// Limit a single connect attempt timeout to 10s
77+
c, err := dialer(address, 10*time.Second)
78+
if tolerateErr != nil && tolerateErr(err) {
79+
time.Sleep(10 * time.Millisecond)
5980
continue
6081
}
61-
synC <- &dialResult{c, err}
82+
resCh <- &dialResult{c: c, err: err}
6283
return
6384
}
6485
}
6586
}()
87+
6688
select {
67-
case dr := <-synC:
68-
return dr.c, dr.err
69-
case <-time.After(timeout):
70-
close(stopC)
89+
case <-ctx.Done():
90+
close(stopCh)
7191
go func() {
72-
dr := <-synC
73-
if dr != nil && dr.c != nil {
74-
dr.c.Close()
92+
// If the dial succeed after timeout,
93+
// close it to prevent resource leak.
94+
r := <-resCh
95+
if r != nil && r.c != nil {
96+
r.c.Close()
7597
}
7698
}()
77-
return nil, errors.Errorf("dial %s: timeout", address)
99+
if ctx.Err() == context.DeadlineExceeded {
100+
return nil, errors.Wrapf(ErrTimeout, "dial %s", address)
101+
}
102+
return nil, errors.Wrapf(ctx.Err(), "dial %s", address)
103+
case res := <-resCh:
104+
return res.c, res.err
78105
}
79106
}

pkg/dialer/dialer_test.go

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
/*
2+
Copyright The containerd Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package dialer
18+
19+
import (
20+
"context"
21+
"crypto/md5"
22+
"errors"
23+
"fmt"
24+
"net"
25+
"strings"
26+
"testing"
27+
"time"
28+
)
29+
30+
func TestContextDialer(t *testing.T) {
31+
ts := []struct {
32+
Name string
33+
Server func(string, <-chan struct{}) (net.Addr, error)
34+
Timeout time.Duration
35+
Cancel bool
36+
ExpectErr bool
37+
Err error
38+
}{
39+
{
40+
Name: "dial successful",
41+
Server: newEchoServer,
42+
ExpectErr: false,
43+
},
44+
{
45+
Name: "context canceled",
46+
Server: newBlockServer,
47+
Cancel: true,
48+
ExpectErr: true,
49+
Err: context.Canceled,
50+
},
51+
{
52+
Name: "dial timeout",
53+
Server: newBlockServer,
54+
Timeout: 300 * time.Millisecond,
55+
ExpectErr: true,
56+
Err: ErrTimeout,
57+
},
58+
}
59+
t.Parallel()
60+
for i := range ts {
61+
tt := ts[i]
62+
t.Run(tt.Name, func(t *testing.T) {
63+
stop := make(chan struct{})
64+
addr, err := tt.Server(tt.Name, stop)
65+
if err != nil {
66+
t.Fatalf("failed to start test server: %v", err)
67+
}
68+
defer close(stop)
69+
if tt.Timeout == 0 {
70+
tt.Timeout = 3 * time.Second
71+
}
72+
ctx, cancel := context.WithTimeout(context.TODO(), tt.Timeout)
73+
defer cancel()
74+
if tt.Cancel {
75+
go func() {
76+
time.Sleep(300 * time.Millisecond)
77+
}()
78+
cancel()
79+
}
80+
_, err = ContextDialer(ctx, addr.String())
81+
if err != nil {
82+
if tt.ExpectErr && errors.Is(err, tt.Err) {
83+
// test OK
84+
return
85+
}
86+
t.Errorf("expect error %v, got %v", tt.Err, err)
87+
} else if tt.ExpectErr {
88+
t.Errorf("expect error %v, got <nil>", tt.Err)
89+
}
90+
})
91+
}
92+
93+
}
94+
95+
// newEchoServer setup a unix/pipe echo server for client to connect
96+
func newEchoServer(name string, stopCh <-chan struct{}) (net.Addr, error) {
97+
l, err := newListener(fmt.Sprintf("%x", md5.Sum([]byte(name))))
98+
if err != nil {
99+
return nil, err
100+
}
101+
102+
go func() {
103+
for {
104+
fd, lerr := l.Accept()
105+
if lerr != nil {
106+
if strings.Contains(lerr.Error(), "closed") {
107+
return
108+
}
109+
continue
110+
}
111+
go echoServer(fd)
112+
}
113+
}()
114+
115+
go func() {
116+
<-stopCh
117+
l.Close()
118+
}()
119+
120+
return l.Addr(), nil
121+
}
122+
123+
// newBlockServer setup a unix/pipe server but never accepts incoming connection
124+
func newBlockServer(name string, stopCh <-chan struct{}) (net.Addr, error) {
125+
l, err := newListener(fmt.Sprintf("%x", md5.Sum([]byte(name))))
126+
if err != nil {
127+
return nil, err
128+
}
129+
130+
go func() {
131+
<-stopCh
132+
l.Close()
133+
}()
134+
return l.Addr(), nil
135+
}
136+
137+
func echoServer(c net.Conn) {
138+
defer c.Close()
139+
buf := make([]byte, 512)
140+
for {
141+
n, err := c.Read(buf)
142+
if err != nil {
143+
return
144+
}
145+
data := buf[0:n]
146+
_, err = c.Write(data)
147+
if err != nil {
148+
return
149+
}
150+
}
151+
}

pkg/dialer/dialer_test_unix.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//+build !windows
2+
3+
/*
4+
Copyright The containerd Authors.
5+
6+
Licensed under the Apache License, Version 2.0 (the "License");
7+
you may not use this file except in compliance with the License.
8+
You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing, software
13+
distributed under the License is distributed on an "AS IS" BASIS,
14+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
See the License for the specific language governing permissions and
16+
limitations under the License.
17+
*/
18+
19+
package dialer
20+
21+
import (
22+
"net"
23+
"os"
24+
"path/filepath"
25+
)
26+
27+
func newListener(hash string) (net.Listener, error) {
28+
return net.Listen("unix", filepath.Join(os.TempDir(), hash+".sock"))
29+
}

pkg/dialer/dialer_test_windows.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//+build windows
2+
3+
/*
4+
Copyright The containerd Authors.
5+
6+
Licensed under the Apache License, Version 2.0 (the "License");
7+
you may not use this file except in compliance with the License.
8+
You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing, software
13+
distributed under the License is distributed on an "AS IS" BASIS,
14+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
See the License for the specific language governing permissions and
16+
limitations under the License.
17+
*/
18+
19+
package dialer
20+
21+
import (
22+
"net"
23+
24+
"github.com/Microsoft/go-winio"
25+
)
26+
27+
func newListener(hash string) (net.Listener, error) {
28+
return winio.ListenPipe(`\\.\pipe\`+hash, &winio.PipeConfig{})
29+
}

0 commit comments

Comments
 (0)