< 返回版块

LeJane 发表于 2020-05-28 17:16

quiche原来代码的逻辑在:https://github.com/cloudflare/quiche/blob/master/examples/server.rs

我的代码:

use byteorder::{LittleEndian, ReadBytesExt};
use ring::rand::*;
use std::collections::HashMap;
use std::net;
use tracing::{debug, error, info, warn};
use v1::build_routers;
use v1::utils::router::Context as ReqContext;
use v1::utils::router::Handler;
use v1::utils::{Client, ClientMap, ClientUidMap};

const MAX_DATAGRAM_SIZE: usize = 65535;

pub struct PartialResponse {
    pub body: Vec<u8>,

    pub written: usize,
}

pub struct Client {
    pub conn: std::pin::Pin<Box<quiche::Connection>>,

    pub partial_responses: HashMap<u64, PartialResponse>,
}

pub type ClientMap = HashMap<Vec<u8>, (net::SocketAddr, Client)>;

pub type ClientUidMap = HashMap<u64, Vec<u8>>;

fn main() {
    tracing::subscriber::set_global_default(
        tracing_subscriber::FmtSubscriber::builder()
            .with_env_filter(
                tracing_subscriber::EnvFilter::from_default_env()
                    .add_directive(tracing_subscriber::filter::LevelFilter::INFO.into()),
            )
            .finish(),
    )
    .unwrap();

    let mut buf = [0; 65535];
    let mut out = [0; MAX_DATAGRAM_SIZE];

    // Setup the event loop.
    let poll = mio::Poll::new().unwrap();
    let mut events = mio::Events::with_capacity(1024);

    // Create the UDP listening socket, and register it with the event loop.
    let socket = net::UdpSocket::bind("127.0.0.1:4433").unwrap();

    let socket = mio::net::UdpSocket::from_socket(socket).unwrap();
    poll.register(
        &socket,
        mio::Token(0),
        mio::Ready::readable(),
        mio::PollOpt::edge(),
    )
    .unwrap();

    // Create the configuration for the QUIC connections.
    let mut config = quiche::Config::new(quiche::PROTOCOL_VERSION).unwrap();

    config.load_cert_chain_from_pem_file("cert.crt").unwrap();
    config.load_priv_key_from_pem_file("cert.key").unwrap();

    config
        .set_application_protos(b"\x05hq-27\x08http/0.9")
        .unwrap();

    config.set_max_idle_timeout(5000);
    config.set_max_packet_size(MAX_DATAGRAM_SIZE as u64);
    config.set_initial_max_data(10_000_000);
    config.set_initial_max_stream_data_bidi_local(100_000_000);
    config.set_initial_max_stream_data_bidi_remote(100_000_000);
    config.set_initial_max_stream_data_uni(1_000_000);
    config.set_initial_max_streams_bidi(100);
    config.set_initial_max_streams_uni(100);
    config.set_disable_active_migration(true);
    config.enable_early_data();

    let rng = SystemRandom::new();
    let conn_id_seed = ring::hmac::Key::generate(ring::hmac::HMAC_SHA256, &rng).unwrap();

    let mut clients = ClientMap::new();
    let router = build_routers();
    let mut clients_uid_map = ClientUidMap::new();

    loop {
        // Find the shorter timeout from all the active connections.
        //
        // TODO: use event loop that properly supports timers
        let timeout = clients.values().filter_map(|(_, c)| c.conn.timeout()).min();

        poll.poll(&mut events, timeout).unwrap();

        // Read incoming UDP packets from the socket and feed them to quiche,
        // until there are no more packets to read.
        'read: loop {
            // If the event loop reported no events, it means that the timeout
            // has expired, so handle it without attempting to read packets. We
            // will then proceed with the send loop.
            if events.is_empty() {
                debug!("timed out");

                clients.values_mut().for_each(|(_, c)| c.conn.on_timeout());

                break 'read;
            }

            let (len, src) = match socket.recv_from(&mut buf) {
                Ok(v) => v,

                Err(e) => {
                    // There are no more UDP packets to read, so end the read
                    // loop.
                    if e.kind() == std::io::ErrorKind::WouldBlock {
                        debug!("recv() would block");
                        break 'read;
                    }

                    panic!("recv() failed: {:?}", e);
                }
            };

            debug!("got {} bytes", len);

            let pkt_buf = &mut buf[..len];

            // Parse the QUIC packet's header.
            let hdr = match quiche::Header::from_slice(pkt_buf, quiche::MAX_CONN_ID_LEN) {
                Ok(v) => v,

                Err(e) => {
                    error!("Parsing packet header failed: {:?}", e);
                    continue 'read;
                }
            };

            info!("got packet {:?}", hdr);

            let conn_id = ring::hmac::sign(&conn_id_seed, &hdr.dcid);
            let conn_id = &conn_id.as_ref()[..quiche::MAX_CONN_ID_LEN];

            // Lookup a connection based on the packet's connection ID. If there
            // is no connection matching, create a new one.
            let (_, client) = if !clients.contains_key(&hdr.dcid) && !clients.contains_key(conn_id)
            {
                if hdr.ty != quiche::Type::Initial {
                    error!("Packet is not Initial");
                    continue 'read;
                }

                if !quiche::version_is_supported(hdr.version) {
                    warn!("Doing version negotiation");

                    let len = quiche::negotiate_version(&hdr.scid, &hdr.dcid, &mut out).unwrap();

                    let out = &out[..len];

                    if let Err(e) = socket.send_to(out, &src) {
                        if e.kind() == std::io::ErrorKind::WouldBlock {
                            debug!("send() would block");
                            break;
                        }

                        panic!("send() failed: {:?}", e);
                    }
                    continue 'read;
                }

                let mut scid = [0; quiche::MAX_CONN_ID_LEN];
                scid.copy_from_slice(&conn_id);

                // Token is always present in Initial packets.
                let token = hdr.token.as_ref().unwrap();

                // Do stateless retry if the client didn't send a token.
                if token.is_empty() {
                    warn!("Doing stateless retry");

                    let new_token = generate_token(&hdr, &src);

                    let len =
                        quiche::retry(&hdr.scid, &hdr.dcid, &scid, &new_token, &mut out).unwrap();

                    let out = &out[..len];

                    if let Err(e) = socket.send_to(out, &src) {
                        if e.kind() == std::io::ErrorKind::WouldBlock {
                            debug!("send() would block");
                            break;
                        }

                        panic!("send() failed: {:?}", e);
                    }
                    continue 'read;
                }

                let odcid = validate_token(&src, token);

                // The token was not valid, meaning the retry failed, so
                // drop the packet.
                if odcid == None {
                    error!("Invalid address validation token");
                    continue 'read;
                }

                if scid.len() != hdr.dcid.len() {
                    error!("Invalid destination connection ID");
                    continue 'read;
                }

                // Reuse the source connection ID we sent in the Retry
                // packet, instead of changing it again.
                scid.copy_from_slice(&hdr.dcid);

                debug!(
                    "New connection: dcid={} scid={}",
                    hex_dump(&hdr.dcid),
                    hex_dump(&scid)
                );

                let conn = quiche::accept(&scid, odcid, &mut config).unwrap();

                let client = Client {
                    conn,
                    partial_responses: HashMap::new(),
                };

                clients.insert(scid.to_vec(), (src, client));

                clients.get_mut(&scid[..]).unwrap()
            } else {
                 match clients.get_mut(&hdr.dcid) {
                     Some(v) => v,

                     None => clients.get_mut(conn_id).unwrap(),
                 }
            };

            // Process potentially coalesced packets.
            let read = match client.conn.recv(pkt_buf) {
                Ok(v) => v,

                Err(e) => {
                    error!("{} recv failed: {:?}", client.conn.trace_id(), e);
                    continue 'read;
                }
            };

            debug!("{} processed {} bytes", client.conn.trace_id(), read);

            if client.conn.is_in_early_data() || client.conn.is_established() {
                // Handle writable streams.
                for stream_id in client.conn.writable() {
                    handle_writable(client, stream_id);
                }

                // Process all readable streams.
                for s in client.conn.readable() {
                    while let Ok((read, fin)) = client.conn.stream_recv(s, &mut buf) {
                        debug!("{} received {} bytes", client.conn.trace_id(), read);

                        let stream_buf = &buf[..read];

                        debug!(
                            "{} stream {} has {} bytes (fin? {})",
                            client.conn.trace_id(),
                            s,
                            stream_buf.len(),
                            fin
                        );
                        //todo:handle stream result from front.

                        let mut cursor = std::io::Cursor::new(stream_buf.clone());
                        let code = cursor.read_u16::<LittleEndian>().unwrap();
                        let version = cursor.read_u8().unwrap();
                        let len = cursor.read_u16::<LittleEndian>().unwrap();
                        let position = cursor.position() as usize;
                        let mut new_body = Vec::new();

                        if len > 0 {
                            let body = &stream_buf[position..position + (len as usize)];
                            new_body.extend_from_slice(body);
                        }
                        let dcid = client.conn.trace_id().as_bytes().to_vec();

                        let ctx = ReqContext {
                            code,
                            version,
                            body_length: len,
                            body: new_body,
                            dcid: dcid,
                            online_users: &mut clients_uid_map,
                            online_clients: &mut clients,
                        };

                        info!(
                            "stream request content-> code:{},version:{},len:{},position:{}",
                            code, version, len, position
                        );

                        router.call(client, s, ctx).unwrap();
                    }
                }
            }
        }

        for (peer, client) in clients.values_mut() {
            loop {
                let write = match client.conn.send(&mut out) {
                    Ok(v) => v,

                    Err(quiche::Error::Done) => {
                        debug!("{} done writing", client.conn.trace_id());
                        break;
                    }

                    Err(e) => {
                        error!("{} send failed: {:?}", client.conn.trace_id(), e);

                        client.conn.close(false, 0x1, b"fail").ok();
                        break;
                    }
                };

                // TODO: coalesce packets.
                if let Err(e) = socket.send_to(&out[..write], &peer) {
                    if e.kind() == std::io::ErrorKind::WouldBlock {
                        debug!("send() would block");
                        break;
                    }

                    panic!("send() failed: {:?}", e);
                }

                debug!("{} written {} bytes", client.conn.trace_id(), write);
            }
        }

        // Garbage collect closed connections.
        clients.retain(|_, (_, ref mut c)| {
            debug!("Collecting garbage");

            if c.conn.is_closed() {
                info!(
                    "{} connection collected {:?}",
                    c.conn.trace_id(),
                    c.conn.stats()
                );
            }
            let mut uid: u64 = 0;
            for (id, dcid) in clients_uid_map.iter() {
                let dcid_str = std::str::from_utf8(dcid).unwrap();
                info!(
                    "find online users->dcid:{} trace_id:{} uid:{}",
                    dcid_str,
                    c.conn.trace_id(),
                    id
                );
                if dcid_str == c.conn.trace_id() {
                    uid = *id;
                }
            }
            if uid > 0 {
                clients_uid_map.remove(&uid);
            }
            !c.conn.is_closed()
        });
    }
}

fn generate_token(hdr: &quiche::Header, src: &net::SocketAddr) -> Vec<u8> {
    let mut token = Vec::new();

    token.extend_from_slice(b"quiche");

    let addr = match src.ip() {
        std::net::IpAddr::V4(a) => a.octets().to_vec(),
        std::net::IpAddr::V6(a) => a.octets().to_vec(),
    };

    token.extend_from_slice(&addr);
    token.extend_from_slice(&hdr.dcid);

    token
}

fn validate_token<'a>(src: &net::SocketAddr, token: &'a [u8]) -> Option<&'a [u8]> {
    if token.len() < 6 {
        return None;
    }

    if &token[..6] != b"quiche" {
        return None;
    }

    let token = &token[6..];

    let addr = match src.ip() {
        std::net::IpAddr::V4(a) => a.octets().to_vec(),
        std::net::IpAddr::V6(a) => a.octets().to_vec(),
    };

    if token.len() < addr.len() || &token[..addr.len()] != addr.as_slice() {
        return None;
    }

    let token = &token[addr.len()..];

    Some(&token[..])
}

fn handle_writable(client: &mut Client, stream_id: u64) {
    let conn = &mut client.conn;

    debug!("{} stream {} is writable", conn.trace_id(), stream_id);

    if !client.partial_responses.contains_key(&stream_id) {
        return;
    }

    let resp = client.partial_responses.get_mut(&stream_id).unwrap();
    let body = &resp.body[resp.written..];

    let written = match conn.stream_send(stream_id, &body, true) {
        Ok(v) => v,

        Err(quiche::Error::Done) => 0,

        Err(e) => {
            error!("{} stream send failed {:?} .....", conn.trace_id(), e);
            return;
        }
    };

    resp.written += written;

    if resp.written == resp.body.len() {
        client.partial_responses.remove(&stream_id);
    }
}

fn hex_dump(buf: &[u8]) -> String {
    let vec: Vec<String> = buf.iter().map(|b| format!("{:02x}", b)).collect();

    vec.join("")
}

报错如下:

   error[E0499]: cannot borrow `clients` as mutable more than once at a time
   --> src/main.rs:279:45
    |
218 |                 match clients.get_mut(&hdr.dcid) {
    |                       ------- first mutable borrow occurs here
...
279 |                             online_clients: &mut clients,
    |                                             ^^^^^^^^^^^^ second mutable borrow occurs here
...
287 |                         router.call(client, s, ctx).unwrap();
    |                                     ------ first borrow later used here

error: aborting due to previous error

For more information about this error, try `rustc --explain E0499`.
error: could not compile `v1`.

To learn more, run the command again with --verbose.

评论区

写评论
作者 LeJane 2020-06-01 09:40

ctx里面是必须对clients里面进行引用的,然后无论我怎么拆分都无法把逻辑拆开,用interior mutability也绕不开借用检查的,会出另外的问题,所以这个quiche就不适合用来做聊天,更适合做http的request->response这种方式。 对以下内容的回复:

solarsail 2020-05-30 16:37

clientclients.get_mut() 得到的值独占引用,而 ctx 中包含了对 clients 的独占引用,这两者没法调和,至少得改一边。router.call() 的 API 我不了解,如果结构定义都不能动的话,我会考虑让 client 和它所属的容器 clients 脱钩,变成一个 owned 对象,比如通过 clone()(需要加个 derive)。处理完再同步回 clients。如果 clientctx 的结构都可以调整,可以试试用些 interior mutability 模式绕过检查,比如 Rc<RefCell<_>> 之类。另外就是考虑 call() 是否真的需要 clients 的独占引用,感觉一边修改 HashMap 里的值,同时还要增删 HashMap 本身的条目,这种方式按照目前的设计不太成立。如果是需要删除 client,因为外面有对 client 的引用,是没法删的;如果需要增加条目(有点奇怪),也可以把需要加的内容传出来,放在外面处理。如果真的是需要删除的话,HashMap 里存Rc<RefCell<_>>可能更符合逻辑。

作者 LeJane 2020-05-28 19:41

这个代码是quiche里面的前后端通信的,然后你说的只在ctx里面借用clients是不行的,就是因为在call上边的代码借用了clients,然后在call里面借用才报错的,这里应该只能更改代码的有关逻辑,可是这块的逻辑我是真的不知道如何更改了

对以下内容的回复:

plus7wist 2020-05-28 17:43

router.call(client, s, ctx).unwrap() 这一句,clientctx 都可变得借用了 clients。Rust 禁止这种行为,没有办法绕开它,你需要想办法修改程序。你的业务我看得不是很明白,我只举个例子:只在 ctx 借用 clients,然后在 call 里面想办法找到你需要的 client,就可以通过借用检查。

1 共 4 条评论, 1 页