mirror of
https://github.com/AFLplusplus/AFLplusplus.git
synced 2025-06-17 20:28:08 +00:00
run code formatter
This commit is contained in:
@ -36,202 +36,236 @@ using namespace llvm;
|
||||
|
||||
namespace {
|
||||
|
||||
class CompareTransform : public ModulePass {
|
||||
class CompareTransform : public ModulePass {
|
||||
|
||||
public:
|
||||
static char ID;
|
||||
CompareTransform() : ModulePass(ID) {
|
||||
}
|
||||
public:
|
||||
static char ID;
|
||||
CompareTransform() : ModulePass(ID) {
|
||||
|
||||
bool runOnModule(Module &M) override;
|
||||
}
|
||||
|
||||
bool runOnModule(Module &M) override;
|
||||
|
||||
#if LLVM_VERSION_MAJOR < 4
|
||||
const char * getPassName() const override {
|
||||
#else
|
||||
StringRef getPassName() const override {
|
||||
#endif
|
||||
return "transforms compare functions";
|
||||
}
|
||||
private:
|
||||
bool transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp
|
||||
,const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp);
|
||||
};
|
||||
}
|
||||
const char *getPassName() const override {
|
||||
|
||||
#else
|
||||
StringRef getPassName() const override {
|
||||
|
||||
#endif
|
||||
return "transforms compare functions";
|
||||
|
||||
}
|
||||
|
||||
private:
|
||||
bool transformCmps(Module &M, const bool processStrcmp,
|
||||
const bool processMemcmp, const bool processStrncmp,
|
||||
const bool processStrcasecmp,
|
||||
const bool processStrncasecmp);
|
||||
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
char CompareTransform::ID = 0;
|
||||
|
||||
bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp
|
||||
, const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp) {
|
||||
bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
|
||||
const bool processMemcmp,
|
||||
const bool processStrncmp,
|
||||
const bool processStrcasecmp,
|
||||
const bool processStrncasecmp) {
|
||||
|
||||
std::vector<CallInst*> calls;
|
||||
LLVMContext &C = M.getContext();
|
||||
IntegerType *Int8Ty = IntegerType::getInt8Ty(C);
|
||||
IntegerType *Int32Ty = IntegerType::getInt32Ty(C);
|
||||
IntegerType *Int64Ty = IntegerType::getInt64Ty(C);
|
||||
std::vector<CallInst *> calls;
|
||||
LLVMContext & C = M.getContext();
|
||||
IntegerType * Int8Ty = IntegerType::getInt8Ty(C);
|
||||
IntegerType * Int32Ty = IntegerType::getInt32Ty(C);
|
||||
IntegerType * Int64Ty = IntegerType::getInt64Ty(C);
|
||||
|
||||
#if LLVM_VERSION_MAJOR < 9
|
||||
Constant*
|
||||
Constant *
|
||||
#else
|
||||
FunctionCallee
|
||||
#endif
|
||||
c = M.getOrInsertFunction("tolower",
|
||||
Int32Ty,
|
||||
Int32Ty
|
||||
c = M.getOrInsertFunction("tolower", Int32Ty, Int32Ty
|
||||
#if LLVM_VERSION_MAJOR < 5
|
||||
, nullptr
|
||||
,
|
||||
nullptr
|
||||
#endif
|
||||
);
|
||||
);
|
||||
#if LLVM_VERSION_MAJOR < 9
|
||||
Function* tolowerFn = cast<Function>(c);
|
||||
Function *tolowerFn = cast<Function>(c);
|
||||
#else
|
||||
FunctionCallee tolowerFn = c;
|
||||
#endif
|
||||
|
||||
/* iterate over all functions, bbs and instruction and add suitable calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp */
|
||||
/* iterate over all functions, bbs and instruction and add suitable calls to
|
||||
* strcmp/memcmp/strncmp/strcasecmp/strncasecmp */
|
||||
for (auto &F : M) {
|
||||
|
||||
for (auto &BB : F) {
|
||||
for(auto &IN: BB) {
|
||||
CallInst* callInst = nullptr;
|
||||
|
||||
for (auto &IN : BB) {
|
||||
|
||||
CallInst *callInst = nullptr;
|
||||
|
||||
if ((callInst = dyn_cast<CallInst>(&IN))) {
|
||||
|
||||
bool isStrcmp = processStrcmp;
|
||||
bool isMemcmp = processMemcmp;
|
||||
bool isStrncmp = processStrncmp;
|
||||
bool isStrcasecmp = processStrcasecmp;
|
||||
bool isStrcmp = processStrcmp;
|
||||
bool isMemcmp = processMemcmp;
|
||||
bool isStrncmp = processStrncmp;
|
||||
bool isStrcasecmp = processStrcasecmp;
|
||||
bool isStrncasecmp = processStrncasecmp;
|
||||
|
||||
Function *Callee = callInst->getCalledFunction();
|
||||
if (!Callee)
|
||||
continue;
|
||||
if (callInst->getCallingConv() != llvm::CallingConv::C)
|
||||
continue;
|
||||
if (!Callee) continue;
|
||||
if (callInst->getCallingConv() != llvm::CallingConv::C) continue;
|
||||
StringRef FuncName = Callee->getName();
|
||||
isStrcmp &= !FuncName.compare(StringRef("strcmp"));
|
||||
isMemcmp &= !FuncName.compare(StringRef("memcmp"));
|
||||
isStrncmp &= !FuncName.compare(StringRef("strncmp"));
|
||||
isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp"));
|
||||
isStrcmp &= !FuncName.compare(StringRef("strcmp"));
|
||||
isMemcmp &= !FuncName.compare(StringRef("memcmp"));
|
||||
isStrncmp &= !FuncName.compare(StringRef("strncmp"));
|
||||
isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp"));
|
||||
isStrncasecmp &= !FuncName.compare(StringRef("strncasecmp"));
|
||||
|
||||
if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp)
|
||||
if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
|
||||
!isStrncasecmp)
|
||||
continue;
|
||||
|
||||
/* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function prototype */
|
||||
/* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function
|
||||
* prototype */
|
||||
FunctionType *FT = Callee->getFunctionType();
|
||||
|
||||
|
||||
isStrcmp &= FT->getNumParams() == 2 &&
|
||||
FT->getReturnType()->isIntegerTy(32) &&
|
||||
FT->getParamType(0) == FT->getParamType(1) &&
|
||||
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
|
||||
isStrcasecmp &= FT->getNumParams() == 2 &&
|
||||
FT->getReturnType()->isIntegerTy(32) &&
|
||||
FT->getParamType(0) == FT->getParamType(1) &&
|
||||
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
|
||||
isMemcmp &= FT->getNumParams() == 3 &&
|
||||
isStrcmp &=
|
||||
FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) &&
|
||||
FT->getParamType(0) == FT->getParamType(1) &&
|
||||
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
|
||||
isStrcasecmp &=
|
||||
FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) &&
|
||||
FT->getParamType(0) == FT->getParamType(1) &&
|
||||
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
|
||||
isMemcmp &= FT->getNumParams() == 3 &&
|
||||
FT->getReturnType()->isIntegerTy(32) &&
|
||||
FT->getParamType(0)->isPointerTy() &&
|
||||
FT->getParamType(1)->isPointerTy() &&
|
||||
FT->getParamType(2)->isIntegerTy();
|
||||
isStrncmp &= FT->getNumParams() == 3 &&
|
||||
FT->getReturnType()->isIntegerTy(32) &&
|
||||
FT->getParamType(0) == FT->getParamType(1) &&
|
||||
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) &&
|
||||
FT->getParamType(2)->isIntegerTy();
|
||||
isStrncmp &= FT->getNumParams() == 3 &&
|
||||
FT->getReturnType()->isIntegerTy(32) &&
|
||||
FT->getParamType(0) == FT->getParamType(1) &&
|
||||
FT->getParamType(0) ==
|
||||
IntegerType::getInt8PtrTy(M.getContext()) &&
|
||||
FT->getParamType(2)->isIntegerTy();
|
||||
isStrncasecmp &= FT->getNumParams() == 3 &&
|
||||
FT->getReturnType()->isIntegerTy(32) &&
|
||||
FT->getParamType(0) == FT->getParamType(1) &&
|
||||
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) &&
|
||||
FT->getParamType(2)->isIntegerTy();
|
||||
FT->getReturnType()->isIntegerTy(32) &&
|
||||
FT->getParamType(0) == FT->getParamType(1) &&
|
||||
FT->getParamType(0) ==
|
||||
IntegerType::getInt8PtrTy(M.getContext()) &&
|
||||
FT->getParamType(2)->isIntegerTy();
|
||||
|
||||
if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp)
|
||||
if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
|
||||
!isStrncasecmp)
|
||||
continue;
|
||||
|
||||
/* is a str{n,}{case,}cmp/memcmp, check if we have
|
||||
* str{case,}cmp(x, "const") or str{case,}cmp("const", x)
|
||||
* strn{case,}cmp(x, "const", ..) or strn{case,}cmp("const", x, ..)
|
||||
* memcmp(x, "const", ..) or memcmp("const", x, ..) */
|
||||
Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1);
|
||||
Value *Str1P = callInst->getArgOperand(0),
|
||||
*Str2P = callInst->getArgOperand(1);
|
||||
StringRef Str1, Str2;
|
||||
bool HasStr1 = getConstantStringInfo(Str1P, Str1);
|
||||
bool HasStr2 = getConstantStringInfo(Str2P, Str2);
|
||||
bool HasStr1 = getConstantStringInfo(Str1P, Str1);
|
||||
bool HasStr2 = getConstantStringInfo(Str2P, Str2);
|
||||
|
||||
/* handle cases of one string is const, one string is variable */
|
||||
if (!(HasStr1 ^ HasStr2))
|
||||
continue;
|
||||
if (!(HasStr1 ^ HasStr2)) continue;
|
||||
|
||||
if (isMemcmp || isStrncmp || isStrncasecmp) {
|
||||
|
||||
/* check if third operand is a constant integer
|
||||
* strlen("constStr") and sizeof() are treated as constant */
|
||||
Value *op2 = callInst->getArgOperand(2);
|
||||
ConstantInt* ilen = dyn_cast<ConstantInt>(op2);
|
||||
if (!ilen)
|
||||
continue;
|
||||
/* final precaution: if size of compare is larger than constant string skip it*/
|
||||
uint64_t literalLength = HasStr1 ? GetStringLength(Str1P) : GetStringLength(Str2P);
|
||||
if (literalLength < ilen->getZExtValue())
|
||||
continue;
|
||||
Value * op2 = callInst->getArgOperand(2);
|
||||
ConstantInt *ilen = dyn_cast<ConstantInt>(op2);
|
||||
if (!ilen) continue;
|
||||
/* final precaution: if size of compare is larger than constant
|
||||
* string skip it*/
|
||||
uint64_t literalLength =
|
||||
HasStr1 ? GetStringLength(Str1P) : GetStringLength(Str2P);
|
||||
if (literalLength < ilen->getZExtValue()) continue;
|
||||
|
||||
}
|
||||
|
||||
calls.push_back(callInst);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (!calls.size())
|
||||
return false;
|
||||
errs() << "Replacing " << calls.size() << " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n";
|
||||
if (!calls.size()) return false;
|
||||
errs() << "Replacing " << calls.size()
|
||||
<< " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n";
|
||||
|
||||
for (auto &callInst: calls) {
|
||||
for (auto &callInst : calls) {
|
||||
|
||||
Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1);
|
||||
StringRef Str1, Str2, ConstStr;
|
||||
Value *Str1P = callInst->getArgOperand(0),
|
||||
*Str2P = callInst->getArgOperand(1);
|
||||
StringRef Str1, Str2, ConstStr;
|
||||
std::string TmpConstStr;
|
||||
Value *VarStr;
|
||||
bool HasStr1 = getConstantStringInfo(Str1P, Str1);
|
||||
Value * VarStr;
|
||||
bool HasStr1 = getConstantStringInfo(Str1P, Str1);
|
||||
getConstantStringInfo(Str2P, Str2);
|
||||
uint64_t constLen, sizedLen;
|
||||
bool isMemcmp = !callInst->getCalledFunction()->getName().compare(StringRef("memcmp"));
|
||||
bool isSizedcmp = isMemcmp
|
||||
|| !callInst->getCalledFunction()->getName().compare(StringRef("strncmp"))
|
||||
|| !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp"));
|
||||
bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare(StringRef("strcasecmp"))
|
||||
|| !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp"));
|
||||
bool isMemcmp =
|
||||
!callInst->getCalledFunction()->getName().compare(StringRef("memcmp"));
|
||||
bool isSizedcmp = isMemcmp ||
|
||||
!callInst->getCalledFunction()->getName().compare(
|
||||
StringRef("strncmp")) ||
|
||||
!callInst->getCalledFunction()->getName().compare(
|
||||
StringRef("strncasecmp"));
|
||||
bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare(
|
||||
StringRef("strcasecmp")) ||
|
||||
!callInst->getCalledFunction()->getName().compare(
|
||||
StringRef("strncasecmp"));
|
||||
|
||||
if (isSizedcmp) {
|
||||
Value *op2 = callInst->getArgOperand(2);
|
||||
ConstantInt* ilen = dyn_cast<ConstantInt>(op2);
|
||||
|
||||
Value * op2 = callInst->getArgOperand(2);
|
||||
ConstantInt *ilen = dyn_cast<ConstantInt>(op2);
|
||||
sizedLen = ilen->getZExtValue();
|
||||
|
||||
}
|
||||
|
||||
if (HasStr1) {
|
||||
|
||||
TmpConstStr = Str1.str();
|
||||
VarStr = Str2P;
|
||||
constLen = isMemcmp ? sizedLen : GetStringLength(Str1P);
|
||||
}
|
||||
else {
|
||||
|
||||
} else {
|
||||
|
||||
TmpConstStr = Str2.str();
|
||||
VarStr = Str1P;
|
||||
constLen = isMemcmp ? sizedLen : GetStringLength(Str2P);
|
||||
|
||||
}
|
||||
|
||||
/* properly handle zero terminated C strings by adding the terminating 0 to
|
||||
* the StringRef (in comparison to std::string a StringRef has built-in
|
||||
* runtime bounds checking, which makes debugging easier) */
|
||||
TmpConstStr.append("\0", 1); ConstStr = StringRef(TmpConstStr);
|
||||
TmpConstStr.append("\0", 1);
|
||||
ConstStr = StringRef(TmpConstStr);
|
||||
|
||||
if (isSizedcmp && constLen > sizedLen) {
|
||||
constLen = sizedLen;
|
||||
}
|
||||
if (isSizedcmp && constLen > sizedLen) { constLen = sizedLen; }
|
||||
|
||||
errs() << callInst->getCalledFunction()->getName() << ": len " << constLen << ": " << ConstStr << "\n";
|
||||
errs() << callInst->getCalledFunction()->getName() << ": len " << constLen
|
||||
<< ": " << ConstStr << "\n";
|
||||
|
||||
/* split before the call instruction */
|
||||
BasicBlock *bb = callInst->getParent();
|
||||
BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(callInst));
|
||||
BasicBlock *next_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
|
||||
BasicBlock *next_bb =
|
||||
BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
|
||||
BranchInst::Create(end_bb, next_bb);
|
||||
PHINode *PN = PHINode::Create(Int32Ty, constLen + 1, "cmp_phi");
|
||||
|
||||
@ -249,71 +283,81 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, const
|
||||
|
||||
char c = isCaseInsensitive ? tolower(ConstStr[i]) : ConstStr[i];
|
||||
|
||||
|
||||
BasicBlock::iterator IP = next_bb->getFirstInsertionPt();
|
||||
IRBuilder<> IRB(&*IP);
|
||||
IRBuilder<> IRB(&*IP);
|
||||
|
||||
Value* v = ConstantInt::get(Int64Ty, i);
|
||||
Value *ele = IRB.CreateInBoundsGEP(VarStr, v, "empty");
|
||||
Value *v = ConstantInt::get(Int64Ty, i);
|
||||
Value *ele = IRB.CreateInBoundsGEP(VarStr, v, "empty");
|
||||
Value *load = IRB.CreateLoad(ele);
|
||||
if (isCaseInsensitive) {
|
||||
|
||||
// load >= 'A' && load <= 'Z' ? load | 0x020 : load
|
||||
std::vector<Value *> args;
|
||||
args.push_back(load);
|
||||
load = IRB.CreateCall(tolowerFn, args, "tmp");
|
||||
load = IRB.CreateTrunc(load, Int8Ty);
|
||||
|
||||
}
|
||||
|
||||
Value *isub;
|
||||
if (HasStr1)
|
||||
isub = IRB.CreateSub(ConstantInt::get(Int8Ty, c), load);
|
||||
else
|
||||
isub = IRB.CreateSub(load, ConstantInt::get(Int8Ty, c));
|
||||
|
||||
Value *sext = IRB.CreateSExt(isub, Int32Ty);
|
||||
Value *sext = IRB.CreateSExt(isub, Int32Ty);
|
||||
PN->addIncoming(sext, cur_bb);
|
||||
|
||||
|
||||
if (i < constLen - 1) {
|
||||
next_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
|
||||
|
||||
next_bb =
|
||||
BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
|
||||
BranchInst::Create(end_bb, next_bb);
|
||||
|
||||
Value *icmp = IRB.CreateICmpEQ(isub, ConstantInt::get(Int8Ty, 0));
|
||||
IRB.CreateCondBr(icmp, next_bb, end_bb);
|
||||
cur_bb->getTerminator()->eraseFromParent();
|
||||
|
||||
} else {
|
||||
//IRB.CreateBr(end_bb);
|
||||
|
||||
// IRB.CreateBr(end_bb);
|
||||
|
||||
}
|
||||
|
||||
//add offset to varstr
|
||||
//create load
|
||||
//create signed isub
|
||||
//create icmp
|
||||
//create jcc
|
||||
//create next_bb
|
||||
// add offset to varstr
|
||||
// create load
|
||||
// create signed isub
|
||||
// create icmp
|
||||
// create jcc
|
||||
// create next_bb
|
||||
|
||||
}
|
||||
|
||||
/* since the call is the first instruction of the bb it is safe to
|
||||
* replace it with a phi instruction */
|
||||
BasicBlock::iterator ii(callInst);
|
||||
ReplaceInstWithInst(callInst->getParent()->getInstList(), ii, PN);
|
||||
|
||||
}
|
||||
|
||||
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
bool CompareTransform::runOnModule(Module &M) {
|
||||
|
||||
if (getenv("AFL_QUIET") == NULL)
|
||||
llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, extended by heiko@hexco.de\n";
|
||||
llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, "
|
||||
"extended by heiko@hexco.de\n";
|
||||
transformCmps(M, true, true, true, true, true);
|
||||
verifyModule(M);
|
||||
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
static void registerCompTransPass(const PassManagerBuilder &,
|
||||
legacy::PassManagerBase &PM) {
|
||||
legacy::PassManagerBase &PM) {
|
||||
|
||||
auto p = new CompareTransform();
|
||||
PM.add(p);
|
||||
|
Reference in New Issue
Block a user