Restore method return types

This commit is contained in:
de4dot 2011-11-01 02:22:05 +01:00
parent bfb83b7482
commit 8faf7389ad

View File

@ -29,27 +29,42 @@ namespace de4dot.deobfuscators {
class TypesRestorer { class TypesRestorer {
ModuleDefinition module; ModuleDefinition module;
List<MethodDefinition> allMethods; List<MethodDefinition> allMethods;
Dictionary<ParameterDefinition, ArgAccessInfo> argInfos = new Dictionary<ParameterDefinition, ArgAccessInfo>(); Dictionary<ParameterDefinition, TypeInfo<ParameterDefinition>> argInfos = new Dictionary<ParameterDefinition, TypeInfo<ParameterDefinition>>();
List<ArgAccessInfo> changedArgInfos = new List<ArgAccessInfo>(); List<TypeInfo<ParameterDefinition>> changedArgInfos = new List<TypeInfo<ParameterDefinition>>();
Dictionary<FieldReferenceAndDeclaringTypeKey, FieldWriteInfo> fieldWrites = new Dictionary<FieldReferenceAndDeclaringTypeKey, FieldWriteInfo>(); List<TypeInfo<ParameterDefinition>> changedReturnTypes = new List<TypeInfo<ParameterDefinition>>();
TypeInfo<ParameterDefinition> methodReturnInfo;
Dictionary<FieldReferenceAndDeclaringTypeKey, TypeInfo<FieldDefinition>> fieldWrites = new Dictionary<FieldReferenceAndDeclaringTypeKey, TypeInfo<FieldDefinition>>();
class ArgAccessInfo { class TypeInfo<T> {
public Dictionary<TypeReferenceKey, bool> types = new Dictionary<TypeReferenceKey, bool>(); public Dictionary<TypeReferenceKey, bool> types = new Dictionary<TypeReferenceKey, bool>();
public TypeReference newType = null; public TypeReference newType = null;
public ParameterDefinition arg; public T arg;
public ArgAccessInfo(ParameterDefinition arg) { public TypeInfo(T arg) {
this.arg = arg; this.arg = arg;
} }
public bool updateNewType() {
if (types.Count == 0)
return false;
TypeReference theNewType = null;
foreach (var key in types.Keys) {
if (theNewType == null) {
theNewType = key.TypeReference;
continue;
} }
theNewType = getCommonBaseClass(theNewType, key.TypeReference);
if (theNewType == null)
break;
}
if (theNewType == null)
return false;
if (MemberReferenceHelper.compareTypes(theNewType, newType))
return false;
class FieldWriteInfo { newType = theNewType;
public Dictionary<TypeReferenceKey, bool> types = new Dictionary<TypeReferenceKey, bool>(); return true;
public TypeReference newType = null;
public FieldDefinition field;
public FieldWriteInfo(FieldDefinition field) {
this.field = field;
} }
} }
@ -68,7 +83,7 @@ namespace de4dot.deobfuscators {
continue; continue;
var key = new FieldReferenceAndDeclaringTypeKey(field); var key = new FieldReferenceAndDeclaringTypeKey(field);
fieldWrites[key] = new FieldWriteInfo(field); fieldWrites[key] = new TypeInfo<FieldDefinition>(field);
} }
} }
@ -83,46 +98,28 @@ namespace de4dot.deobfuscators {
bool deobfuscateMethods() { bool deobfuscateMethods() {
changedArgInfos.Clear(); changedArgInfos.Clear();
changedReturnTypes.Clear();
foreach (var method in allMethods) { foreach (var method in allMethods) {
methodReturnInfo = new TypeInfo<ParameterDefinition>(method.MethodReturnType.Parameter2);
deobfuscateMethod(method); deobfuscateMethod(method);
if (methodReturnInfo.updateNewType())
changedReturnTypes.Add(methodReturnInfo);
foreach (var info in argInfos.Values) { foreach (var info in argInfos.Values) {
if (info.types.Count == 0) if (info.updateNewType())
continue;
TypeReference newType = null;
foreach (var key in info.types.Keys) {
if (newType == null) {
newType = key.TypeReference;
continue;
}
newType = getCommonBaseClass(newType, key.TypeReference);
if (newType == null)
break;
}
if (newType == null)
continue;
if (MemberReferenceHelper.compareTypes(newType, info.newType))
continue;
info.newType = newType;
changedArgInfos.Add(info); changedArgInfos.Add(info);
} }
} }
if (changedArgInfos.Count == 0) if (changedArgInfos.Count == 0 && changedReturnTypes.Count == 0)
return false; return false;
changedArgInfos.Sort((a, b) => { changedArgInfos.Sort((a, b) => sortTypeInfos(a, b));
if (a.arg.Method.MetadataToken.ToInt32() < b.arg.Method.MetadataToken.ToInt32()) return -1; changedReturnTypes.Sort((a, b) => sortTypeInfos(a, b));
if (a.arg.Method.MetadataToken.ToInt32() > b.arg.Method.MetadataToken.ToInt32()) return 1;
if (a.arg.Index < b.arg.Index) return -1;
if (a.arg.Index < b.arg.Index) return 1;
return 0;
});
bool changed = false; bool changed = false;
if (changedArgInfos.Count > 0) {
Log.v("Changing method arg types from object -> real type"); Log.v("Changing method arg types from object -> real type");
Log.indent(); Log.indent();
IMethodSignature updatedMethod = null; IMethodSignature updatedMethod = null;
@ -143,22 +140,54 @@ namespace de4dot.deobfuscators {
} }
Log.deIndent(); Log.deIndent();
Log.deIndent(); Log.deIndent();
}
if (changedReturnTypes.Count > 0) {
Log.v("Changing method return types from object -> real type");
Log.indent();
foreach (var info in changedReturnTypes) {
if (info.newType == null || MemberReferenceHelper.isSystemObject(info.newType))
continue;
Log.v("{0:X8}: new type {1} ({2:X8})", info.arg.Method.MetadataToken.ToInt32(), info.newType, info.newType.MetadataToken.ToInt32());
info.arg.Method.MethodReturnType.ReturnType = info.newType;
info.arg.ParameterType = info.newType;
changed = true;
}
Log.deIndent();
}
return changed; return changed;
} }
static int sortTypeInfos(TypeInfo<ParameterDefinition> a, TypeInfo<ParameterDefinition> b) {
if (a.arg.Method.MetadataToken.ToInt32() < b.arg.Method.MetadataToken.ToInt32())
return -1;
if (a.arg.Method.MetadataToken.ToInt32() > b.arg.Method.MetadataToken.ToInt32())
return 1;
if (a.arg.Index < b.arg.Index)
return -1;
if (a.arg.Index < b.arg.Index)
return 1;
return 0;
}
void deobfuscateMethod(MethodDefinition method) { void deobfuscateMethod(MethodDefinition method) {
if (!method.IsStatic || method.Body == null) if (!method.IsStatic || method.Body == null)
return; return;
bool fixReturnType = MemberReferenceHelper.isSystemObject(method.MethodReturnType.ReturnType);
argInfos.Clear(); argInfos.Clear();
foreach (var arg in method.Parameters) { foreach (var arg in method.Parameters) {
if (arg.ParameterType == null || arg.ParameterType.IsValueType) if (arg.ParameterType == null || arg.ParameterType.IsValueType)
continue; continue;
if (!MemberReferenceHelper.isSystemObject(arg.ParameterType)) if (!MemberReferenceHelper.isSystemObject(arg.ParameterType))
continue; continue;
argInfos[arg] = new ArgAccessInfo(arg); argInfos[arg] = new TypeInfo<ParameterDefinition>(arg);
} }
if (argInfos.Count == 0) if (argInfos.Count == 0 && !fixReturnType)
return; return;
var methodParams = DotNetUtils.getParameters(method); var methodParams = DotNetUtils.getParameters(method);
@ -169,6 +198,15 @@ namespace de4dot.deobfuscators {
for (int i = 0; i < instructions.Count; i++) { for (int i = 0; i < instructions.Count; i++) {
var instr = instructions[i]; var instr = instructions[i];
switch (instr.OpCode.Code) { switch (instr.OpCode.Code) {
case Code.Ret:
if (!fixReturnType)
break;
var type = getLoadedType(method, instructions, i);
if (type == null)
break;
methodReturnInfo.types[new TypeReferenceKey(type)] = true;
break;
case Code.Call: case Code.Call:
case Code.Calli: case Code.Calli:
case Code.Callvirt: case Code.Callvirt:
@ -327,12 +365,10 @@ namespace de4dot.deobfuscators {
if (methodParam == null || type == null) if (methodParam == null || type == null)
return false; return false;
if (type.IsValueType) if (!isValidType(type))
return false;
if (MemberReferenceHelper.isSystemObject(type))
return false; return false;
ArgAccessInfo info; TypeInfo<ParameterDefinition> info;
if (!argInfos.TryGetValue(methodParam, out info)) if (!argInfos.TryGetValue(methodParam, out info))
return false; return false;
var key = new TypeReferenceKey(type); var key = new TypeReferenceKey(type);
@ -389,13 +425,13 @@ namespace de4dot.deobfuscators {
if (!updateFields()) if (!updateFields())
return false; return false;
var infos = new List<FieldWriteInfo>(fieldWrites.Values); var infos = new List<TypeInfo<FieldDefinition>>(fieldWrites.Values);
infos.Sort((a, b) => { infos.Sort((a, b) => {
if (a.field.DeclaringType.MetadataToken.ToInt32() < b.field.DeclaringType.MetadataToken.ToInt32()) return -1; if (a.arg.DeclaringType.MetadataToken.ToInt32() < b.arg.DeclaringType.MetadataToken.ToInt32()) return -1;
if (a.field.DeclaringType.MetadataToken.ToInt32() > b.field.DeclaringType.MetadataToken.ToInt32()) return 1; if (a.arg.DeclaringType.MetadataToken.ToInt32() > b.arg.DeclaringType.MetadataToken.ToInt32()) return 1;
if (a.field.MetadataToken.ToInt32() < b.field.MetadataToken.ToInt32()) return -1; if (a.arg.MetadataToken.ToInt32() < b.arg.MetadataToken.ToInt32()) return -1;
if (a.field.MetadataToken.ToInt32() > b.field.MetadataToken.ToInt32()) return 1; if (a.arg.MetadataToken.ToInt32() > b.arg.MetadataToken.ToInt32()) return 1;
return 0; return 0;
}); });
@ -409,9 +445,9 @@ namespace de4dot.deobfuscators {
if (info.newType == null || MemberReferenceHelper.isSystemObject(info.newType)) if (info.newType == null || MemberReferenceHelper.isSystemObject(info.newType))
continue; continue;
fieldWrites.Remove(new FieldReferenceAndDeclaringTypeKey(info.field)); fieldWrites.Remove(new FieldReferenceAndDeclaringTypeKey(info.arg));
Log.v("{0:X8}: new type: {1} ({2:X8})", info.field.MetadataToken.ToInt32(), info.newType, info.newType.MetadataToken.ToInt32()); Log.v("{0:X8}: new type: {1} ({2:X8})", info.arg.MetadataToken.ToInt32(), info.newType, info.newType.MetadataToken.ToInt32());
info.field.FieldType = info.newType; info.arg.FieldType = info.newType;
changed = true; changed = true;
} }
Log.deIndent(); Log.deIndent();
@ -432,14 +468,28 @@ namespace de4dot.deobfuscators {
continue; continue;
var field = instr.Operand as FieldReference; var field = instr.Operand as FieldReference;
FieldWriteInfo info; TypeInfo<FieldDefinition> info;
if (!fieldWrites.TryGetValue(new FieldReferenceAndDeclaringTypeKey(field), out info)) if (!fieldWrites.TryGetValue(new FieldReferenceAndDeclaringTypeKey(field), out info))
continue; continue;
int instrIndex = i; var fieldType = getLoadedType(method, instructions, i);
if (fieldType == null)
continue;
info.types[new TypeReferenceKey(fieldType)] = true;
}
}
bool changed = false;
foreach (var info in fieldWrites.Values)
changed |= info.updateNewType();
return changed;
}
TypeReference getLoadedType(MethodDefinition method, IList<Instruction> instructions, int instrIndex) {
var prev = getPreviousInstruction(instructions, ref instrIndex); var prev = getPreviousInstruction(instructions, ref instrIndex);
if (prev == null) if (prev == null)
continue; return null;
TypeReference fieldType; TypeReference fieldType;
switch (prev.OpCode.Code) { switch (prev.OpCode.Code) {
@ -452,21 +502,21 @@ namespace de4dot.deobfuscators {
case Code.Callvirt: case Code.Callvirt:
var calledMethod = prev.Operand as MethodReference; var calledMethod = prev.Operand as MethodReference;
if (calledMethod == null) if (calledMethod == null)
continue; return null;
fieldType = calledMethod.MethodReturnType.ReturnType; fieldType = calledMethod.MethodReturnType.ReturnType;
break; break;
case Code.Newarr: case Code.Newarr:
fieldType = prev.Operand as TypeReference; fieldType = prev.Operand as TypeReference;
if (fieldType == null) if (fieldType == null)
continue; return null;
fieldType = new ArrayType(fieldType); fieldType = new ArrayType(fieldType);
break; break;
case Code.Newobj: case Code.Newobj:
var ctor = prev.Operand as MethodReference; var ctor = prev.Operand as MethodReference;
if (ctor == null) if (ctor == null)
continue; return null;
fieldType = ctor.DeclaringType; fieldType = ctor.DeclaringType;
break; break;
@ -492,7 +542,7 @@ namespace de4dot.deobfuscators {
case Code.Ldloc_3: case Code.Ldloc_3:
var local = DotNetUtils.getLocalVar(method.Body.Variables, prev); var local = DotNetUtils.getLocalVar(method.Body.Variables, prev);
if (local == null) if (local == null)
continue; return null;
fieldType = local.VariableType; fieldType = local.VariableType;
break; break;
@ -500,54 +550,32 @@ namespace de4dot.deobfuscators {
case Code.Ldsfld: case Code.Ldsfld:
var field2 = prev.Operand as FieldReference; var field2 = prev.Operand as FieldReference;
if (field2 == null) if (field2 == null)
continue; return null;
fieldType = field2.FieldType; fieldType = field2.FieldType;
break; break;
default: default:
continue; return null;
} }
if (fieldType == null) if (!isValidType(fieldType))
continue; return null;
if (fieldType.IsValueType)
continue;
if (MemberReferenceHelper.isSystemObject(fieldType))
continue;
if (MemberReferenceHelper.verifyType(fieldType, "mscorlib", "System.Void"))
continue;
if (fieldType is GenericParameter)
continue;
info.types[new TypeReferenceKey(fieldType)] = true; return fieldType;
}
} }
bool changed = false; static bool isValidType(TypeReference type) {
foreach (var info in fieldWrites.Values) { if (type == null)
if (info.types.Count == 0) return false;
continue; if (type.IsValueType)
return false;
TypeReference newType = null; if (MemberReferenceHelper.isSystemObject(type))
foreach (var key in info.types.Keys) { return false;
if (newType == null) { if (MemberReferenceHelper.verifyType(type, "mscorlib", "System.Void"))
newType = key.TypeReference; return false;
continue; if (type is GenericParameter)
} return false;
newType = getCommonBaseClass(newType, key.TypeReference); return true;
if (newType == null)
break;
}
if (newType == null)
continue;
if (MemberReferenceHelper.compareTypes(newType, info.newType))
continue;
info.newType = newType;
changed = true;
}
return changed;
} }
static TypeReference getCommonBaseClass(TypeReference a, TypeReference b) { static TypeReference getCommonBaseClass(TypeReference a, TypeReference b) {