瀏覽代碼

add revalidation of the LISTEN sockets by reading `/proc/net/tcp[6]`

Anton Petruhin 4 年之前
父節點
當前提交
6863d5008b
共有 10 個文件被更改,包括 149 次插入73 次删除
  1. 9 8
      cgroup/cgroup_linux.go
  2. 75 31
      containers/container.go
  3. 4 1
      containers/registry.go
  4. 8 8
      ebpftracer/Vagrantfile
  5. 8 20
      ebpftracer/init.go
  6. 1 0
      ebpftracer/tracer.go
  7. 33 0
      proc/fd.go
  8. 1 0
      proc/fixtures/123/fd/5
  9. 0 4
      proc/proc.go
  10. 10 1
      proc/proc_test.go

+ 9 - 8
cgroup/cgroup_linux.go

@@ -3,38 +3,39 @@ package cgroup
 import (
 	"github.com/vishvananda/netns"
 	"golang.org/x/sys/unix"
-	"k8s.io/klog/v2"
 	"runtime"
 )
 
-func init() {
+func Init() error {
 	selfNs, err := netns.GetFromPath("/proc/self/ns/cgroup")
 	if err != nil {
-		klog.Exitln(err)
+		return err
 	}
 	defer selfNs.Close()
 	hostNs, err := netns.GetFromPath("/proc/1/ns/cgroup")
 	if err != nil {
-		klog.Exitln(err)
+		return err
 	}
 	defer hostNs.Close()
 	if selfNs.Equal(hostNs) {
-		return
+		return nil
 	}
 
 	runtime.LockOSThread()
 	defer runtime.UnlockOSThread()
 	if err := unix.Setns(int(hostNs), unix.CLONE_NEWCGROUP); err != nil {
-		klog.Exitln(err)
+		return err
 	}
 
 	cg, err := NewFromProcessCgroupFile("/proc/self/cgroup")
 	if err != nil {
-		klog.Exitln(err)
+		return err
 	}
 	baseCgroupPath = cg.Id
 
 	if err := unix.Setns(int(selfNs), unix.CLONE_NEWCGROUP); err != nil {
-		klog.Exitln(err)
+		return err
 	}
+
+	return nil
 }

+ 75 - 31
containers/container.go

@@ -334,15 +334,11 @@ func (c *Container) onFileOpen(pid uint32, fd uint32) {
 	}
 }
 
-func (c *Container) onListenOpen(pid uint32, addr netaddr.IPPort) {
-	netNs, err := proc.GetNetNs(pid)
-	isHostNs := err == nil && hostNetNsId == netNs.UniqueId()
-	_ = netNs.Close()
-	if addr.IP().IsLoopback() && !isHostNs {
-		return
+func (c *Container) onListenOpen(pid uint32, addr netaddr.IPPort, safe bool) {
+	if !safe {
+		c.lock.Lock()
+		defer c.lock.Unlock()
 	}
-	c.lock.Lock()
-	defer c.lock.Unlock()
 	if _, ok := c.listens[addr]; !ok {
 		c.listens[addr] = map[uint32]time.Time{}
 	}
@@ -641,7 +637,7 @@ func (c *Container) gc(now time.Time) {
 
 	established := map[AddrPair]struct{}{}
 	establishedDst := map[netaddr.IPPort]struct{}{}
-	listens := map[netaddr.IPPort]struct{}{}
+	listens := map[netaddr.IPPort]string{}
 	for pid := range c.pids {
 		sockets, err := proc.GetSockets(pid)
 		if err != nil {
@@ -649,7 +645,7 @@ func (c *Container) gc(now time.Time) {
 		}
 		for _, s := range sockets {
 			if s.Listen {
-				listens[s.SAddr] = struct{}{}
+				listens[s.SAddr] = s.Inode
 			} else {
 				established[AddrPair{src: s.SAddr, dst: s.DAddr}] = struct{}{}
 				establishedDst[s.DAddr] = struct{}{}
@@ -658,27 +654,7 @@ func (c *Container) gc(now time.Time) {
 		break
 	}
 
-	for addr, byPid := range c.listens {
-		_, open := listens[addr]
-		if open {
-			continue
-		}
-		for pid, closedAt := range byPid {
-			if closedAt.IsZero() {
-				byPid[pid] = now
-			}
-		}
-	}
-	for pid, addrs := range c.listens {
-		for addr, closedAt := range addrs {
-			if !closedAt.IsZero() && now.Sub(closedAt) > gcInterval {
-				delete(c.listens[pid], addr)
-			}
-		}
-		if len(c.listens[pid]) == 0 {
-			delete(c.listens, pid)
-		}
-	}
+	c.revalidateListens(now, listens)
 
 	for srcDst := range c.connectionsActive {
 		if _, ok := established[srcDst]; !ok {
@@ -704,6 +680,74 @@ func (c *Container) gc(now time.Time) {
 	}
 }
 
+func (c *Container) revalidateListens(now time.Time, actualListens map[netaddr.IPPort]string) {
+	for addr, byPid := range c.listens {
+		if _, open := actualListens[addr]; open {
+			continue
+		}
+		klog.Warningln("deleting the outdated listen:", addr)
+		for pid, closedAt := range byPid {
+			if closedAt.IsZero() {
+				byPid[pid] = now
+			}
+		}
+	}
+
+	missingListens := map[netaddr.IPPort]string{}
+	for addr, inode := range actualListens {
+		byPids, found := c.listens[addr]
+		if !found {
+			missingListens[addr] = inode
+			continue
+		}
+		open := false
+		for _, closedAt := range byPids {
+			if closedAt.IsZero() {
+				open = true
+				break
+			}
+		}
+		if !open {
+			missingListens[addr] = inode
+		}
+	}
+
+	if len(missingListens) > 0 {
+		inodeToPid := map[string]uint32{}
+		for pid := range c.pids {
+			fds, err := proc.ReadFds(pid)
+			if err != nil {
+				continue
+			}
+			for _, fd := range fds {
+				if fd.SocketInode != "" {
+					inodeToPid[fd.SocketInode] = pid
+				}
+			}
+		}
+		for addr, inode := range missingListens {
+			pid, found := inodeToPid[inode]
+			if !found {
+				klog.Errorln("failed to determine pid for listen:", addr)
+				continue
+			}
+			klog.Warningln("missing listen found:", addr, pid)
+			c.onListenOpen(pid, addr, true)
+		}
+	}
+
+	for addr, pids := range c.listens {
+		for pid, closedAt := range pids {
+			if !closedAt.IsZero() && now.Sub(closedAt) > gcInterval {
+				delete(c.listens[addr], pid)
+			}
+		}
+		if len(c.listens[addr]) == 0 {
+			delete(c.listens, addr)
+		}
+	}
+}
+
 func resolveFd(pid uint32, fd uint32) (mntId string, logPath string) {
 	info := proc.GetFdInfo(pid, fd)
 	if info == nil {

+ 4 - 1
containers/registry.go

@@ -55,6 +55,9 @@ func NewRegistry(reg prometheus.Registerer, kernelVersion string) (*Registry, er
 	if err != nil {
 		return nil, err
 	}
+	if err := cgroup.Init(); err != nil {
+		return nil, err
+	}
 	if err := DockerdInit(); err != nil {
 		klog.Warningln(err)
 	}
@@ -149,7 +152,7 @@ func (r *Registry) handleEvents(ch <-chan ebpftracer.Event) {
 
 			case ebpftracer.EventTypeListenOpen:
 				if c := r.getOrCreateContainer(e.Pid); c != nil {
-					c.onListenOpen(e.Pid, e.SrcAddr)
+					c.onListenOpen(e.Pid, e.SrcAddr, false)
 				} else {
 					klog.Infoln("TCP listen open from unknown container", e)
 				}

+ 8 - 8
ebpftracer/Vagrantfile

@@ -19,16 +19,16 @@ Vagrant.configure("2") do |config|
 	config.vm.box_check_update = false
 	config.vm.synced_folder "..", "/tmp/src"
 
-	config.vm.define "ubuntu1810" do |ubuntu1810|
-		ubuntu1810.vm.box = "generic/ubuntu1810"
+	config.vm.define "ubuntu1810" do |c|
+		c.vm.box = "generic/ubuntu1810"
     end
-	config.vm.define "ubuntu2004" do |ubuntu2004|
-		ubuntu2004.vm.box = "generic/ubuntu2004"
+	config.vm.define "ubuntu2004" do |c|
+		c.vm.box = "generic/ubuntu2004"
     end
-	config.vm.define "ubuntu2010" do |ubuntu2010|
-		ubuntu2010.vm.box = "generic/ubuntu2010"
+	config.vm.define "ubuntu2010" do |c|
+		c.vm.box = "generic/ubuntu2010"
     end
-    config.vm.define "ubuntu2110" do |ubuntu2110|
-        ubuntu2110.vm.box = "generic/ubuntu2110"
+    config.vm.define "ubuntu2110" do |c|
+        c.vm.box = "generic/ubuntu2110"
     end
 end

+ 8 - 20
ebpftracer/init.go

@@ -3,13 +3,10 @@ package ebpftracer
 import (
 	"github.com/coroot/coroot-node-agent/proc"
 	"k8s.io/klog/v2"
-	"os"
-	"path"
-	"strconv"
 	"strings"
 )
 
-type fd struct {
+type file struct {
 	pid uint32
 	fd  uint32
 }
@@ -19,7 +16,7 @@ type sock struct {
 	proc.Sock
 }
 
-func readFds(pids []uint32) (fds []fd, socks []sock) {
+func readFds(pids []uint32) (files []file, socks []sock) {
 	nss := map[string]map[string]sock{}
 	for _, pid := range pids {
 		ns, err := proc.GetNetNs(pid)
@@ -41,28 +38,19 @@ func readFds(pids []uint32) (fds []fd, socks []sock) {
 			}
 		}
 
-		fdDir := proc.Path(pid, "fd")
-		entries, err := os.ReadDir(fdDir)
+		fds, err := proc.ReadFds(pid)
 		if err != nil {
 			continue
 		}
-		for _, entry := range entries {
-			dest, err := os.Readlink(path.Join(fdDir, entry.Name()))
-			if err != nil {
-				continue
-			}
+		for _, fd := range fds {
 			switch {
-			case strings.HasPrefix(dest, "socket:[") && strings.HasSuffix(dest, "]"):
-				inode := dest[len("socket:[") : len(dest)-1]
-				if s, ok := sockets[inode]; ok {
+			case fd.SocketInode != "":
+				if s, ok := sockets[fd.SocketInode]; ok {
 					s.pid = pid
 					socks = append(socks, s)
 				}
-			default:
-				i, err := strconv.Atoi(entry.Name())
-				if err == nil {
-					fds = append(fds, fd{pid: pid, fd: uint32(i)})
-				}
+			case strings.HasPrefix(fd.Dest, "/"):
+				files = append(files, file{pid: pid, fd: fd.Fd})
 			}
 		}
 	}

+ 1 - 0
ebpftracer/tracer.go

@@ -250,6 +250,7 @@ func runEventsReader(name string, r *perf.Reader, ch chan<- Event, e rawEvent) {
 		}
 		if rec.LostSamples > 0 {
 			klog.Errorln(name, "lost samples:", rec.LostSamples)
+			continue
 		}
 		if err := binary.Read(bytes.NewBuffer(rec.RawSample), binary.LittleEndian, e); err != nil {
 			klog.Warningln("failed to read msg:", err)

+ 33 - 0
proc/fd.go

@@ -2,10 +2,43 @@ package proc
 
 import (
 	"os"
+	"path"
 	"strconv"
 	"strings"
 )
 
+type Fd struct {
+	Fd   uint32
+	Dest string
+
+	SocketInode string
+}
+
+func ReadFds(pid uint32) ([]Fd, error) {
+	fdDir := Path(pid, "fd")
+	entries, err := os.ReadDir(fdDir)
+	if err != nil {
+		return nil, err
+	}
+	res := make([]Fd, 0, len(entries))
+	for _, entry := range entries {
+		fd, err := strconv.Atoi(entry.Name())
+		if err != nil {
+			continue
+		}
+		dest, err := os.Readlink(path.Join(fdDir, entry.Name()))
+		if err != nil {
+			continue
+		}
+		var socketInode string
+		if strings.HasPrefix(dest, "socket:[") && strings.HasSuffix(dest, "]") {
+			socketInode = dest[len("socket:[") : len(dest)-1]
+		}
+		res = append(res, Fd{Fd: uint32(fd), Dest: dest, SocketInode: socketInode})
+	}
+	return res, nil
+}
+
 type FdInfo struct {
 	MntId string
 	Flags int

+ 1 - 0
proc/fixtures/123/fd/5

@@ -0,0 +1 @@
+socket:[321]

+ 0 - 4
proc/proc.go

@@ -49,7 +49,3 @@ func ListPids() ([]uint32, error) {
 	}
 	return res, nil
 }
-
-func SetRoot(path string) {
-	root = path
-}

+ 10 - 1
proc/proc_test.go

@@ -8,7 +8,7 @@ import (
 )
 
 func init() {
-	SetRoot("fixtures")
+	root = "fixtures"
 }
 
 func TestListPids(t *testing.T) {
@@ -30,6 +30,15 @@ func TestGetMountInfo(t *testing.T) {
 	}, res)
 }
 
+func TestReadFds(t *testing.T) {
+	fds, err := ReadFds(123)
+	require.NoError(t, err)
+	assert.Equal(t, []Fd{
+		{Fd: 4, Dest: "/var/lib/postgresql/data/pg_wal/000000010000000000000001"},
+		{Fd: 5, Dest: "socket:[321]", SocketInode: "321"},
+	}, fds)
+}
+
 func TestGetFdInfo(t *testing.T) {
 	res := GetFdInfo(123, 4)
 	assert.Equal(t, FdInfo{