de4dot-cex/de4dot.code/renamer/asmmodules/Modules.cs

511 lines
14 KiB
C#
Raw Normal View History

2011-11-15 21:26:51 +08:00
/*
2015-10-30 05:45:26 +08:00
Copyright (C) 2011-2015 de4dot@gmail.com
2011-11-15 21:26:51 +08:00
This file is part of de4dot.
de4dot is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
de4dot is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with de4dot. If not, see <http://www.gnu.org/licenses/>.
*/
using System;
using System.Collections.Generic;
using dnlib.DotNet;
2011-11-15 21:26:51 +08:00
using de4dot.blocks;
namespace de4dot.code.renamer.asmmodules {
public class Modules : IResolver {
2011-11-15 21:26:51 +08:00
bool initializeCalled = false;
2012-04-05 03:06:10 +08:00
IDeobfuscatorContext deobfuscatorContext;
2011-11-15 21:26:51 +08:00
List<Module> modules = new List<Module>();
Dictionary<ModuleDef, Module> modulesDict = new Dictionary<ModuleDef, Module>();
2011-11-15 21:26:51 +08:00
AssemblyHash assemblyHash = new AssemblyHash();
List<MTypeDef> allTypes = new List<MTypeDef>();
List<MTypeDef> baseTypes = new List<MTypeDef>();
List<MTypeDef> nonNestedTypes;
2011-11-17 11:17:03 +08:00
2011-11-18 23:55:54 +08:00
public IList<Module> TheModules {
get { return modules; }
}
public IEnumerable<MTypeDef> AllTypes {
2011-11-17 11:17:03 +08:00
get { return allTypes; }
}
public IEnumerable<MTypeDef> BaseTypes {
2011-11-17 11:17:03 +08:00
get { return baseTypes; }
}
public List<MTypeDef> NonNestedTypes {
2011-11-17 11:17:03 +08:00
get { return nonNestedTypes; }
}
2011-11-15 21:26:51 +08:00
class AssemblyHash {
IDictionary<string, ModuleHash> assemblyHash = new Dictionary<string, ModuleHash>(StringComparer.Ordinal);
2013-01-19 20:03:57 +08:00
public void Add(Module module) {
2011-11-15 21:26:51 +08:00
ModuleHash moduleHash;
2013-01-19 20:03:57 +08:00
var key = GetModuleKey(module);
2011-11-15 21:26:51 +08:00
if (!assemblyHash.TryGetValue(key, out moduleHash))
assemblyHash[key] = moduleHash = new ModuleHash();
2013-01-19 20:03:57 +08:00
moduleHash.Add(module);
2011-11-15 21:26:51 +08:00
}
static string GetModuleKey(Module module) {
2012-11-03 03:10:34 +08:00
if (module.ModuleDefMD.Assembly != null)
return GetAssemblyName(module.ModuleDefMD.Assembly);
2013-01-19 20:03:57 +08:00
return Utils.GetBaseName(module.ModuleDefMD.Location);
2011-11-15 21:26:51 +08:00
}
public ModuleHash Lookup(IAssembly asm) {
2011-11-15 21:26:51 +08:00
ModuleHash moduleHash;
if (assemblyHash.TryGetValue(GetAssemblyName(asm), out moduleHash))
2011-11-15 21:26:51 +08:00
return moduleHash;
return null;
}
static string GetAssemblyName(IAssembly asm) {
if (asm == null)
return string.Empty;
if (PublicKeyBase.IsNullOrEmpty2(asm.PublicKeyOrToken))
return asm.Name;
return asm.FullName;
}
2011-11-15 21:26:51 +08:00
}
class ModuleHash {
ModulesDict modulesDict = new ModulesDict();
Module mainModule = null;
2013-01-19 20:03:57 +08:00
public void Add(Module module) {
2012-11-03 03:10:34 +08:00
var asm = module.ModuleDefMD.Assembly;
if (asm != null && ReferenceEquals(asm.ManifestModule, module.ModuleDefMD)) {
2011-11-15 21:26:51 +08:00
if (mainModule != null) {
throw new UserException(string.Format(
"Two modules in the same assembly are main modules.\n" +
"Is one 32-bit and the other 64-bit?\n" +
" Module1: \"{0}\"" +
" Module2: \"{1}\"",
2012-11-03 03:10:34 +08:00
module.ModuleDefMD.Location,
mainModule.ModuleDefMD.Location));
2011-11-15 21:26:51 +08:00
}
mainModule = module;
}
2013-01-19 20:03:57 +08:00
modulesDict.Add(module);
2011-11-15 21:26:51 +08:00
}
2013-01-19 20:03:57 +08:00
public Module Lookup(string moduleName) {
return modulesDict.Lookup(moduleName);
2011-11-15 21:26:51 +08:00
}
public IEnumerable<Module> Modules {
get { return modulesDict.Modules; }
}
}
class ModulesDict {
IDictionary<string, Module> modulesDict = new Dictionary<string, Module>(StringComparer.Ordinal);
2013-01-19 20:03:57 +08:00
public void Add(Module module) {
2012-11-03 03:10:34 +08:00
var moduleName = module.ModuleDefMD.Name.String;
2013-01-19 20:03:57 +08:00
if (Lookup(moduleName) != null)
2011-11-15 21:26:51 +08:00
throw new ApplicationException(string.Format("Module \"{0}\" was found twice", moduleName));
modulesDict[moduleName] = module;
}
2013-01-19 20:03:57 +08:00
public Module Lookup(string moduleName) {
2011-11-15 21:26:51 +08:00
Module module;
if (modulesDict.TryGetValue(moduleName, out module))
return module;
return null;
}
public IEnumerable<Module> Modules {
get { return modulesDict.Values; }
}
}
public bool Empty {
get { return modules.Count == 0; }
}
2012-04-05 03:06:10 +08:00
public Modules(IDeobfuscatorContext deobfuscatorContext) {
this.deobfuscatorContext = deobfuscatorContext;
}
2013-01-19 20:03:57 +08:00
public void Add(Module module) {
2011-11-15 21:26:51 +08:00
if (initializeCalled)
throw new ApplicationException("initialize() has been called");
Module otherModule;
2012-11-03 03:10:34 +08:00
if (modulesDict.TryGetValue(module.ModuleDefMD, out otherModule))
2011-11-15 21:26:51 +08:00
return;
2012-11-03 03:10:34 +08:00
modulesDict[module.ModuleDefMD] = module;
2011-11-15 21:26:51 +08:00
modules.Add(module);
2013-01-19 20:03:57 +08:00
assemblyHash.Add(module);
2011-11-15 21:26:51 +08:00
}
2013-01-19 20:03:57 +08:00
public void Initialize() {
2011-11-15 21:26:51 +08:00
initializeCalled = true;
2013-01-19 20:03:57 +08:00
FindAllMemberRefs();
InitAllTypes();
ResolveAllRefs();
2011-11-15 21:26:51 +08:00
}
2013-01-19 20:03:57 +08:00
void FindAllMemberRefs() {
2012-11-22 16:14:51 +08:00
Logger.v("Finding all MemberRefs");
2011-11-15 21:26:51 +08:00
int index = 0;
foreach (var module in modules) {
if (modules.Count > 1)
2012-11-22 16:14:51 +08:00
Logger.v("Finding all MemberRefs ({0})", module.Filename);
2013-01-19 20:03:57 +08:00
Logger.Instance.Indent();
module.FindAllMemberRefs(ref index);
Logger.Instance.DeIndent();
2011-11-15 21:26:51 +08:00
}
}
2013-01-19 20:03:57 +08:00
void ResolveAllRefs() {
Logger.v("Resolving references");
2011-11-15 21:26:51 +08:00
foreach (var module in modules) {
if (modules.Count > 1)
Logger.v("Resolving references ({0})", module.Filename);
2013-01-19 20:03:57 +08:00
Logger.Instance.Indent();
module.ResolveAllRefs(this);
Logger.Instance.DeIndent();
2011-11-15 21:26:51 +08:00
}
}
2013-01-19 20:03:57 +08:00
void InitAllTypes() {
2011-11-15 21:26:51 +08:00
foreach (var module in modules)
2013-01-19 20:03:57 +08:00
allTypes.AddRange(module.GetAllTypes());
2011-11-15 21:26:51 +08:00
var typeToTypeDef = new Dictionary<TypeDef, MTypeDef>(allTypes.Count);
2011-11-15 21:26:51 +08:00
foreach (var typeDef in allTypes)
typeToTypeDef[typeDef.TypeDef] = typeDef;
2011-11-15 21:26:51 +08:00
// Initialize Owner
foreach (var typeDef in allTypes) {
if (typeDef.TypeDef.DeclaringType != null)
typeDef.Owner = typeToTypeDef[typeDef.TypeDef.DeclaringType];
2011-11-15 21:26:51 +08:00
}
// Initialize baseType and derivedTypes
foreach (var typeDef in allTypes) {
var baseType = typeDef.TypeDef.BaseType;
2011-11-15 21:26:51 +08:00
if (baseType == null)
continue;
2013-01-19 20:03:57 +08:00
var baseTypeDef = ResolveType(baseType) ?? ResolveOther(baseType);
2011-11-15 21:26:51 +08:00
if (baseTypeDef != null) {
2013-01-19 20:03:57 +08:00
typeDef.AddBaseType(baseTypeDef, baseType);
2011-11-15 21:26:51 +08:00
baseTypeDef.derivedTypes.Add(typeDef);
}
}
// Initialize interfaces
foreach (var typeDef in allTypes) {
2012-11-17 06:50:52 +08:00
foreach (var iface in typeDef.TypeDef.Interfaces) {
2013-01-19 20:03:57 +08:00
var ifaceTypeDef = ResolveType(iface.Interface) ?? ResolveOther(iface.Interface);
2011-11-15 21:26:51 +08:00
if (ifaceTypeDef != null)
2013-01-19 20:03:57 +08:00
typeDef.AddInterface(ifaceTypeDef, iface.Interface);
2011-11-15 21:26:51 +08:00
}
}
// Find all non-nested types
var allTypesDict = new Dictionary<MTypeDef, bool>();
2011-11-15 21:26:51 +08:00
foreach (var t in allTypes)
allTypesDict[t] = true;
foreach (var t in allTypes) {
foreach (var t2 in t.NestedTypes)
allTypesDict.Remove(t2);
}
nonNestedTypes = new List<MTypeDef>(allTypesDict.Keys);
2011-11-15 21:26:51 +08:00
foreach (var typeDef in allTypes) {
if (typeDef.baseType == null || !typeDef.baseType.typeDef.HasModule)
baseTypes.Add(typeDef);
}
}
2011-12-02 05:32:09 +08:00
class AssemblyKeyDictionary<T> where T : class {
2012-11-03 03:10:34 +08:00
Dictionary<ITypeDefOrRef, T> dict = new Dictionary<ITypeDefOrRef, T>(new TypeEqualityComparer(SigComparerOptions.CompareAssemblyVersion));
Dictionary<ITypeDefOrRef, List<ITypeDefOrRef>> refs = new Dictionary<ITypeDefOrRef, List<ITypeDefOrRef>>(TypeEqualityComparer.Instance);
2011-12-02 05:32:09 +08:00
2012-11-03 03:10:34 +08:00
public T this[ITypeDefOrRef type] {
2011-12-02 05:32:09 +08:00
get {
T value;
2013-01-19 20:03:57 +08:00
if (TryGetValue(type, out value))
2011-12-02 05:32:09 +08:00
return value;
throw new KeyNotFoundException();
}
set {
2012-11-03 03:10:34 +08:00
dict[type] = value;
2011-12-02 05:32:09 +08:00
if (value != null) {
2012-11-03 03:10:34 +08:00
List<ITypeDefOrRef> list;
if (!refs.TryGetValue(type, out list))
refs[type] = list = new List<ITypeDefOrRef>();
2011-12-02 05:32:09 +08:00
list.Add(type);
}
}
}
2013-01-19 20:03:57 +08:00
public bool TryGetValue(ITypeDefOrRef type, out T value) {
2012-11-03 03:10:34 +08:00
return dict.TryGetValue(type, out value);
2011-12-02 05:32:09 +08:00
}
2013-01-19 20:03:57 +08:00
public bool TryGetSimilarValue(ITypeDefOrRef type, out T value) {
2012-11-03 03:10:34 +08:00
List<ITypeDefOrRef> list;
if (!refs.TryGetValue(type, out list)) {
2011-12-02 05:32:09 +08:00
value = default(T);
return false;
}
// Find a type whose version is >= type's version and closest to it.
2012-11-03 03:10:34 +08:00
ITypeDefOrRef foundType = null;
var typeAsmName = type.DefinitionAssembly;
IAssembly foundAsmName = null;
2011-12-02 05:32:09 +08:00
foreach (var otherRef in list) {
2012-11-03 03:10:34 +08:00
if (!dict.TryGetValue(otherRef, out value))
2011-12-02 05:32:09 +08:00
continue;
if (typeAsmName == null) {
foundType = otherRef;
break;
}
2012-11-03 03:10:34 +08:00
var otherAsmName = otherRef.DefinitionAssembly;
2011-12-02 05:32:09 +08:00
if (otherAsmName == null)
continue;
// Check pkt or we could return a type in eg. a SL assembly when it's not a SL app.
2012-11-03 03:10:34 +08:00
if (!PublicKeyBase.TokenEquals(typeAsmName.PublicKeyOrToken, otherAsmName.PublicKeyOrToken))
2011-12-02 05:32:09 +08:00
continue;
if (typeAsmName.Version > otherAsmName.Version)
continue; // old version
if (foundType == null) {
foundAsmName = otherAsmName;
foundType = otherRef;
continue;
}
if (foundAsmName.Version <= otherAsmName.Version)
continue;
foundAsmName = otherAsmName;
foundType = otherRef;
}
if (foundType != null) {
2012-11-03 03:10:34 +08:00
value = dict[foundType];
2011-12-02 05:32:09 +08:00
return true;
}
value = default(T);
return false;
}
}
AssemblyKeyDictionary<MTypeDef> typeToTypeDefDict = new AssemblyKeyDictionary<MTypeDef>();
2013-01-19 20:03:57 +08:00
public MTypeDef ResolveOther(ITypeDefOrRef type) {
2012-11-03 03:10:34 +08:00
if (type == null)
return null;
type = type.ScopeType;
2011-11-15 21:26:51 +08:00
if (type == null)
return null;
MTypeDef typeDef;
2013-01-19 20:03:57 +08:00
if (typeToTypeDefDict.TryGetValue(type, out typeDef))
2011-11-15 21:26:51 +08:00
return typeDef;
2013-01-19 20:03:57 +08:00
var typeDef2 = deobfuscatorContext.ResolveType(type);
2012-11-22 16:14:51 +08:00
if (typeDef2 == null) {
2013-01-19 20:03:57 +08:00
typeToTypeDefDict.TryGetSimilarValue(type, out typeDef);
2011-12-02 05:32:09 +08:00
typeToTypeDefDict[type] = typeDef;
return typeDef;
}
2013-01-19 20:03:57 +08:00
if (typeToTypeDefDict.TryGetValue(typeDef2, out typeDef)) {
2011-12-02 05:32:09 +08:00
typeToTypeDefDict[type] = typeDef;
return typeDef;
}
typeToTypeDefDict[type] = null; // In case of a circular reference
2012-11-22 16:14:51 +08:00
typeToTypeDefDict[typeDef2] = null;
2011-11-15 21:26:51 +08:00
2012-11-22 16:14:51 +08:00
typeDef = new MTypeDef(typeDef2, null, 0);
2013-01-19 20:03:57 +08:00
typeDef.AddMembers();
2012-11-17 06:50:52 +08:00
foreach (var iface in typeDef.TypeDef.Interfaces) {
2013-01-19 20:03:57 +08:00
var ifaceDef = ResolveOther(iface.Interface);
2011-11-15 21:26:51 +08:00
if (ifaceDef == null)
continue;
2013-01-19 20:03:57 +08:00
typeDef.AddInterface(ifaceDef, iface.Interface);
2011-11-15 21:26:51 +08:00
}
2013-01-19 20:03:57 +08:00
var baseDef = ResolveOther(typeDef.TypeDef.BaseType);
2011-11-15 21:26:51 +08:00
if (baseDef != null)
2013-01-19 20:03:57 +08:00
typeDef.AddBaseType(baseDef, typeDef.TypeDef.BaseType);
2011-12-02 05:32:09 +08:00
typeToTypeDefDict[type] = typeDef;
2012-11-22 16:14:51 +08:00
if (type != typeDef2)
typeToTypeDefDict[typeDef2] = typeDef;
2011-12-02 05:32:09 +08:00
return typeDef;
2011-11-15 21:26:51 +08:00
}
2013-01-19 20:03:57 +08:00
public MethodNameGroups InitializeVirtualMembers() {
2012-03-08 20:23:01 +08:00
var groups = new MethodNameGroups();
foreach (var typeDef in allTypes)
2013-01-19 20:03:57 +08:00
typeDef.InitializeVirtualMembers(groups, this);
2012-03-08 20:23:01 +08:00
return groups;
}
2013-01-19 20:03:57 +08:00
public void OnTypesRenamed() {
2011-11-18 23:55:54 +08:00
foreach (var module in modules)
2013-01-19 20:03:57 +08:00
module.OnTypesRenamed();
2011-11-18 23:55:54 +08:00
}
2013-01-19 20:03:57 +08:00
public void CleanUp() {
2012-11-03 03:10:34 +08:00
#if PORT
foreach (var module in DotNetUtils.typeCaches.invalidateAll())
AssemblyResolver.Instance.removeModule(module);
2012-11-03 03:10:34 +08:00
#endif
}
2011-11-15 21:26:51 +08:00
// Returns null if it's a non-loaded module/assembly
2013-01-19 20:03:57 +08:00
IEnumerable<Module> FindModules(ITypeDefOrRef type) {
if (type == null)
return null;
2011-11-15 21:26:51 +08:00
var scope = type.Scope;
2012-08-31 06:24:42 +08:00
if (scope == null)
return null;
2011-11-15 21:26:51 +08:00
var scopeType = scope.ScopeType;
if (scopeType == ScopeType.AssemblyRef)
2013-01-19 20:03:57 +08:00
return FindModules((AssemblyRef)scope);
2011-11-15 21:26:51 +08:00
if (scopeType == ScopeType.ModuleDef) {
2013-01-19 20:03:57 +08:00
var modules = FindModules((ModuleDef)scope);
2011-11-15 21:26:51 +08:00
if (modules != null)
return modules;
}
if (scopeType == ScopeType.ModuleRef) {
2012-11-03 03:10:34 +08:00
var moduleRef = (ModuleRef)scope;
2012-11-17 06:50:52 +08:00
if (moduleRef.Name == type.Module.Name) {
2013-01-19 20:03:57 +08:00
var modules = FindModules(type.Module);
2011-11-15 21:26:51 +08:00
if (modules != null)
return modules;
}
}
2011-11-15 21:26:51 +08:00
if (scopeType == ScopeType.ModuleRef || scopeType == ScopeType.ModuleDef) {
2012-11-17 06:50:52 +08:00
var asm = type.Module.Assembly;
2011-11-15 21:26:51 +08:00
if (asm == null)
return null;
var moduleHash = assemblyHash.Lookup(asm);
2011-11-15 21:26:51 +08:00
if (moduleHash == null)
return null;
2013-01-19 20:03:57 +08:00
var module = moduleHash.Lookup(scope.ScopeName);
2011-11-15 21:26:51 +08:00
if (module == null)
return null;
return new List<Module> { module };
}
throw new ApplicationException(string.Format("scope is an unsupported type: {0}", scope.GetType()));
}
2013-01-19 20:03:57 +08:00
IEnumerable<Module> FindModules(AssemblyRef assemblyRef) {
var moduleHash = assemblyHash.Lookup(assemblyRef);
2011-11-15 21:26:51 +08:00
if (moduleHash != null)
return moduleHash.Modules;
return null;
}
2013-01-19 20:03:57 +08:00
IEnumerable<Module> FindModules(ModuleDef moduleDef) {
2011-11-15 21:26:51 +08:00
Module module;
2012-11-03 03:10:34 +08:00
if (modulesDict.TryGetValue(moduleDef, out module))
2011-11-15 21:26:51 +08:00
return new List<Module> { module };
return null;
}
2013-01-19 20:03:57 +08:00
bool IsAutoCreatedType(ITypeDefOrRef typeRef) {
2012-11-03 03:10:34 +08:00
var ts = typeRef as TypeSpec;
if (ts == null)
return false;
var sig = ts.TypeSig;
if (sig == null)
return false;
return sig.IsSZArray || sig.IsArray || sig.IsPointer;
2011-11-15 21:26:51 +08:00
}
2013-01-19 20:03:57 +08:00
public MTypeDef ResolveType(ITypeDefOrRef typeRef) {
var modules = FindModules(typeRef);
2011-11-15 21:26:51 +08:00
if (modules == null)
return null;
foreach (var module in modules) {
2013-01-19 20:03:57 +08:00
var rv = module.ResolveType(typeRef);
2011-11-15 21:26:51 +08:00
if (rv != null)
return rv;
}
2013-01-19 20:03:57 +08:00
if (IsAutoCreatedType(typeRef))
2011-11-15 21:26:51 +08:00
return null;
2012-11-22 16:14:51 +08:00
Logger.e("Could not resolve TypeRef {0} ({1:X8}) (from {2} -> {3})",
2013-01-19 20:03:57 +08:00
Utils.RemoveNewlines(typeRef),
2012-11-03 03:10:34 +08:00
typeRef.MDToken.ToInt32(),
2012-11-17 06:50:52 +08:00
typeRef.Module,
2012-11-03 03:10:34 +08:00
typeRef.Scope);
2011-12-02 07:43:49 +08:00
return null;
2011-11-15 21:26:51 +08:00
}
2013-01-19 20:03:57 +08:00
public MMethodDef ResolveMethod(IMethodDefOrRef methodRef) {
2012-11-03 03:10:34 +08:00
if (methodRef.DeclaringType == null)
2011-11-15 21:26:51 +08:00
return null;
2013-01-19 20:03:57 +08:00
var modules = FindModules(methodRef.DeclaringType);
2011-11-15 21:26:51 +08:00
if (modules == null)
return null;
foreach (var module in modules) {
2013-01-19 20:03:57 +08:00
var rv = module.ResolveMethod(methodRef);
2011-11-15 21:26:51 +08:00
if (rv != null)
return rv;
}
2013-01-19 20:03:57 +08:00
if (IsAutoCreatedType(methodRef.DeclaringType))
2011-11-15 21:26:51 +08:00
return null;
2012-11-22 16:14:51 +08:00
Logger.e("Could not resolve MethodRef {0} ({1:X8}) (from {2} -> {3})",
2013-01-19 20:03:57 +08:00
Utils.RemoveNewlines(methodRef),
2012-11-03 03:10:34 +08:00
methodRef.MDToken.ToInt32(),
2012-11-17 06:50:52 +08:00
methodRef.DeclaringType.Module,
2012-11-03 03:10:34 +08:00
methodRef.DeclaringType.Scope);
2011-12-02 07:43:49 +08:00
return null;
2011-11-15 21:26:51 +08:00
}
2013-01-19 20:03:57 +08:00
public MFieldDef ResolveField(MemberRef fieldRef) {
2012-11-22 16:14:51 +08:00
if (fieldRef.DeclaringType == null)
2011-11-15 21:26:51 +08:00
return null;
2013-01-19 20:03:57 +08:00
var modules = FindModules(fieldRef.DeclaringType);
2011-11-15 21:26:51 +08:00
if (modules == null)
return null;
foreach (var module in modules) {
2013-01-19 20:03:57 +08:00
var rv = module.ResolveField(fieldRef);
2011-11-15 21:26:51 +08:00
if (rv != null)
return rv;
}
2013-01-19 20:03:57 +08:00
if (IsAutoCreatedType(fieldRef.DeclaringType))
2011-11-15 21:26:51 +08:00
return null;
2012-11-22 16:14:51 +08:00
Logger.e("Could not resolve FieldRef {0} ({1:X8}) (from {2} -> {3})",
2013-01-19 20:03:57 +08:00
Utils.RemoveNewlines(fieldRef),
2012-11-22 16:14:51 +08:00
fieldRef.MDToken.ToInt32(),
fieldRef.DeclaringType.Module,
fieldRef.DeclaringType.Scope);
2011-12-02 07:43:49 +08:00
return null;
2011-11-15 21:26:51 +08:00
}
}
}