mirror of
https://github.com/AFLplusplus/AFLplusplus.git
synced 2025-06-08 16:21:32 +00:00
329 lines
12 KiB
C++
329 lines
12 KiB
C++
/*
|
|
* Copyright 2016 laf-intel
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <unistd.h>
|
|
|
|
#include "llvm/ADT/Statistic.h"
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/LegacyPassManager.h"
|
|
#include "llvm/IR/Module.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
|
|
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
|
#include "llvm/IR/Verifier.h"
|
|
#include "llvm/Pass.h"
|
|
#include "llvm/Analysis/ValueTracking.h"
|
|
|
|
#include <set>
|
|
|
|
using namespace llvm;
|
|
|
|
namespace {
|
|
|
|
class CompareTransform : public ModulePass {
|
|
|
|
public:
|
|
static char ID;
|
|
CompareTransform() : ModulePass(ID) {
|
|
}
|
|
|
|
bool runOnModule(Module &M) override;
|
|
|
|
#if LLVM_VERSION_MAJOR < 4
|
|
const char * getPassName() const override {
|
|
#else
|
|
StringRef getPassName() const override {
|
|
#endif
|
|
return "transforms compare functions";
|
|
}
|
|
private:
|
|
bool transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp
|
|
,const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp);
|
|
};
|
|
}
|
|
|
|
|
|
char CompareTransform::ID = 0;
|
|
|
|
bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp
|
|
, const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp) {
|
|
|
|
std::vector<CallInst*> calls;
|
|
LLVMContext &C = M.getContext();
|
|
IntegerType *Int8Ty = IntegerType::getInt8Ty(C);
|
|
IntegerType *Int32Ty = IntegerType::getInt32Ty(C);
|
|
IntegerType *Int64Ty = IntegerType::getInt64Ty(C);
|
|
|
|
#if LLVM_VERSION_MAJOR < 9
|
|
Constant*
|
|
#else
|
|
FunctionCallee
|
|
#endif
|
|
c = M.getOrInsertFunction("tolower",
|
|
Int32Ty,
|
|
Int32Ty
|
|
#if LLVM_VERSION_MAJOR < 5
|
|
, nullptr
|
|
#endif
|
|
);
|
|
Function* tolowerFn = cast<Function>(c);
|
|
|
|
/* iterate over all functions, bbs and instruction and add suitable calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp */
|
|
for (auto &F : M) {
|
|
for (auto &BB : F) {
|
|
for(auto &IN: BB) {
|
|
CallInst* callInst = nullptr;
|
|
|
|
if ((callInst = dyn_cast<CallInst>(&IN))) {
|
|
|
|
bool isStrcmp = processStrcmp;
|
|
bool isMemcmp = processMemcmp;
|
|
bool isStrncmp = processStrncmp;
|
|
bool isStrcasecmp = processStrcasecmp;
|
|
bool isStrncasecmp = processStrncasecmp;
|
|
|
|
Function *Callee = callInst->getCalledFunction();
|
|
if (!Callee)
|
|
continue;
|
|
if (callInst->getCallingConv() != llvm::CallingConv::C)
|
|
continue;
|
|
StringRef FuncName = Callee->getName();
|
|
isStrcmp &= !FuncName.compare(StringRef("strcmp"));
|
|
isMemcmp &= !FuncName.compare(StringRef("memcmp"));
|
|
isStrncmp &= !FuncName.compare(StringRef("strncmp"));
|
|
isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp"));
|
|
isStrncasecmp &= !FuncName.compare(StringRef("strncasecmp"));
|
|
|
|
if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp)
|
|
continue;
|
|
|
|
/* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function prototype */
|
|
FunctionType *FT = Callee->getFunctionType();
|
|
|
|
|
|
isStrcmp &= FT->getNumParams() == 2 &&
|
|
FT->getReturnType()->isIntegerTy(32) &&
|
|
FT->getParamType(0) == FT->getParamType(1) &&
|
|
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
|
|
isStrcasecmp &= FT->getNumParams() == 2 &&
|
|
FT->getReturnType()->isIntegerTy(32) &&
|
|
FT->getParamType(0) == FT->getParamType(1) &&
|
|
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
|
|
isMemcmp &= FT->getNumParams() == 3 &&
|
|
FT->getReturnType()->isIntegerTy(32) &&
|
|
FT->getParamType(0)->isPointerTy() &&
|
|
FT->getParamType(1)->isPointerTy() &&
|
|
FT->getParamType(2)->isIntegerTy();
|
|
isStrncmp &= FT->getNumParams() == 3 &&
|
|
FT->getReturnType()->isIntegerTy(32) &&
|
|
FT->getParamType(0) == FT->getParamType(1) &&
|
|
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) &&
|
|
FT->getParamType(2)->isIntegerTy();
|
|
isStrncasecmp &= FT->getNumParams() == 3 &&
|
|
FT->getReturnType()->isIntegerTy(32) &&
|
|
FT->getParamType(0) == FT->getParamType(1) &&
|
|
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) &&
|
|
FT->getParamType(2)->isIntegerTy();
|
|
|
|
if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp)
|
|
continue;
|
|
|
|
/* is a str{n,}{case,}cmp/memcmp, check if we have
|
|
* str{case,}cmp(x, "const") or str{case,}cmp("const", x)
|
|
* strn{case,}cmp(x, "const", ..) or strn{case,}cmp("const", x, ..)
|
|
* memcmp(x, "const", ..) or memcmp("const", x, ..) */
|
|
Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1);
|
|
StringRef Str1, Str2;
|
|
bool HasStr1 = getConstantStringInfo(Str1P, Str1);
|
|
bool HasStr2 = getConstantStringInfo(Str2P, Str2);
|
|
|
|
/* handle cases of one string is const, one string is variable */
|
|
if (!(HasStr1 ^ HasStr2))
|
|
continue;
|
|
|
|
if (isMemcmp || isStrncmp || isStrncasecmp) {
|
|
/* check if third operand is a constant integer
|
|
* strlen("constStr") and sizeof() are treated as constant */
|
|
Value *op2 = callInst->getArgOperand(2);
|
|
ConstantInt* ilen = dyn_cast<ConstantInt>(op2);
|
|
if (!ilen)
|
|
continue;
|
|
/* final precaution: if size of compare is larger than constant string skip it*/
|
|
uint64_t literalLength = HasStr1 ? GetStringLength(Str1P) : GetStringLength(Str2P);
|
|
if (literalLength < ilen->getZExtValue())
|
|
continue;
|
|
}
|
|
|
|
calls.push_back(callInst);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if (!calls.size())
|
|
return false;
|
|
errs() << "Replacing " << calls.size() << " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n";
|
|
|
|
for (auto &callInst: calls) {
|
|
|
|
Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1);
|
|
StringRef Str1, Str2, ConstStr;
|
|
std::string TmpConstStr;
|
|
Value *VarStr;
|
|
bool HasStr1 = getConstantStringInfo(Str1P, Str1);
|
|
getConstantStringInfo(Str2P, Str2);
|
|
uint64_t constLen, sizedLen;
|
|
bool isMemcmp = !callInst->getCalledFunction()->getName().compare(StringRef("memcmp"));
|
|
bool isSizedcmp = isMemcmp
|
|
|| !callInst->getCalledFunction()->getName().compare(StringRef("strncmp"))
|
|
|| !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp"));
|
|
bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare(StringRef("strcasecmp"))
|
|
|| !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp"));
|
|
|
|
if (isSizedcmp) {
|
|
Value *op2 = callInst->getArgOperand(2);
|
|
ConstantInt* ilen = dyn_cast<ConstantInt>(op2);
|
|
sizedLen = ilen->getZExtValue();
|
|
}
|
|
|
|
if (HasStr1) {
|
|
TmpConstStr = Str1.str();
|
|
VarStr = Str2P;
|
|
constLen = isMemcmp ? sizedLen : GetStringLength(Str1P);
|
|
}
|
|
else {
|
|
TmpConstStr = Str2.str();
|
|
VarStr = Str1P;
|
|
constLen = isMemcmp ? sizedLen : GetStringLength(Str2P);
|
|
}
|
|
|
|
/* properly handle zero terminated C strings by adding the terminating 0 to
|
|
* the StringRef (in comparison to std::string a StringRef has built-in
|
|
* runtime bounds checking, which makes debugging easier) */
|
|
TmpConstStr.append("\0", 1); ConstStr = StringRef(TmpConstStr);
|
|
|
|
if (isSizedcmp && constLen > sizedLen) {
|
|
constLen = sizedLen;
|
|
}
|
|
|
|
errs() << callInst->getCalledFunction()->getName() << ": len " << constLen << ": " << ConstStr << "\n";
|
|
|
|
/* split before the call instruction */
|
|
BasicBlock *bb = callInst->getParent();
|
|
BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(callInst));
|
|
BasicBlock *next_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
|
|
BranchInst::Create(end_bb, next_bb);
|
|
PHINode *PN = PHINode::Create(Int32Ty, constLen + 1, "cmp_phi");
|
|
|
|
#if LLVM_VERSION_MAJOR < 8
|
|
TerminatorInst *term = bb->getTerminator();
|
|
#else
|
|
Instruction *term = bb->getTerminator();
|
|
#endif
|
|
BranchInst::Create(next_bb, bb);
|
|
term->eraseFromParent();
|
|
|
|
for (uint64_t i = 0; i < constLen; i++) {
|
|
|
|
BasicBlock *cur_bb = next_bb;
|
|
|
|
char c = isCaseInsensitive ? tolower(ConstStr[i]) : ConstStr[i];
|
|
|
|
|
|
BasicBlock::iterator IP = next_bb->getFirstInsertionPt();
|
|
IRBuilder<> IRB(&*IP);
|
|
|
|
Value* v = ConstantInt::get(Int64Ty, i);
|
|
Value *ele = IRB.CreateInBoundsGEP(VarStr, v, "empty");
|
|
Value *load = IRB.CreateLoad(ele);
|
|
if (isCaseInsensitive) {
|
|
// load >= 'A' && load <= 'Z' ? load | 0x020 : load
|
|
std::vector<Value *> args;
|
|
args.push_back(load);
|
|
load = IRB.CreateCall(tolowerFn, args, "tmp");
|
|
}
|
|
Value *isub;
|
|
if (HasStr1)
|
|
isub = IRB.CreateSub(ConstantInt::get(Int8Ty, c), load);
|
|
else
|
|
isub = IRB.CreateSub(load, ConstantInt::get(Int8Ty, c));
|
|
|
|
Value *sext = IRB.CreateSExt(isub, Int32Ty);
|
|
PN->addIncoming(sext, cur_bb);
|
|
|
|
|
|
if (i < constLen - 1) {
|
|
next_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
|
|
BranchInst::Create(end_bb, next_bb);
|
|
|
|
#if LLVM_VERSION_MAJOR < 8
|
|
TerminatorInst *term = cur_bb->getTerminator();
|
|
#else
|
|
Instruction *term = cur_bb->getTerminator();
|
|
#endif
|
|
Value *icmp = IRB.CreateICmpEQ(isub, ConstantInt::get(Int8Ty, 0));
|
|
IRB.CreateCondBr(icmp, next_bb, end_bb);
|
|
term->eraseFromParent();
|
|
} else {
|
|
//IRB.CreateBr(end_bb);
|
|
}
|
|
|
|
//add offset to varstr
|
|
//create load
|
|
//create signed isub
|
|
//create icmp
|
|
//create jcc
|
|
//create next_bb
|
|
}
|
|
|
|
/* since the call is the first instruction of the bb it is safe to
|
|
* replace it with a phi instruction */
|
|
BasicBlock::iterator ii(callInst);
|
|
ReplaceInstWithInst(callInst->getParent()->getInstList(), ii, PN);
|
|
}
|
|
|
|
|
|
return true;
|
|
}
|
|
|
|
bool CompareTransform::runOnModule(Module &M) {
|
|
|
|
if (getenv("AFL_QUIET") == NULL)
|
|
llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, extended by heiko@hexco.de\n";
|
|
transformCmps(M, true, true, true, true, true);
|
|
verifyModule(M);
|
|
|
|
return true;
|
|
}
|
|
|
|
static void registerCompTransPass(const PassManagerBuilder &,
|
|
legacy::PassManagerBase &PM) {
|
|
|
|
auto p = new CompareTransform();
|
|
PM.add(p);
|
|
|
|
}
|
|
|
|
static RegisterStandardPasses RegisterCompTransPass(
|
|
PassManagerBuilder::EP_OptimizerLast, registerCompTransPass);
|
|
|
|
static RegisterStandardPasses RegisterCompTransPass0(
|
|
PassManagerBuilder::EP_EnabledOnOptLevel0, registerCompTransPass);
|
|
|