[PATCH] THRIFT-5322: Guard against large string/binary lengths in Go
authorYuxuan 'fishy' Wang <yuxuan.wang@reddit.com>
Thu, 10 Dec 2020 22:42:37 +0000 (14:42 -0800)
committerLaszlo Boszormenyi (GCS) <gcs@debian.org>
Sun, 14 Feb 2021 18:50:04 +0000 (18:50 +0000)
Client: go

In TBinaryProtocol.ReadString, TBinaryProtocol.ReadBinary,
TCompactProtocol.ReadString, and TCompactProtocol.ReadBinary, use
safeReadBytes to prevent from large allocation on malformed sizes.

    $ go test -bench=SafeReadBytes -benchmem
    BenchmarkSafeReadBytes/normal-12                  625057              1789 ns/op            2176 B/op          5 allocs/op
    BenchmarkSafeReadBytes/max-askedSize-12           545271              2236 ns/op           14464 B/op          7 allocs/op
    PASS

Gbp-Pq: Name THRIFT-5322.patch

lib/go/thrift/binary_protocol.go
lib/go/thrift/binary_protocol_test.go
lib/go/thrift/compact_protocol.go

index 93ae898cf5e78953110352a2daf431b57a74f625..c66e4e456440c0e55489a3f62bd79311e175e7e5 100644 (file)
@@ -432,6 +432,15 @@ func (p *TBinaryProtocol) ReadString() (value string, err error) {
                err = invalidDataLength
                return
        }
+       if size == 0 {
+               return "", nil
+       }
+       if size < int32(len(p.buffer)) {
+               // Avoid allocation on small reads
+               buf := p.buffer[:size]
+               read, e := io.ReadFull(p.trans, buf)
+               return string(buf[:read]), NewTProtocolException(e)
+       }
 
        return p.readStringBody(size)
 }
@@ -445,9 +454,7 @@ func (p *TBinaryProtocol) ReadBinary() ([]byte, error) {
                return nil, invalidDataLength
        }
 
-       isize := int(size)
-       buf := make([]byte, isize)
-       _, err := io.ReadFull(p.trans, buf)
+       buf, err := safeReadBytes(size, p.trans)
        return buf, NewTProtocolException(err)
 }
 
@@ -468,38 +475,21 @@ func (p *TBinaryProtocol) readAll(buf []byte) error {
        return NewTProtocolException(err)
 }
 
-const readLimit = 32768
-
 func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) {
-       if size < 0 {
-               return "", nil
-       }
-
-       var (
-               buf bytes.Buffer
-               e   error
-               b   []byte
-       )
+       buf, err := safeReadBytes(size, p.trans)
+       return string(buf), NewTProtocolException(err)
+}
 
-       switch {
-       case int(size) <= len(p.buffer):
-               b = p.buffer[:size] // avoids allocation for small reads
-       case int(size) < readLimit:
-               b = make([]byte, size)
-       default:
-               b = make([]byte, readLimit)
+// This function is shared between TBinaryProtocol and TCompactProtocol.
+//
+// It tries to read size bytes from trans, in a way that prevents large
+// allocations when size is insanely large (mostly caused by malformed message).
+func safeReadBytes(size int32, trans io.Reader) ([]byte, error) {
+       if size < 0 {
+               return nil, nil
        }
 
-       for size > 0 {
-               _, e = io.ReadFull(p.trans, b)
-               buf.Write(b)
-               if e != nil {
-                       break
-               }
-               size -= readLimit
-               if size < readLimit && size > 0 {
-                       b = b[:size]
-               }
-       }
-       return buf.String(), NewTProtocolException(e)
+       buf := new(bytes.Buffer)
+       _, err := io.CopyN(buf, trans, int64(size))
+       return buf.Bytes(), err
 }
index 0462cc79deead19a3082b6282e807ca18c7cbb86..88bfd26b7ea76bc34cddf162c99f31067c6c6b15 100644 (file)
 package thrift
 
 import (
+       "bytes"
+       "math"
+       "strings"
        "testing"
 )
 
 func TestReadWriteBinaryProtocol(t *testing.T) {
        ReadWriteProtocolTest(t, NewTBinaryProtocolFactoryDefault())
 }
+
+const (
+       safeReadBytesSource = `
+Lorem ipsum dolor sit amet, consectetur adipiscing elit. Integer sit amet
+tincidunt nibh. Phasellus vel convallis libero, sit amet posuere quam. Nullam
+blandit velit at nibh fringilla, sed egestas erat dapibus. Sed hendrerit
+tincidunt accumsan. Curabitur consectetur bibendum dui nec hendrerit. Fusce quis
+turpis nec magna efficitur volutpat a ut nibh. Vestibulum odio risus, tristique
+a nisi et, congue mattis mi. Vivamus a nunc justo. Mauris molestie sagittis
+magna, hendrerit auctor lectus egestas non. Phasellus pretium, odio sit amet
+bibendum feugiat, velit nunc luctus erat, ac bibendum mi dui molestie nulla.
+Nullam fermentum magna eu elit vehicula tincidunt. Etiam ornare laoreet
+dignissim. Ut sed nunc ac neque vulputate fermentum. Morbi volutpat dapibus
+magna, at porttitor quam facilisis a. Donec eget fermentum risus. Aliquam erat
+volutpat.
+
+Phasellus molestie id ante vel iaculis. Fusce eget quam nec quam viverra laoreet
+vitae a dui. Mauris blandit blandit dui, iaculis interdum diam mollis at. Morbi
+vel sem et.
+`
+       safeReadBytesSourceLen = len(safeReadBytesSource)
+)
+
+func TestSafeReadBytes(t *testing.T) {
+       srcData := []byte(safeReadBytesSource)
+
+       for _, c := range []struct {
+               label     string
+               askedSize int32
+               dataSize  int
+       }{
+               {
+                       label:     "normal",
+                       askedSize: 100,
+                       dataSize:  100,
+               },
+               {
+                       label:     "max-askedSize",
+                       askedSize: math.MaxInt32,
+                       dataSize:  safeReadBytesSourceLen,
+               },
+       } {
+               t.Run(c.label, func(t *testing.T) {
+                       data := bytes.NewReader(srcData[:c.dataSize])
+                       buf, err := safeReadBytes(c.askedSize, data)
+                       if len(buf) != c.dataSize {
+                               t.Errorf(
+                                       "Expected to read %d bytes, got %d",
+                                       c.dataSize,
+                                       len(buf),
+                               )
+                       }
+                       if !strings.HasPrefix(safeReadBytesSource, string(buf)) {
+                               t.Errorf("Unexpected read data: %q", buf)
+                       }
+                       if int32(c.dataSize) < c.askedSize {
+                               // We expect error in this case
+                               if err == nil {
+                                       t.Errorf(
+                                               "Expected error when dataSize %d < askedSize %d, got nil",
+                                               c.dataSize,
+                                               c.askedSize,
+                                       )
+                               }
+                       } else {
+                               // We expect no error in this case
+                               if err != nil {
+                                       t.Errorf(
+                                               "Expected no error when dataSize %d >= askedSize %d, got: %v",
+                                               c.dataSize,
+                                               c.askedSize,
+                                               err,
+                                       )
+                               }
+                       }
+               })
+       }
+}
+
+func generateSafeReadBytesBenchmark(askedSize int32, dataSize int) func(b *testing.B) {
+       return func(b *testing.B) {
+               data := make([]byte, dataSize)
+               b.ResetTimer()
+               for i := 0; i < b.N; i++ {
+                       safeReadBytes(askedSize, bytes.NewReader(data))
+               }
+       }
+}
+
+func TestSafeReadBytesAlloc(t *testing.T) {
+       if testing.Short() {
+               // NOTE: Since this test runs a benchmark test, it takes at
+               // least 1 second.
+               //
+               // In general we try to avoid unit tests taking that long to run,
+               // but it's to verify a security issue so we made an exception
+               // here:
+               // https://issues.apache.org/jira/browse/THRIFT-5322
+               t.Skip("skipping test in short mode.")
+       }
+
+       const (
+               askedSize = int32(math.MaxInt32)
+               dataSize  = 4096
+       )
+
+       // The purpose of this test is that in the case a string header says
+       // that it has a string askedSize bytes long, the implementation should
+       // not just allocate askedSize bytes upfront. So when there're actually
+       // not enough data to be read (dataSize), the actual allocated bytes
+       // should be somewhere between dataSize and askedSize.
+       //
+       // Different approachs could have different memory overheads, so this
+       // target is arbitrary in nature. But when dataSize is small enough
+       // compare to askedSize, half the askedSize is a good and safe target.
+       const target = int64(askedSize) / 2
+
+       bm := testing.Benchmark(generateSafeReadBytesBenchmark(askedSize, dataSize))
+       actual := bm.AllocedBytesPerOp()
+       if actual > target {
+               t.Errorf(
+                       "Expected allocated bytes per op to be <= %d, got %d",
+                       target,
+                       actual,
+               )
+       } else {
+               t.Logf("Allocated bytes: %d B/op", actual)
+       }
+}
+
+func BenchmarkSafeReadBytes(b *testing.B) {
+       for _, c := range []struct {
+               label     string
+               askedSize int32
+               dataSize  int
+       }{
+               {
+                       label:     "normal",
+                       askedSize: 100,
+                       dataSize:  100,
+               },
+               {
+                       label:     "max-askedSize",
+                       askedSize: math.MaxInt32,
+                       dataSize:  4096,
+               },
+       } {
+               b.Run(c.label, generateSafeReadBytesBenchmark(c.askedSize, c.dataSize))
+       }
+}
index c5b6f4b19ddd916a261b094fd2c0d4b6cf279b7b..bb0ceba12116889cc8f070af1c7920a1d4b322bf 100644 (file)
@@ -561,17 +561,17 @@ func (p *TCompactProtocol) ReadString() (value string, err error) {
        if length < 0 {
                return "", invalidDataLength
        }
-
        if length == 0 {
                return "", nil
        }
-       var buf []byte
-       if length <= int32(len(p.buffer)) {
-               buf = p.buffer[0:length]
-       } else {
-               buf = make([]byte, length)
+       if length < int32(len(p.buffer)) {
+               // Avoid allocation on small reads
+               buf := p.buffer[:length]
+               read, e := io.ReadFull(p.trans, buf)
+               return string(buf[:read]), NewTProtocolException(e)
        }
-       _, e = io.ReadFull(p.trans, buf)
+
+       buf, e := safeReadBytes(length, p.trans)
        return string(buf), NewTProtocolException(e)
 }
 
@@ -588,8 +588,7 @@ func (p *TCompactProtocol) ReadBinary() (value []byte, err error) {
                return nil, invalidDataLength
        }
 
-       buf := make([]byte, length)
-       _, e = io.ReadFull(p.trans, buf)
+       buf, e := safeReadBytes(length, p.trans)
        return buf, NewTProtocolException(e)
 }