[PATCH] Fix CVE-2020-15257
authorTianon Gravi <tianon@infosiftr.com>
Tue, 24 Nov 2020 12:38:31 +0000 (12:38 +0000)
committerFelix Geyer <fgeyer@debian.org>
Sun, 21 Feb 2021 17:18:35 +0000 (17:18 +0000)
This is the 1.2 backport. It's the Samuel Karp patch with additional changes:

 - Add ReadAddress function from commit 84a24711e88
 - Add "horten the unix socket path for shim" commit (a631796fda6)

Below is the original commit message:

-----------------------------------------------------------------------

Use path based unix socket for shims

This allows filesystem based ACLs for configuring access to the socket of a
shim.

Co-authored-by: Samuel Karp <skarp@amazon.com>
Signed-off-by: Samuel Karp <skarp@amazon.com>
Signed-off-by: Michael Crosby <michael@thepasture.io>
Signed-off-by: Michael Crosby <michael.crosby@apple.com>
-----------------------------------------------------------------------

containerd-shim: use path-based unix socket

This allows filesystem-based ACLs for configuring access to the socket
of a shim.

Ported from Michael Crosby's similar patch for v2 shims.

Signed-off-by: Samuel Karp <skarp@amazon.com>
-----------------------------------------------------------------------

Co-authored-by: Paulo Flabiano Smorigo <pfsmorigo@canonical.com>
Co-authored-by: varsha teratipally <teratipally@google.com>
Signed-off-by: Tianon Gravi <tianon@infosiftr.com>
Gbp-Pq: Name cve-2020-15257.patch

containerd/cmd/containerd-shim/main_unix.go
containerd/cmd/ctr/commands/shim/shim.go
containerd/container_test.go
containerd/runtime/v1/linux/bundle.go
containerd/runtime/v1/shim/client/client.go
containerd/runtime/v2/runc/service.go
containerd/runtime/v2/shim/shim.go
containerd/runtime/v2/shim/shim_unix.go
containerd/runtime/v2/shim/util.go
containerd/runtime/v2/shim/util_unix.go
containerd/runtime/v2/shim/util_windows.go

index e05dad611d9244a1af43eb501714fdd3badf06c1..150ff4d8a8cc483047c3ef8c83795148d2bc6e22 100644 (file)
@@ -62,7 +62,7 @@ var (
 func init() {
        flag.BoolVar(&debugFlag, "debug", false, "enable debug output in logs")
        flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim")
-       flag.StringVar(&socketFlag, "socket", "", "abstract socket path to serve")
+       flag.StringVar(&socketFlag, "socket", "", "socket path to serve")
        flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd")
        flag.StringVar(&workdirFlag, "workdir", "", "path used to storge large temporary data")
        flag.StringVar(&runtimeRootFlag, "runtime-root", proc.RuncRoot, "root directory for the runtime")
@@ -161,10 +161,18 @@ func serve(ctx context.Context, server *ttrpc.Server, path string) error {
                l, err = net.FileListener(os.NewFile(3, "socket"))
                path = "[inherited from parent]"
        } else {
-               if len(path) > 106 {
-                       return errors.Errorf("%q: unix socket path too long (> 106)", path)
+               const (
+                       abstractSocketPrefix = "\x00"
+                       socketPathLimit      = 106
+               )
+               p := strings.TrimPrefix(path, "unix://")
+               if len(p) == len(path) {
+                       p = abstractSocketPrefix + p
                }
-               l, err = net.Listen("unix", "\x00"+path)
+               if len(p) > socketPathLimit {
+                       return errors.Errorf("%q: unix socket path too long (> %d)", p, socketPathLimit)
+               }
+               l, err = net.Listen("unix", p)
        }
        if err != nil {
                return err
index ec08cc68bb93af5b2000b6d21c40065664d59255..3dbb8b062f13ebaa6c638803bde82dbd6fdd1590 100644 (file)
@@ -231,7 +231,7 @@ func getTaskService(context *cli.Context) (task.TaskService, error) {
                return nil, errors.New("socket path must be specified")
        }
 
-       conn, err := net.Dial("unix", "\x00"+bindSocket)
+       conn, err := connectToAddress(bindSocket)
        if err != nil {
                return nil, err
        }
@@ -243,3 +243,13 @@ func getTaskService(context *cli.Context) (task.TaskService, error) {
 
        return task.NewTaskClient(client), nil
 }
+
+// as we changed the socket address from abstract, we need to have a backward
+// compatibility to handle the abstract sockets as well.
+func connectToAddress(address string) (net.Conn, error) {
+       conn, err := net.Dial("unix", address)
+       if err != nil {
+               return net.Dial("unix", "\x00"+address)
+       }
+       return conn, err
+}
index 927646da66f1f6602dd1e1df91f2e9c03545738d..a08785fcea2e81fc15858b23e2ca35818bed7375 100644 (file)
@@ -32,7 +32,9 @@ import (
        // Register the typeurl
        "github.com/containerd/containerd/cio"
        "github.com/containerd/containerd/containers"
+       "github.com/containerd/containerd/namespaces"
        "github.com/containerd/containerd/oci"
+       "github.com/containerd/containerd/platforms"
        _ "github.com/containerd/containerd/runtime"
        "github.com/containerd/typeurl"
        specs "github.com/opencontainers/runtime-spec/specs-go"
@@ -1528,3 +1530,59 @@ func TestContainerHook(t *testing.T) {
        }
        defer task.Delete(ctx, WithProcessKill)
 }
+
+func TestShimSockLength(t *testing.T) {
+       t.Parallel()
+
+       // Max length of namespace should be 76
+       namespace := strings.Repeat("n", 76)
+
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+
+       ctx = namespaces.WithNamespace(ctx, namespace)
+
+       client, err := newClient(t, address)
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer client.Close()
+
+       image, err := client.Pull(ctx, testImage,
+               WithPlatformMatcher(platforms.Default()),
+               WithPullUnpack,
+       )
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       id := strings.Repeat("c", 64)
+
+       // We don't have limitation with length of container name,
+       // but 64 bytes of sha256 is the common case
+       container, err := client.NewContainer(ctx, id,
+               WithNewSnapshot(id, image),
+               WithNewSpec(oci.WithImageConfig(image), withExitStatus(0)),
+       )
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer container.Delete(ctx, WithSnapshotCleanup)
+
+       task, err := container.NewTask(ctx, empty())
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer task.Delete(ctx)
+
+       statusC, err := task.Wait(ctx)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       if err := task.Start(ctx); err != nil {
+               t.Fatal(err)
+       }
+
+       <-statusC
+}
index d73866a2fd890fbea36fe01d95b6b2c215a52a98..84c06f2ab45b78c8f909a2be493a014dee788fe9 100644 (file)
@@ -20,6 +20,8 @@ package linux
 
 import (
        "context"
+       "crypto/sha256"
+       "fmt"
        "io/ioutil"
        "os"
        "path/filepath"
@@ -88,7 +90,7 @@ func ShimRemote(c *Config, daemonAddress, cgroup string, exitHandler func()) Shi
        return func(b *bundle, ns string, ropts *runctypes.RuncOptions) (shim.Config, client.Opt) {
                config := b.shimConfig(ns, c, ropts)
                return config,
-                       client.WithStart(c.Shim, b.shimAddress(ns), daemonAddress, cgroup, c.ShimDebug, exitHandler)
+                       client.WithStart(c.Shim, b.shimAddress(ns, daemonAddress), daemonAddress, cgroup, c.ShimDebug, exitHandler)
        }
 }
 
@@ -102,7 +104,7 @@ func ShimLocal(c *Config, exchange *exchange.Exchange) ShimOpt {
 // ShimConnect is a ShimOpt for connecting to an existing remote shim
 func ShimConnect(c *Config, onClose func()) ShimOpt {
        return func(b *bundle, ns string, ropts *runctypes.RuncOptions) (shim.Config, client.Opt) {
-               return b.shimConfig(ns, c, ropts), client.WithConnect(b.shimAddress(ns), onClose)
+               return b.shimConfig(ns, c, ropts), client.WithConnect(b.decideShimAddress(ns), onClose)
        }
 }
 
@@ -114,6 +116,11 @@ func (b *bundle) NewShimClient(ctx context.Context, namespace string, getClientO
 
 // Delete deletes the bundle from disk
 func (b *bundle) Delete() error {
+       address, _ := b.loadAddress()
+       if address != "" {
+               // we don't care about errors here
+               client.RemoveSocket(address)
+       }
        err := os.RemoveAll(b.path)
        if err == nil {
                return os.RemoveAll(b.workDir)
@@ -126,10 +133,34 @@ func (b *bundle) Delete() error {
        return errors.Wrapf(err, "Failed to remove both bundle and workdir locations: %v", err2)
 }
 
-func (b *bundle) shimAddress(namespace string) string {
+func (b *bundle) legacyShimAddress(namespace string) string {
        return filepath.Join(string(filepath.Separator), "containerd-shim", namespace, b.id, "shim.sock")
 }
 
+const socketRoot = "/run/containerd"
+
+func (b *bundle) shimAddress(namespace, socketPath string) string {
+       d := sha256.Sum256([]byte(filepath.Join(socketPath, namespace, b.id)))
+       return fmt.Sprintf("unix://%s/%x", filepath.Join(socketRoot, "s"), d)
+}
+
+func (b *bundle) loadAddress() (string, error) {
+       addressPath := filepath.Join(b.path, "address")
+       data, err := ioutil.ReadFile(addressPath)
+       if err != nil {
+               return "", err
+       }
+       return string(data), nil
+}
+
+func (b *bundle) decideShimAddress(namespace string) string {
+       address, err := b.loadAddress()
+       if err != nil {
+               return b.legacyShimAddress(namespace)
+       }
+       return address
+}
+
 func (b *bundle) shimConfig(namespace string, c *Config, runcOptions *runctypes.RuncOptions) shim.Config {
        var (
                criuPath      string
index 015d88c2dc66c2d77ed6b61d435d14a6f40cdfa0..1225e099ee19fc5f152df8625d6653b001edad9b 100644 (file)
@@ -20,10 +20,12 @@ package client
 
 import (
        "context"
+       "fmt"
        "io"
        "net"
        "os"
        "os/exec"
+       "path/filepath"
        "strings"
        "sync"
        "syscall"
@@ -53,9 +55,17 @@ func WithStart(binary, address, daemonAddress, cgroup string, debug bool, exitHa
        return func(ctx context.Context, config shim.Config) (_ shimapi.ShimService, _ io.Closer, err error) {
                socket, err := newSocket(address)
                if err != nil {
-                       return nil, nil, err
+                       if !eaddrinuse(err) {
+                               return nil, nil, err
+                       }
+                       if err := RemoveSocket(address); err != nil {
+                               return nil, nil, errors.Wrap(err, "remove already used socket")
+                       }
+                       if socket, err = newSocket(address); err != nil {
+                               return nil, nil, err
+                       }
                }
-               defer socket.Close()
+
                f, err := socket.File()
                if err != nil {
                        return nil, nil, errors.Wrapf(err, "failed to get fd for socket %s", address)
@@ -77,12 +87,18 @@ func WithStart(binary, address, daemonAddress, cgroup string, debug bool, exitHa
                go func() {
                        cmd.Wait()
                        exitHandler()
+                       socket.Close()
+                       RemoveSocket(address)
                }()
                log.G(ctx).WithFields(logrus.Fields{
                        "pid":     cmd.Process.Pid,
                        "address": address,
                        "debug":   debug,
                }).Infof("shim %s started", binary)
+
+               if err := writeAddress(filepath.Join(config.Path, "address"), address); err != nil {
+                       return nil, nil, err
+               }
                // set shim in cgroup if it is provided
                if cgroup != "" {
                        if err := setCgroup(cgroup, cmd); err != nil {
@@ -104,6 +120,26 @@ func WithStart(binary, address, daemonAddress, cgroup string, debug bool, exitHa
        }
 }
 
+func eaddrinuse(err error) bool {
+       cause := errors.Cause(err)
+       netErr, ok := cause.(*net.OpError)
+       if !ok {
+               return false
+       }
+       if netErr.Op != "listen" {
+               return false
+       }
+       syscallErr, ok := netErr.Err.(*os.SyscallError)
+       if !ok {
+               return false
+       }
+       errno, ok := syscallErr.Err.(syscall.Errno)
+       if !ok {
+               return false
+       }
+       return errno == syscall.EADDRINUSE
+}
+
 func newCommand(binary, daemonAddress string, debug bool, config shim.Config, socket *os.File) (*exec.Cmd, error) {
        selfExe, err := os.Executable()
        if err != nil {
@@ -144,31 +180,92 @@ func newCommand(binary, daemonAddress string, debug bool, config shim.Config, so
        return cmd, nil
 }
 
+// writeAddress writes a address file atomically
+func writeAddress(path, address string) error {
+       path, err := filepath.Abs(path)
+       if err != nil {
+               return err
+       }
+       tempPath := filepath.Join(filepath.Dir(path), fmt.Sprintf(".%s", filepath.Base(path)))
+       f, err := os.OpenFile(tempPath, os.O_RDWR|os.O_CREATE|os.O_EXCL|os.O_SYNC, 0666)
+       if err != nil {
+               return err
+       }
+       _, err = f.WriteString(address)
+       f.Close()
+       if err != nil {
+               return err
+       }
+       return os.Rename(tempPath, path)
+}
+
+const (
+       abstractSocketPrefix = "\x00"
+       socketPathLimit      = 106
+)
+
+type socket string
+
+func (s socket) isAbstract() bool {
+       return !strings.HasPrefix(string(s), "unix://")
+}
+
+func (s socket) path() string {
+       path := strings.TrimPrefix(string(s), "unix://")
+       // if there was no trim performed, we assume an abstract socket
+       if len(path) == len(s) {
+               path = abstractSocketPrefix + path
+       }
+       return path
+}
+
 func newSocket(address string) (*net.UnixListener, error) {
-       if len(address) > 106 {
-               return nil, errors.Errorf("%q: unix socket path too long (> 106)", address)
+       if len(address) > socketPathLimit {
+               return nil, errors.Errorf("%q: unix socket path too long (> %d)", address, socketPathLimit)
+       }
+       var (
+               sock = socket(address)
+               path = sock.path()
+       )
+       if !sock.isAbstract() {
+               if err := os.MkdirAll(filepath.Dir(path), 0600); err != nil {
+                       return nil, errors.Wrapf(err, "%s", path)
+               }
        }
-       l, err := net.Listen("unix", "\x00"+address)
+       l, err := net.Listen("unix", path)
        if err != nil {
-               return nil, errors.Wrapf(err, "failed to listen to abstract unix socket %q", address)
+               return nil, errors.Wrapf(err, "failed to listen to unix socket %q (abstract: %t)", address, sock.isAbstract())
+       }
+       if err := os.Chmod(path, 0600); err != nil {
+               l.Close()
+               return nil, err
        }
 
        return l.(*net.UnixListener), nil
 }
 
+// RemoveSocket removes the socket at the specified address if
+// it exists on the filesystem
+func RemoveSocket(address string) error {
+       sock := socket(address)
+       if !sock.isAbstract() {
+               return os.Remove(sock.path())
+       }
+       return nil
+}
+
 func connect(address string, d func(string, time.Duration) (net.Conn, error)) (net.Conn, error) {
        return d(address, 100*time.Second)
 }
 
-func annonDialer(address string, timeout time.Duration) (net.Conn, error) {
-       address = strings.TrimPrefix(address, "unix://")
-       return net.DialTimeout("unix", "\x00"+address, timeout)
+func anonDialer(address string, timeout time.Duration) (net.Conn, error) {
+       return net.DialTimeout("unix", socket(address).path(), timeout)
 }
 
 // WithConnect connects to an existing shim
 func WithConnect(address string, onClose func()) Opt {
        return func(ctx context.Context, config shim.Config) (shimapi.ShimService, io.Closer, error) {
-               conn, err := connect(address, annonDialer)
+               conn, err := connect(address, anonDialer)
                if err != nil {
                        return nil, nil, err
                }
index e3c78d6e792c349970abcb52a249645bf66aa1d3..7bae6070c85632db4ad31e4d8437624eb7a61029 100644 (file)
@@ -142,20 +142,26 @@ func (s *service) StartShim(ctx context.Context, id, containerdBinary, container
        if err != nil {
                return "", err
        }
-       address, err := shim.SocketAddress(ctx, id)
+       address, err := shim.SocketAddress(ctx, containerdAddress, id)
        if err != nil {
                return "", err
        }
        socket, err := shim.NewSocket(address)
        if err != nil {
-               return "", err
+               if !shim.SocketEaddrinuse(err) {
+                       return "", err
+               }
+               if err := shim.RemoveSocket(address); err != nil {
+                       return "", errors.Wrap(err, "remove already used socket")
+               }
+               if socket, err = shim.NewSocket(address); err != nil {
+                       return "", err
+               }
        }
-       defer socket.Close()
        f, err := socket.File()
        if err != nil {
                return "", err
        }
-       defer f.Close()
 
        cmd.ExtraFiles = append(cmd.ExtraFiles, f)
 
@@ -164,6 +170,7 @@ func (s *service) StartShim(ctx context.Context, id, containerdBinary, container
        }
        defer func() {
                if err != nil {
+                       _ = shim.RemoveSocket(address)
                        cmd.Process.Kill()
                }
        }()
@@ -581,6 +588,9 @@ func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*task
 
 func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*ptypes.Empty, error) {
        s.cancel()
+       if address, err := shim.ReadAddress("address"); err == nil {
+               _ = shim.RemoveSocket(address)
+       }
        os.Exit(0)
        return empty, nil
 }
index 39484c191222b68d3ba1695a78af8d3b85cea6cb..b5fb3ff6d0e5b2aee225aea6afbb3ffee98120c7 100644 (file)
@@ -77,7 +77,7 @@ func parseFlags() {
        flag.BoolVar(&debugFlag, "debug", false, "enable debug output in logs")
        flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim")
        flag.StringVar(&idFlag, "id", "", "id of the task")
-       flag.StringVar(&socketFlag, "socket", "", "abstract socket path to serve")
+       flag.StringVar(&socketFlag, "socket", "", "socket path to serve")
        flag.StringVar(&bundlePath, "bundle", "", "path to the bundle if not workdir")
 
        flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd")
@@ -239,11 +239,14 @@ func serve(ctx context.Context, server *ttrpc.Server, path string) error {
                return err
        }
        go func() {
-               defer l.Close()
                if err := server.Serve(ctx, l); err != nil &&
                        !strings.Contains(err.Error(), "use of closed network connection") {
                        logrus.WithError(err).Fatal("containerd-shim: ttrpc server failure")
                }
+               l.Close()
+               if address, err := ReadAddress("address"); err == nil {
+                       _ = RemoveSocket(address)
+               }
        }()
        return nil
 }
index 937aaaf0d8764b1454c4ec3dee0e74c492685700..42fba80d45a1ba2f0401c5b2e164208485631a2f 100644 (file)
@@ -58,15 +58,15 @@ func serveListener(path string) (net.Listener, error) {
                l, err = net.FileListener(os.NewFile(3, "socket"))
                path = "[inherited from parent]"
        } else {
-               if len(path) > 106 {
-                       return nil, errors.Errorf("%q: unix socket path too long (> 106)", path)
+               if len(path) > socketPathLimit {
+                       return nil, errors.Errorf("%q: unix socket path too long (> %d)", path, socketPathLimit)
                }
-               l, err = net.Listen("unix", "\x00"+path)
+               l, err = net.Listen("unix", path)
        }
        if err != nil {
                return nil, err
        }
-       logrus.WithField("socket", path).Debug("serving api on abstract socket")
+       logrus.WithField("socket", path).Debug("serving api on socket")
        return l, nil
 }
 
index b9c7524a61fd42938953eccd333af1eed49468cb..2cba62b8f9ff2d76763c3c1090d721f4b1dffdde 100644 (file)
@@ -19,6 +19,7 @@ package shim
 import (
        "context"
        "fmt"
+       "io/ioutil"
        "net"
        "os"
        "os/exec"
@@ -126,3 +127,22 @@ func WriteAddress(path, address string) error {
        }
        return os.Rename(tempPath, path)
 }
+
+// ErrNoAddress is returned when the address file has no content
+var ErrNoAddress = errors.New("no shim address")
+
+// ReadAddress returns the shim's socket address from the path
+func ReadAddress(path string) (string, error) {
+       path, err := filepath.Abs(path)
+       if err != nil {
+               return "", err
+       }
+       data, err := ioutil.ReadFile(path)
+       if err != nil {
+               return "", err
+       }
+       if len(data) == 0 {
+               return "", ErrNoAddress
+       }
+       return string(data), nil
+}
index 262fe2b363bed901a2e76e457aa9c2a667d7154a..d8a57a1da82c0e4b60fb46f15dce224140e67c48 100644 (file)
@@ -20,7 +20,10 @@ package shim
 
 import (
        "context"
+       "crypto/sha256"
+       "fmt"
        "net"
+       "os"
        "path/filepath"
        "strings"
        "syscall"
@@ -31,6 +34,8 @@ import (
        "github.com/pkg/errors"
 )
 
+const socketPathLimit = 106
+
 func getSysProcAttr() *syscall.SysProcAttr {
        return &syscall.SysProcAttr{
                Setpgid: true,
@@ -42,29 +47,101 @@ func SetScore(pid int) error {
        return sys.SetOOMScore(pid, sys.OOMScoreMaxKillable)
 }
 
-// SocketAddress returns an abstract socket address
-func SocketAddress(ctx context.Context, id string) (string, error) {
+const socketRoot = "/run/containerd"
+
+// SocketAddress returns a socket address
+func SocketAddress(ctx context.Context, socketPath, id string) (string, error) {
        ns, err := namespaces.NamespaceRequired(ctx)
        if err != nil {
                return "", err
        }
-       return filepath.Join(string(filepath.Separator), "containerd-shim", ns, id, "shim.sock"), nil
+       d := sha256.Sum256([]byte(filepath.Join(socketPath, ns, id)))
+       return fmt.Sprintf("unix://%s/%x", filepath.Join(socketRoot, "s"), d), nil
 }
 
-// AnonDialer returns a dialer for an abstract socket
+// AnonDialer returns a dialer for a socket
 func AnonDialer(address string, timeout time.Duration) (net.Conn, error) {
-       address = strings.TrimPrefix(address, "unix://")
-       return net.DialTimeout("unix", "\x00"+address, timeout)
+       return net.DialTimeout("unix", socket(address).path(), timeout)
 }
 
 // NewSocket returns a new socket
 func NewSocket(address string) (*net.UnixListener, error) {
-       if len(address) > 106 {
-               return nil, errors.Errorf("%q: unix socket path too long (> 106)", address)
+       var (
+               sock = socket(address)
+               path = sock.path()
+       )
+       if !sock.isAbstract() {
+               if err := os.MkdirAll(filepath.Dir(path), 0600); err != nil {
+                       return nil, errors.Wrapf(err, "%s", path)
+               }
        }
-       l, err := net.Listen("unix", "\x00"+address)
+       l, err := net.Listen("unix", path)
        if err != nil {
-               return nil, errors.Wrapf(err, "failed to listen to abstract unix socket %q", address)
+               return nil, err
+       }
+       if err := os.Chmod(path, 0600); err != nil {
+               os.Remove(sock.path())
+               l.Close()
+               return nil, err
        }
        return l.(*net.UnixListener), nil
 }
+
+const abstractSocketPrefix = "\x00"
+
+type socket string
+
+func (s socket) isAbstract() bool {
+       return !strings.HasPrefix(string(s), "unix://")
+}
+
+func (s socket) path() string {
+       path := strings.TrimPrefix(string(s), "unix://")
+       // if there was no trim performed, we assume an abstract socket
+       if len(path) == len(s) {
+               path = abstractSocketPrefix + path
+       }
+       return path
+}
+
+// RemoveSocket removes the socket at the specified address if
+// it exists on the filesystem
+func RemoveSocket(address string) error {
+       sock := socket(address)
+       if !sock.isAbstract() {
+               return os.Remove(sock.path())
+       }
+       return nil
+}
+
+// SocketEaddrinuse returns true if the provided error is caused by the
+// EADDRINUSE error number
+func SocketEaddrinuse(err error) bool {
+       netErr, ok := err.(*net.OpError)
+       if !ok {
+               return false
+       }
+       if netErr.Op != "listen" {
+               return false
+       }
+       syscallErr, ok := netErr.Err.(*os.SyscallError)
+       if !ok {
+               return false
+       }
+       errno, ok := syscallErr.Err.(syscall.Errno)
+       if !ok {
+               return false
+       }
+       return errno == syscall.EADDRINUSE
+}
+
+// CanConnect returns true if the socket provided at the address
+// is accepting new connections
+func CanConnect(address string) bool {
+       conn, err := AnonDialer(address, 100*time.Millisecond)
+       if err != nil {
+               return false
+       }
+       conn.Close()
+       return true
+}
index 986fc754bb3aa63c647b827cbef7aafbfb9bfdb3..be92604235efee1a25a2b1930f9f90774dd31138 100644 (file)
@@ -88,3 +88,9 @@ func NewSocket(address string) (net.Listener, error) {
        }
        return l, nil
 }
+
+// RemoveSocket removes the socket at the specified address if
+// it exists on the filesystem
+func RemoveSocket(address string) error {
+       return nil
+}