While I was reading AVL Tree page on wikipedia, it motivated me enough to write the implementation in Haskell.
There are several AVL Tree implementations already exist: a package uploaded to hackage, gist snippet, and polymorphic stanamically balanced AVL tree. As for classic data structure and algorithm exercise, decided to write a simple implementation.
> {-# LANGUAGE BangPatterns #-}
> module AVL where
For taking benchmarks and comparison with Data.Map
from container
package. None of these modules are used in code implementing AVL.
> import Control.DeepSeq (NFData(..), deepseq)
> import Criterion.Main
> import System.Random
> import qualified Data.Map as M
Like other binary trees, AVL tree has leaf constructor and node constructor with right and left branch, with a field to hold height of the node:
> data AVL a
> = Node {-# UNPACK #-} !Int !(AVL a) !a !(AVL a)
> | Leaf
> deriving (Eq, Show)
There was a note about order of constructors written in comments of Data.Map.Base:
When type has 2 constructors, a forward conditional jump is made
when successfully matching second constructor, in GHC 7.0.
This was still true in GHC 7.6.1, the version used at the time of
writing. This is the reason why Node
constructor came before the
Leaf
constructor.
An alias for leaf node:
> empty :: AVL a
> empty = Leaf
> {-# INLINEABLE empty #-}
Height of tree. Defining height of Leaf
node as '0'.
> height :: AVL a -> Int
> height t = case t of
> Leaf -> 0
> Node !n _ _ _ -> n
> {-# INLINE height #-}
Insert element to AVL tree. This function calls rebalance
after
inserting new element. Also, new element is strictly evaluated inside
the local function go
.
> insert :: Ord a => a -> AVL a -> AVL a
> insert = go where
> go :: Ord a => a -> AVL a -> AVL a
> go !n Leaf = Node 1 Leaf n Leaf
> go !n (Node h l !n' r) = case compare n' n of
> LT -> rebalance $ Node h l n' (insert n r)
> _ -> rebalance $ Node h (insert n l) n' r
> {-# INLINEABLE insert #-}
A function to check whether given element is a member of tree or not. Given element and element of pattern matched node are, again strictly evaluated.
> member :: (Ord a, Eq a) => a -> AVL a -> Bool
> member _ Leaf = False
> member !x (Node _ l !y r) = case compare x y of
> LT -> member x l
> GT -> member x r
> EQ -> True
> {-# INLINEABLE member #-}
There could be more functions, delete, merge, etc. But I'm little bit lazy to write them in this post.
Here comes the balancing function. It rebalances tree, four cases are considered: right-right, right-left, left-left, and left-right.
Firstly, comparing with Leaf
node, which is not needed to rebalance
any more. Then the height from left node and right node are compared,
leading to next case.
> rebalance :: AVL a -> AVL a
> rebalance Leaf = Leaf
> rebalance n1@(Node _ l1 x1 r1) = case compare hL1 hR1 of
> LT -> rightIsHeavy
> GT -> leftIsHeavy
> EQ -> updateHeight n1
> where
> hL1 = height l1
> hR1 = height r1
When right node is heavy, we compare the left node and right node again, and perform the rotations.
> rightIsHeavy = case r1 of
> Leaf -> n1
> Node _ l2 x2 r2 ->
> case compare (height l2) (height r2) of
The right-right case, single rotation swapping the node in the middle to top and top node to left:
> LT -> Node (hL1+2) (Node (hL1+1) l1 x1 l2) x2 r2
The right-left case, bringing the bottom element to top, top element to left, and middle element to right, with reordering each hanging nodes appropriately.
> GT -> case l2 of
> Leaf -> n1
> Node h3 l3 x3 r3 ->
> Node (h3+1) (Node h3 l1 x1 l3) x3 (Node h3 r3 x2 r2)
Nothing left to do when nodes already has same height, merely returning the given node itself.
> EQ -> n1
Left-left case and left-right cases are symmetrical to above:
> leftIsHeavy = case l1 of
> Leaf -> n1
> Node _ l2 x2 r2 ->
> case compare (height l2) (height r2) of
> LT -> case r2 of
> Leaf -> n1
> Node h3 l3 x3 r3 ->
> Node (h3+1) (Node h3 l2 x2 l3) x3 (Node h3 r3 x1 r1)
> GT -> Node (hR1+2) l2 x2 (Node (hR1+1) r2 x1 r1)
> EQ -> n1
> {-# INLINE rebalance #-}
Non-recursive height updating function, used in EQ
case of
rebalance
:
> updateHeight :: AVL a -> AVL a
> updateHeight t = case t of
> Leaf -> Leaf
> Node _ Leaf n Leaf -> Node 1 Leaf n Leaf
> Node _ l@(Node h _ _ _) n Leaf -> Node (h+1) l n Leaf
> Node _ Leaf n r@(Node h _ _ _) -> Node (h+1) Leaf n r
> Node _ l@(Node hl _ _ _) n r@(Node hr _ _ _) -> Node h' l n r where
> h' | hl < hr = hr + 1
> | otherwise = hl + 1
> {-# INLINE updateHeight #-}
That's all for AVL tree to have insert
and member
function. For
testing, a function to check the balance:
> isBalanced :: AVL a -> Bool
> isBalanced t = case t of
> Leaf -> True
> Node h l _ r ->
> abs (h - height l) <= 1 && abs (h - height r) <= 1 &&
> isBalanced l && isBalanced r
> {-# INLINE isBalanced #-}
Simple check:
ghci> isBalanced $ foldr insert empty [1..1024]
True
ghci> isBalanced $ foldr insert empty [1024,1023..1]
True
Benchmarks:
> instance NFData a => NFData (AVL a) where
> rnf Leaf = ()
> rnf (Node h l x r) = rnf h `seq` rnf l `seq` rnf x `seq` rnf r
>
> avlBenches :: [Benchmark]
> avlBenches =
> let tn n = foldr insert empty [0..n-1::Int]
> insertAVL !k =
> let !x = let x' = tn k in x' `deepseq` x'
> in bench ("n=" ++ show k) (whnfIO $ insertRand x (0,k-1))
> insertRand t (a,b) = do
> x <- getStdRandom (randomR (a,b))
> let y = insert x t
> return $! y
> memberAVL k =
> let x = tn k
> in x `deepseq` bench ("n=" ++ show k) (whnfIO $ memberRand x (0,k-1))
> memberRand t (a,b) = do
> x <- getStdRandom (randomR (a,b))
> let y = member x t
> return $! y
> mn n = let xs = [0..n-1::Int] in M.fromList $ zip xs (repeat ())
> insertMap !k =
> let !x = let x' = mn k in x' `deepseq` x'
> in bench ("n=" ++ show k) (whnfIO $ insertRandM x (0,k-1))
> insertRandM m (a,b) = do
> x <- getStdRandom (randomR (a,b))
> let y = M.insert x () m
> return $! y
> memberMap k =
> let !x = let x' = mn k in x' `deepseq` x'
> in bench ("n=" ++ show k) (whnfIO $ memberRandM x (0,k-1))
> memberRandM m (a,b) = do
> x <- getStdRandom (randomR (a,b))
> let y = M.member x m
> return $! y
> benchmarks =
> [ bgroup "AVL"
> [ bgroup "insert" [insertAVL (2^k) | k <- [10..14::Int]]
> , bgroup "member" [memberAVL (2^k) | k <- [10..14::Int]]
> ]
> , bgroup "Data.Map"
> [ bgroup "insert" [insertMap (2^k) | k <- [10..14::Int]]
> , bgroup "member" [memberMap (2^k) | k <- [10..14::Int]]
> ]
> ]
> in benchmarks
Using above benchmark as main:
> main :: IO ()
> main = defaultMain avlBenches
Compile, run the benchmark, and see the result in html report:
$ ghc -O2 -fllvm AVL.lhs -main-is AVL -o AVL
$ ./AVL -o bench.html
Report is here.
Lessons learned: from above simple implementation, a data structure with
performance close to those provided by standard package could be
implemented. Benchmark show that AVL tree insertion was slightly slower
than insertion in Data.Map, performance of member lookup was almost
identical. Though when there is no need for making own implemntation,
just use the implementation from proven package, in most case those
provided data structures has more features, Data.Map.Map has useful
functions like insetWith
, unionWith
, etc.
By the way, during the benchmark I was using nfIO
instead of whnfIO
for a while, resulting to linearly increasing insertion time for while.