feat: Rate limit of ws client (#306)

* feat: implement rate limit for client

* chore: check

* chore: check

* chore: check

* chore: update

* chore: add client version

* chore: update
This commit is contained in:
Nathan.fooo 2024-02-08 15:11:23 +08:00 committed by GitHub
parent e1307f4f5d
commit 29a0851f48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 141 additions and 15 deletions

88
Cargo.lock generated
View File

@ -1274,6 +1274,7 @@ dependencies = [
"getrandom 0.2.12",
"gotrue",
"gotrue-entity",
"governor",
"mime",
"mime_guess",
"parking_lot 0.12.1",
@ -1719,6 +1720,19 @@ dependencies = [
"syn 2.0.48",
]
[[package]]
name = "dashmap"
version = "5.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
dependencies = [
"cfg-if",
"hashbrown 0.14.3",
"lock_api",
"once_cell",
"parking_lot_core 0.9.9",
]
[[package]]
name = "data-encoding"
version = "2.5.0"
@ -2171,6 +2185,12 @@ version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
[[package]]
name = "futures-timer"
version = "3.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c"
[[package]]
name = "futures-util"
version = "0.3.30"
@ -2314,6 +2334,24 @@ dependencies = [
"serde_json",
]
[[package]]
name = "governor"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "821239e5672ff23e2a7060901fa622950bbd80b649cdaadd78d1c1767ed14eb4"
dependencies = [
"cfg-if",
"dashmap",
"futures",
"futures-timer",
"no-std-compat",
"nonzero_ext",
"parking_lot 0.12.1",
"quanta 0.11.1",
"rand 0.8.5",
"smallvec",
]
[[package]]
name = "h2"
version = "0.3.24"
@ -2963,6 +3001,15 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4"
[[package]]
name = "mach2"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19b955cdeb2a02b9117f121ce63aa52d08ade45de53e48fe6a38b39c10f6f709"
dependencies = [
"libc",
]
[[package]]
name = "markup5ever"
version = "0.11.0"
@ -3116,7 +3163,7 @@ dependencies = [
"futures-util",
"once_cell",
"parking_lot 0.12.1",
"quanta",
"quanta 0.12.2",
"rustc_version",
"skeptic",
"smallvec",
@ -3176,6 +3223,12 @@ dependencies = [
"libc",
]
[[package]]
name = "no-std-compat"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
[[package]]
name = "nom"
version = "7.1.3"
@ -3186,6 +3239,12 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "nonzero_ext"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
[[package]]
name = "normpath"
version = "1.1.1"
@ -3942,6 +4001,22 @@ dependencies = [
"unicase",
]
[[package]]
name = "quanta"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a17e662a7a8291a865152364c20c7abc5e60486ab2001e8ec10b24862de0b9ab"
dependencies = [
"crossbeam-utils",
"libc",
"mach2",
"once_cell",
"raw-cpuid 10.7.0",
"wasi 0.11.0+wasi-snapshot-preview1",
"web-sys",
"winapi",
]
[[package]]
name = "quanta"
version = "0.12.2"
@ -3951,7 +4026,7 @@ dependencies = [
"crossbeam-utils",
"libc",
"once_cell",
"raw-cpuid",
"raw-cpuid 11.0.1",
"wasi 0.11.0+wasi-snapshot-preview1",
"web-sys",
"winapi",
@ -4053,6 +4128,15 @@ dependencies = [
"rand_core 0.5.1",
]
[[package]]
name = "raw-cpuid"
version = "10.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332"
dependencies = [
"bitflags 1.3.2",
]
[[package]]
name = "raw-cpuid"
version = "11.0.1"

View File

@ -53,6 +53,7 @@ pub fn localhost_client_with_device_id(device_id: &str) -> Client {
&LOCALHOST_GOTRUE,
device_id,
ClientConfiguration::default(),
"test",
)
}

View File

@ -43,6 +43,7 @@ serde.workspace = true
database-entity.workspace = true
app-error = { workspace = true, features = ["tokio_error", "bincode_error"] }
scraper = { version = "0.17.1", optional = true }
governor = { version = "0.6.0" }
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
tokio-retry = "0.3"

View File

@ -46,10 +46,6 @@ use url::Url;
use gotrue_entity::dto::SignUpResponse::{Authenticated, NotAuthenticated};
use gotrue_entity::dto::{GotrueTokenResponse, UpdateGotrueUserParams, User};
/// The API version of the client.
/// 0.0.4
/// fix refresh token issue
pub const CLIENT_API_VERSION: &str = "0.0.5";
pub const X_COMPRESSION_TYPE: &str = "X-Compression-Type";
pub const X_COMPRESSION_BUFFER_SIZE: &str = "X-Compression-Buffer-Size";
pub const X_COMPRESSION_TYPE_BROTLI: &str = "brotli";
@ -110,6 +106,7 @@ pub struct Client {
pub base_url: String,
ws_addr: String,
pub device_id: String,
pub client_id: String,
pub(crate) token: Arc<RwLock<ClientToken>>,
pub(crate) is_refreshing_token: Arc<AtomicBool>,
pub(crate) refresh_ret_txs: Arc<RwLock<Vec<RefreshTokenSender>>>,
@ -134,6 +131,7 @@ impl Client {
gotrue_url: &str,
device_id: &str,
config: ClientConfiguration,
client_id: &str,
) -> Self {
let reqwest_client = reqwest::Client::new();
Self {
@ -146,6 +144,7 @@ impl Client {
refresh_ret_txs: Default::default(),
config,
device_id: device_id.to_string(),
client_id: client_id.to_string(),
}
}
@ -1193,7 +1192,7 @@ impl Client {
let request_builder = self
.cloud_client
.request(method, url)
.header("client-version", CLIENT_API_VERSION)
.header("client-version", self.client_id.clone())
.header("client-timestamp", ts_now.to_string())
.header("device_id", self.device_id.clone())
.bearer_auth(access_token);

View File

@ -1,17 +1,23 @@
use futures_util::{SinkExt, StreamExt};
use governor::clock::DefaultClock;
use governor::middleware::NoOpMiddleware;
use governor::state::{InMemoryState, NotKeyed};
use governor::{Quota, RateLimiter};
use parking_lot::RwLock;
use std::borrow::Cow;
use std::collections::HashMap;
use futures_util::FutureExt;
use std::num::NonZeroU32;
use std::sync::{Arc, Weak};
use std::time::Duration;
use tokio::sync::broadcast::{channel, Receiver, Sender};
use crate::ws::{ConnectState, ConnectStateNotify, WSError, WebSocketChannel};
use crate::ServerFixIntervalPing;
use crate::{platform_spawn, retry_connect};
use realtime_entity::collab_msg::CollabMessage;
use realtime_entity::message::RealtimeMessage;
use realtime_entity::message::{RealtimeMessage, SystemMessage};
use realtime_entity::user::UserMessage;
use tokio::sync::{oneshot, Mutex};
use tracing::{debug, error, info, trace, warn};
@ -59,6 +65,8 @@ pub struct WSClient {
collab_channels: Arc<RwLock<ChannelByObjectId>>,
ping: Arc<Mutex<Option<ServerFixIntervalPing>>>,
stop_tx: Mutex<Option<oneshot::Sender<()>>>,
rate_limiter:
Arc<tokio::sync::RwLock<RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>>>,
}
impl WSClient {
@ -72,6 +80,7 @@ impl WSClient {
let ping = Arc::new(Mutex::new(None));
let http_sender = Arc::new(http_sender);
let (user_channel, _) = channel(1);
let rate_limiter = gen_rate_limiter(10);
WSClient {
addr: Arc::new(parking_lot::Mutex::new(None)),
config,
@ -82,6 +91,7 @@ impl WSClient {
collab_channels,
ping,
stop_tx: Mutex::new(None),
rate_limiter: Arc::new(tokio::sync::RwLock::new(rate_limiter)),
}
}
@ -145,6 +155,7 @@ impl WSClient {
*self.ping.lock().await = Some(ping);
let user_message_tx = self.user_channel.as_ref().clone();
let rate_limiter = self.rate_limiter.clone();
// Receive messages from the websocket, and send them to the channels.
platform_spawn(async move {
while let Some(Ok(ws_msg)) = stream.next().await {
@ -180,7 +191,14 @@ impl WSClient {
RealtimeMessage::User(user_message) => {
let _ = user_message_tx.send(user_message);
},
RealtimeMessage::ServerKickedOff => {},
RealtimeMessage::System(sys_message) => match sys_message {
SystemMessage::RateLimit(limit) => {
*rate_limiter.write().await = gen_rate_limiter(limit);
},
SystemMessage::KickOff => {
//
},
},
}
},
Err(err) => {
@ -210,12 +228,15 @@ impl WSClient {
let mut rx = self.sender.subscribe();
let weak_http_sender = Arc::downgrade(&self.http_sender);
let rate_limiter = self.rate_limiter.clone();
let device_id = device_id.to_string();
platform_spawn(async move {
loop {
tokio::select! {
_ = &mut stop_rx => break,
Ok(msg) = rx.recv() => {
rate_limiter.read().await.until_ready().fuse().await;
let len = msg.len();
// The maximum size allowed for a WebSocket message is 65,536 bytes. If the message exceeds
// 40,960 bytes (to avoid occupying the entire space), it should be sent over HTTP instead.
@ -307,3 +328,14 @@ impl WSClient {
self.state_notify.lock().set_state(state);
}
}
fn gen_rate_limiter(
mut times_per_sec: u32,
) -> RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware> {
// make sure the rate limiter is not zero
if times_per_sec == 0 {
times_per_sec = 1;
}
let quota = Quota::per_second(NonZeroU32::new(times_per_sec).unwrap());
RateLimiter::direct(quota)
}

View File

@ -13,15 +13,15 @@ use websocket::Message;
pub enum RealtimeMessage {
Collab(CollabMessage),
User(UserMessage),
ServerKickedOff,
System(SystemMessage),
}
impl RealtimeMessage {
pub fn device_id(&self) -> Option<String> {
match self {
RealtimeMessage::Collab(msg) => msg.device_id(),
RealtimeMessage::ServerKickedOff => None,
RealtimeMessage::User(_) => None,
RealtimeMessage::System(_) => None,
}
}
}
@ -30,8 +30,8 @@ impl Display for RealtimeMessage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RealtimeMessage::Collab(msg) => f.write_fmt(format_args!("Collab:{}", msg.object_id())),
RealtimeMessage::ServerKickedOff => f.write_fmt(format_args!("ServerKickedOff")),
RealtimeMessage::User(_) => f.write_fmt(format_args!("User")),
RealtimeMessage::System(_) => f.write_fmt(format_args!("System")),
}
}
}
@ -111,3 +111,9 @@ impl From<RealtimeMessage> for Message {
Message::Binary(bytes)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SystemMessage {
RateLimit(u32),
KickOff,
}

View File

@ -222,7 +222,7 @@ where
match &msg {
RealtimeMessage::Collab(_) => ctx.binary(msg),
RealtimeMessage::User(_) => ctx.binary(msg),
RealtimeMessage::ServerKickedOff => ctx.stop(),
RealtimeMessage::System(_) => ctx.binary(msg),
}
}
}

View File

@ -28,6 +28,7 @@ use crate::collaborate::permission::CollabAccessControl;
use crate::collaborate::retry::{CollabUserMessage, SubscribeGroupIfNeed};
use crate::util::channel_ext::UnboundedSenderSink;
use database::collab::CollabStorage;
use realtime_entity::message::SystemMessage;
#[derive(Clone)]
pub struct CollabServer<S, U, AC> {
@ -509,7 +510,9 @@ impl CollabClientStream {
}
pub fn disconnect(&self) {
self.sink.do_send(RealtimeMessage::ServerKickedOff);
self
.sink
.do_send(RealtimeMessage::System(SystemMessage::KickOff));
}
}