diff --git a/src/compile.cpp b/src/compile.cpp index 905aeee7a4..d0901367cc 100644 --- a/src/compile.cpp +++ b/src/compile.cpp @@ -28,7 +28,7 @@ namespace { const bool Verbose = false; const bool DebugNatives = false; -const bool DebugTraces = false; +const bool DebugCallTable = false; const bool DebugFrameMaps = false; const bool CheckArrayBounds = true; @@ -42,6 +42,7 @@ class MyThread: public Thread { ip(t->ip), base(t->base), stack(t->stack), + nativeMethod(0), next(t->trace) { t->trace = this; @@ -61,6 +62,7 @@ class MyThread: public Thread { void* ip; void* base; void* stack; + object nativeMethod; CallTrace* next; }; @@ -70,8 +72,7 @@ class MyThread: public Thread { base(0), stack(0), trace(0), - reference(0), - methodInvoked(0) + reference(0) { } void* ip; @@ -79,41 +80,60 @@ class MyThread: public Thread { void* stack; CallTrace* trace; Reference* reference; - object methodInvoked; }; +object +resolveThisPointer(MyThread* t, void* stack, object method) +{ + return reinterpret_cast(stack)[methodParameterFootprint(t, method)]; +} + object resolveTarget(MyThread* t, void* stack, object method) { - if (method) { - unsigned parameterFootprint = methodParameterFootprint(t, method); + object class_ = objectClass(t, resolveThisPointer(t, stack, method)); - object class_ = objectClass - (t, reinterpret_cast(stack)[parameterFootprint]); + if (classVmFlags(t, class_) & BootstrapFlag) { + PROTECT(t, method); + PROTECT(t, class_); - if (classVmFlags(t, class_) & BootstrapFlag) { - PROTECT(t, method); - PROTECT(t, class_); - - resolveClass(t, className(t, class_)); - if (UNLIKELY(t->exception)) return 0; - } - - if (classFlags(t, methodClass(t, method)) & ACC_INTERFACE) { - return findInterfaceMethod(t, method, class_); - } else { - return findMethod(t, method, class_); - } + resolveClass(t, className(t, class_)); + if (UNLIKELY(t->exception)) return 0; } - return method; + if (classFlags(t, methodClass(t, method)) & ACC_INTERFACE) { + return findInterfaceMethod(t, method, class_); + } else { + return findMethod(t, method, class_); + } +} + +object& +methodTree(MyThread* t); + +object +methodTreeSentinal(MyThread* t); + +intptr_t +compareIpToMethodBounds(Thread* t, intptr_t ip, object method) +{ + intptr_t start = reinterpret_cast + (&singletonValue(t, methodCompiled(t, method), 0)); + if (ip < start) { + return -1; + } else if (ip < start + singletonCount(t, methodCompiled(t, method))) { + return 0; + } else { + return 1; + } } object -findTraceNode(MyThread* t, void* address); - -void -insertTraceNode(MyThread* t, object node); +methodForIp(MyThread* t, void* ip) +{ + return treeQuery(t, methodTree(t), reinterpret_cast(ip), + methodTreeSentinal(t), compareIpToMethodBounds); +} class MyStackWalker: public Processor::StackWalker { public: @@ -124,7 +144,7 @@ class MyStackWalker: public Processor::StackWalker { { } virtual void visit(Heap::Visitor* v) { - v->visit(&(walker->node)); + v->visit(&(walker->method_)); v->visit(&(walker->nativeMethod)); } @@ -133,13 +153,12 @@ class MyStackWalker: public Processor::StackWalker { MyStackWalker(MyThread* t): t(t), + ip_(t->ip ? t->ip : (stack ? *static_cast(stack) : 0)), base(t->base), stack(t->stack), trace(t->trace), - node(t->ip ? findTraceNode(t, t->ip) : - (stack ? findTraceNode(t, *static_cast(stack)) : - 0)), - nativeMethod(resolveNativeMethod(t, stack, node)), + nativeMethod(trace->nativeMethod), + method_(ip_ ? methodForIp(t, ip_) : 0), protector(this) { } @@ -148,24 +167,11 @@ class MyStackWalker: public Processor::StackWalker { base(w->base), stack(w->stack), trace(w->trace), - node(w->node), nativeMethod(w->nativeMethod), + method_(w->method_), protector(this) { } - static object resolveNativeMethod(MyThread* t, void* stack, object node) { - if (node) { - object target = traceNodeTarget(t, node); - if (traceNodeVirtualCall(t, node)) { - target = resolveTarget(t, stack, target); - } - if (target and methodFlags(t, target) & ACC_NATIVE) { - return target; - } - } - return 0; - } - virtual void walk(Processor::StackVisitor* v) { if (stack == 0) { return; @@ -189,14 +195,13 @@ class MyStackWalker: public Processor::StackWalker { } else { stack = static_cast(base) + 1; base = *static_cast(base); - node = findTraceNode(t, *static_cast(stack)); - if (node == 0) { + method_ = methodForIp(t, *static_cast(stack)); + if (method_ == 0) { if (trace and trace->stack) { base = trace->base; stack = static_cast(trace->stack); + nativeMethod = trace->nativeMethod; trace = trace->next; - node = findTraceNode(t, *static_cast(stack)); - nativeMethod = resolveNativeMethod(t, stack, node); } else { return false; } @@ -209,7 +214,7 @@ class MyStackWalker: public Processor::StackWalker { if (nativeMethod) { return nativeMethod; } else { - return traceNodeMethod(t, node); + return method_; } } @@ -217,9 +222,8 @@ class MyStackWalker: public Processor::StackWalker { if (nativeMethod) { return 0; } else { - intptr_t start = reinterpret_cast - (&singletonValue(t, methodCompiled(t, traceNodeMethod(t, node)), 0)); - return traceNodeAddress(t, node) - start; + return reinterpret_cast(ip_) - reinterpret_cast + (&singletonValue(t, methodCompiled(t, method_), 0)); } } @@ -243,11 +247,12 @@ class MyStackWalker: public Processor::StackWalker { } MyThread* t; + void* ip_; void* base; void* stack; MyThread::CallTrace* trace; - object node; object nativeMethod; + object method_; MyProtector protector; }; @@ -303,6 +308,7 @@ class TraceElement: public Compiler::TraceHandler { Context* context; Promise* address; + intptr_t addressValue; object target; bool virtualCall; TraceElement* next; @@ -467,6 +473,7 @@ class Context { method(method), objectPool(0), traceLog(0), + traceLogCount(0), visitTable(makeVisitTable(t, &zone, method)), rootTable(makeRootTable(t, &zone, method)), eventLog(t->m->system, t->m->heap, 1024), @@ -480,6 +487,7 @@ class Context { method(0), objectPool(0), traceLog(0), + traceLogCount(0), visitTable(0), rootTable(0), eventLog(t->m->system, t->m->heap, 0), @@ -496,6 +504,7 @@ class Context { object method; PoolElement* objectPool; TraceElement* traceLog; + unsigned traceLogCount; uint16_t* visitTable; uintptr_t* rootTable; bool dirtyRoots; @@ -1075,6 +1084,8 @@ class Frame { (context->zone.allocate(sizeof(TraceElement) + (mapSize * BytesPerWord))) TraceElement(context, target, virtualCall, context->traceLog); + ++ context->traceLogCount; + context->eventLog.append(TraceEvent); context->eventLog.appendAddress(e); @@ -1097,6 +1108,15 @@ savedTargetIndex(MyThread* t, object method) return codeMaxLocals(t, methodCode(t, method)); } +object +findCallNode(MyThread* t, void* address); + +void +insertCallNode(MyThread* t, object node); + +void +removeCallNode(MyThread* t, object node); + void findUnwindTarget(MyThread* t, void** targetIp, void** targetBase, void** targetStack) @@ -1112,9 +1132,8 @@ findUnwindTarget(MyThread* t, void** targetIp, void** targetBase, *targetIp = 0; while (*targetIp == 0) { - object node = findTraceNode(t, ip); - if (node) { - object method = traceNodeMethod(t, node); + object method = methodForIp(t, ip); + if (method) { PROTECT(t, method); uint8_t* compiled = reinterpret_cast @@ -3659,6 +3678,29 @@ calculateFrameMaps(MyThread* t, Context* context, uintptr_t* originalRoots, Allocator* codeAllocator(MyThread* t); +int +compareTraceElementPointers(const void* va, const void* vb) +{ + TraceElement* a = *static_cast(va); + TraceElement* b = *static_cast(vb); + if (a->addressValue > b->addressValue) { + return 1; + } else if (a->addressValue < b->addressValue) { + return -1; + } else { + return 0; + } +} + +intptr_t +compareMethodBounds(Thread* t, object a, object b) +{ + return reinterpret_cast + (&singletonValue(t, methodCompiled(t, a), 0)) + - reinterpret_cast + (&singletonValue(t, methodCompiled(t, b), 0)); +} + object finish(MyThread* t, Context* context, const char* name) { @@ -3680,18 +3722,61 @@ finish(MyThread* t, Context* context, const char* name) if (context->method) { PROTECT(t, result); - unsigned mapSize = frameMapSizeInWords(t, context->method); + { object code = methodCode(t, context->method); - for (TraceElement* p = context->traceLog; p; p = p->next) { - object node = makeTraceNode - (t, p->address->value(c), 0, context->method, p->target, - p->virtualCall, mapSize, false); + code = makeCode(t, 0, + codeExceptionHandlerTable(t, code), + codeLineNumberTable(t, code), + codeMaxStack(t, code), + codeMaxLocals(t, code), + 0, false); - if (mapSize) { - memcpy(&traceNodeMap(t, node, 0), p->map, mapSize * BytesPerWord); + set(t, context->method, MethodCode, code); + } + + if (context->traceLogCount) { + TraceElement* elements[context->traceLogCount]; + unsigned index = 0; + for (TraceElement* p = context->traceLog; p; p = p->next) { + elements[index++] = p; + p->addressValue = p->address->value(c); + + if (p->target) { + insertCallNode + (t, makeCallNode + (t, p->address->value(c), p->target, p->virtualCall, 0)); + } } - insertTraceNode(t, node); + qsort(elements, context->traceLogCount, sizeof(TraceElement*), + compareTraceElementPointers); + + unsigned size = frameSize(t, context->method); + object map = makeIntArray + (t, context->traceLogCount + + ceiling(context->traceLogCount * size, 32), + false); + + for (unsigned i = 0; i < context->traceLogCount; ++i) { + TraceElement* p = elements[i]; + + intArrayBody(t, map, i) = static_cast(p->addressValue) + - reinterpret_cast(start); + + for (unsigned j = 0; j < size; ++j) { + unsigned index = ((i * size) + j); + int32_t* v = &intArrayBody + (t, map, context->traceLogCount + (index / 32)); + + if (getBit(p->map, j)) { + *v |= static_cast(1) << (index % 32); + } else { + *v &= ~(static_cast(1) << (index % 32)); + } + } + } + + set(t, methodCode(t, context->method), CodePool, map); } for (PoolElement* p = context->objectPool; p; p = p->next) { @@ -3709,6 +3794,13 @@ finish(MyThread* t, Context* context, const char* name) updateLineNumberTable(t, c, methodCode(t, context->method), reinterpret_cast(start)); + { object node = makeTreeNode + (t, context->method, methodTreeSentinal(t), methodTreeSentinal(t)); + + methodTree(t) = treeInsert + (t, methodTree(t), node, methodTreeSentinal(t), compareMethodBounds); + } + if (Verbose) { logCompile (start, c->codeSize(), @@ -3864,13 +3956,13 @@ compile(MyThread* t, object method); void* compileMethod2(MyThread* t) { - object node = findTraceNode(t, *static_cast(t->stack)); + object node = findCallNode(t, *static_cast(t->stack)); PROTECT(t, node); - object target = traceNodeTarget(t, node); + object target = callNodeTarget(t, node); PROTECT(t, target); - if (traceNodeVirtualCall(t, node)) { + if (callNodeVirtualCall(t, node)) { target = resolveTarget(t, t->stack, target); } @@ -3881,10 +3973,19 @@ compileMethod2(MyThread* t) if (UNLIKELY(t->exception)) { return 0; } else { - if (not traceNodeVirtualCall(t, node)) { + if (callNodeVirtualCall(t, node)) { + classVtable + (t, objectClass + (t, resolveThisPointer(t, t->stack, target)), methodOffset(t, target)) + = &singletonValue(t, methodCompiled(t, target), 0); + } else { + ACQUIRE(t, t->m->classLock); + + removeCallNode(t, node); + Context context(t); context.c->updateCall - (reinterpret_cast(traceNodeAddress(t, node)), + (reinterpret_cast(callNodeAddress(t, node)), &singletonValue(t, methodCompiled(t, target), 0)); } return &singletonValue(t, methodCompiled(t, target), 0); @@ -4088,22 +4189,19 @@ invokeNative2(MyThread* t, object method) uint64_t FORCE_ALIGN invokeNative(MyThread* t) { - object node = findTraceNode(t, *static_cast(t->stack)); - object target; - if (node) { - target = traceNodeTarget(t, node); - if (traceNodeVirtualCall(t, node)) { - target = resolveTarget(t, t->stack, target); + if (t->trace->nativeMethod == 0) { + object node = findCallNode(t, *static_cast(t->stack)); + t->trace->nativeMethod = callNodeTarget(t, node); + if (callNodeVirtualCall(t, node)) { + t->trace->nativeMethod = resolveTarget + (t, t->stack, t->trace->nativeMethod); } - } else { - target = t->methodInvoked; - t->methodInvoked = 0; } uint64_t result = 0; if (LIKELY(t->exception == 0)) { - result = invokeNative2(t, target); + result = invokeNative2(t, t->trace->nativeMethod); } if (UNLIKELY(t->exception)) { @@ -4112,13 +4210,36 @@ invokeNative(MyThread* t) return result; } } +unsigned +frameMapIndex(MyThread* t, object method, int32_t offset) +{ + object map = codePool(t, methodCode(t, method)); + unsigned mapSize = ceiling + (intArrayLength(t, map), (32 / frameSize(t, method)) + 1); + unsigned indexSize = intArrayLength(t, map) - mapSize; + + unsigned bottom = 0; + unsigned top = indexSize; + for (unsigned span = top - bottom; span; span = top - bottom) { + unsigned middle = bottom + (span / 2); + int32_t v = intArrayBody(t, map, middle); + + if (offset == v) { + return (indexSize * 32) + (frameSize(t, method) * middle); + } else if (offset < v) { + top = middle; + } else { + bottom = middle + 1; + } + } + + abort(t); +} void -visitStackAndLocals(MyThread* t, Heap::Visitor* v, void* base, object node, - void* calleeBase, unsigned argumentFootprint) +visitStackAndLocals(MyThread* t, Heap::Visitor* v, void* base, object method, + void* ip, void* calleeBase, unsigned argumentFootprint) { - object method = traceNodeMethod(t, node); - unsigned count; if (calleeBase) { unsigned parameterFootprint = methodParameterFootprint(t, method); @@ -4131,11 +4252,17 @@ visitStackAndLocals(MyThread* t, Heap::Visitor* v, void* base, object node, } if (count) { - uintptr_t* map = &traceNodeMap(t, node, 0); + object map = codePool(t, methodCode(t, method)); + int index = frameMapIndex + (t, method, difference + (ip, &singletonValue(t, methodCompiled(t, method), 0))); for (unsigned i = 0; i < count; ++i) { - if (getBit(map, i)) { - v->visit(localObject(t, base, method, i)); + int j = index + i; + if ((intArrayBody(t, map, j / 32) + & (static_cast(1) << (j % 32)))) + { + v->visit(localObject(t, base, method, i)); } } } @@ -4156,15 +4283,15 @@ visitStack(MyThread* t, Heap::Visitor* v) unsigned argumentFootprint = 0; while (stack) { - object node = findTraceNode(t, ip); - if (node) { - PROTECT(t, node); + object method = methodForIp(t, ip); + if (method) { + PROTECT(t, method); - visitStackAndLocals(t, v, base, node, calleeBase, argumentFootprint); + visitStackAndLocals + (t, v, base, method, ip, calleeBase, argumentFootprint); calleeBase = base; - argumentFootprint = methodParameterFootprint - (t, traceNodeMethod(t, node)); + argumentFootprint = methodParameterFootprint(t, method); stack = static_cast(base) + 1; if (stack) { @@ -4348,21 +4475,19 @@ invoke(Thread* thread, object method, ArgumentList* arguments) unsigned returnCode = methodReturnCode(t, method); unsigned returnType = fieldType(t, returnCode); - if (methodFlags(t, method) & ACC_NATIVE) { - t->methodInvoked = method; - } - uint64_t result; { MyThread::CallTrace trace(t); + if (methodFlags(t, method) & ACC_NATIVE) { + trace.nativeMethod = method; + } + result = vmInvoke (t, &singletonValue(t, methodCompiled(t, method), 0), arguments->array, arguments->position, returnType); } - assert(t, t->methodInvoked == 0); - object r; switch (returnCode) { case ByteField: @@ -4403,7 +4528,7 @@ class SegFaultHandler: public System::SignalHandler { { MyThread* t = static_cast(m->localThread->get()); if (t->state == Thread::ActiveState) { - object node = findTraceNode(t, *ip); + object node = findCallNode(t, *ip); if (node) { t->ip = *ip; t->base = *base; @@ -4428,8 +4553,10 @@ class MyProcessor: public Processor { allocator(allocator), defaultCompiled(0), nativeCompiled(0), - addressTable(0), - addressCount(0), + callTable(0), + callTableSize(0), + methodTree(0), + methodTreeSentinal(0), indirectCaller(0), indirectCallerSize(0), codeAllocator(s, allocator, true, 64 * 1024) @@ -4540,10 +4667,14 @@ class MyProcessor: public Processor { if (t == t->m->rootThread) { v->visit(&defaultCompiled); v->visit(&nativeCompiled); - v->visit(&addressTable); + v->visit(&callTable); + v->visit(&methodTree); + v->visit(&methodTreeSentinal); } - v->visit(&(t->methodInvoked)); + for (MyThread::CallTrace* trace = t->trace; trace; trace = trace->next) { + v->visit(&(trace->nativeMethod)); + } for (Reference* r = t->reference; r; r = r->next) { v->visit(&(r->target)); @@ -4704,8 +4835,10 @@ class MyProcessor: public Processor { Allocator* allocator; object defaultCompiled; object nativeCompiled; - object addressTable; - unsigned addressCount; + object callTable; + unsigned callTableSize; + object methodTree; + object methodTreeSentinal; uint8_t* indirectCaller; unsigned indirectCallerSize; SegFaultHandler segFaultHandler; @@ -4716,11 +4849,15 @@ MyProcessor* processor(MyThread* t) { MyProcessor* p = static_cast(t->m->processor); - if (p->addressTable == 0) { + if (p->callTable == 0) { ACQUIRE(t, t->m->classLock); - if (p->addressTable == 0) { - p->addressTable = makeArray(t, 128, true); + if (p->callTable == 0) { + p->callTable = makeArray(t, 128, true); + + p->methodTree = p->methodTreeSentinal = makeTreeNode(t, 0, 0, 0); + set(t, p->methodTree, TreeNodeLeft, p->methodTreeSentinal); + set(t, p->methodTree, TreeNodeRight, p->methodTreeSentinal); Context context(t); Compiler* c = context.c; @@ -4782,23 +4919,23 @@ compile(MyThread* t, object method) } object -findTraceNode(MyThread* t, void* address) +findCallNode(MyThread* t, void* address) { - if (DebugTraces) { + if (DebugCallTable) { fprintf(stderr, "find trace node %p\n", address); } MyProcessor* p = processor(t); - object table = p->addressTable; + object table = p->callTable; intptr_t key = reinterpret_cast(address); unsigned index = static_cast(key) & (arrayLength(t, table) - 1); for (object n = arrayBody(t, table, index); - n; n = traceNodeNext(t, n)) + n; n = callNodeNext(t, n)) { - intptr_t k = traceNodeAddress(t, n); + intptr_t k = callNodeAddress(t, n); if (k == key) { return n; @@ -4822,26 +4959,17 @@ resizeTable(MyThread* t, object oldTable, unsigned newLength) for (unsigned i = 0; i < arrayLength(t, oldTable); ++i) { for (oldNode = arrayBody(t, oldTable, i); oldNode; - oldNode = traceNodeNext(t, oldNode)) + oldNode = callNodeNext(t, oldNode)) { - intptr_t k = traceNodeAddress(t, oldNode); + intptr_t k = callNodeAddress(t, oldNode); unsigned index = k & (newLength - 1); - object newNode = makeTraceNode - (t, traceNodeAddress(t, oldNode), - arrayBody(t, newTable, index), - traceNodeMethod(t, oldNode), - traceNodeTarget(t, oldNode), - traceNodeVirtualCall(t, oldNode), - traceNodeLength(t, oldNode), - false); - - if (traceNodeLength(t, oldNode)) { - memcpy(&traceNodeMap(t, newNode, 0), - &traceNodeMap(t, oldNode, 0), - traceNodeLength(t, oldNode) * BytesPerWord); - } + object newNode = makeCallNode + (t, callNodeAddress(t, oldNode), + callNodeTarget(t, oldNode), + callNodeVirtualCall(t, oldNode), + arrayBody(t, newTable, index)); set(t, newTable, ArrayBody + (index * BytesPerWord), newNode); } @@ -4851,29 +4979,85 @@ resizeTable(MyThread* t, object oldTable, unsigned newLength) } void -insertTraceNode(MyThread* t, object node) +insertCallNode(MyThread* t, object node) { - if (DebugTraces) { + if (DebugCallTable) { fprintf(stderr, "insert trace node %p\n", - reinterpret_cast(traceNodeAddress(t, node))); + reinterpret_cast(callNodeAddress(t, node))); } MyProcessor* p = processor(t); PROTECT(t, node); - ++ p->addressCount; + ++ p->callTableSize; - if (p->addressCount >= arrayLength(t, p->addressTable) * 2) { - p->addressTable = resizeTable - (t, p->addressTable, arrayLength(t, p->addressTable) * 2); + if (p->callTableSize >= arrayLength(t, p->callTable) * 2) { + p->callTable = resizeTable + (t, p->callTable, arrayLength(t, p->callTable) * 2); } - intptr_t key = traceNodeAddress(t, node); + intptr_t key = callNodeAddress(t, node); unsigned index = static_cast(key) - & (arrayLength(t, p->addressTable) - 1); + & (arrayLength(t, p->callTable) - 1); - set(t, node, TraceNodeNext, arrayBody(t, p->addressTable, index)); - set(t, p->addressTable, ArrayBody + (index * BytesPerWord), node); + set(t, node, CallNodeNext, arrayBody(t, p->callTable, index)); + set(t, p->callTable, ArrayBody + (index * BytesPerWord), node); +} + +void +removeCallNode(MyThread* t, object node) +{ + if (DebugCallTable) { + fprintf(stderr, "remove call node %p\n", + reinterpret_cast(callNodeAddress(t, node))); + } + + MyProcessor* p = processor(t); + PROTECT(t, node); + + object oldNode = 0; + PROTECT(t, oldNode); + + object newNode = 0; + PROTECT(t, newNode); + + intptr_t key = callNodeAddress(t, node); + unsigned index = static_cast(key) + & (arrayLength(t, p->callTable) - 1); + + for (oldNode = arrayBody(t, p->callTable, index); + oldNode; + oldNode = callNodeNext(t, oldNode)) + { + if (oldNode != node) { + newNode = makeCallNode + (t, callNodeAddress(t, oldNode), + callNodeTarget(t, oldNode), + callNodeVirtualCall(t, oldNode), + newNode); + } + } + + set(t, p->callTable, ArrayBody + (index * BytesPerWord), newNode); + + -- p->callTableSize; + + if (p->callTableSize <= arrayLength(t, p->callTable) / 3) { + p->callTable = resizeTable + (t, p->callTable, arrayLength(t, p->callTable) / 2); + } +} + +object& +methodTree(MyThread* t) +{ + return processor(t)->methodTree; +} + +object +methodTreeSentinal(MyThread* t) +{ + return processor(t)->methodTreeSentinal; } Allocator* diff --git a/src/types.def b/src/types.def index a14af97b5e..75ee570ec0 100644 --- a/src/types.def +++ b/src/types.def @@ -91,13 +91,22 @@ (object method) (int ip)) -(type traceNode +(type treeNode + (object value) + (object left) + (object right)) + +(type treePath + (uintptr_t fresh) + (object node) + (object root) + (object ancestors)) + +(type callNode (intptr_t address) - (object next) - (object method) (object target) (uintptr_t virtualCall) - (array uintptr_t map)) + (object next)) (type array (noassert array object body)) diff --git a/src/util.cpp b/src/util.cpp index d983951a06..25f07f08e3 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -12,6 +12,235 @@ using namespace vm; +namespace { + +inline object +cloneTreeNode(Thread* t, object n) +{ + return makeTreeNode + (t, treeNodeValue(t, n), treeNodeLeft(t, n), treeNodeRight(t, n)); +} + +inline object +getTreeNodeValue(Thread*, object n) +{ + return reinterpret_cast + (cast(n, TreeNodeValue) & PointerMask); +} + +inline void +setTreeNodeValue(Thread* t, object n, object value) +{ + intptr_t red = cast(n, TreeNodeValue) & (~PointerMask); + + set(t, n, TreeNodeValue, value); + + cast(n, TreeNodeValue) |= red; +} + +inline bool +treeNodeRed(Thread*, object n) +{ + return (cast(n, TreeNodeValue) & (~PointerMask)) != 1; +} + +inline void +setTreeNodeRed(Thread*, object n, intptr_t red) +{ + cast(n, TreeNodeValue) |= red; +} + +object +treeFind(Thread* t, object old, object node, object sentinal, + intptr_t (*compare)(Thread* t, object a, object b)) +{ + PROTECT(t, old); + PROTECT(t, node); + PROTECT(t, sentinal); + + object newRoot = cloneTreeNode(t, old); + PROTECT(t, newRoot); + + object new_ = newRoot; + PROTECT(t, new_); + + object ancestors = 0; + PROTECT(t, ancestors); + + while (old != sentinal) { + ancestors = makePair(t, new_, ancestors); + + intptr_t difference = compare + (t, getTreeNodeValue(t, node), getTreeNodeValue(t, node)); + + if (difference < 0) { + old = treeNodeLeft(t, old); + object n = cloneTreeNode(t, old); + set(t, new_, TreeNodeLeft, n); + new_ = n; + } else if (difference > 0) { + old = treeNodeRight(t, old); + object n = cloneTreeNode(t, old); + set(t, new_, TreeNodeRight, n); + new_ = n; + } else { + return makeTreePath(t, false, new_, newRoot, pairSecond(t, ancestors)); + } + } + + setTreeNodeValue(t, new_, getTreeNodeValue(t, node)); + + return makeTreePath(t, true, new_, newRoot, ancestors); +} + +object +leftRotate(Thread* t, object n) +{ + object child = cloneTreeNode(t, treeNodeRight(t, n)); + set(t, n, TreeNodeRight, treeNodeLeft(t, child)); + set(t, child, TreeNodeLeft, n); + return child; +} + +object +rightRotate(Thread* t, object n) +{ + object child = cloneTreeNode(t, treeNodeLeft(t, n)); + set(t, n, TreeNodeLeft, treeNodeRight(t, child)); + set(t, child, TreeNodeRight, n); + return child; +} + +object +treeAdd(Thread* t, object path) +{ + object new_ = treePathNode(t, path); + PROTECT(t, new_); + + object newRoot = treePathRoot(t, path); + PROTECT(t, newRoot); + + object ancestors = treePathAncestors(t, path); + PROTECT(t, ancestors); + + // rebalance + setTreeNodeRed(t, new_, 1); + while (ancestors != 0 and treeNodeRed(t, pairFirst(t, ancestors))) { + if (pairFirst(t, ancestors) + == treeNodeLeft(t, pairFirst(t, pairSecond(t, ancestors)))) + { + if (treeNodeRed + (t, treeNodeRight(t, pairFirst(t, pairSecond(t, ancestors))))) + { + setTreeNodeRed(t, pairFirst(t, ancestors), 1); + + object n = cloneTreeNode + (t, treeNodeRight(t, pairFirst(t, pairSecond(t, ancestors)))); + + set(t, pairFirst(t, pairSecond(t, ancestors)), TreeNodeRight, n); + + setTreeNodeRed + (t, treeNodeRight + (t, pairFirst(t, pairSecond(t, ancestors))), 0); + + setTreeNodeRed(t, pairFirst(t, pairSecond(t, ancestors)), 0); + + new_ = pairFirst(t, pairSecond(t, ancestors)); + ancestors = pairSecond(t, pairSecond(t, ancestors)); + } else { + if (new_ == treeNodeRight(t, pairFirst(t, ancestors))) { + new_ = pairFirst(t, ancestors); + ancestors = pairSecond(t, ancestors); + + object n = leftRotate(t, new_); + + if (new_ == treeNodeRight(t, pairFirst(t, ancestors))) { + set(t, pairFirst(t, ancestors), TreeNodeRight, n); + } else { + set(t, pairFirst(t, ancestors), TreeNodeLeft, n); + } + ancestors = makePair(t, n, ancestors); + } + setTreeNodeRed(t, pairFirst(t, ancestors), 0); + setTreeNodeRed(t, pairFirst(t, pairSecond(t, ancestors)), 1); + + object n = rightRotate(t, pairFirst(t, pairSecond(t, ancestors))); + if (pairSecond(t, pairSecond(t, ancestors)) == 0) { + newRoot = n; + } else if (treeNodeRight + (t, pairFirst(t, pairSecond(t, pairSecond(t, ancestors)))) + == pairFirst(t, pairSecond(t, ancestors))) + { + set(t, pairFirst(t, pairSecond(t, pairSecond(t, ancestors))), + TreeNodeRight, n); + } else { + set(t, pairFirst(t, pairSecond(t, pairSecond(t, ancestors))), + TreeNodeLeft, n); + } + // done + } + } else { // this is just the reverse of the code above (right and + // left swapped): + if (treeNodeRed + (t, treeNodeLeft(t, pairFirst(t, pairSecond(t, ancestors))))) + { + setTreeNodeRed(t, pairFirst(t, ancestors), 1); + + object n = cloneTreeNode + (t, treeNodeLeft(t, pairFirst(t, pairSecond(t, ancestors)))); + + set(t, pairFirst(t, pairSecond(t, ancestors)), TreeNodeLeft, n); + + setTreeNodeRed + (t, treeNodeLeft + (t, pairFirst(t, pairSecond(t, ancestors))), 0); + + setTreeNodeRed(t, pairFirst(t, pairSecond(t, ancestors)), 0); + + new_ = pairFirst(t, pairSecond(t, ancestors)); + ancestors = pairSecond(t, pairSecond(t, ancestors)); + } else { + if (new_ == treeNodeLeft(t, pairFirst(t, ancestors))) { + new_ = pairFirst(t, ancestors); + ancestors = pairSecond(t, ancestors); + + object n = rightRotate(t, new_); + + if (new_ == treeNodeLeft(t, pairFirst(t, ancestors))) { + set(t, pairFirst(t, ancestors), TreeNodeLeft, n); + } else { + set(t, pairFirst(t, ancestors), TreeNodeRight, n); + } + ancestors = makePair(t, n, ancestors); + } + setTreeNodeRed(t, pairFirst(t, ancestors), 0); + setTreeNodeRed(t, pairFirst(t, pairSecond(t, ancestors)), 1); + + object n = leftRotate(t, pairFirst(t, pairSecond(t, ancestors))); + if (pairSecond(t, pairSecond(t, ancestors)) == 0) { + newRoot = n; + } else if (treeNodeLeft + (t, pairFirst(t, pairSecond(t, pairSecond(t, ancestors)))) + == pairFirst(t, pairSecond(t, ancestors))) + { + set(t, pairFirst(t, pairSecond(t, pairSecond(t, ancestors))), + TreeNodeLeft, n); + } else { + set(t, pairFirst(t, pairSecond(t, pairSecond(t, ancestors))), + TreeNodeRight, n); + } + // done + } + } + } + + setTreeNodeRed(t, newRoot, 0); + + return newRoot; +} + +} // namespace + namespace vm { object @@ -275,4 +504,35 @@ vectorAppend(Thread* t, object vector, object value) return vector; } +object +treeQuery(Thread* t, object tree, intptr_t key, object sentinal, + intptr_t (*compare)(Thread* t, intptr_t key, object b)) +{ + object node = tree; + while (node != sentinal) { + intptr_t difference = compare(t, key, getTreeNodeValue(t, node)); + if (difference < 0) { + node = treeNodeLeft(t, node); + } else if (difference > 0) { + node = treeNodeRight(t, node); + } else { + return node; + } + } + + return 0; +} + +object +treeInsert(Thread* t, object tree, object node, object sentinal, + intptr_t (*compare)(Thread* t, object a, object b)) +{ + object path = treeFind(t, tree, node, sentinal, compare); + if (treePathFresh(t, path)) { + return treeAdd(t, path); + } else { + return tree; + } +} + } // namespace vm diff --git a/src/util.h b/src/util.h index 4d74aa7d9e..12ad95e0f6 100644 --- a/src/util.h +++ b/src/util.h @@ -83,6 +83,14 @@ listAppend(Thread* t, object list, object value); object vectorAppend(Thread* t, object vector, object value); +object +treeQuery(Thread* t, object tree, intptr_t key, object sentinal, + intptr_t (*compare)(Thread* t, intptr_t key, object b)); + +object +treeInsert(Thread* t, object tree, object node, object sentinal, + intptr_t (*compare)(Thread* t, object a, object b)); + } // vm #endif//UTIL_H