diff --git a/blocks/BlocksSorter.cs b/blocks/BlocksSorter.cs
index 9566810b..0f7f90e4 100644
--- a/blocks/BlocksSorter.cs
+++ b/blocks/BlocksSorter.cs
@@ -17,70 +17,260 @@
along with de4dot. If not, see .
*/
+using System;
using System.Collections.Generic;
namespace de4dot.blocks {
class BlocksSorter {
ScopeBlock scopeBlock;
- Dictionary visited;
- List sorted;
+
+ class BlockInfo {
+ public int dfsNumber = -1;
+ public int low;
+ public BaseBlock baseBlock;
+ public bool onStack;
+
+ public BlockInfo(BaseBlock baseBlock) {
+ this.baseBlock = baseBlock;
+ }
+
+ public bool visited() {
+ return dfsNumber >= 0;
+ }
+
+ public override string ToString() {
+ return string.Format("L:{0}, D:{1}, S:{2}", low, dfsNumber, onStack);
+ }
+ }
+
+ // It uses Tarjan's strongly connected components algorithm to find all SCCs.
+ // See http://www.ics.uci.edu/~eppstein/161/960220.html or wikipedia for a good explanation.
+ // The non-Tarjan code is still pretty simple and can (should) be improved.
+ class Sorter {
+ ScopeBlock scopeBlock;
+ IList validBlocks;
+ Dictionary blockToInfo = new Dictionary();
+ Stack stack = new Stack();
+ List sorted;
+ int dfsNumber = 0;
+ bool skipFirstBlock;
+ BaseBlock firstBlock;
+
+ public Sorter(ScopeBlock scopeBlock, IList validBlocks, bool skipFirstBlock) {
+ this.scopeBlock = scopeBlock;
+ this.validBlocks = validBlocks;
+ this.skipFirstBlock = skipFirstBlock;
+ }
+
+ public List sort() {
+ if (validBlocks.Count == 0)
+ return new List();
+ if (skipFirstBlock)
+ firstBlock = validBlocks[0];
+
+ foreach (var block in validBlocks) {
+ if (block != firstBlock)
+ blockToInfo[block] = new BlockInfo(block);
+ }
+
+ sorted = new List(validBlocks.Count);
+ var finalList = new List(validBlocks.Count);
+
+ if (firstBlock is Block) {
+ foreach (var target in getTargets(firstBlock)) {
+ visit(target);
+ finalList.AddRange(sorted);
+ sorted.Clear();
+ }
+ }
+ foreach (var bb in validBlocks) {
+ visit(bb);
+ finalList.AddRange(sorted);
+ sorted.Clear();
+ }
+
+ if (stack.Count > 0)
+ throw new ApplicationException("Stack isn't empty");
+
+ if (firstBlock != null)
+ finalList.Insert(0, firstBlock);
+ else if (validBlocks[0] != finalList[0]) {
+ // Make sure the original first block is first
+ int index = finalList.IndexOf(validBlocks[0]);
+ finalList.RemoveAt(index);
+ finalList.Insert(0, validBlocks[0]);
+ }
+ return finalList;
+ }
+
+ void visit(BaseBlock bb) {
+ var info = getInfo(bb);
+ if (info == null)
+ return;
+ if (info.baseBlock == firstBlock)
+ return;
+ if (info.visited())
+ return;
+ visit(info);
+ }
+
+ BlockInfo getInfo(BaseBlock baseBlock) {
+ baseBlock = scopeBlock.toChild(baseBlock);
+ if (baseBlock == null)
+ return null;
+ BlockInfo info;
+ blockToInfo.TryGetValue(baseBlock, out info);
+ return info;
+ }
+
+ List getTargets(BaseBlock baseBlock) {
+ var list = new List();
+
+ if (baseBlock is Block) {
+ var block = (Block)baseBlock;
+ addTargets(list, block.getTargets());
+ }
+ else if (baseBlock is TryBlock)
+ addTargets(list, (TryBlock)baseBlock);
+ else if (baseBlock is TryHandlerBlock)
+ addTargets(list, (TryHandlerBlock)baseBlock);
+ else
+ addTargets(list, (ScopeBlock)baseBlock);
+
+ return list;
+ }
+
+ void addTargets(List dest, TryBlock tryBlock) {
+ addTargets(dest, (ScopeBlock)tryBlock);
+ foreach (var tryHandlerBlock in tryBlock.TryHandlerBlocks) {
+ dest.Add(tryHandlerBlock);
+ addTargets(dest, tryHandlerBlock);
+ }
+ }
+
+ void addTargets(List dest, TryHandlerBlock tryHandlerBlock) {
+ addTargets(dest, (ScopeBlock)tryHandlerBlock);
+
+ dest.Add(tryHandlerBlock.FilterHandlerBlock);
+ addTargets(dest, tryHandlerBlock.FilterHandlerBlock);
+
+ dest.Add(tryHandlerBlock.HandlerBlock);
+ addTargets(dest, tryHandlerBlock.HandlerBlock);
+ }
+
+ void addTargets(List dest, ScopeBlock scopeBlock) {
+ foreach (var block in scopeBlock.getAllBlocks())
+ addTargets(dest, block.getTargets());
+ }
+
+ void addTargets(List dest, IEnumerable source) {
+ var list = new List(source);
+ list.Reverse();
+ foreach (var block in list)
+ dest.Add(block);
+ }
+
+ void visit(BlockInfo info) {
+ if (info.baseBlock == firstBlock)
+ throw new ApplicationException("Can't visit firstBlock");
+ stack.Push(info);
+ info.onStack = true;
+ info.dfsNumber = dfsNumber;
+ info.low = dfsNumber;
+ dfsNumber++;
+
+ foreach (var tmp in getTargets(info.baseBlock)) {
+ var targetInfo = getInfo(tmp);
+ if (targetInfo == null)
+ continue;
+ if (targetInfo.baseBlock == firstBlock)
+ continue;
+
+ if (!targetInfo.visited()) {
+ visit(targetInfo);
+ info.low = Math.Min(info.low, targetInfo.low);
+ }
+ else if (targetInfo.onStack)
+ info.low = Math.Min(info.low, targetInfo.dfsNumber);
+ }
+
+ if (info.low != info.dfsNumber)
+ return;
+ var sccBlocks = new List();
+ while (true) {
+ var poppedInfo = stack.Pop();
+ poppedInfo.onStack = false;
+ sccBlocks.Add(poppedInfo.baseBlock);
+ if (ReferenceEquals(info, poppedInfo))
+ break;
+ }
+ if (sccBlocks.Count > 1) {
+ sccBlocks.Reverse();
+ var result = new Sorter(scopeBlock, sccBlocks, true).sort();
+ sortLoopBlock(result);
+ sorted.InsertRange(0, result);
+ }
+ else {
+ sorted.Insert(0, sccBlocks[0]);
+ }
+ }
+
+ void sortLoopBlock(List list) {
+ // Some popular decompilers sometimes produce bad output unless the loop condition
+ // checker block is at the end of the loop. Eg., they may use a while loop when
+ // it's really a for/foreach loop.
+
+ var loopStart = getLoopStartBlock(list);
+ if (loopStart == null)
+ return;
+
+ if (!list.Remove(loopStart))
+ throw new ApplicationException("Could not remove block");
+ list.Add(loopStart);
+ }
+
+ Block getLoopStartBlock(List list) {
+ var loopBlocks = new Dictionary(list.Count);
+ foreach (var bb in list) {
+ var block = bb as Block;
+ if (block != null)
+ loopBlocks[block] = true;
+ }
+
+ var targetBlocks = new Dictionary();
+ foreach (var bb in list) {
+ var block = bb as Block;
+ if (block == null)
+ continue;
+ foreach (var source in block.Sources) {
+ if (loopBlocks.ContainsKey(source))
+ continue;
+ int count;
+ targetBlocks.TryGetValue(block, out count);
+ targetBlocks[block] = count + 1;
+ }
+ }
+
+ int max = -1;
+ Block loopStart = null;
+ foreach (var kv in targetBlocks) {
+ if (kv.Value <= max)
+ continue;
+ max = kv.Value;
+ loopStart = kv.Key;
+ }
+
+ return loopStart;
+ }
+ }
public BlocksSorter(ScopeBlock scopeBlock) {
this.scopeBlock = scopeBlock;
}
- bool hasVisited(BaseBlock bb) {
- bool hasVisited;
- if (visited.TryGetValue(bb, out hasVisited))
- return hasVisited;
- visited[bb] = false;
- return false;
- }
-
public List sort() {
- visited = new Dictionary();
- sorted = new List(scopeBlock.BaseBlocks.Count);
-
- if (scopeBlock.BaseBlocks.Count > 0)
- search(scopeBlock.BaseBlocks[0]);
- sorted.Reverse(); // It's in reverse order
-
- // Just in case there's dead code or unreferenced exception blocks
- foreach (var bb in scopeBlock.BaseBlocks) {
- if (hasVisited(bb))
- continue;
- sorted.Add(bb);
- }
-
- sorted = new ForwardScanOrder(scopeBlock, sorted).fix();
-
- return sorted;
- }
-
- // Depth-first order
- void search(BaseBlock bb) {
- if (hasVisited(bb))
- return;
-
- visited[bb] = true;
- var block = bb as Block; // Block or ScopeBlock
- if (block != null) {
- // Since the sorted array will be in reverse order, and we want the
- // conditional branches to fall through to their fall-through target, make
- // sure the FallThrough target is added last! Some conditional instructions
- // aren't reversible (eg. beq and bne.un) since they don't take the same
- // types of arguments. This will also make sure .NET Reflector doesn't
- // crash (sometimes).
- var targets = new List(block.getTargets());
- targets.Reverse();
-
- foreach (var target in targets) {
- var child = scopeBlock.toChild(target);
- if (child != null)
- search(child);
- }
- }
- sorted.Add(bb);
+ var sorted = new Sorter(scopeBlock, scopeBlock.BaseBlocks, false).sort();
+ return new ForwardScanOrder(scopeBlock, sorted).fix();
}
}
}