de4dot-cex/de4dot.mdecrypt/DynamicMethodsDecrypter.cs

580 lines
19 KiB
C#

/*
Copyright (C) 2011-2012 de4dot@gmail.com
This file is part of de4dot.
de4dot is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
de4dot is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with de4dot. If not, see <http://www.gnu.org/licenses/>.
*/
using System;
using System.IO;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Reflection;
using dot10.DotNet;
using dot10.DotNet.MD;
using de4dot.blocks;
namespace de4dot.mdecrypt {
public class DynamicMethodsDecrypter {
static DynamicMethodsDecrypter instance;
DecryptMethodsInfo decryptMethodsInfo;
struct FuncPtrInfo<D> {
public D del;
public IntPtr ptr;
public IntPtr ptrInDll;
public void prepare(Delegate del) {
RuntimeHelpers.PrepareDelegate(del);
ptr = Marshal.GetFunctionPointerForDelegate(del);
}
}
[StructLayout(LayoutKind.Sequential, Pack = 1)]
struct IMAGE_SECTION_HEADER {
public ulong name;
public uint VirtualSize;
public uint VirtualAddress;
public uint SizeOfRawData;
public uint PointerToRawData;
public uint PointerToRelocations;
public uint PointerToLinenumbers;
public ushort NumberOfRelocations;
public ushort NumberOfLinenumbers;
public uint Characteristics;
}
[StructLayout(LayoutKind.Sequential, Pack=1, Size=0x88)]
struct CORINFO_METHOD_INFO {
public IntPtr ftn;
public IntPtr scope;
public IntPtr ILCode;
public uint ILCodeSize;
public ushort maxStack;
public ushort EHCount;
// 0x64 other bytes here...
}
class DecryptContext {
public DumpedMethod dm;
public MethodDef method;
}
FuncPtrInfo<CompileMethod> ourCompileMethodInfo = new FuncPtrInfo<CompileMethod>();
FuncPtrInfo<ReturnMethodToken> returnMethodTokenInfo = new FuncPtrInfo<ReturnMethodToken>();
FuncPtrInfo<ReturnNameOfMethod> returnNameOfMethodInfo = new FuncPtrInfo<ReturnNameOfMethod>();
IntPtr origCompileMethod;
IntPtr jitterTextFreeMem;
IntPtr callMethod;
CallMethod callMethodDelegate;
IntPtr jitterInstance;
IntPtr jitterVtbl;
Module moduleToDecrypt;
IntPtr hInstModule;
IntPtr ourCompMem;
bool compileMethodIsThisCall;
IntPtr ourCodeAddr;
MDTable methodDefTable;
IntPtr methodDefTablePtr;
ModuleDefMD dot10Module;
MethodDef moduleCctor;
uint moduleCctorCodeRva;
IntPtr moduleToDecryptScope;
DecryptContext ctx = new DecryptContext();
public static DynamicMethodsDecrypter Instance {
get {
if (instance != null)
return instance;
return instance = new DynamicMethodsDecrypter();
}
}
static Version VersionNet45 = new Version(4, 0, 30319, 17020);
DynamicMethodsDecrypter() {
if (UIntPtr.Size != 4)
throw new ApplicationException("Only 32-bit dynamic methods decryption is supported");
// .NET 4.5's compileMethod has thiscall calling convention
compileMethodIsThisCall = Environment.Version >= VersionNet45;
}
[DllImport("kernel32", CharSet = CharSet.Ansi)]
static extern IntPtr GetModuleHandle(string name);
[DllImport("kernel32", CharSet = CharSet.Ansi)]
static extern IntPtr GetProcAddress(IntPtr hModule, string name);
[DllImport("kernel32")]
static extern bool VirtualProtect(IntPtr addr, int size, uint newProtect, out uint oldProtect);
const uint PAGE_EXECUTE_READWRITE = 0x40;
[DllImport("kernel32")]
static extern IntPtr VirtualAlloc(IntPtr lpAddress, UIntPtr dwSize, uint flAllocationType, uint flProtect);
delegate IntPtr GetJit();
delegate int CompileMethod(IntPtr jitter, IntPtr comp, IntPtr info, uint flags, IntPtr nativeEntry, IntPtr nativeSizeOfCode, out bool handled);
delegate int ReturnMethodToken();
delegate string ReturnNameOfMethod();
delegate int CallMethod(IntPtr compileMethod, IntPtr jitter, IntPtr comp, IntPtr info, uint flags, IntPtr nativeEntry, IntPtr nativeSizeOfCode);
public DecryptMethodsInfo DecryptMethodsInfo {
set { decryptMethodsInfo = value; }
}
public unsafe Module Module {
set {
if (moduleToDecrypt != null)
throw new ApplicationException("Module has already been initialized");
moduleToDecrypt = value;
hInstModule = Marshal.GetHINSTANCE(moduleToDecrypt);
moduleToDecryptScope = getScope(moduleToDecrypt);
dot10Module = ModuleDefMD.Load(hInstModule);
methodDefTable = dot10Module.TablesStream.MethodTable;
methodDefTablePtr = new IntPtr((byte*)hInstModule + (uint)dot10Module.MetaData.PEImage.ToRVA(methodDefTable.StartOffset));
initializeDot10Methods();
}
}
static IntPtr getScope(Module module) {
var obj = getFieldValue(module.ModuleHandle, "m_ptr");
if (obj is IntPtr)
return (IntPtr)obj;
if (obj.GetType().ToString() == "System.Reflection.RuntimeModule")
return (IntPtr)getFieldValue(obj, "m_pData");
throw new ApplicationException(string.Format("m_ptr is an invalid type: {0}", obj.GetType()));
}
static object getFieldValue(object obj, string fieldName) {
var field = obj.GetType().GetField(fieldName, BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
if (field == null)
throw new ApplicationException(string.Format("Could not get field {0}::{1}", obj.GetType(), fieldName));
return field.GetValue(obj);
}
unsafe void initializeDot10Methods() {
moduleCctor = dot10Module.GlobalType.FindStaticConstructor();
if (moduleCctor == null)
moduleCctorCodeRva = 0;
else {
byte* p = (byte*)hInstModule + (uint)moduleCctor.RVA;
if ((*p & 3) == 2)
moduleCctorCodeRva = (uint)moduleCctor.RVA + 1;
else
moduleCctorCodeRva = (uint)((uint)moduleCctor.RVA + (p[1] >> 4) * 4);
}
}
public unsafe void installCompileMethod() {
var hJitterDll = getJitterDllHandle();
jitterTextFreeMem = getEndOfText(hJitterDll);
var getJitPtr = GetProcAddress(hJitterDll, "getJit");
var getJit = (GetJit)Marshal.GetDelegateForFunctionPointer(getJitPtr, typeof(GetJit));
jitterInstance = getJit();
jitterVtbl = *(IntPtr*)jitterInstance;
origCompileMethod = *(IntPtr*)jitterVtbl;
prepareMethods();
initializeDelegateFunctionPointers();
createOurCode();
callMethodDelegate = (CallMethod)Marshal.GetDelegateForFunctionPointer(callMethod, typeof(CallMethod));
writeCompileMethod(ourCompileMethodInfo.ptrInDll);
}
unsafe void writeCompileMethod(IntPtr newCompileMethod) {
uint oldProtect;
if (!VirtualProtect(jitterVtbl, IntPtr.Size, PAGE_EXECUTE_READWRITE, out oldProtect))
throw new ApplicationException("Could not enable write access to jitter vtbl");
*(IntPtr*)jitterVtbl = newCompileMethod;
VirtualProtect(jitterVtbl, IntPtr.Size, oldProtect, out oldProtect);
}
void initializeDelegateFunctionPointers() {
ourCompileMethodInfo.prepare(ourCompileMethodInfo.del = compileMethod);
returnMethodTokenInfo.prepare(returnMethodTokenInfo.del = returnMethodToken);
returnNameOfMethodInfo.prepare(returnNameOfMethodInfo.del = returnNameOfMethod);
}
public void loadObfuscator() {
RuntimeHelpers.RunModuleConstructor(moduleToDecrypt.ModuleHandle);
}
public unsafe bool canDecryptMethods() {
return *(IntPtr*)jitterVtbl != ourCompileMethodInfo.ptrInDll &&
*(IntPtr*)jitterVtbl != origCompileMethod;
}
unsafe static IntPtr getEndOfText(IntPtr hDll) {
byte* p = (byte*)hDll;
p += *(uint*)(p + 0x3C); // add DOSHDR.e_lfanew
p += 4;
int numSections = *(ushort*)(p + 2);
int sizeOptionalHeader = *(ushort*)(p + 0x10);
p += 0x14;
uint sectionAlignment = *(uint*)(p + 0x20);
p += sizeOptionalHeader;
var textName = new byte[8] { (byte)'.', (byte)'t', (byte)'e', (byte)'x', (byte)'t', 0, 0, 0 };
var name = new byte[8];
var pSection = (IMAGE_SECTION_HEADER*)p;
for (int i = 0; i < numSections; i++, pSection++) {
Marshal.Copy(new IntPtr(pSection), name, 0, name.Length);
if (!compareName(textName, name, name.Length))
continue;
uint size = pSection->VirtualSize;
uint rva = pSection->VirtualAddress;
int displ = -8;
return new IntPtr((byte*)hDll + rva + size + displ);
}
throw new ApplicationException("Could not find .text section");
}
static bool compareName(byte[] b1, byte[] b2, int len) {
for (int i = 0; i < len; i++) {
if (b1[i] != b2[i])
return false;
}
return true;
}
void prepareMethods() {
Marshal.PrelinkAll(GetType());
foreach (var methodInfo in GetType().GetMethods(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Static | BindingFlags.Instance))
RuntimeHelpers.PrepareMethod(methodInfo.MethodHandle);
}
unsafe void createOurCode() {
var code = new NativeCodeGenerator();
// our compileMethod() func
int compileMethodOffset = code.Size;
int numPushedArgs = compileMethodIsThisCall ? 5 : 6;
code.writeByte(0x51); // push ecx
code.writeByte(0x50); // push eax
code.writeByte(0x54); // push esp
for (int i = 0; i < 5; i++)
writePushDwordPtrEspDispl(code, (sbyte)(0xC + numPushedArgs * 4)); // push dword ptr [esp+XXh]
if (!compileMethodIsThisCall)
writePushDwordPtrEspDispl(code, (sbyte)(0xC + numPushedArgs * 4)); // push dword ptr [esp+XXh]
else
code.writeByte(0x51); // push ecx
code.writeCall(ourCompileMethodInfo.ptr);
code.writeByte(0x5A); // pop edx
code.writeByte(0x59); // pop ecx
code.writeBytes(0x84, 0xD2); // test dl, dl
code.writeBytes(0x74, 0x03); // jz $+5
code.writeBytes(0xC2, (ushort)(numPushedArgs * 4)); // retn 14h/18h
for (int i = 0; i < numPushedArgs; i++)
writePushDwordPtrEspDispl(code, (sbyte)(numPushedArgs * 4)); // push dword ptr [esp+XXh]
code.writeCall(origCompileMethod);
code.writeBytes(0xC2, (ushort)(numPushedArgs * 4)); // retn 14h/18h
// Our callMethod() code. 1st arg is the method to call. stdcall calling convention.
int callMethodOffset = code.Size;
code.writeByte(0x58); // pop eax (ret addr)
code.writeByte(0x5A); // pop edx (method to call)
if (compileMethodIsThisCall)
code.writeByte(0x59); // pop ecx (this ptr)
code.writeByte(0x50); // push eax (ret addr)
code.writeBytes(0xFF, 0xE2); // jmp edx
// Returns token of method
int getMethodTokenOffset = code.Size;
code.writeCall(returnMethodTokenInfo.ptr);
code.writeBytes(0xC2, (ushort)(IntPtr.Size * 2));
// Returns name of method
int getMethodNameOffset = code.Size;
code.writeCall(returnNameOfMethodInfo.ptr);
code.writeBytes(0xC2, (ushort)(IntPtr.Size * 3));
ourCodeAddr = VirtualAlloc(IntPtr.Zero, new UIntPtr((ulong)code.Size), 0x00001000, PAGE_EXECUTE_READWRITE);
IntPtr baseAddr = ourCodeAddr;
ourCompileMethodInfo.ptrInDll = new IntPtr((byte*)baseAddr + compileMethodOffset);
callMethod = new IntPtr((byte*)baseAddr + callMethodOffset);
returnMethodTokenInfo.ptrInDll = new IntPtr((byte*)baseAddr + getMethodTokenOffset);
returnNameOfMethodInfo.ptrInDll = new IntPtr((byte*)baseAddr + getMethodNameOffset);
byte[] theCode = code.getCode(baseAddr);
Marshal.Copy(theCode, 0, baseAddr, theCode.Length);
}
// Writes push dword ptr [esp+displ]
static void writePushDwordPtrEspDispl(NativeCodeGenerator code, sbyte displ) {
code.writeBytes(0xFF, 0x74);
code.writeBytes(0x24, (byte)displ);
}
static IntPtr getJitterDllHandle() {
var hJitterDll = GetModuleHandle("mscorjit");
if (hJitterDll == IntPtr.Zero)
hJitterDll = GetModuleHandle("clrjit");
if (hJitterDll == IntPtr.Zero)
throw new ApplicationException("Could not get a handle to the jitter DLL");
return hJitterDll;
}
unsafe int compileMethod(IntPtr jitter, IntPtr comp, IntPtr info, uint flags, IntPtr nativeEntry, IntPtr nativeSizeOfCode, out bool handled) {
if (ourCompMem != IntPtr.Zero && comp == ourCompMem) {
// We're decrypting methods
var info2 = (CORINFO_METHOD_INFO*)info;
ctx.dm.code = new byte[info2->ILCodeSize];
Marshal.Copy(info2->ILCode, ctx.dm.code, 0, ctx.dm.code.Length);
ctx.dm.mhMaxStack = info2->maxStack;
ctx.dm.mhCodeSize = info2->ILCodeSize;
if ((ctx.dm.mhFlags & 8) != 0)
ctx.dm.extraSections = readExtraSections((byte*)info2->ILCode + info2->ILCodeSize);
updateFromMethodDefTableRow();
handled = true;
return 0;
}
else {
// We're not decrypting methods
var info2 = (CORINFO_METHOD_INFO*)info;
if (info2->scope != moduleToDecryptScope) {
handled = false;
return 0;
}
uint codeRva = (uint)((byte*)info2->ILCode - (byte*)hInstModule);
if (decryptMethodsInfo.moduleCctorBytes != null && moduleCctorCodeRva != 0 && moduleCctorCodeRva == codeRva) {
fixed (byte* newIlCodeBytes = &decryptMethodsInfo.moduleCctorBytes[0]) {
writeCompileMethod(origCompileMethod);
info2->ILCode = new IntPtr(newIlCodeBytes);
info2->ILCodeSize = (uint)decryptMethodsInfo.moduleCctorBytes.Length;
handled = true;
return callMethodDelegate(origCompileMethod, jitter, comp, info, flags, nativeEntry, nativeSizeOfCode);
}
}
}
handled = false;
return 0;
}
unsafe static byte* align(byte* p, int alignment) {
return (byte*)new IntPtr((long)((ulong)(p + alignment - 1) & ~(ulong)(alignment - 1)));
}
unsafe static byte[] readExtraSections(byte* p) {
p = align(p, 4);
byte* startPos = p;
p = parseSection(p);
int size = (int)(p - startPos);
var sections = new byte[size];
Marshal.Copy(new IntPtr(startPos), sections, 0, sections.Length);
return sections;
}
unsafe static byte* parseSection(byte* p) {
byte flags;
do {
p = align(p, 4);
flags = *p++;
if ((flags & 1) == 0)
throw new ApplicationException("Not an exception section");
if ((flags & 0x3E) != 0)
throw new ApplicationException("Invalid bits set");
if ((flags & 0x40) != 0) {
p--;
int num = (int)(*(uint*)p >> 8) / 24;
p += 4 + num * 24;
}
else {
int num = *p++ / 12;
p += 2 + num * 12;
}
} while ((flags & 0x80) != 0);
return p;
}
unsafe void updateFromMethodDefTableRow() {
uint methodIndex = ctx.dm.token - 0x06000001;
byte* row = (byte*)methodDefTablePtr + methodIndex * methodDefTable.RowSize;
ctx.dm.mdRVA = read(row, methodDefTable.Columns[0]);
ctx.dm.mdImplFlags = (ushort)read(row, methodDefTable.Columns[1]);
ctx.dm.mdFlags = (ushort)read(row, methodDefTable.Columns[2]);
ctx.dm.mdName = read(row, methodDefTable.Columns[3]);
ctx.dm.mdSignature = read(row, methodDefTable.Columns[4]);
ctx.dm.mdParamList = read(row, methodDefTable.Columns[5]);
}
static unsafe uint read(byte* row, ColumnInfo colInfo) {
switch (colInfo.Size) {
case 1: return *(row + colInfo.Offset);
case 2: return *(ushort*)(row + colInfo.Offset);
case 4: return *(uint*)(row + colInfo.Offset);
default: throw new ApplicationException(string.Format("Unknown size: {0}", colInfo.Size));
}
}
string returnNameOfMethod() {
return ctx.method.Name.String;
}
int returnMethodToken() {
return ctx.method.MDToken.ToInt32();
}
public DumpedMethods decryptMethods() {
if (!canDecryptMethods())
throw new ApplicationException("Can't decrypt methods since compileMethod() isn't hooked yet");
installCompileMethod2();
var dumpedMethods = new DumpedMethods();
if (decryptMethodsInfo.methodsToDecrypt == null) {
for (uint rid = 1; rid <= methodDefTable.Rows; rid++)
dumpedMethods.add(decryptMethod(0x06000000 + rid));
}
else {
foreach (var token in decryptMethodsInfo.methodsToDecrypt)
dumpedMethods.add(decryptMethod(token));
}
return dumpedMethods;
}
unsafe DumpedMethod decryptMethod(uint token) {
if (!canDecryptMethods())
throw new ApplicationException("Can't decrypt methods since compileMethod() isn't hooked yet");
ctx = new DecryptContext();
ctx.dm = new DumpedMethod();
ctx.dm.token = token;
ctx.method = dot10Module.ResolveMethod(MDToken.ToRID(token));
if (ctx.method == null)
throw new ApplicationException(string.Format("Could not find method {0:X8}", token));
byte* mh = (byte*)hInstModule + (uint)ctx.method.RVA;
byte* code;
if (mh == (byte*)hInstModule) {
ctx.dm.mhMaxStack = 0;
ctx.dm.mhCodeSize = 0;
ctx.dm.mhFlags = 0;
ctx.dm.mhLocalVarSigTok = 0;
code = null;
}
else if ((*mh & 3) == 2) {
uint headerSize = 1;
ctx.dm.mhMaxStack = 8;
ctx.dm.mhCodeSize = (uint)(*mh >> 2);
ctx.dm.mhFlags = 2;
ctx.dm.mhLocalVarSigTok = 0;
code = mh + headerSize;
}
else {
uint headerSize = (uint)((mh[1] >> 4) * 4);
ctx.dm.mhMaxStack = *(ushort*)(mh + 2);
ctx.dm.mhCodeSize = *(uint*)(mh + 4);
ctx.dm.mhFlags = *(ushort*)mh;
ctx.dm.mhLocalVarSigTok = *(uint*)(mh + 8);
code = mh + headerSize;
}
CORINFO_METHOD_INFO info = default(CORINFO_METHOD_INFO);
info.ILCode = new IntPtr(code);
info.ILCodeSize = ctx.dm.mhCodeSize;
info.maxStack = ctx.dm.mhMaxStack;
info.scope = moduleToDecryptScope;
initializeOurComp();
if (code == null) {
ctx.dm.code = new byte[0];
updateFromMethodDefTableRow();
}
else
callMethodDelegate(*(IntPtr*)jitterVtbl, jitterInstance, ourCompMem, new IntPtr(&info), 0, new IntPtr(0x12345678), new IntPtr(0x3ABCDEF0));
var dm = ctx.dm;
ctx = null;
return dm;
}
unsafe void initializeOurComp() {
const int numIndexes = 15;
if (ourCompMem == IntPtr.Zero)
ourCompMem = Marshal.AllocHGlobal(numIndexes * IntPtr.Size);
if (ourCompMem == IntPtr.Zero)
throw new ApplicationException("Could not allocate memory");
IntPtr* mem = (IntPtr*)ourCompMem;
for (int i = 0; i < numIndexes; i++)
mem[i] = IntPtr.Zero;
mem[1] = new IntPtr(mem + 2);
mem[3] = new IntPtr(IntPtr.Size * 5);
mem[5] = new IntPtr(IntPtr.Size * 7);
mem[6] = new IntPtr(mem + 7);
mem[7] = returnNameOfMethodInfo.ptrInDll;
mem[8] = new IntPtr(mem);
mem[13] = returnMethodTokenInfo.ptrInDll; // .NET 2.0
mem[14] = returnMethodTokenInfo.ptrInDll; // .NET 4.0
}
bool hasInstalledCompileMethod2 = false;
unsafe void installCompileMethod2() {
if (hasInstalledCompileMethod2)
return;
if (!patchDword(*(IntPtr*)jitterVtbl, 0x30000, origCompileMethod, ourCompileMethodInfo.ptrInDll))
throw new ApplicationException("Couldn't patch compileMethod");
hasInstalledCompileMethod2 = true;
return;
}
unsafe bool patchDword(IntPtr addr, int size, IntPtr origValue, IntPtr newValue) {
addr = new IntPtr(addr.ToInt64() & ~0xFFF);
var endAddr = new IntPtr(addr.ToInt64() + size);
for (; addr.ToPointer() < endAddr.ToPointer(); addr = new IntPtr(addr.ToInt64() + 0x1000)) {
try {
for (int i = 0; i < 0x1000; i += IntPtr.Size) {
var addr2 = (IntPtr*)((byte*)addr + i);
if (*addr2 == origValue) {
*addr2 = newValue;
return true;
}
}
}
catch {
}
}
return false;
}
}
}