//go:build linux && amd64 // +build linux,amd64 package aotinject /* #cgo CFLAGS: -I ../inject/include #cgo amd64 LDFLAGS: ${SRCDIR}/../inject/lib/libhotpatch_amd64.a #cgo arm64 LDFLAGS: ${SRCDIR}/../inject/lib/libhotpatch_arm64.a #include "hotpatch.h" #include */ import "C" import ( "debug/elf" "errors" "fmt" "syscall" "unsafe" "golang.org/x/arch/x86/x86asm" ) var PID string const jumpBackAddrOffset = 15 const nopEntryOffset = 12 const nopLen = 1 const longJumpSize = 17 type ProcessMapsInfo struct { Start, End uint64 Path string } type ProcessFunctionInfo struct { Name string Offset uint64 Start, Size uint64 } type InstMemStruct struct { InstStartAddr uint64 InstEndAddr uint64 CodeArray []byte InstCodeArray [][]byte InsertIndex int } // 解析code中的指令,并打印每一条指令 func parseAndPrintInstructions(code []byte) { pc := uint64(0) for pc < uint64(len(code)) { inst, err := x86asm.Decode(code[pc:], 64) if err != nil { fmt.Printf("Decode error: %v\n", err) break } fmt.Printf("0x%x:[%d] %s\n", pc, pc, inst.String()) pc += uint64(inst.Len) } } func findSendFunctionNameAddr(SymAddr uint64, code []byte, sendFunctionName string, sendSymAddr uint64) (uint64, uint64, error) { pc := uint64(0) for pc < uint64(len(code)) { inst, err := x86asm.Decode(code[pc:], 64) if err != nil { fmt.Printf("Decode error: %v\n", err) break } fmt.Printf("0x%x:[%d] %s\n", SymAddr+pc, pc, inst.String()) if inst.Op == x86asm.CALL { if v, ok := inst.Args[0].(x86asm.Rel); ok { fmt.Printf("%s\n", inst.Args[0].String()) if SymAddr+pc+uint64(v)+5 == sendSymAddr { fmt.Printf("Found send function call at 0x%x\n", SymAddr+pc) return SymAddr + pc, pc, nil } } } pc += uint64(inst.Len) } return 0, 0, errors.New("Send function Offset not found") } // 获取 Send 函数附近的内存数据,用于长跳转到 libmylib 中,需要凑够 17 个字节,先向上查找,再向下查找 func findMemForLongJump(SymAddr uint64, code []byte, sendFunctionName string, sendSymAddr uint64) (InstMemStruct, error) { // 定义一个结构体,储存 pc、指令、和指令长度 type InstStruct struct { pc uint64 inst x86asm.Inst len int } // 定义个字节数组用于储存遍历出来的指令,字节数组 var codeArray []byte // 定义一个字节数组的数组,用于储存遍历出来的每个指令的字节内容 var InstCodeArray [][]byte // 定义一个InstStruct数组,用于储存遍历出来的指令 var instArray []InstStruct // 定义一个标志记录记录找到的Send函数在instArray数组的位置 var sendIndex int pc := uint64(0) for pc < uint64(len(code)) { inst, err := x86asm.Decode(code[pc:], 64) if err != nil { fmt.Printf("Decode error: %v\n", err) break } // 将指令和指令长度存入InstStruct结构体中 instArray = append(instArray, InstStruct{pc: pc, inst: inst, len: inst.Len}) // 如果地址等于Send函数的地址,记录下Send函数在instArray数组的位置 if SymAddr+pc == sendSymAddr { sendIndex = len(instArray) - 1 } pc += uint64(inst.Len) } // 打印 sendIndex 和 instArray 的长度和 Send 函数的指令 fmt.Printf("sendIndex: %d, instArray len: %d, Send Function: %s\n", sendIndex, len(instArray), instArray[sendIndex].inst.String()) InstStartAddr := uint64(0) InstEndAddr := uint64(sendSymAddr) + 5 // 从Send函数开始向上查找,凑够17个字节 for i := sendIndex; i >= 0; i-- { if i != sendIndex && instArray[i].inst.Op == x86asm.CALL || instArray[i].inst.Op == x86asm.JMP || instArray[i].inst.Op == x86asm.JE || instArray[i].inst.Op == x86asm.JG || instArray[i].inst.Op == x86asm.JGE || instArray[i].inst.Op == x86asm.JL || instArray[i].inst.Op == x86asm.JLE || instArray[i].inst.Op == x86asm.JNE { break } // 将指令存入codeArray数组,指令应该从 code 中的pc开始,长度为instArray[i].len codeArray = append(codeArray, code[instArray[i].pc:instArray[i].pc+uint64(instArray[i].len)]...) InstCodeArray = append(InstCodeArray, code[instArray[i].pc:instArray[i].pc+uint64(instArray[i].len)]) // 记录当前指令的地址 InstStartAddr = SymAddr + instArray[i].pc // 打印本次循环的具体指令和指令长度 fmt.Printf("inst: %s\n", instArray[i].inst.String()) // 打印codeArray数组的长度和内容 fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray) // 如果codeArray数组长度大于等于17,或者 该指令代码等于 Call、jmp、jne、je、jg、jge、jl、jle、jne、等跳转指令,跳出循环,用于跳转的指令有哪些,可以根据实际情况添加 if len(codeArray) >= longJumpSize { fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray) // 打印InstCodeArray数组的长度和内容 fmt.Printf("InstCodeArray len: %d, InstCodeArray: %v\n", len(InstCodeArray), InstCodeArray) break } } // 将 InstCodeArray 数组进行倒序 for i := 0; i < len(InstCodeArray)/2; i++ { InstCodeArray[i], InstCodeArray[len(InstCodeArray)-1-i] = InstCodeArray[len(InstCodeArray)-1-i], InstCodeArray[i] } insertIndex := len(InstCodeArray) - 1 // 如果 codeArray 的长度没有凑够 17 个字节,则从Send函数开始向下查找,继续凑够17个字节 if len(codeArray) < longJumpSize { // 从Send函数开始向下查找,凑够17个字节 for i := sendIndex + 1; i < len(instArray); i++ { if instArray[i].inst.Op == x86asm.CALL || instArray[i].inst.Op == x86asm.JMP || instArray[i].inst.Op == x86asm.JE || instArray[i].inst.Op == x86asm.JG || instArray[i].inst.Op == x86asm.JGE || instArray[i].inst.Op == x86asm.JL || instArray[i].inst.Op == x86asm.JLE || instArray[i].inst.Op == x86asm.JNE { break } // 将指令存入codeArray数组,指令应该从 code 中的pc开始,长度为instArray[i].len codeArray = append(codeArray, code[instArray[i].pc:instArray[i].pc+uint64(instArray[i].len)]...) InstCodeArray = append(InstCodeArray, code[instArray[i].pc:instArray[i].pc+uint64(instArray[i].len)]) // 记录当前指令的地址 InstEndAddr = SymAddr + instArray[i].pc + uint64(instArray[i].len) // 打印本次循环的具体指令和指令长度 fmt.Printf("inst: %s\n", instArray[i].inst.String()) // 打印codeArray数组的长度和内容 fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray) // 如果codeArray数组长度大于等于17,或者 该指令代码等于 Call、jmp、jne、je、jg、jge、jl、jle、jne、等跳转指令,跳出循环,用于跳转的指令有哪些,可以根据实际情况添加 if len(codeArray) >= longJumpSize { fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray) // 打印InstCodeArray数组的长度和内容 fmt.Printf("InstCodeArray len: %d, InstCodeArray: %v\n", len(InstCodeArray), InstCodeArray) break } } } // 打印最终的codeArray数组的长度和内容 fmt.Printf("codeArray len: %d, codeArray: %v\n", len(codeArray), codeArray) // 打印最终的InstAddr地址 fmt.Printf("InstAddr: %x-%x Len: %d, InstCodeArray: %v, InstCodeArray len: %d\n", InstStartAddr, InstEndAddr, len(codeArray), InstCodeArray, len(InstCodeArray)) instMemStruct := InstMemStruct{InstStartAddr: InstStartAddr, InstEndAddr: InstEndAddr, CodeArray: codeArray, InstCodeArray: InstCodeArray, InsertIndex: insertIndex} if len(codeArray) < longJumpSize { return instMemStruct, errors.New("Not enough memory for long jump") } return instMemStruct, nil } // 从 PLT 中获取函数的地址,先从改进程的 maps 中的 glibc 中获取对应的函数地址,在反过来从 plt 段中的 jmp 中计算出该函数的地址 func getFunctionOffsetPLT(pid int, libPath, functionName string) (elf.Symbol, error) { // 获取进程 maps 中 libc.so 的基地址 processLibcMapsInfo, err := getProcessMapsInfo(pid, "libc.so.6") // 获取 libc 中 send 函数的地址 libcFunction, err := getFunctionOffset(processLibcMapsInfo.Path, functionName) // 打印 libc 中 send 函数的地址 fmt.Printf("libc send function: %s, addr: %x, addr: %x\n", functionName, libcFunction.Value, libcFunction.Value+processLibcMapsInfo.Start) processMapsInfo, err := getProcessMapsInfo(pid) // 打印进程 maps 中的地址 fmt.Printf("process maps: %s, addr: %x\n", processMapsInfo.Path, processMapsInfo.Start) // 解析 libPath 中的 plt 段,遍历所有的 plt 段的地址,找到 所有 jmp 指令所跳转的地址为 libc 中 send 函数的地址 elfFile, err := elf.Open(libPath) if err != nil { return elf.Symbol{}, fmt.Errorf("failed to open ELF file: %v", err) } defer elfFile.Close() // 遍历 plt 段 pltSection := elfFile.Section(".plt") if pltSection == nil { return elf.Symbol{}, errors.New(".plt section not found") } pltData, err := pltSection.Data() if err != nil { return elf.Symbol{}, fmt.Errorf("failed to read .plt section data: %v", err) } entrySize := 16 // 计算 plt 段中一共有多少 plt 入口,使用 plt 段的大小除以 plt 入口的大小 // plt 入口的数量为 len(pltData)/entrySize entryNum := len(pltData) / entrySize baseJmpAddr := uint64(0) isSetBaseJmpAddr := false // 初始化一个 map,key 为指向的地址,value 为 plt 段中的地址 pltEntryMap := make(map[uint64]uint64) // 遍历 plt 段 for i := 0; i < len(pltData); i += int(entrySize) { // 获取 plt 段中的地址 addr := pltSection.Addr + uint64(i) // 打印 addr fmt.Printf("addr: %x\n", addr) // 获取 plt 段中的 jmp 指令 jmpInst := pltData[i : i+int(entrySize)] // 解析 jmp 指令 inst, err := x86asm.Decode(jmpInst, 64) if err != nil { fmt.Printf("Decode error: %v\n", err) break } // 判断 inst 是否为 jmp 指令 if inst.Op == x86asm.JMP { jmpAddr := addr + uint64(inst.Len) + uint64(int32(inst.Args[0].(x86asm.Mem).Disp)) pltEntryMap[jmpAddr] = addr if isSetBaseJmpAddr == false { // baseAddr = addr baseJmpAddr = jmpAddr isSetBaseJmpAddr = true } } } var data []byte data = make([]byte, 8*(entryNum+1)) // 使用 ptrace attach 目标进程 fmt.Printf("Attach Process: %d\n", pid) if err := syscall.PtraceAttach(pid); err != nil { fmt.Printf("PtraceAttach Err: %v\n", err) } // 等待目标进程停止 if _, err := syscall.Wait4(pid, nil, 0, nil); err != nil { fmt.Printf("wait4: %v", err) } dataLen, _ := syscall.PtracePeekData(pid, uintptr(baseJmpAddr+processMapsInfo.Start), data) fmt.Printf("Ptrace DETACH: %d\n", pid) // 恢复执行 if err = syscall.PtraceDetach(pid); err != nil { fmt.Printf("ptrace DETACH: %v", err) } // 按 uint64 数组遍历 data for i := 0; i < dataLen; i += 8 { // 获取 data 中的地址 addr := *(*uint64)(unsafe.Pointer(&data[i])) // 打印 data 中的地址 fmt.Printf("addr: 0x%x\n", addr) // 如果 data 中的地址等于 libc 中 send 函数的地址,返回该地址 if processLibcMapsInfo.Start+libcFunction.Value == addr { fmt.Printf("Found %s at address: 0x%x\n", functionName, baseJmpAddr) fmt.Printf("Found %s from address: 0x%x[0x%x]\n", functionName, uint64(i)+baseJmpAddr, processMapsInfo.Start+uint64(i)+baseJmpAddr) // 判断 pltEntryMap 中是否存在 baseJmpAddr,如果存在,返回 baseJmpAddr if value, ok := pltEntryMap[uint64(i)+baseJmpAddr]; ok { fmt.Printf("Found %s at address on plt: 0x%x\n", functionName, value) return elf.Symbol{Name: functionName, Value: value, Size: 0}, nil } } } return elf.Symbol{}, fmt.Errorf("function %s not found", functionName) } // 生成一个函数,用于生成长跳指令插入,使用 r11 寄存器,使用前需要保存 r11 寄存器的值,使用后恢复 r11 寄存器的值,跳转的指令从参数传入,第二个参数为指令长度,如果指令长度不够,需要在后面添加 nop 指令 func generateLongJumpCode(cwFunctionAddr uint64, length int, saveR11 bool) []byte { jumpCode := []byte{} if saveR11 { jumpCode = append(jumpCode, 0x41, 0x53) // push r11 } jumpCode = append(jumpCode, 0x49, 0xbb, byte(cwFunctionAddr), byte(cwFunctionAddr>>8), byte(cwFunctionAddr>>16), byte(cwFunctionAddr>>24), byte(cwFunctionAddr>>32), byte(cwFunctionAddr>>40), byte(cwFunctionAddr>>48), byte(cwFunctionAddr>>56), // movabs r11, cwFunctionAddr 0x41, 0xff, 0xe3, // jmp r11 ) if saveR11 { jumpCode = append(jumpCode, 0x41, 0x5b) // pop r11 } // Add NOP instructions if the length is not enough for len(jumpCode) < length { jumpCode = append(jumpCode, 0x90) // NOP instruction } return jumpCode } // 生成新指令,用于替换原指令,一个参数是原来的指令 instCodeArray,另一个参数是跳转回去的地址,第三个参数是原始 Send 函数的地址,第四个参数是在哪个指令前插入自定义指令 func generateNewCode(instCodeArray [][]byte, jumpBackAddr uint64, sendAddr uint64, insertIndex int) ([]byte, int) { newCode := []byte{} hookOffset := 0 for i, inst := range instCodeArray { if i == insertIndex { /** 约定从栈顶拿8个字节储存 ebpf 中配置的 header 长度:先备份 rdx 到 -0x08(%rsp) 待 ebpf 将值写到 -0x08(%rsp) 之后,再从 -0x08(%rsp) 赋值到 rdx mov %rdx, -0x08(%rsp) mov -0x08(%rsp), %rdx **/ newCode = append(newCode, 0x48, 0x89, 0x54, 0x24, 0xf8) // mov %rdx, -0x08(%rsp) hookOffset = len(newCode) newCode = append(newCode, 0x90) // nop newCode = append(newCode, 0x48, 0x8b, 0x54, 0x24, 0xf8) // mov -0x08(%rsp), %rdx /** 生成 Call 指令利用 r11 寄存器进程长调用调用原来的 Send 函数 **/ CallSendCode := []byte{ 0x41, 0x53, // push r11 0x49, 0xbb, byte(sendAddr), byte(sendAddr >> 8), byte(sendAddr >> 16), byte(sendAddr >> 24), byte(sendAddr >> 32), byte(sendAddr >> 40), byte(sendAddr >> 48), byte(sendAddr >> 56), // movabs r11, sendAddr 0x41, 0xff, 0xd3, // call r11 0x41, 0x5b, // pop r11 } newCode = append(newCode, CallSendCode...) // 生成长跳指令,跳转回去 newCode = append(newCode, generateLongJumpCode(jumpBackAddr, longJumpSize, false)...) continue } newCode = append(newCode, inst...) } return newCode, hookOffset } // 查找nop函数中hookOffset的位置 func findNopFunctionHookOffset(pid int, cwSym *ProcessFunctionInfo) (int, error) { // 定义一个字节数组内容:0x48, 0x8b, 0x54, 0x24, 0xf8 nopCode := []byte{0x48, 0x8b, 0x54, 0x24, 0xf8} code, err := readMemory(pid, cwSym.Start+nopEntryOffset, cwSym.Size-nopEntryOffset) if err != nil { fmt.Printf("readMemory error: %v\n", err) return 0, err } fmt.Printf("nopCode: %v\n", code) pc := uint64(0) for pc < uint64(len(code)) { inst, err := x86asm.Decode(code[pc:], 64) if err != nil { fmt.Printf("Decode error: %v\n", err) break } // 打印对应指令的字节码 fmt.Printf("%s[%d]-[%v]\n", inst.String(), pc, code[pc:pc+uint64(inst.Len)]) // 如果指令的字节码等于nopCode,返回当前的pc if string(code[pc:pc+uint64(inst.Len)]) == string(nopCode) { return int(pc) - nopLen, nil } pc += uint64(inst.Len) } return 0, errors.New("hookOffset not found") }