run code formatter

This commit is contained in:
Andrea Fioraldi
2019-09-02 18:49:43 +02:00
parent 2ae4ca91b4
commit b24639d011
57 changed files with 8674 additions and 7125 deletions

View File

@ -36,202 +36,236 @@ using namespace llvm;
namespace {
class CompareTransform : public ModulePass {
class CompareTransform : public ModulePass {
public:
static char ID;
CompareTransform() : ModulePass(ID) {
}
public:
static char ID;
CompareTransform() : ModulePass(ID) {
bool runOnModule(Module &M) override;
}
bool runOnModule(Module &M) override;
#if LLVM_VERSION_MAJOR < 4
const char * getPassName() const override {
#else
StringRef getPassName() const override {
#endif
return "transforms compare functions";
}
private:
bool transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp
,const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp);
};
}
const char *getPassName() const override {
#else
StringRef getPassName() const override {
#endif
return "transforms compare functions";
}
private:
bool transformCmps(Module &M, const bool processStrcmp,
const bool processMemcmp, const bool processStrncmp,
const bool processStrcasecmp,
const bool processStrncasecmp);
};
} // namespace
char CompareTransform::ID = 0;
bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp
, const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp) {
bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
const bool processMemcmp,
const bool processStrncmp,
const bool processStrcasecmp,
const bool processStrncasecmp) {
std::vector<CallInst*> calls;
LLVMContext &C = M.getContext();
IntegerType *Int8Ty = IntegerType::getInt8Ty(C);
IntegerType *Int32Ty = IntegerType::getInt32Ty(C);
IntegerType *Int64Ty = IntegerType::getInt64Ty(C);
std::vector<CallInst *> calls;
LLVMContext & C = M.getContext();
IntegerType * Int8Ty = IntegerType::getInt8Ty(C);
IntegerType * Int32Ty = IntegerType::getInt32Ty(C);
IntegerType * Int64Ty = IntegerType::getInt64Ty(C);
#if LLVM_VERSION_MAJOR < 9
Constant*
Constant *
#else
FunctionCallee
#endif
c = M.getOrInsertFunction("tolower",
Int32Ty,
Int32Ty
c = M.getOrInsertFunction("tolower", Int32Ty, Int32Ty
#if LLVM_VERSION_MAJOR < 5
, nullptr
,
nullptr
#endif
);
);
#if LLVM_VERSION_MAJOR < 9
Function* tolowerFn = cast<Function>(c);
Function *tolowerFn = cast<Function>(c);
#else
FunctionCallee tolowerFn = c;
#endif
/* iterate over all functions, bbs and instruction and add suitable calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp */
/* iterate over all functions, bbs and instruction and add suitable calls to
* strcmp/memcmp/strncmp/strcasecmp/strncasecmp */
for (auto &F : M) {
for (auto &BB : F) {
for(auto &IN: BB) {
CallInst* callInst = nullptr;
for (auto &IN : BB) {
CallInst *callInst = nullptr;
if ((callInst = dyn_cast<CallInst>(&IN))) {
bool isStrcmp = processStrcmp;
bool isMemcmp = processMemcmp;
bool isStrncmp = processStrncmp;
bool isStrcasecmp = processStrcasecmp;
bool isStrcmp = processStrcmp;
bool isMemcmp = processMemcmp;
bool isStrncmp = processStrncmp;
bool isStrcasecmp = processStrcasecmp;
bool isStrncasecmp = processStrncasecmp;
Function *Callee = callInst->getCalledFunction();
if (!Callee)
continue;
if (callInst->getCallingConv() != llvm::CallingConv::C)
continue;
if (!Callee) continue;
if (callInst->getCallingConv() != llvm::CallingConv::C) continue;
StringRef FuncName = Callee->getName();
isStrcmp &= !FuncName.compare(StringRef("strcmp"));
isMemcmp &= !FuncName.compare(StringRef("memcmp"));
isStrncmp &= !FuncName.compare(StringRef("strncmp"));
isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp"));
isStrcmp &= !FuncName.compare(StringRef("strcmp"));
isMemcmp &= !FuncName.compare(StringRef("memcmp"));
isStrncmp &= !FuncName.compare(StringRef("strncmp"));
isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp"));
isStrncasecmp &= !FuncName.compare(StringRef("strncasecmp"));
if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp)
if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
!isStrncasecmp)
continue;
/* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function prototype */
/* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function
* prototype */
FunctionType *FT = Callee->getFunctionType();
isStrcmp &= FT->getNumParams() == 2 &&
FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0) == FT->getParamType(1) &&
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
isStrcasecmp &= FT->getNumParams() == 2 &&
FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0) == FT->getParamType(1) &&
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
isMemcmp &= FT->getNumParams() == 3 &&
isStrcmp &=
FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0) == FT->getParamType(1) &&
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
isStrcasecmp &=
FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0) == FT->getParamType(1) &&
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
isMemcmp &= FT->getNumParams() == 3 &&
FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0)->isPointerTy() &&
FT->getParamType(1)->isPointerTy() &&
FT->getParamType(2)->isIntegerTy();
isStrncmp &= FT->getNumParams() == 3 &&
FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0) == FT->getParamType(1) &&
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) &&
FT->getParamType(2)->isIntegerTy();
isStrncmp &= FT->getNumParams() == 3 &&
FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0) == FT->getParamType(1) &&
FT->getParamType(0) ==
IntegerType::getInt8PtrTy(M.getContext()) &&
FT->getParamType(2)->isIntegerTy();
isStrncasecmp &= FT->getNumParams() == 3 &&
FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0) == FT->getParamType(1) &&
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) &&
FT->getParamType(2)->isIntegerTy();
FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0) == FT->getParamType(1) &&
FT->getParamType(0) ==
IntegerType::getInt8PtrTy(M.getContext()) &&
FT->getParamType(2)->isIntegerTy();
if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp)
if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
!isStrncasecmp)
continue;
/* is a str{n,}{case,}cmp/memcmp, check if we have
* str{case,}cmp(x, "const") or str{case,}cmp("const", x)
* strn{case,}cmp(x, "const", ..) or strn{case,}cmp("const", x, ..)
* memcmp(x, "const", ..) or memcmp("const", x, ..) */
Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1);
Value *Str1P = callInst->getArgOperand(0),
*Str2P = callInst->getArgOperand(1);
StringRef Str1, Str2;
bool HasStr1 = getConstantStringInfo(Str1P, Str1);
bool HasStr2 = getConstantStringInfo(Str2P, Str2);
bool HasStr1 = getConstantStringInfo(Str1P, Str1);
bool HasStr2 = getConstantStringInfo(Str2P, Str2);
/* handle cases of one string is const, one string is variable */
if (!(HasStr1 ^ HasStr2))
continue;
if (!(HasStr1 ^ HasStr2)) continue;
if (isMemcmp || isStrncmp || isStrncasecmp) {
/* check if third operand is a constant integer
* strlen("constStr") and sizeof() are treated as constant */
Value *op2 = callInst->getArgOperand(2);
ConstantInt* ilen = dyn_cast<ConstantInt>(op2);
if (!ilen)
continue;
/* final precaution: if size of compare is larger than constant string skip it*/
uint64_t literalLength = HasStr1 ? GetStringLength(Str1P) : GetStringLength(Str2P);
if (literalLength < ilen->getZExtValue())
continue;
Value * op2 = callInst->getArgOperand(2);
ConstantInt *ilen = dyn_cast<ConstantInt>(op2);
if (!ilen) continue;
/* final precaution: if size of compare is larger than constant
* string skip it*/
uint64_t literalLength =
HasStr1 ? GetStringLength(Str1P) : GetStringLength(Str2P);
if (literalLength < ilen->getZExtValue()) continue;
}
calls.push_back(callInst);
}
}
}
}
if (!calls.size())
return false;
errs() << "Replacing " << calls.size() << " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n";
if (!calls.size()) return false;
errs() << "Replacing " << calls.size()
<< " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n";
for (auto &callInst: calls) {
for (auto &callInst : calls) {
Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1);
StringRef Str1, Str2, ConstStr;
Value *Str1P = callInst->getArgOperand(0),
*Str2P = callInst->getArgOperand(1);
StringRef Str1, Str2, ConstStr;
std::string TmpConstStr;
Value *VarStr;
bool HasStr1 = getConstantStringInfo(Str1P, Str1);
Value * VarStr;
bool HasStr1 = getConstantStringInfo(Str1P, Str1);
getConstantStringInfo(Str2P, Str2);
uint64_t constLen, sizedLen;
bool isMemcmp = !callInst->getCalledFunction()->getName().compare(StringRef("memcmp"));
bool isSizedcmp = isMemcmp
|| !callInst->getCalledFunction()->getName().compare(StringRef("strncmp"))
|| !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp"));
bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare(StringRef("strcasecmp"))
|| !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp"));
bool isMemcmp =
!callInst->getCalledFunction()->getName().compare(StringRef("memcmp"));
bool isSizedcmp = isMemcmp ||
!callInst->getCalledFunction()->getName().compare(
StringRef("strncmp")) ||
!callInst->getCalledFunction()->getName().compare(
StringRef("strncasecmp"));
bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare(
StringRef("strcasecmp")) ||
!callInst->getCalledFunction()->getName().compare(
StringRef("strncasecmp"));
if (isSizedcmp) {
Value *op2 = callInst->getArgOperand(2);
ConstantInt* ilen = dyn_cast<ConstantInt>(op2);
Value * op2 = callInst->getArgOperand(2);
ConstantInt *ilen = dyn_cast<ConstantInt>(op2);
sizedLen = ilen->getZExtValue();
}
if (HasStr1) {
TmpConstStr = Str1.str();
VarStr = Str2P;
constLen = isMemcmp ? sizedLen : GetStringLength(Str1P);
}
else {
} else {
TmpConstStr = Str2.str();
VarStr = Str1P;
constLen = isMemcmp ? sizedLen : GetStringLength(Str2P);
}
/* properly handle zero terminated C strings by adding the terminating 0 to
* the StringRef (in comparison to std::string a StringRef has built-in
* runtime bounds checking, which makes debugging easier) */
TmpConstStr.append("\0", 1); ConstStr = StringRef(TmpConstStr);
TmpConstStr.append("\0", 1);
ConstStr = StringRef(TmpConstStr);
if (isSizedcmp && constLen > sizedLen) {
constLen = sizedLen;
}
if (isSizedcmp && constLen > sizedLen) { constLen = sizedLen; }
errs() << callInst->getCalledFunction()->getName() << ": len " << constLen << ": " << ConstStr << "\n";
errs() << callInst->getCalledFunction()->getName() << ": len " << constLen
<< ": " << ConstStr << "\n";
/* split before the call instruction */
BasicBlock *bb = callInst->getParent();
BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(callInst));
BasicBlock *next_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
BasicBlock *next_bb =
BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
BranchInst::Create(end_bb, next_bb);
PHINode *PN = PHINode::Create(Int32Ty, constLen + 1, "cmp_phi");
@ -249,71 +283,81 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, const
char c = isCaseInsensitive ? tolower(ConstStr[i]) : ConstStr[i];
BasicBlock::iterator IP = next_bb->getFirstInsertionPt();
IRBuilder<> IRB(&*IP);
IRBuilder<> IRB(&*IP);
Value* v = ConstantInt::get(Int64Ty, i);
Value *ele = IRB.CreateInBoundsGEP(VarStr, v, "empty");
Value *v = ConstantInt::get(Int64Ty, i);
Value *ele = IRB.CreateInBoundsGEP(VarStr, v, "empty");
Value *load = IRB.CreateLoad(ele);
if (isCaseInsensitive) {
// load >= 'A' && load <= 'Z' ? load | 0x020 : load
std::vector<Value *> args;
args.push_back(load);
load = IRB.CreateCall(tolowerFn, args, "tmp");
load = IRB.CreateTrunc(load, Int8Ty);
}
Value *isub;
if (HasStr1)
isub = IRB.CreateSub(ConstantInt::get(Int8Ty, c), load);
else
isub = IRB.CreateSub(load, ConstantInt::get(Int8Ty, c));
Value *sext = IRB.CreateSExt(isub, Int32Ty);
Value *sext = IRB.CreateSExt(isub, Int32Ty);
PN->addIncoming(sext, cur_bb);
if (i < constLen - 1) {
next_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
next_bb =
BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
BranchInst::Create(end_bb, next_bb);
Value *icmp = IRB.CreateICmpEQ(isub, ConstantInt::get(Int8Ty, 0));
IRB.CreateCondBr(icmp, next_bb, end_bb);
cur_bb->getTerminator()->eraseFromParent();
} else {
//IRB.CreateBr(end_bb);
// IRB.CreateBr(end_bb);
}
//add offset to varstr
//create load
//create signed isub
//create icmp
//create jcc
//create next_bb
// add offset to varstr
// create load
// create signed isub
// create icmp
// create jcc
// create next_bb
}
/* since the call is the first instruction of the bb it is safe to
* replace it with a phi instruction */
BasicBlock::iterator ii(callInst);
ReplaceInstWithInst(callInst->getParent()->getInstList(), ii, PN);
}
return true;
}
bool CompareTransform::runOnModule(Module &M) {
if (getenv("AFL_QUIET") == NULL)
llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, extended by heiko@hexco.de\n";
llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, "
"extended by heiko@hexco.de\n";
transformCmps(M, true, true, true, true, true);
verifyModule(M);
return true;
}
static void registerCompTransPass(const PassManagerBuilder &,
legacy::PassManagerBase &PM) {
legacy::PassManagerBase &PM) {
auto p = new CompareTransform();
PM.add(p);