inject_asm_code_amd64.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. //go:build linux && amd64
  2. // +build linux,amd64
  3. package aotinject
  4. /*
  5. #cgo CFLAGS: -I ../inject/include
  6. #cgo amd64 LDFLAGS: ${SRCDIR}/../inject/lib/libhotpatch_amd64.a
  7. #cgo arm64 LDFLAGS: ${SRCDIR}/../inject/lib/libhotpatch_arm64.a
  8. #include "hotpatch.h"
  9. #include <stdlib.h>
  10. */
  11. import "C"
  12. import (
  13. "debug/elf"
  14. "errors"
  15. "fmt"
  16. "syscall"
  17. "unsafe"
  18. "golang.org/x/arch/x86/x86asm"
  19. )
  20. var PID string
  21. const jumpBackAddrOffset = 15
  22. const nopEntryOffset = 12
  23. const nopLen = 1
  24. const longJumpSize = 17
  25. type ProcessMapsInfo struct {
  26. Start, End uint64
  27. Path string
  28. }
  29. type ProcessFunctionInfo struct {
  30. Name string
  31. Offset uint64
  32. Start, Size uint64
  33. }
  34. type InstMemStruct struct {
  35. InstStartAddr uint64
  36. InstEndAddr uint64
  37. CodeArray []byte
  38. InstCodeArray [][]byte
  39. InsertIndex int
  40. }
  41. // 解析code中的指令,并打印每一条指令
  42. func parseAndPrintInstructions(code []byte) {
  43. pc := uint64(0)
  44. for pc < uint64(len(code)) {
  45. inst, err := x86asm.Decode(code[pc:], 64)
  46. if err != nil {
  47. fmt.Printf("Decode error: %v\n", err)
  48. break
  49. }
  50. fmt.Printf("0x%x:[%d] %s\n", pc, pc, inst.String())
  51. pc += uint64(inst.Len)
  52. }
  53. }
  54. func findSendFunctionNameAddr(SymAddr uint64, code []byte, sendFunctionName string, sendSymAddr uint64) (uint64, uint64, error) {
  55. pc := uint64(0)
  56. for pc < uint64(len(code)) {
  57. inst, err := x86asm.Decode(code[pc:], 64)
  58. if err != nil {
  59. fmt.Printf("Decode error: %v\n", err)
  60. break
  61. }
  62. fmt.Printf("0x%x:[%d] %s\n", SymAddr+pc, pc, inst.String())
  63. if inst.Op == x86asm.CALL {
  64. if v, ok := inst.Args[0].(x86asm.Rel); ok {
  65. fmt.Printf("%s\n", inst.Args[0].String())
  66. if SymAddr+pc+uint64(v)+5 == sendSymAddr {
  67. fmt.Printf("Found send function call at 0x%x\n", SymAddr+pc)
  68. return SymAddr + pc, pc, nil
  69. }
  70. }
  71. }
  72. pc += uint64(inst.Len)
  73. }
  74. return 0, 0, errors.New("Send function Offset not found")
  75. }
  76. // 获取 Send 函数附近的内存数据,用于长跳转到 libmylib 中,需要凑够 17 个字节,先向上查找,再向下查找
  77. func findMemForLongJump(SymAddr uint64, code []byte, sendFunctionName string, sendSymAddr uint64) (InstMemStruct, error) {
  78. // 定义一个结构体,储存 pc、指令、和指令长度
  79. type InstStruct struct {
  80. pc uint64
  81. inst x86asm.Inst
  82. len int
  83. }
  84. // 定义个字节数组用于储存遍历出来的指令,字节数组
  85. var codeArray []byte
  86. // 定义一个字节数组的数组,用于储存遍历出来的每个指令的字节内容
  87. var InstCodeArray [][]byte
  88. // 定义一个InstStruct数组,用于储存遍历出来的指令
  89. var instArray []InstStruct
  90. // 定义一个标志记录记录找到的Send函数在instArray数组的位置
  91. var sendIndex int
  92. pc := uint64(0)
  93. for pc < uint64(len(code)) {
  94. inst, err := x86asm.Decode(code[pc:], 64)
  95. if err != nil {
  96. fmt.Printf("Decode error: %v\n", err)
  97. break
  98. }
  99. // 将指令和指令长度存入InstStruct结构体中
  100. instArray = append(instArray, InstStruct{pc: pc, inst: inst, len: inst.Len})
  101. // 如果地址等于Send函数的地址,记录下Send函数在instArray数组的位置
  102. if SymAddr+pc == sendSymAddr {
  103. sendIndex = len(instArray) - 1
  104. }
  105. pc += uint64(inst.Len)
  106. }
  107. // 打印 sendIndex 和 instArray 的长度和 Send 函数的指令
  108. fmt.Printf("sendIndex: %d, instArray len: %d, Send Function: %s\n", sendIndex, len(instArray), instArray[sendIndex].inst.String())
  109. InstStartAddr := uint64(0)
  110. InstEndAddr := uint64(sendSymAddr) + 5
  111. // 从Send函数开始向上查找,凑够17个字节
  112. for i := sendIndex; i >= 0; i-- {
  113. if i != sendIndex && instArray[i].inst.Op == x86asm.CALL || instArray[i].inst.Op == x86asm.JMP || instArray[i].inst.Op == x86asm.JE || instArray[i].inst.Op == x86asm.JG || instArray[i].inst.Op == x86asm.JGE || instArray[i].inst.Op == x86asm.JL || instArray[i].inst.Op == x86asm.JLE || instArray[i].inst.Op == x86asm.JNE {
  114. break
  115. }
  116. // 将指令存入codeArray数组,指令应该从 code 中的pc开始,长度为instArray[i].len
  117. codeArray = append(codeArray, code[instArray[i].pc:instArray[i].pc+uint64(instArray[i].len)]...)
  118. InstCodeArray = append(InstCodeArray, code[instArray[i].pc:instArray[i].pc+uint64(instArray[i].len)])
  119. // 记录当前指令的地址
  120. InstStartAddr = SymAddr + instArray[i].pc
  121. // 打印本次循环的具体指令和指令长度
  122. fmt.Printf("inst: %s\n", instArray[i].inst.String())
  123. // 打印codeArray数组的长度和内容
  124. fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray)
  125. // 如果codeArray数组长度大于等于17,或者 该指令代码等于 Call、jmp、jne、je、jg、jge、jl、jle、jne、等跳转指令,跳出循环,用于跳转的指令有哪些,可以根据实际情况添加
  126. if len(codeArray) >= longJumpSize {
  127. fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray)
  128. // 打印InstCodeArray数组的长度和内容
  129. fmt.Printf("InstCodeArray len: %d, InstCodeArray: %v\n", len(InstCodeArray), InstCodeArray)
  130. break
  131. }
  132. }
  133. // 将 InstCodeArray 数组进行倒序
  134. for i := 0; i < len(InstCodeArray)/2; i++ {
  135. InstCodeArray[i], InstCodeArray[len(InstCodeArray)-1-i] = InstCodeArray[len(InstCodeArray)-1-i], InstCodeArray[i]
  136. }
  137. insertIndex := len(InstCodeArray) - 1
  138. // 如果 codeArray 的长度没有凑够 17 个字节,则从Send函数开始向下查找,继续凑够17个字节
  139. if len(codeArray) < longJumpSize {
  140. // 从Send函数开始向下查找,凑够17个字节
  141. for i := sendIndex + 1; i < len(instArray); i++ {
  142. if instArray[i].inst.Op == x86asm.CALL || instArray[i].inst.Op == x86asm.JMP || instArray[i].inst.Op == x86asm.JE || instArray[i].inst.Op == x86asm.JG || instArray[i].inst.Op == x86asm.JGE || instArray[i].inst.Op == x86asm.JL || instArray[i].inst.Op == x86asm.JLE || instArray[i].inst.Op == x86asm.JNE {
  143. break
  144. }
  145. // 将指令存入codeArray数组,指令应该从 code 中的pc开始,长度为instArray[i].len
  146. codeArray = append(codeArray, code[instArray[i].pc:instArray[i].pc+uint64(instArray[i].len)]...)
  147. InstCodeArray = append(InstCodeArray, code[instArray[i].pc:instArray[i].pc+uint64(instArray[i].len)])
  148. // 记录当前指令的地址
  149. InstEndAddr = SymAddr + instArray[i].pc + uint64(instArray[i].len)
  150. // 打印本次循环的具体指令和指令长度
  151. fmt.Printf("inst: %s\n", instArray[i].inst.String())
  152. // 打印codeArray数组的长度和内容
  153. fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray)
  154. // 如果codeArray数组长度大于等于17,或者 该指令代码等于 Call、jmp、jne、je、jg、jge、jl、jle、jne、等跳转指令,跳出循环,用于跳转的指令有哪些,可以根据实际情况添加
  155. if len(codeArray) >= longJumpSize {
  156. fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray)
  157. // 打印InstCodeArray数组的长度和内容
  158. fmt.Printf("InstCodeArray len: %d, InstCodeArray: %v\n", len(InstCodeArray), InstCodeArray)
  159. break
  160. }
  161. }
  162. }
  163. // 打印最终的codeArray数组的长度和内容
  164. fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray)
  165. // 打印最终的InstAddr地址
  166. fmt.Printf("InstAddr: %x-%x Len: %d, InstCodeArray: %v, InstCodeArray len: %d\n", InstStartAddr, InstEndAddr, len(codeArray), InstCodeArray, len(InstCodeArray))
  167. instMemStruct := InstMemStruct{InstStartAddr: InstStartAddr, InstEndAddr: InstEndAddr, CodeArray: codeArray, InstCodeArray: InstCodeArray, InsertIndex: insertIndex}
  168. if len(codeArray) < longJumpSize {
  169. return instMemStruct, errors.New("Not enough memory for long jump")
  170. }
  171. return instMemStruct, nil
  172. }
  173. // 从 PLT 中获取函数的地址,先从改进程的 maps 中的 glibc 中获取对应的函数地址,在反过来从 plt 段中的 jmp 中计算出该函数的地址
  174. func getFunctionOffsetPLT(pid int, libPath, functionName string) (elf.Symbol, error) {
  175. // 获取进程 maps 中 libc.so 的基地址
  176. processLibcMapsInfo, err := getProcessMapsInfo(pid, "libc.so.6")
  177. // 获取 libc 中 send 函数的地址
  178. libcFunction, err := getFunctionOffset(processLibcMapsInfo.Path, functionName)
  179. // 打印 libc 中 send 函数的地址
  180. fmt.Printf("libc send function: %s, addr: %x, addr: %x\n", functionName, libcFunction.Value, libcFunction.Value+processLibcMapsInfo.Start)
  181. processMapsInfo, err := getProcessMapsInfo(pid)
  182. // 打印进程 maps 中的地址
  183. fmt.Printf("process maps: %s, addr: %x\n", processMapsInfo.Path, processMapsInfo.Start)
  184. // 解析 libPath 中的 plt 段,遍历所有的 plt 段的地址,找到 所有 jmp 指令所跳转的地址为 libc 中 send 函数的地址
  185. elfFile, err := elf.Open(libPath)
  186. if err != nil {
  187. return elf.Symbol{}, fmt.Errorf("failed to open ELF file: %v", err)
  188. }
  189. defer elfFile.Close()
  190. // 遍历 plt 段
  191. pltSection := elfFile.Section(".plt")
  192. if pltSection == nil {
  193. return elf.Symbol{}, errors.New(".plt section not found")
  194. }
  195. pltData, err := pltSection.Data()
  196. if err != nil {
  197. return elf.Symbol{}, fmt.Errorf("failed to read .plt section data: %v", err)
  198. }
  199. entrySize := 16
  200. // 计算 plt 段中一共有多少 plt 入口,使用 plt 段的大小除以 plt 入口的大小
  201. // plt 入口的数量为 len(pltData)/entrySize
  202. entryNum := len(pltData) / entrySize
  203. baseJmpAddr := uint64(0)
  204. isSetBaseJmpAddr := false
  205. // 初始化一个 map,key 为指向的地址,value 为 plt 段中的地址
  206. pltEntryMap := make(map[uint64]uint64)
  207. // 遍历 plt 段
  208. for i := 0; i < len(pltData); i += int(entrySize) {
  209. // 获取 plt 段中的地址
  210. addr := pltSection.Addr + uint64(i)
  211. // 打印 addr
  212. fmt.Printf("addr: %x\n", addr)
  213. // 获取 plt 段中的 jmp 指令
  214. jmpInst := pltData[i : i+int(entrySize)]
  215. // 解析 jmp 指令
  216. inst, err := x86asm.Decode(jmpInst, 64)
  217. if err != nil {
  218. fmt.Printf("Decode error: %v\n", err)
  219. break
  220. }
  221. // 判断 inst 是否为 jmp 指令
  222. if inst.Op == x86asm.JMP {
  223. jmpAddr := addr + uint64(inst.Len) + uint64(int32(inst.Args[0].(x86asm.Mem).Disp))
  224. pltEntryMap[jmpAddr] = addr
  225. if isSetBaseJmpAddr == false {
  226. // baseAddr = addr
  227. baseJmpAddr = jmpAddr
  228. isSetBaseJmpAddr = true
  229. }
  230. }
  231. }
  232. var data []byte
  233. data = make([]byte, 8*(entryNum+1))
  234. // 使用 ptrace attach 目标进程
  235. fmt.Printf("Attach Process: %d\n", pid)
  236. if err := syscall.PtraceAttach(pid); err != nil {
  237. fmt.Printf("PtraceAttach Err: %v\n", err)
  238. }
  239. // 等待目标进程停止
  240. if _, err := syscall.Wait4(pid, nil, 0, nil); err != nil {
  241. fmt.Printf("wait4: %v", err)
  242. }
  243. dataLen, _ := syscall.PtracePeekData(pid, uintptr(baseJmpAddr+processMapsInfo.Start), data)
  244. fmt.Printf("Ptrace DETACH: %d\n", pid)
  245. // 恢复执行
  246. if err = syscall.PtraceDetach(pid); err != nil {
  247. fmt.Printf("ptrace DETACH: %v", err)
  248. }
  249. // 按 uint64 数组遍历 data
  250. for i := 0; i < dataLen; i += 8 {
  251. // 获取 data 中的地址
  252. addr := *(*uint64)(unsafe.Pointer(&data[i]))
  253. // 打印 data 中的地址
  254. fmt.Printf("addr: 0x%x\n", addr)
  255. // 如果 data 中的地址等于 libc 中 send 函数的地址,返回该地址
  256. if processLibcMapsInfo.Start+libcFunction.Value == addr {
  257. fmt.Printf("Found %s at address: 0x%x\n", functionName, baseJmpAddr)
  258. fmt.Printf("Found %s from address: 0x%x[0x%x]\n", functionName, uint64(i)+baseJmpAddr, processMapsInfo.Start+uint64(i)+baseJmpAddr)
  259. // 判断 pltEntryMap 中是否存在 baseJmpAddr,如果存在,返回 baseJmpAddr
  260. if value, ok := pltEntryMap[uint64(i)+baseJmpAddr]; ok {
  261. fmt.Printf("Found %s at address on plt: 0x%x\n", functionName, value)
  262. return elf.Symbol{Name: functionName, Value: value, Size: 0}, nil
  263. }
  264. }
  265. }
  266. return elf.Symbol{}, fmt.Errorf("function %s not found", functionName)
  267. }
  268. // 生成一个函数,用于生成长跳指令插入,使用 r11 寄存器,使用前需要保存 r11 寄存器的值,使用后恢复 r11 寄存器的值,跳转的指令从参数传入,第二个参数为指令长度,如果指令长度不够,需要在后面添加 nop 指令
  269. func generateLongJumpCode(cwFunctionAddr uint64, length int, saveR11 bool) []byte {
  270. jumpCode := []byte{}
  271. if saveR11 {
  272. jumpCode = append(jumpCode, 0x41, 0x53) // push r11
  273. }
  274. jumpCode = append(jumpCode,
  275. 0x49, 0xbb, byte(cwFunctionAddr), byte(cwFunctionAddr>>8), byte(cwFunctionAddr>>16), byte(cwFunctionAddr>>24),
  276. byte(cwFunctionAddr>>32), byte(cwFunctionAddr>>40), byte(cwFunctionAddr>>48), byte(cwFunctionAddr>>56), // movabs r11, cwFunctionAddr
  277. 0x41, 0xff, 0xe3, // jmp r11
  278. )
  279. if saveR11 {
  280. jumpCode = append(jumpCode, 0x41, 0x5b) // pop r11
  281. }
  282. // Add NOP instructions if the length is not enough
  283. for len(jumpCode) < length {
  284. jumpCode = append(jumpCode, 0x90) // NOP instruction
  285. }
  286. return jumpCode
  287. }
  288. // 生成新指令,用于替换原指令,一个参数是原来的指令 instCodeArray,另一个参数是跳转回去的地址,第三个参数是原始 Send 函数的地址,第四个参数是在哪个指令前插入自定义指令
  289. func generateNewCode(instCodeArray [][]byte, jumpBackAddr uint64, sendAddr uint64, insertIndex int) ([]byte, int) {
  290. newCode := []byte{}
  291. hookOffset := 0
  292. for i, inst := range instCodeArray {
  293. if i == insertIndex {
  294. /**
  295. 约定从栈顶拿8个字节储存 ebpf 中配置的 header 长度:先备份 rdx 到 -0x08(%rsp)
  296. 待 ebpf 将值写到 -0x08(%rsp) 之后,再从 -0x08(%rsp) 赋值到 rdx
  297. mov %rdx, -0x08(%rsp)
  298. mov -0x08(%rsp), %rdx
  299. **/
  300. newCode = append(newCode, 0x48, 0x89, 0x54, 0x24, 0xf8) // mov %rdx, -0x08(%rsp)
  301. hookOffset = len(newCode)
  302. newCode = append(newCode, 0x90) // nop
  303. newCode = append(newCode, 0x48, 0x8b, 0x54, 0x24, 0xf8) // mov -0x08(%rsp), %rdx
  304. /**
  305. 生成 Call 指令利用 r11 寄存器进程长调用调用原来的 Send 函数
  306. **/
  307. CallSendCode := []byte{
  308. 0x41, 0x53, // push r11
  309. 0x49, 0xbb, byte(sendAddr), byte(sendAddr >> 8), byte(sendAddr >> 16), byte(sendAddr >> 24), byte(sendAddr >> 32), byte(sendAddr >> 40), byte(sendAddr >> 48), byte(sendAddr >> 56), // movabs r11, sendAddr
  310. 0x41, 0xff, 0xd3, // call r11
  311. 0x41, 0x5b, // pop r11
  312. }
  313. newCode = append(newCode, CallSendCode...)
  314. // 生成长跳指令,跳转回去
  315. newCode = append(newCode, generateLongJumpCode(jumpBackAddr, longJumpSize, false)...)
  316. continue
  317. }
  318. newCode = append(newCode, inst...)
  319. }
  320. return newCode, hookOffset
  321. }
  322. // 查找nop函数中hookOffset的位置
  323. func findNopFunctionHookOffset(pid int, cwSym *ProcessFunctionInfo) (int, error) {
  324. // 定义一个字节数组内容:0x48, 0x8b, 0x54, 0x24, 0xf8
  325. nopCode := []byte{0x48, 0x8b, 0x54, 0x24, 0xf8}
  326. code, err := readMemory(pid, cwSym.Start+nopEntryOffset, cwSym.Size-nopEntryOffset)
  327. if err != nil {
  328. fmt.Printf("readMemory error: %v\n", err)
  329. return 0, err
  330. }
  331. fmt.Printf("nopCode: %v\n", code)
  332. pc := uint64(0)
  333. for pc < uint64(len(code)) {
  334. inst, err := x86asm.Decode(code[pc:], 64)
  335. if err != nil {
  336. fmt.Printf("Decode error: %v\n", err)
  337. break
  338. }
  339. // 打印对应指令的字节码
  340. fmt.Printf("%s[%d]-[%v]\n", inst.String(), pc, code[pc:pc+uint64(inst.Len)])
  341. // 如果指令的字节码等于nopCode,返回当前的pc
  342. if string(code[pc:pc+uint64(inst.Len)]) == string(nopCode) {
  343. return int(pc) - nopLen, nil
  344. }
  345. pc += uint64(inst.Len)
  346. }
  347. return 0, errors.New("hookOffset not found")
  348. }