/* Copyright (C) 2011-2015 de4dot@gmail.com This file is part of de4dot. de4dot is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. de4dot is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with de4dot. If not, see . */ using System; using System.Collections.Generic; using dnlib.DotNet; using de4dot.blocks; namespace de4dot.code.renamer.asmmodules { public class Modules : IResolver { bool initializeCalled = false; IDeobfuscatorContext deobfuscatorContext; List modules = new List(); Dictionary modulesDict = new Dictionary(); AssemblyHash assemblyHash = new AssemblyHash(); List allTypes = new List(); List baseTypes = new List(); List nonNestedTypes; public IList TheModules { get { return modules; } } public IEnumerable AllTypes { get { return allTypes; } } public IEnumerable BaseTypes { get { return baseTypes; } } public List NonNestedTypes { get { return nonNestedTypes; } } class AssemblyHash { IDictionary assemblyHash = new Dictionary(StringComparer.Ordinal); public void Add(Module module) { ModuleHash moduleHash; var key = GetModuleKey(module); if (!assemblyHash.TryGetValue(key, out moduleHash)) assemblyHash[key] = moduleHash = new ModuleHash(); moduleHash.Add(module); } static string GetModuleKey(Module module) { if (module.ModuleDefMD.Assembly != null) return GetAssemblyName(module.ModuleDefMD.Assembly); return Utils.GetBaseName(module.ModuleDefMD.Location); } public ModuleHash Lookup(IAssembly asm) { ModuleHash moduleHash; if (assemblyHash.TryGetValue(GetAssemblyName(asm), out moduleHash)) return moduleHash; return null; } static string GetAssemblyName(IAssembly asm) { if (asm == null) return string.Empty; if (PublicKeyBase.IsNullOrEmpty2(asm.PublicKeyOrToken)) return asm.Name; return asm.FullName; } } class ModuleHash { ModulesDict modulesDict = new ModulesDict(); Module mainModule = null; public void Add(Module module) { var asm = module.ModuleDefMD.Assembly; if (asm != null && ReferenceEquals(asm.ManifestModule, module.ModuleDefMD)) { if (mainModule != null) { throw new UserException(string.Format( "Two modules in the same assembly are main modules.\n" + "Is one 32-bit and the other 64-bit?\n" + " Module1: \"{0}\"" + " Module2: \"{1}\"", module.ModuleDefMD.Location, mainModule.ModuleDefMD.Location)); } mainModule = module; } modulesDict.Add(module); } public Module Lookup(string moduleName) { return modulesDict.Lookup(moduleName); } public IEnumerable Modules { get { return modulesDict.Modules; } } } class ModulesDict { IDictionary modulesDict = new Dictionary(StringComparer.Ordinal); public void Add(Module module) { var moduleName = module.ModuleDefMD.Name.String; if (Lookup(moduleName) != null) throw new ApplicationException(string.Format("Module \"{0}\" was found twice", moduleName)); modulesDict[moduleName] = module; } public Module Lookup(string moduleName) { Module module; if (modulesDict.TryGetValue(moduleName, out module)) return module; return null; } public IEnumerable Modules { get { return modulesDict.Values; } } } public bool Empty { get { return modules.Count == 0; } } public Modules(IDeobfuscatorContext deobfuscatorContext) { this.deobfuscatorContext = deobfuscatorContext; } public void Add(Module module) { if (initializeCalled) throw new ApplicationException("initialize() has been called"); Module otherModule; if (modulesDict.TryGetValue(module.ModuleDefMD, out otherModule)) return; modulesDict[module.ModuleDefMD] = module; modules.Add(module); assemblyHash.Add(module); } public void Initialize() { initializeCalled = true; FindAllMemberRefs(); InitAllTypes(); ResolveAllRefs(); } void FindAllMemberRefs() { Logger.v("Finding all MemberRefs"); int index = 0; foreach (var module in modules) { if (modules.Count > 1) Logger.v("Finding all MemberRefs ({0})", module.Filename); Logger.Instance.Indent(); module.FindAllMemberRefs(ref index); Logger.Instance.DeIndent(); } } void ResolveAllRefs() { Logger.v("Resolving references"); foreach (var module in modules) { if (modules.Count > 1) Logger.v("Resolving references ({0})", module.Filename); Logger.Instance.Indent(); module.ResolveAllRefs(this); Logger.Instance.DeIndent(); } } void InitAllTypes() { foreach (var module in modules) allTypes.AddRange(module.GetAllTypes()); var typeToTypeDef = new Dictionary(allTypes.Count); foreach (var typeDef in allTypes) typeToTypeDef[typeDef.TypeDef] = typeDef; // Initialize Owner foreach (var typeDef in allTypes) { if (typeDef.TypeDef.DeclaringType != null) typeDef.Owner = typeToTypeDef[typeDef.TypeDef.DeclaringType]; } // Initialize baseType and derivedTypes foreach (var typeDef in allTypes) { var baseType = typeDef.TypeDef.BaseType; if (baseType == null) continue; var baseTypeDef = ResolveType(baseType) ?? ResolveOther(baseType); if (baseTypeDef != null) { typeDef.AddBaseType(baseTypeDef, baseType); baseTypeDef.derivedTypes.Add(typeDef); } } // Initialize interfaces foreach (var typeDef in allTypes) { foreach (var iface in typeDef.TypeDef.Interfaces) { var ifaceTypeDef = ResolveType(iface.Interface) ?? ResolveOther(iface.Interface); if (ifaceTypeDef != null) typeDef.AddInterface(ifaceTypeDef, iface.Interface); } } // Find all non-nested types var allTypesDict = new Dictionary(); foreach (var t in allTypes) allTypesDict[t] = true; foreach (var t in allTypes) { foreach (var t2 in t.NestedTypes) allTypesDict.Remove(t2); } nonNestedTypes = new List(allTypesDict.Keys); foreach (var typeDef in allTypes) { if (typeDef.baseType == null || !typeDef.baseType.typeDef.HasModule) baseTypes.Add(typeDef); } } class AssemblyKeyDictionary where T : class { Dictionary dict = new Dictionary(new TypeEqualityComparer(SigComparerOptions.CompareAssemblyVersion)); Dictionary> refs = new Dictionary>(TypeEqualityComparer.Instance); public T this[ITypeDefOrRef type] { get { T value; if (TryGetValue(type, out value)) return value; throw new KeyNotFoundException(); } set { dict[type] = value; if (value != null) { List list; if (!refs.TryGetValue(type, out list)) refs[type] = list = new List(); list.Add(type); } } } public bool TryGetValue(ITypeDefOrRef type, out T value) { return dict.TryGetValue(type, out value); } public bool TryGetSimilarValue(ITypeDefOrRef type, out T value) { List list; if (!refs.TryGetValue(type, out list)) { value = default(T); return false; } // Find a type whose version is >= type's version and closest to it. ITypeDefOrRef foundType = null; var typeAsmName = type.DefinitionAssembly; IAssembly foundAsmName = null; foreach (var otherRef in list) { if (!dict.TryGetValue(otherRef, out value)) continue; if (typeAsmName == null) { foundType = otherRef; break; } var otherAsmName = otherRef.DefinitionAssembly; if (otherAsmName == null) continue; // Check pkt or we could return a type in eg. a SL assembly when it's not a SL app. if (!PublicKeyBase.TokenEquals(typeAsmName.PublicKeyOrToken, otherAsmName.PublicKeyOrToken)) continue; if (typeAsmName.Version > otherAsmName.Version) continue; // old version if (foundType == null) { foundAsmName = otherAsmName; foundType = otherRef; continue; } if (foundAsmName.Version <= otherAsmName.Version) continue; foundAsmName = otherAsmName; foundType = otherRef; } if (foundType != null) { value = dict[foundType]; return true; } value = default(T); return false; } } AssemblyKeyDictionary typeToTypeDefDict = new AssemblyKeyDictionary(); public MTypeDef ResolveOther(ITypeDefOrRef type) { if (type == null) return null; type = type.ScopeType; if (type == null) return null; MTypeDef typeDef; if (typeToTypeDefDict.TryGetValue(type, out typeDef)) return typeDef; var typeDef2 = deobfuscatorContext.ResolveType(type); if (typeDef2 == null) { typeToTypeDefDict.TryGetSimilarValue(type, out typeDef); typeToTypeDefDict[type] = typeDef; return typeDef; } if (typeToTypeDefDict.TryGetValue(typeDef2, out typeDef)) { typeToTypeDefDict[type] = typeDef; return typeDef; } typeToTypeDefDict[type] = null; // In case of a circular reference typeToTypeDefDict[typeDef2] = null; typeDef = new MTypeDef(typeDef2, null, 0); typeDef.AddMembers(); foreach (var iface in typeDef.TypeDef.Interfaces) { var ifaceDef = ResolveOther(iface.Interface); if (ifaceDef == null) continue; typeDef.AddInterface(ifaceDef, iface.Interface); } var baseDef = ResolveOther(typeDef.TypeDef.BaseType); if (baseDef != null) typeDef.AddBaseType(baseDef, typeDef.TypeDef.BaseType); typeToTypeDefDict[type] = typeDef; if (type != typeDef2) typeToTypeDefDict[typeDef2] = typeDef; return typeDef; } public MethodNameGroups InitializeVirtualMembers() { var groups = new MethodNameGroups(); foreach (var typeDef in allTypes) typeDef.InitializeVirtualMembers(groups, this); return groups; } public void OnTypesRenamed() { foreach (var module in modules) module.OnTypesRenamed(); } public void CleanUp() { #if PORT foreach (var module in DotNetUtils.typeCaches.invalidateAll()) AssemblyResolver.Instance.removeModule(module); #endif } // Returns null if it's a non-loaded module/assembly IEnumerable FindModules(ITypeDefOrRef type) { if (type == null) return null; var scope = type.Scope; if (scope == null) return null; var scopeType = scope.ScopeType; if (scopeType == ScopeType.AssemblyRef) return FindModules((AssemblyRef)scope); if (scopeType == ScopeType.ModuleDef) { var modules = FindModules((ModuleDef)scope); if (modules != null) return modules; } if (scopeType == ScopeType.ModuleRef) { var moduleRef = (ModuleRef)scope; if (moduleRef.Name == type.Module.Name) { var modules = FindModules(type.Module); if (modules != null) return modules; } } if (scopeType == ScopeType.ModuleRef || scopeType == ScopeType.ModuleDef) { var asm = type.Module.Assembly; if (asm == null) return null; var moduleHash = assemblyHash.Lookup(asm); if (moduleHash == null) return null; var module = moduleHash.Lookup(scope.ScopeName); if (module == null) return null; return new List { module }; } throw new ApplicationException(string.Format("scope is an unsupported type: {0}", scope.GetType())); } IEnumerable FindModules(AssemblyRef assemblyRef) { var moduleHash = assemblyHash.Lookup(assemblyRef); if (moduleHash != null) return moduleHash.Modules; return null; } IEnumerable FindModules(ModuleDef moduleDef) { Module module; if (modulesDict.TryGetValue(moduleDef, out module)) return new List { module }; return null; } bool IsAutoCreatedType(ITypeDefOrRef typeRef) { var ts = typeRef as TypeSpec; if (ts == null) return false; var sig = ts.TypeSig; if (sig == null) return false; return sig.IsSZArray || sig.IsArray || sig.IsPointer; } public MTypeDef ResolveType(ITypeDefOrRef typeRef) { var modules = FindModules(typeRef); if (modules == null) return null; foreach (var module in modules) { var rv = module.ResolveType(typeRef); if (rv != null) return rv; } if (IsAutoCreatedType(typeRef)) return null; Logger.e("Could not resolve TypeRef {0} ({1:X8}) (from {2} -> {3})", Utils.RemoveNewlines(typeRef), typeRef.MDToken.ToInt32(), typeRef.Module, typeRef.Scope); return null; } public MMethodDef ResolveMethod(IMethodDefOrRef methodRef) { if (methodRef.DeclaringType == null) return null; var modules = FindModules(methodRef.DeclaringType); if (modules == null) return null; foreach (var module in modules) { var rv = module.ResolveMethod(methodRef); if (rv != null) return rv; } if (IsAutoCreatedType(methodRef.DeclaringType)) return null; Logger.e("Could not resolve MethodRef {0} ({1:X8}) (from {2} -> {3})", Utils.RemoveNewlines(methodRef), methodRef.MDToken.ToInt32(), methodRef.DeclaringType.Module, methodRef.DeclaringType.Scope); return null; } public MFieldDef ResolveField(MemberRef fieldRef) { if (fieldRef.DeclaringType == null) return null; var modules = FindModules(fieldRef.DeclaringType); if (modules == null) return null; foreach (var module in modules) { var rv = module.ResolveField(fieldRef); if (rv != null) return rv; } if (IsAutoCreatedType(fieldRef.DeclaringType)) return null; Logger.e("Could not resolve FieldRef {0} ({1:X8}) (from {2} -> {3})", Utils.RemoveNewlines(fieldRef), fieldRef.MDToken.ToInt32(), fieldRef.DeclaringType.Module, fieldRef.DeclaringType.Scope); return null; } } }