< 返回版块

dvorakchen 发表于 2023-03-04 22:37

Tags:rust,问答

我有一个全局 HashMap 如下:

static mut CONNECTING_CLIENT: Lazy<Mutex<HashMap<String, TcpStream>>> =
    Lazy::new(|| Mutex::new(HashMap::new()));

并且封装了一个方法来获取值的可变借用,如下:

pub async fn get_client_stream(username: &String) -> Option<&mut TcpStream> {
    unsafe {
        CONNECTING_CLIENT
        .lock()
        .await
        .get_mut(username)
        //  这个 get_mut 返回值是 Option<&mut TcpStream>
    }
}

报了如下的错误:

cannot return reference to temporary value
returns a reference to data owned by the current function

我猜想是因为 MutexGuard 持有了借用,出了 unsafe 块后就失效了。

参考了 HashMap 的源码后修改成了以下方式:

pub async fn get_client_stream(username: &String) -> Option<&mut TcpStream> {
    let mut v = unsafe {
        CONNECTING_CLIENT
        .lock()
        .await
    };

    v.get_mut(username).map(|v| {
        unsafe {
            &mut *(v as *mut TcpStream)
        }
    })
}

这样一来,虽然编译是通过了,但我不能保证这个 unsafe 是否 safe,不知道会不会有问题

求教可能会发生什么问题?有没有其他更好的写法?

评论区

写评论
lithbitren 2023-03-05 01:07

不管是不是static/unsafe,智能指针包裹的可变引用尽可能不要跨作用域传递,最好是用完就drop,免得影响其他引用,包括但不限于Mutex/RwLock/RefCell。

如果觉得每次用的时候都得写一串东西用起来太麻烦,可以考虑用宏来封装,起码不用考虑这么多类型、生命周期以及所有权问题。

macro_rules! get_client_stream {
    ($username: expr) => {
        unsafe {
            CONNECTING_CLIENT
            .lock()
            .await
            .get_mut($username)
        }
    }
}
hax10 2023-03-04 23:49

我觉得你没必要使用全局可变的静态变量,这样做才需要使用不安全代码块。最好初始化程序时建立一个服务器对象,这个对象能供多条线程使用。

简单示例:

use std::collections::HashMap;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio::sync::RwLock;

pub struct Server {
    connections: Arc<RwLock<HashMap<String, TcpStream>>>,
}

impl Server {
    pub fn new() -> Server {
        Server {
            connections: Arc::new(RwLock::new(HashMap::new())),
        }
    }

    async fn process_connection(&self, conn: TcpStream) {
        // 获取散列表的写作权
        let mut hashmap = self.connections.write().await;
        hashmap.insert(String::from("鸡大保"), conn);
        let conn_ref = hashmap.get(&String::from("鸡大保")).unwrap();
        println!("{:#?}", conn_ref);
    }
}

#[tokio::main]
async fn main() {
    let s = Server::new();
    let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();

    loop {
        let (conn, _) = listener.accept().await.unwrap();
        s.process_connection(conn).await;
    }
}
1 共 2 条评论, 1 页