From 29a0851f485957cc6410ccf9d261c781c1d2f757 Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Thu, 8 Feb 2024 15:11:23 +0800 Subject: [PATCH] feat: Rate limit of ws client (#306) * feat: implement rate limit for client * chore: check * chore: check * chore: check * chore: update * chore: add client version * chore: update --- Cargo.lock | 88 ++++++++++++++++++++++++- libs/client-api-test-util/src/client.rs | 1 + libs/client-api/Cargo.toml | 1 + libs/client-api/src/http.rs | 9 ++- libs/client-api/src/ws/client.rs | 38 ++++++++++- libs/realtime-entity/src/message.rs | 12 +++- libs/realtime/src/client.rs | 2 +- libs/realtime/src/collaborate/server.rs | 5 +- 8 files changed, 141 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6c2af36e..7718c59e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1274,6 +1274,7 @@ dependencies = [ "getrandom 0.2.12", "gotrue", "gotrue-entity", + "governor", "mime", "mime_guess", "parking_lot 0.12.1", @@ -1719,6 +1720,19 @@ dependencies = [ "syn 2.0.48", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown 0.14.3", + "lock_api", + "once_cell", + "parking_lot_core 0.9.9", +] + [[package]] name = "data-encoding" version = "2.5.0" @@ -2171,6 +2185,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +[[package]] +name = "futures-timer" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" + [[package]] name = "futures-util" version = "0.3.30" @@ -2314,6 +2334,24 @@ dependencies = [ "serde_json", ] +[[package]] +name = "governor" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "821239e5672ff23e2a7060901fa622950bbd80b649cdaadd78d1c1767ed14eb4" +dependencies = [ + "cfg-if", + "dashmap", + "futures", + "futures-timer", + "no-std-compat", + "nonzero_ext", + "parking_lot 0.12.1", + "quanta 0.11.1", + "rand 0.8.5", + "smallvec", +] + [[package]] name = "h2" version = "0.3.24" @@ -2963,6 +3001,15 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4" +[[package]] +name = "mach2" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b955cdeb2a02b9117f121ce63aa52d08ade45de53e48fe6a38b39c10f6f709" +dependencies = [ + "libc", +] + [[package]] name = "markup5ever" version = "0.11.0" @@ -3116,7 +3163,7 @@ dependencies = [ "futures-util", "once_cell", "parking_lot 0.12.1", - "quanta", + "quanta 0.12.2", "rustc_version", "skeptic", "smallvec", @@ -3176,6 +3223,12 @@ dependencies = [ "libc", ] +[[package]] +name = "no-std-compat" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" + [[package]] name = "nom" version = "7.1.3" @@ -3186,6 +3239,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "normpath" version = "1.1.1" @@ -3942,6 +4001,22 @@ dependencies = [ "unicase", ] +[[package]] +name = "quanta" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a17e662a7a8291a865152364c20c7abc5e60486ab2001e8ec10b24862de0b9ab" +dependencies = [ + "crossbeam-utils", + "libc", + "mach2", + "once_cell", + "raw-cpuid 10.7.0", + "wasi 0.11.0+wasi-snapshot-preview1", + "web-sys", + "winapi", +] + [[package]] name = "quanta" version = "0.12.2" @@ -3951,7 +4026,7 @@ dependencies = [ "crossbeam-utils", "libc", "once_cell", - "raw-cpuid", + "raw-cpuid 11.0.1", "wasi 0.11.0+wasi-snapshot-preview1", "web-sys", "winapi", @@ -4053,6 +4128,15 @@ dependencies = [ "rand_core 0.5.1", ] +[[package]] +name = "raw-cpuid" +version = "10.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "raw-cpuid" version = "11.0.1" diff --git a/libs/client-api-test-util/src/client.rs b/libs/client-api-test-util/src/client.rs index 5e574ec2..e267e1db 100644 --- a/libs/client-api-test-util/src/client.rs +++ b/libs/client-api-test-util/src/client.rs @@ -53,6 +53,7 @@ pub fn localhost_client_with_device_id(device_id: &str) -> Client { &LOCALHOST_GOTRUE, device_id, ClientConfiguration::default(), + "test", ) } diff --git a/libs/client-api/Cargo.toml b/libs/client-api/Cargo.toml index b2e87f3d..138c0040 100644 --- a/libs/client-api/Cargo.toml +++ b/libs/client-api/Cargo.toml @@ -43,6 +43,7 @@ serde.workspace = true database-entity.workspace = true app-error = { workspace = true, features = ["tokio_error", "bincode_error"] } scraper = { version = "0.17.1", optional = true } +governor = { version = "0.6.0" } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] tokio-retry = "0.3" diff --git a/libs/client-api/src/http.rs b/libs/client-api/src/http.rs index a51682b0..3fb8df51 100644 --- a/libs/client-api/src/http.rs +++ b/libs/client-api/src/http.rs @@ -46,10 +46,6 @@ use url::Url; use gotrue_entity::dto::SignUpResponse::{Authenticated, NotAuthenticated}; use gotrue_entity::dto::{GotrueTokenResponse, UpdateGotrueUserParams, User}; -/// The API version of the client. -/// 0.0.4 -/// fix refresh token issue -pub const CLIENT_API_VERSION: &str = "0.0.5"; pub const X_COMPRESSION_TYPE: &str = "X-Compression-Type"; pub const X_COMPRESSION_BUFFER_SIZE: &str = "X-Compression-Buffer-Size"; pub const X_COMPRESSION_TYPE_BROTLI: &str = "brotli"; @@ -110,6 +106,7 @@ pub struct Client { pub base_url: String, ws_addr: String, pub device_id: String, + pub client_id: String, pub(crate) token: Arc>, pub(crate) is_refreshing_token: Arc, pub(crate) refresh_ret_txs: Arc>>, @@ -134,6 +131,7 @@ impl Client { gotrue_url: &str, device_id: &str, config: ClientConfiguration, + client_id: &str, ) -> Self { let reqwest_client = reqwest::Client::new(); Self { @@ -146,6 +144,7 @@ impl Client { refresh_ret_txs: Default::default(), config, device_id: device_id.to_string(), + client_id: client_id.to_string(), } } @@ -1193,7 +1192,7 @@ impl Client { let request_builder = self .cloud_client .request(method, url) - .header("client-version", CLIENT_API_VERSION) + .header("client-version", self.client_id.clone()) .header("client-timestamp", ts_now.to_string()) .header("device_id", self.device_id.clone()) .bearer_auth(access_token); diff --git a/libs/client-api/src/ws/client.rs b/libs/client-api/src/ws/client.rs index cedcff1e..c4315066 100644 --- a/libs/client-api/src/ws/client.rs +++ b/libs/client-api/src/ws/client.rs @@ -1,17 +1,23 @@ use futures_util::{SinkExt, StreamExt}; +use governor::clock::DefaultClock; +use governor::middleware::NoOpMiddleware; +use governor::state::{InMemoryState, NotKeyed}; +use governor::{Quota, RateLimiter}; use parking_lot::RwLock; use std::borrow::Cow; use std::collections::HashMap; + +use futures_util::FutureExt; +use std::num::NonZeroU32; use std::sync::{Arc, Weak}; use std::time::Duration; - use tokio::sync::broadcast::{channel, Receiver, Sender}; use crate::ws::{ConnectState, ConnectStateNotify, WSError, WebSocketChannel}; use crate::ServerFixIntervalPing; use crate::{platform_spawn, retry_connect}; use realtime_entity::collab_msg::CollabMessage; -use realtime_entity::message::RealtimeMessage; +use realtime_entity::message::{RealtimeMessage, SystemMessage}; use realtime_entity::user::UserMessage; use tokio::sync::{oneshot, Mutex}; use tracing::{debug, error, info, trace, warn}; @@ -59,6 +65,8 @@ pub struct WSClient { collab_channels: Arc>, ping: Arc>>, stop_tx: Mutex>>, + rate_limiter: + Arc>>, } impl WSClient { @@ -72,6 +80,7 @@ impl WSClient { let ping = Arc::new(Mutex::new(None)); let http_sender = Arc::new(http_sender); let (user_channel, _) = channel(1); + let rate_limiter = gen_rate_limiter(10); WSClient { addr: Arc::new(parking_lot::Mutex::new(None)), config, @@ -82,6 +91,7 @@ impl WSClient { collab_channels, ping, stop_tx: Mutex::new(None), + rate_limiter: Arc::new(tokio::sync::RwLock::new(rate_limiter)), } } @@ -145,6 +155,7 @@ impl WSClient { *self.ping.lock().await = Some(ping); let user_message_tx = self.user_channel.as_ref().clone(); + let rate_limiter = self.rate_limiter.clone(); // Receive messages from the websocket, and send them to the channels. platform_spawn(async move { while let Some(Ok(ws_msg)) = stream.next().await { @@ -180,7 +191,14 @@ impl WSClient { RealtimeMessage::User(user_message) => { let _ = user_message_tx.send(user_message); }, - RealtimeMessage::ServerKickedOff => {}, + RealtimeMessage::System(sys_message) => match sys_message { + SystemMessage::RateLimit(limit) => { + *rate_limiter.write().await = gen_rate_limiter(limit); + }, + SystemMessage::KickOff => { + // + }, + }, } }, Err(err) => { @@ -210,12 +228,15 @@ impl WSClient { let mut rx = self.sender.subscribe(); let weak_http_sender = Arc::downgrade(&self.http_sender); + let rate_limiter = self.rate_limiter.clone(); let device_id = device_id.to_string(); platform_spawn(async move { loop { tokio::select! { _ = &mut stop_rx => break, Ok(msg) = rx.recv() => { + rate_limiter.read().await.until_ready().fuse().await; + let len = msg.len(); // The maximum size allowed for a WebSocket message is 65,536 bytes. If the message exceeds // 40,960 bytes (to avoid occupying the entire space), it should be sent over HTTP instead. @@ -307,3 +328,14 @@ impl WSClient { self.state_notify.lock().set_state(state); } } + +fn gen_rate_limiter( + mut times_per_sec: u32, +) -> RateLimiter { + // make sure the rate limiter is not zero + if times_per_sec == 0 { + times_per_sec = 1; + } + let quota = Quota::per_second(NonZeroU32::new(times_per_sec).unwrap()); + RateLimiter::direct(quota) +} diff --git a/libs/realtime-entity/src/message.rs b/libs/realtime-entity/src/message.rs index 2a30e674..441861c4 100644 --- a/libs/realtime-entity/src/message.rs +++ b/libs/realtime-entity/src/message.rs @@ -13,15 +13,15 @@ use websocket::Message; pub enum RealtimeMessage { Collab(CollabMessage), User(UserMessage), - ServerKickedOff, + System(SystemMessage), } impl RealtimeMessage { pub fn device_id(&self) -> Option { match self { RealtimeMessage::Collab(msg) => msg.device_id(), - RealtimeMessage::ServerKickedOff => None, RealtimeMessage::User(_) => None, + RealtimeMessage::System(_) => None, } } } @@ -30,8 +30,8 @@ impl Display for RealtimeMessage { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { RealtimeMessage::Collab(msg) => f.write_fmt(format_args!("Collab:{}", msg.object_id())), - RealtimeMessage::ServerKickedOff => f.write_fmt(format_args!("ServerKickedOff")), RealtimeMessage::User(_) => f.write_fmt(format_args!("User")), + RealtimeMessage::System(_) => f.write_fmt(format_args!("System")), } } } @@ -111,3 +111,9 @@ impl From for Message { Message::Binary(bytes) } } + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum SystemMessage { + RateLimit(u32), + KickOff, +} diff --git a/libs/realtime/src/client.rs b/libs/realtime/src/client.rs index cc05a4b8..b2e9a0fd 100644 --- a/libs/realtime/src/client.rs +++ b/libs/realtime/src/client.rs @@ -222,7 +222,7 @@ where match &msg { RealtimeMessage::Collab(_) => ctx.binary(msg), RealtimeMessage::User(_) => ctx.binary(msg), - RealtimeMessage::ServerKickedOff => ctx.stop(), + RealtimeMessage::System(_) => ctx.binary(msg), } } } diff --git a/libs/realtime/src/collaborate/server.rs b/libs/realtime/src/collaborate/server.rs index dbb0a8f4..0cfc9a8d 100644 --- a/libs/realtime/src/collaborate/server.rs +++ b/libs/realtime/src/collaborate/server.rs @@ -28,6 +28,7 @@ use crate::collaborate::permission::CollabAccessControl; use crate::collaborate::retry::{CollabUserMessage, SubscribeGroupIfNeed}; use crate::util::channel_ext::UnboundedSenderSink; use database::collab::CollabStorage; +use realtime_entity::message::SystemMessage; #[derive(Clone)] pub struct CollabServer { @@ -509,7 +510,9 @@ impl CollabClientStream { } pub fn disconnect(&self) { - self.sink.do_send(RealtimeMessage::ServerKickedOff); + self + .sink + .do_send(RealtimeMessage::System(SystemMessage::KickOff)); } }