Ignore invalid "Host:" header between go1.6 and old docker clients
authorPaul Tagliamonte <paultag@debian.org>
Tue, 12 Jul 2016 14:46:35 +0000 (15:46 +0100)
committerTianon Gravi <tianon@debian.org>
Tue, 12 Jul 2016 14:46:35 +0000 (15:46 +0100)
Gbp-Pq: Name 22000--ignore-invalid-host-header.patch

docker/daemon.go
docker/daemon_unix.go
docker/daemon_windows.go
docker/hack/malformed_host_override.go [new file with mode: 0644]
docker/hack/malformed_host_override_test.go [new file with mode: 0644]

index bee921c78242fc0e68ad8a82c3c3e63c431475f6..7cd8adc2d7af4b420ef958d6187fb3b3524d3937 100644 (file)
@@ -250,13 +250,14 @@ func (cli *DaemonCli) CmdDaemon(args ...string) error {
                if len(protoAddrParts) != 2 {
                        logrus.Fatalf("bad format %s, expected PROTO://ADDR", protoAddr)
                }
-               l, err := listeners.Init(protoAddrParts[0], protoAddrParts[1], serverConfig.SocketGroup, serverConfig.TLSConfig)
+               ls, err := listeners.Init(protoAddrParts[0], protoAddrParts[1], serverConfig.SocketGroup, serverConfig.TLSConfig)
                if err != nil {
                        logrus.Fatal(err)
                }
+               ls = wrapListeners(protoAddrParts[0], ls)
 
                logrus.Debugf("Listener created for HTTP on %s (%s)", protoAddrParts[0], protoAddrParts[1])
-               api.Accept(protoAddrParts[1], l...)
+               api.Accept(protoAddrParts[1], ls...)
        }
 
        if err := migrateKey(); err != nil {
index b65eb1f0ef0ae560a470a28c4abfaad5ef219110..4e4b498a7842bb0ceadccb56c607bdda95157ed1 100644 (file)
@@ -4,10 +4,13 @@ package main
 
 import (
        "fmt"
+       "net"
        "os"
        "os/signal"
        "syscall"
 
+       "github.com/docker/docker/docker/hack"
+
        "github.com/Sirupsen/logrus"
        apiserver "github.com/docker/docker/api/server"
        "github.com/docker/docker/daemon"
@@ -80,3 +83,15 @@ func (cli *DaemonCli) getPlatformRemoteOptions() []libcontainerd.RemoteOption {
        }
        return opts
 }
+
+func wrapListeners(proto string, ls []net.Listener) []net.Listener {
+       switch proto {
+       case "unix":
+               ls[0] = &hack.MalformedHostHeaderOverride{ls[0]}
+       case "fd":
+               for i := range ls {
+                       ls[i] = &hack.MalformedHostHeaderOverride{ls[i]}
+               }
+       }
+       return ls
+}
index ae8d737d6c77b29e6abb0ea57594fa4a51d75455..b5ffbf99365a5440c486ce3579f35927f99f0330 100644 (file)
@@ -4,6 +4,7 @@ package main
 
 import (
        "fmt"
+       "net"
        "os"
        "syscall"
 
@@ -62,3 +63,7 @@ func setupConfigReloadTrap(configFile string, flags *mflag.FlagSet, reload func(
 func (cli *DaemonCli) getPlatformRemoteOptions() []libcontainerd.RemoteOption {
        return nil
 }
+
+func wrapListeners(proto string, ls []net.Listener) []net.Listener {
+       return ls
+}
diff --git a/docker/hack/malformed_host_override.go b/docker/hack/malformed_host_override.go
new file mode 100644 (file)
index 0000000..d4aa3dd
--- /dev/null
@@ -0,0 +1,121 @@
+// +build !windows
+
+package hack
+
+import "net"
+
+// MalformedHostHeaderOverride is a wrapper to be able
+// to overcome the 400 Bad request coming from old docker
+// clients that send an invalid Host header.
+type MalformedHostHeaderOverride struct {
+       net.Listener
+}
+
+// MalformedHostHeaderOverrideConn wraps the underlying unix
+// connection and keeps track of the first read from http.Server
+// which just reads the headers.
+type MalformedHostHeaderOverrideConn struct {
+       net.Conn
+       first bool
+}
+
+var closeConnHeader = []byte("\r\nConnection: close\r")
+
+// Read reads the first *read* request from http.Server to inspect
+// the Host header. If the Host starts with / then we're talking to
+// an old docker client which send an invalid Host header. To not
+// error out in http.Server we rewrite the first bytes of the request
+// to sanitize the Host header itself.
+// In case we're not dealing with old docker clients the data is just passed
+// to the server w/o modification.
+func (l *MalformedHostHeaderOverrideConn) Read(b []byte) (n int, err error) {
+       // http.Server uses a 4k buffer
+       if l.first && len(b) == 4096 {
+               // This keeps track of the first read from http.Server which just reads
+               // the headers
+               l.first = false
+               // The first read of the connection by http.Server is done limited to
+               // DefaultMaxHeaderBytes (usually 1 << 20) + 4096.
+               // Here we do the first read which gets us all the http headers to
+               // be inspected and modified below.
+               c, err := l.Conn.Read(b)
+               if err != nil {
+                       return c, err
+               }
+
+               var (
+                       start, end    int
+                       firstLineFeed = -1
+                       buf           []byte
+               )
+               for i := 0; i <= c-1-7; i++ {
+                       if b[i] == '\n' && firstLineFeed == -1 {
+                               firstLineFeed = i
+                       }
+                       if b[i] != '\n' {
+                               continue
+                       }
+
+                       if b[i+1] == '\r' && b[i+2] == '\n' {
+                               return c, nil
+                       }
+
+                       if b[i+1] != 'H' {
+                               continue
+                       }
+                       if b[i+2] != 'o' {
+                               continue
+                       }
+                       if b[i+3] != 's' {
+                               continue
+                       }
+                       if b[i+4] != 't' {
+                               continue
+                       }
+                       if b[i+5] != ':' {
+                               continue
+                       }
+                       if b[i+6] != ' ' {
+                               continue
+                       }
+                       if b[i+7] != '/' {
+                               continue
+                       }
+                       // ensure clients other than the docker clients do not get this hack
+                       if i != firstLineFeed {
+                               return c, nil
+                       }
+                       start = i + 7
+                       // now find where the value ends
+                       for ii, bbb := range b[start:c] {
+                               if bbb == '\n' {
+                                       end = start + ii
+                                       break
+                               }
+                       }
+                       buf = make([]byte, 0, c+len(closeConnHeader)-(end-start))
+                       // strip the value of the host header and
+                       // inject `Connection: close` to ensure we don't reuse this connection
+                       buf = append(buf, b[:start]...)
+                       buf = append(buf, closeConnHeader...)
+                       buf = append(buf, b[end:c]...)
+                       copy(b, buf)
+                       break
+               }
+               if len(buf) == 0 {
+                       return c, nil
+               }
+               return len(buf), nil
+       }
+       return l.Conn.Read(b)
+}
+
+// Accept makes the listener accepts connections and wraps the connection
+// in a MalformedHostHeaderOverrideConn initilizing first to true.
+func (l *MalformedHostHeaderOverride) Accept() (net.Conn, error) {
+       c, err := l.Listener.Accept()
+       if err != nil {
+               return c, err
+       }
+       return &MalformedHostHeaderOverrideConn{c, true}, nil
+}
diff --git a/docker/hack/malformed_host_override_test.go b/docker/hack/malformed_host_override_test.go
new file mode 100644 (file)
index 0000000..1a0a60b
--- /dev/null
@@ -0,0 +1,124 @@
+// +build !windows
+
+package hack
+
+import (
+       "bytes"
+       "io"
+       "net"
+       "strings"
+       "testing"
+)
+
+type bufConn struct {
+       net.Conn
+       buf *bytes.Buffer
+}
+
+func (bc *bufConn) Read(b []byte) (int, error) {
+       return bc.buf.Read(b)
+}
+
+func TestHeaderOverrideHack(t *testing.T) {
+       tests := [][2][]byte{
+               {
+                       []byte("GET /foo\nHost: /var/run/docker.sock\nUser-Agent: Docker\r\n\r\n"),
+                       []byte("GET /foo\nHost: \r\nConnection: close\r\nUser-Agent: Docker\r\n\r\n"),
+               },
+               {
+                       []byte("GET /foo\nHost: /var/run/docker.sock\nUser-Agent: Docker\nFoo: Bar\r\n"),
+                       []byte("GET /foo\nHost: \r\nConnection: close\r\nUser-Agent: Docker\nFoo: Bar\r\n"),
+               },
+               {
+                       []byte("GET /foo\nHost: /var/run/docker.sock\nUser-Agent: Docker\r\n\r\ntest something!"),
+                       []byte("GET /foo\nHost: \r\nConnection: close\r\nUser-Agent: Docker\r\n\r\ntest something!"),
+               },
+               {
+                       []byte("GET /foo\nHost: /var/run/docker.sock\nUser-Agent: Docker\r\n\r\ntest something! " + strings.Repeat("test", 15000)),
+                       []byte("GET /foo\nHost: \r\nConnection: close\r\nUser-Agent: Docker\r\n\r\ntest something! " + strings.Repeat("test", 15000)),
+               },
+               {
+                       []byte("GET /foo\nFoo: Bar\nHost: /var/run/docker.sock\nUser-Agent: Docker\r\n\r\n"),
+                       []byte("GET /foo\nFoo: Bar\nHost: /var/run/docker.sock\nUser-Agent: Docker\r\n\r\n"),
+               },
+       }
+
+       // Test for https://github.com/docker/docker/issues/23045
+       h0 := "GET /foo\nUser-Agent: Docker\r\n\r\n"
+       h0 = h0 + strings.Repeat("a", 4096-len(h0)-1) + "\n"
+       tests = append(tests, [2][]byte{[]byte(h0), []byte(h0)})
+
+       for _, pair := range tests {
+               read := make([]byte, 4096)
+               client := &bufConn{
+                       buf: bytes.NewBuffer(pair[0]),
+               }
+               l := MalformedHostHeaderOverrideConn{client, true}
+
+               n, err := l.Read(read)
+               if err != nil && err != io.EOF {
+                       t.Fatalf("read: %d - %d, err: %v\n%s", n, len(pair[0]), err, string(read[:n]))
+               }
+               if !bytes.Equal(read[:n], pair[1][:n]) {
+                       t.Fatalf("\n%s\n%s\n", read[:n], pair[1][:n])
+               }
+       }
+}
+
+func BenchmarkWithHack(b *testing.B) {
+       client, srv := net.Pipe()
+       done := make(chan struct{})
+       req := []byte("GET /foo\nHost: /var/run/docker.sock\nUser-Agent: Docker\n")
+       read := make([]byte, 4096)
+       b.SetBytes(int64(len(req) * 30))
+
+       l := MalformedHostHeaderOverrideConn{client, true}
+       go func() {
+               for {
+                       if _, err := srv.Write(req); err != nil {
+                               srv.Close()
+                               break
+                       }
+                       l.first = true // make sure each subsequent run uses the hack parsing
+               }
+               close(done)
+       }()
+
+       for i := 0; i < b.N; i++ {
+               for i := 0; i < 30; i++ {
+                       if n, err := l.Read(read); err != nil && err != io.EOF {
+                               b.Fatalf("read: %d - %d, err: %v\n%s", n, len(req), err, string(read[:n]))
+                       }
+               }
+       }
+       l.Close()
+       <-done
+}
+
+func BenchmarkNoHack(b *testing.B) {
+       client, srv := net.Pipe()
+       done := make(chan struct{})
+       req := []byte("GET /foo\nHost: /var/run/docker.sock\nUser-Agent: Docker\n")
+       read := make([]byte, 4096)
+       b.SetBytes(int64(len(req) * 30))
+
+       go func() {
+               for {
+                       if _, err := srv.Write(req); err != nil {
+                               srv.Close()
+                               break
+                       }
+               }
+               close(done)
+       }()
+
+       for i := 0; i < b.N; i++ {
+               for i := 0; i < 30; i++ {
+                       if _, err := client.Read(read); err != nil && err != io.EOF {
+                               b.Fatal(err)
+                       }
+               }
+       }
+       client.Close()
+       <-done
+}