eliminate call stack recursion in compile method

Some OSes (notably, Windows CE) restrict the size of the call stack
such that recursive compilation of branch instructions can lead to
stack overflow in methods with large numbers of such instructions.  In
fact, a worst-case method could even lead to overflow when the stack
size limit is relatively generous.

The solution is to convert this recursion into iteration with an
explicit stack to maintain state about alternate paths through each
branch.
This commit is contained in:
Joel Dice 2012-10-13 09:46:12 -06:00
parent 3e0ab35ba1
commit 8f308291b0
2 changed files with 314 additions and 112 deletions

View File

@ -1388,6 +1388,10 @@ class Frame {
}
~Frame() {
dispose();
}
void dispose() {
if (level > 1) {
context->eventLog.append(PopContextEvent);
}
@ -3744,21 +3748,10 @@ isReferenceTailCall(MyThread* t, object code, unsigned ip, object caller,
(t, code, ip, caller, methodReferenceReturnCode(t, calleeReference));
}
void
compile(MyThread* t, Frame* initialFrame, unsigned ip,
int exceptionHandlerStart = -1);
void
saveStateAndCompile(MyThread* t, Frame* initialFrame, unsigned ip)
{
Compiler::State* state = initialFrame->c->saveState();
compile(t, initialFrame, ip);
initialFrame->c->restoreState(state);
}
bool
integerBranch(MyThread* t, Frame* frame, object code, unsigned& ip,
unsigned size, Compiler::Operand* a, Compiler::Operand* b)
unsigned size, Compiler::Operand* a, Compiler::Operand* b,
unsigned* newIpp)
{
if (ip + 3 > codeLength(t, code)) {
return false;
@ -3802,14 +3795,14 @@ integerBranch(MyThread* t, Frame* frame, object code, unsigned& ip,
return false;
}
saveStateAndCompile(t, frame, newIp);
*newIpp = newIp;
return true;
}
bool
floatBranch(MyThread* t, Frame* frame, object code, unsigned& ip,
unsigned size, bool lessIfUnordered, Compiler::Operand* a,
Compiler::Operand* b)
Compiler::Operand* b, unsigned* newIpp)
{
if (ip + 3 > codeLength(t, code)) {
return false;
@ -3869,7 +3862,7 @@ floatBranch(MyThread* t, Frame* frame, object code, unsigned& ip,
return false;
}
saveStateAndCompile(t, frame, newIp);
*newIpp = newIp;
return true;
}
@ -4043,17 +4036,126 @@ targetFieldOffset(Context* context, object field)
}
}
class Stack {
public:
class MyResource: public Thread::Resource {
public:
MyResource(Stack* s): Resource(s->thread), s(s) { }
virtual void release() {
s->zone.dispose();
}
Stack* s;
};
Stack(MyThread* t):
thread(t),
zone(t->m->system, t->m->heap, 0),
resource(this)
{ }
~Stack() {
zone.dispose();
}
void pushValue(uintptr_t v) {
*static_cast<uintptr_t*>(push(BytesPerWord)) = v;
}
uintptr_t peekValue(unsigned offset) {
return *static_cast<uintptr_t*>(peek((offset + 1) * BytesPerWord));
}
uintptr_t popValue() {
uintptr_t v = peekValue(0);
pop(BytesPerWord);
return v;
}
void* push(unsigned size) {
return zone.allocate(size);
}
void* peek(unsigned size) {
return zone.peek(size);
}
void pop(unsigned size) {
zone.pop(size);
}
MyThread* thread;
Zone zone;
MyResource resource;
};
class SwitchState {
public:
SwitchState(Compiler::State* state,
unsigned count,
unsigned defaultIp,
Compiler::Operand* key,
Promise* start,
int bottom,
int top):
state(state),
count(count),
defaultIp(defaultIp),
key(key),
start(start),
bottom(bottom),
top(top),
index(0)
{ }
Frame* frame() {
return reinterpret_cast<Frame*>
(reinterpret_cast<uint8_t*>(this) - pad(count * 4) - pad(sizeof(Frame)));
}
uint32_t* ipTable() {
return reinterpret_cast<uint32_t*>
(reinterpret_cast<uint8_t*>(this) - pad(count * 4));
}
Compiler::State* state;
unsigned count;
unsigned defaultIp;
Compiler::Operand* key;
Promise* start;
int bottom;
int top;
unsigned index;
};
void
compile(MyThread* t, Frame* initialFrame, unsigned ip,
int exceptionHandlerStart)
compile(MyThread* t, Frame* initialFrame, unsigned initialIp,
int exceptionHandlerStart = -1)
{
THREAD_RUNTIME_ARRAY(t, uint8_t, stackMap,
codeMaxStack(t, methodCode(t, initialFrame->context->method)));
Frame myFrame(initialFrame, RUNTIME_ARRAY_BODY(stackMap));
Frame* frame = &myFrame;
enum {
Return,
Unbranch,
Unsubroutine,
Untable0,
Untable1,
Unswitch
};
Frame* frame = initialFrame;
Compiler* c = frame->c;
Context* context = frame->context;
unsigned stackSize = codeMaxStack(t, methodCode(t, context->method));
Stack stack(t);
unsigned ip = initialIp;
unsigned newIp;
stack.pushValue(Return);
start:
uint8_t* stackMap = static_cast<uint8_t*>(stack.push(stackSize));
frame = new (stack.push(sizeof(Frame))) Frame(frame, stackMap);
loop:
object code = methodCode(t, context->method);
PROTECT(t, code);
@ -4061,7 +4163,7 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
if (context->visitTable[ip] ++) {
// we've already visited this part of the code
frame->visitLogicalIp(ip);
return;
goto next;
}
frame->startLogicalIp(ip);
@ -4318,7 +4420,7 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
case areturn: {
handleExit(t, frame);
c->return_(TargetBytesPerWord, frame->popObject());
} return;
} goto next;
case arraylength: {
frame->pushInt
@ -4363,7 +4465,7 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
if (ip == codeLength(t, code)) {
c->trap();
}
} return;
} goto next;
case bipush:
frame->pushInt
@ -4427,7 +4529,9 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
Compiler::Operand* a = frame->popLong();
Compiler::Operand* b = frame->popLong();
if (not floatBranch(t, frame, code, ip, 8, false, a, b)) {
if (floatBranch(t, frame, code, ip, 8, false, a, b, &newIp)) {
goto branch;
} else {
frame->pushInt
(c->call
(c->constant
@ -4442,7 +4546,9 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
Compiler::Operand* a = frame->popLong();
Compiler::Operand* b = frame->popLong();
if (not floatBranch(t, frame, code, ip, 8, true, a, b)) {
if (floatBranch(t, frame, code, ip, 8, true, a, b, &newIp)) {
goto branch;
} else {
frame->pushInt
(c->call
(c->constant
@ -4540,7 +4646,9 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
Compiler::Operand* a = frame->popInt();
Compiler::Operand* b = frame->popInt();
if (not floatBranch(t, frame, code, ip, 4, false, a, b)) {
if (floatBranch(t, frame, code, ip, 4, false, a, b, &newIp)) {
goto branch;
} else {
frame->pushInt
(c->call
(c->constant
@ -4553,7 +4661,9 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
Compiler::Operand* a = frame->popInt();
Compiler::Operand* b = frame->popInt();
if (not floatBranch(t, frame, code, ip, 4, true, a, b)) {
if (floatBranch(t, frame, code, ip, 4, true, a, b, &newIp)) {
goto branch;
} else {
frame->pushInt
(c->call
(c->constant
@ -4887,7 +4997,7 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
case if_acmpeq:
case if_acmpne: {
uint32_t offset = codeReadInt16(t, code, ip);
uint32_t newIp = (ip - 3) + offset;
newIp = (ip - 3) + offset;
assert(t, newIp < codeLength(t, code));
Compiler::Operand* a = frame->popObject();
@ -4899,9 +5009,7 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
} else {
c->jumpIfNotEqual(TargetBytesPerWord, a, b, target);
}
saveStateAndCompile(t, frame, newIp);
} break;
} goto branch;
case if_icmpeq:
case if_icmpne:
@ -4910,7 +5018,7 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
case if_icmplt:
case if_icmple: {
uint32_t offset = codeReadInt16(t, code, ip);
uint32_t newIp = (ip - 3) + offset;
newIp = (ip - 3) + offset;
assert(t, newIp < codeLength(t, code));
Compiler::Operand* a = frame->popInt();
@ -4939,9 +5047,7 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
default:
abort(t);
}
saveStateAndCompile(t, frame, newIp);
} break;
} goto branch;
case ifeq:
case ifne:
@ -4950,7 +5056,7 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
case iflt:
case ifle: {
uint32_t offset = codeReadInt16(t, code, ip);
uint32_t newIp = (ip - 3) + offset;
newIp = (ip - 3) + offset;
assert(t, newIp < codeLength(t, code));
Compiler::Operand* target = frame->machineIp(newIp);
@ -4980,14 +5086,12 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
default:
abort(t);
}
saveStateAndCompile(t, frame, newIp);
} break;
} goto branch;
case ifnull:
case ifnonnull: {
uint32_t offset = codeReadInt16(t, code, ip);
uint32_t newIp = (ip - 3) + offset;
newIp = (ip - 3) + offset;
assert(t, newIp < codeLength(t, code));
Compiler::Operand* a = c->constant(0, Compiler::ObjectType);
@ -4999,9 +5103,7 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
} else {
c->jumpIfNotEqual(TargetBytesPerWord, a, b, target);
}
saveStateAndCompile(t, frame, newIp);
} break;
} goto branch;
case iinc: {
uint8_t index = codeBody(t, code, ip++);
@ -5298,7 +5400,7 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
case freturn: {
handleExit(t, frame);
c->return_(4, frame->popInt());
} return;
} goto next;
case ishl: {
Compiler::Operand* a = frame->popInt();
@ -5358,7 +5460,6 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
case jsr:
case jsr_w: {
uint32_t thisIp;
uint32_t newIp;
if (instruction == jsr) {
uint32_t offset = codeReadInt16(t, code, ip);
@ -5376,10 +5477,11 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
c->jmp(frame->machineIp(newIp));
saveStateAndCompile(t, frame, newIp);
frame->endSubroutine(start);
} break;
stack.pushValue(start);
stack.pushValue(ip);
stack.pushValue(Unsubroutine);
ip = newIp;
} goto start;
case l2d: {
frame->pushLong(c->i2f(8, 8, frame->popLong()));
@ -5409,7 +5511,9 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
Compiler::Operand* a = frame->popLong();
Compiler::Operand* b = frame->popLong();
if (not integerBranch(t, frame, code, ip, 8, a, b)) {
if (integerBranch(t, frame, code, ip, 8, a, b, &newIp)) {
goto branch;
} else {
frame->pushInt
(c->call
(c->constant
@ -5567,7 +5671,8 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
(frame->addressPromise(c->machineIp(defaultIp)));
Promise* start = 0;
THREAD_RUNTIME_ARRAY(t, uint32_t, ipTable, pairCount);
uint32_t* ipTable = static_cast<uint32_t*>
(stack.push(sizeof(uint32_t) * pairCount));
for (int32_t i = 0; i < pairCount; ++i) {
unsigned index = ip + (i * 8);
int32_t key = codeReadInt32(t, code, index);
@ -5598,19 +5703,15 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
TARGET_THREAD_CODEIMAGE), address)
: address);
Compiler::State* state = c->saveState();
new (stack.push(sizeof(SwitchState))) SwitchState
(c->saveState(), pairCount, defaultIp, 0, 0, 0, 0);
for (int32_t i = 0; i < pairCount; ++i) {
compile(t, frame, RUNTIME_ARRAY_BODY(ipTable)[i]);
c->restoreState(state);
}
goto switchloop;
} else {
// a switch statement with no cases, apparently
c->jmp(frame->machineIp(defaultIp));
ip = defaultIp;
}
ip = defaultIp;
} break;
case lor: {
@ -5635,7 +5736,7 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
case dreturn: {
handleExit(t, frame);
c->return_(8, frame->popLong());
} return;
} goto next;
case lshl: {
Compiler::Operand* a = frame->popInt();
@ -6062,7 +6163,7 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
case ret: {
unsigned index = codeBody(t, code, ip);
frame->returnFromSubroutine(index);
} return;
} goto next;
case return_:
if (needsReturnBarrier(t, context->method)) {
@ -6071,7 +6172,7 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
handleExit(t, frame);
c->return_(0, 0);
return;
goto next;
case sipush:
frame->pushInt
@ -6096,7 +6197,9 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
int32_t top = codeReadInt32(t, code, ip);
Promise* start = 0;
THREAD_RUNTIME_ARRAY(t, uint32_t, ipTable, top - bottom + 1);
unsigned count = top - bottom + 1;
uint32_t* ipTable = static_cast<uint32_t*>
(stack.push(sizeof(uint32_t) * count));
for (int32_t i = 0; i < top - bottom + 1; ++i) {
unsigned index = ip + (i * 4);
uint32_t newIp = base + codeReadInt32(t, code, index);
@ -6119,43 +6222,12 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
c->save(1, key);
saveStateAndCompile(t, frame, defaultIp);
c->jumpIfGreater(4, c->constant(top, Compiler::IntegerType), key,
frame->machineIp(defaultIp));
c->save(1, key);
saveStateAndCompile(t, frame, defaultIp);
Compiler::Operand* normalizedKey
= (bottom
? c->sub(4, c->constant(bottom, Compiler::IntegerType), key) : key);
Compiler::Operand* entry = c->memory
(frame->absoluteAddressOperand(start), Compiler::AddressType, 0,
normalizedKey, TargetBytesPerWord);
c->jmp
(c->load
(TargetBytesPerWord, TargetBytesPerWord, context->bootContext
? c->add
(TargetBytesPerWord, c->memory
(c->register_(t->arch->thread()), Compiler::AddressType,
TARGET_THREAD_CODEIMAGE), entry)
: entry,
TargetBytesPerWord));
Compiler::State* state = c->saveState();
for (int32_t i = 0; i < top - bottom + 1; ++i) {
compile(t, frame, RUNTIME_ARRAY_BODY(ipTable)[i]);
c->restoreState(state);
}
new (stack.push(sizeof(SwitchState))) SwitchState
(c->saveState(), count, defaultIp, key, start, bottom, top);
stack.pushValue(Untable0);
ip = defaultIp;
} break;
} goto start;
case wide: {
switch (codeBody(t, code, ip++)) {
@ -6199,7 +6271,7 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
unsigned index = codeReadInt16(t, code, ip);
c->jmp(loadLocal(context, 1, index));
frame->returnFromSubroutine(index);
} return;
} goto next;
default: abort(t);
}
@ -6208,6 +6280,113 @@ compile(MyThread* t, Frame* initialFrame, unsigned ip,
default: abort(t);
}
}
next:
frame->dispose();
frame = 0;
stack.pop(sizeof(Frame));
stack.pop(stackSize);
switch (stack.popValue()) {
case Return:
return;
case Unbranch:
ip = stack.popValue();
c->restoreState(reinterpret_cast<Compiler::State*>(stack.popValue()));
frame = static_cast<Frame*>(stack.peek(sizeof(Frame)));
goto loop;
case Untable0: {
SwitchState* s = static_cast<SwitchState*>
(stack.peek(sizeof(SwitchState)));
frame = s->frame();
c->restoreState(s->state);
c->jumpIfGreater(4, c->constant(s->top, Compiler::IntegerType), s->key,
frame->machineIp(s->defaultIp));
c->save(1, s->key);
ip = s->defaultIp;
stack.pushValue(Untable1);
} goto start;
case Untable1: {
SwitchState* s = static_cast<SwitchState*>
(stack.peek(sizeof(SwitchState)));
frame = s->frame();
c->restoreState(s->state);
Compiler::Operand* normalizedKey
= (s->bottom
? c->sub(4, c->constant(s->bottom, Compiler::IntegerType), s->key)
: s->key);
Compiler::Operand* entry = c->memory
(frame->absoluteAddressOperand(s->start), Compiler::AddressType, 0,
normalizedKey, TargetBytesPerWord);
c->jmp
(c->load
(TargetBytesPerWord, TargetBytesPerWord, context->bootContext
? c->add
(TargetBytesPerWord, c->memory
(c->register_(t->arch->thread()), Compiler::AddressType,
TARGET_THREAD_CODEIMAGE), entry)
: entry,
TargetBytesPerWord));
s->state = c->saveState();
} goto switchloop;
case Unswitch: {
SwitchState* s = static_cast<SwitchState*>
(stack.peek(sizeof(SwitchState)));
frame = s->frame();
c->restoreState
(static_cast<SwitchState*>(stack.peek(sizeof(SwitchState)))->state);
} goto switchloop;
case Unsubroutine: {
ip = stack.popValue();
unsigned start = stack.popValue();
frame = reinterpret_cast<Frame*>(stack.peek(sizeof(Frame)));
frame->endSubroutine(start);
} goto loop;
default:
abort(t);
}
switchloop: {
SwitchState* s = static_cast<SwitchState*>
(stack.peek(sizeof(SwitchState)));
if (s->index < s->count) {
ip = s->ipTable()[s->index++];
stack.pushValue(Unswitch);
goto start;
} else {
ip = s->defaultIp;
unsigned count = s->count * 4;
stack.pop(sizeof(SwitchState));
stack.pop(count);
frame = reinterpret_cast<Frame*>(stack.peek(sizeof(Frame)));
goto loop;
}
}
branch:
stack.pushValue(reinterpret_cast<uintptr_t>(c->saveState()));
stack.pushValue(ip);
stack.pushValue(Unbranch);
ip = newIp;
goto start;
}
FILE* compileLog = 0;

View File

@ -20,10 +20,13 @@ class Zone: public Allocator {
public:
class Segment {
public:
Segment(Segment* next, unsigned size): next(next), size(size) { }
Segment(Segment* next, unsigned size):
next(next), size(size), position(0)
{ }
Segment* next;
uintptr_t size;
uintptr_t position;
uint8_t data[0];
};
@ -31,7 +34,6 @@ class Zone: public Allocator {
s(s),
allocator(allocator),
segment(0),
position(0),
minimumFootprint(minimumFootprint < sizeof(Segment) ? 0 :
minimumFootprint - sizeof(Segment))
{ }
@ -55,7 +57,7 @@ class Zone: public Allocator {
}
bool tryEnsure(unsigned space) {
if (segment == 0 or position + space > segment->size) {
if (segment == 0 or segment->position + space > segment->size) {
unsigned size = padToPage
(max
(space, max
@ -72,26 +74,24 @@ class Zone: public Allocator {
}
segment = new (p) Segment(segment, size - sizeof(Segment));
position = 0;
}
return true;
}
void ensure(unsigned space) {
if (segment == 0 or position + space > segment->size) {
if (segment == 0 or segment->position + space > segment->size) {
unsigned size = padToPage(space + sizeof(Segment));
segment = new (allocator->allocate(size))
Segment(segment, size - sizeof(Segment));
position = 0;
}
}
virtual void* tryAllocate(unsigned size) {
size = pad(size);
if (tryEnsure(size)) {
void* r = segment->data + position;
position += size;
void* r = segment->data + segment->position;
segment->position += size;
return r;
} else {
return 0;
@ -99,17 +99,41 @@ class Zone: public Allocator {
}
virtual void* allocate(unsigned size) {
size = pad(size);
void* p = tryAllocate(size);
if (p) {
return p;
} else {
ensure(size);
void* r = segment->data + position;
position += size;
void* r = segment->data + segment->position;
segment->position += size;
return r;
}
}
void* peek(unsigned size) {
size = pad(size);
Segment* s = segment;
while (s->position < size) {
size -= s->position;
s = s->next;
}
return s->data + (s->position - size);
}
void pop(unsigned size) {
size = pad(size);
Segment* s = segment;
while (s->position < size) {
size -= s->position;
Segment* next = s->next;
allocator->free(s, sizeof(Segment) + s->size);
s = next;
}
s->position -= size;
segment = s;
}
virtual void free(const void*, unsigned) {
// not supported
abort(s);
@ -119,7 +143,6 @@ class Zone: public Allocator {
Allocator* allocator;
void* context;
Segment* segment;
unsigned position;
unsigned minimumFootprint;
};