package java.util; import java.util.Comparator; import java.lang.Iterable; public class PersistentSet implements Iterable { private static final Node NullNode = new Node(null); static { NullNode.left = NullNode; NullNode.right = NullNode; } private final Node root; private final Comparator comparator; private final int size; public PersistentSet() { this(NullNode, new Comparator() { public int compare(T a, T b) { return ((Comparable) a).compareTo(b); } }, 0); } public PersistentSet(Comparator comparator) { this(NullNode, comparator, 0); } private PersistentSet(Node root, Comparator comparator, int size) { this.root = root; this.comparator = comparator; this.size = size; } public Comparator comparator() { return comparator; } public PersistentSet add(T value) { return add(value, false); } public int size() { return size; } public PersistentSet add(T value, boolean replaceExisting) { Path p = find(value); if (! p.fresh) { if (replaceExisting) { return p.replaceWith(value); } else { return this; } } return add(p); } private PersistentSet add(Path p) { if (! p.fresh) throw new IllegalArgumentException(); Node new_ = p.node; Node newRoot = p.root.root; Cell> ancestors = p.ancestors; // rebalance new_.red = true; while (ancestors != null && ancestors.value.red) { if (ancestors.value == ancestors.next.value.left) { if (ancestors.next.value.right.red) { ancestors.value.red = false; ancestors.next.value.right = new Node(ancestors.next.value.right); ancestors.next.value.right.red = false; ancestors.next.value.red = true; new_ = ancestors.next.value; ancestors = ancestors.next.next; } else { if (new_ == ancestors.value.right) { new_ = ancestors.value; ancestors = ancestors.next; Node n = leftRotate(new_); if (ancestors.value.right == new_) { ancestors.value.right = n; } else { ancestors.value.left = n; } ancestors = new Cell(n, ancestors); } ancestors.value.red = false; ancestors.next.value.red = true; Node n = rightRotate(ancestors.next.value); if (ancestors.next.next == null) { newRoot = n; } else if (ancestors.next.next.value.right == ancestors.next.value) { ancestors.next.next.value.right = n; } else { ancestors.next.next.value.left = n; } // done } } else { if (ancestors.next.value.left.red) { ancestors.value.red = false; ancestors.next.value.left = new Node(ancestors.next.value.left); ancestors.next.value.left.red = false; ancestors.next.value.red = true; new_ = ancestors.next.value; ancestors = ancestors.next.next; } else { if (new_ == ancestors.value.left) { new_ = ancestors.value; ancestors = ancestors.next; Node n = rightRotate(new_); if (ancestors.value.right == new_) { ancestors.value.right = n; } else { ancestors.value.left = n; } ancestors = new Cell(n, ancestors); } ancestors.value.red = false; ancestors.next.value.red = true; Node n = leftRotate(ancestors.next.value); if (ancestors.next.next == null) { newRoot = n; } else if (ancestors.next.next.value.right == ancestors.next.value) { ancestors.next.next.value.right = n; } else { ancestors.next.next.value.left = n; } // done } } } newRoot.red = false; return new PersistentSet(newRoot, comparator, size + 1); } private static Node leftRotate(Node n) { Node child = new Node(n.right); n.right = child.left; child.left = n; return child; } private static Node rightRotate(Node n) { Node child = new Node(n.left); n.left = child.right; child.right = n; return child; } public PersistentSet remove(T value) { Path p = find(value); if (! p.fresh) { return remove(p); } return this; } private PersistentSet remove(Path p) { Node new_ = p.node; Node newRoot = p.root.root; Cell> ancestors = p.ancestors; Node dead; if (new_.left == NullNode || new_.right == NullNode) { dead = new_; } else { Cell> path = successor(new_, ancestors); dead = path.value; ancestors = path.next; } Node child; if (dead.left != NullNode) { child = dead.left; } else { child = dead.right; } if (ancestors == null) { child.red = false; return new PersistentSet(child, comparator, 1); } else if (dead == ancestors.value.left) { ancestors.value.left = child; } else { ancestors.value.right = child; } if (dead != new_) { new_.value = dead.value; } if (! dead.red) { // rebalance while (ancestors != null && ! child.red) { if (child == ancestors.value.left) { Node sibling = ancestors.value.right = new Node(ancestors.value.right); if (sibling.red) { sibling.red = false; ancestors.value.red = true; Node n = leftRotate(ancestors.value); if (ancestors.next == null) { newRoot = n; } else if (ancestors.next.value.right == ancestors.value) { ancestors.next.value.right = n; } else { ancestors.next.value.left = n; } ancestors.next = new Cell(n, ancestors.next); sibling = ancestors.value.right; } if (! (sibling.left.red || sibling.right.red)) { sibling.red = true; child = ancestors.value; ancestors = ancestors.next; } else { if (! sibling.right.red) { sibling.left = new Node(sibling.left); sibling.left.red = false; sibling.red = true; sibling = ancestors.value.right = rightRotate(sibling); } sibling.red = ancestors.value.red; ancestors.value.red = false; sibling.right = new Node(sibling.right); sibling.right.red = false; Node n = leftRotate(ancestors.value); if (ancestors.next == null) { newRoot = n; } else if (ancestors.next.value.right == ancestors.value) { ancestors.next.value.right = n; } else { ancestors.next.value.left = n; } child = newRoot; ancestors = null; } } else { Node sibling = ancestors.value.left = new Node(ancestors.value.left); if (sibling.red) { sibling.red = false; ancestors.value.red = true; Node n = rightRotate(ancestors.value); if (ancestors.next == null) { newRoot = n; } else if (ancestors.next.value.left == ancestors.value) { ancestors.next.value.left = n; } else { ancestors.next.value.right = n; } ancestors.next = new Cell(n, ancestors.next); sibling = ancestors.value.left; } if (! (sibling.right.red || sibling.left.red)) { sibling.red = true; child = ancestors.value; ancestors = ancestors.next; } else { if (! sibling.left.red) { sibling.right = new Node(sibling.right); sibling.right.red = false; sibling.red = true; sibling = ancestors.value.left = leftRotate(sibling); } sibling.red = ancestors.value.red; ancestors.value.red = false; sibling.left = new Node(sibling.left); sibling.left.red = false; Node n = rightRotate(ancestors.value); if (ancestors.next == null) { newRoot = n; } else if (ancestors.next.value.left == ancestors.value) { ancestors.next.value.left = n; } else { ancestors.next.value.right = n; } child = newRoot; ancestors = null; } } } child.red = false; } return new PersistentSet(newRoot, comparator, size - 1); } private static Cell> minimum(Node n, Cell> ancestors) { while (n.left != NullNode) { n.left = new Node(n.left); ancestors = new Cell(n, ancestors); n = n.left; } return new Cell(n, ancestors); } private static Cell> successor(Node n, Cell> ancestors) { if (n.right != NullNode) { n.right = new Node(n.right); return minimum(n.right, new Cell(n, ancestors)); } while (ancestors != null && n == ancestors.value.right) { n = ancestors.value; ancestors = ancestors.next; } return ancestors; } public Path find(T value) { Node newRoot = new Node(root); Cell> ancestors = null; Node old = root; Node new_ = newRoot; while (old != NullNode) { ancestors = new Cell(new_, ancestors); int difference = comparator.compare(value, old.value); if (difference < 0) { old = old.left; new_ = new_.left = new Node(old); } else if (difference > 0) { old = old.right; new_ = new_.right = new Node(old); } else { return new Path(false, new_, new PersistentSet(newRoot, comparator, size), ancestors.next); } } new_.value = value; return new Path(true, new_, new PersistentSet(newRoot, comparator, size), ancestors); } public Path first() { if (root == NullNode) return null; Node newRoot = new Node(root); Cell> ancestors = null; Node old = root; Node new_ = newRoot; while (old.left != NullNode) { ancestors = new Cell(new_, ancestors); old = old.left; new_ = new_.left = new Node(old); } return new Path(true, new_, new PersistentSet(newRoot, comparator, size), ancestors); } public Path last() { if (root == NullNode) return null; Node newRoot = new Node(root); Cell> ancestors = null; Node old = root; Node new_ = newRoot; while (old.right != NullNode) { ancestors = new Cell(new_, ancestors); old = old.right; new_ = new_.right = new Node(old); } return new Path(true, new_, new PersistentSet(newRoot, comparator, size), ancestors); } public java.util.Iterator iterator() { return new Iterator(first()); } private Path successor(Path p) { Cell> s = successor(p.node, p.ancestors); if (s == null) { return null; } else { return new Path(false, s.value, p.root, s.next); } } // public void dump(java.io.PrintStream out) { // dump(root, out, 0); // } // private static void indent(java.io.PrintStream out, int level) { // for (int i = 0; i < level; ++i) out.print(" "); // } // private static void dump(Node n, java.io.PrintStream out, int level) { // indent(out, level); // out.print(n == NullNode ? null : n.value); // out.println(n == NullNode ? "" : n.red ? " (red)" : " (black)"); // if (n.left != NullNode || n.right != NullNode) { // dump(n.left, out, level + 1); // dump(n.right, out, level + 1); // } // } // private static int[] randomSet(java.util.Random r, int size) { // int[] data = new int[size]; // for (int i = size - 1; i >= 0; --i) { // data[i] = i + 1; // } // for (int i = size - 1; i >= 0; --i) { // int n = r.nextInt(size); // int tmp = data[i]; // data[i] = data[n]; // data[n] = tmp; // } // return data; // } // public static void main(String[] args) { // java.util.Random r = new java.util.Random(Integer.parseInt(args[0])); // int size = 18; // PersistentSet[] sets = new PersistentSet[size]; // PersistentSet s = new PersistentSet(); // int[] data = randomSet(r, size); // for (int i = 0; i < size; ++i) { // System.out.println("-- add " + data[i] + " -- "); // sets[i] = s = s.add(data[i]); // dump(s.root, System.out, 0); // } // System.out.println("\npersistence:\n"); // for (int i = 0; i < size; ++i) { // dump(sets[i].root, System.out, 0); // System.out.println("--"); // } // data = randomSet(r, size); // System.out.println("\nremoval:\n"); // for (int i = 0; i < size; ++i) { // System.out.println("-- remove " + data[i] + " -- "); // sets[i] = s = s.remove(data[i]); // dump(s.root, System.out, 0); // } // System.out.println("\npersistence:\n"); // for (int i = 0; i < size; ++i) { // dump(sets[i].root, System.out, 0); // System.out.println("--"); // } // } private static class Node { public T value; public Node left; public Node right; public boolean red; public Node(Node basis) { if (basis != null) { value = basis.value; left = basis.left; right = basis.right; red = basis.red; } } } public static class Path { private final boolean fresh; private final Node node; private final PersistentSet root; private final Cell> ancestors; public Path(boolean fresh, Node node, PersistentSet root, Cell> ancestors) { this.fresh = fresh; this.node = node; this.root = root; this.ancestors = ancestors; } public T value() { return node.value; } public boolean fresh() { return fresh; } public PersistentSet root() { return root; } public Path successor() { return root.successor(this); } public PersistentSet remove() { return root.remove(this); } public PersistentSet add() { if (! fresh) throw new IllegalStateException(); return root.add(this); } public PersistentSet replaceWith(T value) { if (fresh) throw new IllegalStateException(); if (root.comparator.compare(node.value, value) != 0) throw new IllegalArgumentException(); node.value = value; return root; } } public class Iterator implements java.util.Iterator { private PersistentSet.Path path; private Iterator(PersistentSet.Path path) { this.path = path; } private Iterator(Iterator start) { path = start.path; } public boolean hasNext() { return path != null; } public T next() { PersistentSet.Path p = path; path = path.successor(); return p.value(); } public void remove() { throw new UnsupportedOperationException(); } } }