Skip to content

Commit

Permalink
Merge pull request #1376 from Johanmyst/fixCPPClasses
Browse files Browse the repository at this point in the history
Fix incomplete/incorrect callgraphs due to incorrect class names
  • Loading branch information
yuleisui authored Feb 15, 2024
2 parents 2965b6e + 78cb220 commit 2c2a53e
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 84 deletions.
6 changes: 3 additions & 3 deletions svf-llvm/include/SVF-LLVM/ObjTypeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ObjTypeInference
typedef Map<const Value *, const Type *> ValueToType;
typedef std::pair<const Value *, bool> ValueBoolPair;
typedef Map<const Value *, Set<std::string>> ValueToClassNames;
typedef Map<const CallBase *, Set<const Function *>> AllocToClsNameSources;
typedef Map<const Value *, Set<const Value *>> ObjToClsNameSources;


private:
Expand All @@ -57,7 +57,7 @@ class ObjTypeInference
ValueToSources _valueToAllocs; // value allocations (stack, static, heap) cache
ValueToClassNames _thisPtrClassNames; // thisptr class name cache
ValueToSources _valueToAllocOrClsNameSources; // value alloc/clsname sources cache
AllocToClsNameSources _allocToClsNameSources; // alloc clsname sources cache
ObjToClsNameSources _objToClsNameSources; // alloc clsname sources cache


public:
Expand Down Expand Up @@ -122,7 +122,7 @@ class ObjTypeInference
Set<const Value *> &bwFindAllocOrClsNameSources(const Value *startValue);

/// forward find class name sources starting from an allocation
Set<const Function *> &fwFindClsNameSources(const CallBase *alloc);
Set<const Value *> &fwFindClsNameSources(const Value *startValue);
};
}
#endif //SVF_OBJTYPEINFERENCE_H
10 changes: 8 additions & 2 deletions svf-llvm/lib/CppUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
*/

#include "SVF-LLVM/CppUtil.h"
#include "SVF-LLVM/BasicTypes.h"
#include "SVF-LLVM/LLVMUtil.h"
#include "Util/Casting.h"
#include "Util/SVFUtil.h"
#include "SVF-LLVM/LLVMModule.h"
#include "SVF-LLVM/ObjTypeInference.h"
Expand Down Expand Up @@ -640,9 +642,9 @@ Set<std::string> cppUtil::extractClsNamesFromFunc(const Function *foo)
{
assert(foo->hasName() && "foo does not have a name? possible indirect call");
const std::string &name = foo->getName().str();
if (isConstructor(foo))
if (isConstructor(foo) || isDestructor(foo))
{
// c++ constructor
// c++ constructor or destructor
DemangledName demangledName = cppUtil::demangle(name);
return {demangledName.className};
}
Expand Down Expand Up @@ -797,6 +799,10 @@ bool cppUtil::isClsNameSource(const Value *val)
if(!foo) return false;
return isConstructor(foo) || isDestructor(foo) || isTemplateFunc(foo) || isDynCast(foo);
}
else if (const auto *func = SVFUtil::dyn_cast<Function>(val))
{
return isConstructor(func) || isDestructor(func) || isTemplateFunc(func);
}
return false;
}

Expand Down
172 changes: 93 additions & 79 deletions svf-llvm/lib/ObjTypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
*/

#include "SVF-LLVM/ObjTypeInference.h"
#include "SVF-LLVM/BasicTypes.h"
#include "SVF-LLVM/LLVMModule.h"
#include "SVF-LLVM/LLVMUtil.h"
#include "SVF-LLVM/CppUtil.h"
#include "Util/Casting.h"

#define TYPE_DEBUG 0 /* Turn this on if you're debugging type inference */
#define ERR_MSG(msg) \
Expand Down Expand Up @@ -590,69 +592,77 @@ u32_t ObjTypeInference::objTyToNumFields(const Type *objTy)


/*!
* get or infer the class names of thisptr
* get or infer the class names of thisptr; starting from :param:`thisPtr`, will walk backwards to find
* all potential sources for the class name. Valid sources include global or stack variables, heap allocations,
* or C++ dynamic casts/constructors/destructors.
* If the source site is a global/stack/heap variable, find the corresponding constructor/destructor to
* extract the class' name from (since the type of the variable is not reliable but the demangled name is)
* @param thisPtr
* @return
* @return a set of all possible type names that :param:`thisPtr` could point to
*/
Set<std::string> &ObjTypeInference::inferThisPtrClsName(const Value *thisPtr)
{
auto it = _thisPtrClassNames.find(thisPtr);
if (it != _thisPtrClassNames.end()) return it->second;

Set<std::string> names;
auto insertClassNames = [&names](Set<std::string> &classNames)

// Lambda for checking a function is a valid name source & extracting a class name from it
auto addNamesFromFunc = [&names](const Function *func) -> void
{
names.insert(classNames.begin(), classNames.end());
ABORT_IFNOT(isClsNameSource(func), "Func is invalid class name source: " + dumpValueAndDbgInfo(func));
for (auto name : extractClsNamesFromFunc(func)) names.insert(name);
};

// backward find heap allocations or class name sources
Set<const Value *> &vals = bwFindAllocOrClsNameSources(thisPtr);
for (const auto &val: vals)
// Lambda for getting callee & extracting class name for calls to constructors/destructors/template funcs
auto addNamesFromCall = [&names, &addNamesFromFunc](const CallBase *call) -> void
{
ABORT_IFNOT(isClsNameSource(call), "Call is invalid class name source: " + dumpValueAndDbgInfo(call));

const auto *func = call->getCalledFunction();
if (isDynCast(func)) names.insert(extractClsNameFromDynCast(call));
else addNamesFromFunc(func);
};

// Walk backwards to find all valid source sites for the pointer (e.g. stack/global/heap variables)
for (const auto &val: bwFindAllocOrClsNameSources(thisPtr))
{
// A source site is either a constructor/destructor/template function from which the class name can be
// extracted; a call to a C++ constructor/destructor/template function from which the class name can be
// extracted; or an allocation site of an object (i.e. a stack/global/heap variable), from which a
// forward walk can be performed to find calls to C++ constructor/destructor/template functions from
// which the class' name can then be extracted; skip starting pointer
if (val == thisPtr) continue;

if (const auto *func = SVFUtil::dyn_cast<Function>(val))
{
// extract class name from function name
Set<std::string> classNames = extractClsNamesFromFunc(func);
insertClassNames(classNames);
// Constructor/destructor/template func; extract name from func directly
addNamesFromFunc(func);
}
else if (SVFUtil::isa<LoadInst, StoreInst, GetElementPtrInst, AllocaInst, GlobalValue>(val))
else if (isClsNameSource(val))
{
// extract class name from instructions
const Type *type = infersiteToType(val);
const std::string &className = typeToClsName(type);
if (!className.empty())
{
Set<std::string> tgt{className};
insertClassNames(tgt);
}
// Call to constructor/destructor/template func; get callee; extract name from callee
ABORT_IFNOT(SVFUtil::isa<CallBase>(val), "Call source site is not a callbase: " + dumpValueAndDbgInfo(val));
addNamesFromCall(SVFUtil::cast<CallBase>(val));
}
else if (const auto *callBase = SVFUtil::dyn_cast<CallBase>(val))
else if (isAlloc(val))
{
if (const Function *callFunc = callBase->getCalledFunction())
// Stack/global/heap allocation site; walk forward; find constructor/destructor/template calls
ABORT_IFNOT((SVFUtil::isa<AllocaInst, CallBase, GlobalVariable>(val)),
"Alloc site source is not a stack/heap/global variable: " + dumpValueAndDbgInfo(val));
for (const auto *src : fwFindClsNameSources(val))
{
Set<std::string> classNames = extractClsNamesFromFunc(callFunc);
insertClassNames(classNames);
if (isDynCast(callFunc))
{
// dynamic cast
Set<std::string> tgt{extractClsNameFromDynCast(callBase)};
insertClassNames(tgt);
}
else if (isNewAlloc(callFunc))
{
// for heap allocation, we forward find class name sources
Set<const Function *>& srcs = fwFindClsNameSources(callBase);
for (const auto &src: srcs)
{
classNames = extractClsNamesFromFunc(src);
insertClassNames(classNames);
}
}
if (const auto *func = SVFUtil::dyn_cast<Function>(src)) addNamesFromFunc(func);
else if (const auto *call = SVFUtil::dyn_cast<CallBase>(src)) addNamesFromCall(call);
else ABORT_MSG("Source site from forward walk is invalid: " + dumpValueAndDbgInfo(src));
}
}
else
{
ERR_MSG("Unsupported source type found:" + dumpValueAndDbgInfo(val));
}
}

return _thisPtrClassNames[thisPtr] = names;
}

Expand Down Expand Up @@ -711,48 +721,43 @@ Set<const Value *> &ObjTypeInference::bwFindAllocOrClsNameSources(const Value *s
workList.push({curValue, true});
}

// current inst reside in cpp self-inference function
// If current value is an instruction inside a constructor/destructor/template, use it as a source
if (const auto *inst = SVFUtil::dyn_cast<Instruction>(curValue))
{
if (const Function *foo = inst->getFunction())
if (const auto *parent = inst->getFunction())
{
if (isConstructor(foo) || isDestructor(foo) || isTemplateFunc(foo) || isDynCast(foo))
{
insertSource(foo);
if (canUpdate)
{
_valueToAllocOrClsNameSources[curValue] = sources;
}
continue;
}
if (isClsNameSource(parent)) insertSource(parent);
}
}

// If the current value is an object (global, heap, stack, etc) or name source (constructor/destructor,
// a C++ dynamic cast, or a template function), use it as a source
if (isAlloc(curValue) || isClsNameSource(curValue))
{
insertSource(curValue);
}
else if (const auto *getElementPtrInst = SVFUtil::dyn_cast<GetElementPtrInst>(curValue))

// Explore the current value further depending on the type of the value; use cached values if possible
if (const auto *getElementPtrInst = SVFUtil::dyn_cast<GetElementPtrInst>(curValue))
{
insertSource(getElementPtrInst);
insertSourcesOrPushWorklist(getElementPtrInst->getPointerOperand());
}
else if (const auto *bitCastInst = SVFUtil::dyn_cast<BitCastInst>(curValue))
{
Value *prevVal = bitCastInst->getOperand(0);
insertSourcesOrPushWorklist(prevVal);
insertSourcesOrPushWorklist(bitCastInst->getOperand(0));
}
else if (const auto *phiNode = SVFUtil::dyn_cast<PHINode>(curValue))
{
for (u32_t i = 0; i < phiNode->getNumOperands(); ++i)
for (const auto *op : phiNode->operand_values())
{
insertSourcesOrPushWorklist(phiNode->getOperand(i));
insertSourcesOrPushWorklist(op);
}
}
else if (const auto *loadInst = SVFUtil::dyn_cast<LoadInst>(curValue))
{
for (const auto &use: loadInst->getPointerOperand()->uses())
for (const auto *user : loadInst->getPointerOperand()->users())
{
if (const auto *storeInst = SVFUtil::dyn_cast<StoreInst>(use.getUser()))
if (const auto *storeInst = SVFUtil::dyn_cast<StoreInst>(user))
{
if (storeInst->getPointerOperand() == loadInst->getPointerOperand())
{
Expand All @@ -763,9 +768,9 @@ Set<const Value *> &ObjTypeInference::bwFindAllocOrClsNameSources(const Value *s
}
else if (const auto *argument = SVFUtil::dyn_cast<Argument>(curValue))
{
for (const auto &use: argument->getParent()->uses())
for (const auto *user: argument->getParent()->users())
{
if (const auto *callBase = SVFUtil::dyn_cast<CallBase>(use.getUser()))
if (const auto *callBase = SVFUtil::dyn_cast<CallBase>(user))
{
// skip function as parameter
// e.g., call void @foo(%struct.ssl_ctx_st* %9, i32 (i8*, i32, i32, i8*)* @passwd_callback)
Expand All @@ -778,7 +783,7 @@ Set<const Value *> &ObjTypeInference::bwFindAllocOrClsNameSources(const Value *s
else if (const auto *callBase = SVFUtil::dyn_cast<CallBase>(curValue))
{
ABORT_IFNOT(!callBase->doesNotReturn(), "callbase does not return:" + dumpValueAndDbgInfo(callBase));
if (Function *callee = callBase->getCalledFunction())
if (const auto *callee = callBase->getCalledFunction())
{
if (!callee->isDeclaration())
{
Expand All @@ -790,47 +795,56 @@ Set<const Value *> &ObjTypeInference::bwFindAllocOrClsNameSources(const Value *s
}
}
}

// If updating is allowed; store the gathered sources as sources for the current value in the cache
if (canUpdate)
{
_valueToAllocOrClsNameSources[curValue] = sources;
}
}

return _valueToAllocOrClsNameSources[startValue];
}

Set<const Function *> &ObjTypeInference::fwFindClsNameSources(const CallBase *alloc)
Set<const Value *> &ObjTypeInference::fwFindClsNameSources(const Value *startValue)
{
assert(startValue && "startValue was null?");

// consult cache
auto tIt = _allocToClsNameSources.find(alloc);
if (tIt != _allocToClsNameSources.end())
auto tIt = _objToClsNameSources.find(startValue);
if (tIt != _objToClsNameSources.end())
{
return tIt->second;
}

Set<const Function *> clsSources;
// for heap allocation, we forward find class name sources
auto inferViaCppCall = [&clsSources](const CallBase *callBase)
Set<const Value *> sources;

// Lambda for adding a callee to the sources iff it is a constructor/destructor/template/dyncast
auto inferViaCppCall = [&sources](const CallBase *caller)
{
if (!callBase->getCalledFunction()) return;
const Function *constructFoo = callBase->getCalledFunction();
clsSources.insert(constructFoo);
if (!caller) return;
if (isClsNameSource(caller)) sources.insert(caller);
};
for (const auto &use: alloc->uses())

// Find all calls of starting val (or through cast); add as potential source iff applicable
for (const auto *user : startValue->users())
{
if (const auto *cppCall = SVFUtil::dyn_cast<CallBase>(use.getUser()))
if (const auto *caller = SVFUtil::dyn_cast<CallBase>(user))
{
inferViaCppCall(cppCall);
inferViaCppCall(caller);
}
else if (const auto *bitCastInst = SVFUtil::dyn_cast<BitCastInst>(use.getUser()))
else if (const auto *bitcast = SVFUtil::dyn_cast<BitCastInst>(user))
{
for (const auto &use2: bitCastInst->uses())
for (const auto *cast_user : bitcast->users())
{
if (const auto *cppCall2 = SVFUtil::dyn_cast<CallBase>(use2.getUser()))
if (const auto *caller = SVFUtil::dyn_cast<CallBase>(cast_user))
{
inferViaCppCall(cppCall2);
inferViaCppCall(caller);
}
}
}
}
return _allocToClsNameSources[alloc] = SVFUtil::move(clsSources);
}

// Store sources in cache for starting value & return the found sources
return _objToClsNameSources[startValue] = SVFUtil::move(sources);
}

0 comments on commit 2c2a53e

Please sign in to comment.