Просмотр исходного кода

Merge pull request #26 from coroot/fix_mongodb_message_parsing

Avoiding unnecessary allocations in the case of truncated mongodb packets
Nikolay Sivko 2 лет назад
Родитель
Сommit
4a9d581f65
2 измененных файлов с 51 добавлено и 36 удалено
  1. 22 32
      tracing/mongo.go
  2. 29 4
      tracing/tracing_test.go

+ 22 - 32
tracing/mongo.go

@@ -1,8 +1,6 @@
 package tracing
 
 import (
-	"bufio"
-	"bytes"
 	"encoding/binary"
 	"github.com/coroot/coroot-node-agent/ebpftracer"
 	"go.mongodb.org/mongo-driver/bson"
@@ -10,7 +8,6 @@ import (
 	"go.opentelemetry.io/otel/codes"
 	semconv "go.opentelemetry.io/otel/semconv/v1.18.0"
 	"go.opentelemetry.io/otel/trace"
-	"io"
 	"time"
 )
 
@@ -31,38 +28,31 @@ func handleMongoQuery(start, end time.Time, r *ebpftracer.L7Request, attrs []att
 	span.End(trace.WithTimestamp(end))
 }
 
-type mongoMsgHeader struct {
-	MessageLength int32
-	RequestID     int32
-	ResponseTo    int32
-	OpCode        int32
-}
+const (
+	mongoHeaderLength      = 20
+	mongoOpCodeOffset      = 12
+	mongoSectionKindLength = 1
+	mongoSectionSizeLength = 4
+	mongoSectionKindBody   = 0
+)
 
-func parseMongo(payload []byte) string {
-	h := &mongoMsgHeader{}
-	reader := bufio.NewReader(bytes.NewReader(payload))
-	if err := binary.Read(reader, binary.LittleEndian, h); err != nil {
-		return ""
-	}
-	if h.OpCode != MongoOpMSG {
-		return ""
+func parseMongo(payload []byte) (res string) {
+	res = "<truncated>"
+	if len(payload) < mongoHeaderLength+mongoSectionKindLength+mongoSectionSizeLength {
+		return
 	}
-	if _, err := reader.Discard(4); err != nil { //flagBits
-		return ""
+	opCode := binary.LittleEndian.Uint32(payload[mongoOpCodeOffset:])
+	if opCode != MongoOpMSG {
+		return
 	}
-	if sectionKind, err := reader.ReadByte(); err != nil || sectionKind != 0 {
-		return ""
+	sectionKind := payload[mongoHeaderLength]
+	if sectionKind != mongoSectionKindBody {
+		return
 	}
-	return bsonToString(reader)
-}
-
-func bsonToString(r io.Reader) (res string) {
-	res = "<truncated>"
-	defer func() {
-		recover()
-	}()
-	if raw, err := bson.NewFromIOReader(r); err == nil {
-		res = raw.String()
+	sectionData := payload[mongoHeaderLength+mongoSectionKindLength:]
+	sectionLength := binary.LittleEndian.Uint32(sectionData)
+	if sectionLength < 1 || int(sectionLength) > len(sectionData) {
+		return
 	}
-	return
+	return bson.Raw(sectionData).String()
 }

+ 29 - 4
tracing/tracing_test.go

@@ -2,6 +2,7 @@ package tracing
 
 import (
 	"bytes"
+	"encoding/binary"
 	"github.com/stretchr/testify/assert"
 	"go.mongodb.org/mongo-driver/bson"
 	"testing"
@@ -52,12 +53,36 @@ func Test_parseRedis(t *testing.T) {
 	assert.Equal(t, "mylist", args)
 }
 
+type mongoHeader struct {
+	MessageLength int32
+	RequestID     int32
+	ResponseTo    int32
+	OpCode        int32
+	Flags         int32
+	SectionKind   uint8
+}
+
 func Test_parseMongo(t *testing.T) {
+	buf := bytes.NewBuffer(nil)
 	v := bson.M{"a": "bssssssssssssssssssssssssssssssssssssssssss"}
-	buf := make([]byte, 1024)
 	data, err := bson.Marshal(v)
+
+	h := mongoHeader{
+		MessageLength: 16 + 4 + 1 + int32(len(data)),
+		OpCode:        MongoOpMSG,
+	}
+
+	assert.NoError(t, binary.Write(buf, binary.LittleEndian, h))
+	_, err = buf.Write(data)
 	assert.NoError(t, err)
-	copy(buf, data)
-	assert.Equal(t, `{"a": "bssssssssssssssssssssssssssssssssssssssssss"}`, bsonToString(bytes.NewReader(buf)))
-	assert.Equal(t, `<truncated>`, bsonToString(bytes.NewReader(buf[:20])))
+
+	payload := buf.Bytes()
+
+	assert.Equal(t, `{"a": "bssssssssssssssssssssssssssssssssssssssssss"}`, parseMongo(payload))
+	assert.Equal(t, `<truncated>`, parseMongo(payload[:20]))
+
+	dataSize := binary.LittleEndian.Uint32(data)
+
+	binary.LittleEndian.PutUint32(payload[mongoHeaderLength+mongoSectionKindLength:], dataSize+1)
+	assert.Equal(t, `<truncated>`, parseMongo(payload))
 }