/* Copyright (C) 2011-2015 de4dot@gmail.com This file is part of de4dot. de4dot is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. de4dot is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with de4dot. If not, see . */ using System; using System.Collections.Generic; namespace de4dot.blocks { class BlocksSorter { ScopeBlock scopeBlock; class BlockInfo { public int dfsNumber = -1; public int low; public BaseBlock baseBlock; public bool onStack; public BlockInfo(BaseBlock baseBlock) { this.baseBlock = baseBlock; } public bool Visited() { return dfsNumber >= 0; } public override string ToString() { return string.Format("L:{0}, D:{1}, S:{2}", low, dfsNumber, onStack); } } // It uses Tarjan's strongly connected components algorithm to find all SCCs. // See http://www.ics.uci.edu/~eppstein/161/960220.html or wikipedia for a good explanation. // The non-Tarjan code is still pretty simple and can (should) be improved. class Sorter { ScopeBlock scopeBlock; IList validBlocks; Dictionary blockToInfo = new Dictionary(); Stack stack = new Stack(); List sorted; int dfsNumber = 0; bool skipFirstBlock; BaseBlock firstBlock; public Sorter(ScopeBlock scopeBlock, IList validBlocks, bool skipFirstBlock) { this.scopeBlock = scopeBlock; this.validBlocks = validBlocks; this.skipFirstBlock = skipFirstBlock; } public List Sort() { if (validBlocks.Count == 0) return new List(); if (skipFirstBlock) firstBlock = validBlocks[0]; foreach (var block in validBlocks) { if (block != firstBlock) blockToInfo[block] = new BlockInfo(block); } sorted = new List(validBlocks.Count); var finalList = new List(validBlocks.Count); if (firstBlock is Block) { foreach (var target in GetTargets(firstBlock)) { Visit(target); finalList.AddRange(sorted); sorted.Clear(); } } foreach (var bb in validBlocks) { Visit(bb); finalList.AddRange(sorted); sorted.Clear(); } if (stack.Count > 0) throw new ApplicationException("Stack isn't empty"); if (firstBlock != null) finalList.Insert(0, firstBlock); else if (validBlocks[0] != finalList[0]) { // Make sure the original first block is first int index = finalList.IndexOf(validBlocks[0]); finalList.RemoveAt(index); finalList.Insert(0, validBlocks[0]); } return finalList; } void Visit(BaseBlock bb) { var info = GetInfo(bb); if (info == null) return; if (info.baseBlock == firstBlock) return; if (info.Visited()) return; Visit(info); } BlockInfo GetInfo(BaseBlock baseBlock) { baseBlock = scopeBlock.ToChild(baseBlock); if (baseBlock == null) return null; BlockInfo info; blockToInfo.TryGetValue(baseBlock, out info); return info; } List GetTargets(BaseBlock baseBlock) { var list = new List(); if (baseBlock is Block) { var block = (Block)baseBlock; AddTargets(list, block.GetTargets()); } else if (baseBlock is TryBlock) AddTargets(list, (TryBlock)baseBlock); else if (baseBlock is TryHandlerBlock) AddTargets(list, (TryHandlerBlock)baseBlock); else AddTargets(list, (ScopeBlock)baseBlock); return list; } void AddTargets(List dest, TryBlock tryBlock) { AddTargets(dest, (ScopeBlock)tryBlock); foreach (var tryHandlerBlock in tryBlock.TryHandlerBlocks) { dest.Add(tryHandlerBlock); AddTargets(dest, tryHandlerBlock); } } void AddTargets(List dest, TryHandlerBlock tryHandlerBlock) { AddTargets(dest, (ScopeBlock)tryHandlerBlock); dest.Add(tryHandlerBlock.FilterHandlerBlock); AddTargets(dest, tryHandlerBlock.FilterHandlerBlock); dest.Add(tryHandlerBlock.HandlerBlock); AddTargets(dest, tryHandlerBlock.HandlerBlock); } void AddTargets(List dest, ScopeBlock scopeBlock) { foreach (var block in scopeBlock.GetAllBlocks()) AddTargets(dest, block.GetTargets()); } void AddTargets(List dest, IEnumerable source) { var list = new List(source); list.Reverse(); foreach (var block in list) dest.Add(block); } struct VisitState { public BlockInfo Info; public List Targets; public int TargetIndex; public BlockInfo TargetInfo; public VisitState(BlockInfo info) { this.Info = info; this.Targets = null; this.TargetIndex = 0; this.TargetInfo = null; } } Stack visitStateStack = new Stack(); void Visit(BlockInfo info) { // This method used to be recursive but to prevent stack overflows, // it's not recursive anymore. VisitState state = new VisitState(info); recursive_call: if (state.Info.baseBlock == firstBlock) throw new ApplicationException("Can't visit firstBlock"); stack.Push(state.Info); state.Info.onStack = true; state.Info.dfsNumber = dfsNumber; state.Info.low = dfsNumber; dfsNumber++; state.Targets = GetTargets(state.Info.baseBlock); state.TargetIndex = 0; return_to_caller: for (; state.TargetIndex < state.Targets.Count; state.TargetIndex++) { state.TargetInfo = GetInfo(state.Targets[state.TargetIndex]); if (state.TargetInfo == null) continue; if (state.TargetInfo.baseBlock == firstBlock) continue; if (!state.TargetInfo.Visited()) { visitStateStack.Push(state); state = new VisitState(state.TargetInfo); goto recursive_call; } else if (state.TargetInfo.onStack) state.Info.low = Math.Min(state.Info.low, state.TargetInfo.dfsNumber); } if (state.Info.low != state.Info.dfsNumber) goto return_from_method; var sccBlocks = new List(); while (true) { var poppedInfo = stack.Pop(); poppedInfo.onStack = false; sccBlocks.Add(poppedInfo.baseBlock); if (ReferenceEquals(state.Info, poppedInfo)) break; } if (sccBlocks.Count > 1) { sccBlocks.Reverse(); var result = new Sorter(scopeBlock, sccBlocks, true).Sort(); SortLoopBlock(result); sorted.InsertRange(0, result); } else { sorted.Insert(0, sccBlocks[0]); } return_from_method: if (visitStateStack.Count == 0) return; state = visitStateStack.Pop(); state.Info.low = Math.Min(state.Info.low, state.TargetInfo.low); state.TargetIndex++; goto return_to_caller; } void SortLoopBlock(List list) { // Some popular decompilers sometimes produce bad output unless the loop condition // checker block is at the end of the loop. Eg., they may use a while loop when // it's really a for/foreach loop. var loopStart = GetLoopStartBlock(list); if (loopStart == null) return; if (!list.Remove(loopStart)) throw new ApplicationException("Could not remove block"); list.Add(loopStart); } Block GetLoopStartBlock(List list) { var loopBlocks = new Dictionary(list.Count); foreach (var bb in list) { var block = bb as Block; if (block != null) loopBlocks[block] = true; } var targetBlocks = new Dictionary(); foreach (var bb in list) { var block = bb as Block; if (block == null) continue; foreach (var source in block.Sources) { if (loopBlocks.ContainsKey(source)) continue; int count; targetBlocks.TryGetValue(block, out count); targetBlocks[block] = count + 1; } } int max = -1; Block loopStart = null; foreach (var kv in targetBlocks) { if (kv.Value <= max) continue; max = kv.Value; loopStart = kv.Key; } return loopStart; } } public BlocksSorter(ScopeBlock scopeBlock) { this.scopeBlock = scopeBlock; } public List Sort() { var sorted = new Sorter(scopeBlock, scopeBlock.BaseBlocks, false).Sort(); return new ForwardScanOrder(scopeBlock, sorted).Fix(); } } }