1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
use super::super::ShardStream;
use async_tungstenite::tungstenite::Message;
use futures_channel::mpsc::{self, UnboundedReceiver, UnboundedSender};
use futures_util::{
    future::{self, Either},
    sink::SinkExt,
    stream::StreamExt,
};
use tokio::time::timeout;
#[allow(unused_imports)]
use tracing::{debug, info, trace, warn};

pub struct SocketForwarder {
    rx: UnboundedReceiver<Message>,
    pub stream: ShardStream,
    tx: UnboundedSender<Message>,
}

impl SocketForwarder {
    pub fn new(
        stream: ShardStream,
    ) -> (Self, UnboundedReceiver<Message>, UnboundedSender<Message>) {
        let (to_user, from_forwarder) = mpsc::unbounded();
        let (to_forwarder, from_user) = mpsc::unbounded();

        (
            Self {
                rx: from_user,
                stream,
                tx: to_user,
            },
            from_forwarder,
            to_forwarder,
        )
    }

    pub async fn run(mut self) {
        const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(90);
        debug!("[SocketForwarder] Starting driving loop");
        loop {
            match future::select(self.rx.next(), timeout(TIMEOUT, self.stream.next())).await {
                Either::Left((Some(msg), _)) => {
                    trace!("[SocketForwarder] Sending msg: {}", msg);
                    if let Err(err) = self.stream.send(msg).await {
                        warn!("[SocketForwarder] Got error when sending: {}", err);
                        break;
                    }
                }
                Either::Left((None, _)) => {
                    debug!("[SocketForwarder] Got None, closing stream");
                    let _ = self.stream.close(None).await;

                    break;
                }
                Either::Right((Ok(Some(Ok(msg))), _)) => {
                    if self.tx.unbounded_send(msg).is_err() {
                        break;
                    }
                }
                Either::Right((Ok(Some(Err(err))), _)) => {
                    warn!("[SocketForwarder] Got error: {}, closing tx", err);
                    self.tx.close_channel();
                    break;
                }
                Either::Right((Ok(None), _)) => {
                    debug!("[SocketForwarder] Got None, closing tx");
                    self.tx.close_channel();
                    break;
                }
                Either::Right((Err(why), _)) => {
                    warn!("[SocketForwarder] Error: {}", why);
                    self.tx.close_channel();
                    break;
                }
            }
        }
        debug!("[SocketForwarder] Leaving loop");
    }
}