port_whitelist.go 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. package flags
  2. import (
  3. "github.com/coroot/coroot-node-agent/ebpftracer/l7"
  4. "strconv"
  5. "strings"
  6. )
  7. // 端口白名单数据结构
  8. type PortWhitelist struct {
  9. Ports map[uint16]bool
  10. }
  11. // 全局端口白名单实例
  12. var (
  13. MysqlPorts *PortWhitelist
  14. MariadbPorts *PortWhitelist
  15. )
  16. // 创建端口白名单
  17. func NewPortWhitelist(portList string) *PortWhitelist {
  18. whitelist := &PortWhitelist{
  19. Ports: make(map[uint16]bool),
  20. }
  21. if portList == "" {
  22. return whitelist
  23. }
  24. ports := strings.Split(portList, ",")
  25. for _, portStr := range ports {
  26. portStr = strings.TrimSpace(portStr)
  27. if port, err := strconv.ParseUint(portStr, 10, 16); err == nil {
  28. whitelist.Ports[uint16(port)] = true
  29. }
  30. }
  31. return whitelist
  32. }
  33. // 检查端口是否在白名单中
  34. func (pw *PortWhitelist) Contains(port uint16) bool {
  35. if pw == nil {
  36. return false
  37. }
  38. return pw.Ports[port]
  39. }
  40. // 获取协议类型(基于端口白名单)
  41. func GetProtocolByPort(port uint16) l7.Protocol {
  42. if MysqlPorts != nil && MysqlPorts.Contains(port) {
  43. return l7.ProtocolMysql
  44. }
  45. if MariadbPorts != nil && MariadbPorts.Contains(port) {
  46. return l7.ProtocolMariaDB
  47. }
  48. // 如果端口不在白名单中,返回默认协议
  49. if *MysqlDefault == "mariadb" {
  50. return l7.ProtocolMariaDB
  51. }
  52. return l7.ProtocolMysql
  53. }
  54. // 初始化端口白名单
  55. func InitPortWhitelists() {
  56. MysqlPorts = NewPortWhitelist(*MysqlPortWhitelist)
  57. MariadbPorts = NewPortWhitelist(*MariadbPortWhitelist)
  58. }