err = invalidDataLength
return
}
+ if size == 0 {
+ return "", nil
+ }
+ if size < int32(len(p.buffer)) {
+ // Avoid allocation on small reads
+ buf := p.buffer[:size]
+ read, e := io.ReadFull(p.trans, buf)
+ return string(buf[:read]), NewTProtocolException(e)
+ }
return p.readStringBody(size)
}
return nil, invalidDataLength
}
- isize := int(size)
- buf := make([]byte, isize)
- _, err := io.ReadFull(p.trans, buf)
+ buf, err := safeReadBytes(size, p.trans)
return buf, NewTProtocolException(err)
}
return NewTProtocolException(err)
}
-const readLimit = 32768
-
func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) {
- if size < 0 {
- return "", nil
- }
-
- var (
- buf bytes.Buffer
- e error
- b []byte
- )
+ buf, err := safeReadBytes(size, p.trans)
+ return string(buf), NewTProtocolException(err)
+}
- switch {
- case int(size) <= len(p.buffer):
- b = p.buffer[:size] // avoids allocation for small reads
- case int(size) < readLimit:
- b = make([]byte, size)
- default:
- b = make([]byte, readLimit)
+// This function is shared between TBinaryProtocol and TCompactProtocol.
+//
+// It tries to read size bytes from trans, in a way that prevents large
+// allocations when size is insanely large (mostly caused by malformed message).
+func safeReadBytes(size int32, trans io.Reader) ([]byte, error) {
+ if size < 0 {
+ return nil, nil
}
- for size > 0 {
- _, e = io.ReadFull(p.trans, b)
- buf.Write(b)
- if e != nil {
- break
- }
- size -= readLimit
- if size < readLimit && size > 0 {
- b = b[:size]
- }
- }
- return buf.String(), NewTProtocolException(e)
+ buf := new(bytes.Buffer)
+ _, err := io.CopyN(buf, trans, int64(size))
+ return buf.Bytes(), err
}
package thrift
import (
+ "bytes"
+ "math"
+ "strings"
"testing"
)
func TestReadWriteBinaryProtocol(t *testing.T) {
ReadWriteProtocolTest(t, NewTBinaryProtocolFactoryDefault())
}
+
+const (
+ safeReadBytesSource = `
+Lorem ipsum dolor sit amet, consectetur adipiscing elit. Integer sit amet
+tincidunt nibh. Phasellus vel convallis libero, sit amet posuere quam. Nullam
+blandit velit at nibh fringilla, sed egestas erat dapibus. Sed hendrerit
+tincidunt accumsan. Curabitur consectetur bibendum dui nec hendrerit. Fusce quis
+turpis nec magna efficitur volutpat a ut nibh. Vestibulum odio risus, tristique
+a nisi et, congue mattis mi. Vivamus a nunc justo. Mauris molestie sagittis
+magna, hendrerit auctor lectus egestas non. Phasellus pretium, odio sit amet
+bibendum feugiat, velit nunc luctus erat, ac bibendum mi dui molestie nulla.
+Nullam fermentum magna eu elit vehicula tincidunt. Etiam ornare laoreet
+dignissim. Ut sed nunc ac neque vulputate fermentum. Morbi volutpat dapibus
+magna, at porttitor quam facilisis a. Donec eget fermentum risus. Aliquam erat
+volutpat.
+
+Phasellus molestie id ante vel iaculis. Fusce eget quam nec quam viverra laoreet
+vitae a dui. Mauris blandit blandit dui, iaculis interdum diam mollis at. Morbi
+vel sem et.
+`
+ safeReadBytesSourceLen = len(safeReadBytesSource)
+)
+
+func TestSafeReadBytes(t *testing.T) {
+ srcData := []byte(safeReadBytesSource)
+
+ for _, c := range []struct {
+ label string
+ askedSize int32
+ dataSize int
+ }{
+ {
+ label: "normal",
+ askedSize: 100,
+ dataSize: 100,
+ },
+ {
+ label: "max-askedSize",
+ askedSize: math.MaxInt32,
+ dataSize: safeReadBytesSourceLen,
+ },
+ } {
+ t.Run(c.label, func(t *testing.T) {
+ data := bytes.NewReader(srcData[:c.dataSize])
+ buf, err := safeReadBytes(c.askedSize, data)
+ if len(buf) != c.dataSize {
+ t.Errorf(
+ "Expected to read %d bytes, got %d",
+ c.dataSize,
+ len(buf),
+ )
+ }
+ if !strings.HasPrefix(safeReadBytesSource, string(buf)) {
+ t.Errorf("Unexpected read data: %q", buf)
+ }
+ if int32(c.dataSize) < c.askedSize {
+ // We expect error in this case
+ if err == nil {
+ t.Errorf(
+ "Expected error when dataSize %d < askedSize %d, got nil",
+ c.dataSize,
+ c.askedSize,
+ )
+ }
+ } else {
+ // We expect no error in this case
+ if err != nil {
+ t.Errorf(
+ "Expected no error when dataSize %d >= askedSize %d, got: %v",
+ c.dataSize,
+ c.askedSize,
+ err,
+ )
+ }
+ }
+ })
+ }
+}
+
+func generateSafeReadBytesBenchmark(askedSize int32, dataSize int) func(b *testing.B) {
+ return func(b *testing.B) {
+ data := make([]byte, dataSize)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ safeReadBytes(askedSize, bytes.NewReader(data))
+ }
+ }
+}
+
+func TestSafeReadBytesAlloc(t *testing.T) {
+ if testing.Short() {
+ // NOTE: Since this test runs a benchmark test, it takes at
+ // least 1 second.
+ //
+ // In general we try to avoid unit tests taking that long to run,
+ // but it's to verify a security issue so we made an exception
+ // here:
+ // https://issues.apache.org/jira/browse/THRIFT-5322
+ t.Skip("skipping test in short mode.")
+ }
+
+ const (
+ askedSize = int32(math.MaxInt32)
+ dataSize = 4096
+ )
+
+ // The purpose of this test is that in the case a string header says
+ // that it has a string askedSize bytes long, the implementation should
+ // not just allocate askedSize bytes upfront. So when there're actually
+ // not enough data to be read (dataSize), the actual allocated bytes
+ // should be somewhere between dataSize and askedSize.
+ //
+ // Different approachs could have different memory overheads, so this
+ // target is arbitrary in nature. But when dataSize is small enough
+ // compare to askedSize, half the askedSize is a good and safe target.
+ const target = int64(askedSize) / 2
+
+ bm := testing.Benchmark(generateSafeReadBytesBenchmark(askedSize, dataSize))
+ actual := bm.AllocedBytesPerOp()
+ if actual > target {
+ t.Errorf(
+ "Expected allocated bytes per op to be <= %d, got %d",
+ target,
+ actual,
+ )
+ } else {
+ t.Logf("Allocated bytes: %d B/op", actual)
+ }
+}
+
+func BenchmarkSafeReadBytes(b *testing.B) {
+ for _, c := range []struct {
+ label string
+ askedSize int32
+ dataSize int
+ }{
+ {
+ label: "normal",
+ askedSize: 100,
+ dataSize: 100,
+ },
+ {
+ label: "max-askedSize",
+ askedSize: math.MaxInt32,
+ dataSize: 4096,
+ },
+ } {
+ b.Run(c.label, generateSafeReadBytesBenchmark(c.askedSize, c.dataSize))
+ }
+}
if length < 0 {
return "", invalidDataLength
}
-
if length == 0 {
return "", nil
}
- var buf []byte
- if length <= int32(len(p.buffer)) {
- buf = p.buffer[0:length]
- } else {
- buf = make([]byte, length)
+ if length < int32(len(p.buffer)) {
+ // Avoid allocation on small reads
+ buf := p.buffer[:length]
+ read, e := io.ReadFull(p.trans, buf)
+ return string(buf[:read]), NewTProtocolException(e)
}
- _, e = io.ReadFull(p.trans, buf)
+
+ buf, e := safeReadBytes(length, p.trans)
return string(buf), NewTProtocolException(e)
}
return nil, invalidDataLength
}
- buf := make([]byte, length)
- _, e = io.ReadFull(p.trans, buf)
+ buf, e := safeReadBytes(length, p.trans)
return buf, NewTProtocolException(e)
}