inject_asm_code_arm64.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. //go:build linux && arm64
  2. // +build linux,arm64
  3. package aotinject
  4. /*
  5. #cgo CFLAGS: -I ../inject/include
  6. #cgo amd64 LDFLAGS: ${SRCDIR}/../inject/lib/libhotpatch_amd64.a
  7. #cgo arm64 LDFLAGS: ${SRCDIR}/../inject/lib/libhotpatch_arm64.a
  8. #include "hotpatch.h"
  9. #include <stdlib.h>
  10. */
  11. import "C"
  12. import (
  13. "debug/elf"
  14. "errors"
  15. "fmt"
  16. "syscall"
  17. "unsafe"
  18. "golang.org/x/arch/arm64/arm64asm"
  19. "golang.org/x/arch/x86/x86asm"
  20. )
  21. var PID string
  22. // 定义一个全局变量,类型为 uint64,值为 4
  23. const instLen = 4
  24. const jumpBackAddrOffset = 24
  25. const nopEntryOffset = 12
  26. const nopLen = 4
  27. const longJumpSize = 28
  28. type ProcessMapsInfo struct {
  29. Start, End uint64
  30. Path string
  31. }
  32. type ProcessFunctionInfo struct {
  33. Name string
  34. Start, Size uint64
  35. }
  36. type InstMemStruct struct {
  37. InstStartAddr uint64
  38. InstEndAddr uint64
  39. CodeArray []byte
  40. InstCodeArray [][]byte
  41. InsertIndex int
  42. }
  43. // 解析code中的指令,并打印每一条指令
  44. func parseAndPrintInstructions(code []byte) {
  45. pc := uint64(0)
  46. for pc < uint64(len(code)) {
  47. inst, err := arm64asm.Decode(code[pc:])
  48. if err != nil {
  49. fmt.Printf("Decode error: %v\n", err)
  50. break
  51. }
  52. fmt.Printf("0x%x:[%d] %s\n", pc, pc, inst.String())
  53. pc += instLen
  54. }
  55. }
  56. func findSendFunctionNameAddr(SymAddr uint64, code []byte, sendFunctionName string, sendSymAddr uint64) (uint64, uint64, error) {
  57. pc := uint64(0)
  58. for pc < uint64(len(code)) {
  59. inst, err := arm64asm.Decode(code[pc:])
  60. if err != nil {
  61. fmt.Printf("Decode error: %v\n", err)
  62. break
  63. }
  64. fmt.Printf("0x%x:[%d] %s, %v\n", SymAddr+pc, pc, inst.String(), inst.Args[0].String())
  65. if inst.Op == arm64asm.BL {
  66. if v, ok := inst.Args[0].(arm64asm.PCRel); ok {
  67. fmt.Printf("%s\n", inst.Args[0].String())
  68. if SymAddr+pc+uint64(v) == sendSymAddr {
  69. fmt.Printf("Found send function call at 0x%x\n", SymAddr+pc)
  70. return SymAddr + pc, pc, nil
  71. }
  72. }
  73. }
  74. pc += instLen
  75. }
  76. return 0, 0, errors.New("Send function Offset not found")
  77. }
  78. // 获取 Send 函数附近的内存数据,用于长跳转到 libmylib 中,需要凑够 17 个字节,先向上查找,再向下查找
  79. func findMemForLongJump(SymAddr uint64, code []byte, sendFunctionName string, sendSymAddr uint64) (InstMemStruct, error) {
  80. // 定义一个结构体,储存 pc、指令、和指令长度
  81. type InstStruct struct {
  82. pc uint64
  83. inst arm64asm.Inst
  84. len int
  85. }
  86. // 定义个字节数组用于储存遍历出来的指令,字节数组
  87. var codeArray []byte
  88. // 定义一个字节数组的数组,用于储存遍历出来的每个指令的字节内容
  89. var InstCodeArray [][]byte
  90. // 定义一个InstStruct数组,用于储存遍历出来的指令
  91. var instArray []InstStruct
  92. // 定义一个标志记录记录找到的Send函数在instArray数组的位置
  93. var sendIndex int
  94. pc := uint64(0)
  95. for pc < uint64(len(code)) {
  96. inst, err := arm64asm.Decode(code[pc:])
  97. if err != nil {
  98. fmt.Printf("Decode error: %v\n", err)
  99. break
  100. }
  101. // 将指令和指令长度存入InstStruct结构体中
  102. instArray = append(instArray, InstStruct{pc: pc, inst: inst, len: instLen})
  103. // 如果地址等于Send函数的地址,记录下Send函数在instArray数组的位置
  104. if SymAddr+pc == sendSymAddr {
  105. sendIndex = len(instArray) - 1
  106. }
  107. pc += instLen
  108. }
  109. // 打印 sendIndex 和 instArray 的长度和 Send 函数的指令
  110. fmt.Printf("sendIndex: %d, instArray len: %d, Send Function: %s\n", sendIndex, len(instArray), instArray[sendIndex].inst.String())
  111. InstStartAddr := uint64(0)
  112. InstEndAddr := uint64(sendSymAddr) + 4
  113. // 从Send函数开始向上查找,凑够17个字节
  114. for i := sendIndex; i >= 0; i-- {
  115. if i != sendIndex && instArray[i].inst.Op == arm64asm.BL || instArray[i].inst.Op == arm64asm.B || instArray[i].inst.Op == arm64asm.BLR || instArray[i].inst.Op == arm64asm.CBNZ || instArray[i].inst.Op == arm64asm.CBZ || instArray[i].inst.Op == arm64asm.RET || instArray[i].inst.Op == arm64asm.TBZ || instArray[i].inst.Op == arm64asm.TBNZ {
  116. break
  117. }
  118. // 将指令存入codeArray数组,指令应该从 code 中的pc开始,长度为instArray[i].len
  119. codeArray = append(codeArray, code[instArray[i].pc:instArray[i].pc+uint64(instArray[i].len)]...)
  120. InstCodeArray = append(InstCodeArray, code[instArray[i].pc:instArray[i].pc+uint64(instArray[i].len)])
  121. // 记录当前指令的地址
  122. InstStartAddr = SymAddr + instArray[i].pc
  123. // 打印本次循环的具体指令和指令长度
  124. fmt.Printf("inst: %s\n", instArray[i].inst.String())
  125. // 打印codeArray数组的长度和内容
  126. fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray)
  127. // 如果codeArray数组长度大于等于17,或者 该指令代码等于 Call、jmp、jne、je、jg、jge、jl、jle、jne、等跳转指令,跳出循环,用于跳转的指令有哪些,可以根据实际情况添加
  128. if len(codeArray) >= longJumpSize {
  129. fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray)
  130. // 打印InstCodeArray数组的长度和内容
  131. fmt.Printf("InstCodeArray len: %d, InstCodeArray: %v\n", len(InstCodeArray), InstCodeArray)
  132. break
  133. }
  134. }
  135. // 将 InstCodeArray 数组进行倒序
  136. for i := 0; i < len(InstCodeArray)/2; i++ {
  137. InstCodeArray[i], InstCodeArray[len(InstCodeArray)-1-i] = InstCodeArray[len(InstCodeArray)-1-i], InstCodeArray[i]
  138. }
  139. insertIndex := len(InstCodeArray) - 1
  140. // 如果 codeArray 的长度没有凑够 17 个字节,则从Send函数开始向下查找,继续凑够17个字节
  141. if len(codeArray) < longJumpSize {
  142. // 从Send函数开始向下查找,凑够17个字节
  143. for i := sendIndex + 1; i < len(instArray); i++ {
  144. if instArray[i].inst.Op == arm64asm.BL || instArray[i].inst.Op == arm64asm.B || instArray[i].inst.Op == arm64asm.BLR || instArray[i].inst.Op == arm64asm.CBNZ || instArray[i].inst.Op == arm64asm.CBZ || instArray[i].inst.Op == arm64asm.RET || instArray[i].inst.Op == arm64asm.TBZ || instArray[i].inst.Op == arm64asm.TBNZ {
  145. break
  146. }
  147. // 将指令存入codeArray数组,指令应该从 code 中的pc开始,长度为instArray[i].len
  148. codeArray = append(codeArray, code[instArray[i].pc:instArray[i].pc+uint64(instArray[i].len)]...)
  149. InstCodeArray = append(InstCodeArray, code[instArray[i].pc:instArray[i].pc+uint64(instArray[i].len)])
  150. // 记录当前指令的地址
  151. InstEndAddr = SymAddr + instArray[i].pc + uint64(instArray[i].len)
  152. // 打印本次循环的具体指令和指令长度
  153. fmt.Printf("inst: %s\n", instArray[i].inst.String())
  154. // 打印codeArray数组的长度和内容
  155. fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray)
  156. // 如果codeArray数组长度大于等于17,或者 该指令代码等于 Call、jmp、jne、je、jg、jge、jl、jle、jne、等跳转指令,跳出循环,用于跳转的指令有哪些,可以根据实际情况添加
  157. if len(codeArray) >= longJumpSize {
  158. fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray)
  159. // 打印InstCodeArray数组的长度和内容
  160. fmt.Printf("InstCodeArray len: %d, InstCodeArray: %v\n", len(InstCodeArray), InstCodeArray)
  161. break
  162. }
  163. }
  164. }
  165. // 打印最终的codeArray数组的长度和内容
  166. fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray)
  167. // 打印最终的InstAddr地址
  168. fmt.Printf("InstAddr: %x-%x Len: %d, InstCodeArray: %v, InstCodeArray len: %d\n", InstStartAddr, InstEndAddr, len(codeArray), InstCodeArray, len(InstCodeArray))
  169. instMemStruct := InstMemStruct{InstStartAddr: InstStartAddr, InstEndAddr: InstEndAddr, CodeArray: codeArray, InstCodeArray: InstCodeArray, InsertIndex: insertIndex}
  170. if len(codeArray) < longJumpSize {
  171. return instMemStruct, errors.New("Not enough memory for long jump")
  172. }
  173. return instMemStruct, nil
  174. }
  175. // 从 PLT 中获取函数的地址,先从改进程的 maps 中的 glibc 中获取对应的函数地址,在反过来从 plt 段中的 jmp 中计算出该函数的地址
  176. func getFunctionOffsetPLT(pid int, libPath, functionName string) (elf.Symbol, error) {
  177. // 获取进程 maps 中 libc.so 的基地址
  178. processLibcMapsInfo, err := getProcessMapsInfo(pid, "libc.so.6")
  179. // 获取 libc 中 send 函数的地址
  180. libcFunction, err := getFunctionOffset(processLibcMapsInfo.Path, functionName)
  181. // 打印 libc 中 send 函数的地址
  182. fmt.Printf("libc send function: %s, addr: %x, addr: %x\n", functionName, libcFunction.Value, libcFunction.Value+processLibcMapsInfo.Start)
  183. processMapsInfo, err := getProcessMapsInfo(pid)
  184. // 打印进程 maps 中的地址
  185. fmt.Printf("process maps: %s, addr: %x\n", processMapsInfo.Path, processMapsInfo.Start)
  186. // 解析 libPath 中的 plt 段,遍历所有的 plt 段的地址,找到 所有 jmp 指令所跳转的地址为 libc 中 send 函数的地址
  187. elfFile, err := elf.Open(libPath)
  188. if err != nil {
  189. return elf.Symbol{}, fmt.Errorf("failed to open ELF file: %v", err)
  190. }
  191. defer elfFile.Close()
  192. // 遍历 plt 段
  193. pltSection := elfFile.Section(".plt")
  194. if pltSection == nil {
  195. return elf.Symbol{}, errors.New(".plt section not found")
  196. }
  197. pltData, err := pltSection.Data()
  198. if err != nil {
  199. return elf.Symbol{}, fmt.Errorf("failed to read .plt section data: %v", err)
  200. }
  201. entrySize := 16
  202. // 计算 plt 段中一共有多少 plt 入口,使用 plt 段的大小除以 plt 入口的大小
  203. // plt 入口的数量为 len(pltData)/entrySize
  204. entryNum := len(pltData) / entrySize
  205. baseJmpAddr := uint64(0)
  206. isSetBaseJmpAddr := false
  207. // 初始化一个 map,key 为指向的地址,value 为 plt 段中的地址
  208. pltEntryMap := make(map[uint64]uint64)
  209. // 遍历 plt 段
  210. for i := 0; i < len(pltData); i += int(entrySize) {
  211. // 获取 plt 段中的地址
  212. addr := pltSection.Addr + uint64(i)
  213. // 打印 addr
  214. fmt.Printf("addr: %x\n", addr)
  215. // 获取 plt 段中的 jmp 指令
  216. jmpInst := pltData[i : i+int(entrySize)]
  217. // 解析 jmp 指令
  218. inst, err := x86asm.Decode(jmpInst, 64)
  219. if err != nil {
  220. fmt.Printf("Decode error: %v\n", err)
  221. break
  222. }
  223. // 判断 inst 是否为 jmp 指令
  224. if inst.Op == x86asm.JMP {
  225. jmpAddr := addr + instLen + uint64(int32(inst.Args[0].(x86asm.Mem).Disp))
  226. pltEntryMap[jmpAddr] = addr
  227. if isSetBaseJmpAddr == false {
  228. // baseAddr = addr
  229. baseJmpAddr = jmpAddr
  230. isSetBaseJmpAddr = true
  231. }
  232. }
  233. }
  234. var data []byte
  235. data = make([]byte, 8*(entryNum+1))
  236. // 使用 ptrace attach 目标进程
  237. fmt.Printf("Attach Process: %d\n", pid)
  238. if err := syscall.PtraceAttach(pid); err != nil {
  239. fmt.Printf("PtraceAttach Err: %v\n", err)
  240. }
  241. // 等待目标进程停止
  242. if _, err := syscall.Wait4(pid, nil, 0, nil); err != nil {
  243. fmt.Printf("wait4: %v", err)
  244. }
  245. dataLen, _ := syscall.PtracePeekData(pid, uintptr(baseJmpAddr+processMapsInfo.Start), data)
  246. fmt.Printf("Ptrace DETACH: %d\n", pid)
  247. // 恢复执行
  248. if err = syscall.PtraceDetach(pid); err != nil {
  249. fmt.Printf("ptrace DETACH: %v", err)
  250. }
  251. // 按 uint64 数组遍历 data
  252. for i := 0; i < dataLen; i += 8 {
  253. // 获取 data 中的地址
  254. addr := *(*uint64)(unsafe.Pointer(&data[i]))
  255. // 打印 data 中的地址
  256. fmt.Printf("addr: 0x%x\n", addr)
  257. // 如果 data 中的地址等于 libc 中 send 函数的地址,返回该地址
  258. if processLibcMapsInfo.Start+libcFunction.Value == addr {
  259. fmt.Printf("Found %s at address: 0x%x\n", functionName, baseJmpAddr)
  260. fmt.Printf("Found %s from address: 0x%x[0x%x]\n", functionName, uint64(i)+baseJmpAddr, processMapsInfo.Start+uint64(i)+baseJmpAddr)
  261. // 判断 pltEntryMap 中是否存在 baseJmpAddr,如果存在,返回 baseJmpAddr
  262. if value, ok := pltEntryMap[uint64(i)+baseJmpAddr]; ok {
  263. fmt.Printf("Found %s at address on plt: 0x%x\n", functionName, value)
  264. return elf.Symbol{Name: functionName, Value: value, Size: 0}, nil
  265. }
  266. }
  267. }
  268. return elf.Symbol{}, fmt.Errorf("function %s not found", functionName)
  269. }
  270. func buildArm64Enc(imm16 uint64, hw uint64) []byte {
  271. PC_START := uint64(56)
  272. rd := 0x10 // x16
  273. fixed := 0x1e5
  274. // Combine all parts into a single 32-bit value
  275. result := uint64(rd<<0) | (imm16 << 5) | (hw << 21) | uint64(fixed<<23)
  276. fmt.Printf("set *(unsigned int*)($origin+%d) = 0x%08x\n", PC_START+hw*4, result)
  277. return []byte{
  278. byte(result),
  279. byte(result >> 8),
  280. byte(result >> 16),
  281. byte(result >> 24),
  282. }
  283. }
  284. // 生成一个函数,用于生成长跳指令插入,使用 x16 寄存器,使用前需要保存 x16 寄存器的值,使用后恢复 x16 寄存器的值,跳转的指令从参数传入,第二个参数为指令长度,如果指令长度不够,需要在后面添加 nop 指令
  285. func generateLongJumpCode(cwFunctionAddr uint64, length int, saveX16 bool) []byte {
  286. jumpCode := []byte{}
  287. if saveX16 {
  288. jumpCode = append(jumpCode, 0xFF, 0x43, 0x00, 0xD1) // sub sp, sp, #16
  289. jumpCode = append(jumpCode, 0xF0, 0x03, 0x00, 0xF9) // str x16, [sp]
  290. }
  291. jumpCode = append(jumpCode, buildArm64Enc(cwFunctionAddr&0xFFFF, 0)...) // movz x16, (cwFunctionAddr & 0xFFFF)
  292. jumpCode = append(jumpCode, buildArm64Enc((cwFunctionAddr>>16)&0xFFFF, 1)...) // movk x16, ((cwFunctionAddr >> 16) & 0xFFFF) LSL #16
  293. jumpCode = append(jumpCode, buildArm64Enc((cwFunctionAddr>>32)&0xFFFF, 2)...) // movk x16, ((cwFunctionAddr >> 16) & 0xFFFF) LSL #16
  294. // jumpCode = append(jumpCode, buildArm64Enc((cwFunctionAddr >> 48) & 0xFFFF, 3)...) // movk x16, ((cwFunctionAddr >> 16) & 0xFFFF) LSL #16
  295. jumpCode = append(jumpCode, 0x00, 0x02, 0x1F, 0xD6) // br x16
  296. if saveX16 {
  297. jumpCode = append(jumpCode, 0xF0, 0x03, 0x40, 0xF9) // ldr x16, [sp]
  298. jumpCode = append(jumpCode, 0xFF, 0x43, 0x00, 0x91) // add sp, sp, #16
  299. }
  300. // Add NOP instructions if the length is not enough
  301. for len(jumpCode) < length {
  302. jumpCode = append(jumpCode, 0x1F, 0x20, 0x03, 0xD5) // NOP instruction
  303. }
  304. return jumpCode
  305. }
  306. // 生成新指令,用于替换原指令,一个参数是原来的指令 instCodeArray,另一个参数是跳转回去的地址,第三个参数是原始 Send 函数的地址,第四个参数是在哪个指令前插入自定义指令
  307. func generateNewCode(instCodeArray [][]byte, jumpBackAddr uint64, sendAddr uint64, insertIndex int) ([]byte, int) {
  308. newCode := []byte{}
  309. hookOffset := 0
  310. for i, inst := range instCodeArray {
  311. if i == insertIndex {
  312. /**
  313. 约定从栈顶拿8个字节储存 ebpf 中配置的 header 长度:先备份 x2 到 -0x08(%sp)
  314. 待 ebpf 将值写到 -0x08(%sp) 之后,再从 -0x08(%sp) 赋值到 x2
  315. str x2, [sp, #-8]!
  316. ldr x2, [sp], #8
  317. **/
  318. newCode = append(newCode, 0xE2, 0x83, 0x1F, 0xF8) // str x2, [sp, #-8]
  319. hookOffset = len(newCode)
  320. newCode = append(newCode, 0x1F, 0x20, 0x03, 0xD5) // NOP instruction
  321. newCode = append(newCode, 0xE2, 0x83, 0x5F, 0xF8) // ldr x2, [sp, #-8]
  322. /**
  323. 生成 Call 指令利用 x16 寄存器进程长调用调用原来的 Send 函数
  324. **/
  325. newCode = append(newCode, buildArm64Enc(sendAddr&0xFFFF, 0)...) // movz x16, (sendAddr & 0xFFFF)
  326. newCode = append(newCode, buildArm64Enc((sendAddr>>16)&0xFFFF, 1)...) // movk x16, ((sendAddr >> 16) & 0xFFFF) LSL #16
  327. newCode = append(newCode, buildArm64Enc((sendAddr>>32)&0xFFFF, 2)...) // movk x16, ((sendAddr >> 16) & 0xFFFF) LSL #16
  328. newCode = append(newCode, buildArm64Enc((sendAddr>>48)&0xFFFF, 3)...) // movk x16, ((sendAddr >> 16) & 0xFFFF) LSL #16
  329. newCode = append(newCode, 0x00, 0x02, 0x3F, 0xD6) // blr x16
  330. // 生成长跳指令,跳转回去
  331. newCode = append(newCode, generateLongJumpCode(jumpBackAddr, longJumpSize, false)...)
  332. continue
  333. }
  334. newCode = append(newCode, inst...)
  335. }
  336. return newCode, hookOffset
  337. }
  338. // 查找nop函数中hookOffset的位置
  339. func findNopFunctionHookOffset(pid int, cwSym *ProcessFunctionInfo) (int, error) {
  340. // 定义一个字节数组内容:0x48, 0x8b, 0x54, 0x24, 0xf8
  341. nopCode := []byte{0xE2, 0x83, 0x5F, 0xF8}
  342. code, err := readMemory(pid, cwSym.Start+nopEntryOffset, cwSym.Size-nopEntryOffset)
  343. if err != nil {
  344. fmt.Printf("readMemory error: %v\n", err)
  345. return 0, err
  346. }
  347. fmt.Printf("nopCode: %v\n", code)
  348. pc := uint64(0)
  349. for pc < uint64(len(code)) {
  350. inst, err := arm64asm.Decode(code[pc:])
  351. if err != nil {
  352. fmt.Printf("Decode error: %v\n", err)
  353. break
  354. }
  355. // 打印对应指令的字节码
  356. fmt.Printf("%s[%d]-[%v]\n", inst.String(), pc, code[pc:pc+instLen])
  357. // 如果指令的字节码等于nopCode,返回当前的pc
  358. if string(code[pc:pc+instLen]) == string(nopCode) {
  359. return int(pc) - nopLen, nil
  360. }
  361. pc += instLen
  362. }
  363. return 0, errors.New("hookOffset not found")
  364. }