8
\$\begingroup\$

I'm learning Rust by implementing basic data structures and algorithms. I implemented a binary heap (max heap):

mod binary_heap { #[derive(Debug)] pub struct MaxHeap<T> { pub data: Vec<T>, } impl<T> MaxHeap<T> where T: PartialOrd, { pub fn new() -> MaxHeap<T> { MaxHeap { data: vec![] } } pub fn push(&mut self, value: T) { self.data.push(value); let new_node_index: usize = self.data.len() - 1; self.sift_up(new_node_index); } pub fn pop(&mut self) -> Option<T> { match self.data.len() { 0 => None, _ => { let deleted_node = self.data.swap_remove(0); self.sift_down(); Some(deleted_node) } } } fn sift_up(&mut self, mut new_node_index: usize) { while !self.is_root(new_node_index) && self.is_greater_than_parent(new_node_index) { let parent_index = self.parent_index(new_node_index); self.data.swap(parent_index, new_node_index); new_node_index = self.parent_index(new_node_index); } } fn is_root(&self, node_index: usize) -> bool { node_index == 0 } fn is_greater_than_parent(&self, node_index: usize) -> bool { let parent_index = self.parent_index(node_index); self.data[node_index] > self.data[parent_index] } fn sift_down(&mut self) { let mut sifted_down_node_index: usize = 0; while self.has_greater_child(sifted_down_node_index) { let larger_child_index = self.calculate_larger_child_index(sifted_down_node_index); self.data.swap(sifted_down_node_index, larger_child_index); sifted_down_node_index = larger_child_index; } } fn left_child_index(&self, index: usize) -> usize { (index * 2) + 1 } fn right_child_index(&self, index: usize) -> usize { (index * 2) + 2 } fn parent_index(&self, index: usize) -> usize { (index - 1) / 2 } fn has_greater_child(&self, index: usize) -> bool { let left_child_index: usize = self.left_child_index(index); let right_child_index: usize = self.right_child_index(index); self.data.get(left_child_index).is_some() && self.data[left_child_index] > self.data[index] || self.data.get(right_child_index).is_some() && self.data[right_child_index] > self.data[index] } fn calculate_larger_child_index(&self, index: usize) -> usize { let left_child_index: usize = self.left_child_index(index); let right_child_index: usize = self.right_child_index(index); let left_child = self.data.get(left_child_index); let right_child = self.data.get(right_child_index); if ((right_child.is_some() && left_child.is_some()) && right_child > left_child) || left_child.is_none() { return right_child_index; } else { return left_child_index; } } } #[cfg(test)] mod test { use super::*; #[test] fn push_test() { let mut heap: MaxHeap<i32> = MaxHeap::new(); heap.push(3); heap.push(2); assert_eq!(vec![3, 2], heap.data); } #[test] fn root_node_is_always_the_biggest_element_in_heap_after_push_test() { let mut heap: MaxHeap<i32> = MaxHeap::new(); heap.push(5); heap.push(10); heap.push(2); assert_eq!(10, heap.data[0]); heap.push(3); assert_eq!(10, heap.data[0]); heap.push(20); assert_eq!(20, heap.data[0]); } #[test] fn pop_always_pop_the_root_node() { let mut heap: MaxHeap<i32> = MaxHeap::new(); heap.push(10); heap.push(4); heap.push(7); heap.pop(); assert!(!heap.data.contains(&10)); } #[test] fn pop_returns_a_variant_of_the_option_enum() { let mut heap: MaxHeap<i32> = MaxHeap::new(); heap.push(10); assert_eq!(Some(10), heap.pop()); assert_eq!(None, heap.pop()); } #[test] fn root_node_is_always_the_biggest_element_in_heap_after_pop_test() { let mut heap: MaxHeap<i32> = MaxHeap::new(); heap.push(5); heap.push(10); heap.push(2); heap.pop(); assert_eq!(5, heap.data[0]); heap.pop(); assert_eq!(2, heap.data[0]); } #[test] fn heap_is_generic_over_some_type_t() { let mut heap: MaxHeap<(i32, String)> = MaxHeap::new(); heap.push((2, String::from("Zanzibar"))); heap.push((10, String::from("Porto"))); heap.push((5, String::from("Beijing"))); let element = heap.pop(); assert_eq!(Some((10, String::from("Porto"))), element); assert_eq!((5, String::from("Beijing")), heap.data[0]); } } } 

Also I'd like to implement a min heap, without duplicating much code and without the need for the client code to always wrap any future elements in a Reverse struct.

What can be improved here? Any feedback is much appreciated!

Thanks!

\$\endgroup\$

    1 Answer 1

    2
    \$\begingroup\$

    The overall implementation looks good!

    Here are some of my suggestions:

    • The field data can be private.
    • A as_vec method can be implemented to return a reference to data.
    • A peek method can be implemented to return a reference to the top element.
    • The constraint of the items in the max heap should be Ord rather than PartialOrd.
    • There is no need to explicitly specify the types. For example the usize from this can be removed: let new_node_index: usize = ....
    • I would use if rather than match in the pop method.
    • The two returns in calculate_larger_child_index can be removed to keep it consistent with the other methods.

    Code with changes:

    pub mod binary_heap { #[derive(Debug)] pub struct MaxHeap<T> { data: Vec<T>, } impl<T> MaxHeap<T> where T: Ord, { pub fn new() -> MaxHeap<T> { MaxHeap { data: vec![] } } pub fn peek(&self) -> &T { &self.data[0] } pub fn as_vec(&self) -> &Vec<T> { &self.data } pub fn push(&mut self, value: T) { self.data.push(value); let new_node_index = self.data.len() - 1; self.sift_up(new_node_index); } pub fn pop(&mut self) -> Option<T> { if !self.data.is_empty() { let deleted_node = self.data.swap_remove(0); self.sift_down(); Some(deleted_node) } else { None } } fn sift_up(&mut self, mut new_node_index: usize) { while !self.is_root(new_node_index) && self.is_greater_than_parent(new_node_index) { let parent_index = self.parent_index(new_node_index); self.data.swap(parent_index, new_node_index); new_node_index = self.parent_index(new_node_index); } } fn is_root(&self, node_index: usize) -> bool { node_index == 0 } fn is_greater_than_parent(&self, node_index: usize) -> bool { let parent_index = self.parent_index(node_index); self.data[node_index] > self.data[parent_index] } fn sift_down(&mut self) { let mut sifted_down_node_index: usize = 0; while self.has_greater_child(sifted_down_node_index) { let larger_child_index = self.calculate_larger_child_index(sifted_down_node_index); self.data.swap(sifted_down_node_index, larger_child_index); sifted_down_node_index = larger_child_index; } } fn left_child_index(&self, index: usize) -> usize { (index * 2) + 1 } fn right_child_index(&self, index: usize) -> usize { (index * 2) + 2 } fn parent_index(&self, index: usize) -> usize { (index - 1) / 2 } fn has_greater_child(&self, index: usize) -> bool { let left_child_index = self.left_child_index(index); let right_child_index = self.right_child_index(index); self.data.get(left_child_index).is_some() && self.data[left_child_index] > self.data[index] || self.data.get(right_child_index).is_some() && self.data[right_child_index] > self.data[index] } fn calculate_larger_child_index(&self, index: usize) -> usize { let left_child_index = self.left_child_index(index); let right_child_index = self.right_child_index(index); let left_child = self.data.get(left_child_index); let right_child = self.data.get(right_child_index); if ((right_child.is_some() && left_child.is_some()) && right_child > left_child) || left_child.is_none() { right_child_index } else { left_child_index } } } #[cfg(test)] mod test { use super::*; #[test] fn push_test() { let mut heap: MaxHeap<i32> = MaxHeap::new(); heap.push(3); heap.push(2); assert_eq!(&vec![3, 2], heap.as_vec()); } #[test] fn root_node_is_always_the_biggest_element_in_heap_after_push_test() { let mut heap: MaxHeap<i32> = MaxHeap::new(); heap.push(5); heap.push(10); heap.push(2); assert_eq!(&10, heap.peek()); heap.push(3); assert_eq!(&10, heap.peek()); heap.push(20); assert_eq!(&20, heap.peek()); } #[test] fn pop_always_pop_the_root_node() { let mut heap: MaxHeap<i32> = MaxHeap::new(); heap.push(10); heap.push(4); heap.push(7); heap.pop(); assert!(!heap.as_vec().contains(&10)); } #[test] fn pop_returns_a_variant_of_the_option_enum() { let mut heap: MaxHeap<i32> = MaxHeap::new(); heap.push(10); assert_eq!(Some(10), heap.pop()); assert_eq!(None, heap.pop()); } #[test] fn root_node_is_always_the_biggest_element_in_heap_after_pop_test() { let mut heap: MaxHeap<i32> = MaxHeap::new(); heap.push(5); heap.push(10); heap.push(2); heap.pop(); assert_eq!(&5, heap.peek()); heap.pop(); assert_eq!(&2, heap.peek()); } #[test] fn heap_is_generic_over_some_type_t() { let mut heap: MaxHeap<(i32, String)> = MaxHeap::new(); heap.push((2, String::from("Zanzibar"))); heap.push((10, String::from("Porto"))); heap.push((5, String::from("Beijing"))); let element = heap.pop(); assert_eq!(Some((10, String::from("Porto"))), element); assert_eq!((5, String::from("Beijing")), heap.data[0]); } } } 

    Further possible improvements:

    • Implement the Iterator and/or IntoIterator trait for MaxHeap
    \$\endgroup\$
    2
    • \$\begingroup\$Thanks for the feedback. I implemented all of your comments in the original source code! Why should the heap use Ord rather than PartialOrd?\$\endgroup\$CommentedMay 29, 2021 at 19:00
    • 1
      \$\begingroup\$Using PartialOrd means that not all values will have defined order and using >, <, etc. operators with such values will always return false. For example both f64::NAN < 0.5 and f64::NAN > 0.5 return false, which can lead to inconsistencies in your data structure. You can still use PartialOrd as a constraint, as long as you either check the values before insertion and reject the ones that cannot be ordered or you define a custom ordering for them (for example treat them as smaller/larger than every other value).\$\endgroup\$
      – DobromirM
      CommentedMay 30, 2021 at 20:22

    Start asking to get answers

    Find the answer to your question by asking.

    Ask question

    Explore related questions

    See similar questions with these tags.