modify TreeSet.MyIterator to support both ascending and descending iteration

This also fixes a bug such that the remove() method left the iterator
in an inconsistent state.
This commit is contained in:
Joel Dice 2013-12-04 17:52:27 -07:00
parent bb18637f13
commit 2000c139ea
3 changed files with 89 additions and 6 deletions

View File

@ -358,6 +358,18 @@ public class PersistentSet <T> implements Iterable <T> {
return new Cell(n, ancestors); return new Cell(n, ancestors);
} }
private static <T> Cell<Node<T>> maximum(Node<T> n,
Cell<Node<T>> ancestors)
{
while (n.right != NullNode) {
n.right = new Node(n.right);
ancestors = new Cell(n, ancestors);
n = n.right;
}
return new Cell(n, ancestors);
}
private static <T> Cell<Node<T>> successor(Node<T> n, private static <T> Cell<Node<T>> successor(Node<T> n,
Cell<Node<T>> ancestors) Cell<Node<T>> ancestors)
{ {
@ -374,6 +386,22 @@ public class PersistentSet <T> implements Iterable <T> {
return ancestors; return ancestors;
} }
private static <T> Cell<Node<T>> predecessor(Node<T> n,
Cell<Node<T>> ancestors)
{
if (n.left != NullNode) {
n.left = new Node(n.left);
return maximum(n.left, new Cell(n, ancestors));
}
while (ancestors != null && n == ancestors.value.left) {
n = ancestors.value;
ancestors = ancestors.next;
}
return ancestors;
}
public Path<T> find(T value) { public Path<T> find(T value) {
Node<T> newRoot = new Node(root); Node<T> newRoot = new Node(root);
Cell<Node<T>> ancestors = null; Cell<Node<T>> ancestors = null;
@ -456,6 +484,15 @@ public class PersistentSet <T> implements Iterable <T> {
} }
} }
private Path<T> predecessor(Path<T> p) {
Cell<Node<T>> s = predecessor(p.node, p.ancestors);
if (s == null) {
return null;
} else {
return new Path(false, s.value, p.root, s.next);
}
}
private static class Node <T> { private static class Node <T> {
public T value; public T value;
public Node left; public Node left;
@ -503,6 +540,10 @@ public class PersistentSet <T> implements Iterable <T> {
return root.successor(this); return root.successor(this);
} }
public Path<T> predecessor() {
return root.predecessor(this);
}
public PersistentSet<T> remove() { public PersistentSet<T> remove() {
if (fresh) throw new IllegalStateException(); if (fresh) throw new IllegalStateException();

View File

@ -57,9 +57,7 @@ public class TreeSet<T> extends AbstractSet<T> implements Collection<T> {
} }
public Iterator<T> descendingIterator() { public Iterator<T> descendingIterator() {
ArrayList<T> iterable = new ArrayList<T>(this); return new MyIterator<T>(set.last(), true);
Collections.reverse(iterable);
return iterable.iterator();
} }
public String toString() { public String toString() {
@ -141,12 +139,18 @@ public class TreeSet<T> extends AbstractSet<T> implements Collection<T> {
private Cell<T> prevCell; private Cell<T> prevCell;
private Cell<T> prevPrevCell; private Cell<T> prevPrevCell;
private boolean canRemove = false; private boolean canRemove = false;
private final boolean reversed;
private MyIterator(PersistentSet.Path<Cell<T>> path) { private MyIterator(PersistentSet.Path<Cell<T>> path) {
this(path, false);
}
private MyIterator(PersistentSet.Path<Cell<T>> path, boolean reversed) {
this.path = path; this.path = path;
this.reversed = reversed;
if (path != null) { if (path != null) {
cell = path.value(); cell = path.value();
nextPath = path.successor(); nextPath = nextPath();
} }
} }
@ -157,6 +161,7 @@ public class TreeSet<T> extends AbstractSet<T> implements Collection<T> {
prevCell = start.prevCell; prevCell = start.prevCell;
prevPrevCell = start.prevPrevCell; prevPrevCell = start.prevPrevCell;
canRemove = start.canRemove; canRemove = start.canRemove;
reversed = start.reversed;
} }
public boolean hasNext() { public boolean hasNext() {
@ -166,7 +171,7 @@ public class TreeSet<T> extends AbstractSet<T> implements Collection<T> {
public T next() { public T next() {
if (cell == null) { if (cell == null) {
path = nextPath; path = nextPath;
nextPath = path.successor(); nextPath = nextPath();
cell = path.value(); cell = path.value();
} }
prevPrevCell = prevCell; prevPrevCell = prevCell;
@ -176,6 +181,10 @@ public class TreeSet<T> extends AbstractSet<T> implements Collection<T> {
return prevCell.value; return prevCell.value;
} }
private PersistentSet.Path nextPath() {
return reversed ? path.predecessor() : path.successor();
}
public void remove() { public void remove() {
if (! canRemove) throw new IllegalStateException(); if (! canRemove) throw new IllegalStateException();
@ -190,11 +199,12 @@ public class TreeSet<T> extends AbstractSet<T> implements Collection<T> {
} else { } else {
// cell is alone in the list. // cell is alone in the list.
set = (PersistentSet) path.remove(); set = (PersistentSet) path.remove();
path = path.successor(); path = nextPath;
if (path != null) { if (path != null) {
prevCell = null; prevCell = null;
cell = path.value(); cell = path.value();
path = (PersistentSet.Path) set.find((Cell) cell); path = (PersistentSet.Path) set.find((Cell) cell);
nextPath = nextPath();
} }
} }

View File

@ -48,7 +48,39 @@ public class Tree {
} }
} }
private static void ascendingIterator() {
TreeSet<Integer> t = new TreeSet<Integer>();
t.add(7);
t.add(2);
t.add(9);
t.add(2);
Iterator<Integer> iter = t.iterator();
expect(2 == (int)iter.next());
expect(7 == (int)iter.next());
iter.remove();
expect(9 == (int)iter.next());
expect(!iter.hasNext());
isEqual(printList(t), "2, 9");
}
private static void descendingIterator() {
TreeSet<Integer> t = new TreeSet<Integer>();
t.add(7);
t.add(2);
t.add(9);
t.add(2);
Iterator<Integer> iter = t.descendingIterator();
expect(9 == (int)iter.next());
expect(7 == (int)iter.next());
iter.remove();
expect(2 == (int)iter.next());
expect(!iter.hasNext());
isEqual(printList(t), "2, 9");
}
public static void main(String args[]) { public static void main(String args[]) {
ascendingIterator();
descendingIterator();
TreeSet<Integer> t1 = new TreeSet<Integer>(new MyCompare()); TreeSet<Integer> t1 = new TreeSet<Integer>(new MyCompare());
t1.add(5); t1.add(2); t1.add(1); t1.add(8); t1.add(3); t1.add(5); t1.add(2); t1.add(1); t1.add(8); t1.add(3);
isEqual(printList(t1), "1, 2, 3, 5, 8"); isEqual(printList(t1), "1, 2, 3, 5, 8");