From 5a7e223b99e94c42945f4d549834b6f0bc17368b Mon Sep 17 00:00:00 2001 From: nathan Date: Tue, 14 Mar 2023 00:05:23 +0800 Subject: [PATCH] feat: support ws --- Cargo.lock | 38 ++++++++ Cargo.toml | 4 + src/api/mod.rs | 2 + src/api/ws.rs | 31 +++++++ src/application.rs | 3 +- src/component/auth/user.rs | 27 +++--- src/component/mod.rs | 1 + src/component/ws/client.rs | 168 +++++++++++++++++++++++++++++++++++ src/component/ws/entities.rs | 84 ++++++++++++++++++ src/component/ws/mod.rs | 12 +++ src/component/ws/server.rs | 62 +++++++++++++ 11 files changed, 415 insertions(+), 17 deletions(-) create mode 100644 src/api/ws.rs create mode 100644 src/component/ws/client.rs create mode 100644 src/component/ws/entities.rs create mode 100644 src/component/ws/mod.rs create mode 100644 src/component/ws/server.rs diff --git a/Cargo.lock b/Cargo.lock index cd257a37..f0b696a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9,6 +9,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f728064aca1c318585bf4bb04ffcfac9e75e508ab4e8b1bd9ba5dfe04e2cbed5" dependencies = [ "actix-rt", + "actix_derive", "bitflags", "bytes", "crossbeam-channel", @@ -291,6 +292,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "actix_derive" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d44b8fee1ced9671ba043476deddef739dd0959bf77030b26b738cc591737a7" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "adler" version = "1.0.2" @@ -398,6 +410,7 @@ checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800" name = "appflowy-server" version = "0.1.0" dependencies = [ + "actix", "actix-cors", "actix-http", "actix-identity", @@ -410,8 +423,11 @@ dependencies = [ "anyhow", "argon2", "async-stream", + "bincode", + "bytes", "chrono", "config", + "dashmap", "derive_more", "fancy-regex", "futures-util", @@ -526,6 +542,15 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bit-set" version = "0.5.3" @@ -874,6 +899,19 @@ dependencies = [ "syn", ] +[[package]] +name = "dashmap" +version = "5.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" +dependencies = [ + "cfg-if", + "hashbrown", + "lock_api", + "once_cell", + "parking_lot_core 0.9.7", +] + [[package]] name = "derive_more" version = "0.99.17" diff --git a/Cargo.toml b/Cargo.toml index 8e90c660..00d13884 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +actix = "0.13" actix-web = "4.3.1" actix-http = "3.3.1" actix-rt = "2" @@ -46,6 +47,9 @@ unicode-segmentation = "1.0" lazy_static = "1.4.0" fancy-regex = "0.11.0" validator = "0.16.0" +bytes = "1.4.0" +bincode = "1.3.3" +dashmap = "5.4" # tracing tracing = { version = "0.1.37" } diff --git a/src/api/mod.rs b/src/api/mod.rs index 6ad6bb75..a7ea590f 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,5 +1,7 @@ mod token; mod user; +mod ws; pub use token::token_scope; pub use user::user_scope; +pub use ws::ws_scope; diff --git a/src/api/ws.rs b/src/api/ws.rs new file mode 100644 index 00000000..da9bc75b --- /dev/null +++ b/src/api/ws.rs @@ -0,0 +1,31 @@ +use crate::component::auth::LoggedUser; +use crate::component::ws::{MessageReceivers, WSClient, WSServer}; +use actix::Addr; +use actix_web::web::{Data, Path, Payload}; +use actix_web::{get, web, HttpRequest, HttpResponse, Result, Scope}; +use actix_web_actors::ws; + +pub fn ws_scope() -> Scope { + web::scope("/ws").service(establish_ws_connection) +} + +#[get("/{token}")] +pub async fn establish_ws_connection( + request: HttpRequest, + payload: Payload, + token: Path, + server: Data>, + msg_receivers: Data, +) -> Result { + tracing::info!("establish_ws_connection"); + let user = LoggedUser::from_token(token.clone())?; + let client = WSClient::new(user, server.get_ref().clone(), msg_receivers); + let result = ws::start(client, &request, payload); + match result { + Ok(response) => Ok(response), + Err(e) => { + tracing::error!("ws connection error: {:?}", e); + Err(e) + } + } +} diff --git a/src/application.rs b/src/application.rs index 421bf146..b0fc2906 100644 --- a/src/application.rs +++ b/src/application.rs @@ -1,4 +1,4 @@ -use crate::api::{token_scope, user_scope}; +use crate::api::{token_scope, user_scope, ws_scope}; use crate::component::auth::HEADER_TOKEN; use crate::config::config::{Config, DatabaseSetting}; use crate::middleware::cors::default_cors; @@ -67,6 +67,7 @@ pub async fn run( .app_data(web::JsonConfig::default().limit(4096)) .service(user_scope()) .service(token_scope()) + .service(ws_scope()) .app_data(Data::new(state.clone())) }) .listen(listener)? diff --git a/src/component/auth/user.rs b/src/component/auth/user.rs index a11af3f2..c6f4ab29 100644 --- a/src/component/auth/user.rs +++ b/src/component/auth/user.rs @@ -139,18 +139,13 @@ pub async fn change_password( err: format!("{}", e), })?; // Save password to disk - sqlx::query!( - r#" - UPDATE users - SET password = $1 - WHERE uid = $2 - "#, - new_hash_password.expose_secret(), - uid - ) - .execute(&mut transaction) - .await - .context("Failed to change user's password in the database.")?; + let sql = "update users set password = ? where uid = ?"; + let _ = sqlx::query(sql) + .bind(new_hash_password.expose_secret()) + .bind(uid) + .execute(&mut transaction) + .await + .context("Failed to change user's password in the database.")?; transaction .commit() @@ -228,7 +223,7 @@ pub struct LoggedUser { uid: Secret, } -impl std::convert::From for LoggedUser { +impl From for LoggedUser { fn from(c: Claim) -> Self { Self { uid: Secret::new(c.user_id()), @@ -269,7 +264,7 @@ impl FromRequest for LoggedUser { } } -impl std::convert::TryFrom<&HeaderValue> for LoggedUser { +impl TryFrom<&HeaderValue> for LoggedUser { type Error = AuthError; fn try_from(header: &HeaderValue) -> Result { @@ -360,13 +355,13 @@ impl Token { } } -impl std::convert::From for Token { +impl From for Token { fn from(val: String) -> Self { Self(val) } } -impl std::convert::Into for Token { +impl Into for Token { fn into(self) -> String { self.0 } diff --git a/src/component/mod.rs b/src/component/mod.rs index 79bb8910..adfbdbb9 100644 --- a/src/component/mod.rs +++ b/src/component/mod.rs @@ -1,2 +1,3 @@ pub mod auth; pub mod token_state; +pub mod ws; diff --git a/src/component/ws/client.rs b/src/component/ws/client.rs new file mode 100644 index 00000000..6687717c --- /dev/null +++ b/src/component/ws/client.rs @@ -0,0 +1,168 @@ +use crate::component::auth::LoggedUser; +use crate::component::ws::entities::{ + Connect, Disconnect, Socket, SocketMessagePayload, WebSocketMessage, +}; +use crate::component::ws::server::WSServer; +use crate::component::ws::{HEARTBEAT_INTERVAL, PING_TIMEOUT}; +use actix::*; +use actix_http::ws::Message::*; +use actix_web::web::Data; +use actix_web_actors::ws; +use bytes::Bytes; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Instant; + +pub trait MessageReceiver: Send + Sync { + fn receive(&self, data: WSClientData); +} + +#[derive(Default)] +pub struct MessageReceivers { + inner: HashMap>, +} + +impl MessageReceivers { + pub fn new() -> Self { + MessageReceivers::default() + } + + pub fn set(&mut self, channel: u8, receiver: Arc) { + self.inner.insert(channel, receiver); + } + + pub fn get(&self, source: u8) -> Option> { + self.inner.get(&source).cloned() + } +} + +pub struct WSClientData { + pub(crate) user: Arc, + pub(crate) socket: Socket, + pub(crate) data: Bytes, +} + +pub struct WSClient { + user: Arc, + server: Addr, + msg_receivers: Data, + hb: Instant, +} + +impl WSClient { + pub fn new( + user: LoggedUser, + server: Addr, + msg_receivers: Data, + ) -> Self { + Self { + user: Arc::new(user), + server, + msg_receivers, + hb: Instant::now(), + } + } + + fn hb(&self, ctx: &mut ws::WebsocketContext) { + ctx.run_interval(HEARTBEAT_INTERVAL, |client, ctx| { + if Instant::now().duration_since(client.hb) > PING_TIMEOUT { + client.server.do_send(Disconnect { + user: client.user.clone(), + }); + ctx.stop(); + } else { + ctx.ping(b""); + } + }); + } + + fn handle_binary_message(&self, bytes: Bytes, socket: Socket) { + let payload = SocketMessagePayload::from_bytes(&bytes); + match self.msg_receivers.get(payload.channel) { + None => { + tracing::error!("Can't find the receiver for {:?}", payload.channel); + } + Some(handler) => { + let client_data = WSClientData { + user: self.user.clone(), + socket, + data: Bytes::from(payload.data), + }; + handler.receive(client_data); + } + } + } +} + +impl StreamHandler> for WSClient { + fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { + match msg { + Ok(Ping(msg)) => { + self.hb = Instant::now(); + ctx.pong(&msg); + } + Ok(Pong(_msg)) => { + // tracing::debug!("Receive {} pong {:?}", &self.session_id, &msg); + self.hb = Instant::now(); + } + Ok(Binary(bytes)) => { + let socket = ctx.address().recipient(); + self.handle_binary_message(bytes, socket); + } + Ok(Text(_)) => { + tracing::warn!("Receive unexpected text message"); + } + Ok(Close(reason)) => { + ctx.close(reason); + ctx.stop(); + } + Ok(ws::Message::Continuation(_)) => {} + Ok(ws::Message::Nop) => {} + Err(e) => { + tracing::error!("WebSocketStream protocol error {:?}", e); + ctx.stop(); + } + } + } +} + +impl Handler for WSClient { + type Result = (); + + fn handle(&mut self, msg: WebSocketMessage, ctx: &mut Self::Context) { + ctx.binary(msg.0); + } +} + +impl Actor for WSClient { + type Context = ws::WebsocketContext; + + fn started(&mut self, ctx: &mut Self::Context) { + self.hb(ctx); + let socket = ctx.address().recipient(); + let connect = Connect { + socket, + user: self.user.clone(), + }; + self.server + .send(connect) + .into_actor(self) + .then(|res, _client, _ctx| { + match res { + Ok(Ok(_)) => tracing::trace!("Send connect message to server success"), + Ok(Err(e)) => tracing::error!("Send connect message to server failed: {:?}", e), + Err(e) => tracing::error!("Send connect message to server failed: {:?}", e), + } + fut::ready(()) + }) + .wait(ctx); + } + + fn stopping(&mut self, _: &mut Self::Context) -> Running { + self.server.do_send(Disconnect { + user: self.user.clone(), + }); + + Running::Stop + } +} diff --git a/src/component/ws/entities.rs b/src/component/ws/entities.rs new file mode 100644 index 00000000..ea4dd37e --- /dev/null +++ b/src/component/ws/entities.rs @@ -0,0 +1,84 @@ +use crate::component::auth::LoggedUser; +use actix::{Message, Recipient}; +use bytes::Bytes; +use serde::{Deserialize, Serialize}; +use std::fmt::Formatter; +use std::sync::Arc; + +pub type Socket = Recipient; + +#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, Eq)] +pub struct WSSessionId(pub String); + +impl> std::convert::From for WSSessionId { + fn from(s: T) -> Self { + WSSessionId(s.as_ref().to_owned()) + } +} + +impl std::fmt::Display for WSSessionId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let desc = &self.0.to_string(); + f.write_str(desc) + } +} + +pub struct Session { + pub user: Arc, + pub socket: Socket, +} + +impl std::convert::From for Session { + fn from(c: Connect) -> Self { + Self { + user: c.user, + socket: c.socket, + } + } +} + +#[derive(Debug, Message, Clone)] +#[rtype(result = "Result<(), WSError>")] +pub struct Connect { + pub socket: Socket, + pub user: Arc, +} + +#[derive(Debug, Message, Clone)] +#[rtype(result = "Result<(), WSError>")] +pub struct Disconnect { + pub user: Arc, +} + +#[derive(Debug, Message, Clone)] +#[rtype(result = "()")] +pub struct WebSocketMessage(pub Bytes); + +impl std::ops::Deref for WebSocketMessage { + type Target = Bytes; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SocketMessagePayload { + pub(crate) channel: u8, + pub(crate) data: Vec, +} + +impl SocketMessagePayload { + pub fn to_bytes(&self) -> Vec { + bincode::serialize(self).unwrap() + } + + pub fn from_bytes>(bytes: T) -> Self { + bincode::deserialize(bytes.as_ref()).unwrap() + } +} + +#[derive(Debug)] +pub enum WSError { + Internal, +} diff --git a/src/component/ws/mod.rs b/src/component/ws/mod.rs new file mode 100644 index 00000000..f77b68e1 --- /dev/null +++ b/src/component/ws/mod.rs @@ -0,0 +1,12 @@ +use std::time::Duration; + +mod client; +mod entities; +mod server; + +pub use client::*; +pub use server::WSServer; + +pub(crate) const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(8); +pub(crate) const PING_TIMEOUT: Duration = Duration::from_secs(60); +pub(crate) const MAX_PAYLOAD_SIZE: usize = 262_144; // max payload size is 256k diff --git a/src/component/ws/server.rs b/src/component/ws/server.rs new file mode 100644 index 00000000..fb714747 --- /dev/null +++ b/src/component/ws/server.rs @@ -0,0 +1,62 @@ +use crate::component::ws::entities::{ + Connect, Disconnect, Session, WSError, WSSessionId, WebSocketMessage, +}; +use actix::{Actor, Context, Handler}; +use dashmap::DashMap; + +pub struct WSServer { + sessions: DashMap, +} + +impl std::default::Default for WSServer { + fn default() -> Self { + Self { + sessions: DashMap::new(), + } + } +} +impl WSServer { + pub fn new() -> Self { + WSServer::default() + } + + pub fn send(&self, _msg: WebSocketMessage) { + unimplemented!() + } +} + +impl Actor for WSServer { + type Context = Context; + fn started(&mut self, _ctx: &mut Self::Context) {} +} + +impl Handler for WSServer { + type Result = Result<(), WSError>; + fn handle(&mut self, msg: Connect, _ctx: &mut Context) -> Self::Result { + let session: Session = msg.into(); + self.sessions.insert(session.id.clone(), session); + Ok(()) + } +} + +impl Handler for WSServer { + type Result = Result<(), WSError>; + fn handle(&mut self, msg: Disconnect, _: &mut Context) -> Self::Result { + self.sessions.remove(&msg.sid); + Ok(()) + } +} + +impl Handler for WSServer { + type Result = (); + + fn handle(&mut self, _msg: WebSocketMessage, _ctx: &mut Context) -> Self::Result { + unimplemented!() + } +} + +impl actix::Supervised for WSServer { + fn restarting(&mut self, _ctx: &mut Context) { + tracing::warn!("restarting"); + } +}