< 返回版块

automanyang 发表于 2020-04-27 15:39

rust实现的双向list

list的功能

首先要定义list的接口,就是list有的具体功能。可以看出这个list不仅仅在头尾可以增加和删除,也可以在list中间增加删除。

trait定义:

pub trait ListInterface<T> {
    // NodePtr是list中node的指针,通过NodePtr操作node,具体定义参见后面的代码。
    type NodePtr;
    // IntoIter是一个iterator,list可以转换为IntoIter,此后list就不存在了。
    type IntoIter: Iterator;
    // Iter也是一个iteraotor,可以用来遍历list。
    type Iter: Iterator;

    fn is_empty(&self) -> bool;
    // node的数量
    fn len(&self) -> usize;
    // at是具体的位置,从0开始直到len - 1。如果at > (len - 1)返回值为None
    fn get(&self, at: usize) -> Option<Self::NodePtr>;

    fn push_head(&mut self, value: T) -> Self::NodePtr;
    fn push_tail(&mut self, value: T) -> Self::NodePtr;    
    fn pop_head(&mut self) -> Option<Self::NodePtr>;
    fn pop_tail(&mut self) -> Option<Self::NodePtr>;
    
    // 如果at > (len - 1),实际就是push_tail
    fn insert_at(&mut self, at: usize, value: T) -> Self::NodePtr;
    // 如果before不是指向本list的node,就不插入,返回值是None
    fn insert(&mut self, before: &Self::NodePtr, value: T) -> Option<Self::NodePtr>;
    // 如果at > (len - 1),就不remove,返回值为None
    fn remove_at(&mut self, at: usize) -> Option<Self::NodePtr>;
    // 如果ptr不是指向本list的node,就不remove,返回值是None
    fn remove(&mut self, ptr: &Self::NodePtr) -> Option<Self::NodePtr>;

    // 将node移动到head
    fn top(&mut self, ptr: &Self::NodePtr) -> Option<Self::NodePtr>;
    // 将node移动到tail
    fn bottom(&mut self, ptr: &Self::NodePtr) -> Option<Self::NodePtr>;

    fn into_iter(self) -> Self::IntoIter;
    fn iter(&self) -> Self::Iter;
}

数据结构

首先定义list中的node,如下:

#[cfg_attr(test, derive(Eq, PartialEq))]
pub struct Node<T> {
    pub value: T,
    id: Option<usize>,
    pre: Option<Pointer<T>>,
    next: Option<Pointer<T>>,
}
impl<T> Node<T> {
    fn new(id: usize, value: T) -> Self {
        Self {
            value,
            id: Some(id),
            pre: None,
            next: None,
        }
    }
}
impl<T: std::fmt::Debug> fmt::Debug for Node<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(
            f,
            "({}{:?}{})",
            if self.pre.is_some() { "<" } else { "|" },
            self.value,
            if self.next.is_some() { ">" } else { "|" },
        )
    }
}

注意,Node中只有value是pub的,其他都是外界不可操作的,Node的new方法也不是pub。

其中的id解释一下,每个List都有一个唯一的id,List中Node的id与List的id一致,如果Node的id与List的id不一致,那就说明Node不是这个List产生的。

Pointer定义如下,指向Node。

#[cfg_attr(test, derive(Debug, Eq, PartialEq))]
pub struct Pointer<T>(Rc<RefCell<Node<T>>>);
impl<T> Pointer<T> {
    fn new(id: usize, value: T) -> Self {
        Self(Rc::new(RefCell::new(Node::new(id, value))))
    }
    pub fn node(&self) -> Ref<Node<T>> {
        self.0.borrow()
    }
    pub fn node_mut(&self) -> RefMut<Node<T>> {
        self.0.borrow_mut()
    }
}
impl<T> Clone for Pointer<T> {
    fn clone(&self) -> Self {
        Self(self.0.clone())
    }
}

这是是List的定义:

pub struct List<T> {
    id: usize,
    count: usize,
    head: Option<Pointer<T>>,
    tail: Option<Pointer<T>>,
}
impl<T> List<T> {
    pub fn new() -> Self {
        static mut ID: AtomicUsize = AtomicUsize::new(1);

        Self {
            id: unsafe { ID.fetch_add(1, Ordering::SeqCst) },
            count: 0,
            head: None,
            tail: None,
        }
    }
    fn new_pointer(&mut self, value: T) -> Pointer<T> {
        self.count += 1;
        Pointer::new(self.id, value)
    }
    fn contains(&self, ptr: &Pointer<T>) -> bool {
        ptr.node()
            .id
            .map(|v| v == self.id)
            .unwrap_or(false)
    }
}
impl<T> Default for List<T> {
    fn default() -> Self { Self::new() }
}
impl<T> Drop for List<T> {
    fn drop(&mut self) {
        self.iter().for_each(|_| {
            self.pop_tail();
        });
    }
}
impl<T: std::fmt::Debug> fmt::Debug for List<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let mut r = write!(f, "({})-[{}]", self.id, self.count);
        let mut it = self.iter();
        while let Some(p) = it.next() {
            r = write!(f, "-{:?}", p.node());
        }
        r
    }
}

再最后,是IntoIter和Iter的定义:

pub struct IntoIter<T>(List<T>);
impl<T> Iterator for IntoIter<T> {
    type Item = Pointer<T>;

    fn next(&mut self) -> Option<Self::Item> {
        self.0.pop_head()
    }
}

pub struct Iter<T> {
    next: Option<Pointer<T>>,
}
impl<T> Iterator for Iter<T> {
    type Item = Pointer<T>;

    fn next(&mut self) -> Option<Self::Item> {
        let p: Option<Pointer<T>> = self.next.take();
        self.next = p
            .as_ref()
            .and_then(|v| v.node().next.as_ref().map(|v2| v2.clone()));
        p
    }
}

实现ListInterface trait

基本上都比较简单,代码如下:

impl<T> ListInterface<T> for List<T> {
    type NodePtr = Pointer<T>;
    type IntoIter = IntoIter<T>;
    type Iter = Iter<T>;
    
    fn is_empty(&self) -> bool {
        self.count == 0
    }
    fn len(&self) -> usize {
        self.count
    }
    fn get(&self, at: usize) -> Option<Pointer<T>> {
        self.iter().nth(at)
    }
    fn push_head(&mut self, value: T) -> Pointer<T> {
        if let Some(head) = self.head.clone() {
            self.insert(&head, value).unwrap()
        } else {
            let ptr = self.new_pointer(value);
            self.tail.replace(ptr.clone());
            self.head.replace(ptr.clone());
            ptr
        }
    }
    fn push_tail(&mut self, value: T) -> Pointer<T> {
        let ptr = self.new_pointer(value);
        if let Some(tail) = self.tail.clone() {
            tail.node_mut().next.replace(ptr.clone());
            ptr.node_mut().pre.replace(tail);
            self.tail.replace(ptr.clone());
        } else {
            self.tail.replace(ptr.clone());
            self.head.replace(ptr.clone());
        }
        ptr
    }
    fn pop_head(&mut self) -> Option<Pointer<T>> {
        self.head.clone().and_then(|v| self.remove(&v))
    }
    fn pop_tail(&mut self) -> Option<Pointer<T>> {
        self.tail.clone().and_then(|v| self.remove(&v))
    }
    fn insert_at(&mut self, at: usize, value: T) -> Pointer<T> {
        if let Some(ref p) = self.get(at) {
            self.insert(p, value).unwrap()
        } else {
            self.push_tail(value)
        }
    }
    fn insert(&mut self, at: &Pointer<T>, value: T) -> Option<Pointer<T>> {
        if !self.contains(at) {
            return None;
        }

        let ptr = self.new_pointer(value);
        if let Some(ref pre) = at.node().pre {
            pre.node_mut().next.replace(ptr.clone());
            ptr.node_mut().pre.replace(pre.clone());
        } else {
            // the at was the head, so the node will be the head now.
            self.head.replace(ptr.clone());
        }
        at.node_mut().pre.replace(ptr.clone());
        ptr.node_mut().next.replace(at.clone());
        Some(ptr)
    }
    fn remove_at(&mut self, at: usize) -> Option<Pointer<T>> {
        let p = self.get(at);
        p.as_ref().map(|v| self.remove(v));
        p
    }
    fn remove(&mut self, ptr: &Pointer<T>) -> Option<Pointer<T>> {
        if !self.contains(ptr) {
            return None;
        }

        if let Some(ref pre) = ptr.node().pre {
            // ptr不是head
            if let Some(ref next) = ptr.node().next {
                // ptr不是tail
                pre.node_mut().next.replace(next.clone());
                next.node_mut().pre.replace(pre.clone());
            } else {
                // ptr是tail,tail改为前一个
                pre.node_mut().next.take();
                self.tail.replace(pre.clone());
            }
        } else {
            // ptr is head
            if let Some(ref next) = ptr.node().next {
                // ptr is not tail
                next.node_mut().pre.take();
                self.head.replace(next.clone());
            } else {
                // node is tail
                self.head.take();
                self.tail.take();
            }
        }
        // split the node from the list
        ptr.node_mut().pre.take();
        ptr.node_mut().next.take();

        self.count -= 1;
        ptr.node_mut().id.take();
        Some(ptr.clone())
    }
    fn top(&mut self, ptr: &Pointer<T>) -> Option<Pointer<T>> {
        if !self.contains(ptr) {
            return None;
        }

        if let Some(ref pre) = ptr.node().pre {
            // ptr不是head
            if let Some(ref next) = ptr.node().next {
                // ptr不是tail
                pre.node_mut().next.replace(next.clone());
                next.node_mut().pre.replace(pre.clone());
            } else {
                // ptr是tail,tail改为前一个
                pre.node_mut().next.take();
                self.tail.replace(pre.clone());
            }
        }
        if ptr.node().pre.is_some() {
            // ptr不是head,将node放在head之前
            if let Some(ref head) = self.head {
                head.node_mut().pre.replace(ptr.clone());
            }
            ptr.node_mut().pre = None;
            ptr.node_mut().next = self.head.clone();
            // ptr改为node
            self.head.replace(ptr.clone());
        }
        Some(ptr.clone())
    }
    fn bottom(&mut self, ptr: &Pointer<T>) -> Option<Pointer<T>> {
        if !self.contains(ptr) {
            return None;
        }

        if let Some(ref next) = ptr.node().next {
            // ptr不是tail
            if let Some(ref pre) = ptr.node().pre {
                // ptr不是head
                next.node_mut().pre.replace(pre.clone());
                pre.node_mut().next.replace(next.clone());
            } else {
                // ptr是head,head改为后一个
                next.node_mut().pre.take();
                self.head.replace(next.clone());
            }
        }
        if ptr.node().next.is_some() {
            // ptr不是tail,将ptr放在tail之后
            if let Some(ref tail) = self.tail {
                tail.node_mut().next.replace(ptr.clone());
            }
            ptr.node_mut().next = None;
            ptr.node_mut().pre = self.tail.clone();
            // ptr改为node
            self.tail.replace(ptr.clone());
        }
        Some(ptr.clone())
    }
    fn into_iter(self) -> IntoIter<T> {
        IntoIter(self)
    }
    fn iter(&self) -> Iter<T> {
        Iter {
            next: self.head.clone(),
        }
    }
}

关于Send和Sync

List是!Send + !Sync,原因是Pointer中使用了Rc,Rc是!Send + !Sync。

那么可以使List支持Send + Sync吗?

将Rc替换为Arc是否支持Send + Sync呢?答案是否。

Arc是Send + Sync必须要求T: Send + Sync,参见如下std中的定义:

impl<T: ?Sized + Sync + Send> Send for Arc<T>
impl<T: ?Sized + Sync + Send> Sync for Arc<T>

在Pointer的定义中,就是要求RefCell必须是Send + Sync。但是RefCell明确不支持Sync,参见如下std中的定义:

impl<T> Send for RefCell<T>
where
    T: Send + ?Sized, 
[src]

impl<T> !Sync for RefCell<T>
where
    T: ?Sized, 

所以Pointer就不支持Send + Sync,那么List就更不能支持Send + Sync。

如何让List支持Send + Sync?

可以将Pointer修改如下,Pointer就是Send + Sync。

pub struct Pointer<T: Send>(Arc<Mutex<RefCell<Node<T>>>>);

同时Node和List也就是Send + Sync了,前提条件就是T必须是Send + Sync。如下所示:

impl<T> Send for Node<T>
where
    T: Send, 

impl<T> Sync for Node<T>
where
    T: Send + Sync, 

impl<T> Send for List<T>
where
    T: Send, 

impl<T> Sync for List<T>
where
    T: Send, 

定义宏list,用来生成List:

macro_rules! list {
    [] => { List::new() };
    [$($x: expr),+] => {{
        let mut list = List::new();
        $(list.push_tail($x);)+
        list
    }};
    [$($x: expr,)+] => { list![$($x),+] }
}

测试

len()方法

测试代码如下:

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_list_len() {
        let l = list![1_u32, 2, 3, 4];
        dbg!(&l);
        assert_eq!(4, l.len())
    }
}

如下的测试命令:

cargo test test_list_len -- --nocapture --test-threads=1

输出如下:

running 1 test
[leetcode/src/utilites/list3.rs:372] &l = (1)-[4]-(|1>)-(<2>)-(<3>)-(<4|)
test utilites::list3::tests::test_list_len ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 24 filtered out

get()方法

测试代码如下:

    #[test]
    fn test_list_get() {
        let l = list![1_u32, 2, 3, 4];
        dbg!(&l);

        assert_eq!(1, l.get(0).unwrap().node().value);
        assert_eq!(3, l.get(2).unwrap().node().value);
        assert_eq!(4, l.get(3).unwrap().node().value);
        assert_eq!(None, l.get(4));
        assert_eq!(None, l.get(14));

        l.get(1).unwrap().node_mut().value = 22;
        assert_eq!(22, l.get(1).unwrap().node().value);

        dbg!(&l);
    }

输出如下:

running 1 test
[leetcode/src/utilites/list3.rs:381] &l = (1)-[4]-(|1>)-(<2>)-(<3>)-(<4|)
[leetcode/src/utilites/list3.rs:392] &l = (1)-[4]-(|1>)-(<22>)-(<3>)-(<4|)
test utilites::list3::tests::test_list_get ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 25 filtered out

push_head()方法

测试代码如下:

    fn test_list_push_head() {
        let mut l = List::<u32>::new();
        l.push_head(3);
        l.push_head(2);
        l.push_head(1);

        dbg!(&l);
        assert_eq!(1, l.get(0).unwrap().node().value);
        assert_eq!(2, l.get(1).unwrap().node().value);
        assert_eq!(3, l.get(2).unwrap().node().value);
        
        // --

        let mut l = list![11_u32, 12, 13, 14];
        l.push_head(3);
        l.push_head(2);
        l.push_head(1);

        dbg!(&l);
        assert_eq!(7, l.len());
        assert_eq!(1, l.get(0).unwrap().node().value);
        assert_eq!(3, l.get(2).unwrap().node().value);
        assert_eq!(11, l.get(3).unwrap().node().value);
        assert_eq!(13, l.get(5).unwrap().node().value);
        assert_eq!(14, l.get(6).unwrap().node().value);
        assert_eq!(None, l.get(7));
        assert_eq!(None, l.get(17));
    }

输出如下:

running 1 test
[leetcode/src/utilites/list3.rs:402] &l = (1)-[3]-(|1>)-(<2>)-(<3|)
[leetcode/src/utilites/list3.rs:414] &l = (2)-[7]-(|1>)-(<2>)-(<3>)-(<11>)-(<12>)-(<13>)-(<14|)
test utilites::list3::tests::test_list_push_head ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 26 filtered out

更多的测试代码,就不在这里啰嗦了,请参考playgournd的链接,有兴趣可以玩玩。 https://play.rust-lang.org/?version=stable&mode=debug&edition=2018&gist=7534c79e06a299af42e170dad583b5c5

评论区

写评论
jonirrings 2020-04-28 14:02

点赞,收藏!!!双向链表把我写得好伤心

1 共 1 条评论, 1 页