inject_asm_code_arm64.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. //go:build linux && arm64
  2. // +build linux,arm64
  3. package aotinject
  4. /*
  5. #cgo CFLAGS: -I ../inject/include
  6. #cgo amd64 LDFLAGS: ${SRCDIR}/../inject/lib/libhotpatch_amd64.a
  7. #cgo arm64 LDFLAGS: ${SRCDIR}/../inject/lib/libhotpatch_arm64.a
  8. #include "hotpatch.h"
  9. #include <stdlib.h>
  10. */
  11. import "C"
  12. import (
  13. "debug/elf"
  14. "errors"
  15. "fmt"
  16. "syscall"
  17. "unsafe"
  18. "golang.org/x/arch/arm64/arm64asm"
  19. "golang.org/x/arch/x86/x86asm"
  20. )
  21. var PID string
  22. // 定义一个全局变量,类型为 uint64,值为 4
  23. const instLen = 4
  24. const jumpBackAddrOffset = 24
  25. const nopEntryOffset = 12
  26. const nopLen = 4
  27. const longJumpSize = 28
  28. type ProcessMapsInfo struct {
  29. Start, End uint64
  30. Path string
  31. }
  32. type ProcessFunctionInfo struct {
  33. Name string
  34. Offset uint64
  35. Start, Size uint64
  36. }
  37. type InstMemStruct struct {
  38. InstStartAddr uint64
  39. InstEndAddr uint64
  40. CodeArray []byte
  41. InstCodeArray [][]byte
  42. InsertIndex int
  43. }
  44. // 解析code中的指令,并打印每一条指令
  45. func parseAndPrintInstructions(code []byte) {
  46. pc := uint64(0)
  47. for pc < uint64(len(code)) {
  48. inst, err := arm64asm.Decode(code[pc:])
  49. if err != nil {
  50. fmt.Printf("Decode error: %v\n", err)
  51. break
  52. }
  53. fmt.Printf("0x%x:[%d] %s\n", pc, pc, inst.String())
  54. pc += instLen
  55. }
  56. }
  57. func findSendFunctionNameAddr(SymAddr uint64, code []byte, sendFunctionName string, sendSymAddr uint64) (uint64, uint64, error) {
  58. pc := uint64(0)
  59. for pc < uint64(len(code)) {
  60. inst, err := arm64asm.Decode(code[pc:])
  61. if err != nil {
  62. fmt.Printf("Decode error: %v\n", err)
  63. break
  64. }
  65. fmt.Printf("0x%x:[%d] %s, %v\n", SymAddr+pc, pc, inst.String(), inst.Args[0].String())
  66. if inst.Op == arm64asm.BL {
  67. if v, ok := inst.Args[0].(arm64asm.PCRel); ok {
  68. fmt.Printf("%s\n", inst.Args[0].String())
  69. if SymAddr+pc+uint64(v) == sendSymAddr {
  70. fmt.Printf("Found send function call at 0x%x\n", SymAddr+pc)
  71. return SymAddr + pc, pc, nil
  72. }
  73. }
  74. }
  75. pc += instLen
  76. }
  77. return 0, 0, errors.New("Send function Offset not found")
  78. }
  79. // 获取 Send 函数附近的内存数据,用于长跳转到 libmylib 中,需要凑够 17 个字节,先向上查找,再向下查找
  80. func findMemForLongJump(SymAddr uint64, code []byte, sendFunctionName string, sendSymAddr uint64) (InstMemStruct, error) {
  81. // 定义一个结构体,储存 pc、指令、和指令长度
  82. type InstStruct struct {
  83. pc uint64
  84. inst arm64asm.Inst
  85. len int
  86. }
  87. // 定义个字节数组用于储存遍历出来的指令,字节数组
  88. var codeArray []byte
  89. // 定义一个字节数组的数组,用于储存遍历出来的每个指令的字节内容
  90. var InstCodeArray [][]byte
  91. // 定义一个InstStruct数组,用于储存遍历出来的指令
  92. var instArray []InstStruct
  93. // 定义一个标志记录记录找到的Send函数在instArray数组的位置
  94. var sendIndex int
  95. pc := uint64(0)
  96. for pc < uint64(len(code)) {
  97. inst, err := arm64asm.Decode(code[pc:])
  98. if err != nil {
  99. fmt.Printf("Decode error: %v\n", err)
  100. break
  101. }
  102. // 将指令和指令长度存入InstStruct结构体中
  103. instArray = append(instArray, InstStruct{pc: pc, inst: inst, len: instLen})
  104. // 如果地址等于Send函数的地址,记录下Send函数在instArray数组的位置
  105. if SymAddr+pc == sendSymAddr {
  106. sendIndex = len(instArray) - 1
  107. }
  108. pc += instLen
  109. }
  110. // 打印 sendIndex 和 instArray 的长度和 Send 函数的指令
  111. fmt.Printf("sendIndex: %d, instArray len: %d, Send Function: %s\n", sendIndex, len(instArray), instArray[sendIndex].inst.String())
  112. InstStartAddr := uint64(0)
  113. InstEndAddr := uint64(sendSymAddr) + 4
  114. // 从Send函数开始向上查找,凑够17个字节
  115. for i := sendIndex; i >= 0; i-- {
  116. if i != sendIndex && instArray[i].inst.Op == arm64asm.BL || instArray[i].inst.Op == arm64asm.B || instArray[i].inst.Op == arm64asm.BLR || instArray[i].inst.Op == arm64asm.CBNZ || instArray[i].inst.Op == arm64asm.CBZ || instArray[i].inst.Op == arm64asm.RET || instArray[i].inst.Op == arm64asm.TBZ || instArray[i].inst.Op == arm64asm.TBNZ {
  117. break
  118. }
  119. // 将指令存入codeArray数组,指令应该从 code 中的pc开始,长度为instArray[i].len
  120. codeArray = append(codeArray, code[instArray[i].pc:instArray[i].pc+uint64(instArray[i].len)]...)
  121. InstCodeArray = append(InstCodeArray, code[instArray[i].pc:instArray[i].pc+uint64(instArray[i].len)])
  122. // 记录当前指令的地址
  123. InstStartAddr = SymAddr + instArray[i].pc
  124. // 打印本次循环的具体指令和指令长度
  125. fmt.Printf("inst: %s\n", instArray[i].inst.String())
  126. // 打印codeArray数组的长度和内容
  127. fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray)
  128. // 如果codeArray数组长度大于等于17,或者 该指令代码等于 Call、jmp、jne、je、jg、jge、jl、jle、jne、等跳转指令,跳出循环,用于跳转的指令有哪些,可以根据实际情况添加
  129. if len(codeArray) >= longJumpSize {
  130. fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray)
  131. // 打印InstCodeArray数组的长度和内容
  132. fmt.Printf("InstCodeArray len: %d, InstCodeArray: %v\n", len(InstCodeArray), InstCodeArray)
  133. break
  134. }
  135. }
  136. // 将 InstCodeArray 数组进行倒序
  137. for i := 0; i < len(InstCodeArray)/2; i++ {
  138. InstCodeArray[i], InstCodeArray[len(InstCodeArray)-1-i] = InstCodeArray[len(InstCodeArray)-1-i], InstCodeArray[i]
  139. }
  140. insertIndex := len(InstCodeArray) - 1
  141. // 如果 codeArray 的长度没有凑够 17 个字节,则从Send函数开始向下查找,继续凑够17个字节
  142. if len(codeArray) < longJumpSize {
  143. // 从Send函数开始向下查找,凑够17个字节
  144. for i := sendIndex + 1; i < len(instArray); i++ {
  145. if instArray[i].inst.Op == arm64asm.BL || instArray[i].inst.Op == arm64asm.B || instArray[i].inst.Op == arm64asm.BLR || instArray[i].inst.Op == arm64asm.CBNZ || instArray[i].inst.Op == arm64asm.CBZ || instArray[i].inst.Op == arm64asm.RET || instArray[i].inst.Op == arm64asm.TBZ || instArray[i].inst.Op == arm64asm.TBNZ {
  146. break
  147. }
  148. // 将指令存入codeArray数组,指令应该从 code 中的pc开始,长度为instArray[i].len
  149. codeArray = append(codeArray, code[instArray[i].pc:instArray[i].pc+uint64(instArray[i].len)]...)
  150. InstCodeArray = append(InstCodeArray, code[instArray[i].pc:instArray[i].pc+uint64(instArray[i].len)])
  151. // 记录当前指令的地址
  152. InstEndAddr = SymAddr + instArray[i].pc + uint64(instArray[i].len)
  153. // 打印本次循环的具体指令和指令长度
  154. fmt.Printf("inst: %s\n", instArray[i].inst.String())
  155. // 打印codeArray数组的长度和内容
  156. fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray)
  157. // 如果codeArray数组长度大于等于17,或者 该指令代码等于 Call、jmp、jne、je、jg、jge、jl、jle、jne、等跳转指令,跳出循环,用于跳转的指令有哪些,可以根据实际情况添加
  158. if len(codeArray) >= longJumpSize {
  159. fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray)
  160. // 打印InstCodeArray数组的长度和内容
  161. fmt.Printf("InstCodeArray len: %d, InstCodeArray: %v\n", len(InstCodeArray), InstCodeArray)
  162. break
  163. }
  164. }
  165. }
  166. // 打印最终的codeArray数组的长度和内容
  167. fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray)
  168. // 打印最终的InstAddr地址
  169. fmt.Printf("InstAddr: %x-%x Len: %d, InstCodeArray: %v, InstCodeArray len: %d\n", InstStartAddr, InstEndAddr, len(codeArray), InstCodeArray, len(InstCodeArray))
  170. instMemStruct := InstMemStruct{InstStartAddr: InstStartAddr, InstEndAddr: InstEndAddr, CodeArray: codeArray, InstCodeArray: InstCodeArray, InsertIndex: insertIndex}
  171. if len(codeArray) < longJumpSize {
  172. return instMemStruct, errors.New("Not enough memory for long jump")
  173. }
  174. return instMemStruct, nil
  175. }
  176. // 从 PLT 中获取函数的地址,先从改进程的 maps 中的 glibc 中获取对应的函数地址,在反过来从 plt 段中的 jmp 中计算出该函数的地址
  177. func getFunctionOffsetPLT(pid int, libPath, functionName string) (elf.Symbol, error) {
  178. // 获取进程 maps 中 libc.so 的基地址
  179. processLibcMapsInfo, err := getProcessMapsInfo(pid, "libc.so.6")
  180. // 获取 libc 中 send 函数的地址
  181. libcFunction, err := getFunctionOffset(processLibcMapsInfo.Path, functionName)
  182. // 打印 libc 中 send 函数的地址
  183. fmt.Printf("libc send function: %s, addr: %x, addr: %x\n", functionName, libcFunction.Value, libcFunction.Value+processLibcMapsInfo.Start)
  184. processMapsInfo, err := getProcessMapsInfo(pid)
  185. // 打印进程 maps 中的地址
  186. fmt.Printf("process maps: %s, addr: %x\n", processMapsInfo.Path, processMapsInfo.Start)
  187. // 解析 libPath 中的 plt 段,遍历所有的 plt 段的地址,找到 所有 jmp 指令所跳转的地址为 libc 中 send 函数的地址
  188. elfFile, err := elf.Open(libPath)
  189. if err != nil {
  190. return elf.Symbol{}, fmt.Errorf("failed to open ELF file: %v", err)
  191. }
  192. defer elfFile.Close()
  193. // 遍历 plt 段
  194. pltSection := elfFile.Section(".plt")
  195. if pltSection == nil {
  196. return elf.Symbol{}, errors.New(".plt section not found")
  197. }
  198. pltData, err := pltSection.Data()
  199. if err != nil {
  200. return elf.Symbol{}, fmt.Errorf("failed to read .plt section data: %v", err)
  201. }
  202. entrySize := 16
  203. // 计算 plt 段中一共有多少 plt 入口,使用 plt 段的大小除以 plt 入口的大小
  204. // plt 入口的数量为 len(pltData)/entrySize
  205. entryNum := len(pltData) / entrySize
  206. baseJmpAddr := uint64(0)
  207. isSetBaseJmpAddr := false
  208. // 初始化一个 map,key 为指向的地址,value 为 plt 段中的地址
  209. pltEntryMap := make(map[uint64]uint64)
  210. // 遍历 plt 段
  211. for i := 0; i < len(pltData); i += int(entrySize) {
  212. // 获取 plt 段中的地址
  213. addr := pltSection.Addr + uint64(i)
  214. // 打印 addr
  215. fmt.Printf("addr: %x\n", addr)
  216. // 获取 plt 段中的 jmp 指令
  217. jmpInst := pltData[i : i+int(entrySize)]
  218. // 解析 jmp 指令
  219. inst, err := x86asm.Decode(jmpInst, 64)
  220. if err != nil {
  221. fmt.Printf("Decode error: %v\n", err)
  222. break
  223. }
  224. // 判断 inst 是否为 jmp 指令
  225. if inst.Op == x86asm.JMP {
  226. jmpAddr := addr + instLen + uint64(int32(inst.Args[0].(x86asm.Mem).Disp))
  227. pltEntryMap[jmpAddr] = addr
  228. if isSetBaseJmpAddr == false {
  229. // baseAddr = addr
  230. baseJmpAddr = jmpAddr
  231. isSetBaseJmpAddr = true
  232. }
  233. }
  234. }
  235. var data []byte
  236. data = make([]byte, 8*(entryNum+1))
  237. // 使用 ptrace attach 目标进程
  238. fmt.Printf("Attach Process: %d\n", pid)
  239. if err := syscall.PtraceAttach(pid); err != nil {
  240. fmt.Printf("PtraceAttach Err: %v\n", err)
  241. }
  242. // 等待目标进程停止
  243. if _, err := syscall.Wait4(pid, nil, 0, nil); err != nil {
  244. fmt.Printf("wait4: %v", err)
  245. }
  246. dataLen, _ := syscall.PtracePeekData(pid, uintptr(baseJmpAddr+processMapsInfo.Start), data)
  247. fmt.Printf("Ptrace DETACH: %d\n", pid)
  248. // 恢复执行
  249. if err = syscall.PtraceDetach(pid); err != nil {
  250. fmt.Printf("ptrace DETACH: %v", err)
  251. }
  252. // 按 uint64 数组遍历 data
  253. for i := 0; i < dataLen; i += 8 {
  254. // 获取 data 中的地址
  255. addr := *(*uint64)(unsafe.Pointer(&data[i]))
  256. // 打印 data 中的地址
  257. fmt.Printf("addr: 0x%x\n", addr)
  258. // 如果 data 中的地址等于 libc 中 send 函数的地址,返回该地址
  259. if processLibcMapsInfo.Start+libcFunction.Value == addr {
  260. fmt.Printf("Found %s at address: 0x%x\n", functionName, baseJmpAddr)
  261. fmt.Printf("Found %s from address: 0x%x[0x%x]\n", functionName, uint64(i)+baseJmpAddr, processMapsInfo.Start+uint64(i)+baseJmpAddr)
  262. // 判断 pltEntryMap 中是否存在 baseJmpAddr,如果存在,返回 baseJmpAddr
  263. if value, ok := pltEntryMap[uint64(i)+baseJmpAddr]; ok {
  264. fmt.Printf("Found %s at address on plt: 0x%x\n", functionName, value)
  265. return elf.Symbol{Name: functionName, Value: value, Size: 0}, nil
  266. }
  267. }
  268. }
  269. return elf.Symbol{}, fmt.Errorf("function %s not found", functionName)
  270. }
  271. func buildArm64Enc(imm16 uint64, hw uint64) []byte {
  272. PC_START := uint64(56)
  273. rd := 0x10 // x16
  274. fixed := 0x1e5
  275. // Combine all parts into a single 32-bit value
  276. result := uint64(rd<<0) | (imm16 << 5) | (hw << 21) | uint64(fixed<<23)
  277. fmt.Printf("set *(unsigned int*)($origin+%d) = 0x%08x\n", PC_START+hw*4, result)
  278. return []byte{
  279. byte(result),
  280. byte(result >> 8),
  281. byte(result >> 16),
  282. byte(result >> 24),
  283. }
  284. }
  285. // 生成一个函数,用于生成长跳指令插入,使用 x16 寄存器,使用前需要保存 x16 寄存器的值,使用后恢复 x16 寄存器的值,跳转的指令从参数传入,第二个参数为指令长度,如果指令长度不够,需要在后面添加 nop 指令
  286. func generateLongJumpCode(cwFunctionAddr uint64, length int, saveX16 bool) []byte {
  287. jumpCode := []byte{}
  288. if saveX16 {
  289. jumpCode = append(jumpCode, 0xFF, 0x43, 0x00, 0xD1) // sub sp, sp, #16
  290. jumpCode = append(jumpCode, 0xF0, 0x03, 0x00, 0xF9) // str x16, [sp]
  291. }
  292. jumpCode = append(jumpCode, buildArm64Enc(cwFunctionAddr&0xFFFF, 0)...) // movz x16, (cwFunctionAddr & 0xFFFF)
  293. jumpCode = append(jumpCode, buildArm64Enc((cwFunctionAddr>>16)&0xFFFF, 1)...) // movk x16, ((cwFunctionAddr >> 16) & 0xFFFF) LSL #16
  294. jumpCode = append(jumpCode, buildArm64Enc((cwFunctionAddr>>32)&0xFFFF, 2)...) // movk x16, ((cwFunctionAddr >> 16) & 0xFFFF) LSL #16
  295. // jumpCode = append(jumpCode, buildArm64Enc((cwFunctionAddr >> 48) & 0xFFFF, 3)...) // movk x16, ((cwFunctionAddr >> 16) & 0xFFFF) LSL #16
  296. jumpCode = append(jumpCode, 0x00, 0x02, 0x1F, 0xD6) // br x16
  297. if saveX16 {
  298. jumpCode = append(jumpCode, 0xF0, 0x03, 0x40, 0xF9) // ldr x16, [sp]
  299. jumpCode = append(jumpCode, 0xFF, 0x43, 0x00, 0x91) // add sp, sp, #16
  300. }
  301. // Add NOP instructions if the length is not enough
  302. for len(jumpCode) < length {
  303. jumpCode = append(jumpCode, 0x1F, 0x20, 0x03, 0xD5) // NOP instruction
  304. }
  305. return jumpCode
  306. }
  307. // 生成新指令,用于替换原指令,一个参数是原来的指令 instCodeArray,另一个参数是跳转回去的地址,第三个参数是原始 Send 函数的地址,第四个参数是在哪个指令前插入自定义指令
  308. func generateNewCode(instCodeArray [][]byte, jumpBackAddr uint64, sendAddr uint64, insertIndex int) ([]byte, int) {
  309. newCode := []byte{}
  310. hookOffset := 0
  311. for i, inst := range instCodeArray {
  312. if i == insertIndex {
  313. /**
  314. 约定从栈顶拿8个字节储存 ebpf 中配置的 header 长度:先备份 x2 到 -0x08(%sp)
  315. 待 ebpf 将值写到 -0x08(%sp) 之后,再从 -0x08(%sp) 赋值到 x2
  316. str x2, [sp, #-8]!
  317. ldr x2, [sp], #8
  318. **/
  319. newCode = append(newCode, 0xE2, 0x83, 0x1F, 0xF8) // str x2, [sp, #-8]
  320. hookOffset = len(newCode)
  321. newCode = append(newCode, 0x1F, 0x20, 0x03, 0xD5) // NOP instruction
  322. newCode = append(newCode, 0xE2, 0x83, 0x5F, 0xF8) // ldr x2, [sp, #-8]
  323. /**
  324. 生成 Call 指令利用 x16 寄存器进程长调用调用原来的 Send 函数
  325. **/
  326. newCode = append(newCode, buildArm64Enc(sendAddr&0xFFFF, 0)...) // movz x16, (sendAddr & 0xFFFF)
  327. newCode = append(newCode, buildArm64Enc((sendAddr>>16)&0xFFFF, 1)...) // movk x16, ((sendAddr >> 16) & 0xFFFF) LSL #16
  328. newCode = append(newCode, buildArm64Enc((sendAddr>>32)&0xFFFF, 2)...) // movk x16, ((sendAddr >> 16) & 0xFFFF) LSL #16
  329. newCode = append(newCode, buildArm64Enc((sendAddr>>48)&0xFFFF, 3)...) // movk x16, ((sendAddr >> 16) & 0xFFFF) LSL #16
  330. newCode = append(newCode, 0x00, 0x02, 0x3F, 0xD6) // blr x16
  331. // 生成长跳指令,跳转回去
  332. newCode = append(newCode, generateLongJumpCode(jumpBackAddr, longJumpSize, false)...)
  333. continue
  334. }
  335. newCode = append(newCode, inst...)
  336. }
  337. return newCode, hookOffset
  338. }
  339. // 查找nop函数中hookOffset的位置
  340. func findNopFunctionHookOffset(pid int, cwSym *ProcessFunctionInfo) (int, error) {
  341. // 定义一个字节数组内容:0x48, 0x8b, 0x54, 0x24, 0xf8
  342. nopCode := []byte{0xE2, 0x83, 0x5F, 0xF8}
  343. code, err := readMemory(pid, cwSym.Start+nopEntryOffset, cwSym.Size-nopEntryOffset)
  344. if err != nil {
  345. fmt.Printf("readMemory error: %v\n", err)
  346. return 0, err
  347. }
  348. fmt.Printf("nopCode: %v\n", code)
  349. pc := uint64(0)
  350. for pc < uint64(len(code)) {
  351. inst, err := arm64asm.Decode(code[pc:])
  352. if err != nil {
  353. fmt.Printf("Decode error: %v\n", err)
  354. break
  355. }
  356. // 打印对应指令的字节码
  357. fmt.Printf("%s[%d]-[%v]\n", inst.String(), pc, code[pc:pc+instLen])
  358. // 如果指令的字节码等于nopCode,返回当前的pc
  359. if string(code[pc:pc+instLen]) == string(nopCode) {
  360. return int(pc) - nopLen, nil
  361. }
  362. pc += instLen
  363. }
  364. return 0, errors.New("hookOffset not found")
  365. }