1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
use super::{
    config::ShardConfig,
    error::{Error, Result},
    event::Events,
    processor::{Latency, Session, ShardProcessor},
    sink::ShardSink,
    stage::Stage,
};
use crate::{listener::Listeners, EventTypeFlags};
use futures_util::{
    future::{self, AbortHandle},
    stream::Stream,
};

use once_cell::sync::OnceCell;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use tokio::sync::watch::Receiver as WatchReceiver;
use tracing::debug;

use async_tungstenite::tungstenite::protocol::{frame::coding::CloseCode, CloseFrame};
use async_tungstenite::tungstenite::Message;
use std::borrow::Cow;
use twilight_model::gateway::event::Event;

/// Information about a shard, including its latency, current session sequence,
/// and connection stage.
#[derive(Clone, Debug)]
pub struct Information {
    id: u64,
    latency: Latency,
    seq: u64,
    stage: Stage,
}

impl Information {
    /// Returns the ID of the shard.
    pub fn id(&self) -> u64 {
        self.id
    }

    /// Returns the latency information for the shard.
    ///
    /// This includes the average latency over all time, and the latency
    /// information for the 5 most recent heartbeats.
    pub fn latency(&self) -> &Latency {
        &self.latency
    }

    /// The current sequence of the connection.
    ///
    /// This is the number of the event that was received this session (without
    /// reconnecting). A larger number typically correlates that the shard has
    /// been connected for a longer time, while a smaller number typically
    /// correlates to meaning that it's been connected for a less amount of
    /// time.
    pub fn seq(&self) -> u64 {
        self.seq
    }

    /// The current stage of the shard.
    ///
    /// For example, once a shard is fully booted then it will be
    /// [`Connected`].
    ///
    /// [`Connected`]: enum.Stage.html
    pub fn stage(&self) -> Stage {
        self.stage
    }
}
/// Holds the sessions id and sequence number to resume this shard's session with with
#[derive(Clone, Debug)]
pub struct ResumeSession {
    pub session_id: String,
    pub sequence: u64,
}

#[derive(Debug)]
struct ShardRef {
    config: Arc<ShardConfig>,
    listeners: Listeners<Event>,
    processor_handle: OnceCell<AbortHandle>,
    session: OnceCell<WatchReceiver<Arc<Session>>>,
}

#[derive(Clone, Debug)]
pub struct Shard(Arc<ShardRef>);

impl Shard {
    /// Creates a new shard, which will automatically connect to the gateway.
    ///
    /// # Examples
    ///
    /// Create a new shard, wait a second, and then print its current connection
    /// stage:
    ///
    /// ```no_run
    /// use twilight_gateway::Shard;
    /// use std::{env, time::Duration};
    /// use tokio::time as tokio_time;
    ///
    /// # #[tokio::main]
    /// # async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
    /// let mut shard = Shard::new(env::var("DISCORD_TOKEN")?);
    /// shard.start().await?;
    ///
    /// tokio_time::delay_for(Duration::from_secs(1)).await;
    ///
    /// let info = shard.info().await?;
    /// println!("Shard stage: {}", info.stage());
    /// # Ok(()) }
    /// ```
    pub fn new(config: impl Into<ShardConfig>) -> Self {
        Self::_new(config.into())
    }

    fn _new(config: ShardConfig) -> Self {
        let config = Arc::new(config);

        Self(Arc::new(ShardRef {
            config,
            listeners: Listeners::default(),
            processor_handle: OnceCell::new(),
            session: OnceCell::new(),
        }))
    }

    /// Returns an immutable reference to the configuration used for this client.
    pub fn config(&self) -> &ShardConfig {
        &self.0.config
    }

    /// Start the shard, connecting it to the gateway and starting the process
    /// of receiving and processing events.
    ///
    /// # Errors
    ///
    /// Errors if the `ShardProcessor` could not be started.
    pub async fn start(&mut self) -> Result<()> {
        let url = self
            .0
            .config
            .http_client()
            .gateway()
            .authed()
            .await
            .map_err(|source| Error::GettingGatewayUrl { source })?
            .url;

        let config = Arc::clone(&self.0.config);
        let listeners = self.0.listeners.clone();
        let (processor, wrx) = ShardProcessor::new(config, url, listeners)
            .await
            .map_err(|source| Error::Processor { source })?;
        let (fut, handle) = future::abortable(processor.run());

        tokio::spawn(async move {
            let _ = fut.await;

            debug!("[Shard] Shard processor future ended");
        });

        // We know that these haven't been set, so we can ignore this.
        let _ = self.0.processor_handle.set(handle);
        let _ = self.0.session.set(wrx);

        Ok(())
    }

    /// Creates a new stream of events from the shard.
    ///
    /// There can be multiple streams of events. All events will be broadcast to
    /// all streams of events.
    ///
    /// All event types except for [`EventType::SHARD_PAYLOAD`] are enabled.
    /// If you need to enable it, consider calling [`some_events`] instead.
    ///
    /// [`EventType::SHARD_PAYLOAD`]: events/struct.EventType.html#const.SHARD_PAYLOAD
    /// [`some_events`]: #method.some_events
    pub async fn events(&self) -> impl Stream<Item = Event> {
        let rx = self.0.listeners.add(EventTypeFlags::default());

        Events::new(EventTypeFlags::default(), rx)
    }

    /// Creates a new filtered stream of events from the shard.
    ///
    /// Only the events specified in the bitflags will be sent over the stream.
    ///
    /// # Examples
    ///
    /// Filter the events so that you only receive the [`Event::ShardConnected`]
    /// and [`Event::ShardDisconnected`] events:
    ///
    /// ```no_run
    /// use twilight_gateway::{EventTypeFlags, Event, Shard};
    /// use futures::StreamExt;
    /// use std::env;
    ///
    /// # #[tokio::main]
    /// # async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
    /// let mut shard = Shard::new(env::var("DISCORD_TOKEN")?);
    /// shard.start().await?;
    ///
    /// let event_types = EventTypeFlags::SHARD_CONNECTED | EventTypeFlags::SHARD_DISCONNECTED;
    /// let mut events = shard.some_events(event_types).await;
    ///
    /// while let Some(event) = events.next().await {
    ///     match event {
    ///         Event::ShardConnected(_) => println!("Shard is now connected"),
    ///         Event::ShardDisconnected(_) => println!("Shard is now disconnected"),
    ///         // No other event will come in through the stream.
    ///         _ => {},
    ///     }
    /// }
    /// # Ok(()) }
    /// ```
    pub async fn some_events(&self, event_types: EventTypeFlags) -> impl Stream<Item = Event> {
        let rx = self.0.listeners.add(event_types);

        Events::new(event_types, rx)
    }

    /// Returns information about the running of the shard, such as the current
    /// connection stage.
    ///
    /// # Errors
    ///
    /// Returns [`Error::Shutdown`] if the shard isn't actively running.
    ///
    /// [`Error::Shutdown`]: error/enum.Error.html
    pub async fn info(&self) -> Result<Information> {
        let session = self.session()?;

        Ok(Information {
            id: self.config().shard()[0],
            latency: session.heartbeats.latency().await,
            seq: session.seq(),
            stage: session.stage(),
        })
    }

    /// Returns a handle to the current session
    ///
    /// # Note
    ///
    /// This session can be invalidated if it is kept around
    /// under a reconnect or resume. In consequence this call
    /// should not be cached.
    ///
    /// # Errors
    ///
    /// Returns [`Error::Shutdown`] if the shard isn't actively running.
    ///
    /// [`Error::Shutdown`]: error/enum.Error.html
    pub fn session(&self) -> Result<Arc<Session>> {
        let session = self.0.session.get().ok_or(Error::Stopped)?;

        Ok(Arc::clone(&session.borrow()))
    }

    /// Returns an interface implementing the `Sink` trait which can be used to
    /// send messages.
    ///
    /// # Note
    ///
    /// This call should not be cached for too long
    /// as it will be invalidated by reconnects and
    /// resumes.
    ///
    /// # Errors
    ///
    /// Returns [`Error::Shutdown`] if the shard isn't actively running.
    ///
    /// [`Error::Shutdown`]: error/enum.Error.html
    pub fn sink(&self) -> Result<ShardSink> {
        let session = self.session()?;

        Ok(ShardSink(session.tx.clone()))
    }

    /// Send a command over the gateway.
    ///
    /// # Errors
    /// Fails if command could not be serialized, or if the command could
    /// not be sent.
    ///
    /// Returns [`Error::Shutdown`] if the shard isn't actively running.
    ///
    /// [`Error::Shutdown`]: error/enum.Error.html
    pub async fn command(&self, com: &impl serde::Serialize) -> Result<()> {
        let payload = Message::Text(
            crate::json_to_string(&com)
                .map_err(|err| Error::PayloadSerialization { source: err })?,
        );
        let session = self.session()?;

        // Tick ratelimiter.
        session.ratelimit.lock().await.tick().await;

        session
            .tx
            .unbounded_send(payload)
            .map_err(|err| Error::SendingMessage { source: err })?;
        Ok(())
    }

    /// Shuts down the shard.
    ///
    /// This will cleanly close the connection, causing discord to end the session and show the bot offline
    pub async fn shutdown(&self) {
        self.0.listeners.remove_all();

        if let Some(processor_handle) = self.0.processor_handle.get() {
            processor_handle.abort();
        }

        if let Ok(session) = self.session() {
            // Since we're shutting down now, we don't care if it sends or not.
            let _ = session.tx.unbounded_send(Message::Close(None));
            session.stop_heartbeater().await;
        }
    }

    /// This will shut down the shard in a resumable way and return shard id and optional session info to resume with later if this shard is resumable
    pub async fn shutdown_resumable(&self) -> (u64, Option<ResumeSession>) {
        self.0.listeners.remove_all();

        if let Some(processor_handle) = self.0.processor_handle.get() {
            processor_handle.abort();
        }

        let shard_id = self.config().shard()[0];

        let session = match self.session() {
            Ok(session) => session,
            Err(_) => return (shard_id, None),
        };

        let _ = session.tx.unbounded_send(Message::Close(Some(CloseFrame {
            code: CloseCode::Restart,
            reason: Cow::from("Closing in a resumable way"),
        })));

        let session_id = session.id.lock().await.clone();
        let sequence = session.seq.load(Ordering::Relaxed);

        session.stop_heartbeater().await;

        let data = match session_id {
            Some(id) => Some(ResumeSession {
                session_id: id,
                sequence,
            }),
            None => None,
        };

        (shard_id, data)
    }
}