/* * Copyright 2011 Roger Kapsi * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // package org.ardverk.collection.spt; import java.io.Serializable; import java.util.AbstractCollection; import java.util.Collection; import java.util.ConcurrentModificationException; import java.util.Iterator; import java.util.Map; import java.util.NoSuchElementException; import java.util.Set; /** * A simple/lightweight implementation of a PATRICIA {@link Trie}. */ public class PatriciaTrie extends AbstractTrie implements Serializable { private static final long serialVersionUID = 7464215084236615537L; private static KeyAnalyzer DEFAULT = new KeyAnalyzer() { @Override public boolean isSet(Object key, int bitIndex) { return ((PatriciaKey)key).isBitSet(bitIndex); } @SuppressWarnings({ "rawtypes", "unchecked" }) @Override public int bitIndex(Object key, Object otherKey) { return ((PatriciaKey)key).bitIndex(otherKey); } }; private final KeyAnalyzer keyAnalyzer; private volatile RootNode root = new RootNode(); private volatile int size = 0; private transient volatile Entry[] entries = null; private transient volatile EntrySet entrySet = null; private transient volatile KeySet keySet = null; private transient volatile Values values = null; private transient volatile int modCount = 0; public PatriciaTrie() { this(DEFAULT); } public PatriciaTrie(KeyAnalyzer keyAnalyzer) { this.keyAnalyzer = keyAnalyzer; } public PatriciaTrie(Map m) { this(keyAnalyzer(m), m); } public PatriciaTrie(KeyAnalyzer keyAnalyzer, Map m) { this.keyAnalyzer = keyAnalyzer; putAll(m); } /** * Returns the {@link KeyAnalyzer}. */ public KeyAnalyzer getKeyAnalyzer() { return keyAnalyzer; } @Override public Entry select(K key) { Node entry = selectR(root.left, key, -1); if (!entry.isEmpty()) { return entry; } return null; } private Node selectR(Node h, K key, int bitIndex) { if (h.bitIndex <= bitIndex) { return h; } if (!isSet(key, h.bitIndex)) { return selectR(h.left, key, h.bitIndex); } else { return selectR(h.right, key, h.bitIndex); } } @Override public V put(K key, V value) { // This is a shortcut! The root is the only place to store null! if (key == null) { return putForNullKey(key, value); } Entry entry = select(key); K existing = null; if (entry != null) { existing = entry.getKey(); if (equals(key, existing)) { return entry.setValue(value); } } int bitIndex = bitIndex(key, existing); if (bitIndex == KeyAnalyzer.NULL_KEY) { return putForNullKey(key, value); } assert (bitIndex >= 0); root.left = putR(root.left, key, value, bitIndex, root); incrementSize(); return null; } /** * Stores the given key-value at the {@link RootNode}. */ private V putForNullKey(K key, V value) { if (root.isEmpty()) { incrementSize(); } return root.setKeyValue(key, value); } private Node putR(Node h, K key, V value, int bitIndex, Node p) { if ((h.bitIndex >= bitIndex) || (h.bitIndex <= p.bitIndex)) { Node t = new Node(key, value, bitIndex); boolean isSet = isSet(key, t.bitIndex); t.left = isSet ? h : t; t.right = isSet ? t : h; return t; } if (!isSet(key, h.bitIndex)) { h.left = putR(h.left, key, value, bitIndex, h); } else { h.right = putR(h.right, key, value, bitIndex, h); } return h; } @Override public V remove(Object key) { @SuppressWarnings("unchecked") Entry entry = entry((K)key); if (entry != null) { return removeEntry(entry); } return null; } /** * Removes the given {@link Entry} from the Trie. */ private V removeEntry(final Entry entry) { // We're traversing the old Trie and adding elements to the new Trie! RootNode old = clear0(); traverseR(old.left, new Cursor() { @Override public boolean select(Entry e) { if (!entry.equals(e)) { put(e.getKey(), e.getValue()); } return true; } }, -1); return entry.getValue(); } @Override public void select(K key, Cursor cursor) { selectR(root.left, key, cursor, -1); } private boolean selectR(Node h, K key, Cursor cursor, int bitIndex) { if (h.bitIndex <= bitIndex) { if (!h.isEmpty()) { return cursor.select(h); } return true; } if (!isSet(key, h.bitIndex)) { if (selectR(h.left, key, cursor, h.bitIndex)) { return selectR(h.right, key, cursor, h.bitIndex); } } else { if (selectR(h.right, key, cursor, h.bitIndex)) { return selectR(h.left, key, cursor, h.bitIndex); } } return false; } @Override public void traverse(Cursor cursor) { traverseR(root.left, cursor, -1); } private static boolean traverseR(Node h, Cursor cursor, int bitIndex) { if (h.bitIndex <= bitIndex) { if (!h.isEmpty()) { return cursor.select(h); } return true; } if (traverseR(h.left, cursor, h.bitIndex)) { return traverseR(h.right, cursor, h.bitIndex); } return false; } @Override public void clear() { clear0(); } @Override public int size() { return size; } @Override public Set> entrySet() { if (entrySet == null) { entrySet = new EntrySet(); } return entrySet; } @Override public Set keySet() { if (keySet == null) { keySet = new KeySet(); } return keySet; } @Override public Collection values() { if (values == null) { values = new Values(); } return values; } @Override public Entry firstEntry() { Node entry = followLeft(root.left, -1, root); if (!entry.isEmpty()) { return entry; } return null; } @Override public Entry lastEntry() { Node entry = followRight(root.left, -1); if (!entry.isEmpty()) { return entry; } return null; } private Node followLeft(Node h, int bitIndex, Node p) { if (h.bitIndex <= bitIndex) { if (!h.isEmpty()) { return h; } return p; } return followLeft(h.left, h.bitIndex, h); } private Node followRight(Node h, int bitIndex) { if (h.bitIndex <= bitIndex) { return h; } return followRight(h.right, h.bitIndex); } /** * Increments the {@link #size} counter and calls {@link #clearEntriesArray()}. */ private void incrementSize() { ++size; clearEntriesArray(); } /** * Clears the {@link PatriciaTrie} and returns the old {@link RootNode}. * The {@link RootNode} may be used to {@link #traverse(RootNode, Cursor)} * the old {@link PatriciaTrie}. * * @see #remove(Object) */ private RootNode clear0() { RootNode previous = root; root = new RootNode(); size = 0; clearEntriesArray(); return previous; } /** * Clears the {@link #entries} array. */ private void clearEntriesArray() { entries = null; ++modCount; } /** * @see KeyAnalyzer#isSet(Object, int) */ private boolean isSet(K key, int bitIndex) { return keyAnalyzer.isSet(key, bitIndex); } /** * @see KeyAnalyzer#bitIndex(Object, Object) */ private int bitIndex(K key, K otherKey) { return keyAnalyzer.bitIndex(key, otherKey); } /** * Turns the {@link PatriciaTrie} into an {@link Entry[]}. The array * is being cached for as long as the {@link PatriciaTrie} isn't being * modified. * * @see ViewIterator */ private Entry[] toArray() { if (entries == null) { @SuppressWarnings("unchecked") final Entry[] dst = new Entry[size()]; traverse(new Cursor() { private int index = 0; @Override public boolean select(Entry entry) { dst[index++] = entry; return true; } }); entries = dst; } return entries; } /** * Returns a {@link KeyAnalyzer} for the given {@link Map}. */ @SuppressWarnings({ "unchecked", "rawtypes" }) private static KeyAnalyzer keyAnalyzer(Map m) { if (m instanceof PatriciaTrie) { return ((PatriciaTrie)m).getKeyAnalyzer(); } return DEFAULT; } /** * An {@link Iterator} for {@link Entry}s. * * @see PatriciaTrie#toArray() */ private abstract class ViewIterator implements Iterator { private final Entry[] entries = toArray(); private int expectedModCount = PatriciaTrie.this.modCount; private int index = 0; private Entry current = null; @Override public boolean hasNext() { return index < entries.length; } @Override public E next() { if (!hasNext()) { throw new NoSuchElementException(); } if (expectedModCount != PatriciaTrie.this.modCount) { throw new ConcurrentModificationException(); } current = entries[index++]; return next(current); } /** * Called for each {@link Entry}. * * @see #next() */ protected abstract E next(Entry entry); @Override public void remove() { if (current == null) { throw new IllegalStateException(); } removeEntry(current); expectedModCount = PatriciaTrie.this.modCount; current = null; } } /** * An abstract base class for the various views. */ private abstract class AbstractView extends AbstractCollection { @Override public void clear() { PatriciaTrie.this.clear(); } @Override public int size() { return PatriciaTrie.this.size(); } } /** * @see PatriciaTrie#entrySet() */ private class EntrySet extends AbstractView> implements Set> { private Entry entry(Entry entry) { Entry other = PatriciaTrie.this.entry(entry.getKey()); if (other != null && other.equals(entry)) { return other; } return null; } @SuppressWarnings("unchecked") @Override public boolean contains(Object o) { if (o instanceof Entry) { return entry((Entry)o) != null; } return false; } @Override public boolean remove(Object o) { if (o instanceof Entry) { @SuppressWarnings("unchecked") Entry entry = entry((Entry)o); if (entry != null) { int size = size(); PatriciaTrie.this.removeEntry(entry); return size != size(); } } return false; } @Override public Iterator> iterator() { return new ViewIterator>() { @SuppressWarnings("unchecked") @Override protected Entry next(Entry entry) { return (Entry)entry; } }; } } /** * @see PatriciaTrie#keySet() */ private class KeySet extends AbstractView implements Set { @Override public boolean remove(Object key) { int size = size(); PatriciaTrie.this.remove(key); return size != size(); } @Override public boolean contains(Object o) { return PatriciaTrie.this.containsKey(o); } @Override public Iterator iterator() { return new ViewIterator() { @Override protected K next(Entry entry) { return entry.getKey(); } }; } } /** * @see PatriciaTrie#values() */ private class Values extends AbstractView { @Override public boolean remove(Object value) { for (Entry entry : entrySet()) { if (AbstractTrie.equals(value, entry.getValue())) { int size = size(); PatriciaTrie.this.removeEntry(entry); return size != size(); } } return false; } @Override public Iterator iterator() { return new ViewIterator() { @Override protected V next(Entry entry) { return entry.getValue(); } }; } } /** * The root node of the {@link Trie}. */ private static class RootNode extends Node { private static final long serialVersionUID = -8857149853096688620L; private boolean empty = true; public RootNode() { super(null, null, -1); this.left = this; } /** * Sets the key and value of the root node. */ public V setKeyValue(K key, V value) { this.key = key; this.empty = false; return setValue(value); } @Override public boolean isEmpty() { return empty; } } /** * A node in the {@link Trie}. */ private static class Node implements Entry, Serializable { public static final long serialVersionUID = -2409938371345117780L; public final int bitIndex; public K key; public V value; public Node left; public Node right; public Node(K key, V value, int bitIndex) { this.bitIndex = bitIndex; this.key = key; this.value = value; } /** * Returns {@code true} if the {@link Node} has no key-value. */ public boolean isEmpty() { return false; } @Override public K getKey() { return key; } @Override public V getValue() { return value; } @Override public V setValue(V value) { V existing = this.value; this.value = value; return existing; } @Override public int hashCode() { return 31 * (key != null ? key.hashCode() : 0) + (value != null ? value.hashCode() : 0); } @Override public boolean equals(Object o) { if (o == this) { return true; } else if (!(o instanceof Entry)) { return false; } Entry other = (Entry)o; return AbstractTrie.equals(key, other.getKey()) && AbstractTrie.equals(value, other.getValue()); } @Override public String toString() { return key + " (" + bitIndex + ") -> " + value; } } void printTrie() { showRecurse( root.left, -1, 0 ); System.out.println("==================================="); } public void showRecurse(Node t, int bitIndex, int h) { if (t.bitIndex <= bitIndex) return; showRecurse(t.right, t.bitIndex, h+1); printnode(t, h); showRecurse(t.left, t.bitIndex, h+1); } public void printnode(Node x, int h) { for (int i = 0; i < h; i++) System.out.print(" "); System.out.println("[" + x.key + "("+x.bitIndex+")" + "," + x.value+"]"); } }