/*
Copyright (C) 2011-2012 de4dot@gmail.com
This file is part of de4dot.
de4dot is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
de4dot is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with de4dot. If not, see .
*/
using System;
using System.Collections.Generic;
using dot10.DotNet;
using dot10.DotNet.Emit;
using dot10.DotNet.MD;
namespace de4dot.blocks {
public enum FrameworkType {
Unknown,
Desktop,
Silverlight, // and WindowsPhone, XNA Xbox360
CompactFramework,
XNA,
Zune,
}
#if PORT
class TypeCache {
ModuleDefinition module;
de4dot.blocks.OLD_REMOVE.TypeDefinitionDict typeRefToDef = new de4dot.blocks.OLD_REMOVE.TypeDefinitionDict();
public TypeCache(ModuleDefinition module) {
this.module = module;
init();
}
void init() {
foreach (var type in module.GetTypes())
typeRefToDef.add(type, type);
}
public TypeDefinition lookup(TypeReference typeReference) {
return typeRefToDef.find(typeReference);
}
}
#endif
#if PORT
public class TypeCaches {
Dictionary typeCaches = new Dictionary();
// Should be called when the whole module is reloaded or when a lot of types have been
// modified (eg. renamed)
public void invalidate(ModuleDefinition module) {
if (module == null)
return;
typeCaches.Remove(module);
}
// Call this to invalidate all modules
public List invalidateAll() {
var list = new List(typeCaches.Keys);
typeCaches.Clear();
return list;
}
public TypeDefinition lookup(ModuleDefinition module, TypeReference typeReference) {
TypeCache typeCache;
if (!typeCaches.TryGetValue(module, out typeCache))
typeCaches[module] = typeCache = new TypeCache(module);
return typeCache.lookup(typeReference);
}
}
#endif
#if PORT
public class CallCounter {
Dictionary calls = new Dictionary();
public void add(MethodReference calledMethod) {
int count;
var key = new de4dot.blocks.OLD_REMOVE.MethodReferenceAndDeclaringTypeKey(calledMethod);
calls.TryGetValue(key, out count);
calls[key] = count + 1;
}
public MethodReference most() {
int numCalls;
return most(out numCalls);
}
public MethodReference most(out int numCalls) {
MethodReference method = null;
int callCount = 0;
foreach (var key in calls.Keys) {
if (calls[key] > callCount) {
callCount = calls[key];
method = key.MethodReference;
}
}
numCalls = callCount;
return method;
}
}
#endif
#if PORT
public class MethodCalls {
Dictionary methodCalls = new Dictionary(StringComparer.Ordinal);
public void addMethodCalls(MethodDef method) {
if (!method.HasBody)
return;
foreach (var instr in method.Body.Instructions) {
var calledMethod = instr.Operand as MethodReference;
if (calledMethod != null)
add(calledMethod);
}
}
public void add(MethodReference method) {
string key = method.FullName;
if (!methodCalls.ContainsKey(key))
methodCalls[key] = 0;
methodCalls[key]++;
}
public int count(string methodFullName) {
int count;
methodCalls.TryGetValue(methodFullName, out count);
return count;
}
public bool called(string methodFullName) {
return count(methodFullName) != 0;
}
}
#endif
public static class DotNetUtils {
#if PORT
public static readonly TypeCaches typeCaches = new TypeCaches();
#endif
public static TypeDef getModuleType(ModuleDef module) {
return module.GlobalType;
}
public static MethodDef getModuleTypeCctor(ModuleDef module) {
return module.GlobalType.FindClassConstructor();
}
public static bool isEmpty(MethodDef method) {
if (method.Body == null)
return false;
foreach (var instr in method.Body.Instructions) {
var code = instr.OpCode.Code;
if (code != Code.Nop && code != Code.Ret)
return false;
}
return true;
}
public static bool isEmptyObfuscated(MethodDef method) {
if (method.Body == null)
return false;
int index = 0;
var instr = getInstruction(method.Body.Instructions, ref index);
if (instr == null || instr.OpCode.Code != Code.Ret)
return false;
return true;
}
#if PORT
public static FieldDefinition findFieldType(TypeDefinition typeDefinition, string typeName, bool isStatic) {
if (typeDefinition == null)
return null;
foreach (var field in typeDefinition.Fields) {
if (field.FieldType.FullName == typeName && field.IsStatic == isStatic)
return field;
}
return null;
}
#endif
public static IEnumerable findMethods(IEnumerable methods, string returnType, string[] argsTypes) {
return findMethods(methods, returnType, argsTypes, true);
}
public static IEnumerable findMethods(IEnumerable methods, string returnType, string[] argsTypes, bool isStatic) {
foreach (var method in methods) {
var sig = method.MethodSig;
if (sig == null || !method.HasBody || !sig.IsDefault)
continue;
if (method.IsStatic != isStatic || sig.Params.Count != argsTypes.Length)
continue;
if (sig.GenParamCount > 0)
continue;
if (sig.RetType.GetFullName() != returnType)
continue;
for (int i = 0; i < argsTypes.Length; i++) {
if (sig.Params[i].GetFullName() != argsTypes[i])
goto next;
}
yield return method;
next: ;
}
}
public static bool isDelegate(IType type) {
if (type == null)
return false;
var fn = type.FullName;
return fn == "System.Delegate" || fn == "System.MulticastDelegate";
}
public static bool derivesFromDelegate(TypeDef type) {
return type != null && isDelegate(type.BaseType);
}
#if PORT
public static bool isSameAssembly(TypeReference type, string assembly) {
return MemberReferenceHelper.getCanonicalizedScopeName(type.Scope) == assembly.ToLowerInvariant();
}
public static bool isMethod(MethodReference method, string returnType, string parameters) {
return method != null && method.FullName == returnType + " " + method.DeclaringType.FullName + "::" + method.Name + parameters;
}
#endif
public static bool isMethod(IMethod method, string returnType, string parameters) {
return method != null && method.FullName == returnType + " " + method.DeclaringType.FullName + "::" + method.Name + parameters;
}
public static string getDllName(string dll) {
if (dll.EndsWith(".dll", StringComparison.OrdinalIgnoreCase))
return dll.Substring(0, dll.Length - 4);
return dll;
}
public static bool hasPinvokeMethod(TypeDef type, string methodName) {
return getPInvokeMethod(type, methodName) != null;
}
public static MethodDef getPInvokeMethod(TypeDef type, string methodName) {
if (type == null)
return null;
var mname = new UTF8String(methodName);
foreach (var method in type.Methods) {
if (method.ImplMap == null)
continue;
if (UTF8String.Equals(method.ImplMap.Name, mname))
return method;
}
return null;
}
#if PORT
public static MethodDef getPInvokeMethod(TypeDefinition type, string dll, string funcName) {
foreach (var method in type.Methods) {
if (isPinvokeMethod(method, dll, funcName))
return method;
}
return null;
}
public static bool isPinvokeMethod(MethodDef method, string dll, string funcName) {
if (method == null)
return false;
if (method.PInvokeInfo == null || method.PInvokeInfo.EntryPoint != funcName)
return false;
return getDllName(dll).Equals(getDllName(method.PInvokeInfo.Module.Name), StringComparison.OrdinalIgnoreCase);
}
public static MethodDef getMethod(TypeDefinition type, string name) {
if (type == null)
return null;
foreach (var method in type.Methods) {
if (method.Name == name)
return method;
}
return null;
}
public static MethodDef getMethod(TypeDefinition type, MethodReference methodReference) {
if (type == null || methodReference == null)
return null;
if (methodReference is MethodDef)
return (MethodDef)methodReference;
foreach (var method in type.Methods) {
if (MemberReferenceHelper.compareMethodReference(method, methodReference))
return method;
}
return null;
}
public static MethodDef getMethod(ModuleDefinition module, MethodReference method) {
if (method == null)
return null;
return getMethod(module, method, method.DeclaringType);
}
public static MethodDef getMethod2(ModuleDefinition module, MethodReference method) {
if (method == null)
return null;
return getMethod(module, method, method.DeclaringType.GetElementType());
}
static MethodDef getMethod(ModuleDefinition module, MethodReference method, TypeReference declaringType) {
if (method == null)
return null;
if (method is MethodDef)
return (MethodDef)method;
return getMethod(getType(module, declaringType), method);
}
#endif
public static MethodDef getMethod(TypeDef type, string returnType, string parameters) {
foreach (var method in type.Methods) {
if (isMethod(method, returnType, parameters))
return method;
}
return null;
}
public static MethodDef getMethod2(ModuleDef module, IMethod method) {
if (method == null)
return null;
return getMethod(module, method, method.DeclaringType.ScopeType);
}
public static TypeDef getType(ModuleDef module, TypeSig type) {
type = type.RemovePinnedAndModifiers();
var tdr = type as TypeDefOrRefSig;
if (tdr == null)
return null;
return getType(module, tdr.TypeDefOrRef);
}
public static TypeDef getType(ModuleDef module, ITypeDefOrRef type) {
var td = type as TypeDef;
if (td == null) {
var tr = type as TypeRef;
if (tr != null) {
var trAsm = tr.DefinitionAssembly;
var modAsm = module.Assembly;
if (trAsm != null && modAsm != null && trAsm.Name == modAsm.Name)
td = tr.Resolve();
}
}
return td != null && td.OwnerModule == module ? td : null;
}
static MethodDef getMethod(ModuleDef module, IMethod method, ITypeDefOrRef declaringType) {
if (method == null)
return null;
if (method is MethodDef)
return (MethodDef)method;
return getMethod(getType(module, declaringType), method);
}
public static MethodDef getMethod(TypeDef type, IMethod methodRef) {
if (type == null || methodRef == null)
return null;
if (methodRef is MethodDef)
return (MethodDef)methodRef;
return type.FindMethod(methodRef.Name, methodRef.MethodSig);
}
#if PORT
public static IEnumerable getNormalMethods(TypeDefinition type) {
foreach (var method in type.Methods) {
if (method.HasPInvokeInfo)
continue;
if (method.Name == ".ctor" || method.Name == ".cctor")
continue;
yield return method;
}
}
public static TypeDefinition getType(ModuleDefinition module, TypeReference typeReference) {
if (typeReference == null)
return null;
if (typeReference is TypeDefinition)
return (TypeDefinition)typeReference;
return typeCaches.lookup(module, typeReference);
}
public static FieldDefinition getField(ModuleDefinition module, FieldReference field) {
if (field == null)
return null;
if (field is FieldDefinition)
return (FieldDefinition)field;
return getField(getType(module, field.DeclaringType), field);
}
public static FieldDefinition getField(TypeDefinition type, FieldReference fieldReference) {
if (type == null || fieldReference == null)
return null;
if (fieldReference is FieldDefinition)
return (FieldDefinition)fieldReference;
foreach (var field in type.Fields) {
if (MemberReferenceHelper.compareFieldReference(field, fieldReference))
return field;
}
return null;
}
public static FieldDefinition getField(TypeDefinition type, string typeFullName) {
if (type == null)
return null;
foreach (var field in type.Fields) {
if (field.FieldType.FullName == typeFullName)
return field;
}
return null;
}
public static FieldDefinition getFieldByName(TypeDefinition type, string name) {
if (type == null)
return null;
foreach (var field in type.Fields) {
if (field.Name == name)
return field;
}
return null;
}
public static IEnumerable getMethodCalls(MethodDef method) {
var list = new List();
if (method.HasBody) {
foreach (var instr in method.Body.Instructions) {
var calledMethod = instr.Operand as MethodReference;
if (calledMethod != null)
list.Add(calledMethod);
}
}
return list;
}
public static MethodCalls getMethodCallCounts(MethodDef method) {
var methodCalls = new MethodCalls();
methodCalls.addMethodCalls(method);
return methodCalls;
}
public static bool hasString(MethodDef method, string s) {
if (method == null || method.Body == null)
return false;
foreach (var instr in method.Body.Instructions) {
if (instr.OpCode.Code == Code.Ldstr && (string)instr.Operand == s)
return true;
}
return false;
}
public static IList getCodeStrings(MethodDef method) {
var strings = new List();
if (method != null && method.Body != null) {
foreach (var instr in method.Body.Instructions) {
if (instr.OpCode.Code == Code.Ldstr)
strings.Add((string)instr.Operand);
}
}
return strings;
}
#endif
public static IList getCodeStrings(MethodDef method) {
var strings = new List();
if (method != null && method.Body != null) {
foreach (var instr in method.Body.Instructions) {
if (instr.OpCode.Code == Code.Ldstr)
strings.Add((string)instr.Operand);
}
}
return strings;
}
public static Resource getResource(ModuleDef module, string name) {
return getResource(module, new List { name });
}
public static Resource getResource(ModuleDef module, IEnumerable strings) {
if (!module.HasResources)
return null;
var resources = module.Resources;
foreach (var tmp in strings) {
var resourceName = removeFromNullChar(tmp);
if (resourceName == null)
continue;
var name = new UTF8String(resourceName);
foreach (var resource in resources) {
if (UTF8String.Equals(resource.Name, name))
return resource;
}
}
return null;
}
static string removeFromNullChar(string s) {
int index = s.IndexOf((char)0);
if (index < 0)
return s;
return s.Substring(0, index);
}
#if PORT
// Copies most things but not everything
public static MethodDef clone(MethodDef method) {
var newMethod = new MethodDef(method.Name, method.Attributes, method.MethodReturnType.ReturnType);
newMethod.MetadataToken = method.MetadataToken;
newMethod.Attributes = method.Attributes;
newMethod.ImplAttributes = method.ImplAttributes;
newMethod.HasThis = method.HasThis;
newMethod.ExplicitThis = method.ExplicitThis;
newMethod.CallingConvention = method.CallingConvention;
newMethod.SemanticsAttributes = method.SemanticsAttributes;
newMethod.DeclaringType = method.DeclaringType;
foreach (var arg in method.Parameters)
newMethod.Parameters.Add(new ParameterDefinition(arg.Name, arg.Attributes, arg.ParameterType));
foreach (var gp in method.GenericParameters)
newMethod.GenericParameters.Add(new GenericParameter(gp.Name, newMethod) { Attributes = gp.Attributes });
copyBodyFromTo(method, newMethod);
return newMethod;
}
#endif
// Copies most things but not everything
public static MethodDef clone(MethodDef method) {
var newMethod = new MethodDefUser(method.Name, method.MethodSig, method.ImplFlags, method.Flags);
newMethod.Rid = method.Rid;
newMethod.DeclaringType2 = method.DeclaringType;
foreach (var pd in method.ParamList)
newMethod.ParamList.Add(new ParamDefUser(pd.Name, pd.Sequence, pd.Flags));
foreach (var gp in method.GenericParams) {
var newGp = new GenericParamUser(gp.Number, gp.Flags, gp.Name);
foreach (var gpc in newGp.GenericParamConstraints)
newGp.GenericParamConstraints.Add(new GenericParamConstraintUser(gpc.Constraint));
newMethod.GenericParams.Add(newGp);
}
copyBodyFromTo(method, newMethod);
return method;
}
#if PORT
public static Instruction clone(Instruction instr) {
return new Instruction {
Offset = instr.Offset,
OpCode = instr.OpCode,
Operand = instr.Operand,
SequencePoint = instr.SequencePoint,
};
}
public static void copyBody(MethodDef method, out IList instructions, out IList exceptionHandlers) {
if (method == null || !method.HasBody) {
instructions = new List();
exceptionHandlers = new List();
return;
}
var oldInstrs = method.Body.Instructions;
var oldExHandlers = method.Body.ExceptionHandlers;
instructions = new List(oldInstrs.Count);
exceptionHandlers = new List(oldExHandlers.Count);
var oldToIndex = Utils.createObjectToIndexDictionary(oldInstrs);
foreach (var oldInstr in oldInstrs)
instructions.Add(clone(oldInstr));
foreach (var newInstr in instructions) {
var operand = newInstr.Operand;
if (operand is Instruction)
newInstr.Operand = instructions[oldToIndex[(Instruction)operand]];
else if (operand is Instruction[]) {
var oldArray = (Instruction[])operand;
var newArray = new Instruction[oldArray.Length];
for (int i = 0; i < oldArray.Length; i++)
newArray[i] = instructions[oldToIndex[oldArray[i]]];
newInstr.Operand = newArray;
}
}
foreach (var oldEx in oldExHandlers) {
var newEx = new ExceptionHandler(oldEx.HandlerType) {
TryStart = getInstruction(instructions, oldToIndex, oldEx.TryStart),
TryEnd = getInstruction(instructions, oldToIndex, oldEx.TryEnd),
FilterStart = getInstruction(instructions, oldToIndex, oldEx.FilterStart),
HandlerStart= getInstruction(instructions, oldToIndex, oldEx.HandlerStart),
HandlerEnd = getInstruction(instructions, oldToIndex, oldEx.HandlerEnd),
CatchType = oldEx.CatchType,
};
exceptionHandlers.Add(newEx);
}
}
static Instruction getInstruction(IList instructions, IDictionary instructionToIndex, Instruction instruction) {
if (instruction == null)
return null;
return instructions[instructionToIndex[instruction]];
}
#endif
public static void copyBody(MethodDef method, out IList instructions, out IList exceptionHandlers) {
if (method == null || !method.HasBody) {
instructions = new List();
exceptionHandlers = new List();
return;
}
var oldInstrs = method.Body.Instructions;
var oldExHandlers = method.Body.ExceptionHandlers;
instructions = new List(oldInstrs.Count);
exceptionHandlers = new List(oldExHandlers.Count);
var oldToIndex = Utils.createObjectToIndexDictionary(oldInstrs);
foreach (var oldInstr in oldInstrs)
instructions.Add(oldInstr.Clone());
foreach (var newInstr in instructions) {
var operand = newInstr.Operand;
if (operand is Instruction)
newInstr.Operand = instructions[oldToIndex[(Instruction)operand]];
else if (operand is IList) {
var oldArray = (IList)operand;
var newArray = new Instruction[oldArray.Count];
for (int i = 0; i < oldArray.Count; i++)
newArray[i] = instructions[oldToIndex[oldArray[i]]];
newInstr.Operand = newArray;
}
}
foreach (var oldEx in oldExHandlers) {
var newEx = new ExceptionHandler(oldEx.HandlerType) {
TryStart = getInstruction(instructions, oldToIndex, oldEx.TryStart),
TryEnd = getInstruction(instructions, oldToIndex, oldEx.TryEnd),
FilterStart = getInstruction(instructions, oldToIndex, oldEx.FilterStart),
HandlerStart = getInstruction(instructions, oldToIndex, oldEx.HandlerStart),
HandlerEnd = getInstruction(instructions, oldToIndex, oldEx.HandlerEnd),
CatchType = oldEx.CatchType,
};
exceptionHandlers.Add(newEx);
}
}
static Instruction getInstruction(IList instructions, IDictionary instructionToIndex, Instruction instruction) {
if (instruction == null)
return null;
return instructions[instructionToIndex[instruction]];
}
#if PORT
public static void restoreBody(MethodDef method, IEnumerable instructions, IEnumerable exceptionHandlers) {
if (method == null || !method.HasBody)
return;
var bodyInstrs = method.Body.Instructions;
bodyInstrs.Clear();
foreach (var instr in instructions)
bodyInstrs.Add(instr);
var bodyExceptionHandlers = method.Body.ExceptionHandlers;
bodyExceptionHandlers.Clear();
foreach (var eh in exceptionHandlers)
bodyExceptionHandlers.Add(eh);
}
#endif
public static void restoreBody(MethodDef method, IEnumerable instructions, IEnumerable exceptionHandlers) {
if (method == null || method.Body == null)
return;
var bodyInstrs = method.Body.Instructions;
bodyInstrs.Clear();
foreach (var instr in instructions)
bodyInstrs.Add(instr);
var bodyExceptionHandlers = method.Body.ExceptionHandlers;
bodyExceptionHandlers.Clear();
foreach (var eh in exceptionHandlers)
bodyExceptionHandlers.Add(eh);
}
public static void copyBodyFromTo(MethodDef fromMethod, MethodDef toMethod) {
if (fromMethod == toMethod)
return;
IList instructions;
IList exceptionHandlers;
copyBody(fromMethod, out instructions, out exceptionHandlers);
restoreBody(toMethod, instructions, exceptionHandlers);
copyLocalsFromTo(fromMethod, toMethod);
updateInstructionOperands(fromMethod, toMethod);
}
static void copyLocalsFromTo(MethodDef fromMethod, MethodDef toMethod) {
var fromBody = fromMethod.Body;
var toBody = toMethod.Body;
toBody.LocalList.Clear();
foreach (var local in fromBody.LocalList)
toBody.LocalList.Add(new Local(local.Type));
}
static void updateInstructionOperands(MethodDef fromMethod, MethodDef toMethod) {
var fromBody = fromMethod.Body;
var toBody = toMethod.Body;
toBody.InitLocals = fromBody.InitLocals;
toBody.MaxStack = fromBody.MaxStack;
var newOperands = new Dictionary