mirror of
https://github.com/AFLplusplus/AFLplusplus.git
synced 2025-06-15 19:38:09 +00:00
afl++ 2.52c initial commit
This commit is contained in:
306
llvm_mode/compare-transform-pass.so.cc
Normal file
306
llvm_mode/compare-transform-pass.so.cc
Normal file
@ -0,0 +1,306 @@
|
||||
/*
|
||||
* Copyright 2016 laf-intel
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "llvm/ADT/Statistic.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
|
||||
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/Pass.h"
|
||||
#include "llvm/Analysis/ValueTracking.h"
|
||||
|
||||
#include <set>
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
namespace {
|
||||
|
||||
class CompareTransform : public ModulePass {
|
||||
|
||||
public:
|
||||
static char ID;
|
||||
CompareTransform() : ModulePass(ID) {
|
||||
}
|
||||
|
||||
bool runOnModule(Module &M) override;
|
||||
|
||||
#if __clang_major__ < 4
|
||||
const char * getPassName() const override {
|
||||
#else
|
||||
StringRef getPassName() const override {
|
||||
#endif
|
||||
return "transforms compare functions";
|
||||
}
|
||||
private:
|
||||
bool transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp
|
||||
,const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp);
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
char CompareTransform::ID = 0;
|
||||
|
||||
bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp
|
||||
, const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp) {
|
||||
|
||||
std::vector<CallInst*> calls;
|
||||
LLVMContext &C = M.getContext();
|
||||
IntegerType *Int8Ty = IntegerType::getInt8Ty(C);
|
||||
IntegerType *Int32Ty = IntegerType::getInt32Ty(C);
|
||||
IntegerType *Int64Ty = IntegerType::getInt64Ty(C);
|
||||
Constant* c = M.getOrInsertFunction("tolower",
|
||||
Int32Ty,
|
||||
Int32Ty
|
||||
#if __clang_major__ < 7
|
||||
, nullptr
|
||||
#endif
|
||||
);
|
||||
Function* tolowerFn = cast<Function>(c);
|
||||
|
||||
/* iterate over all functions, bbs and instruction and add suitable calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp */
|
||||
for (auto &F : M) {
|
||||
for (auto &BB : F) {
|
||||
for(auto &IN: BB) {
|
||||
CallInst* callInst = nullptr;
|
||||
|
||||
if ((callInst = dyn_cast<CallInst>(&IN))) {
|
||||
|
||||
bool isStrcmp = processStrcmp;
|
||||
bool isMemcmp = processMemcmp;
|
||||
bool isStrncmp = processStrncmp;
|
||||
bool isStrcasecmp = processStrcasecmp;
|
||||
bool isStrncasecmp = processStrncasecmp;
|
||||
|
||||
Function *Callee = callInst->getCalledFunction();
|
||||
if (!Callee)
|
||||
continue;
|
||||
if (callInst->getCallingConv() != llvm::CallingConv::C)
|
||||
continue;
|
||||
StringRef FuncName = Callee->getName();
|
||||
isStrcmp &= !FuncName.compare(StringRef("strcmp"));
|
||||
isMemcmp &= !FuncName.compare(StringRef("memcmp"));
|
||||
isStrncmp &= !FuncName.compare(StringRef("strncmp"));
|
||||
isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp"));
|
||||
isStrncasecmp &= !FuncName.compare(StringRef("strncasecmp"));
|
||||
|
||||
if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp)
|
||||
continue;
|
||||
|
||||
/* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function prototype */
|
||||
FunctionType *FT = Callee->getFunctionType();
|
||||
|
||||
|
||||
isStrcmp &= FT->getNumParams() == 2 &&
|
||||
FT->getReturnType()->isIntegerTy(32) &&
|
||||
FT->getParamType(0) == FT->getParamType(1) &&
|
||||
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
|
||||
isStrcasecmp &= FT->getNumParams() == 2 &&
|
||||
FT->getReturnType()->isIntegerTy(32) &&
|
||||
FT->getParamType(0) == FT->getParamType(1) &&
|
||||
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
|
||||
isMemcmp &= FT->getNumParams() == 3 &&
|
||||
FT->getReturnType()->isIntegerTy(32) &&
|
||||
FT->getParamType(0)->isPointerTy() &&
|
||||
FT->getParamType(1)->isPointerTy() &&
|
||||
FT->getParamType(2)->isIntegerTy();
|
||||
isStrncmp &= FT->getNumParams() == 3 &&
|
||||
FT->getReturnType()->isIntegerTy(32) &&
|
||||
FT->getParamType(0) == FT->getParamType(1) &&
|
||||
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) &&
|
||||
FT->getParamType(2)->isIntegerTy();
|
||||
isStrncasecmp &= FT->getNumParams() == 3 &&
|
||||
FT->getReturnType()->isIntegerTy(32) &&
|
||||
FT->getParamType(0) == FT->getParamType(1) &&
|
||||
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) &&
|
||||
FT->getParamType(2)->isIntegerTy();
|
||||
|
||||
if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp)
|
||||
continue;
|
||||
|
||||
/* is a str{n,}{case,}cmp/memcmp, check is we have
|
||||
* str{case,}cmp(x, "const") or str{case,}cmp("const", x)
|
||||
* strn{case,}cmp(x, "const", ..) or strn{case,}cmp("const", x, ..)
|
||||
* memcmp(x, "const", ..) or memcmp("const", x, ..) */
|
||||
Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1);
|
||||
StringRef Str1, Str2;
|
||||
bool HasStr1 = getConstantStringInfo(Str1P, Str1);
|
||||
bool HasStr2 = getConstantStringInfo(Str2P, Str2);
|
||||
|
||||
/* handle cases of one string is const, one string is variable */
|
||||
if (!(HasStr1 ^ HasStr2))
|
||||
continue;
|
||||
|
||||
if (isMemcmp || isStrncmp || isStrncasecmp) {
|
||||
/* check if third operand is a constant integer
|
||||
* strlen("constStr") and sizeof() are treated as constant */
|
||||
Value *op2 = callInst->getArgOperand(2);
|
||||
ConstantInt* ilen = dyn_cast<ConstantInt>(op2);
|
||||
if (!ilen)
|
||||
continue;
|
||||
/* final precaution: if size of compare is larger than constant string skip it*/
|
||||
uint64_t literalLength = HasStr1 ? GetStringLength(Str1P) : GetStringLength(Str2P);
|
||||
if (literalLength < ilen->getZExtValue())
|
||||
continue;
|
||||
}
|
||||
|
||||
calls.push_back(callInst);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!calls.size())
|
||||
return false;
|
||||
errs() << "Replacing " << calls.size() << " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n";
|
||||
|
||||
for (auto &callInst: calls) {
|
||||
|
||||
Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1);
|
||||
StringRef Str1, Str2, ConstStr;
|
||||
Value *VarStr;
|
||||
bool HasStr1 = getConstantStringInfo(Str1P, Str1);
|
||||
getConstantStringInfo(Str2P, Str2);
|
||||
uint64_t constLen, sizedLen;
|
||||
bool isMemcmp = !callInst->getCalledFunction()->getName().compare(StringRef("memcmp"));
|
||||
bool isSizedcmp = isMemcmp
|
||||
|| !callInst->getCalledFunction()->getName().compare(StringRef("strncmp"))
|
||||
|| !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp"));
|
||||
bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare(StringRef("strcasecmp"))
|
||||
|| !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp"));
|
||||
|
||||
if (isSizedcmp) {
|
||||
Value *op2 = callInst->getArgOperand(2);
|
||||
ConstantInt* ilen = dyn_cast<ConstantInt>(op2);
|
||||
sizedLen = ilen->getZExtValue();
|
||||
}
|
||||
|
||||
if (HasStr1) {
|
||||
ConstStr = Str1;
|
||||
VarStr = Str2P;
|
||||
constLen = isMemcmp ? sizedLen : GetStringLength(Str1P);
|
||||
}
|
||||
else {
|
||||
ConstStr = Str2;
|
||||
VarStr = Str1P;
|
||||
constLen = isMemcmp ? sizedLen : GetStringLength(Str2P);
|
||||
}
|
||||
if (isSizedcmp && constLen > sizedLen) {
|
||||
constLen = sizedLen;
|
||||
}
|
||||
|
||||
errs() << callInst->getCalledFunction()->getName() << ": len " << constLen << ": " << ConstStr << "\n";
|
||||
|
||||
/* split before the call instruction */
|
||||
BasicBlock *bb = callInst->getParent();
|
||||
BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(callInst));
|
||||
BasicBlock *next_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
|
||||
BranchInst::Create(end_bb, next_bb);
|
||||
PHINode *PN = PHINode::Create(Int32Ty, constLen + 1, "cmp_phi");
|
||||
|
||||
TerminatorInst *term = bb->getTerminator();
|
||||
BranchInst::Create(next_bb, bb);
|
||||
term->eraseFromParent();
|
||||
|
||||
for (uint64_t i = 0; i < constLen; i++) {
|
||||
|
||||
BasicBlock *cur_bb = next_bb;
|
||||
|
||||
char c = isCaseInsensitive ? tolower(ConstStr[i]) : ConstStr[i];
|
||||
|
||||
|
||||
BasicBlock::iterator IP = next_bb->getFirstInsertionPt();
|
||||
IRBuilder<> IRB(&*IP);
|
||||
|
||||
Value* v = ConstantInt::get(Int64Ty, i);
|
||||
Value *ele = IRB.CreateInBoundsGEP(VarStr, v, "empty");
|
||||
Value *load = IRB.CreateLoad(ele);
|
||||
if (isCaseInsensitive) {
|
||||
// load >= 'A' && load <= 'Z' ? load | 0x020 : load
|
||||
std::vector<Value *> args;
|
||||
args.push_back(load);
|
||||
load = IRB.CreateCall(tolowerFn, args, "tmp");
|
||||
}
|
||||
Value *isub;
|
||||
if (HasStr1)
|
||||
isub = IRB.CreateSub(ConstantInt::get(Int8Ty, c), load);
|
||||
else
|
||||
isub = IRB.CreateSub(load, ConstantInt::get(Int8Ty, c));
|
||||
|
||||
Value *sext = IRB.CreateSExt(isub, Int32Ty);
|
||||
PN->addIncoming(sext, cur_bb);
|
||||
|
||||
|
||||
if (i < constLen - 1) {
|
||||
next_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
|
||||
BranchInst::Create(end_bb, next_bb);
|
||||
|
||||
TerminatorInst *term = cur_bb->getTerminator();
|
||||
Value *icmp = IRB.CreateICmpEQ(isub, ConstantInt::get(Int8Ty, 0));
|
||||
IRB.CreateCondBr(icmp, next_bb, end_bb);
|
||||
term->eraseFromParent();
|
||||
} else {
|
||||
//IRB.CreateBr(end_bb);
|
||||
}
|
||||
|
||||
//add offset to varstr
|
||||
//create load
|
||||
//create signed isub
|
||||
//create icmp
|
||||
//create jcc
|
||||
//create next_bb
|
||||
}
|
||||
|
||||
/* since the call is the first instruction of the bb it is safe to
|
||||
* replace it with a phi instruction */
|
||||
BasicBlock::iterator ii(callInst);
|
||||
ReplaceInstWithInst(callInst->getParent()->getInstList(), ii, PN);
|
||||
}
|
||||
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CompareTransform::runOnModule(Module &M) {
|
||||
|
||||
llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, extended by heiko@hexco.de\n";
|
||||
transformCmps(M, true, true, true, true, true);
|
||||
verifyModule(M);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void registerCompTransPass(const PassManagerBuilder &,
|
||||
legacy::PassManagerBase &PM) {
|
||||
|
||||
auto p = new CompareTransform();
|
||||
PM.add(p);
|
||||
|
||||
}
|
||||
|
||||
static RegisterStandardPasses RegisterCompTransPass(
|
||||
PassManagerBuilder::EP_OptimizerLast, registerCompTransPass);
|
||||
|
||||
static RegisterStandardPasses RegisterCompTransPass0(
|
||||
PassManagerBuilder::EP_EnabledOnOptLevel0, registerCompTransPass);
|
||||
|
Reference in New Issue
Block a user