diff --git a/de4dot.code/deobfuscators/SmartAssembly/TamperProtectionRemover.cs b/de4dot.code/deobfuscators/SmartAssembly/TamperProtectionRemover.cs index 14de31d1..7d4f1d6b 100644 --- a/de4dot.code/deobfuscators/SmartAssembly/TamperProtectionRemover.cs +++ b/de4dot.code/deobfuscators/SmartAssembly/TamperProtectionRemover.cs @@ -28,6 +28,11 @@ namespace de4dot.deobfuscators.SmartAssembly { ModuleDefinition module; List pinvokeMethods = new List(); + enum Type { + V1, + V2, + } + public IList PinvokeMethods { get { return pinvokeMethods; } } @@ -61,155 +66,193 @@ namespace de4dot.deobfuscators.SmartAssembly { public int End { get; set; } } - IList findTamperBlocks(Blocks blocks, IList allBlocks, out MethodDefinition pinvokeMethod) { - var list = new List(3); - - var first = findFirstBlocks(allBlocks, blocks.Locals, out pinvokeMethod); - if (first == null) - return null; - - var second = first[1]; - var badBlock = second.Block.LastInstr.isBrfalse() ? second.Block.Targets[0] : second.Block.FallThrough; - var last = findLastBlock(badBlock); - if (last == null) - return null; - - list.AddRange(first); - list.Add(last); - return list; + class TamperBlocks { + public Type type; + public MethodDefinition pinvokeMethod; + public BlockInfo first; + public BlockInfo second; + public BlockInfo bad; } - IList findFirstBlocks(IList allBlocks, IList locals, out MethodDefinition pinvokeMethod) { - pinvokeMethod = null; + TamperBlocks findTamperBlocks(Blocks blocks, IList allBlocks) { + var tamperBlocks = new TamperBlocks(); + if (!findFirstBlocks(tamperBlocks, allBlocks, blocks.Locals)) + return null; + + var second = tamperBlocks.second; + var badBlock = second.Block.LastInstr.isBrfalse() ? second.Block.Targets[0] : second.Block.FallThrough; + tamperBlocks.bad = findBadBlock(badBlock); + if (tamperBlocks.bad == null) + return null; + + return tamperBlocks; + } + + bool findFirstBlocks(TamperBlocks tamperBlocks, IList allBlocks, IList locals) { foreach (var b in allBlocks) { - if (!b.LastInstr.isBrfalse()) - continue; - try { - var block = b; - var list = new List(); - var instrs = block.Instructions; - int start = instrs.Count - 1; - int end = start; - Instr instr; - MethodReference method; - - /* - * ldc.i4.0 - * stloc X - * call GetExecutingAssembly() - * stloc Y - * ldloc Y - * callvirt Location - * ldc.i4.1 - * ldloca X - * call StrongNameSignatureVerificationEx - * pop - * ldloc X - * brfalse bad_code - * ldloc Y - * callvirt FullName() - * ldstr "......" - * callvirt EndsWith(string) - * brfalse bad_code / brtrue good_code - */ - - instr = instrs[--start]; - if (!instr.isLdloc()) - continue; - var loc0 = Instr.getLocalVar(locals, instr); - - instr = instrs[--start]; - if (instr.OpCode != OpCodes.Pop) - continue; - - instr = instrs[--start]; - if (instr.OpCode != OpCodes.Call) - continue; - pinvokeMethod = DotNetUtils.getMethod(module, instr.Operand as MethodReference); - if (!DotNetUtils.isPinvokeMethod(pinvokeMethod, "mscorwks", "StrongNameSignatureVerificationEx")) - continue; - - while (true) { - instr = instrs[--start]; - if (instr.OpCode == OpCodes.Callvirt) - break; - } - method = (MethodReference)instr.Operand; - if (method.ToString() != "System.String System.Reflection.Assembly::get_Location()") - continue; - - while (true) { - instr = instrs[--start]; - if (instr.OpCode == OpCodes.Call) - break; - } - method = (MethodReference)instr.Operand; - if (method.ToString() != "System.Reflection.Assembly System.Reflection.Assembly::GetExecutingAssembly()") - continue; - - instr = instrs[--start]; - if (!instr.isStloc() || Instr.getLocalVar(locals, instr) != loc0) - continue; - instr = instrs[--start]; - if (!instr.isLdcI4()) - continue; - - list.Add(new BlockInfo { - Block = block, - Start = start, - End = end, - }); - - block = block.FallThrough; - instrs = block.Instructions; - start = end = 0; - - instr = instrs[end++]; - if (!instr.isLdloc()) - continue; - - instr = instrs[end++]; - if (instr.OpCode != OpCodes.Callvirt) - continue; - method = (MethodReference)instr.Operand; - if (method.ToString() != "System.String System.Reflection.Assembly::get_FullName()") - continue; - - instr = instrs[end++]; - if (instr.OpCode != OpCodes.Ldstr) - continue; - - instr = instrs[end++]; - if (instr.OpCode != OpCodes.Callvirt) - continue; - method = (MethodReference)instr.Operand; - if (method.ToString() != "System.Boolean System.String::EndsWith(System.String)") - continue; - - instr = instrs[end++]; - if (!instr.isBrfalse() && !instr.isBrtrue()) - continue; - - end--; - list.Add(new BlockInfo { - Block = block, - Start = start, - End = end, - }); - - return list; + if (findFirstBlocks(b, tamperBlocks, allBlocks, locals)) + return true; } catch (ArgumentOutOfRangeException) { continue; } } - return null; + return false; } - BlockInfo findLastBlock(Block last) { + static int findCallMethod(Block block, int index, bool keepLooking, Func func) { + var instrs = block.Instructions; + for (int i = index; i < instrs.Count; i++) { + var instr = instrs[i]; + if (instr.OpCode.Code != Code.Call && instr.OpCode.Code != Code.Callvirt) + continue; + + var calledMethod = instr.Operand as MethodReference; + if (calledMethod != null && func(calledMethod)) + return i; + if (!keepLooking) + return -1; + } + return -1; + } + + bool findFirstBlocks(Block block, TamperBlocks tamperBlocks, IList allBlocks, IList locals) { + if (!block.LastInstr.isBrfalse()) + return false; + + /* + * ldc.i4.0 + * stloc X + * call GetExecutingAssembly() + * stloc Y + * ldloc Y + * callvirt Location + * ldc.i4.1 + * ldloca X + * call StrongNameSignatureVerificationEx + * pop / brfalse bad_code + * ldloc X + * brfalse bad_code + * ldloc Y + * callvirt FullName() + * ldstr "......" + * callvirt EndsWith(string) + * brfalse bad_code / brtrue good_code + */ + + var instrs = block.Instructions; + int end = instrs.Count - 1; + Instr instr; + MethodReference method; + tamperBlocks.type = Type.V1; + + int index = 0; + + int start = findCallMethod(block, index, true, (calledMethod) => calledMethod.ToString() == "System.Reflection.Assembly System.Reflection.Assembly::GetExecutingAssembly()"); + if (start < 0) + return false; + index = start + 1; + instr = instrs[--start]; + if (!instr.isStloc()) + return false; + var loc0 = Instr.getLocalVar(locals, instr); + instr = instrs[--start]; + if (!instr.isLdcI4()) + return false; + + index = findCallMethod(block, index, false, (calledMethod) => calledMethod.ToString() == "System.String System.Reflection.Assembly::get_Location()"); + if (index < 0) + return false; + index++; + + index = findCallMethod(block, index, false, (calledMethod) => { + tamperBlocks.pinvokeMethod = DotNetUtils.getMethod(module, calledMethod); + return DotNetUtils.isPinvokeMethod(tamperBlocks.pinvokeMethod, "mscorwks", "StrongNameSignatureVerificationEx"); + }); + if (index < 0) + return false; + index++; + + if (!instrs[index].isBrfalse()) { + if (instrs[index].OpCode.Code != Code.Pop) + return false; + instr = instrs[index + 1]; + if (!instr.isLdloc() || Instr.getLocalVar(locals, instr) != loc0) + return false; + if (!instrs[index + 2].isBrfalse()) + return false; + + tamperBlocks.type = Type.V1; + tamperBlocks.first = new BlockInfo { + Block = block, + Start = start, + End = end, + }; + } + else { + tamperBlocks.type = Type.V2; + tamperBlocks.first = new BlockInfo { + Block = block, + Start = start, + End = end, + }; + + block = block.FallThrough; + if (block == null) + return false; + instrs = block.Instructions; + index = 0; + instr = instrs[index]; + if (!instr.isLdloc() || Instr.getLocalVar(locals, instr) != loc0) + return false; + if (!instrs[index + 1].isBrfalse()) + return false; + } + + block = block.FallThrough; + instrs = block.Instructions; + start = end = 0; + + instr = instrs[end++]; + if (!instr.isLdloc()) + return false; + + instr = instrs[end++]; + if (instr.OpCode != OpCodes.Callvirt) + return false; + method = (MethodReference)instr.Operand; + if (method.ToString() != "System.String System.Reflection.Assembly::get_FullName()") + return false; + + instr = instrs[end++]; + if (instr.OpCode != OpCodes.Ldstr) + return false; + + instr = instrs[end++]; + if (instr.OpCode != OpCodes.Callvirt) + return false; + method = (MethodReference)instr.Operand; + if (method.ToString() != "System.Boolean System.String::EndsWith(System.String)") + return false; + + instr = instrs[end++]; + if (!instr.isBrfalse() && !instr.isBrtrue()) + return false; + + end--; + tamperBlocks.second = new BlockInfo { + Block = block, + Start = start, + End = end, + }; + + return true; + } + + BlockInfo findBadBlock(Block last) { /* * ldstr "........." * newobj System.Security.SecurityException(string) @@ -248,42 +291,66 @@ namespace de4dot.deobfuscators.SmartAssembly { } bool removeTamperProtection(Blocks blocks) { - MethodDefinition pinvokeMethod; var allBlocks = new List(blocks.MethodBlocks.getAllBlocks()); - var blockInfos = findTamperBlocks(blocks, allBlocks, out pinvokeMethod); + var tamperBlocks = findTamperBlocks(blocks, allBlocks); - if (blockInfos == null) { + if (tamperBlocks == null) { if (isTamperProtected(allBlocks)) Log.w("Could not remove tamper protection code: {0} ({1:X8})", blocks.Method, blocks.Method.MetadataToken.ToUInt32()); return false; } - if (blockInfos.Count != 3) - throw new ApplicationException("Invalid state"); - var first = blockInfos[0]; - var second = blockInfos[1]; - var bad = blockInfos[2]; - if (first.Block.Targets.Count != 1 || first.Block.Targets[0] != bad.Block) - throw new ApplicationException("Invalid state"); - if (second.Start != 0 || second.End + 1 != second.Block.Instructions.Count) - throw new ApplicationException("Invalid state"); - if (bad.Start != 0 || bad.End + 1 != bad.Block.Instructions.Count) - throw new ApplicationException("Invalid state"); - var goodBlock = second.Block.LastInstr.isBrtrue() ? second.Block.Targets[0] : second.Block.FallThrough; - - first.Block.remove(first.Start, first.End - first.Start + 1); - first.Block.replaceLastInstrsWithBranch(0, goodBlock); - removeDeadBlock(second); - removeDeadBlock(bad); - pinvokeMethods.Add(pinvokeMethod); + switch (tamperBlocks.type) { + case Type.V1: + removeTamperV1(tamperBlocks); + break; + case Type.V2: + removeTamperV2(tamperBlocks); + break; + default: + throw new ApplicationException("Unknown type"); + } + pinvokeMethods.Add(tamperBlocks.pinvokeMethod); return true; } - void removeDeadBlock(BlockInfo info) { - var parent = (ScopeBlock)info.Block.Parent; + void removeTamperV1(TamperBlocks tamperBlocks) { + var first = tamperBlocks.first; + var second = tamperBlocks.second; + var bad = tamperBlocks.bad; + var goodBlock = second.Block.LastInstr.isBrtrue() ? second.Block.Targets[0] : second.Block.FallThrough; + + if (first.Block.Targets.Count != 1 || first.Block.Targets[0] != bad.Block) + throw new ApplicationException("Invalid state"); + + first.Block.remove(first.Start, first.End - first.Start + 1); + first.Block.replaceLastInstrsWithBranch(0, goodBlock); + removeDeadBlock(second.Block); + removeDeadBlock(bad.Block); + } + + void removeTamperV2(TamperBlocks tamperBlocks) { + var first = tamperBlocks.first; + var second = tamperBlocks.second.Block; + var bad = tamperBlocks.bad.Block; + var firstFallthrough = first.Block.FallThrough; + var goodBlock = second.LastInstr.isBrtrue() ? second.Targets[0] : second.FallThrough; + + if (first.Block.Targets.Count != 1 || first.Block.Targets[0] != bad) + throw new ApplicationException("Invalid state"); + + first.Block.remove(first.Start, first.End - first.Start + 1); + first.Block.replaceLastInstrsWithBranch(0, goodBlock); + removeDeadBlock(firstFallthrough); + removeDeadBlock(second); + removeDeadBlock(bad); + } + + void removeDeadBlock(Block block) { + var parent = (ScopeBlock)block.Parent; if (parent != null) // null if already dead - parent.removeDeadBlock(info.Block); + parent.removeDeadBlock(block); } } }