func init() {
flag.BoolVar(&debugFlag, "debug", false, "enable debug output in logs")
flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim")
- flag.StringVar(&socketFlag, "socket", "", "abstract socket path to serve")
+ flag.StringVar(&socketFlag, "socket", "", "socket path to serve")
flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd")
flag.StringVar(&workdirFlag, "workdir", "", "path used to storge large temporary data")
flag.StringVar(&runtimeRootFlag, "runtime-root", proc.RuncRoot, "root directory for the runtime")
l, err = net.FileListener(os.NewFile(3, "socket"))
path = "[inherited from parent]"
} else {
- if len(path) > 106 {
- return errors.Errorf("%q: unix socket path too long (> 106)", path)
+ const (
+ abstractSocketPrefix = "\x00"
+ socketPathLimit = 106
+ )
+ p := strings.TrimPrefix(path, "unix://")
+ if len(p) == len(path) {
+ p = abstractSocketPrefix + p
}
- l, err = net.Listen("unix", "\x00"+path)
+ if len(p) > socketPathLimit {
+ return errors.Errorf("%q: unix socket path too long (> %d)", p, socketPathLimit)
+ }
+ l, err = net.Listen("unix", p)
}
if err != nil {
return err
return nil, errors.New("socket path must be specified")
}
- conn, err := net.Dial("unix", "\x00"+bindSocket)
+ conn, err := connectToAddress(bindSocket)
if err != nil {
return nil, err
}
return task.NewTaskClient(client), nil
}
+
+// as we changed the socket address from abstract, we need to have a backward
+// compatibility to handle the abstract sockets as well.
+func connectToAddress(address string) (net.Conn, error) {
+ conn, err := net.Dial("unix", address)
+ if err != nil {
+ return net.Dial("unix", "\x00"+address)
+ }
+ return conn, err
+}
// Register the typeurl
"github.com/containerd/containerd/cio"
"github.com/containerd/containerd/containers"
+ "github.com/containerd/containerd/namespaces"
"github.com/containerd/containerd/oci"
+ "github.com/containerd/containerd/platforms"
_ "github.com/containerd/containerd/runtime"
"github.com/containerd/typeurl"
specs "github.com/opencontainers/runtime-spec/specs-go"
}
defer task.Delete(ctx, WithProcessKill)
}
+
+func TestShimSockLength(t *testing.T) {
+ t.Parallel()
+
+ // Max length of namespace should be 76
+ namespace := strings.Repeat("n", 76)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ ctx = namespaces.WithNamespace(ctx, namespace)
+
+ client, err := newClient(t, address)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer client.Close()
+
+ image, err := client.Pull(ctx, testImage,
+ WithPlatformMatcher(platforms.Default()),
+ WithPullUnpack,
+ )
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ id := strings.Repeat("c", 64)
+
+ // We don't have limitation with length of container name,
+ // but 64 bytes of sha256 is the common case
+ container, err := client.NewContainer(ctx, id,
+ WithNewSnapshot(id, image),
+ WithNewSpec(oci.WithImageConfig(image), withExitStatus(0)),
+ )
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer container.Delete(ctx, WithSnapshotCleanup)
+
+ task, err := container.NewTask(ctx, empty())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer task.Delete(ctx)
+
+ statusC, err := task.Wait(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if err := task.Start(ctx); err != nil {
+ t.Fatal(err)
+ }
+
+ <-statusC
+}
import (
"context"
+ "crypto/sha256"
+ "fmt"
"io/ioutil"
"os"
"path/filepath"
return func(b *bundle, ns string, ropts *runctypes.RuncOptions) (shim.Config, client.Opt) {
config := b.shimConfig(ns, c, ropts)
return config,
- client.WithStart(c.Shim, b.shimAddress(ns), daemonAddress, cgroup, c.ShimDebug, exitHandler)
+ client.WithStart(c.Shim, b.shimAddress(ns, daemonAddress), daemonAddress, cgroup, c.ShimDebug, exitHandler)
}
}
// ShimConnect is a ShimOpt for connecting to an existing remote shim
func ShimConnect(c *Config, onClose func()) ShimOpt {
return func(b *bundle, ns string, ropts *runctypes.RuncOptions) (shim.Config, client.Opt) {
- return b.shimConfig(ns, c, ropts), client.WithConnect(b.shimAddress(ns), onClose)
+ return b.shimConfig(ns, c, ropts), client.WithConnect(b.decideShimAddress(ns), onClose)
}
}
// Delete deletes the bundle from disk
func (b *bundle) Delete() error {
+ address, _ := b.loadAddress()
+ if address != "" {
+ // we don't care about errors here
+ client.RemoveSocket(address)
+ }
err := os.RemoveAll(b.path)
if err == nil {
return os.RemoveAll(b.workDir)
return errors.Wrapf(err, "Failed to remove both bundle and workdir locations: %v", err2)
}
-func (b *bundle) shimAddress(namespace string) string {
+func (b *bundle) legacyShimAddress(namespace string) string {
return filepath.Join(string(filepath.Separator), "containerd-shim", namespace, b.id, "shim.sock")
}
+const socketRoot = "/run/containerd"
+
+func (b *bundle) shimAddress(namespace, socketPath string) string {
+ d := sha256.Sum256([]byte(filepath.Join(socketPath, namespace, b.id)))
+ return fmt.Sprintf("unix://%s/%x", filepath.Join(socketRoot, "s"), d)
+}
+
+func (b *bundle) loadAddress() (string, error) {
+ addressPath := filepath.Join(b.path, "address")
+ data, err := ioutil.ReadFile(addressPath)
+ if err != nil {
+ return "", err
+ }
+ return string(data), nil
+}
+
+func (b *bundle) decideShimAddress(namespace string) string {
+ address, err := b.loadAddress()
+ if err != nil {
+ return b.legacyShimAddress(namespace)
+ }
+ return address
+}
+
func (b *bundle) shimConfig(namespace string, c *Config, runcOptions *runctypes.RuncOptions) shim.Config {
var (
criuPath string
import (
"context"
+ "fmt"
"io"
"net"
"os"
"os/exec"
+ "path/filepath"
"strings"
"sync"
"syscall"
return func(ctx context.Context, config shim.Config) (_ shimapi.ShimService, _ io.Closer, err error) {
socket, err := newSocket(address)
if err != nil {
- return nil, nil, err
+ if !eaddrinuse(err) {
+ return nil, nil, err
+ }
+ if err := RemoveSocket(address); err != nil {
+ return nil, nil, errors.Wrap(err, "remove already used socket")
+ }
+ if socket, err = newSocket(address); err != nil {
+ return nil, nil, err
+ }
}
- defer socket.Close()
+
f, err := socket.File()
if err != nil {
return nil, nil, errors.Wrapf(err, "failed to get fd for socket %s", address)
go func() {
cmd.Wait()
exitHandler()
+ socket.Close()
+ RemoveSocket(address)
}()
log.G(ctx).WithFields(logrus.Fields{
"pid": cmd.Process.Pid,
"address": address,
"debug": debug,
}).Infof("shim %s started", binary)
+
+ if err := writeAddress(filepath.Join(config.Path, "address"), address); err != nil {
+ return nil, nil, err
+ }
// set shim in cgroup if it is provided
if cgroup != "" {
if err := setCgroup(cgroup, cmd); err != nil {
}
}
+func eaddrinuse(err error) bool {
+ cause := errors.Cause(err)
+ netErr, ok := cause.(*net.OpError)
+ if !ok {
+ return false
+ }
+ if netErr.Op != "listen" {
+ return false
+ }
+ syscallErr, ok := netErr.Err.(*os.SyscallError)
+ if !ok {
+ return false
+ }
+ errno, ok := syscallErr.Err.(syscall.Errno)
+ if !ok {
+ return false
+ }
+ return errno == syscall.EADDRINUSE
+}
+
func newCommand(binary, daemonAddress string, debug bool, config shim.Config, socket *os.File) (*exec.Cmd, error) {
selfExe, err := os.Executable()
if err != nil {
return cmd, nil
}
+// writeAddress writes a address file atomically
+func writeAddress(path, address string) error {
+ path, err := filepath.Abs(path)
+ if err != nil {
+ return err
+ }
+ tempPath := filepath.Join(filepath.Dir(path), fmt.Sprintf(".%s", filepath.Base(path)))
+ f, err := os.OpenFile(tempPath, os.O_RDWR|os.O_CREATE|os.O_EXCL|os.O_SYNC, 0666)
+ if err != nil {
+ return err
+ }
+ _, err = f.WriteString(address)
+ f.Close()
+ if err != nil {
+ return err
+ }
+ return os.Rename(tempPath, path)
+}
+
+const (
+ abstractSocketPrefix = "\x00"
+ socketPathLimit = 106
+)
+
+type socket string
+
+func (s socket) isAbstract() bool {
+ return !strings.HasPrefix(string(s), "unix://")
+}
+
+func (s socket) path() string {
+ path := strings.TrimPrefix(string(s), "unix://")
+ // if there was no trim performed, we assume an abstract socket
+ if len(path) == len(s) {
+ path = abstractSocketPrefix + path
+ }
+ return path
+}
+
func newSocket(address string) (*net.UnixListener, error) {
- if len(address) > 106 {
- return nil, errors.Errorf("%q: unix socket path too long (> 106)", address)
+ if len(address) > socketPathLimit {
+ return nil, errors.Errorf("%q: unix socket path too long (> %d)", address, socketPathLimit)
+ }
+ var (
+ sock = socket(address)
+ path = sock.path()
+ )
+ if !sock.isAbstract() {
+ if err := os.MkdirAll(filepath.Dir(path), 0600); err != nil {
+ return nil, errors.Wrapf(err, "%s", path)
+ }
}
- l, err := net.Listen("unix", "\x00"+address)
+ l, err := net.Listen("unix", path)
if err != nil {
- return nil, errors.Wrapf(err, "failed to listen to abstract unix socket %q", address)
+ return nil, errors.Wrapf(err, "failed to listen to unix socket %q (abstract: %t)", address, sock.isAbstract())
+ }
+ if err := os.Chmod(path, 0600); err != nil {
+ l.Close()
+ return nil, err
}
return l.(*net.UnixListener), nil
}
+// RemoveSocket removes the socket at the specified address if
+// it exists on the filesystem
+func RemoveSocket(address string) error {
+ sock := socket(address)
+ if !sock.isAbstract() {
+ return os.Remove(sock.path())
+ }
+ return nil
+}
+
func connect(address string, d func(string, time.Duration) (net.Conn, error)) (net.Conn, error) {
return d(address, 100*time.Second)
}
-func annonDialer(address string, timeout time.Duration) (net.Conn, error) {
- address = strings.TrimPrefix(address, "unix://")
- return net.DialTimeout("unix", "\x00"+address, timeout)
+func anonDialer(address string, timeout time.Duration) (net.Conn, error) {
+ return net.DialTimeout("unix", socket(address).path(), timeout)
}
// WithConnect connects to an existing shim
func WithConnect(address string, onClose func()) Opt {
return func(ctx context.Context, config shim.Config) (shimapi.ShimService, io.Closer, error) {
- conn, err := connect(address, annonDialer)
+ conn, err := connect(address, anonDialer)
if err != nil {
return nil, nil, err
}
if err != nil {
return "", err
}
- address, err := shim.SocketAddress(ctx, id)
+ address, err := shim.SocketAddress(ctx, containerdAddress, id)
if err != nil {
return "", err
}
socket, err := shim.NewSocket(address)
if err != nil {
- return "", err
+ if !shim.SocketEaddrinuse(err) {
+ return "", err
+ }
+ if err := shim.RemoveSocket(address); err != nil {
+ return "", errors.Wrap(err, "remove already used socket")
+ }
+ if socket, err = shim.NewSocket(address); err != nil {
+ return "", err
+ }
}
- defer socket.Close()
f, err := socket.File()
if err != nil {
return "", err
}
- defer f.Close()
cmd.ExtraFiles = append(cmd.ExtraFiles, f)
}
defer func() {
if err != nil {
+ _ = shim.RemoveSocket(address)
cmd.Process.Kill()
}
}()
func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*ptypes.Empty, error) {
s.cancel()
+ if address, err := shim.ReadAddress("address"); err == nil {
+ _ = shim.RemoveSocket(address)
+ }
os.Exit(0)
return empty, nil
}
flag.BoolVar(&debugFlag, "debug", false, "enable debug output in logs")
flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim")
flag.StringVar(&idFlag, "id", "", "id of the task")
- flag.StringVar(&socketFlag, "socket", "", "abstract socket path to serve")
+ flag.StringVar(&socketFlag, "socket", "", "socket path to serve")
flag.StringVar(&bundlePath, "bundle", "", "path to the bundle if not workdir")
flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd")
return err
}
go func() {
- defer l.Close()
if err := server.Serve(ctx, l); err != nil &&
!strings.Contains(err.Error(), "use of closed network connection") {
logrus.WithError(err).Fatal("containerd-shim: ttrpc server failure")
}
+ l.Close()
+ if address, err := ReadAddress("address"); err == nil {
+ _ = RemoveSocket(address)
+ }
}()
return nil
}
l, err = net.FileListener(os.NewFile(3, "socket"))
path = "[inherited from parent]"
} else {
- if len(path) > 106 {
- return nil, errors.Errorf("%q: unix socket path too long (> 106)", path)
+ if len(path) > socketPathLimit {
+ return nil, errors.Errorf("%q: unix socket path too long (> %d)", path, socketPathLimit)
}
- l, err = net.Listen("unix", "\x00"+path)
+ l, err = net.Listen("unix", path)
}
if err != nil {
return nil, err
}
- logrus.WithField("socket", path).Debug("serving api on abstract socket")
+ logrus.WithField("socket", path).Debug("serving api on socket")
return l, nil
}
import (
"context"
"fmt"
+ "io/ioutil"
"net"
"os"
"os/exec"
}
return os.Rename(tempPath, path)
}
+
+// ErrNoAddress is returned when the address file has no content
+var ErrNoAddress = errors.New("no shim address")
+
+// ReadAddress returns the shim's socket address from the path
+func ReadAddress(path string) (string, error) {
+ path, err := filepath.Abs(path)
+ if err != nil {
+ return "", err
+ }
+ data, err := ioutil.ReadFile(path)
+ if err != nil {
+ return "", err
+ }
+ if len(data) == 0 {
+ return "", ErrNoAddress
+ }
+ return string(data), nil
+}
import (
"context"
+ "crypto/sha256"
+ "fmt"
"net"
+ "os"
"path/filepath"
"strings"
"syscall"
"github.com/pkg/errors"
)
+const socketPathLimit = 106
+
func getSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
Setpgid: true,
return sys.SetOOMScore(pid, sys.OOMScoreMaxKillable)
}
-// SocketAddress returns an abstract socket address
-func SocketAddress(ctx context.Context, id string) (string, error) {
+const socketRoot = "/run/containerd"
+
+// SocketAddress returns a socket address
+func SocketAddress(ctx context.Context, socketPath, id string) (string, error) {
ns, err := namespaces.NamespaceRequired(ctx)
if err != nil {
return "", err
}
- return filepath.Join(string(filepath.Separator), "containerd-shim", ns, id, "shim.sock"), nil
+ d := sha256.Sum256([]byte(filepath.Join(socketPath, ns, id)))
+ return fmt.Sprintf("unix://%s/%x", filepath.Join(socketRoot, "s"), d), nil
}
-// AnonDialer returns a dialer for an abstract socket
+// AnonDialer returns a dialer for a socket
func AnonDialer(address string, timeout time.Duration) (net.Conn, error) {
- address = strings.TrimPrefix(address, "unix://")
- return net.DialTimeout("unix", "\x00"+address, timeout)
+ return net.DialTimeout("unix", socket(address).path(), timeout)
}
// NewSocket returns a new socket
func NewSocket(address string) (*net.UnixListener, error) {
- if len(address) > 106 {
- return nil, errors.Errorf("%q: unix socket path too long (> 106)", address)
+ var (
+ sock = socket(address)
+ path = sock.path()
+ )
+ if !sock.isAbstract() {
+ if err := os.MkdirAll(filepath.Dir(path), 0600); err != nil {
+ return nil, errors.Wrapf(err, "%s", path)
+ }
}
- l, err := net.Listen("unix", "\x00"+address)
+ l, err := net.Listen("unix", path)
if err != nil {
- return nil, errors.Wrapf(err, "failed to listen to abstract unix socket %q", address)
+ return nil, err
+ }
+ if err := os.Chmod(path, 0600); err != nil {
+ os.Remove(sock.path())
+ l.Close()
+ return nil, err
}
return l.(*net.UnixListener), nil
}
+
+const abstractSocketPrefix = "\x00"
+
+type socket string
+
+func (s socket) isAbstract() bool {
+ return !strings.HasPrefix(string(s), "unix://")
+}
+
+func (s socket) path() string {
+ path := strings.TrimPrefix(string(s), "unix://")
+ // if there was no trim performed, we assume an abstract socket
+ if len(path) == len(s) {
+ path = abstractSocketPrefix + path
+ }
+ return path
+}
+
+// RemoveSocket removes the socket at the specified address if
+// it exists on the filesystem
+func RemoveSocket(address string) error {
+ sock := socket(address)
+ if !sock.isAbstract() {
+ return os.Remove(sock.path())
+ }
+ return nil
+}
+
+// SocketEaddrinuse returns true if the provided error is caused by the
+// EADDRINUSE error number
+func SocketEaddrinuse(err error) bool {
+ netErr, ok := err.(*net.OpError)
+ if !ok {
+ return false
+ }
+ if netErr.Op != "listen" {
+ return false
+ }
+ syscallErr, ok := netErr.Err.(*os.SyscallError)
+ if !ok {
+ return false
+ }
+ errno, ok := syscallErr.Err.(syscall.Errno)
+ if !ok {
+ return false
+ }
+ return errno == syscall.EADDRINUSE
+}
+
+// CanConnect returns true if the socket provided at the address
+// is accepting new connections
+func CanConnect(address string) bool {
+ conn, err := AnonDialer(address, 100*time.Millisecond)
+ if err != nil {
+ return false
+ }
+ conn.Close()
+ return true
+}
}
return l, nil
}
+
+// RemoveSocket removes the socket at the specified address if
+// it exists on the filesystem
+func RemoveSocket(address string) error {
+ return nil
+}