Skip to content
4 changes: 4 additions & 0 deletions llvm/include/llvm/Analysis/IR2Vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -598,12 +598,16 @@ class LLVM_ABI FlowAwareEmbedder : public Embedder {
// FlowAware embeddings would benefit from caching instruction embeddings as
// they are reused while computing the embeddings of other instructions.
mutable InstEmbeddingsMap InstVecMap;
static SmallVector<Function *, 15> FuncStack;
Embedding computeEmbeddings(const Instruction &I) const override;
static SmallMapVector<const Function *, SmallVector<const Function *, 10>, 16>
FuncCallMap;

public:
FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
: Embedder(F, Vocab) {}
void invalidateEmbeddings() override { InstVecMap.clear(); }
static void computeFuncCallMap(Module &M);
};

} // namespace ir2vec
Expand Down
41 changes: 40 additions & 1 deletion llvm/lib/Analysis/IR2Vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
Expand Down Expand Up @@ -61,7 +62,10 @@ cl::opt<IR2VecKind> IR2VecEmbeddingKind(
"Generate flow-aware embeddings")),
cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"),
cl::cat(IR2VecCategory));

// static members of Flowaware Embeddings
SmallVector<Function *, 15> FlowAwareEmbedder::FuncStack;
SmallMapVector<const Function *, SmallVector<const Function *, 10>, 16>
FlowAwareEmbedder::FuncCallMap;
} // namespace ir2vec
} // namespace llvm

Expand Down Expand Up @@ -207,6 +211,23 @@ Embedding FlowAwareEmbedder::computeEmbeddings(const Instruction &I) const {
// TODO: Handle call instructions differently.
// For now, we treat them like other instructions
Embedding ArgEmb(Dimension, 0);

if (isa<CallInst>(I)) {
const auto *Ci = dyn_cast<CallInst>(&I);
Function *Func = Ci->getCalledFunction();
if (Func) {
if (!Func->isDeclaration() &&
std::find(FuncStack.begin(), FuncStack.end(), Func) ==
FuncStack.end()) {
FuncStack.push_back(Func);
auto Emb = Embedder::create(IR2VecEmbeddingKind, *Func, Vocab);
auto FuncVec = Emb->getFunctionVector();
std::transform(ArgEmb.begin(), ArgEmb.end(), FuncVec.begin(),
FuncVec.end(), std::plus<double>());
FuncStack.pop_back();
}
}
}
for (const auto &Op : I.operands()) {
// If the operand is defined elsewhere, we use its embedding
if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
Expand Down Expand Up @@ -245,6 +266,24 @@ Embedding FlowAwareEmbedder::computeEmbeddings(const Instruction &I) const {
return InstVector;
}

void FlowAwareEmbedder::computeFuncCallMap(Module &M) {
CallGraph Cg = CallGraph(M);
for (auto CallItr = Cg.begin(); CallItr != Cg.end(); CallItr++) {
if (CallItr->first && !CallItr->first->isDeclaration()) {
const auto *ParentFunc = CallItr->first;
CallGraphNode *Cgn = CallItr->second.get();
if (Cgn) {
for (auto It = Cgn->begin(); It != Cgn->end(); It++) {
const auto *Func = It->second->getFunction();
if (Func && !Func->isDeclaration()) {
FuncCallMap[ParentFunc].push_back(Func);
}
}
}
}
}
}

// ==----------------------------------------------------------------------===//
// VocabStorage
//===----------------------------------------------------------------------===//
Expand Down