package java.util;

import java.util.Comparator;
import java.lang.Iterable;

public class PersistentSet <T> implements Iterable <T> {
  private static final Node NullNode = new Node(null);

  static {
    NullNode.left = NullNode;
    NullNode.right = NullNode;
  }

  private final Node<T> root;
  private final Comparator<T> comparator;
  private final int size;

  public PersistentSet() {
    this(NullNode, new Comparator<T>() {
        public int compare(T a, T b) {
          return ((Comparable<T>) a).compareTo(b);
        }
      }, 0);
  }

  public PersistentSet(Comparator<T> comparator) {
    this(NullNode, comparator, 0);
  }

  private PersistentSet(Node<T> root, Comparator<T> comparator, int size) {
    this.root = root;
    this.comparator = comparator;
    this.size = size;
  }

  public Comparator<T> comparator() {
    return comparator;
  }

  public PersistentSet<T> add(T value) {
    return add(value, false);
  }

  public int size() {
    return size;
  }

  public PersistentSet<T> add(T value, boolean replaceExisting) {
    Path<T> p = find(value);
    if (! p.fresh) {
      if (replaceExisting) {
        return p.replaceWith(value);
      } else {
        return this;
      }
    }

    return add(p);
  }

  private PersistentSet<T> add(Path<T> p) {
    if (! p.fresh) throw new IllegalArgumentException();

    Node<T> new_ = p.node;
    Node<T> newRoot = p.root.root;
    Cell<Node<T>> ancestors = p.ancestors;

    // rebalance
    new_.red = true;
    while (ancestors != null && ancestors.value.red) {
      if (ancestors.value == ancestors.next.value.left) {
        if (ancestors.next.value.right.red) {
          ancestors.value.red = false;
          ancestors.next.value.right = new Node(ancestors.next.value.right);
          ancestors.next.value.right.red = false;
          ancestors.next.value.red = true;
          new_ = ancestors.next.value;
          ancestors = ancestors.next.next;
        } else {
          if (new_ == ancestors.value.right) {
            new_ = ancestors.value;
            ancestors = ancestors.next;

            Node<T> n = leftRotate(new_);
            if (ancestors.value.right == new_) {
              ancestors.value.right = n;
            } else {
              ancestors.value.left = n;
            }
            ancestors = new Cell(n, ancestors);
          }
          ancestors.value.red = false;
          ancestors.next.value.red = true;

          Node<T> n = rightRotate(ancestors.next.value);
          if (ancestors.next.next == null) {
            newRoot = n;
          } else if (ancestors.next.next.value.right == ancestors.next.value) {
            ancestors.next.next.value.right = n;
          } else {
            ancestors.next.next.value.left = n;
          }
          // done
        }
      } else {
        if (ancestors.next.value.left.red) {
          ancestors.value.red = false;
          ancestors.next.value.left = new Node(ancestors.next.value.left);
          ancestors.next.value.left.red = false;
          ancestors.next.value.red = true;
          new_ = ancestors.next.value;
          ancestors = ancestors.next.next;
        } else {
          if (new_ == ancestors.value.left) {
            new_ = ancestors.value;
            ancestors = ancestors.next;

            Node<T> n = rightRotate(new_);
            if (ancestors.value.right == new_) {
              ancestors.value.right = n;
            } else {
              ancestors.value.left = n;
            }
            ancestors = new Cell(n, ancestors);
          }
          ancestors.value.red = false;
          ancestors.next.value.red = true;

          Node<T> n = leftRotate(ancestors.next.value);
          if (ancestors.next.next == null) {
            newRoot = n;
          } else if (ancestors.next.next.value.right == ancestors.next.value) {
            ancestors.next.next.value.right = n;
          } else {
            ancestors.next.next.value.left = n;
          }
          // done
        }
      }
    }

    newRoot.red = false;

    return new PersistentSet(newRoot, comparator, size + 1);
  }

  private static <T> Node<T> leftRotate(Node<T> n) {
    Node<T> child = new Node(n.right);
    n.right = child.left;
    child.left = n;
    return child;
  }

  private static <T> Node<T> rightRotate(Node<T> n) {
    Node<T> child = new Node(n.left);
    n.left = child.right;
    child.right = n;
    return child;
  }

  public PersistentSet<T> remove(T value) {
    Path<T> p = find(value);
    if (! p.fresh) {
      return remove(p);
    }

    return this;
  }

  private PersistentSet<T> remove(Path<T> p) {
    Node<T> new_ = p.node;
    Node<T> newRoot = p.root.root;
    Cell<Node<T>> ancestors = p.ancestors;

    Node<T> dead;
    if (new_.left == NullNode || new_.right == NullNode) {
      dead = new_;
    } else {
      Cell<Node<T>> path = successor(new_, ancestors);
      dead = path.value;
      ancestors = path.next;
    }
    
    Node<T> child;
    if (dead.left != NullNode) {
      child = dead.left;
    } else {
      child = dead.right;
    }

    if (ancestors == null) {
      child.red = false;
      return new PersistentSet(child, comparator, 1);
    } else if (dead == ancestors.value.left) {
      ancestors.value.left = child;
    } else {
      ancestors.value.right = child;
    }

    if (dead != new_) {
      new_.value = dead.value;
    }

    if (! dead.red) {
      // rebalance
      while (ancestors != null && ! child.red) {
        if (child == ancestors.value.left) {
          Node<T> sibling = ancestors.value.right
            = new Node(ancestors.value.right);
          if (sibling.red) {
            sibling.red = false;
            ancestors.value.red = true;
            
            Node<T> n = leftRotate(ancestors.value);
            if (ancestors.next == null) {
              newRoot = n;
            } else if (ancestors.next.value.right == ancestors.value) {
              ancestors.next.value.right = n;
            } else {
              ancestors.next.value.left = n;
            }
            ancestors.next = new Cell(n, ancestors.next);

            sibling = ancestors.value.right;
          }

          if (! (sibling.left.red || sibling.right.red)) {
            sibling.red = true;
            child = ancestors.value;
            ancestors = ancestors.next;
          } else {
            if (! sibling.right.red) {
              sibling.left = new Node(sibling.left);
              sibling.left.red = false;

              sibling.red = true;
              sibling = ancestors.value.right = rightRotate(sibling);
            }

            sibling.red = ancestors.value.red;
            ancestors.value.red = false;

            sibling.right = new Node(sibling.right);
            sibling.right.red = false;
            
            Node<T> n = leftRotate(ancestors.value);
            if (ancestors.next == null) {
              newRoot = n;
            } else if (ancestors.next.value.right == ancestors.value) {
              ancestors.next.value.right = n;
            } else {
              ancestors.next.value.left = n;
            }

            child = newRoot;
            ancestors = null;
          }
        } else {
          Node<T> sibling = ancestors.value.left
            = new Node(ancestors.value.left);
          if (sibling.red) {
            sibling.red = false;
            ancestors.value.red = true;
            
            Node<T> n = rightRotate(ancestors.value);
            if (ancestors.next == null) {
              newRoot = n;
            } else if (ancestors.next.value.left == ancestors.value) {
              ancestors.next.value.left = n;
            } else {
              ancestors.next.value.right = n;
            }
            ancestors.next = new Cell(n, ancestors.next);

            sibling = ancestors.value.left;
          }

          if (! (sibling.right.red || sibling.left.red)) {
            sibling.red = true;
            child = ancestors.value;
            ancestors = ancestors.next;
          } else {
            if (! sibling.left.red) {
              sibling.right = new Node(sibling.right);
              sibling.right.red = false;

              sibling.red = true;
              sibling = ancestors.value.left = leftRotate(sibling);
            }

            sibling.red = ancestors.value.red;
            ancestors.value.red = false;

            sibling.left = new Node(sibling.left);
            sibling.left.red = false;
            
            Node<T> n = rightRotate(ancestors.value);
            if (ancestors.next == null) {
              newRoot = n;
            } else if (ancestors.next.value.left == ancestors.value) {
              ancestors.next.value.left = n;
            } else {
              ancestors.next.value.right = n;
            }

            child = newRoot;
            ancestors = null;
          }
        }
      }

      child.red = false;
    }

    return new PersistentSet(newRoot, comparator, size - 1);
  }

  private static <T> Cell<Node<T>> minimum(Node<T> n,
                                           Cell<Node<T>> ancestors)
  {
    while (n.left != NullNode) {
      n.left = new Node(n.left);
      ancestors = new Cell(n, ancestors);
      n = n.left;
    }

    return new Cell(n, ancestors);
  }

  private static <T> Cell<Node<T>> successor(Node<T> n,
                                             Cell<Node<T>> ancestors)
  {
    if (n.right != NullNode) {
      n.right = new Node(n.right);
      return minimum(n.right, new Cell(n, ancestors));
    }

    while (ancestors != null && n == ancestors.value.right) {
      n = ancestors.value;
      ancestors = ancestors.next;
    }

    return ancestors;
  }

  public Path<T> find(T value) {
    Node<T> newRoot = new Node(root);
    Cell<Node<T>> ancestors = null;

    Node<T> old = root;
    Node<T> new_ = newRoot;
    while (old != NullNode) {
      ancestors = new Cell(new_, ancestors);

      int difference = comparator.compare(value, old.value);
      if (difference < 0) {
        old = old.left;
        new_ = new_.left = new Node(old);
      } else if (difference > 0) {
        old = old.right;
        new_ = new_.right = new Node(old);
      } else {
        return new Path(false, new_,
                        new PersistentSet(newRoot, comparator, size),
                        ancestors.next);
      }
    }

    new_.value = value;
    return new Path(true, new_,
                    new PersistentSet(newRoot, comparator, size),
                    ancestors);
  }

  public Path<T> first() {
    if (root == NullNode) return null;

    Node<T> newRoot = new Node(root);
    Cell<Node<T>> ancestors = null;

    Node<T> old = root;
    Node<T> new_ = newRoot;
    while (old.left != NullNode) {
      ancestors = new Cell(new_, ancestors);

      old = old.left;
      new_ = new_.left = new Node(old);
    }

    return new Path(true, new_,
                    new PersistentSet(newRoot, comparator, size),
                    ancestors);
  }

  public Path<T> last() {
    if (root == NullNode) return null;

    Node<T> newRoot = new Node(root);
    Cell<Node<T>> ancestors = null;

    Node<T> old = root;
    Node<T> new_ = newRoot;
    while (old.right != NullNode) {
      ancestors = new Cell(new_, ancestors);

      old = old.right;
      new_ = new_.right = new Node(old);
    }

    return new Path(true, new_,
                    new PersistentSet(newRoot, comparator, size),
                    ancestors);
  }

  public java.util.Iterator<T> iterator() {
    return new Iterator(first());
  }

  private Path<T> successor(Path<T> p) {
    Cell<Node<T>> s = successor(p.node, p.ancestors);
    if (s == null) {
      return null;
    } else {
      return new Path(false, s.value, p.root, s.next);
    }
  }

//   public void dump(java.io.PrintStream out) {
//     dump(root, out, 0);
//   }

//   private static void indent(java.io.PrintStream out, int level) {
//     for (int i = 0; i < level; ++i) out.print("  ");
//   }

//   private static <T> void dump(Node<T> n, java.io.PrintStream out, int level) {
//     indent(out, level);
//     out.print(n == NullNode ? null : n.value);
//     out.println(n == NullNode ? "" : n.red ? " (red)" : " (black)");
//     if (n.left != NullNode || n.right != NullNode) {
//       dump(n.left, out, level + 1);
//       dump(n.right, out, level + 1);
//     }
//   }

//   private static int[] randomSet(java.util.Random r, int size) {
//     int[] data = new int[size];
//     for (int i = size - 1; i >= 0; --i) {
//       data[i] = i + 1;
//     }

//     for (int i = size - 1; i >= 0; --i) {
//       int n = r.nextInt(size);
//       int tmp = data[i];
//       data[i] = data[n];
//       data[n] = tmp;
//     }

//     return data;
//   }

//   public static void main(String[] args) {
//     java.util.Random r = new java.util.Random(Integer.parseInt(args[0]));
//     int size = 18;
//     PersistentSet<Integer>[] sets = new PersistentSet[size];
//     PersistentSet<Integer> s = new PersistentSet();

//     int[] data = randomSet(r, size);

//     for (int i = 0; i < size; ++i) {
//       System.out.println("-- add " + data[i] + " -- ");
//       sets[i] = s = s.add(data[i]);
//       dump(s.root, System.out, 0);
//     }

//     System.out.println("\npersistence:\n");
//     for (int i = 0; i < size; ++i) {
//       dump(sets[i].root, System.out, 0);
//       System.out.println("--");
//     }

//     data = randomSet(r, size);

//     System.out.println("\nremoval:\n");
//     for (int i = 0; i < size; ++i) {
//       System.out.println("-- remove " + data[i] + " -- ");
//       sets[i] = s = s.remove(data[i]);
//       dump(s.root, System.out, 0);
//     }

//     System.out.println("\npersistence:\n");
//     for (int i = 0; i < size; ++i) {
//       dump(sets[i].root, System.out, 0);
//       System.out.println("--");
//     }
//   }

  private static class Node <T> {
    public T value;
    public Node left;
    public Node right;
    public boolean red;
    
    public Node(Node<T> basis) {
      if (basis != null) {
        value = basis.value;
        left = basis.left;
        right = basis.right;
        red = basis.red;
      }
    }
  }

  public static class Path <T> {
    private final boolean fresh;
    private final Node<T> node;
    private final PersistentSet<T> root;
    private final Cell<Node<T>> ancestors;
    
    public Path(boolean fresh, Node<T> node, PersistentSet<T> root,
                Cell<Node<T>> ancestors)
    {
      this.fresh = fresh;
      this.node = node;
      this.root = root;
      this.ancestors = ancestors;
    }

    public T value() {
      return node.value;
    }

    public boolean fresh() {
      return fresh;
    }

    public PersistentSet<T> root() {
      return root;
    }

    public Path<T> successor() {
      return root.successor(this);
    }

    public PersistentSet<T> remove() {
      return root.remove(this);
    }

    public PersistentSet<T> add() {
      if (! fresh) throw new IllegalStateException();

      return root.add(this);
    }

    public PersistentSet<T> replaceWith(T value) {
      if (fresh) throw new IllegalStateException();
      if (root.comparator.compare(node.value, value) != 0)
        throw new IllegalArgumentException();

      node.value = value;
      return root;
    }
  }
  
  public class Iterator <T> implements java.util.Iterator <T> {
    private PersistentSet.Path<T> path;

    private Iterator(PersistentSet.Path<T> path) {
      this.path = path;
    }

    private Iterator(Iterator<T> start) {
      path = start.path;
    }

    public boolean hasNext() {
      return path != null;
    }

    public T next() {
      PersistentSet.Path<T> p = path;
      path = path.successor();
      return p.value();
    }

    public void remove() {
      throw new UnsupportedOperationException();
    }
  }
}