lower requirements for lower llvm/clang versions

3.7.1 works with the exception of InsTrim, 3.8.1 and above is ok
This commit is contained in:
hexcoder-
2020-01-30 21:32:08 +01:00
parent b13bb64c3b
commit ceed66930e
7 changed files with 387 additions and 64 deletions

View File

@ -3,10 +3,23 @@
#include <stdarg.h> #include <stdarg.h>
#include <unistd.h> #include <unistd.h>
#include "llvm/Config/llvm-config.h"
#if LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR < 5
typedef long double max_align_t;
#endif
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DenseSet.h"
#if LLVM_VERSION_MAJOR > 3 || \
(LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4)
#include "llvm/IR/CFG.h" #include "llvm/IR/CFG.h"
#include "llvm/IR/Dominators.h" #include "llvm/IR/Dominators.h"
#include "llvm/IR/DebugInfo.h"
#else
#include "llvm/Support/CFG.h"
#include "llvm/Analysis/Dominators.h"
#include "llvm/DebugInfo.h"
#endif
#include "llvm/IR/IRBuilder.h" #include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h" #include "llvm/IR/Instructions.h"
#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/LegacyPassManager.h"
@ -16,9 +29,7 @@
#include "llvm/Support/CommandLine.h" #include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/BasicBlock.h" #include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CFG.h"
#include <unordered_set> #include <unordered_set>
#include <random> #include <random>
#include <list> #include <list>
@ -97,7 +108,7 @@ struct InsTrim : public ModulePass {
// ripped from aflgo // ripped from aflgo
static bool isBlacklisted(const Function *F) { static bool isBlacklisted(const Function *F) {
static const SmallVector<std::string, 4> Blacklist = { static const char *Blacklist[] = {
"asan.", "asan.",
"llvm.", "llvm.",
@ -173,6 +184,8 @@ struct InsTrim : public ModulePass {
StringRef instFilename; StringRef instFilename;
unsigned int instLine = 0; unsigned int instLine = 0;
#if LLVM_VERSION_MAJOR >= 4 || \
(LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR >= 7)
for (auto &BB : F) { for (auto &BB : F) {
BasicBlock::iterator IP = BB.getFirstInsertionPt(); BasicBlock::iterator IP = BB.getFirstInsertionPt();
@ -227,6 +240,48 @@ struct InsTrim : public ModulePass {
} }
#else
for (auto &BB : F) {
BasicBlock::iterator IP = BB.getFirstInsertionPt();
IRBuilder<> IRB(&(*IP));
if (Loc.isUnknown()) Loc = IP->getDebugLoc();
}
if (!Loc.isUnknown()) {
DILocation cDILoc(Loc.getAsMDNode(C));
instLine = cDILoc.getLineNumber();
instFilename = cDILoc.getFilename();
/* Continue only if we know where we actually are */
if (!instFilename.str().empty()) {
for (std::list<std::string>::iterator it = myWhitelist.begin();
it != myWhitelist.end(); ++it) {
if (instFilename.str().length() >= it->length()) {
if (instFilename.str().compare(
instFilename.str().length() - it->length(),
it->length(), *it) == 0) {
instrumentBlock = true;
break;
}
}
}
}
}
#endif
/* Either we couldn't figure out our location or the location is /* Either we couldn't figure out our location or the location is
* not whitelisted, so we skip instrumentation. */ * not whitelisted, so we skip instrumentation. */
if (!instrumentBlock) { if (!instrumentBlock) {

View File

@ -36,7 +36,7 @@ else
endif endif
LLVMVER = $(shell $(LLVM_CONFIG) --version 2>/dev/null ) LLVMVER = $(shell $(LLVM_CONFIG) --version 2>/dev/null )
LLVM_UNSUPPORTED = $(shell $(LLVM_CONFIG) --version 2>/dev/null | egrep -q '^3\.[0-7]|^1[2-9]' && echo 1 || echo 0 ) LLVM_UNSUPPORTED = $(shell $(LLVM_CONFIG) --version 2>/dev/null | egrep -q '^3\.[0-3]|^1[2-9]' && echo 1 || echo 0 )
LLVM_NEW_API = $(shell $(LLVM_CONFIG) --version 2>/dev/null | egrep -q '^1[0-9]' && echo 1 || echo 0 ) LLVM_NEW_API = $(shell $(LLVM_CONFIG) --version 2>/dev/null | egrep -q '^1[0-9]' && echo 1 || echo 0 )
LLVM_MAJOR = $(shell $(LLVM_CONFIG) --version 2>/dev/null | sed 's/\..*//') LLVM_MAJOR = $(shell $(LLVM_CONFIG) --version 2>/dev/null | sed 's/\..*//')
LLVM_BINDIR = $(shell $(LLVM_CONFIG) --bindir 2>/dev/null) LLVM_BINDIR = $(shell $(LLVM_CONFIG) --bindir 2>/dev/null)
@ -201,7 +201,7 @@ endif
ln -sf afl-clang-fast ../afl-clang-fast++ ln -sf afl-clang-fast ../afl-clang-fast++
../libLLVMInsTrim.so: LLVMInsTrim.so.cc MarkNodes.cc | test_deps ../libLLVMInsTrim.so: LLVMInsTrim.so.cc MarkNodes.cc | test_deps
$(CXX) $(CLANG_CFL) -DLLVMInsTrim_EXPORTS -fno-rtti -fPIC -std=$(LLVM_STDCXX) -shared $< MarkNodes.cc -o $@ $(CLANG_LFL) -$(CXX) $(CLANG_CFL) -DLLVMInsTrim_EXPORTS -fno-rtti -fPIC -std=$(LLVM_STDCXX) -shared $< MarkNodes.cc -o $@ $(CLANG_LFL)
../afl-llvm-pass.so: afl-llvm-pass.so.cc | test_deps ../afl-llvm-pass.so: afl-llvm-pass.so.cc | test_deps
$(CXX) $(CLANG_CFL) -DLLVMInsTrim_EXPORTS -fno-rtti -fPIC -std=$(LLVM_STDCXX) -shared $< -o $@ $(CLANG_LFL) $(CXX) $(CLANG_CFL) -DLLVMInsTrim_EXPORTS -fno-rtti -fPIC -std=$(LLVM_STDCXX) -shared $< -o $@ $(CLANG_LFL)

View File

@ -3,11 +3,22 @@
#include <queue> #include <queue>
#include <set> #include <set>
#include <vector> #include <vector>
#include "llvm/Config/llvm-config.h"
#if LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR < 5
typedef long double max_align_t;
#endif
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/IR/BasicBlock.h" #include "llvm/IR/BasicBlock.h"
#if LLVM_VERSION_MAJOR > 3 || \
(LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4)
#include "llvm/IR/CFG.h" #include "llvm/IR/CFG.h"
#else
#include "llvm/Support/CFG.h"
#endif
#include "llvm/IR/Constants.h" #include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h" #include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h" #include "llvm/IR/IRBuilder.h"

View File

@ -37,14 +37,26 @@
#include <fstream> #include <fstream>
#include <sys/time.h> #include <sys/time.h>
#include "llvm/IR/DebugInfo.h" #include "llvm/Config/llvm-config.h"
#include "llvm/IR/BasicBlock.h" #if LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR < 5
typedef long double max_align_t;
#endif
#include "llvm/IR/IRBuilder.h" #include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Module.h" #include "llvm/IR/Module.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h"
#if LLVM_VERSION_MAJOR > 3 || \
(LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4)
#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/CFG.h" #include "llvm/IR/CFG.h"
#else
#include "llvm/DebugInfo.h"
#include "llvm/Support/CFG.h"
#endif
using namespace llvm; using namespace llvm;
@ -78,7 +90,7 @@ class AFLCoverage : public ModulePass {
// ripped from aflgo // ripped from aflgo
static bool isBlacklisted(const Function *F) { static bool isBlacklisted(const Function *F) {
static const SmallVector<std::string, 4> Blacklist = { static const char *Blacklist[] = {
"asan.", "asan.",
"llvm.", "llvm.",
@ -197,6 +209,8 @@ bool AFLCoverage::runOnModule(Module &M) {
* For now, just instrument the block if we are not able * For now, just instrument the block if we are not able
* to determine our location. */ * to determine our location. */
DebugLoc Loc = IP->getDebugLoc(); DebugLoc Loc = IP->getDebugLoc();
#if LLVM_VERSION_MAJOR >= 4 || \
(LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR >= 7)
if (Loc) { if (Loc) {
DILocation *cDILoc = dyn_cast<DILocation>(Loc.getAsMDNode()); DILocation *cDILoc = dyn_cast<DILocation>(Loc.getAsMDNode());
@ -249,6 +263,47 @@ bool AFLCoverage::runOnModule(Module &M) {
} }
#else
if (!Loc.isUnknown()) {
DILocation cDILoc(Loc.getAsMDNode(C));
unsigned int instLine = cDILoc.getLineNumber();
StringRef instFilename = cDILoc.getFilename();
(void)instLine;
/* Continue only if we know where we actually are */
if (!instFilename.str().empty()) {
for (std::list<std::string>::iterator it = myWhitelist.begin();
it != myWhitelist.end(); ++it) {
/* We don't check for filename equality here because
* filenames might actually be full paths. Instead we
* check that the actual filename ends in the filename
* specified in the list. */
if (instFilename.str().length() >= it->length()) {
if (instFilename.str().compare(
instFilename.str().length() - it->length(),
it->length(), *it) == 0) {
instrumentBlock = true;
break;
}
}
}
}
}
#endif
/* Either we couldn't figure out our location or the location is /* Either we couldn't figure out our location or the location is
* not whitelisted, so we skip instrumentation. */ * not whitelisted, so we skip instrumentation. */
if (!instrumentBlock) continue; if (!instrumentBlock) continue;
@ -273,13 +328,19 @@ bool AFLCoverage::runOnModule(Module &M) {
// result: a little more speed and less map pollution // result: a little more speed and less map pollution
int more_than_one = -1; int more_than_one = -1;
// fprintf(stderr, "BB %u: ", cur_loc); // fprintf(stderr, "BB %u: ", cur_loc);
for (BasicBlock *Pred : predecessors(&BB)) { for (pred_iterator PI = pred_begin(&BB), E = pred_end(&BB); PI != E;
++PI) {
BasicBlock *Pred = *PI;
int count = 0; int count = 0;
if (more_than_one == -1) more_than_one = 0; if (more_than_one == -1) more_than_one = 0;
// fprintf(stderr, " %p=>", Pred); // fprintf(stderr, " %p=>", Pred);
for (BasicBlock *Succ : successors(Pred)) { for (succ_iterator SI = succ_begin(Pred), E = succ_end(Pred); SI != E;
++SI) {
BasicBlock *Succ = *SI;
// if (count > 0) // if (count > 0)
// fprintf(stderr, "|"); // fprintf(stderr, "|");

View File

@ -22,9 +22,9 @@
#include <string> #include <string>
#include <fstream> #include <fstream>
#include <sys/time.h> #include <sys/time.h>
#include "llvm/Config/llvm-config.h"
#include "llvm/ADT/Statistic.h" #include "llvm/ADT/Statistic.h"
#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/IRBuilder.h" #include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h" #include "llvm/IR/Module.h"
@ -32,10 +32,19 @@
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Pass.h" #include "llvm/Pass.h"
#include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/ValueTracking.h"
#if LLVM_VERSION_MAJOR > 3 || \
(LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4)
#include "llvm/IR/Verifier.h"
#include "llvm/IR/DebugInfo.h"
#else
#include "llvm/Analysis/Verifier.h"
#include "llvm/DebugInfo.h"
#define nullptr 0
#endif
#include <set> #include <set>
using namespace llvm; using namespace llvm;
@ -115,7 +124,7 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
c = M.getOrInsertFunction("tolower", Int32Ty, Int32Ty c = M.getOrInsertFunction("tolower", Int32Ty, Int32Ty
#if LLVM_VERSION_MAJOR < 5 #if LLVM_VERSION_MAJOR < 5
, ,
nullptr NULL
#endif #endif
); );
#if LLVM_VERSION_MAJOR < 9 #if LLVM_VERSION_MAJOR < 9
@ -140,6 +149,8 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
* For now, just instrument the block if we are not able * For now, just instrument the block if we are not able
* to determine our location. */ * to determine our location. */
DebugLoc Loc = IP->getDebugLoc(); DebugLoc Loc = IP->getDebugLoc();
#if LLVM_VERSION_MAJOR >= 4 || \
(LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR >= 7)
if (Loc) { if (Loc) {
DILocation *cDILoc = dyn_cast<DILocation>(Loc.getAsMDNode()); DILocation *cDILoc = dyn_cast<DILocation>(Loc.getAsMDNode());
@ -192,6 +203,47 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
} }
#else
if (!Loc.isUnknown()) {
DILocation cDILoc(Loc.getAsMDNode(C));
unsigned int instLine = cDILoc.getLineNumber();
StringRef instFilename = cDILoc.getFilename();
(void)instLine;
/* Continue only if we know where we actually are */
if (!instFilename.str().empty()) {
for (std::list<std::string>::iterator it = myWhitelist.begin();
it != myWhitelist.end(); ++it) {
/* We don't check for filename equality here because
* filenames might actually be full paths. Instead we
* check that the actual filename ends in the filename
* specified in the list. */
if (instFilename.str().length() >= it->length()) {
if (instFilename.str().compare(
instFilename.str().length() - it->length(),
it->length(), *it) == 0) {
instrumentBlock = true;
break;
}
}
}
}
}
#endif
/* Either we couldn't figure out our location or the location is /* Either we couldn't figure out our location or the location is
* not whitelisted, so we skip instrumentation. */ * not whitelisted, so we skip instrumentation. */
if (!instrumentBlock) continue; if (!instrumentBlock) continue;

View File

@ -24,16 +24,25 @@
#include <fstream> #include <fstream>
#include <sys/time.h> #include <sys/time.h>
#include "llvm/Config/llvm-config.h"
#include "llvm/Pass.h" #include "llvm/Pass.h"
#include "llvm/IR/DebugInfo.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/LegacyPassManager.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IR/Module.h" #include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h" #include "llvm/IR/IRBuilder.h"
#if LLVM_VERSION_MAJOR > 3 || \
(LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4)
#include "llvm/IR/Verifier.h"
#include "llvm/IR/DebugInfo.h"
#else
#include "llvm/Analysis/Verifier.h"
#include "llvm/DebugInfo.h"
#define nullptr 0
#endif
using namespace llvm; using namespace llvm;
@ -66,7 +75,7 @@ class SplitComparesTransform : public ModulePass {
static bool isBlacklisted(const Function *F) { static bool isBlacklisted(const Function *F) {
static const SmallVector<std::string, 5> Blacklist = { static const char *Blacklist[] = {
"asan.", "llvm.", "sancov.", "__ubsan_handle_", "ign." "asan.", "llvm.", "sancov.", "__ubsan_handle_", "ign."
@ -139,6 +148,8 @@ bool SplitComparesTransform::simplifyCompares(Module &M) {
* For now, just instrument the block if we are not able * For now, just instrument the block if we are not able
* to determine our location. */ * to determine our location. */
DebugLoc Loc = IP->getDebugLoc(); DebugLoc Loc = IP->getDebugLoc();
#if LLVM_VERSION_MAJOR >= 4 || \
(LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR >= 7)
if (Loc) { if (Loc) {
DILocation *cDILoc = dyn_cast<DILocation>(Loc.getAsMDNode()); DILocation *cDILoc = dyn_cast<DILocation>(Loc.getAsMDNode());
@ -191,6 +202,47 @@ bool SplitComparesTransform::simplifyCompares(Module &M) {
} }
#else
if (!Loc.isUnknown()) {
DILocation cDILoc(Loc.getAsMDNode(C));
unsigned int instLine = cDILoc.getLineNumber();
StringRef instFilename = cDILoc.getFilename();
(void)instLine;
/* Continue only if we know where we actually are */
if (!instFilename.str().empty()) {
for (std::list<std::string>::iterator it = myWhitelist.begin();
it != myWhitelist.end(); ++it) {
/* We don't check for filename equality here because
* filenames might actually be full paths. Instead we
* check that the actual filename ends in the filename
* specified in the list. */
if (instFilename.str().length() >= it->length()) {
if (instFilename.str().compare(
instFilename.str().length() - it->length(),
it->length(), *it) == 0) {
instrumentBlock = true;
break;
}
}
}
}
}
#endif
/* Either we couldn't figure out our location or the location is /* Either we couldn't figure out our location or the location is
* not whitelisted, so we skip instrumentation. */ * not whitelisted, so we skip instrumentation. */
if (!instrumentBlock) continue; if (!instrumentBlock) continue;
@ -283,7 +335,8 @@ bool SplitComparesTransform::simplifyCompares(Module &M) {
* block bb it is now at the position where the old IcmpInst was */ * block bb it is now at the position where the old IcmpInst was */
Instruction *icmp_np; Instruction *icmp_np;
icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1); icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
bb->getInstList().insert(bb->getTerminator()->getIterator(), icmp_np); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
icmp_np);
/* create a new basic block which holds the new EQ icmp */ /* create a new basic block which holds the new EQ icmp */
Instruction *icmp_eq; Instruction *icmp_eq;
@ -348,7 +401,8 @@ bool SplitComparesTransform::simplifyCompares(Module &M) {
* block bb it is now at the position where the old IcmpInst was */ * block bb it is now at the position where the old IcmpInst was */
Instruction *fcmp_np; Instruction *fcmp_np;
fcmp_np = CmpInst::Create(Instruction::FCmp, new_pred, op0, op1); fcmp_np = CmpInst::Create(Instruction::FCmp, new_pred, op0, op1);
bb->getInstList().insert(bb->getTerminator()->getIterator(), fcmp_np); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
fcmp_np);
/* create a new basic block which holds the new EQ fcmp */ /* create a new basic block which holds the new EQ fcmp */
Instruction *fcmp_eq; Instruction *fcmp_eq;
@ -469,20 +523,21 @@ bool SplitComparesTransform::simplifyIntSignedness(Module &M) {
s_op0 = BinaryOperator::Create(Instruction::LShr, op0, s_op0 = BinaryOperator::Create(Instruction::LShr, op0,
ConstantInt::get(IntType, bitw - 1)); ConstantInt::get(IntType, bitw - 1));
bb->getInstList().insert(bb->getTerminator()->getIterator(), s_op0); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op0);
t_op0 = new TruncInst(s_op0, Int1Ty); t_op0 = new TruncInst(s_op0, Int1Ty);
bb->getInstList().insert(bb->getTerminator()->getIterator(), t_op0); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), t_op0);
s_op1 = BinaryOperator::Create(Instruction::LShr, op1, s_op1 = BinaryOperator::Create(Instruction::LShr, op1,
ConstantInt::get(IntType, bitw - 1)); ConstantInt::get(IntType, bitw - 1));
bb->getInstList().insert(bb->getTerminator()->getIterator(), s_op1); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op1);
t_op1 = new TruncInst(s_op1, Int1Ty); t_op1 = new TruncInst(s_op1, Int1Ty);
bb->getInstList().insert(bb->getTerminator()->getIterator(), t_op1); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), t_op1);
/* compare of the sign bits */ /* compare of the sign bits */
icmp_sign_bit = icmp_sign_bit =
CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_op0, t_op1); CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_op0, t_op1);
bb->getInstList().insert(bb->getTerminator()->getIterator(), icmp_sign_bit); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
icmp_sign_bit);
/* create a new basic block which is executed if the signedness bit is /* create a new basic block which is executed if the signedness bit is
* different */ * different */
@ -557,6 +612,8 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
LLVMContext &C = M.getContext(); LLVMContext &C = M.getContext();
#if LLVM_VERSION_MAJOR > 3 || \
(LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 7)
const DataLayout &dl = M.getDataLayout(); const DataLayout &dl = M.getDataLayout();
/* define unions with floating point and (sign, exponent, mantissa) triples /* define unions with floating point and (sign, exponent, mantissa) triples
@ -571,6 +628,8 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
} }
#endif
std::vector<CmpInst *> fcomps; std::vector<CmpInst *> fcomps;
/* get all EQ, NE, GT, and LT fcmps. if the other two /* get all EQ, NE, GT, and LT fcmps. if the other two
@ -669,11 +728,11 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
Instruction *b_op0, *b_op1; Instruction *b_op0, *b_op1;
b_op0 = CastInst::Create(Instruction::BitCast, op0, b_op0 = CastInst::Create(Instruction::BitCast, op0,
IntegerType::get(C, op_size)); IntegerType::get(C, op_size));
bb->getInstList().insert(bb->getTerminator()->getIterator(), b_op0); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), b_op0);
b_op1 = CastInst::Create(Instruction::BitCast, op1, b_op1 = CastInst::Create(Instruction::BitCast, op1,
IntegerType::get(C, op_size)); IntegerType::get(C, op_size));
bb->getInstList().insert(bb->getTerminator()->getIterator(), b_op1); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), b_op1);
/* isolate signs of value of floating point type */ /* isolate signs of value of floating point type */
@ -684,21 +743,22 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
s_s0 = s_s0 =
BinaryOperator::Create(Instruction::LShr, b_op0, BinaryOperator::Create(Instruction::LShr, b_op0,
ConstantInt::get(b_op0->getType(), op_size - 1)); ConstantInt::get(b_op0->getType(), op_size - 1));
bb->getInstList().insert(bb->getTerminator()->getIterator(), s_s0); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_s0);
t_s0 = new TruncInst(s_s0, Int1Ty); t_s0 = new TruncInst(s_s0, Int1Ty);
bb->getInstList().insert(bb->getTerminator()->getIterator(), t_s0); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), t_s0);
s_s1 = s_s1 =
BinaryOperator::Create(Instruction::LShr, b_op1, BinaryOperator::Create(Instruction::LShr, b_op1,
ConstantInt::get(b_op1->getType(), op_size - 1)); ConstantInt::get(b_op1->getType(), op_size - 1));
bb->getInstList().insert(bb->getTerminator()->getIterator(), s_s1); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_s1);
t_s1 = new TruncInst(s_s1, Int1Ty); t_s1 = new TruncInst(s_s1, Int1Ty);
bb->getInstList().insert(bb->getTerminator()->getIterator(), t_s1); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), t_s1);
/* compare of the sign bits */ /* compare of the sign bits */
icmp_sign_bit = icmp_sign_bit =
CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_s0, t_s1); CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_s0, t_s1);
bb->getInstList().insert(bb->getTerminator()->getIterator(), icmp_sign_bit); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
icmp_sign_bit);
/* create a new basic block which is executed if the signedness bits are /* create a new basic block which is executed if the signedness bits are
* equal */ * equal */
@ -730,16 +790,16 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
Instruction::LShr, b_op1, Instruction::LShr, b_op1,
ConstantInt::get(b_op1->getType(), shiftR_exponent)); ConstantInt::get(b_op1->getType(), shiftR_exponent));
signequal_bb->getInstList().insert( signequal_bb->getInstList().insert(
signequal_bb->getTerminator()->getIterator(), s_e0); BasicBlock::iterator(signequal_bb->getTerminator()), s_e0);
signequal_bb->getInstList().insert( signequal_bb->getInstList().insert(
signequal_bb->getTerminator()->getIterator(), s_e1); BasicBlock::iterator(signequal_bb->getTerminator()), s_e1);
t_e0 = new TruncInst(s_e0, IntExponentTy); t_e0 = new TruncInst(s_e0, IntExponentTy);
t_e1 = new TruncInst(s_e1, IntExponentTy); t_e1 = new TruncInst(s_e1, IntExponentTy);
signequal_bb->getInstList().insert( signequal_bb->getInstList().insert(
signequal_bb->getTerminator()->getIterator(), t_e0); BasicBlock::iterator(signequal_bb->getTerminator()), t_e0);
signequal_bb->getInstList().insert( signequal_bb->getInstList().insert(
signequal_bb->getTerminator()->getIterator(), t_e1); BasicBlock::iterator(signequal_bb->getTerminator()), t_e1);
if (sizeInBits - precision < exTySizeBytes * 8) { if (sizeInBits - precision < exTySizeBytes * 8) {
@ -750,9 +810,9 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
Instruction::And, t_e1, Instruction::And, t_e1,
ConstantInt::get(t_e1->getType(), mask_exponent)); ConstantInt::get(t_e1->getType(), mask_exponent));
signequal_bb->getInstList().insert( signequal_bb->getInstList().insert(
signequal_bb->getTerminator()->getIterator(), m_e0); BasicBlock::iterator(signequal_bb->getTerminator()), m_e0);
signequal_bb->getInstList().insert( signequal_bb->getInstList().insert(
signequal_bb->getTerminator()->getIterator(), m_e1); BasicBlock::iterator(signequal_bb->getTerminator()), m_e1);
} else { } else {
@ -780,7 +840,7 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
icmp_exponent = icmp_exponent =
CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, m_e0, m_e1); CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, m_e0, m_e1);
signequal_bb->getInstList().insert( signequal_bb->getInstList().insert(
signequal_bb->getTerminator()->getIterator(), icmp_exponent); BasicBlock::iterator(signequal_bb->getTerminator()), icmp_exponent);
icmp_exponent_result = icmp_exponent_result =
BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0); BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0);
break; break;
@ -789,7 +849,7 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
icmp_exponent = icmp_exponent =
CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, m_e0, m_e1); CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, m_e0, m_e1);
signequal_bb->getInstList().insert( signequal_bb->getInstList().insert(
signequal_bb->getTerminator()->getIterator(), icmp_exponent); BasicBlock::iterator(signequal_bb->getTerminator()), icmp_exponent);
icmp_exponent_result = icmp_exponent_result =
BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0); BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0);
break; break;
@ -798,7 +858,8 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
} }
signequal_bb->getInstList().insert( signequal_bb->getInstList().insert(
signequal_bb->getTerminator()->getIterator(), icmp_exponent_result); BasicBlock::iterator(signequal_bb->getTerminator()),
icmp_exponent_result);
{ {
@ -822,19 +883,19 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
m_f1 = BinaryOperator::Create( m_f1 = BinaryOperator::Create(
Instruction::And, b_op1, Instruction::And, b_op1,
ConstantInt::get(b_op1->getType(), mask_fraction)); ConstantInt::get(b_op1->getType(), mask_fraction));
middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), middle_bb->getInstList().insert(
m_f0); BasicBlock::iterator(middle_bb->getTerminator()), m_f0);
middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), middle_bb->getInstList().insert(
m_f1); BasicBlock::iterator(middle_bb->getTerminator()), m_f1);
if (needTrunc) { if (needTrunc) {
t_f0 = new TruncInst(m_f0, IntFractionTy); t_f0 = new TruncInst(m_f0, IntFractionTy);
t_f1 = new TruncInst(m_f1, IntFractionTy); t_f1 = new TruncInst(m_f1, IntFractionTy);
middle_bb->getInstList().insert( middle_bb->getInstList().insert(
middle_bb->getTerminator()->getIterator(), t_f0); BasicBlock::iterator(middle_bb->getTerminator()), t_f0);
middle_bb->getInstList().insert( middle_bb->getInstList().insert(
middle_bb->getTerminator()->getIterator(), t_f1); BasicBlock::iterator(middle_bb->getTerminator()), t_f1);
} else { } else {
@ -850,9 +911,9 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
t_f0 = new TruncInst(b_op0, IntFractionTy); t_f0 = new TruncInst(b_op0, IntFractionTy);
t_f1 = new TruncInst(b_op1, IntFractionTy); t_f1 = new TruncInst(b_op1, IntFractionTy);
middle_bb->getInstList().insert( middle_bb->getInstList().insert(
middle_bb->getTerminator()->getIterator(), t_f0); BasicBlock::iterator(middle_bb->getTerminator()), t_f0);
middle_bb->getInstList().insert( middle_bb->getInstList().insert(
middle_bb->getTerminator()->getIterator(), t_f1); BasicBlock::iterator(middle_bb->getTerminator()), t_f1);
} else { } else {
@ -882,7 +943,7 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
icmp_fraction = icmp_fraction =
CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_f0, t_f1); CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_f0, t_f1);
middle_bb->getInstList().insert( middle_bb->getInstList().insert(
middle_bb->getTerminator()->getIterator(), icmp_fraction); BasicBlock::iterator(middle_bb->getTerminator()), icmp_fraction);
icmp_fraction_result = icmp_fraction_result =
BinaryOperator::Create(Instruction::Xor, icmp_fraction, t_s0); BinaryOperator::Create(Instruction::Xor, icmp_fraction, t_s0);
break; break;
@ -891,7 +952,7 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
icmp_fraction = icmp_fraction =
CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_f0, t_f1); CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_f0, t_f1);
middle_bb->getInstList().insert( middle_bb->getInstList().insert(
middle_bb->getTerminator()->getIterator(), icmp_fraction); BasicBlock::iterator(middle_bb->getTerminator()), icmp_fraction);
icmp_fraction_result = icmp_fraction_result =
BinaryOperator::Create(Instruction::Xor, icmp_fraction, t_s0); BinaryOperator::Create(Instruction::Xor, icmp_fraction, t_s0);
break; break;
@ -899,8 +960,8 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
} }
middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), middle_bb->getInstList().insert(
icmp_fraction_result); BasicBlock::iterator(middle_bb->getTerminator()), icmp_fraction_result);
PHINode *PN = PHINode::Create(Int1Ty, 3, ""); PHINode *PN = PHINode::Create(Int1Ty, 3, "");
@ -1037,18 +1098,21 @@ size_t SplitComparesTransform::splitIntCompares(Module &M, unsigned bitw) {
s_op0 = BinaryOperator::Create(Instruction::LShr, op0, s_op0 = BinaryOperator::Create(Instruction::LShr, op0,
ConstantInt::get(OldIntType, bitw / 2)); ConstantInt::get(OldIntType, bitw / 2));
bb->getInstList().insert(bb->getTerminator()->getIterator(), s_op0); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op0);
op0_high = new TruncInst(s_op0, NewIntType); op0_high = new TruncInst(s_op0, NewIntType);
bb->getInstList().insert(bb->getTerminator()->getIterator(), op0_high); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
op0_high);
s_op1 = BinaryOperator::Create(Instruction::LShr, op1, s_op1 = BinaryOperator::Create(Instruction::LShr, op1,
ConstantInt::get(OldIntType, bitw / 2)); ConstantInt::get(OldIntType, bitw / 2));
bb->getInstList().insert(bb->getTerminator()->getIterator(), s_op1); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op1);
op1_high = new TruncInst(s_op1, NewIntType); op1_high = new TruncInst(s_op1, NewIntType);
bb->getInstList().insert(bb->getTerminator()->getIterator(), op1_high); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
op1_high);
icmp_high = CmpInst::Create(Instruction::ICmp, pred, op0_high, op1_high); icmp_high = CmpInst::Create(Instruction::ICmp, pred, op0_high, op1_high);
bb->getInstList().insert(bb->getTerminator()->getIterator(), icmp_high); bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
icmp_high);
/* now we have to destinguish between == != and > < */ /* now we have to destinguish between == != and > < */
if (pred == CmpInst::ICMP_EQ || pred == CmpInst::ICMP_NE) { if (pred == CmpInst::ICMP_EQ || pred == CmpInst::ICMP_NE) {
@ -1194,13 +1258,19 @@ bool SplitComparesTransform::runOnModule(Module &M) {
<< "bit: " << splitIntCompares(M, bitw) << " splitted\n"; << "bit: " << splitIntCompares(M, bitw) << " splitted\n";
bitw >>= 1; bitw >>= 1;
#if LLVM_VERSION_MAJOR > 3 || \
(LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 7)
[[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */ [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */
#endif
case 32: case 32:
errs() << "Split-integer-compare-pass " << bitw errs() << "Split-integer-compare-pass " << bitw
<< "bit: " << splitIntCompares(M, bitw) << " splitted\n"; << "bit: " << splitIntCompares(M, bitw) << " splitted\n";
bitw >>= 1; bitw >>= 1;
#if LLVM_VERSION_MAJOR > 3 || \
(LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 7)
[[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */ [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */
#endif
case 16: case 16:
errs() << "Split-integer-compare-pass " << bitw errs() << "Split-integer-compare-pass " << bitw
<< "bit: " << splitIntCompares(M, bitw) << " splitted\n"; << "bit: " << splitIntCompares(M, bitw) << " splitted\n";

View File

@ -23,8 +23,9 @@
#include <fstream> #include <fstream>
#include <sys/time.h> #include <sys/time.h>
#include "llvm/Config/llvm-config.h"
#include "llvm/ADT/Statistic.h" #include "llvm/ADT/Statistic.h"
#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/IRBuilder.h" #include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h" #include "llvm/IR/Module.h"
@ -32,10 +33,20 @@
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Pass.h" #include "llvm/Pass.h"
#include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/IRBuilder.h"
#if LLVM_VERSION_MAJOR > 3 || \
(LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4)
#include "llvm/IR/Verifier.h"
#include "llvm/IR/DebugInfo.h"
#else
#include "llvm/Analysis/Verifier.h"
#include "llvm/DebugInfo.h"
#define nullptr 0
#endif
#include <set> #include <set>
using namespace llvm; using namespace llvm;
@ -69,7 +80,7 @@ class SplitSwitchesTransform : public ModulePass {
static bool isBlacklisted(const Function *F) { static bool isBlacklisted(const Function *F) {
static const SmallVector<std::string, 5> Blacklist = { static const char *Blacklist[] = {
"asan.", "llvm.", "sancov.", "__ubsan_handle_", "ign." "asan.", "llvm.", "sancov.", "__ubsan_handle_", "ign."
@ -140,7 +151,7 @@ BasicBlock *SplitSwitchesTransform::switchConvert(
IntegerType * ByteType = IntegerType::get(OrigBlock->getContext(), 8); IntegerType * ByteType = IntegerType::get(OrigBlock->getContext(), 8);
unsigned BytesInValue = bytesChecked.size(); unsigned BytesInValue = bytesChecked.size();
std::vector<uint8_t> setSizes; std::vector<uint8_t> setSizes;
std::vector<std::set<uint8_t>> byteSets(BytesInValue, std::set<uint8_t>()); std::vector<std::set<uint8_t> > byteSets(BytesInValue, std::set<uint8_t>());
assert(ValTypeBitWidth >= 8 && ValTypeBitWidth <= 64); assert(ValTypeBitWidth >= 8 && ValTypeBitWidth <= 64);
@ -213,8 +224,25 @@ BasicBlock *SplitSwitchesTransform::switchConvert(
NewNode->getInstList().push_back(Comp); NewNode->getInstList().push_back(Comp);
bytesChecked[smallestIndex] = true; bytesChecked[smallestIndex] = true;
if (std::all_of(bytesChecked.begin(), bytesChecked.end(), bool allBytesAreChecked = true;
[](bool b) { return b; })) {
for (std::vector<bool>::iterator BCI = bytesChecked.begin(),
E = bytesChecked.end();
BCI != E; ++BCI) {
if (!*BCI) {
allBytesAreChecked = false;
break;
}
}
// if (std::all_of(bytesChecked.begin(), bytesChecked.end(),
// [](bool b) { return b; })) {
if (allBytesAreChecked) {
assert(Cases.size() == 1); assert(Cases.size() == 1);
BranchInst::Create(Cases[0].BB, NewDefault, Comp, NewNode); BranchInst::Create(Cases[0].BB, NewDefault, Comp, NewNode);
@ -306,6 +334,10 @@ BasicBlock *SplitSwitchesTransform::switchConvert(
bool SplitSwitchesTransform::splitSwitches(Module &M) { bool SplitSwitchesTransform::splitSwitches(Module &M) {
#if (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR < 7)
LLVMContext &C = M.getContext();
#endif
std::vector<SwitchInst *> switches; std::vector<SwitchInst *> switches;
/* iterate over all functions, bbs and instruction and add /* iterate over all functions, bbs and instruction and add
@ -327,6 +359,8 @@ bool SplitSwitchesTransform::splitSwitches(Module &M) {
* For now, just instrument the block if we are not able * For now, just instrument the block if we are not able
* to determine our location. */ * to determine our location. */
DebugLoc Loc = IP->getDebugLoc(); DebugLoc Loc = IP->getDebugLoc();
#if LLVM_VERSION_MAJOR >= 4 || \
(LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR >= 7)
if (Loc) { if (Loc) {
DILocation *cDILoc = dyn_cast<DILocation>(Loc.getAsMDNode()); DILocation *cDILoc = dyn_cast<DILocation>(Loc.getAsMDNode());
@ -379,6 +413,47 @@ bool SplitSwitchesTransform::splitSwitches(Module &M) {
} }
#else
if (!Loc.isUnknown()) {
DILocation cDILoc(Loc.getAsMDNode(C));
unsigned int instLine = cDILoc.getLineNumber();
StringRef instFilename = cDILoc.getFilename();
(void)instLine;
/* Continue only if we know where we actually are */
if (!instFilename.str().empty()) {
for (std::list<std::string>::iterator it = myWhitelist.begin();
it != myWhitelist.end(); ++it) {
/* We don't check for filename equality here because
* filenames might actually be full paths. Instead we
* check that the actual filename ends in the filename
* specified in the list. */
if (instFilename.str().length() >= it->length()) {
if (instFilename.str().compare(
instFilename.str().length() - it->length(),
it->length(), *it) == 0) {
instrumentBlock = true;
break;
}
}
}
}
}
#endif
/* Either we couldn't figure out our location or the location is /* Either we couldn't figure out our location or the location is
* not whitelisted, so we skip instrumentation. */ * not whitelisted, so we skip instrumentation. */
if (!instrumentBlock) continue; if (!instrumentBlock) continue;
@ -426,8 +501,7 @@ bool SplitSwitchesTransform::splitSwitches(Module &M) {
* if the default block is set as an unreachable we avoid creating one * if the default block is set as an unreachable we avoid creating one
* because will never be a valid target.*/ * because will never be a valid target.*/
BasicBlock *NewDefault = nullptr; BasicBlock *NewDefault = nullptr;
NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault", F, Default);
NewDefault->insertInto(F, Default);
BranchInst::Create(Default, NewDefault); BranchInst::Create(Default, NewDefault);
/* Prepare cases vector. */ /* Prepare cases vector. */