Add code to map virtual methods to base/iface methods

This commit is contained in:
de4dot 2011-11-16 23:08:27 +01:00
parent 79eb228200
commit b58c3843e3
9 changed files with 579 additions and 7 deletions

View File

@ -159,6 +159,7 @@
<Compile Include="renamer\asmmodules\IResolver.cs" />
<Compile Include="renamer\asmmodules\MemberRefFinder.cs" />
<Compile Include="renamer\asmmodules\MethodDef.cs" />
<Compile Include="renamer\asmmodules\MethodNameScopes.cs" />
<Compile Include="renamer\asmmodules\Module.cs" />
<Compile Include="renamer\asmmodules\Modules.cs" />
<Compile Include="renamer\asmmodules\PropertyDef.cs" />

View File

@ -40,6 +40,8 @@ namespace de4dot.renamer {
Log.n("Renaming all obfuscated symbols");
modules.initialize();
modules.initializeVirtualMembers();
modules.cleanUp();
}
}
}

View File

@ -42,5 +42,13 @@ namespace de4dot.renamer.asmmodules {
yield return m;
}
}
public bool isVirtual() {
foreach (var method in methodDefinitions()) {
if (method.IsVirtual)
return true;
}
return false;
}
}
}

View File

@ -39,9 +39,12 @@ namespace de4dot.renamer.asmmodules {
return null;
}
public void unload() {
foreach (var module in asmDef.Modules)
public void unload(string asmFullName) {
foreach (var module in asmDef.Modules) {
DotNetUtils.typeCaches.invalidate(module);
AssemblyResolver.Instance.removeModule(module);
}
AssemblyResolver.Instance.removeModule(asmFullName);
}
}
@ -50,7 +53,7 @@ namespace de4dot.renamer.asmmodules {
Dictionary<string, ExternalAssembly> assemblies = new Dictionary<string, ExternalAssembly>();
ExternalAssembly load(TypeReference type) {
var asmFullName = DotNetUtils.getFullAssemblyName(type.Scope);
var asmFullName = DotNetUtils.getFullAssemblyName(type);
ExternalAssembly asm;
if (assemblies.TryGetValue(asmFullName, out asm))
return asm;
@ -93,10 +96,10 @@ namespace de4dot.renamer.asmmodules {
}
public void unloadAll() {
foreach (var asm in assemblies.Values) {
if (asm == null)
foreach (var pair in assemblies) {
if (pair.Value == null)
continue;
asm.unload();
pair.Value.unload(pair.Key);
}
assemblies.Clear();
}

View File

@ -28,5 +28,17 @@ namespace de4dot.renamer.asmmodules {
public MethodDef(MethodDefinition methodDefinition, TypeDef owner, int index)
: base(methodDefinition, owner, index) {
}
public bool isPublic() {
return MethodDefinition.IsPublic;
}
public bool isVirtual() {
return MethodDefinition.IsVirtual;
}
public bool isNewSlot() {
return MethodDefinition.IsNewSlot;
}
}
}

View File

@ -0,0 +1,82 @@
/*
Copyright (C) 2011 de4dot@gmail.com
This file is part of de4dot.
de4dot is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
de4dot is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with de4dot. If not, see <http://www.gnu.org/licenses/>.
*/
using System;
using System.Collections.Generic;
namespace de4dot.renamer.asmmodules {
class MethodNameScope {
List<MethodDef> methods = new List<MethodDef>();
public List<MethodDef> Methods {
get { return methods; }
}
public int Count {
get { return methods.Count; }
}
public void add(MethodDef method) {
methods.Add(method);
}
public void merge(MethodNameScope other) {
if (this == other)
return;
methods.AddRange(other.methods);
}
}
class MethodNameScopes {
Dictionary<MethodDef, MethodNameScope> methodScopes = new Dictionary<MethodDef, MethodNameScope>();
public void same(MethodDef a, MethodDef b) {
merge(get(a), get(b));
}
public void add(MethodDef methodDef) {
get(methodDef);
}
MethodNameScope get(MethodDef method) {
if (!method.isVirtual())
throw new ApplicationException("Not a virtual method");
MethodNameScope scope;
if (!methodScopes.TryGetValue(method, out scope)) {
methodScopes[method] = scope = new MethodNameScope();
scope.add(method);
}
return scope;
}
void merge(MethodNameScope a, MethodNameScope b) {
if (a == b)
return;
if (a.Count < b.Count) {
MethodNameScope tmp = a;
a = b;
b = tmp;
}
a.merge(b);
foreach (var methodDef in b.Methods)
methodScopes[methodDef] = a;
}
}
}

View File

@ -29,7 +29,7 @@ namespace de4dot.renamer.asmmodules {
Dictionary<ModuleDefinition, Module> modulesDict = new Dictionary<ModuleDefinition, Module>();
AssemblyHash assemblyHash = new AssemblyHash();
List<TypeDef> allTypes = new List<TypeDef>(); //TODO: Do we need this?
List<TypeDef> allTypes = new List<TypeDef>();
List<TypeDef> baseTypes = new List<TypeDef>(); //TODO: Do we need this?
List<TypeDef> nonNestedTypes; //TODO: Do we need this?
@ -240,6 +240,18 @@ namespace de4dot.renamer.asmmodules {
return otherTypesDict[key] = typeDef;
}
public void initializeVirtualMembers() {
var scopes = new MethodNameScopes();
foreach (var typeDef in allTypes)
typeDef.initializeVirtualMembers(scopes, this);
}
public void cleanUp() {
externalAssemblies.unloadAll();
foreach (var module in DotNetUtils.typeCaches.invalidateAll())
AssemblyResolver.Instance.removeModule(module);
}
// Returns null if it's a non-loaded module/assembly
IEnumerable<Module> findModules(TypeReference type) {
var scope = type.Scope;

View File

@ -40,5 +40,13 @@ namespace de4dot.renamer.asmmodules {
yield return m;
}
}
public bool isVirtual() {
foreach (var method in methodDefinitions()) {
if (method.IsVirtual)
return true;
}
return false;
}
}
}

View File

@ -20,6 +20,7 @@
using System;
using System.Collections.Generic;
using Mono.Cecil;
using de4dot.blocks;
namespace de4dot.renamer.asmmodules {
class TypeInfo {
@ -29,6 +30,172 @@ namespace de4dot.renamer.asmmodules {
this.typeReference = typeReference;
this.typeDef = typeDef;
}
public TypeInfo(TypeInfo other, GenericInstanceType git) {
this.typeReference = TypeReferenceInstance.make(other.typeReference, git);
this.typeDef = other.typeDef;
}
public override int GetHashCode() {
return typeDef.GetHashCode() +
MemberReferenceHelper.typeHashCode(typeReference);
}
public override bool Equals(object obj) {
var other = obj as TypeInfo;
if (other == null)
return false;
return typeDef == other.typeDef &&
MemberReferenceHelper.compareTypes(typeReference, other.typeReference);
}
public override string ToString() {
return typeReference.ToString();
}
}
class MethodInst {
public MethodDef origMethodDef;
public MethodReference methodReference;
public MethodInst(MethodDef origMethodDef, MethodReference methodReference) {
this.origMethodDef = origMethodDef;
this.methodReference = methodReference;
}
public override string ToString() {
return methodReference.ToString();
}
}
class MethodInstances {
Dictionary<MethodReferenceKey, List<MethodInst>> methodInstances = new Dictionary<MethodReferenceKey, List<MethodInst>>();
public void initializeFrom(MethodInstances other, GenericInstanceType git) {
foreach (var list in other.methodInstances.Values) {
foreach (var methodInst in list) {
MethodReference newMethod = MethodReferenceInstance.make(methodInst.methodReference, git);
add(new MethodInst(methodInst.origMethodDef, newMethod));
}
}
}
public void add(MethodInst methodInst) {
List<MethodInst> list;
var key = new MethodReferenceKey(methodInst.methodReference);
if (methodInst.origMethodDef.isNewSlot() || !methodInstances.TryGetValue(key, out list))
methodInstances[key] = list = new List<MethodInst>();
list.Add(methodInst);
}
public List<MethodInst> lookup(MethodReference methodReference) {
List<MethodInst> list;
methodInstances.TryGetValue(new MethodReferenceKey(methodReference), out list);
return list;
}
public IEnumerable<List<MethodInst>> getMethods() {
return methodInstances.Values;
}
}
// Keeps track which methods of an interface have been implemented
class InterfaceMethodInfo {
TypeInfo iface;
Dictionary<MethodDef, MethodDef> ifaceMethodToClassMethod = new Dictionary<MethodDef, MethodDef>();
public TypeInfo IFace {
get { return iface; }
}
public Dictionary<MethodDef, MethodDef> IfaceMethodToClassMethod {
get { return ifaceMethodToClassMethod; }
}
public InterfaceMethodInfo(TypeInfo iface) {
this.iface = iface;
foreach (var methodDef in iface.typeDef.getAllMethods())
ifaceMethodToClassMethod[methodDef] = null;
}
public InterfaceMethodInfo(TypeInfo iface, InterfaceMethodInfo other) {
this.iface = iface;
foreach (var key in other.ifaceMethodToClassMethod.Keys)
ifaceMethodToClassMethod[key] = other.ifaceMethodToClassMethod[key];
}
public void merge(InterfaceMethodInfo other) {
foreach (var key in other.ifaceMethodToClassMethod.Keys) {
if (other.ifaceMethodToClassMethod[key] == null)
continue;
if (ifaceMethodToClassMethod[key] != null)
throw new ApplicationException("Interface method already initialized");
ifaceMethodToClassMethod[key] = other.ifaceMethodToClassMethod[key];
}
}
public void addMethod(MethodDef ifaceMethod, MethodDef classMethod) {
if (!ifaceMethodToClassMethod.ContainsKey(ifaceMethod))
throw new ApplicationException("Could not find interface method");
ifaceMethodToClassMethod[ifaceMethod] = classMethod;
}
public void addMethodIfEmpty(MethodDef ifaceMethod, MethodDef classMethod) {
if (ifaceMethodToClassMethod[ifaceMethod] == null)
ifaceMethodToClassMethod[ifaceMethod] = classMethod;
}
public override string ToString() {
return iface.ToString();
}
}
class InterfaceMethodInfos {
Dictionary<TypeReferenceKey, InterfaceMethodInfo> interfaceMethods = new Dictionary<TypeReferenceKey, InterfaceMethodInfo>();
public IEnumerable<InterfaceMethodInfo> AllInfos {
get { return interfaceMethods.Values; }
}
public void initializeFrom(InterfaceMethodInfos other, GenericInstanceType git) {
foreach (var pair in other.interfaceMethods) {
var oldTypeInfo = pair.Value.IFace;
var newTypeInfo = new TypeInfo(oldTypeInfo, git);
var oldKey = new TypeReferenceKey(oldTypeInfo.typeReference);
var newKey = new TypeReferenceKey(newTypeInfo.typeReference);
InterfaceMethodInfo newMethodsInfo = new InterfaceMethodInfo(newTypeInfo, other.interfaceMethods[oldKey]);
if (interfaceMethods.ContainsKey(newKey))
newMethodsInfo.merge(interfaceMethods[newKey]);
interfaceMethods[newKey] = newMethodsInfo;
}
}
public void addInterface(TypeInfo iface) {
var key = new TypeReferenceKey(iface.typeReference);
if (!interfaceMethods.ContainsKey(key))
interfaceMethods[key] = new InterfaceMethodInfo(iface);
}
public void addMethod(TypeInfo iface, MethodDef ifaceMethod, MethodDef classMethod) {
addMethod(iface.typeReference, ifaceMethod, classMethod);
}
public void addMethod(TypeReference iface, MethodDef ifaceMethod, MethodDef classMethod) {
InterfaceMethodInfo info;
var key = new TypeReferenceKey(iface);
if (!interfaceMethods.TryGetValue(key, out info))
throw new ApplicationException("Could not find interface");
info.addMethod(ifaceMethod, classMethod);
}
public void addMethodIfEmpty(TypeInfo iface, MethodDef ifaceMethod, MethodDef classMethod) {
InterfaceMethodInfo info;
var key = new TypeReferenceKey(iface.typeReference);
if (!interfaceMethods.TryGetValue(key, out info))
throw new ApplicationException("Could not find interface");
info.addMethodIfEmpty(ifaceMethod, classMethod);
}
}
class TypeDef : Ref {
@ -42,6 +209,11 @@ namespace de4dot.renamer.asmmodules {
internal IList<TypeDef> derivedTypes = new List<TypeDef>();
Module module;
bool initializeVirtualMembersCalled = false;
MethodInstances virtualMethodInstances = new MethodInstances();
Dictionary<TypeInfo, bool> allImplementedInterfaces = new Dictionary<TypeInfo, bool>();
InterfaceMethodInfos interfaceMethodInfos = new InterfaceMethodInfos();
public bool HasModule {
get { return module != null; }
}
@ -101,6 +273,10 @@ namespace de4dot.renamer.asmmodules {
return fields.find(fr);
}
public IEnumerable<MethodDef> getAllMethods() {
return methods.getAll();
}
public void addMembers() {
var type = TypeDefinition;
@ -113,5 +289,273 @@ namespace de4dot.renamer.asmmodules {
for (int i = 0; i < type.Properties.Count; i++)
add(new PropertyDef(type.Properties[i], this, i));
}
public void initializeVirtualMembers(MethodNameScopes scopes, IResolver resolver) {
if (initializeVirtualMembersCalled)
return;
initializeVirtualMembersCalled = true;
foreach (var iface in interfaces)
iface.typeDef.initializeVirtualMembers(scopes, resolver);
if (baseType != null)
baseType.typeDef.initializeVirtualMembers(scopes, resolver);
foreach (var methodDef in methods.getAll()) {
if (methodDef.isVirtual())
scopes.add(methodDef);
}
instantiateVirtualMembers(scopes);
initializeInterfaceMethods(scopes);
}
void initializeAllInterfaces() {
if (baseType != null)
initializeInterfaces(baseType);
foreach (var iface in interfaces) {
allImplementedInterfaces[iface] = true;
interfaceMethodInfos.addInterface(iface);
initializeInterfaces(iface);
}
}
void initializeInterfaces(TypeInfo typeInfo) {
var git = typeInfo.typeReference as GenericInstanceType;
interfaceMethodInfos.initializeFrom(typeInfo.typeDef.interfaceMethodInfos, git);
foreach (var info in typeInfo.typeDef.allImplementedInterfaces.Keys) {
var newTypeInfo = new TypeInfo(info, git);
allImplementedInterfaces[newTypeInfo] = true;
}
}
void initializeInterfaceMethods(MethodNameScopes scopes) {
initializeAllInterfaces();
if (TypeDefinition.IsInterface)
return;
//--- Partition II 12.2 Implementing virtual methods on interfaces:
//--- The VES shall use the following algorithm to determine the appropriate
//--- implementation of an interface's virtual abstract methods:
//---
//--- * If the base class implements the interface, start with the same virtual methods
//--- that it provides; otherwise, create an interface that has empty slots for all
//--- virtual functions.
// Done. See initializeAllInterfaces().
var methodsDict = new Dictionary<MethodReferenceKey, MethodDef>();
//--- * If this class explicitly specifies that it implements the interface (i.e., the
//--- interfaces that appear in this class InterfaceImpl table, §22.23)
//--- * If the class defines any public virtual newslot methods whose name and
//--- signature match a virtual method on the interface, then use these new virtual
//--- methods to implement the corresponding interface method.
if (interfaces.Count > 0) {
methodsDict.Clear();
foreach (var method in methods.getAll()) {
if (!method.isPublic() || !method.isVirtual() || !method.isNewSlot())
continue;
methodsDict[new MethodReferenceKey(method.MethodDefinition)] = method;
}
foreach (var ifaceInfo in interfaces) {
foreach (var methodsList in ifaceInfo.typeDef.virtualMethodInstances.getMethods()) {
if (methodsList.Count != 1) // Never happens
throw new ApplicationException("Interface with more than one method in the list");
var methodInst = methodsList[0];
var ifaceMethod = methodInst.origMethodDef;
if (!ifaceMethod.isVirtual())
continue;
var ifaceMethodReference = MethodReferenceInstance.make(methodInst.methodReference, ifaceInfo.typeReference as GenericInstanceType);
MethodDef classMethod;
var key = new MethodReferenceKey(ifaceMethodReference);
if (!methodsDict.TryGetValue(key, out classMethod))
continue;
interfaceMethodInfos.addMethod(ifaceInfo, ifaceMethod, classMethod);
}
}
}
//--- * If there are any virtual methods in the interface that still have empty slots,
//--- see if there are any public virtual methods, but not public virtual newslot
//--- methods, available on this class (directly or inherited) having the same name
//--- and signature, then use these to implement the corresponding methods on the
//--- interface.
methodsDict.Clear();
foreach (var methodInstList in virtualMethodInstances.getMethods()) {
// This class' method is at the end
for (int i = methodInstList.Count - 1; i >= 0; i--) {
var classMethod = methodInstList[i];
// These methods are guaranteed to be virtual.
// We should allow newslot methods, despite what the official doc says.
if (!classMethod.origMethodDef.isPublic())
continue;
methodsDict[new MethodReferenceKey(classMethod.methodReference)] = classMethod.origMethodDef;
break;
}
}
foreach (var ifaceInfo in allImplementedInterfaces.Keys) {
foreach (var methodsList in ifaceInfo.typeDef.virtualMethodInstances.getMethods()) {
if (methodsList.Count != 1) // Never happens
throw new ApplicationException("Interface with more than one method in the list");
var ifaceMethod = methodsList[0].origMethodDef;
if (!ifaceMethod.isVirtual())
continue;
var ifaceMethodRef = MethodReferenceInstance.make(ifaceMethod.MethodDefinition, ifaceInfo.typeReference as GenericInstanceType);
MethodDef classMethod;
var key = new MethodReferenceKey(ifaceMethodRef);
if (!methodsDict.TryGetValue(key, out classMethod))
continue;
interfaceMethodInfos.addMethodIfEmpty(ifaceInfo, ifaceMethod, classMethod);
}
}
//--- * Apply all MethodImpls that are specified for this class, thereby placing
//--- explicitly specified virtual methods into the interface in preference to those
//--- inherited or chosen by name matching.
methodsDict.Clear();
var ifaceMethodsDict = new Dictionary<MethodReferenceAndDeclaringTypeKey, MethodDef>();
var overrideMethods = new Dictionary<MethodDef, bool>();
foreach (var ifaceInfo in allImplementedInterfaces.Keys) {
var git = ifaceInfo.typeReference as GenericInstanceType;
foreach (var ifaceMethod in ifaceInfo.typeDef.methods.getAll()) {
MethodReference ifaceMethodReference = ifaceMethod.MethodDefinition;
if (git != null)
ifaceMethodReference = simpleClone(ifaceMethod.MethodDefinition, git);
ifaceMethodsDict[new MethodReferenceAndDeclaringTypeKey(ifaceMethodReference)] = ifaceMethod;
}
}
foreach (var classMethod in methods.getAll()) {
if (!classMethod.isVirtual())
continue;
foreach (var overrideMethod in classMethod.MethodDefinition.Overrides) {
MethodDef ifaceMethod;
var key = new MethodReferenceAndDeclaringTypeKey(overrideMethod);
if (!ifaceMethodsDict.TryGetValue(key, out ifaceMethod)) {
// We couldn't find the interface method (eg. interface not resolved) or
// it overrides a base class method, and not an interface method.
continue;
}
interfaceMethodInfos.addMethod(overrideMethod.DeclaringType, ifaceMethod, classMethod);
overrideMethods[classMethod] = true;
}
}
//--- * If the current class is not abstract and there are any interface methods that
//--- still have empty slots, then the program is invalid.
// Check it anyway. C# requires a method, even if it's abstract. I don't think anyone
// writes pure CIL assemblies.
foreach (var info in interfaceMethodInfos.AllInfos) {
foreach (var pair in info.IfaceMethodToClassMethod) {
if (pair.Value != null)
continue;
if (!resolvedAllInterfaces() || !resolvedBaseClasses())
continue;
string errMsg = string.Format(
"Could not find interface method {0} ({1:X8}). Type: {2} ({3:X8})",
pair.Key.MethodDefinition,
pair.Key.MethodDefinition.MetadataToken.ToInt32(),
TypeDefinition,
TypeDefinition.MetadataToken.ToInt32());
// Ignore if COM class
if (!hasAttribute("System.Runtime.InteropServices.TypeLibTypeAttribute"))
throw new ApplicationException(errMsg);
Log.e("{0}", errMsg);
}
}
foreach (var info in interfaceMethodInfos.AllInfos) {
foreach (var pair in info.IfaceMethodToClassMethod) {
if (pair.Value == null)
continue;
if (overrideMethods.ContainsKey(pair.Value))
continue;
scopes.same(pair.Key, pair.Value);
}
}
}
bool hasAttribute(string name) {
foreach (var attr in TypeDefinition.CustomAttributes) {
if (attr.AttributeType.FullName == name)
return true;
}
return false;
}
// Returns true if all interfaces have been resolved
bool? resolvedAllInterfacesResult;
bool resolvedAllInterfaces() {
if (!resolvedAllInterfacesResult.HasValue) {
resolvedAllInterfacesResult = true; // If we find a circular reference
resolvedAllInterfacesResult = resolvedAllInterfacesInternal();
}
return resolvedAllInterfacesResult.Value;
}
bool resolvedAllInterfacesInternal() {
if (TypeDefinition.Interfaces.Count != interfaces.Count)
return false;
foreach (var ifaceInfo in interfaces) {
if (!ifaceInfo.typeDef.resolvedAllInterfaces())
return false;
}
return true;
}
// Returns true if all base classes have been resolved
bool? resolvedBaseClassesResult;
bool resolvedBaseClasses() {
if (!resolvedBaseClassesResult.HasValue) {
resolvedBaseClassesResult = true; // If we find a circular reference
resolvedBaseClassesResult = resolvedBaseClassesInternal();
}
return resolvedBaseClassesResult.Value;
}
bool resolvedBaseClassesInternal() {
if (TypeDefinition.BaseType == null)
return true;
if (baseType == null)
return false;
return baseType.typeDef.resolvedBaseClasses();
}
MethodReference simpleClone(MethodReference methodReference, TypeReference declaringType) {
var m = new MethodReference(methodReference.Name, methodReference.MethodReturnType.ReturnType, declaringType);
m.MethodReturnType.ReturnType = methodReference.MethodReturnType.ReturnType;
m.HasThis = methodReference.HasThis;
m.ExplicitThis = methodReference.ExplicitThis;
m.CallingConvention = methodReference.CallingConvention;
foreach (var p in methodReference.Parameters)
m.Parameters.Add(new ParameterDefinition(p.Name, p.Attributes, p.ParameterType));
foreach (var gp in methodReference.GenericParameters)
m.GenericParameters.Add(new GenericParameter(declaringType));
return m;
}
void instantiateVirtualMembers(MethodNameScopes scopes) {
if (!TypeDefinition.IsInterface) {
if (baseType != null)
virtualMethodInstances.initializeFrom(baseType.typeDef.virtualMethodInstances, baseType.typeReference as GenericInstanceType);
// Figure out which methods we override in the base class
foreach (var methodDef in methods.getAll()) {
if (!methodDef.isVirtual() || methodDef.isNewSlot())
continue;
var methodInstList = virtualMethodInstances.lookup(methodDef.MethodDefinition);
if (methodInstList == null)
continue;
foreach (var methodInst in methodInstList)
scopes.same(methodDef, methodInst.origMethodDef);
}
}
foreach (var methodDef in methods.getAll()) {
if (!methodDef.isVirtual())
continue;
virtualMethodInstances.add(new MethodInst(methodDef, methodDef.MethodDefinition));
}
}
}
}