port_whitelist.go 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. package flags
  2. import (
  3. "strconv"
  4. "strings"
  5. "github.com/coroot/coroot-node-agent/ebpftracer/l7"
  6. )
  7. // 端口白名单数据结构
  8. type PortWhitelist struct {
  9. Ports map[uint16]bool
  10. }
  11. // 全局端口白名单实例
  12. var (
  13. MysqlPorts *PortWhitelist
  14. MariadbPorts *PortWhitelist
  15. TidbPorts *PortWhitelist
  16. )
  17. // 创建端口白名单
  18. func NewPortWhitelist(portList string) *PortWhitelist {
  19. whitelist := &PortWhitelist{
  20. Ports: make(map[uint16]bool),
  21. }
  22. if portList == "" {
  23. return whitelist
  24. }
  25. ports := strings.Split(portList, ",")
  26. for _, portStr := range ports {
  27. portStr = strings.TrimSpace(portStr)
  28. if port, err := strconv.ParseUint(portStr, 10, 16); err == nil {
  29. whitelist.Ports[uint16(port)] = true
  30. }
  31. }
  32. return whitelist
  33. }
  34. // 检查端口是否在白名单中
  35. func (pw *PortWhitelist) Contains(port uint16) bool {
  36. if pw == nil {
  37. return false
  38. }
  39. return pw.Ports[port]
  40. }
  41. // 获取协议类型(基于端口白名单)
  42. func GetProtocolByPort(port uint16) l7.Protocol {
  43. if MysqlPorts != nil && MysqlPorts.Contains(port) {
  44. return l7.ProtocolMysql
  45. }
  46. if MariadbPorts != nil && MariadbPorts.Contains(port) {
  47. return l7.ProtocolMariaDB
  48. }
  49. if TidbPorts != nil && TidbPorts.Contains(port) {
  50. return l7.ProtocolTiDB
  51. }
  52. // 如果端口不在白名单中,返回默认协议
  53. if *MysqlDefault == "mariadb" {
  54. return l7.ProtocolMariaDB
  55. }
  56. if *MysqlDefault == "tidb" {
  57. return l7.ProtocolTiDB
  58. }
  59. return l7.ProtocolMysql
  60. }
  61. // 初始化端口白名单
  62. func InitPortWhitelists() {
  63. MysqlPorts = NewPortWhitelist(*MysqlPortWhitelist)
  64. MariadbPorts = NewPortWhitelist(*MariadbPortWhitelist)
  65. TidbPorts = NewPortWhitelist(*TidbPortWhitelist)
  66. }