Below is my implementation of a persistent red black tree in Rust.
I have a few questions about potential improvements. Currently the data and nodes are stored in referenced counted pointers. Is this the best way to do it?
On a similar note, the pattern matching statements for the balance functions are quite verbose because of the use of rc. Is there a way to write it more concisely?
Also, because I use rc pointers, should I implement the drop trait?
#[allow(clippy::module_inception)] pub mod red_black_tree { use std::cmp::Ordering; use std::rc::Rc; pub enum RedBlackTree<T> where T: Ord, { Node { color: Color, data: Rc<T>, left: Rc<RedBlackTree<T>>, right: Rc<RedBlackTree<T>>, }, Leaf, } #[derive(Clone)] pub enum Color { Red, Black, } impl<T> RedBlackTree<T> where T: Ord, { pub fn new() -> Self { RedBlackTree::Leaf } pub fn contains(&self, item: T) -> bool { match self { RedBlackTree::Node { color: _, data, left, right, } => match item.cmp(&data) { Ordering::Less => left.contains(item), Ordering::Equal => true, Ordering::Greater => right.contains(item), }, RedBlackTree::Leaf => false, } } pub fn insert(&self, item: T) -> Self { match self { RedBlackTree::Node { color, data, left, right, } => match item.cmp(&data) { Ordering::Less => RedBlackTree::Node { color: color.clone(), data: Rc::clone(data), left: Rc::new(left.insert(item)), right: Rc::clone(right), } .balance() .make_black(), Ordering::Equal => RedBlackTree::Node { color: color.clone(), data: Rc::clone(data), left: Rc::clone(left), right: Rc::clone(right), } .make_black(), Ordering::Greater => RedBlackTree::Node { color: color.clone(), data: Rc::clone(data), left: Rc::clone(left), right: Rc::new(right.insert(item)), } .balance() .make_black(), }, RedBlackTree::Leaf => RedBlackTree::Node { color: Color::Black, data: Rc::new(item), left: Rc::new(RedBlackTree::new()), right: Rc::new(RedBlackTree::new()), }, } } pub fn get(&self, item: &T) -> Option<Rc<T>> { match self { RedBlackTree::Node { color: _, data, left, right, } => match item.cmp(&data) { Ordering::Less => left.get(item), Ordering::Equal => Option::from(Rc::clone(data)), Ordering::Greater => right.get(item), }, RedBlackTree::Leaf => Option::None, } } fn make_black(&self) -> Self { match self { RedBlackTree::Node { color: _, data, left, right, } => RedBlackTree::Node { color: Color::Black, data: Rc::clone(data), left: Rc::clone(left), right: Rc::clone(right), }, RedBlackTree::Leaf => RedBlackTree::new(), } } fn balance(self) -> Self { if let RedBlackTree::Node { color: Color::Black, data: parent_data, left: parent_left, right: parent_right, } = &self { if let RedBlackTree::Node { color: Color::Red, data: child_data, left: child_left, right: child_right, } = Rc::as_ref(&parent_left) { if let RedBlackTree::Node { color: Color::Red, data: grandchild_data, left: grandchild_left, right: grandchild_right, } = Rc::as_ref(&child_left) { return RedBlackTree::from( grandchild_left, grandchild_right, child_right, parent_right, grandchild_data, child_data, parent_data, ); } else if let RedBlackTree::Node { color: Color::Red, data: grandchild_data, left: grandchild_left, right: grandchild_right, } = Rc::as_ref(&child_right) { return RedBlackTree::from( child_left, grandchild_left, grandchild_right, parent_right, child_data, grandchild_data, parent_data, ); } } else if let RedBlackTree::Node { color: Color::Red, data: child_data, left: child_left, right: child_right, } = Rc::as_ref(&parent_right) { if let RedBlackTree::Node { color: Color::Red, data: grandchild_data, left: grandchild_left, right: grandchild_right, } = Rc::as_ref(&child_left) { return RedBlackTree::from( parent_left, grandchild_left, grandchild_right, child_right, parent_data, grandchild_data, child_data, ); } else if let RedBlackTree::Node { color: Color::Red, data: grandchild_data, left: grandchild_left, right: grandchild_right, } = Rc::as_ref(&child_right) { return RedBlackTree::from( parent_left, child_left, grandchild_left, grandchild_right, parent_data, child_data, grandchild_data, ); } } } self.clone() } #[allow(clippy::many_single_char_names)] fn from( a: &Rc<RedBlackTree<T>>, b: &Rc<RedBlackTree<T>>, c: &Rc<RedBlackTree<T>>, d: &Rc<RedBlackTree<T>>, x: &Rc<T>, y: &Rc<T>, z: &Rc<T>, ) -> RedBlackTree<T> { RedBlackTree::Node { color: Color::Red, data: Rc::clone(y), left: Rc::new(RedBlackTree::Node { color: Color::Black, data: Rc::clone(x), left: Rc::clone(a), right: Rc::clone(b), }), right: Rc::new(RedBlackTree::Node { color: Color::Black, data: Rc::clone(z), left: Rc::clone(c), right: Rc::clone(d), }), } } } impl<T> Clone for RedBlackTree<T> where T: Ord, { fn clone(&self) -> Self { match self { RedBlackTree::Node { color, data, left, right, } => RedBlackTree::Node { color: color.clone(), data: Rc::clone(data), left: Rc::clone(left), right: Rc::clone(right), }, RedBlackTree::Leaf => RedBlackTree::new(), } } } impl<T> Default for RedBlackTree<T> where T: Ord, { fn default() -> Self { RedBlackTree::<T>::new() } } } #[cfg(test)] mod tests { use super::red_black_tree::*; use std::cmp::*; struct Point { x: i64, y: i64, } impl Point { fn new(x: i64, y: i64) -> Point { Point { x, y } } fn magnitude_squared(&self) -> u64 { (self.x as u64).pow(2) + (self.y as u64).pow(2) } } impl PartialEq for Point { fn eq(&self, other: &Self) -> bool { self.magnitude_squared() == other.magnitude_squared() } } impl Eq for Point {} impl PartialOrd for Point { fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) } } impl Ord for Point { fn cmp(&self, other: &Self) -> Ordering { self.magnitude_squared().cmp(&other.magnitude_squared()) } } #[test] fn test_empty() { let tree = RedBlackTree::<i64>::new(); assert!(!tree.contains(0)); assert!(!tree.contains(5)); assert!(!tree.contains(-20)); } #[test] fn test_insert() { let mut tree: RedBlackTree<char> = RedBlackTree::new(); assert!(!tree.contains('a')); assert!(!tree.contains('b')); assert!(!tree.contains('c')); tree = tree.insert('a'); assert!(tree.contains('a')); assert!(!tree.contains('b')); assert!(!tree.contains('c')); tree = tree.insert('b'); assert!(tree.contains('a')); assert!(tree.contains('b')); assert!(!tree.contains('c')); } #[test] fn test_get() { let mut tree: RedBlackTree<Point> = RedBlackTree::new(); tree = tree.insert(Point::new(0, 0)); tree = tree.insert(Point::new(1, 1)); tree = tree.insert(Point::new(2, 2)); tree = tree.insert(Point::new(3, 4)); assert_eq!(tree.get(&Point::new(0, 0)).unwrap().x, 0); assert_eq!(tree.get(&Point::new(0, 0)).unwrap().y, 0); assert_eq!(tree.get(&Point::new(0, 5)).unwrap().x, 3); assert_eq!(tree.get(&Point::new(0, 5)).unwrap().y, 4); } }
Edit: the rebalancing function relies is based on this diagram (source):
#[allow(clippy::module_inception)]
? Generally, when a lint is triggered, we fix the code rather than disable the lint. The same goes for#[allow(clippy::many_single_char_names)]
. I'm scared atfrom(a, b, c, d, x, y, z)
.\$\endgroup\$from(a, b, c, d, x, y, z)
function, I was following various guides that all used that naming scheme. I've attached at image explaining it\$\endgroup\$red_black_tree::red_black_tree::RedBlackTree
. The correct way is to simply remove themod
declaration (the file is automatically a module).\$\endgroup\$