AFLplusplus/llvm_mode/compare-transform-pass.so.cc

377 lines
12 KiB
C++

/*
* Copyright 2016 laf-intel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Pass.h"
#include "llvm/Analysis/ValueTracking.h"
#include <set>
using namespace llvm;
namespace {
class CompareTransform : public ModulePass {
public:
static char ID;
CompareTransform() : ModulePass(ID) {
}
bool runOnModule(Module &M) override;
#if LLVM_VERSION_MAJOR < 4
const char *getPassName() const override {
#else
StringRef getPassName() const override {
#endif
return "transforms compare functions";
}
private:
bool transformCmps(Module &M, const bool processStrcmp,
const bool processMemcmp, const bool processStrncmp,
const bool processStrcasecmp,
const bool processStrncasecmp);
};
} // namespace
char CompareTransform::ID = 0;
bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
const bool processMemcmp,
const bool processStrncmp,
const bool processStrcasecmp,
const bool processStrncasecmp) {
std::vector<CallInst *> calls;
LLVMContext & C = M.getContext();
IntegerType * Int8Ty = IntegerType::getInt8Ty(C);
IntegerType * Int32Ty = IntegerType::getInt32Ty(C);
IntegerType * Int64Ty = IntegerType::getInt64Ty(C);
#if LLVM_VERSION_MAJOR < 9
Constant *
#else
FunctionCallee
#endif
c = M.getOrInsertFunction("tolower", Int32Ty, Int32Ty
#if LLVM_VERSION_MAJOR < 5
,
nullptr
#endif
);
#if LLVM_VERSION_MAJOR < 9
Function *tolowerFn = cast<Function>(c);
#else
FunctionCallee tolowerFn = c;
#endif
/* iterate over all functions, bbs and instruction and add suitable calls to
* strcmp/memcmp/strncmp/strcasecmp/strncasecmp */
for (auto &F : M) {
for (auto &BB : F) {
for (auto &IN : BB) {
CallInst *callInst = nullptr;
if ((callInst = dyn_cast<CallInst>(&IN))) {
bool isStrcmp = processStrcmp;
bool isMemcmp = processMemcmp;
bool isStrncmp = processStrncmp;
bool isStrcasecmp = processStrcasecmp;
bool isStrncasecmp = processStrncasecmp;
Function *Callee = callInst->getCalledFunction();
if (!Callee) continue;
if (callInst->getCallingConv() != llvm::CallingConv::C) continue;
StringRef FuncName = Callee->getName();
isStrcmp &= !FuncName.compare(StringRef("strcmp"));
isMemcmp &= !FuncName.compare(StringRef("memcmp"));
isStrncmp &= !FuncName.compare(StringRef("strncmp"));
isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp"));
isStrncasecmp &= !FuncName.compare(StringRef("strncasecmp"));
if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
!isStrncasecmp)
continue;
/* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function
* prototype */
FunctionType *FT = Callee->getFunctionType();
isStrcmp &=
FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0) == FT->getParamType(1) &&
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
isStrcasecmp &=
FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0) == FT->getParamType(1) &&
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
isMemcmp &= FT->getNumParams() == 3 &&
FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0)->isPointerTy() &&
FT->getParamType(1)->isPointerTy() &&
FT->getParamType(2)->isIntegerTy();
isStrncmp &= FT->getNumParams() == 3 &&
FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0) == FT->getParamType(1) &&
FT->getParamType(0) ==
IntegerType::getInt8PtrTy(M.getContext()) &&
FT->getParamType(2)->isIntegerTy();
isStrncasecmp &= FT->getNumParams() == 3 &&
FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0) == FT->getParamType(1) &&
FT->getParamType(0) ==
IntegerType::getInt8PtrTy(M.getContext()) &&
FT->getParamType(2)->isIntegerTy();
if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
!isStrncasecmp)
continue;
/* is a str{n,}{case,}cmp/memcmp, check if we have
* str{case,}cmp(x, "const") or str{case,}cmp("const", x)
* strn{case,}cmp(x, "const", ..) or strn{case,}cmp("const", x, ..)
* memcmp(x, "const", ..) or memcmp("const", x, ..) */
Value *Str1P = callInst->getArgOperand(0),
*Str2P = callInst->getArgOperand(1);
StringRef Str1, Str2;
bool HasStr1 = getConstantStringInfo(Str1P, Str1);
bool HasStr2 = getConstantStringInfo(Str2P, Str2);
/* handle cases of one string is const, one string is variable */
if (!(HasStr1 ^ HasStr2)) continue;
if (isMemcmp || isStrncmp || isStrncasecmp) {
/* check if third operand is a constant integer
* strlen("constStr") and sizeof() are treated as constant */
Value * op2 = callInst->getArgOperand(2);
ConstantInt *ilen = dyn_cast<ConstantInt>(op2);
if (!ilen) continue;
/* final precaution: if size of compare is larger than constant
* string skip it*/
uint64_t literalLength =
HasStr1 ? GetStringLength(Str1P) : GetStringLength(Str2P);
if (literalLength < ilen->getZExtValue()) continue;
}
calls.push_back(callInst);
}
}
}
}
if (!calls.size()) return false;
errs() << "Replacing " << calls.size()
<< " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n";
for (auto &callInst : calls) {
Value *Str1P = callInst->getArgOperand(0),
*Str2P = callInst->getArgOperand(1);
StringRef Str1, Str2, ConstStr;
std::string TmpConstStr;
Value * VarStr;
bool HasStr1 = getConstantStringInfo(Str1P, Str1);
getConstantStringInfo(Str2P, Str2);
uint64_t constLen, sizedLen;
bool isMemcmp =
!callInst->getCalledFunction()->getName().compare(StringRef("memcmp"));
bool isSizedcmp = isMemcmp ||
!callInst->getCalledFunction()->getName().compare(
StringRef("strncmp")) ||
!callInst->getCalledFunction()->getName().compare(
StringRef("strncasecmp"));
bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare(
StringRef("strcasecmp")) ||
!callInst->getCalledFunction()->getName().compare(
StringRef("strncasecmp"));
if (isSizedcmp) {
Value * op2 = callInst->getArgOperand(2);
ConstantInt *ilen = dyn_cast<ConstantInt>(op2);
sizedLen = ilen->getZExtValue();
} else {
sizedLen = 0;
}
if (HasStr1) {
TmpConstStr = Str1.str();
VarStr = Str2P;
constLen = isMemcmp ? sizedLen : GetStringLength(Str1P);
} else {
TmpConstStr = Str2.str();
VarStr = Str1P;
constLen = isMemcmp ? sizedLen : GetStringLength(Str2P);
}
/* properly handle zero terminated C strings by adding the terminating 0 to
* the StringRef (in comparison to std::string a StringRef has built-in
* runtime bounds checking, which makes debugging easier) */
TmpConstStr.append("\0", 1);
ConstStr = StringRef(TmpConstStr);
if (isSizedcmp && constLen > sizedLen) { constLen = sizedLen; }
errs() << callInst->getCalledFunction()->getName() << ": len " << constLen
<< ": " << ConstStr << "\n";
/* split before the call instruction */
BasicBlock *bb = callInst->getParent();
BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(callInst));
BasicBlock *next_bb =
BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
BranchInst::Create(end_bb, next_bb);
PHINode *PN = PHINode::Create(Int32Ty, constLen + 1, "cmp_phi");
#if LLVM_VERSION_MAJOR < 8
TerminatorInst *term = bb->getTerminator();
#else
Instruction *term = bb->getTerminator();
#endif
BranchInst::Create(next_bb, bb);
term->eraseFromParent();
for (uint64_t i = 0; i < constLen; i++) {
BasicBlock *cur_bb = next_bb;
char c = isCaseInsensitive ? tolower(ConstStr[i]) : ConstStr[i];
BasicBlock::iterator IP = next_bb->getFirstInsertionPt();
IRBuilder<> IRB(&*IP);
Value *v = ConstantInt::get(Int64Ty, i);
Value *ele = IRB.CreateInBoundsGEP(VarStr, v, "empty");
Value *load = IRB.CreateLoad(ele);
if (isCaseInsensitive) {
// load >= 'A' && load <= 'Z' ? load | 0x020 : load
std::vector<Value *> args;
args.push_back(load);
load = IRB.CreateCall(tolowerFn, args, "tmp");
load = IRB.CreateTrunc(load, Int8Ty);
}
Value *isub;
if (HasStr1)
isub = IRB.CreateSub(ConstantInt::get(Int8Ty, c), load);
else
isub = IRB.CreateSub(load, ConstantInt::get(Int8Ty, c));
Value *sext = IRB.CreateSExt(isub, Int32Ty);
PN->addIncoming(sext, cur_bb);
if (i < constLen - 1) {
next_bb =
BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
BranchInst::Create(end_bb, next_bb);
Value *icmp = IRB.CreateICmpEQ(isub, ConstantInt::get(Int8Ty, 0));
IRB.CreateCondBr(icmp, next_bb, end_bb);
cur_bb->getTerminator()->eraseFromParent();
} else {
// IRB.CreateBr(end_bb);
}
// add offset to varstr
// create load
// create signed isub
// create icmp
// create jcc
// create next_bb
}
/* since the call is the first instruction of the bb it is safe to
* replace it with a phi instruction */
BasicBlock::iterator ii(callInst);
ReplaceInstWithInst(callInst->getParent()->getInstList(), ii, PN);
}
return true;
}
bool CompareTransform::runOnModule(Module &M) {
if (getenv("AFL_QUIET") == NULL)
llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, "
"extended by heiko@hexco.de\n";
transformCmps(M, true, true, true, true, true);
verifyModule(M);
return true;
}
static void registerCompTransPass(const PassManagerBuilder &,
legacy::PassManagerBase &PM) {
auto p = new CompareTransform();
PM.add(p);
}
static RegisterStandardPasses RegisterCompTransPass(
PassManagerBuilder::EP_OptimizerLast, registerCompTransPass);
static RegisterStandardPasses RegisterCompTransPass0(
PassManagerBuilder::EP_EnabledOnOptLevel0, registerCompTransPass);