afl++ 2.52c initial commit

This commit is contained in:
van Hauser
2019-05-28 16:40:24 +02:00
parent 1b3d018d35
commit f367728c44
194 changed files with 29133 additions and 0 deletions

121
llvm_mode/Makefile Normal file
View File

@ -0,0 +1,121 @@
#
# american fuzzy lop - LLVM instrumentation
# -----------------------------------------
#
# Written by Laszlo Szekeres <lszekeres@google.com> and
# Michal Zalewski <lcamtuf@google.com>
#
# LLVM integration design comes from Laszlo Szekeres.
#
# Copyright 2015, 2016 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
PREFIX ?= /usr/local
HELPER_PATH = $(PREFIX)/lib/afl
BIN_PATH = $(PREFIX)/bin
VERSION = $(shell grep '^\#define VERSION ' ../config.h | cut -d '"' -f2)
LLVM_CONFIG ?= llvm-config
CFLAGS ?= -O3 -funroll-loops
CFLAGS += -Wall -D_FORTIFY_SOURCE=2 -g -Wno-pointer-sign \
-DAFL_PATH=\"$(HELPER_PATH)\" -DBIN_PATH=\"$(BIN_PATH)\" \
-DVERSION=\"$(VERSION)\"
ifdef AFL_TRACE_PC
CFLAGS += -DUSE_TRACE_PC=1
endif
CXXFLAGS ?= -O3 -funroll-loops
CXXFLAGS += -Wall -D_FORTIFY_SOURCE=2 -g -Wno-pointer-sign \
-DVERSION=\"$(VERSION)\" -Wno-variadic-macros
CLANG_CFL = `$(LLVM_CONFIG) --cxxflags` -fno-rtti -fpic $(CXXFLAGS)
CLANG_LFL = `$(LLVM_CONFIG) --ldflags` $(LDFLAGS)
# User teor2345 reports that this is required to make things work on MacOS X.
ifeq "$(shell uname)" "Darwin"
CLANG_LFL += -Wl,-flat_namespace -Wl,-undefined,suppress
endif
# We were using llvm-config --bindir to get the location of clang, but
# this seems to be busted on some distros, so using the one in $PATH is
# probably better.
ifeq "$(origin CC)" "default"
CC = clang
CXX = clang++
endif
ifndef AFL_TRACE_PC
PROGS = ../afl-clang-fast ../afl-llvm-pass.so ../afl-llvm-rt.o ../afl-llvm-rt-32.o ../afl-llvm-rt-64.o ../compare-transform-pass.so ../split-compares-pass.so ../split-switches-pass.so
else
PROGS = ../afl-clang-fast ../afl-llvm-rt.o ../afl-llvm-rt-32.o ../afl-llvm-rt-64.o ../compare-transform-pass.so ../split-compares-pass.so ../split-switches-pass.so
endif
all: test_deps $(PROGS) test_build all_done
test_deps:
ifndef AFL_TRACE_PC
@echo "[*] Checking for working 'llvm-config'..."
@which $(LLVM_CONFIG) >/dev/null 2>&1 || ( echo "[-] Oops, can't find 'llvm-config'. Install clang or set \$$LLVM_CONFIG or \$$PATH beforehand."; echo " (Sometimes, the binary will be named llvm-config-3.5 or something like that.)"; exit 1 )
else
@echo "[!] Note: using -fsanitize=trace-pc mode (this will fail with older LLVM)."
endif
@echo "[*] Checking for working '$(CC)'..."
@which $(CC) >/dev/null 2>&1 || ( echo "[-] Oops, can't find '$(CC)'. Make sure that it's in your \$$PATH (or set \$$CC and \$$CXX)."; exit 1 )
@echo "[*] Checking for '../afl-showmap'..."
@test -f ../afl-showmap || ( echo "[-] Oops, can't find '../afl-showmap'. Be sure to compile AFL first."; exit 1 )
@echo "[+] All set and ready to build."
../afl-clang-fast: afl-clang-fast.c | test_deps
$(CC) $(CFLAGS) $< -o $@ $(LDFLAGS)
ln -sf afl-clang-fast ../afl-clang-fast++
../afl-llvm-pass.so: afl-llvm-pass.so.cc | test_deps
$(CXX) $(CLANG_CFL) -shared $< -o $@ $(CLANG_LFL)
# laf
../split-switches-pass.so: split-switches-pass.so.cc | test_deps
$(CXX) $(CLANG_CFL) -shared $< -o $@ $(CLANG_LFL)
../compare-transform-pass.so: compare-transform-pass.so.cc | test_deps
$(CXX) $(CLANG_CFL) -shared $< -o $@ $(CLANG_LFL)
../split-compares-pass.so: split-compares-pass.so.cc | test_deps
$(CXX) $(CLANG_CFL) -shared $< -o $@ $(CLANG_LFL)
# /laf
../afl-llvm-rt.o: afl-llvm-rt.o.c | test_deps
$(CC) $(CFLAGS) -fPIC -c $< -o $@
../afl-llvm-rt-32.o: afl-llvm-rt.o.c | test_deps
@printf "[*] Building 32-bit variant of the runtime (-m32)... "
@$(CC) $(CFLAGS) -m32 -fPIC -c $< -o $@ 2>/dev/null; if [ "$$?" = "0" ]; then echo "success!"; else echo "failed (that's fine)"; fi
../afl-llvm-rt-64.o: afl-llvm-rt.o.c | test_deps
@printf "[*] Building 64-bit variant of the runtime (-m64)... "
@$(CC) $(CFLAGS) -m64 -fPIC -c $< -o $@ 2>/dev/null; if [ "$$?" = "0" ]; then echo "success!"; else echo "failed (that's fine)"; fi
test_build: $(PROGS)
@echo "[*] Testing the CC wrapper and instrumentation output..."
unset AFL_USE_ASAN AFL_USE_MSAN AFL_INST_RATIO; AFL_QUIET=1 AFL_PATH=. AFL_CC=$(CC) ../afl-clang-fast $(CFLAGS) ../test-instr.c -o test-instr $(LDFLAGS)
echo 0 | ../afl-showmap -m none -q -o .test-instr0 ./test-instr
echo 1 | ../afl-showmap -m none -q -o .test-instr1 ./test-instr
@rm -f test-instr
@cmp -s .test-instr0 .test-instr1; DR="$$?"; rm -f .test-instr0 .test-instr1; if [ "$$DR" = "0" ]; then echo; echo "Oops, the instrumentation does not seem to be behaving correctly!"; echo; echo "Please ping <lcamtuf@google.com> to troubleshoot the issue."; echo; exit 1; fi
@echo "[+] All right, the instrumentation seems to be working!"
all_done: test_build
@echo "[+] All done! You can now use '../afl-clang-fast' to compile programs."
.NOTPARALLEL: clean
clean:
rm -f *.o *.so *~ a.out core core.[1-9][0-9]* test-instr .test-instr0 .test-instr1
rm -f $(PROGS) ../afl-clang-fast++

View File

@ -0,0 +1,20 @@
Usage
=====
By default the passes will not run when you compile programs using
afl-clang-fast. Hence, you can use AFL as usual.
To enable the passes you must set environment variables before you
compile the target project.
The following options exist:
export LAF_SPLIT_SWITCHES=1 Enables the split-switches pass.
export LAF_TRANSFORM_COMPARES=1 Enables the transform-compares pass
(strcmp, memcmp, strncmp, strcasecmp, strncasecmp).
export LAF_SPLIT_COMPARES=1 Enables the split-compares pass.
By default it will split all compares with a bit width <= 64 bits.
You can change this behaviour by setting
export LAF_SPLIT_COMPARES_BITW=<bit_width>.

192
llvm_mode/README.llvm Normal file
View File

@ -0,0 +1,192 @@
============================================
Fast LLVM-based instrumentation for afl-fuzz
============================================
(See ../docs/README for the general instruction manual.)
1) Introduction
---------------
!!! This works with LLVM up to version 6 !!!
The code in this directory allows you to instrument programs for AFL using
true compiler-level instrumentation, instead of the more crude
assembly-level rewriting approach taken by afl-gcc and afl-clang. This has
several interesting properties:
- The compiler can make many optimizations that are hard to pull off when
manually inserting assembly. As a result, some slow, CPU-bound programs will
run up to around 2x faster.
The gains are less pronounced for fast binaries, where the speed is limited
chiefly by the cost of creating new processes. In such cases, the gain will
probably stay within 10%.
- The instrumentation is CPU-independent. At least in principle, you should
be able to rely on it to fuzz programs on non-x86 architectures (after
building afl-fuzz with AFL_NO_X86=1).
- The instrumentation can cope a bit better with multi-threaded targets.
- Because the feature relies on the internals of LLVM, it is clang-specific
and will *not* work with GCC.
Once this implementation is shown to be sufficiently robust and portable, it
will probably replace afl-clang. For now, it can be built separately and
co-exists with the original code.
The idea and much of the implementation comes from Laszlo Szekeres.
2) How to use
-------------
In order to leverage this mechanism, you need to have clang installed on your
system. You should also make sure that the llvm-config tool is in your path
(or pointed to via LLVM_CONFIG in the environment).
Unfortunately, some systems that do have clang come without llvm-config or the
LLVM development headers; one example of this is FreeBSD. FreeBSD users will
also run into problems with clang being built statically and not being able to
load modules (you'll see "Service unavailable" when loading afl-llvm-pass.so).
To solve all your problems, you can grab pre-built binaries for your OS from:
http://llvm.org/releases/download.html
...and then put the bin/ directory from the tarball at the beginning of your
$PATH when compiling the feature and building packages later on. You don't need
to be root for that.
To build the instrumentation itself, type 'make'. This will generate binaries
called afl-clang-fast and afl-clang-fast++ in the parent directory. Once this
is done, you can instrument third-party code in a way similar to the standard
operating mode of AFL, e.g.:
CC=/path/to/afl/afl-clang-fast ./configure [...options...]
make
Be sure to also include CXX set to afl-clang-fast++ for C++ code.
The tool honors roughly the same environmental variables as afl-gcc (see
../docs/env_variables.txt). This includes AFL_INST_RATIO, AFL_USE_ASAN,
AFL_HARDEN, and AFL_DONT_OPTIMIZE.
Note: if you want the LLVM helper to be installed on your system for all
users, you need to build it before issuing 'make install' in the parent
directory.
3) Gotchas, feedback, bugs
--------------------------
This is an early-stage mechanism, so field reports are welcome. You can send bug
reports to <afl-users@googlegroups.com>.
4) Bonus feature #1: deferred instrumentation
---------------------------------------------
AFL tries to optimize performance by executing the targeted binary just once,
stopping it just before main(), and then cloning this "master" process to get
a steady supply of targets to fuzz.
Although this approach eliminates much of the OS-, linker- and libc-level
costs of executing the program, it does not always help with binaries that
perform other time-consuming initialization steps - say, parsing a large config
file before getting to the fuzzed data.
In such cases, it's beneficial to initialize the forkserver a bit later, once
most of the initialization work is already done, but before the binary attempts
to read the fuzzed input and parse it; in some cases, this can offer a 10x+
performance gain. You can implement delayed initialization in LLVM mode in a
fairly simple way.
First, find a suitable location in the code where the delayed cloning can
take place. This needs to be done with *extreme* care to avoid breaking the
binary. In particular, the program will probably malfunction if you select
a location after:
- The creation of any vital threads or child processes - since the forkserver
can't clone them easily.
- The initialization of timers via setitimer() or equivalent calls.
- The creation of temporary files, network sockets, offset-sensitive file
descriptors, and similar shared-state resources - but only provided that
their state meaningfully influences the behavior of the program later on.
- Any access to the fuzzed input, including reading the metadata about its
size.
With the location selected, add this code in the appropriate spot:
#ifdef __AFL_HAVE_MANUAL_CONTROL
__AFL_INIT();
#endif
You don't need the #ifdef guards, but including them ensures that the program
will keep working normally when compiled with a tool other than afl-clang-fast.
Finally, recompile the program with afl-clang-fast (afl-gcc or afl-clang will
*not* generate a deferred-initialization binary) - and you should be all set!
5) Bonus feature #2: persistent mode
------------------------------------
Some libraries provide APIs that are stateless, or whose state can be reset in
between processing different input files. When such a reset is performed, a
single long-lived process can be reused to try out multiple test cases,
eliminating the need for repeated fork() calls and the associated OS overhead.
The basic structure of the program that does this would be:
while (__AFL_LOOP(1000)) {
/* Read input data. */
/* Call library code to be fuzzed. */
/* Reset state. */
}
/* Exit normally */
The numerical value specified within the loop controls the maximum number
of iterations before AFL will restart the process from scratch. This minimizes
the impact of memory leaks and similar glitches; 1000 is a good starting point,
and going much higher increases the likelihood of hiccups without giving you
any real performance benefits.
A more detailed template is shown in ../experimental/persistent_demo/.
Similarly to the previous mode, the feature works only with afl-clang-fast;
#ifdef guards can be used to suppress it when using other compilers.
Note that as with the previous mode, the feature is easy to misuse; if you
do not fully reset the critical state, you may end up with false positives or
waste a whole lot of CPU power doing nothing useful at all. Be particularly
wary of memory leaks and of the state of file descriptors.
PS. Because there are task switches still involved, the mode isn't as fast as
"pure" in-process fuzzing offered, say, by LLVM's LibFuzzer; but it is a lot
faster than the normal fork() model, and compared to in-process fuzzing,
should be a lot more robust.
6) Bonus feature #3: new 'trace-pc-guard' mode
----------------------------------------------
Recent versions of LLVM are shipping with a built-in execution tracing feature
that provides AFL with the necessary tracing data without the need to
post-process the assembly or install any compiler plugins. See:
http://clang.llvm.org/docs/SanitizerCoverage.html#tracing-pcs-with-guards
As of this writing, the feature is only available on SVN trunk, and is yet to
make it to an official release of LLVM. Nevertheless, if you have a
sufficiently recent compiler and want to give it a try, build afl-clang-fast
this way:
AFL_TRACE_PC=1 make clean all
Note that this mode is currently about 20% slower than "vanilla" afl-clang-fast,
and about 5-10% slower than afl-clang. This is likely because the
instrumentation is not inlined, and instead involves a function call. On systems
that support it, compiling your target with -flto should help.

381
llvm_mode/afl-clang-fast.c Normal file
View File

@ -0,0 +1,381 @@
/*
american fuzzy lop - LLVM-mode wrapper for clang
------------------------------------------------
Written by Laszlo Szekeres <lszekeres@google.com> and
Michal Zalewski <lcamtuf@google.com>
LLVM integration design comes from Laszlo Szekeres.
Copyright 2015, 2016 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at:
http://www.apache.org/licenses/LICENSE-2.0
This program is a drop-in replacement for clang, similar in most respects
to ../afl-gcc. It tries to figure out compilation mode, adds a bunch
of flags, and then calls the real compiler.
*/
#define AFL_MAIN
#include "../config.h"
#include "../types.h"
#include "../debug.h"
#include "../alloc-inl.h"
#include <stdio.h>
#include <unistd.h>
#include <stdlib.h>
#include <string.h>
static u8* obj_path; /* Path to runtime libraries */
static u8** cc_params; /* Parameters passed to the real CC */
static u32 cc_par_cnt = 1; /* Param count, including argv0 */
/* Try to find the runtime libraries. If that fails, abort. */
static void find_obj(u8* argv0) {
u8 *afl_path = getenv("AFL_PATH");
u8 *slash, *tmp;
if (afl_path) {
tmp = alloc_printf("%s/afl-llvm-rt.o", afl_path);
if (!access(tmp, R_OK)) {
obj_path = afl_path;
ck_free(tmp);
return;
}
ck_free(tmp);
}
slash = strrchr(argv0, '/');
if (slash) {
u8 *dir;
*slash = 0;
dir = ck_strdup(argv0);
*slash = '/';
tmp = alloc_printf("%s/afl-llvm-rt.o", dir);
if (!access(tmp, R_OK)) {
obj_path = dir;
ck_free(tmp);
return;
}
ck_free(tmp);
ck_free(dir);
}
if (!access(AFL_PATH "/afl-llvm-rt.o", R_OK)) {
obj_path = AFL_PATH;
return;
}
FATAL("Unable to find 'afl-llvm-rt.o' or 'afl-llvm-pass.so'. Please set AFL_PATH");
}
/* Copy argv to cc_params, making the necessary edits. */
static void edit_params(u32 argc, char** argv) {
u8 fortify_set = 0, asan_set = 0, x_set = 0, maybe_linking = 1, bit_mode = 0;
u8 *name;
cc_params = ck_alloc((argc + 128) * sizeof(u8*));
name = strrchr(argv[0], '/');
if (!name) name = argv[0]; else name++;
if (!strcmp(name, "afl-clang-fast++")) {
u8* alt_cxx = getenv("AFL_CXX");
cc_params[0] = alt_cxx ? alt_cxx : (u8*)"clang++";
} else {
u8* alt_cc = getenv("AFL_CC");
cc_params[0] = alt_cc ? alt_cc : (u8*)"clang";
}
/* There are two ways to compile afl-clang-fast. In the traditional mode, we
use afl-llvm-pass.so to inject instrumentation. In the experimental
'trace-pc-guard' mode, we use native LLVM instrumentation callbacks
instead. The latter is a very recent addition - see:
http://clang.llvm.org/docs/SanitizerCoverage.html#tracing-pcs-with-guards */
// laf
if (getenv("LAF_SPLIT_SWITCHES")) {
cc_params[cc_par_cnt++] = "-Xclang";
cc_params[cc_par_cnt++] = "-load";
cc_params[cc_par_cnt++] = "-Xclang";
cc_params[cc_par_cnt++] = alloc_printf("%s/split-switches-pass.so", obj_path);
}
if (getenv("LAF_TRANSFORM_COMPARES")) {
cc_params[cc_par_cnt++] = "-Xclang";
cc_params[cc_par_cnt++] = "-load";
cc_params[cc_par_cnt++] = "-Xclang";
cc_params[cc_par_cnt++] = alloc_printf("%s/compare-transform-pass.so", obj_path);
}
if (getenv("LAF_SPLIT_COMPARES")) {
cc_params[cc_par_cnt++] = "-Xclang";
cc_params[cc_par_cnt++] = "-load";
cc_params[cc_par_cnt++] = "-Xclang";
cc_params[cc_par_cnt++] = alloc_printf("%s/split-compares-pass.so", obj_path);
}
// /laf
#ifdef USE_TRACE_PC
cc_params[cc_par_cnt++] = "-fsanitize-coverage=trace-pc-guard";
cc_params[cc_par_cnt++] = "-mllvm";
cc_params[cc_par_cnt++] = "-sanitizer-coverage-block-threshold=0";
#else
cc_params[cc_par_cnt++] = "-Xclang";
cc_params[cc_par_cnt++] = "-load";
cc_params[cc_par_cnt++] = "-Xclang";
cc_params[cc_par_cnt++] = alloc_printf("%s/afl-llvm-pass.so", obj_path);
#endif /* ^USE_TRACE_PC */
cc_params[cc_par_cnt++] = "-Qunused-arguments";
/* Detect stray -v calls from ./configure scripts. */
if (argc == 1 && !strcmp(argv[1], "-v")) maybe_linking = 0;
while (--argc) {
u8* cur = *(++argv);
if (!strcmp(cur, "-m32")) bit_mode = 32;
if (!strcmp(cur, "-m64")) bit_mode = 64;
if (!strcmp(cur, "-x")) x_set = 1;
if (!strcmp(cur, "-c") || !strcmp(cur, "-S") || !strcmp(cur, "-E"))
maybe_linking = 0;
if (!strcmp(cur, "-fsanitize=address") ||
!strcmp(cur, "-fsanitize=memory")) asan_set = 1;
if (strstr(cur, "FORTIFY_SOURCE")) fortify_set = 1;
if (!strcmp(cur, "-shared")) maybe_linking = 0;
if (!strcmp(cur, "-Wl,-z,defs") ||
!strcmp(cur, "-Wl,--no-undefined")) continue;
cc_params[cc_par_cnt++] = cur;
}
if (getenv("AFL_HARDEN")) {
cc_params[cc_par_cnt++] = "-fstack-protector-all";
if (!fortify_set)
cc_params[cc_par_cnt++] = "-D_FORTIFY_SOURCE=2";
}
if (!asan_set) {
if (getenv("AFL_USE_ASAN")) {
if (getenv("AFL_USE_MSAN"))
FATAL("ASAN and MSAN are mutually exclusive");
if (getenv("AFL_HARDEN"))
FATAL("ASAN and AFL_HARDEN are mutually exclusive");
cc_params[cc_par_cnt++] = "-U_FORTIFY_SOURCE";
cc_params[cc_par_cnt++] = "-fsanitize=address";
} else if (getenv("AFL_USE_MSAN")) {
if (getenv("AFL_USE_ASAN"))
FATAL("ASAN and MSAN are mutually exclusive");
if (getenv("AFL_HARDEN"))
FATAL("MSAN and AFL_HARDEN are mutually exclusive");
cc_params[cc_par_cnt++] = "-U_FORTIFY_SOURCE";
cc_params[cc_par_cnt++] = "-fsanitize=memory";
}
}
#ifdef USE_TRACE_PC
if (getenv("AFL_INST_RATIO"))
FATAL("AFL_INST_RATIO not available at compile time with 'trace-pc'.");
#endif /* USE_TRACE_PC */
if (!getenv("AFL_DONT_OPTIMIZE")) {
cc_params[cc_par_cnt++] = "-g";
cc_params[cc_par_cnt++] = "-O3";
cc_params[cc_par_cnt++] = "-funroll-loops";
}
if (getenv("AFL_NO_BUILTIN")) {
cc_params[cc_par_cnt++] = "-fno-builtin-strcmp";
cc_params[cc_par_cnt++] = "-fno-builtin-strncmp";
cc_params[cc_par_cnt++] = "-fno-builtin-strcasecmp";
cc_params[cc_par_cnt++] = "-fno-builtin-strncasecmp";
cc_params[cc_par_cnt++] = "-fno-builtin-memcmp";
}
cc_params[cc_par_cnt++] = "-D__AFL_HAVE_MANUAL_CONTROL=1";
cc_params[cc_par_cnt++] = "-D__AFL_COMPILER=1";
cc_params[cc_par_cnt++] = "-DFUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION=1";
/* When the user tries to use persistent or deferred forkserver modes by
appending a single line to the program, we want to reliably inject a
signature into the binary (to be picked up by afl-fuzz) and we want
to call a function from the runtime .o file. This is unnecessarily
painful for three reasons:
1) We need to convince the compiler not to optimize out the signature.
This is done with __attribute__((used)).
2) We need to convince the linker, when called with -Wl,--gc-sections,
not to do the same. This is done by forcing an assignment to a
'volatile' pointer.
3) We need to declare __afl_persistent_loop() in the global namespace,
but doing this within a method in a class is hard - :: and extern "C"
are forbidden and __attribute__((alias(...))) doesn't work. Hence the
__asm__ aliasing trick.
*/
cc_params[cc_par_cnt++] = "-D__AFL_LOOP(_A)="
"({ static volatile char *_B __attribute__((used)); "
" _B = (char*)\"" PERSIST_SIG "\"; "
#ifdef __APPLE__
"__attribute__((visibility(\"default\"))) "
"int _L(unsigned int) __asm__(\"___afl_persistent_loop\"); "
#else
"__attribute__((visibility(\"default\"))) "
"int _L(unsigned int) __asm__(\"__afl_persistent_loop\"); "
#endif /* ^__APPLE__ */
"_L(_A); })";
cc_params[cc_par_cnt++] = "-D__AFL_INIT()="
"do { static volatile char *_A __attribute__((used)); "
" _A = (char*)\"" DEFER_SIG "\"; "
#ifdef __APPLE__
"__attribute__((visibility(\"default\"))) "
"void _I(void) __asm__(\"___afl_manual_init\"); "
#else
"__attribute__((visibility(\"default\"))) "
"void _I(void) __asm__(\"__afl_manual_init\"); "
#endif /* ^__APPLE__ */
"_I(); } while (0)";
if (maybe_linking) {
if (x_set) {
cc_params[cc_par_cnt++] = "-x";
cc_params[cc_par_cnt++] = "none";
}
switch (bit_mode) {
case 0:
cc_params[cc_par_cnt++] = alloc_printf("%s/afl-llvm-rt.o", obj_path);
break;
case 32:
cc_params[cc_par_cnt++] = alloc_printf("%s/afl-llvm-rt-32.o", obj_path);
if (access(cc_params[cc_par_cnt - 1], R_OK))
FATAL("-m32 is not supported by your compiler");
break;
case 64:
cc_params[cc_par_cnt++] = alloc_printf("%s/afl-llvm-rt-64.o", obj_path);
if (access(cc_params[cc_par_cnt - 1], R_OK))
FATAL("-m64 is not supported by your compiler");
break;
}
}
cc_params[cc_par_cnt] = NULL;
}
/* Main entry point */
int main(int argc, char** argv) {
if (isatty(2) && !getenv("AFL_QUIET")) {
#ifdef USE_TRACE_PC
SAYF(cCYA "afl-clang-fast [tpcg] " cBRI VERSION cRST " by <lszekeres@google.com>\n");
#else
SAYF(cCYA "afl-clang-fast " cBRI VERSION cRST " by <lszekeres@google.com>\n");
#endif /* ^USE_TRACE_PC */
}
if (argc < 2) {
SAYF("\n"
"This is a helper application for afl-fuzz. It serves as a drop-in replacement\n"
"for clang, letting you recompile third-party code with the required runtime\n"
"instrumentation. A common use pattern would be one of the following:\n\n"
" CC=%s/afl-clang-fast ./configure\n"
" CXX=%s/afl-clang-fast++ ./configure\n\n"
"In contrast to the traditional afl-clang tool, this version is implemented as\n"
"an LLVM pass and tends to offer improved performance with slow programs.\n\n"
"You can specify custom next-stage toolchain via AFL_CC and AFL_CXX. Setting\n"
"AFL_HARDEN enables hardening optimizations in the compiled code.\n\n",
BIN_PATH, BIN_PATH);
exit(1);
}
find_obj(argv[0]);
edit_params(argc, argv);
execvp(cc_params[0], (char**)cc_params);
FATAL("Oops, failed to execute '%s' - check your PATH", cc_params[0]);
return 0;
}

View File

@ -0,0 +1,221 @@
/*
american fuzzy lop - LLVM-mode instrumentation pass
---------------------------------------------------
Written by Laszlo Szekeres <lszekeres@google.com> and
Michal Zalewski <lcamtuf@google.com>
LLVM integration design comes from Laszlo Szekeres. C bits copied-and-pasted
from afl-as.c are Michal's fault.
Copyright 2015, 2016 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at:
http://www.apache.org/licenses/LICENSE-2.0
This library is plugged into LLVM when invoking clang through afl-clang-fast.
It tells the compiler to add code roughly equivalent to the bits discussed
in ../afl-as.h.
*/
#define AFL_LLVM_PASS
#include "../config.h"
#include "../debug.h"
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include "llvm/IR/BasicBlock.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/CFG.h"
#include <algorithm>
using namespace llvm;
namespace {
class AFLCoverage : public ModulePass {
public:
static char ID;
AFLCoverage() : ModulePass(ID) { }
bool runOnModule(Module &M) override;
// StringRef getPassName() const override {
// return "American Fuzzy Lop Instrumentation";
// }
};
}
char AFLCoverage::ID = 0;
bool AFLCoverage::runOnModule(Module &M) {
LLVMContext &C = M.getContext();
IntegerType *Int8Ty = IntegerType::getInt8Ty(C);
IntegerType *Int32Ty = IntegerType::getInt32Ty(C);
unsigned int cur_loc = 0;
/* Show a banner */
char be_quiet = 0;
if (isatty(2) && !getenv("AFL_QUIET")) {
SAYF(cCYA "afl-llvm-pass " cBRI VERSION cRST " by <lszekeres@google.com>\n");
} else be_quiet = 1;
/* Decide instrumentation ratio */
char* inst_ratio_str = getenv("AFL_INST_RATIO");
unsigned int inst_ratio = 100;
if (inst_ratio_str) {
if (sscanf(inst_ratio_str, "%u", &inst_ratio) != 1 || !inst_ratio ||
inst_ratio > 100)
FATAL("Bad value of AFL_INST_RATIO (must be between 1 and 100)");
}
/* Get globals for the SHM region and the previous location. Note that
__afl_prev_loc is thread-local. */
GlobalVariable *AFLMapPtr =
new GlobalVariable(M, PointerType::get(Int8Ty, 0), false,
GlobalValue::ExternalLinkage, 0, "__afl_area_ptr");
GlobalVariable *AFLPrevLoc = new GlobalVariable(
M, Int32Ty, false, GlobalValue::ExternalLinkage, 0, "__afl_prev_loc",
0, GlobalVariable::GeneralDynamicTLSModel, 0, false);
/* Instrument all the things! */
int inst_blocks = 0;
for (auto &F : M)
for (auto &BB : F) {
BasicBlock::iterator IP = BB.getFirstInsertionPt();
IRBuilder<> IRB(&(*IP));
if (AFL_R(100) >= inst_ratio) continue;
/* Make up cur_loc */
//cur_loc++;
cur_loc = AFL_R(MAP_SIZE);
// only instrument if this basic block is the destination of a previous
// basic block that has multiple successors
// this gets rid of ~5-10% of instrumentations that are unnecessary
// result: a little more speed and less map pollution
int more_than_one = -1;
//fprintf(stderr, "BB %u: ", cur_loc);
for (BasicBlock *Pred : predecessors(&BB)) {
int count = 0;
if (more_than_one == -1)
more_than_one = 0;
//fprintf(stderr, " %p=>", Pred);
for (BasicBlock *Succ : successors(Pred)) {
//if (count > 0)
// fprintf(stderr, "|");
if (Succ != NULL) count++;
//fprintf(stderr, "%p", Succ);
}
if (count > 1)
more_than_one = 1;
}
//fprintf(stderr, " == %d\n", more_than_one);
if (more_than_one != 1)
continue;
ConstantInt *CurLoc = ConstantInt::get(Int32Ty, cur_loc);
/* Load prev_loc */
LoadInst *PrevLoc = IRB.CreateLoad(AFLPrevLoc);
PrevLoc->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
Value *PrevLocCasted = IRB.CreateZExt(PrevLoc, IRB.getInt32Ty());
/* Load SHM pointer */
LoadInst *MapPtr = IRB.CreateLoad(AFLMapPtr);
MapPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
Value *MapPtrIdx =
IRB.CreateGEP(MapPtr, IRB.CreateXor(PrevLocCasted, CurLoc));
/* Update bitmap */
LoadInst *Counter = IRB.CreateLoad(MapPtrIdx);
Counter->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
Value *Incr = IRB.CreateAdd(Counter, ConstantInt::get(Int8Ty, 1));
IRB.CreateStore(Incr, MapPtrIdx)
->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
/* Set prev_loc to cur_loc >> 1 */
StoreInst *Store =
IRB.CreateStore(ConstantInt::get(Int32Ty, cur_loc >> 1), AFLPrevLoc);
Store->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
inst_blocks++;
}
/* Say something nice. */
if (!be_quiet) {
if (!inst_blocks) WARNF("No instrumentation targets found.");
else OKF("Instrumented %u locations (%s mode, ratio %u%%).",
inst_blocks, getenv("AFL_HARDEN") ? "hardened" :
((getenv("AFL_USE_ASAN") || getenv("AFL_USE_MSAN")) ?
"ASAN/MSAN" : "non-hardened"), inst_ratio);
}
return true;
}
static void registerAFLPass(const PassManagerBuilder &,
legacy::PassManagerBase &PM) {
PM.add(new AFLCoverage());
}
static RegisterStandardPasses RegisterAFLPass(
PassManagerBuilder::EP_OptimizerLast, registerAFLPass);
static RegisterStandardPasses RegisterAFLPass0(
PassManagerBuilder::EP_EnabledOnOptLevel0, registerAFLPass);

309
llvm_mode/afl-llvm-rt.o.c Normal file
View File

@ -0,0 +1,309 @@
/*
american fuzzy lop - LLVM instrumentation bootstrap
---------------------------------------------------
Written by Laszlo Szekeres <lszekeres@google.com> and
Michal Zalewski <lcamtuf@google.com>
LLVM integration design comes from Laszlo Szekeres.
Copyright 2015, 2016 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at:
http://www.apache.org/licenses/LICENSE-2.0
This code is the rewrite of afl-as.h's main_payload.
*/
#include "../config.h"
#include "../types.h"
#include <stdio.h>
#include <stdlib.h>
#include <signal.h>
#include <unistd.h>
#include <string.h>
#include <assert.h>
#include <sys/mman.h>
#include <sys/shm.h>
#include <sys/wait.h>
#include <sys/types.h>
/* This is a somewhat ugly hack for the experimental 'trace-pc-guard' mode.
Basically, we need to make sure that the forkserver is initialized after
the LLVM-generated runtime initialization pass, not before. */
#ifdef USE_TRACE_PC
# define CONST_PRIO 5
#else
# define CONST_PRIO 0
#endif /* ^USE_TRACE_PC */
/* Globals needed by the injected instrumentation. The __afl_area_initial region
is used for instrumentation output before __afl_map_shm() has a chance to run.
It will end up as .comm, so it shouldn't be too wasteful. */
u8 __afl_area_initial[MAP_SIZE];
u8* __afl_area_ptr = __afl_area_initial;
__thread u32 __afl_prev_loc;
/* Running in persistent mode? */
static u8 is_persistent;
/* SHM setup. */
static void __afl_map_shm(void) {
u8 *id_str = getenv(SHM_ENV_VAR);
/* If we're running under AFL, attach to the appropriate region, replacing the
early-stage __afl_area_initial region that is needed to allow some really
hacky .init code to work correctly in projects such as OpenSSL. */
if (id_str) {
u32 shm_id = atoi(id_str);
__afl_area_ptr = shmat(shm_id, NULL, 0);
/* Whooooops. */
if (__afl_area_ptr == (void *)-1) _exit(1);
/* Write something into the bitmap so that even with low AFL_INST_RATIO,
our parent doesn't give up on us. */
__afl_area_ptr[0] = 1;
}
}
/* Fork server logic. */
static void __afl_start_forkserver(void) {
static u8 tmp[4];
s32 child_pid;
u8 child_stopped = 0;
void (*old_sigchld_handler)(int) = signal(SIGCHLD, SIG_DFL);
/* Phone home and tell the parent that we're OK. If parent isn't there,
assume we're not running in forkserver mode and just execute program. */
if (write(FORKSRV_FD + 1, tmp, 4) != 4) return;
while (1) {
u32 was_killed;
int status;
/* Wait for parent by reading from the pipe. Abort if read fails. */
if (read(FORKSRV_FD, &was_killed, 4) != 4) _exit(1);
/* If we stopped the child in persistent mode, but there was a race
condition and afl-fuzz already issued SIGKILL, write off the old
process. */
if (child_stopped && was_killed) {
child_stopped = 0;
if (waitpid(child_pid, &status, 0) < 0) _exit(1);
}
if (!child_stopped) {
/* Once woken up, create a clone of our process. */
child_pid = fork();
if (child_pid < 0) _exit(1);
/* In child process: close fds, resume execution. */
if (!child_pid) {
signal(SIGCHLD, old_sigchld_handler);
close(FORKSRV_FD);
close(FORKSRV_FD + 1);
return;
}
} else {
/* Special handling for persistent mode: if the child is alive but
currently stopped, simply restart it with SIGCONT. */
kill(child_pid, SIGCONT);
child_stopped = 0;
}
/* In parent process: write PID to pipe, then wait for child. */
if (write(FORKSRV_FD + 1, &child_pid, 4) != 4) _exit(1);
if (waitpid(child_pid, &status, is_persistent ? WUNTRACED : 0) < 0)
_exit(1);
/* In persistent mode, the child stops itself with SIGSTOP to indicate
a successful run. In this case, we want to wake it up without forking
again. */
if (WIFSTOPPED(status)) child_stopped = 1;
/* Relay wait status to pipe, then loop back. */
if (write(FORKSRV_FD + 1, &status, 4) != 4) _exit(1);
}
}
/* A simplified persistent mode handler, used as explained in README.llvm. */
int __afl_persistent_loop(unsigned int max_cnt) {
static u8 first_pass = 1;
static u32 cycle_cnt;
if (first_pass) {
/* Make sure that every iteration of __AFL_LOOP() starts with a clean slate.
On subsequent calls, the parent will take care of that, but on the first
iteration, it's our job to erase any trace of whatever happened
before the loop. */
if (is_persistent) {
memset(__afl_area_ptr, 0, MAP_SIZE);
__afl_area_ptr[0] = 1;
__afl_prev_loc = 0;
}
cycle_cnt = max_cnt;
first_pass = 0;
return 1;
}
if (is_persistent) {
if (--cycle_cnt) {
raise(SIGSTOP);
__afl_area_ptr[0] = 1;
__afl_prev_loc = 0;
return 1;
} else {
/* When exiting __AFL_LOOP(), make sure that the subsequent code that
follows the loop is not traced. We do that by pivoting back to the
dummy output region. */
__afl_area_ptr = __afl_area_initial;
}
}
return 0;
}
/* This one can be called from user code when deferred forkserver mode
is enabled. */
void __afl_manual_init(void) {
static u8 init_done;
if (!init_done) {
__afl_map_shm();
__afl_start_forkserver();
init_done = 1;
}
}
/* Proper initialization routine. */
__attribute__((constructor(CONST_PRIO))) void __afl_auto_init(void) {
is_persistent = !!getenv(PERSIST_ENV_VAR);
if (getenv(DEFER_ENV_VAR)) return;
__afl_manual_init();
}
/* The following stuff deals with supporting -fsanitize-coverage=trace-pc-guard.
It remains non-operational in the traditional, plugin-backed LLVM mode.
For more info about 'trace-pc-guard', see README.llvm.
The first function (__sanitizer_cov_trace_pc_guard) is called back on every
edge (as opposed to every basic block). */
void __sanitizer_cov_trace_pc_guard(uint32_t* guard) {
__afl_area_ptr[*guard]++;
}
/* Init callback. Populates instrumentation IDs. Note that we're using
ID of 0 as a special value to indicate non-instrumented bits. That may
still touch the bitmap, but in a fairly harmless way. */
void __sanitizer_cov_trace_pc_guard_init(uint32_t* start, uint32_t* stop) {
u32 inst_ratio = 100;
u8* x;
if (start == stop || *start) return;
x = getenv("AFL_INST_RATIO");
if (x) inst_ratio = atoi(x);
if (!inst_ratio || inst_ratio > 100) {
fprintf(stderr, "[-] ERROR: Invalid AFL_INST_RATIO (must be 1-100).\n");
abort();
}
/* Make sure that the first element in the range is always set - we use that
to avoid duplicate calls (which can happen as an artifact of the underlying
implementation in LLVM). */
*(start++) = R(MAP_SIZE - 1) + 1;
while (start < stop) {
if (R(100) < inst_ratio) *start = R(MAP_SIZE - 1) + 1;
else *start = 0;
start++;
}
}

View File

@ -0,0 +1,306 @@
/*
* Copyright 2016 laf-intel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Pass.h"
#include "llvm/Analysis/ValueTracking.h"
#include <set>
using namespace llvm;
namespace {
class CompareTransform : public ModulePass {
public:
static char ID;
CompareTransform() : ModulePass(ID) {
}
bool runOnModule(Module &M) override;
#if __clang_major__ < 4
const char * getPassName() const override {
#else
StringRef getPassName() const override {
#endif
return "transforms compare functions";
}
private:
bool transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp
,const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp);
};
}
char CompareTransform::ID = 0;
bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp
, const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp) {
std::vector<CallInst*> calls;
LLVMContext &C = M.getContext();
IntegerType *Int8Ty = IntegerType::getInt8Ty(C);
IntegerType *Int32Ty = IntegerType::getInt32Ty(C);
IntegerType *Int64Ty = IntegerType::getInt64Ty(C);
Constant* c = M.getOrInsertFunction("tolower",
Int32Ty,
Int32Ty
#if __clang_major__ < 7
, nullptr
#endif
);
Function* tolowerFn = cast<Function>(c);
/* iterate over all functions, bbs and instruction and add suitable calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp */
for (auto &F : M) {
for (auto &BB : F) {
for(auto &IN: BB) {
CallInst* callInst = nullptr;
if ((callInst = dyn_cast<CallInst>(&IN))) {
bool isStrcmp = processStrcmp;
bool isMemcmp = processMemcmp;
bool isStrncmp = processStrncmp;
bool isStrcasecmp = processStrcasecmp;
bool isStrncasecmp = processStrncasecmp;
Function *Callee = callInst->getCalledFunction();
if (!Callee)
continue;
if (callInst->getCallingConv() != llvm::CallingConv::C)
continue;
StringRef FuncName = Callee->getName();
isStrcmp &= !FuncName.compare(StringRef("strcmp"));
isMemcmp &= !FuncName.compare(StringRef("memcmp"));
isStrncmp &= !FuncName.compare(StringRef("strncmp"));
isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp"));
isStrncasecmp &= !FuncName.compare(StringRef("strncasecmp"));
if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp)
continue;
/* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function prototype */
FunctionType *FT = Callee->getFunctionType();
isStrcmp &= FT->getNumParams() == 2 &&
FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0) == FT->getParamType(1) &&
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
isStrcasecmp &= FT->getNumParams() == 2 &&
FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0) == FT->getParamType(1) &&
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
isMemcmp &= FT->getNumParams() == 3 &&
FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0)->isPointerTy() &&
FT->getParamType(1)->isPointerTy() &&
FT->getParamType(2)->isIntegerTy();
isStrncmp &= FT->getNumParams() == 3 &&
FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0) == FT->getParamType(1) &&
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) &&
FT->getParamType(2)->isIntegerTy();
isStrncasecmp &= FT->getNumParams() == 3 &&
FT->getReturnType()->isIntegerTy(32) &&
FT->getParamType(0) == FT->getParamType(1) &&
FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) &&
FT->getParamType(2)->isIntegerTy();
if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp)
continue;
/* is a str{n,}{case,}cmp/memcmp, check is we have
* str{case,}cmp(x, "const") or str{case,}cmp("const", x)
* strn{case,}cmp(x, "const", ..) or strn{case,}cmp("const", x, ..)
* memcmp(x, "const", ..) or memcmp("const", x, ..) */
Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1);
StringRef Str1, Str2;
bool HasStr1 = getConstantStringInfo(Str1P, Str1);
bool HasStr2 = getConstantStringInfo(Str2P, Str2);
/* handle cases of one string is const, one string is variable */
if (!(HasStr1 ^ HasStr2))
continue;
if (isMemcmp || isStrncmp || isStrncasecmp) {
/* check if third operand is a constant integer
* strlen("constStr") and sizeof() are treated as constant */
Value *op2 = callInst->getArgOperand(2);
ConstantInt* ilen = dyn_cast<ConstantInt>(op2);
if (!ilen)
continue;
/* final precaution: if size of compare is larger than constant string skip it*/
uint64_t literalLength = HasStr1 ? GetStringLength(Str1P) : GetStringLength(Str2P);
if (literalLength < ilen->getZExtValue())
continue;
}
calls.push_back(callInst);
}
}
}
}
if (!calls.size())
return false;
errs() << "Replacing " << calls.size() << " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n";
for (auto &callInst: calls) {
Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1);
StringRef Str1, Str2, ConstStr;
Value *VarStr;
bool HasStr1 = getConstantStringInfo(Str1P, Str1);
getConstantStringInfo(Str2P, Str2);
uint64_t constLen, sizedLen;
bool isMemcmp = !callInst->getCalledFunction()->getName().compare(StringRef("memcmp"));
bool isSizedcmp = isMemcmp
|| !callInst->getCalledFunction()->getName().compare(StringRef("strncmp"))
|| !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp"));
bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare(StringRef("strcasecmp"))
|| !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp"));
if (isSizedcmp) {
Value *op2 = callInst->getArgOperand(2);
ConstantInt* ilen = dyn_cast<ConstantInt>(op2);
sizedLen = ilen->getZExtValue();
}
if (HasStr1) {
ConstStr = Str1;
VarStr = Str2P;
constLen = isMemcmp ? sizedLen : GetStringLength(Str1P);
}
else {
ConstStr = Str2;
VarStr = Str1P;
constLen = isMemcmp ? sizedLen : GetStringLength(Str2P);
}
if (isSizedcmp && constLen > sizedLen) {
constLen = sizedLen;
}
errs() << callInst->getCalledFunction()->getName() << ": len " << constLen << ": " << ConstStr << "\n";
/* split before the call instruction */
BasicBlock *bb = callInst->getParent();
BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(callInst));
BasicBlock *next_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
BranchInst::Create(end_bb, next_bb);
PHINode *PN = PHINode::Create(Int32Ty, constLen + 1, "cmp_phi");
TerminatorInst *term = bb->getTerminator();
BranchInst::Create(next_bb, bb);
term->eraseFromParent();
for (uint64_t i = 0; i < constLen; i++) {
BasicBlock *cur_bb = next_bb;
char c = isCaseInsensitive ? tolower(ConstStr[i]) : ConstStr[i];
BasicBlock::iterator IP = next_bb->getFirstInsertionPt();
IRBuilder<> IRB(&*IP);
Value* v = ConstantInt::get(Int64Ty, i);
Value *ele = IRB.CreateInBoundsGEP(VarStr, v, "empty");
Value *load = IRB.CreateLoad(ele);
if (isCaseInsensitive) {
// load >= 'A' && load <= 'Z' ? load | 0x020 : load
std::vector<Value *> args;
args.push_back(load);
load = IRB.CreateCall(tolowerFn, args, "tmp");
}
Value *isub;
if (HasStr1)
isub = IRB.CreateSub(ConstantInt::get(Int8Ty, c), load);
else
isub = IRB.CreateSub(load, ConstantInt::get(Int8Ty, c));
Value *sext = IRB.CreateSExt(isub, Int32Ty);
PN->addIncoming(sext, cur_bb);
if (i < constLen - 1) {
next_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
BranchInst::Create(end_bb, next_bb);
TerminatorInst *term = cur_bb->getTerminator();
Value *icmp = IRB.CreateICmpEQ(isub, ConstantInt::get(Int8Ty, 0));
IRB.CreateCondBr(icmp, next_bb, end_bb);
term->eraseFromParent();
} else {
//IRB.CreateBr(end_bb);
}
//add offset to varstr
//create load
//create signed isub
//create icmp
//create jcc
//create next_bb
}
/* since the call is the first instruction of the bb it is safe to
* replace it with a phi instruction */
BasicBlock::iterator ii(callInst);
ReplaceInstWithInst(callInst->getParent()->getInstList(), ii, PN);
}
return true;
}
bool CompareTransform::runOnModule(Module &M) {
llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, extended by heiko@hexco.de\n";
transformCmps(M, true, true, true, true, true);
verifyModule(M);
return true;
}
static void registerCompTransPass(const PassManagerBuilder &,
legacy::PassManagerBase &PM) {
auto p = new CompareTransform();
PM.add(p);
}
static RegisterStandardPasses RegisterCompTransPass(
PassManagerBuilder::EP_OptimizerLast, registerCompTransPass);
static RegisterStandardPasses RegisterCompTransPass0(
PassManagerBuilder::EP_EnabledOnOptLevel0, registerCompTransPass);

View File

@ -0,0 +1,527 @@
/*
* Copyright 2016 laf-intel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "llvm/Pass.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
using namespace llvm;
namespace {
class SplitComparesTransform : public ModulePass {
public:
static char ID;
SplitComparesTransform() : ModulePass(ID) {}
bool runOnModule(Module &M) override;
#if __clang_major__ >= 4
StringRef getPassName() const override {
#else
const char * getPassName() const override {
#endif
return "simplifies and splits ICMP instructions";
}
private:
bool splitCompares(Module &M, unsigned bitw);
bool simplifyCompares(Module &M);
bool simplifySignedness(Module &M);
};
}
char SplitComparesTransform::ID = 0;
/* This function splits ICMP instructions with xGE or xLE predicates into two
* ICMP instructions with predicate xGT or xLT and EQ */
bool SplitComparesTransform::simplifyCompares(Module &M) {
LLVMContext &C = M.getContext();
std::vector<Instruction*> icomps;
IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
/* iterate over all functions, bbs and instruction and add
* all integer comparisons with >= and <= predicates to the icomps vector */
for (auto &F : M) {
for (auto &BB : F) {
for (auto &IN: BB) {
CmpInst* selectcmpInst = nullptr;
if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
if (selectcmpInst->getPredicate() != CmpInst::ICMP_UGE &&
selectcmpInst->getPredicate() != CmpInst::ICMP_SGE &&
selectcmpInst->getPredicate() != CmpInst::ICMP_ULE &&
selectcmpInst->getPredicate() != CmpInst::ICMP_SLE ) {
continue;
}
auto op0 = selectcmpInst->getOperand(0);
auto op1 = selectcmpInst->getOperand(1);
IntegerType* intTyOp0 = dyn_cast<IntegerType>(op0->getType());
IntegerType* intTyOp1 = dyn_cast<IntegerType>(op1->getType());
/* this is probably not needed but we do it anyway */
if (!intTyOp0 || !intTyOp1) {
continue;
}
icomps.push_back(selectcmpInst);
}
}
}
}
if (!icomps.size()) {
return false;
}
for (auto &IcmpInst: icomps) {
BasicBlock* bb = IcmpInst->getParent();
auto op0 = IcmpInst->getOperand(0);
auto op1 = IcmpInst->getOperand(1);
/* find out what the new predicate is going to be */
auto pred = dyn_cast<CmpInst>(IcmpInst)->getPredicate();
CmpInst::Predicate new_pred;
switch(pred) {
case CmpInst::ICMP_UGE:
new_pred = CmpInst::ICMP_UGT;
break;
case CmpInst::ICMP_SGE:
new_pred = CmpInst::ICMP_SGT;
break;
case CmpInst::ICMP_ULE:
new_pred = CmpInst::ICMP_ULT;
break;
case CmpInst::ICMP_SLE:
new_pred = CmpInst::ICMP_SLT;
break;
default: // keep the compiler happy
continue;
}
/* split before the icmp instruction */
BasicBlock* end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
/* the old bb now contains a unconditional jump to the new one (end_bb)
* we need to delete it later */
/* create the ICMP instruction with new_pred and add it to the old basic
* block bb it is now at the position where the old IcmpInst was */
Instruction* icmp_np;
icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
bb->getInstList().insert(bb->getTerminator()->getIterator(), icmp_np);
/* create a new basic block which holds the new EQ icmp */
Instruction *icmp_eq;
/* insert middle_bb before end_bb */
BasicBlock* middle_bb = BasicBlock::Create(C, "injected",
end_bb->getParent(), end_bb);
icmp_eq = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, op0, op1);
middle_bb->getInstList().push_back(icmp_eq);
/* add an unconditional branch to the end of middle_bb with destination
* end_bb */
BranchInst::Create(end_bb, middle_bb);
/* replace the uncond branch with a conditional one, which depends on the
* new_pred icmp. True goes to end, false to the middle (injected) bb */
auto term = bb->getTerminator();
BranchInst::Create(end_bb, middle_bb, icmp_np, bb);
term->eraseFromParent();
/* replace the old IcmpInst (which is the first inst in end_bb) with a PHI
* inst to wire up the loose ends */
PHINode *PN = PHINode::Create(Int1Ty, 2, "");
/* the first result depends on the outcome of icmp_eq */
PN->addIncoming(icmp_eq, middle_bb);
/* if the source was the original bb we know that the icmp_np yielded true
* hence we can hardcode this value */
PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
/* replace the old IcmpInst with our new and shiny PHI inst */
BasicBlock::iterator ii(IcmpInst);
ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
}
return true;
}
/* this function transforms signed compares to equivalent unsigned compares */
bool SplitComparesTransform::simplifySignedness(Module &M) {
LLVMContext &C = M.getContext();
std::vector<Instruction*> icomps;
IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
/* iterate over all functions, bbs and instruction and add
* all signed compares to icomps vector */
for (auto &F : M) {
for (auto &BB : F) {
for(auto &IN: BB) {
CmpInst* selectcmpInst = nullptr;
if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
if (selectcmpInst->getPredicate() != CmpInst::ICMP_SGT &&
selectcmpInst->getPredicate() != CmpInst::ICMP_SLT
) {
continue;
}
auto op0 = selectcmpInst->getOperand(0);
auto op1 = selectcmpInst->getOperand(1);
IntegerType* intTyOp0 = dyn_cast<IntegerType>(op0->getType());
IntegerType* intTyOp1 = dyn_cast<IntegerType>(op1->getType());
/* see above */
if (!intTyOp0 || !intTyOp1) {
continue;
}
/* i think this is not possible but to lazy to look it up */
if (intTyOp0->getBitWidth() != intTyOp1->getBitWidth()) {
continue;
}
icomps.push_back(selectcmpInst);
}
}
}
}
if (!icomps.size()) {
return false;
}
for (auto &IcmpInst: icomps) {
BasicBlock* bb = IcmpInst->getParent();
auto op0 = IcmpInst->getOperand(0);
auto op1 = IcmpInst->getOperand(1);
IntegerType* intTyOp0 = dyn_cast<IntegerType>(op0->getType());
unsigned bitw = intTyOp0->getBitWidth();
IntegerType *IntType = IntegerType::get(C, bitw);
/* get the new predicate */
auto pred = dyn_cast<CmpInst>(IcmpInst)->getPredicate();
CmpInst::Predicate new_pred;
if (pred == CmpInst::ICMP_SGT) {
new_pred = CmpInst::ICMP_UGT;
} else {
new_pred = CmpInst::ICMP_ULT;
}
BasicBlock* end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
/* create a 1 bit compare for the sign bit. to do this shift and trunc
* the original operands so only the first bit remains.*/
Instruction *s_op0, *t_op0, *s_op1, *t_op1, *icmp_sign_bit;
s_op0 = BinaryOperator::Create(Instruction::LShr, op0, ConstantInt::get(IntType, bitw - 1));
bb->getInstList().insert(bb->getTerminator()->getIterator(), s_op0);
t_op0 = new TruncInst(s_op0, Int1Ty);
bb->getInstList().insert(bb->getTerminator()->getIterator(), t_op0);
s_op1 = BinaryOperator::Create(Instruction::LShr, op1, ConstantInt::get(IntType, bitw - 1));
bb->getInstList().insert(bb->getTerminator()->getIterator(), s_op1);
t_op1 = new TruncInst(s_op1, Int1Ty);
bb->getInstList().insert(bb->getTerminator()->getIterator(), t_op1);
/* compare of the sign bits */
icmp_sign_bit = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_op0, t_op1);
bb->getInstList().insert(bb->getTerminator()->getIterator(), icmp_sign_bit);
/* create a new basic block which is executed if the signedness bit is
* different */
Instruction *icmp_inv_sig_cmp;
BasicBlock* sign_bb = BasicBlock::Create(C, "sign", end_bb->getParent(), end_bb);
if (pred == CmpInst::ICMP_SGT) {
/* if we check for > and the op0 positiv and op1 negative then the final
* result is true. if op0 negative and op1 pos, the cmp must result
* in false
*/
icmp_inv_sig_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_op0, t_op1);
} else {
/* just the inverse of the above statement */
icmp_inv_sig_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_op0, t_op1);
}
sign_bb->getInstList().push_back(icmp_inv_sig_cmp);
BranchInst::Create(end_bb, sign_bb);
/* create a new bb which is executed if signedness is equal */
Instruction *icmp_usign_cmp;
BasicBlock* middle_bb = BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
/* we can do a normal unsigned compare now */
icmp_usign_cmp = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
middle_bb->getInstList().push_back(icmp_usign_cmp);
BranchInst::Create(end_bb, middle_bb);
auto term = bb->getTerminator();
/* if the sign is eq do a normal unsigned cmp, else we have to check the
* signedness bit */
BranchInst::Create(middle_bb, sign_bb, icmp_sign_bit, bb);
term->eraseFromParent();
PHINode *PN = PHINode::Create(Int1Ty, 2, "");
PN->addIncoming(icmp_usign_cmp, middle_bb);
PN->addIncoming(icmp_inv_sig_cmp, sign_bb);
BasicBlock::iterator ii(IcmpInst);
ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
}
return true;
}
/* splits icmps of size bitw into two nested icmps with bitw/2 size each */
bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) {
LLVMContext &C = M.getContext();
IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
IntegerType *OldIntType = IntegerType::get(C, bitw);
IntegerType *NewIntType = IntegerType::get(C, bitw / 2);
std::vector<Instruction*> icomps;
if (bitw % 2) {
return false;
}
/* not supported yet */
if (bitw > 64) {
return false;
}
/* get all EQ, NE, UGT, and ULT icmps of width bitw. if the other two
* unctions were executed only these four predicates should exist */
for (auto &F : M) {
for (auto &BB : F) {
for(auto &IN: BB) {
CmpInst* selectcmpInst = nullptr;
if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
if(selectcmpInst->getPredicate() != CmpInst::ICMP_EQ &&
selectcmpInst->getPredicate() != CmpInst::ICMP_NE &&
selectcmpInst->getPredicate() != CmpInst::ICMP_UGT &&
selectcmpInst->getPredicate() != CmpInst::ICMP_ULT
) {
continue;
}
auto op0 = selectcmpInst->getOperand(0);
auto op1 = selectcmpInst->getOperand(1);
IntegerType* intTyOp0 = dyn_cast<IntegerType>(op0->getType());
IntegerType* intTyOp1 = dyn_cast<IntegerType>(op1->getType());
if (!intTyOp0 || !intTyOp1) {
continue;
}
/* check if the bitwidths are the one we are looking for */
if (intTyOp0->getBitWidth() != bitw || intTyOp1->getBitWidth() != bitw) {
continue;
}
icomps.push_back(selectcmpInst);
}
}
}
}
if (!icomps.size()) {
return false;
}
for (auto &IcmpInst: icomps) {
BasicBlock* bb = IcmpInst->getParent();
auto op0 = IcmpInst->getOperand(0);
auto op1 = IcmpInst->getOperand(1);
auto pred = dyn_cast<CmpInst>(IcmpInst)->getPredicate();
BasicBlock* end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
/* create the comparison of the top halfs of the original operands */
Instruction *s_op0, *op0_high, *s_op1, *op1_high, *icmp_high;
s_op0 = BinaryOperator::Create(Instruction::LShr, op0, ConstantInt::get(OldIntType, bitw / 2));
bb->getInstList().insert(bb->getTerminator()->getIterator(), s_op0);
op0_high = new TruncInst(s_op0, NewIntType);
bb->getInstList().insert(bb->getTerminator()->getIterator(), op0_high);
s_op1 = BinaryOperator::Create(Instruction::LShr, op1, ConstantInt::get(OldIntType, bitw / 2));
bb->getInstList().insert(bb->getTerminator()->getIterator(), s_op1);
op1_high = new TruncInst(s_op1, NewIntType);
bb->getInstList().insert(bb->getTerminator()->getIterator(), op1_high);
icmp_high = CmpInst::Create(Instruction::ICmp, pred, op0_high, op1_high);
bb->getInstList().insert(bb->getTerminator()->getIterator(), icmp_high);
/* now we have to destinguish between == != and > < */
if (pred == CmpInst::ICMP_EQ || pred == CmpInst::ICMP_NE) {
/* transformation for == and != icmps */
/* create a compare for the lower half of the original operands */
Instruction *op0_low, *op1_low, *icmp_low;
BasicBlock* cmp_low_bb = BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
op0_low = new TruncInst(op0, NewIntType);
cmp_low_bb->getInstList().push_back(op0_low);
op1_low = new TruncInst(op1, NewIntType);
cmp_low_bb->getInstList().push_back(op1_low);
icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low);
cmp_low_bb->getInstList().push_back(icmp_low);
BranchInst::Create(end_bb, cmp_low_bb);
/* dependant on the cmp of the high parts go to the end or go on with
* the comparison */
auto term = bb->getTerminator();
if (pred == CmpInst::ICMP_EQ) {
BranchInst::Create(cmp_low_bb, end_bb, icmp_high, bb);
} else {
/* CmpInst::ICMP_NE */
BranchInst::Create(end_bb, cmp_low_bb, icmp_high, bb);
}
term->eraseFromParent();
/* create the PHI and connect the edges accordingly */
PHINode *PN = PHINode::Create(Int1Ty, 2, "");
PN->addIncoming(icmp_low, cmp_low_bb);
if (pred == CmpInst::ICMP_EQ) {
PN->addIncoming(ConstantInt::get(Int1Ty, 0), bb);
} else {
/* CmpInst::ICMP_NE */
PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
}
/* replace the old icmp with the new PHI */
BasicBlock::iterator ii(IcmpInst);
ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
} else {
/* CmpInst::ICMP_UGT and CmpInst::ICMP_ULT */
/* transformations for < and > */
/* create a basic block which checks for the inverse predicate.
* if this is true we can go to the end if not we have to got to the
* bb which checks the lower half of the operands */
Instruction *icmp_inv_cmp, *op0_low, *op1_low, *icmp_low;
BasicBlock* inv_cmp_bb = BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb);
if (pred == CmpInst::ICMP_UGT) {
icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, op0_high, op1_high);
} else {
icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, op0_high, op1_high);
}
inv_cmp_bb->getInstList().push_back(icmp_inv_cmp);
auto term = bb->getTerminator();
term->eraseFromParent();
BranchInst::Create(end_bb, inv_cmp_bb, icmp_high, bb);
/* create a bb which handles the cmp of the lower halfs */
BasicBlock* cmp_low_bb = BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
op0_low = new TruncInst(op0, NewIntType);
cmp_low_bb->getInstList().push_back(op0_low);
op1_low = new TruncInst(op1, NewIntType);
cmp_low_bb->getInstList().push_back(op1_low);
icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low);
cmp_low_bb->getInstList().push_back(icmp_low);
BranchInst::Create(end_bb, cmp_low_bb);
BranchInst::Create(end_bb, cmp_low_bb, icmp_inv_cmp, inv_cmp_bb);
PHINode *PN = PHINode::Create(Int1Ty, 3);
PN->addIncoming(icmp_low, cmp_low_bb);
PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
PN->addIncoming(ConstantInt::get(Int1Ty, 0), inv_cmp_bb);
BasicBlock::iterator ii(IcmpInst);
ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
}
}
return true;
}
bool SplitComparesTransform::runOnModule(Module &M) {
int bitw = 64;
char* bitw_env = getenv("LAF_SPLIT_COMPARES_BITW");
if (bitw_env) {
bitw = atoi(bitw_env);
}
simplifyCompares(M);
simplifySignedness(M);
errs() << "Split-compare-pass by laf.intel@gmail.com\n";
switch (bitw) {
case 64:
errs() << "Running split-compare-pass " << 64 << "\n";
splitCompares(M, 64);
[[clang::fallthrough]];
/* fallthrough */
case 32:
errs() << "Running split-compare-pass " << 32 << "\n";
splitCompares(M, 32);
[[clang::fallthrough]];
/* fallthrough */
case 16:
errs() << "Running split-compare-pass " << 16 << "\n";
splitCompares(M, 16);
break;
default:
errs() << "NOT Running split-compare-pass \n";
return false;
break;
}
verifyModule(M);
return true;
}
static void registerSplitComparesPass(const PassManagerBuilder &,
legacy::PassManagerBase &PM) {
PM.add(new SplitComparesTransform());
}
static RegisterStandardPasses RegisterSplitComparesPass(
PassManagerBuilder::EP_OptimizerLast, registerSplitComparesPass);
static RegisterStandardPasses RegisterSplitComparesTransPass0(
PassManagerBuilder::EP_EnabledOnOptLevel0, registerSplitComparesPass);

View File

@ -0,0 +1,315 @@
/*
* Copyright 2016 laf-intel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Pass.h"
#include "llvm/Analysis/ValueTracking.h"
#include <set>
using namespace llvm;
namespace {
class SplitSwitchesTransform : public ModulePass {
public:
static char ID;
SplitSwitchesTransform() : ModulePass(ID) {
}
bool runOnModule(Module &M) override;
#if __clang_major__ >= 4
StringRef getPassName() const override {
#else
const char * getPassName() const override {
#endif
return "splits switch constructs";
}
struct CaseExpr {
ConstantInt* Val;
BasicBlock* BB;
CaseExpr(ConstantInt *val = nullptr, BasicBlock *bb = nullptr) :
Val(val), BB(bb) { }
};
typedef std::vector<CaseExpr> CaseVector;
private:
bool splitSwitches(Module &M);
bool transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp);
BasicBlock* switchConvert(CaseVector Cases, std::vector<bool> bytesChecked,
BasicBlock* OrigBlock, BasicBlock* NewDefault,
Value* Val, unsigned level);
};
}
char SplitSwitchesTransform::ID = 0;
/* switchConvert - Transform simple list of Cases into list of CaseRange's */
BasicBlock* SplitSwitchesTransform::switchConvert(CaseVector Cases, std::vector<bool> bytesChecked,
BasicBlock* OrigBlock, BasicBlock* NewDefault,
Value* Val, unsigned level) {
unsigned ValTypeBitWidth = Cases[0].Val->getBitWidth();
IntegerType *ValType = IntegerType::get(OrigBlock->getContext(), ValTypeBitWidth);
IntegerType *ByteType = IntegerType::get(OrigBlock->getContext(), 8);
unsigned BytesInValue = bytesChecked.size();
std::vector<uint8_t> setSizes;
std::vector<std::set<uint8_t>> byteSets(BytesInValue, std::set<uint8_t>());
/* for each of the possible cases we iterate over all bytes of the values
* build a set of possible values at each byte position in byteSets */
for (CaseExpr& Case: Cases) {
for (unsigned i = 0; i < BytesInValue; i++) {
uint8_t byte = (Case.Val->getZExtValue() >> (i*8)) & 0xFF;
byteSets[i].insert(byte);
}
}
unsigned smallestIndex = 0;
unsigned smallestSize = 257;
for(unsigned i = 0; i < byteSets.size(); i++) {
if (bytesChecked[i])
continue;
if (byteSets[i].size() < smallestSize) {
smallestIndex = i;
smallestSize = byteSets[i].size();
}
}
assert(bytesChecked[smallestIndex] == false);
/* there are only smallestSize different bytes at index smallestIndex */
Instruction *Shift, *Trunc;
Function* F = OrigBlock->getParent();
BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock", F);
Shift = BinaryOperator::Create(Instruction::LShr, Val, ConstantInt::get(ValType, smallestIndex * 8));
NewNode->getInstList().push_back(Shift);
if (ValTypeBitWidth > 8) {
Trunc = new TruncInst(Shift, ByteType);
NewNode->getInstList().push_back(Trunc);
}
else {
/* not necessary to trunc */
Trunc = Shift;
}
/* this is a trivial case, we can directly check for the byte,
* if the byte is not found go to default. if the byte was found
* mark the byte as checked. if this was the last byte to check
* we can finally execute the block belonging to this case */
if (smallestSize == 1) {
uint8_t byte = *(byteSets[smallestIndex].begin());
/* insert instructions to check whether the value we are switching on is equal to byte */
ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_EQ, Trunc, ConstantInt::get(ByteType, byte), "byteMatch");
NewNode->getInstList().push_back(Comp);
bytesChecked[smallestIndex] = true;
if (std::all_of(bytesChecked.begin(), bytesChecked.end(), [](bool b){return b;} )) {
assert(Cases.size() == 1);
BranchInst::Create(Cases[0].BB, NewDefault, Comp, NewNode);
/* we have to update the phi nodes! */
for (BasicBlock::iterator I = Cases[0].BB->begin(); I != Cases[0].BB->end(); ++I) {
if (!isa<PHINode>(&*I)) {
continue;
}
PHINode *PN = cast<PHINode>(I);
/* Only update the first occurence. */
unsigned Idx = 0, E = PN->getNumIncomingValues();
for (; Idx != E; ++Idx) {
if (PN->getIncomingBlock(Idx) == OrigBlock) {
PN->setIncomingBlock(Idx, NewNode);
break;
}
}
}
}
else {
BasicBlock* BB = switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, Val, level + 1);
BranchInst::Create(BB, NewDefault, Comp, NewNode);
}
}
/* there is no byte which we can directly check on, split the tree */
else {
std::vector<uint8_t> byteVector;
std::copy(byteSets[smallestIndex].begin(), byteSets[smallestIndex].end(), std::back_inserter(byteVector));
std::sort(byteVector.begin(), byteVector.end());
uint8_t pivot = byteVector[byteVector.size() / 2];
/* we already chose to divide the cases based on the value of byte at index smallestIndex
* the pivot value determines the threshold for the decicion; if a case value
* is smaller at this byte index move it to the LHS vector, otherwise to the RHS vector */
CaseVector LHSCases, RHSCases;
for (CaseExpr& Case: Cases) {
uint8_t byte = (Case.Val->getZExtValue() >> (smallestIndex*8)) & 0xFF;
if (byte < pivot) {
LHSCases.push_back(Case);
}
else {
RHSCases.push_back(Case);
}
}
BasicBlock *LBB, *RBB;
LBB = switchConvert(LHSCases, bytesChecked, OrigBlock, NewDefault, Val, level + 1);
RBB = switchConvert(RHSCases, bytesChecked, OrigBlock, NewDefault, Val, level + 1);
/* insert instructions to check whether the value we are switching on is equal to byte */
ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_ULT, Trunc, ConstantInt::get(ByteType, pivot), "byteMatch");
NewNode->getInstList().push_back(Comp);
BranchInst::Create(LBB, RBB, Comp, NewNode);
}
return NewNode;
}
bool SplitSwitchesTransform::splitSwitches(Module &M) {
std::vector<SwitchInst*> switches;
/* iterate over all functions, bbs and instruction and add
* all switches to switches vector for later processing */
for (auto &F : M) {
for (auto &BB : F) {
SwitchInst* switchInst = nullptr;
if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) {
if (switchInst->getNumCases() < 1)
continue;
switches.push_back(switchInst);
}
}
}
if (!switches.size())
return false;
errs() << "Rewriting " << switches.size() << " switch statements " << "\n";
for (auto &SI: switches) {
BasicBlock *CurBlock = SI->getParent();
BasicBlock *OrigBlock = CurBlock;
Function *F = CurBlock->getParent();
/* this is the value we are switching on */
Value *Val = SI->getCondition();
BasicBlock* Default = SI->getDefaultDest();
/* If there is only the default destination, don't bother with the code below. */
if (!SI->getNumCases()) {
continue;
}
/* Create a new, empty default block so that the new hierarchy of
* if-then statements go to this and the PHI nodes are happy.
* if the default block is set as an unreachable we avoid creating one
* because will never be a valid target.*/
BasicBlock *NewDefault = nullptr;
NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault");
NewDefault->insertInto(F, Default);
BranchInst::Create(Default, NewDefault);
/* Prepare cases vector. */
CaseVector Cases;
for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i)
#if __clang_major__ < 7
Cases.push_back(CaseExpr(i.getCaseValue(), i.getCaseSuccessor()));
#else
Cases.push_back(CaseExpr(i->getCaseValue(), i->getCaseSuccessor()));
#endif
std::vector<bool> bytesChecked(Cases[0].Val->getBitWidth() / 8, false);
BasicBlock* SwitchBlock = switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, Val, 0);
/* Branch to our shiny new if-then stuff... */
BranchInst::Create(SwitchBlock, OrigBlock);
/* We are now done with the switch instruction, delete it. */
CurBlock->getInstList().erase(SI);
/* we have to update the phi nodes! */
for (BasicBlock::iterator I = Default->begin(); I != Default->end(); ++I) {
if (!isa<PHINode>(&*I)) {
continue;
}
PHINode *PN = cast<PHINode>(I);
/* Only update the first occurence. */
unsigned Idx = 0, E = PN->getNumIncomingValues();
for (; Idx != E; ++Idx) {
if (PN->getIncomingBlock(Idx) == OrigBlock) {
PN->setIncomingBlock(Idx, NewDefault);
break;
}
}
}
}
verifyModule(M);
return true;
}
bool SplitSwitchesTransform::runOnModule(Module &M) {
llvm::errs() << "Running split-switches-pass by laf.intel@gmail.com\n";
splitSwitches(M);
verifyModule(M);
return true;
}
static void registerSplitSwitchesTransPass(const PassManagerBuilder &,
legacy::PassManagerBase &PM) {
auto p = new SplitSwitchesTransform();
PM.add(p);
}
static RegisterStandardPasses RegisterSplitSwitchesTransPass(
PassManagerBuilder::EP_OptimizerLast, registerSplitSwitchesTransPass);
static RegisterStandardPasses RegisterSplitSwitchesTransPass0(
PassManagerBuilder::EP_EnabledOnOptLevel0, registerSplitSwitchesTransPass);