@@ -57,9 +57,17 @@ func WithStart(binary, address, daemonAddress, cgroup string, debug bool, exitHa
5757 return func (ctx context.Context , config shim.Config ) (_ shimapi.ShimService , _ io.Closer , err error ) {
5858 socket , err := newSocket (address )
5959 if err != nil {
60- return nil , nil , err
60+ if ! eaddrinuse (err ) {
61+ return nil , nil , err
62+ }
63+ if err := RemoveSocket (address ); err != nil {
64+ return nil , nil , errors .Wrap (err , "remove already used socket" )
65+ }
66+ if socket , err = newSocket (address ); err != nil {
67+ return nil , nil , err
68+ }
6169 }
62- defer socket . Close ()
70+
6371 f , err := socket .File ()
6472 if err != nil {
6573 return nil , nil , errors .Wrapf (err , "failed to get fd for socket %s" , address )
@@ -104,6 +112,8 @@ func WithStart(binary, address, daemonAddress, cgroup string, debug bool, exitHa
104112 if stderrLog != nil {
105113 stderrLog .Close ()
106114 }
115+ socket .Close ()
116+ RemoveSocket (address )
107117 }()
108118 log .G (ctx ).WithFields (logrus.Fields {
109119 "pid" : cmd .Process .Pid ,
@@ -138,6 +148,26 @@ func WithStart(binary, address, daemonAddress, cgroup string, debug bool, exitHa
138148 }
139149}
140150
151+ func eaddrinuse (err error ) bool {
152+ cause := errors .Cause (err )
153+ netErr , ok := cause .(* net.OpError )
154+ if ! ok {
155+ return false
156+ }
157+ if netErr .Op != "listen" {
158+ return false
159+ }
160+ syscallErr , ok := netErr .Err .(* os.SyscallError )
161+ if ! ok {
162+ return false
163+ }
164+ errno , ok := syscallErr .Err .(syscall.Errno )
165+ if ! ok {
166+ return false
167+ }
168+ return errno == syscall .EADDRINUSE
169+ }
170+
141171// setupOOMScore gets containerd's oom score and adds +1 to it
142172// to ensure a shim has a lower* score than the daemons
143173func setupOOMScore (shimPid int ) error {
@@ -210,31 +240,73 @@ func writeFile(path, address string) error {
210240 return os .Rename (tempPath , path )
211241}
212242
243+ const (
244+ abstractSocketPrefix = "\x00 "
245+ socketPathLimit = 106
246+ )
247+
248+ type socket string
249+
250+ func (s socket ) isAbstract () bool {
251+ return ! strings .HasPrefix (string (s ), "unix://" )
252+ }
253+
254+ func (s socket ) path () string {
255+ path := strings .TrimPrefix (string (s ), "unix://" )
256+ // if there was no trim performed, we assume an abstract socket
257+ if len (path ) == len (s ) {
258+ path = abstractSocketPrefix + path
259+ }
260+ return path
261+ }
262+
213263func newSocket (address string ) (* net.UnixListener , error ) {
214- if len (address ) > 106 {
215- return nil , errors .Errorf ("%q: unix socket path too long (> 106)" , address )
264+ if len (address ) > socketPathLimit {
265+ return nil , errors .Errorf ("%q: unix socket path too long (> %d)" , address , socketPathLimit )
266+ }
267+ var (
268+ sock = socket (address )
269+ path = sock .path ()
270+ )
271+ if ! sock .isAbstract () {
272+ if err := os .MkdirAll (filepath .Dir (path ), 0600 ); err != nil {
273+ return nil , errors .Wrapf (err , "%s" , path )
274+ }
216275 }
217- l , err := net .Listen ("unix" , " \x00 " + address )
276+ l , err := net .Listen ("unix" , path )
218277 if err != nil {
219- return nil , errors .Wrapf (err , "failed to listen to abstract unix socket %q" , address )
278+ return nil , errors .Wrapf (err , "failed to listen to unix socket %q (abstract: %t)" , address , sock .isAbstract ())
279+ }
280+ if err := os .Chmod (path , 0600 ); err != nil {
281+ l .Close ()
282+ return nil , err
220283 }
221284
222285 return l .(* net.UnixListener ), nil
223286}
224287
288+ // RemoveSocket removes the socket at the specified address if
289+ // it exists on the filesystem
290+ func RemoveSocket (address string ) error {
291+ sock := socket (address )
292+ if ! sock .isAbstract () {
293+ return os .Remove (sock .path ())
294+ }
295+ return nil
296+ }
297+
225298func connect (address string , d func (string , time.Duration ) (net.Conn , error )) (net.Conn , error ) {
226299 return d (address , 100 * time .Second )
227300}
228301
229- func annonDialer (address string , timeout time.Duration ) (net.Conn , error ) {
230- address = strings .TrimPrefix (address , "unix://" )
231- return net .DialTimeout ("unix" , "\x00 " + address , timeout )
302+ func anonDialer (address string , timeout time.Duration ) (net.Conn , error ) {
303+ return net .DialTimeout ("unix" , socket (address ).path (), timeout )
232304}
233305
234306// WithConnect connects to an existing shim
235307func WithConnect (address string , onClose func ()) Opt {
236308 return func (ctx context.Context , config shim.Config ) (shimapi.ShimService , io.Closer , error ) {
237- conn , err := connect (address , annonDialer )
309+ conn , err := connect (address , anonDialer )
238310 if err != nil {
239311 return nil , nil , err
240312 }
0 commit comments