Methodsrewriter is now working

This commit is contained in:
de4dot 2011-09-27 22:06:43 +02:00
parent 695dd81b43
commit c257f16787
14 changed files with 504 additions and 171 deletions

View File

@ -51,12 +51,14 @@
<Compile Include="methodsrewriter\Operand.cs" />
<Compile Include="methodsrewriter\Resolver.cs" />
<Compile Include="methodsrewriter\ResolverUtils.cs" />
<Compile Include="methodsrewriter\TypeInstanceResolver.cs" />
<Compile Include="methodsrewriter\TypeResolver.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="SimpleData.cs" />
<Compile Include="Utils.cs" />
</ItemGroup>
<ItemGroup>
<Reference Include="System" />
<Reference Include="System.Runtime.Remoting" />
</ItemGroup>
<ItemGroup>

View File

@ -69,6 +69,10 @@ namespace AssemblyData {
static class Utils {
static Random random = new Random();
public static uint getRandomUint() {
return (uint)(random.NextDouble() * uint.MaxValue);
}
public static Type getDelegateType(Type returnType, Type[] args) {
Type[] types;
if (returnType == typeof(void)) {

View File

@ -56,6 +56,7 @@ namespace AssemblyData.methodsrewriter {
}
IMethodsRewriter methodsRewriter;
string methodName;
IList<Instruction> allInstructions;
IList<ExceptionHandler> allExceptionHandlers;
ILGenerator ilg;
@ -69,13 +70,15 @@ namespace AssemblyData.methodsrewriter {
List<LocalBuilder> locals;
List<Label> labels;
Dictionary<Instruction, int> instrToIndex;
Stack<ExceptionHandler> exceptionHandlersStack;
public Type DelegateType {
get { return delegateType; }
}
public CodeGenerator(IMethodsRewriter methodsRewriter) {
public CodeGenerator(IMethodsRewriter methodsRewriter, string methodName) {
this.methodsRewriter = methodsRewriter;
this.methodName = methodName;
}
public void setMethodInfo(MMethod methodInfo) {
@ -89,7 +92,7 @@ namespace AssemblyData.methodsrewriter {
this.allInstructions = allInstructions;
this.allExceptionHandlers = allExceptionHandlers;
var dm = new DynamicMethod("emulated_" + methodInfo.methodBase.Name, methodReturnType, methodParameters, methodsRewriter.GetType(), true);
var dm = new DynamicMethod(methodName, methodReturnType, methodParameters, methodInfo.methodBase.Module, true);
var lastInstr = allInstructions[allInstructions.Count - 1];
ilg = dm.GetILGenerator(lastInstr.Offset + lastInstr.GetSize());
@ -97,7 +100,9 @@ namespace AssemblyData.methodsrewriter {
initLocals();
initLabels();
exceptionHandlersStack = new Stack<ExceptionHandler>();
for (int i = 0; i < allInstructions.Count; i++) {
updateExceptionHandlers(i);
var instr = allInstructions[i];
ilg.MarkLabel(labels[i]);
if (instr.Operand is Operand)
@ -105,10 +110,66 @@ namespace AssemblyData.methodsrewriter {
else
writeInstr(instr);
}
updateExceptionHandlers(-1);
return dm.CreateDelegate(delegateType);
}
Instruction getExceptionInstruction(int instructionIndex) {
return instructionIndex < 0 ? null : allInstructions[instructionIndex];
}
void updateExceptionHandlers(int instructionIndex) {
var instr = getExceptionInstruction(instructionIndex);
updateExceptionHandlers(instr);
if (addTryStart(instr))
updateExceptionHandlers(instr);
}
void updateExceptionHandlers(Instruction instr) {
while (exceptionHandlersStack.Count > 0) {
var ex = exceptionHandlersStack.Peek();
if (ex.TryEnd == instr) {
}
if (ex.FilterStart == instr) {
}
if (ex.HandlerStart == instr) {
if (ex.HandlerType == ExceptionHandlerType.Finally)
ilg.BeginFinallyBlock();
else
ilg.BeginCatchBlock(Resolver.getRtType(ex.CatchType));
}
if (ex.HandlerEnd == instr) {
exceptionHandlersStack.Pop();
if (exceptionHandlersStack.Count == 0 || !isSameTryBlock(ex, exceptionHandlersStack.Peek()))
ilg.EndExceptionBlock();
}
else
break;
}
}
bool addTryStart(Instruction instr) {
var list = new List<ExceptionHandler>();
foreach (var ex in allExceptionHandlers) {
if (ex.TryStart == instr)
list.Add(ex);
}
list.Reverse();
foreach (var ex in list) {
if (exceptionHandlersStack.Count == 0 || !isSameTryBlock(ex, exceptionHandlersStack.Peek()))
ilg.BeginExceptionBlock();
exceptionHandlersStack.Push(ex);
}
return list.Count > 0;
}
static bool isSameTryBlock(ExceptionHandler ex1, ExceptionHandler ex2) {
return ex1.TryStart == ex2.TryStart && ex1.TryEnd == ex2.TryEnd;
}
void initInstrToIndex() {
instrToIndex = new Dictionary<Instruction, int>(allInstructions.Count);
for (int i = 0; i < allInstructions.Count; i++)
@ -118,7 +179,7 @@ namespace AssemblyData.methodsrewriter {
void initLocals() {
locals = new List<LocalBuilder>();
foreach (var local in methodInfo.methodDefinition.Body.Variables)
locals.Add(ilg.DeclareLocal(methodsRewriter.getRtType(local.VariableType), local.IsPinned));
locals.Add(ilg.DeclareLocal(Resolver.getRtType(local.VariableType), local.IsPinned));
tempObjLocal = ilg.DeclareLocal(typeof(object));
tempObjArrayLocal = ilg.DeclareLocal(typeof(object[]));
}
@ -250,7 +311,7 @@ namespace AssemblyData.methodsrewriter {
case OperandType.InlineType:
case OperandType.InlineMethod:
case OperandType.InlineField:
var obj = methodsRewriter.getRtObject((MemberReference)instr.Operand);
var obj = Resolver.getRtObject((MemberReference)instr.Operand);
if (obj is ConstructorInfo)
ilg.Emit(opcode, (ConstructorInfo)obj);
else if (obj is MethodInfo)
@ -279,7 +340,6 @@ namespace AssemblyData.methodsrewriter {
ilg.Emit(opcode, checked((byte)getLocalIndex((VariableDefinition)instr.Operand)));
break;
case OperandType.InlineSig: //TODO:
default:
throw new ApplicationException(string.Format("Unknown OperandType {0}", instr.OpCode.OperandType));

View File

@ -23,8 +23,6 @@ using Mono.Cecil;
namespace AssemblyData.methodsrewriter {
interface IMethodsRewriter {
Type getRtType(TypeReference typeReference);
object getRtObject(MemberReference memberReference);
Type getDelegateType(MethodBase methodBase);
}
}

View File

@ -34,11 +34,56 @@ using ROpCodes = System.Reflection.Emit.OpCodes;
namespace AssemblyData.methodsrewriter {
delegate object RewrittenMethod(object[] args);
class MethodsFinder {
Dictionary<Module, MethodsModule> moduleToMethods = new Dictionary<Module, MethodsModule>();
class MethodsModule {
const int MAX_METHODS = 30;
List<MethodBase> methods = new List<MethodBase>(MAX_METHODS);
int next;
public MethodsModule(Module module) {
var flags = BindingFlags.DeclaredOnly | BindingFlags.Instance | BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic;
foreach (var type in module.GetTypes()) {
if (methods.Count >= MAX_METHODS)
break;
foreach (var method in type.GetMethods(flags)) {
if (methods.Count >= MAX_METHODS)
break;
methods.Add(method);
}
}
foreach (var method in module.GetMethods(flags)) {
if (methods.Count >= MAX_METHODS)
break;
methods.Add(method);
}
}
public MethodBase getNext() {
return methods[next++ % methods.Count];
}
}
public MethodBase getMethod(Module module) {
MethodsModule methodsModule;
if (!moduleToMethods.TryGetValue(module, out methodsModule))
moduleToMethods[module] = methodsModule = new MethodsModule(module);
return methodsModule.getNext();
}
}
class MethodsRewriter : IMethodsRewriter {
Dictionary<Module, MModule> modules = new Dictionary<Module, MModule>();
MethodsFinder methodsFinder = new MethodsFinder();
Dictionary<MethodBase, NewMethodInfo> realMethodToNewMethod = new Dictionary<MethodBase, NewMethodInfo>();
List<NewMethodInfo> newMethodInfos = new List<NewMethodInfo>();
// There's no documented way to get a dynamic method's MethodInfo. If we name the
// method and it's a unique random name, we can still find the emulated method.
Dictionary<string, NewMethodInfo> delegateNameToNewMethodInfo = new Dictionary<string, NewMethodInfo>(StringComparer.Ordinal);
class NewMethodInfo {
// Original method
public MethodBase oldMethod;
@ -53,101 +98,26 @@ namespace AssemblyData.methodsrewriter {
public RewrittenMethod rewrittenMethod;
public NewMethodInfo(MethodBase oldMethod) {
// Name of method used by delegateInstance
public string delegateMethodName;
// Name of method used by rewrittenMethod
public string rewrittenMethodName;
public NewMethodInfo(MethodBase oldMethod, int delegateIndex, string delegateMethodName, string rewrittenMethodName) {
this.oldMethod = oldMethod;
this.delegateIndex = delegateIndex;
this.delegateMethodName = delegateMethodName;
this.rewrittenMethodName = rewrittenMethodName;
}
}
MModule loadAssembly(Module module) {
MModule info;
if (modules.TryGetValue(module, out info))
return info;
info = new MModule(module, ModuleDefinition.ReadModule(module.FullyQualifiedName));
modules[module] = info;
return info;
}
MModule getModule(ModuleDefinition moduleDefinition) {
foreach (var mm in modules.Values) {
if (mm.moduleDefinition == moduleDefinition)
return mm;
public bool isRewrittenMethod(string name) {
return name == rewrittenMethodName;
}
return null;
}
MModule getModule(AssemblyNameReference assemblyRef) {
foreach (var mm in modules.Values) {
var asm = mm.moduleDefinition.Assembly;
if (asm.Name.FullName == assemblyRef.FullName)
return mm;
public bool isDelegateMethod(string name) {
return name == delegateMethodName;
}
return null;
}
MModule getModule(IMetadataScope scope) {
if (scope is ModuleDefinition)
return getModule((ModuleDefinition)scope);
else if (scope is AssemblyNameReference)
return getModule((AssemblyNameReference)scope);
return null;
}
MType getType(TypeReference typeReference) {
var module = getModule(typeReference.Scope);
if (module != null)
return module.getType(typeReference);
return null;
}
MMethod getMethod(MethodReference methodReference) {
var module = getModule(methodReference.DeclaringType.Scope);
if (module != null)
return module.getMethod(methodReference);
return null;
}
MField getField(FieldReference fieldReference) {
var module = getModule(fieldReference.DeclaringType.Scope);
if (module != null)
return module.getField(fieldReference);
return null;
}
public object getRtObject(MemberReference memberReference) {
if (memberReference is TypeReference)
return getRtType((TypeReference)memberReference);
else if (memberReference is FieldReference)
return getRtField((FieldReference)memberReference);
else if (memberReference is MethodReference)
return getRtMethod((MethodReference)memberReference);
throw new ApplicationException(string.Format("Unknown MemberReference: {0}", memberReference));
}
public Type getRtType(TypeReference typeReference) {
var mtype = getType(typeReference);
if (mtype != null)
return mtype.type;
return Resolver.resolve(typeReference);
}
public FieldInfo getRtField(FieldReference fieldReference) {
var mfield = getField(fieldReference);
if (mfield != null)
return mfield.fieldInfo;
return Resolver.resolve(fieldReference);
}
public MethodBase getRtMethod(MethodReference methodReference) {
var mmethod = getMethod(methodReference);
if (mmethod != null)
return mmethod.methodBase;
return Resolver.resolve(methodReference);
}
public Type getDelegateType(MethodBase methodBase) {
@ -159,12 +129,13 @@ namespace AssemblyData.methodsrewriter {
if (newMethodInfo.rewrittenMethod != null)
return newMethodInfo.rewrittenMethod;
var dm = new DynamicMethod("method_" + newMethodInfo.oldMethod.Name, typeof(object), new Type[] { GetType(), typeof(object[]) }, GetType(), true);
var dm = new DynamicMethod(newMethodInfo.rewrittenMethodName, typeof(object), new Type[] { GetType(), typeof(object[]) }, newMethodInfo.oldMethod.Module, true);
var ilg = dm.GetILGenerator();
ilg.Emit(ROpCodes.Ldarg_0);
ilg.Emit(ROpCodes.Ldc_I4, newMethodInfo.delegateIndex);
ilg.Emit(ROpCodes.Call, GetType().GetMethod("rtGetDelegateInstance", BindingFlags.DeclaredOnly | BindingFlags.NonPublic | BindingFlags.Instance));
ilg.Emit(ROpCodes.Castclass, newMethodInfo.delegateType);
var args = newMethodInfo.oldMethod.GetParameters();
for (int i = 0; i < args.Length; i++) {
@ -192,20 +163,29 @@ namespace AssemblyData.methodsrewriter {
return newMethodInfo.rewrittenMethod;
}
string getDelegateMethodName(string methodName) {
string name = null;
do {
name = string.Format(" {0} DMN {1:X8} ", methodName, Utils.getRandomUint());
} while (delegateNameToNewMethodInfo.ContainsKey(name));
return name;
}
public void createMethod(MethodBase realMethod) {
if (realMethodToNewMethod.ContainsKey(realMethod))
return;
var newMethodInfo = new NewMethodInfo(realMethod);
newMethodInfo.delegateIndex = newMethodInfos.Count;
var newMethodInfo = new NewMethodInfo(realMethod, newMethodInfos.Count, getDelegateMethodName(realMethod.Name), getDelegateMethodName(realMethod.Name));
newMethodInfos.Add(newMethodInfo);
delegateNameToNewMethodInfo[newMethodInfo.delegateMethodName] = newMethodInfo;
delegateNameToNewMethodInfo[newMethodInfo.rewrittenMethodName] = newMethodInfo;
realMethodToNewMethod[realMethod] = newMethodInfo;
var moduleInfo = loadAssembly(realMethod.Module);
var moduleInfo = Resolver.loadAssembly(realMethod.Module);
var methodInfo = moduleInfo.getMethod(realMethod);
if (!methodInfo.methodDefinition.HasBody || methodInfo.methodDefinition.Body.Instructions.Count == 0)
throw new ApplicationException(string.Format("Method {0} ({1:X8}) has no body", methodInfo.methodDefinition, methodInfo.methodDefinition.MetadataToken.ToUInt32()));
var codeGenerator = new CodeGenerator(this);
var codeGenerator = new CodeGenerator(this, newMethodInfo.delegateMethodName);
codeGenerator.setMethodInfo(methodInfo);
newMethodInfo.delegateType = codeGenerator.DelegateType;
@ -273,7 +253,7 @@ namespace AssemblyData.methodsrewriter {
}
}
var method = getMethod((MethodReference)instr.Operand);
var method = Resolver.getMethod((MethodReference)instr.Operand);
if (method != null) {
createMethod(method.methodBase);
var newMethodInfo = realMethodToNewMethod[method.methodBase];
@ -340,13 +320,61 @@ namespace AssemblyData.methodsrewriter {
return list;
}
static FieldInfo getStackTraceStackFramesField() {
var flags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance;
return ResolverUtils.getFieldThrow(typeof(StackTrace), typeof(StackFrame[]), flags, "Could not find StackTrace's frames (StackFrame[]) field");
}
static FieldInfo getStackFrameMethodField() {
var flags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance;
return ResolverUtils.getFieldThrow(typeof(StackFrame), typeof(MethodBase), flags, "Could not find StackFrame's method (MethodBase) field");
}
static void writeMethodBase(StackFrame frame, MethodBase method) {
var methodField = getStackFrameMethodField();
methodField.SetValue(frame, method);
if (frame.GetMethod() != method)
throw new ApplicationException(string.Format("Could not set new method: {0}", method));
}
NewMethodInfo getNewMethodInfo(string name) {
NewMethodInfo info;
delegateNameToNewMethodInfo.TryGetValue(name, out info);
return info;
}
// Called after the StackTrace ctor has been called.
static StackTrace static_rtFixStackTrace(StackTrace stackTrace, MethodsRewriter self) {
return self.rtFixStackTrace(stackTrace);
}
StackTrace rtFixStackTrace(StackTrace stackTrace) {
//TODO:
var framesField = getStackTraceStackFramesField();
var frames = (StackFrame[])framesField.GetValue(stackTrace);
var newFrames = new List<StackFrame>(frames.Length);
foreach (var frame in frames) {
var method = frame.GetMethod();
var info = getNewMethodInfo(method.Name);
if (info == null) {
newFrames.Add(frame);
}
else if (info.isRewrittenMethod(method.Name)) {
// Write random method from the same module
writeMethodBase(frame, methodsFinder.getMethod(info.oldMethod.Module));
newFrames.Add(frame);
}
else if (info.isDelegateMethod(method.Name)) {
// Write original method
writeMethodBase(frame, info.oldMethod);
newFrames.Add(frame);
}
else {
throw new ApplicationException("BUG: Shouldn't be here");
}
}
framesField.SetValue(stackTrace, newFrames.ToArray());
return stackTrace;
}

View File

@ -40,5 +40,9 @@ namespace AssemblyData.methodsrewriter {
this.type = type;
this.data = data;
}
public override string ToString() {
return "{" + type + " => " + data + "}";
}
}
}

View File

@ -24,8 +24,101 @@ using Mono.Cecil;
using de4dot.blocks;
namespace AssemblyData.methodsrewriter {
public static class Resolver {
static class Resolver {
static Dictionary<string, AssemblyResolver> assemblyResolvers = new Dictionary<string, AssemblyResolver>(StringComparer.Ordinal);
static Dictionary<Module, MModule> modules = new Dictionary<Module, MModule>();
public static MModule loadAssembly(Module module) {
MModule info;
if (modules.TryGetValue(module, out info))
return info;
info = new MModule(module, ModuleDefinition.ReadModule(module.FullyQualifiedName));
modules[module] = info;
return info;
}
static MModule getModule(ModuleDefinition moduleDefinition) {
foreach (var mm in modules.Values) {
if (mm.moduleDefinition == moduleDefinition)
return mm;
}
return null;
}
static MModule getModule(AssemblyNameReference assemblyRef) {
foreach (var mm in modules.Values) {
var asm = mm.moduleDefinition.Assembly;
if (asm.Name.FullName == assemblyRef.FullName)
return mm;
}
return null;
}
public static MModule getModule(IMetadataScope scope) {
if (scope is ModuleDefinition)
return getModule((ModuleDefinition)scope);
else if (scope is AssemblyNameReference)
return getModule((AssemblyNameReference)scope);
return null;
}
public static MType getType(TypeReference typeReference) {
var module = getModule(typeReference.Scope);
if (module != null)
return module.getType(typeReference);
return null;
}
public static MMethod getMethod(MethodReference methodReference) {
var module = getModule(methodReference.DeclaringType.Scope);
if (module != null)
return module.getMethod(methodReference);
return null;
}
public static MField getField(FieldReference fieldReference) {
var module = getModule(fieldReference.DeclaringType.Scope);
if (module != null)
return module.getField(fieldReference);
return null;
}
public static object getRtObject(MemberReference memberReference) {
if (memberReference is TypeReference)
return getRtType((TypeReference)memberReference);
else if (memberReference is FieldReference)
return getRtField((FieldReference)memberReference);
else if (memberReference is MethodReference)
return getRtMethod((MethodReference)memberReference);
throw new ApplicationException(string.Format("Unknown MemberReference: {0}", memberReference));
}
public static Type getRtType(TypeReference typeReference) {
var mtype = getType(typeReference);
if (mtype != null)
return mtype.type;
return Resolver.resolve(typeReference);
}
public static FieldInfo getRtField(FieldReference fieldReference) {
var mfield = getField(fieldReference);
if (mfield != null)
return mfield.fieldInfo;
return Resolver.resolve(fieldReference);
}
public static MethodBase getRtMethod(MethodReference methodReference) {
var mmethod = getMethod(methodReference);
if (mmethod != null)
return mmethod.methodBase;
return Resolver.resolve(methodReference);
}
static AssemblyResolver getAssemblyResolver(IMetadataScope scope) {
var asmName = DotNetUtils.getFullAssemblyName(scope);
@ -35,7 +128,7 @@ namespace AssemblyData.methodsrewriter {
return resolver;
}
public static Type resolve(TypeReference typeReference) {
static Type resolve(TypeReference typeReference) {
var elemType = typeReference.GetElementType();
var resolver = getAssemblyResolver(elemType.Scope);
var resolvedType = resolver.resolve(elemType);
@ -44,7 +137,7 @@ namespace AssemblyData.methodsrewriter {
throw new ApplicationException(string.Format("Could not resolve type {0} ({1:X8}) in assembly {2}", typeReference, typeReference.MetadataToken.ToUInt32(), resolver));
}
public static FieldInfo resolve(FieldReference fieldReference) {
static FieldInfo resolve(FieldReference fieldReference) {
var resolver = getAssemblyResolver(fieldReference.DeclaringType.Scope);
var fieldInfo = resolver.resolve(fieldReference);
if (fieldInfo != null)
@ -52,7 +145,7 @@ namespace AssemblyData.methodsrewriter {
throw new ApplicationException(string.Format("Could not resolve field {0} ({1:X8}) in assembly {2}", fieldReference, fieldReference.MetadataToken.ToUInt32(), resolver));
}
public static MethodBase resolve(MethodReference methodReference) {
static MethodBase resolve(MethodReference methodReference) {
var resolver = getAssemblyResolver(methodReference.DeclaringType.Scope);
var methodBase = resolver.resolve(methodReference);
if (methodBase != null)

View File

@ -91,9 +91,10 @@ namespace AssemblyData.methodsrewriter {
for (int i = 0; i < aGpargs.Length; i++) {
var aArg = aGpargs[i];
var bArg = bGpargs[i];
if (aArg.IsGenericParameter)
continue;
if (!compareTypes(aArg, bGpargs[i]))
if (!compareTypes(aArg, bArg))
return false;
}
@ -246,5 +247,74 @@ namespace AssemblyData.methodsrewriter {
foreach (var m in type.GetMethods(flags))
yield return m;
}
class CachedMemberInfo {
Type type;
Type memberType;
public CachedMemberInfo(Type type, Type memberType) {
this.type = type;
this.memberType = memberType;
}
public override int GetHashCode() {
return type.GetHashCode() ^ memberType.GetHashCode();
}
public override bool Equals(object obj) {
var other = obj as CachedMemberInfo;
if (other == null)
return false;
return type == other.type && memberType == other.memberType;
}
}
static Dictionary<CachedMemberInfo, FieldInfo> cachedFieldInfos = new Dictionary<CachedMemberInfo, FieldInfo>();
public static FieldInfo getField(Type type, Type fieldType, BindingFlags flags) {
var key = new CachedMemberInfo(type, fieldType);
FieldInfo fieldInfo;
if (cachedFieldInfos.TryGetValue(key, out fieldInfo))
return fieldInfo;
foreach (var field in type.GetFields(flags)) {
if (field.FieldType == fieldType) {
cachedFieldInfos[key] = field;
return field;
}
}
return null;
}
public static FieldInfo getFieldThrow(Type type, Type fieldType, BindingFlags flags, string msg) {
var info = getField(type, fieldType, flags);
if (info != null)
return info;
throw new ApplicationException(msg);
}
public static List<FieldInfo> getFields(Type type, Type fieldType, BindingFlags flags) {
var list = new List<FieldInfo>();
foreach (var field in type.GetFields(flags)) {
if (field.FieldType == fieldType)
list.Add(field);
}
return list;
}
public static Type makeInstanceType(Type type, TypeReference typeReference) {
var git = typeReference as GenericInstanceType;
if (git == null)
return type;
var types = new Type[git.GenericArguments.Count];
bool isTypeDef = true;
for (int i = 0; i < git.GenericArguments.Count; i++) {
var arg = git.GenericArguments[i];
if (!(arg is GenericParameter))
isTypeDef = false;
types[i] = Resolver.getRtType(arg);
}
if (isTypeDef)
return type;
return type.MakeGenericType(types);
}
}
}

View File

@ -0,0 +1,102 @@
/*
Copyright (C) 2011 de4dot@gmail.com
This file is part of de4dot.
de4dot is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
de4dot is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with de4dot. If not, see <http://www.gnu.org/licenses/>.
*/
using System;
using System.Collections.Generic;
using System.Reflection;
using Mono.Cecil;
using de4dot.blocks;
namespace AssemblyData.methodsrewriter {
class TypeInstanceResolver {
Type type;
Dictionary<string, List<MethodBase>> methods;
Dictionary<string, List<FieldInfo>> fields;
public TypeInstanceResolver(Type type, TypeReference typeReference) {
this.type = ResolverUtils.makeInstanceType(type, typeReference);
}
public FieldInfo resolve(FieldReference fieldReference) {
initFields();
List<FieldInfo> list;
if (!fields.TryGetValue(fieldReference.Name, out list))
return null;
var git = fieldReference.DeclaringType as GenericInstanceType;
if (git != null)
fieldReference = new FieldReferenceExpander(fieldReference, git).expand();
foreach (var field in list) {
if (ResolverUtils.compareFields(field, fieldReference))
return field;
}
return null;
}
void initFields() {
if (fields != null)
return;
fields = new Dictionary<string, List<FieldInfo>>(StringComparer.Ordinal);
var flags = BindingFlags.DeclaredOnly | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance;
foreach (var field in type.GetFields(flags)) {
List<FieldInfo> list;
if (!fields.TryGetValue(field.Name, out list))
fields[field.Name] = list = new List<FieldInfo>();
list.Add(field);
}
}
public MethodBase resolve(MethodReference methodReference) {
initMethods();
List<MethodBase> list;
if (!methods.TryGetValue(methodReference.Name, out list))
return null;
var git = methodReference.DeclaringType as GenericInstanceType;
if (git != null)
methodReference = new MethodReferenceExpander(methodReference, git).expand();
foreach (var method in list) {
if (ResolverUtils.compareMethods(method, methodReference))
return method;
}
return null;
}
void initMethods() {
if (methods != null)
return;
methods = new Dictionary<string, List<MethodBase>>(StringComparer.Ordinal);
var flags = BindingFlags.DeclaredOnly | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance;
foreach (var method in ResolverUtils.getMethodBases(type, flags)) {
List<MethodBase> list;
if (!methods.TryGetValue(method.Name, out list))
methods[method.Name] = list = new List<MethodBase>();
list.Add(method);
}
}
}
}

View File

@ -21,73 +21,31 @@ using System;
using System.Collections.Generic;
using System.Reflection;
using Mono.Cecil;
using de4dot.blocks;
namespace AssemblyData.methodsrewriter {
class TypeResolver {
public Type type;
Dictionary<string, List<MethodBase>> methods;
Dictionary<string, List<FieldInfo>> fields;
Dictionary<TypeReferenceKey, TypeInstanceResolver> typeRefToInstance = new Dictionary<TypeReferenceKey, TypeInstanceResolver>();
public TypeResolver(Type type) {
this.type = type;
}
public FieldInfo resolve(FieldReference fieldReference) {
initFields();
List<FieldInfo> list;
if (!fields.TryGetValue(fieldReference.Name, out list))
return null;
foreach (var field in list) {
if (ResolverUtils.compareFields(field, fieldReference))
return field;
}
return null;
TypeInstanceResolver getTypeInstance(TypeReference typeReference) {
var key = new TypeReferenceKey(typeReference);
TypeInstanceResolver instance;
if (!typeRefToInstance.TryGetValue(key, out instance))
typeRefToInstance[key] = instance = new TypeInstanceResolver(type, typeReference);
return instance;
}
void initFields() {
if (fields != null)
return;
fields = new Dictionary<string, List<FieldInfo>>(StringComparer.Ordinal);
var flags = BindingFlags.DeclaredOnly | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance;
foreach (var field in type.GetFields(flags)) {
List<FieldInfo> list;
if (!fields.TryGetValue(field.Name, out list))
fields[field.Name] = list = new List<FieldInfo>();
list.Add(field);
}
public FieldInfo resolve(FieldReference fieldReference) {
return getTypeInstance(fieldReference.DeclaringType).resolve(fieldReference);
}
public MethodBase resolve(MethodReference methodReference) {
initMethods();
List<MethodBase> list;
if (!methods.TryGetValue(methodReference.Name, out list))
return null;
foreach (var method in list) {
if (ResolverUtils.compareMethods(method, methodReference))
return method;
}
return null;
}
void initMethods() {
if (methods != null)
return;
methods = new Dictionary<string, List<MethodBase>>(StringComparer.Ordinal);
var flags = BindingFlags.DeclaredOnly | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance;
foreach (var method in ResolverUtils.getMethodBases(type, flags)) {
List<MethodBase> list;
if (!methods.TryGetValue(method.Name, out list))
methods[method.Name] = list = new List<MethodBase>();
list.Add(method);
}
return getTypeInstance(methodReference.DeclaringType).resolve(methodReference);
}
}
}

View File

@ -19,10 +19,9 @@
using System;
using Mono.Cecil;
using de4dot.blocks;
namespace de4dot.renamer {
abstract class Expander {
namespace de4dot.blocks {
public abstract class Expander {
protected bool modified = false;
protected void checkModified(object a, object b) {
@ -31,7 +30,7 @@ namespace de4dot.renamer {
}
}
class TypeReferenceExpander : Expander {
public class TypeReferenceExpander : Expander {
TypeReference typeReference;
GenericInstanceType git;
@ -144,7 +143,7 @@ namespace de4dot.renamer {
}
}
abstract class MultiTypeExpander : Expander {
public abstract class MultiTypeExpander : Expander {
GenericInstanceType git;
public MultiTypeExpander(GenericInstanceType git) {
@ -162,7 +161,7 @@ namespace de4dot.renamer {
}
}
class MethodReferenceExpander : MultiTypeExpander {
public class MethodReferenceExpander : MultiTypeExpander {
MethodReference methodReference;
public MethodReferenceExpander(MethodReference methodReference, GenericInstanceType git)
@ -194,7 +193,21 @@ namespace de4dot.renamer {
}
}
class EventReferenceExpander : MultiTypeExpander {
public class FieldReferenceExpander : MultiTypeExpander {
FieldReference fieldReference;
public FieldReferenceExpander(FieldReference fieldReference, GenericInstanceType git)
: base(git) {
this.fieldReference = fieldReference;
}
public FieldReference expand() {
var fr = new FieldReference(fieldReference.Name, expandType(fieldReference.FieldType));
return getResult(fieldReference, fr);
}
}
public class EventReferenceExpander : MultiTypeExpander {
EventReference eventReference;
public EventReferenceExpander(EventReference eventReference, GenericInstanceType git)
@ -208,7 +221,7 @@ namespace de4dot.renamer {
}
}
class PropertyReferenceExpander : MultiTypeExpander {
public class PropertyReferenceExpander : MultiTypeExpander {
PropertyReference propertyReference;
public PropertyReferenceExpander(PropertyReference propertyReference, GenericInstanceType git)

View File

@ -43,6 +43,7 @@
<Compile Include="HandlerBlock.cs" />
<Compile Include="Instr.cs" />
<Compile Include="InstructionListParser.cs" />
<Compile Include="MemberReferenceExpander.cs" />
<Compile Include="MemberReferenceHelper.cs" />
<Compile Include="MethodBlocks.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />

View File

@ -97,7 +97,6 @@
<Compile Include="Program.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="renamer\DefinitionsRenamer.cs" />
<Compile Include="renamer\MemberReferenceExpander.cs" />
<Compile Include="renamer\MemberRefFinder.cs" />
<Compile Include="renamer\MemberRefs.cs" />
<Compile Include="renamer\MemberRenameState.cs" />

View File

@ -18,6 +18,7 @@
*/
using Mono.Cecil;
using de4dot.blocks;
namespace de4dot.renamer {
abstract class RefExpander {