From d834637d2fcb81619434269e9a9a181f8975b952 Mon Sep 17 00:00:00 2001 From: nathan Date: Tue, 14 Mar 2023 09:34:00 +0800 Subject: [PATCH] feat: ws message --- src/api/user.rs | 6 +++--- src/api/ws.rs | 5 +++-- src/application.rs | 2 +- src/component/auth/user.rs | 8 ++++---- src/component/ws/client.rs | 20 ++++++++------------ src/component/ws/entities.rs | 20 ++++++++++++++++---- src/component/ws/server.rs | 8 ++------ src/state.rs | 9 +++++---- 8 files changed, 42 insertions(+), 36 deletions(-) diff --git a/src/api/user.rs b/src/api/user.rs index 7502bd3d..b0e65a48 100644 --- a/src/api/user.rs +++ b/src/api/user.rs @@ -30,7 +30,7 @@ async fn login_handler( let password = UserPassword::parse(req.password) .map_err(|_| InputParamsError::InvalidPassword)? .0; - let (resp, token) = login(state.pg_pool.clone(), state.cache.clone(), email, password).await?; + let (resp, token) = login(state.pg_pool.clone(), state.user.clone(), email, password).await?; // Renews the session key, assigning existing session state to new key. session.renew(); @@ -43,7 +43,7 @@ async fn login_handler( } async fn logout_handler(logged_user: LoggedUser, state: Data) -> Result { - logout(logged_user, state.cache.clone()).await; + logout(logged_user, state.user.clone()).await; Ok(HttpResponse::Ok().finish()) } @@ -62,7 +62,7 @@ async fn register_handler(req: Json, state: Data) -> Res let resp = register( state.pg_pool.clone(), - state.cache.clone(), + state.user.clone(), name, email, password, diff --git a/src/api/ws.rs b/src/api/ws.rs index da9bc75b..93d3acf6 100644 --- a/src/api/ws.rs +++ b/src/api/ws.rs @@ -1,5 +1,6 @@ use crate::component::auth::LoggedUser; use crate::component::ws::{MessageReceivers, WSClient, WSServer}; +use crate::state::State; use actix::Addr; use actix_web::web::{Data, Path, Payload}; use actix_web::{get, web, HttpRequest, HttpResponse, Result, Scope}; @@ -14,14 +15,14 @@ pub async fn establish_ws_connection( request: HttpRequest, payload: Payload, token: Path, + _state: Data, server: Data>, msg_receivers: Data, ) -> Result { tracing::info!("establish_ws_connection"); let user = LoggedUser::from_token(token.clone())?; let client = WSClient::new(user, server.get_ref().clone(), msg_receivers); - let result = ws::start(client, &request, payload); - match result { + match ws::start(client, &request, payload) { Ok(response) => Ok(response), Err(e) => { tracing::error!("ws connection error: {:?}", e); diff --git a/src/application.rs b/src/application.rs index b0fc2906..731c1799 100644 --- a/src/application.rs +++ b/src/application.rs @@ -87,7 +87,7 @@ pub async fn init_state(configuration: &Config) -> State { State { pg_pool, - cache: Arc::new(Default::default()), + user: Arc::new(Default::default()), } } diff --git a/src/component/auth/user.rs b/src/component/auth/user.rs index c6f4ab29..771c3038 100644 --- a/src/component/auth/user.rs +++ b/src/component/auth/user.rs @@ -2,7 +2,7 @@ use crate::component::auth::{ compute_hash_password, internal_error, validate_credentials, AuthError, Credentials, }; use crate::config::env::{domain, jwt_secret}; -use crate::state::Cache; +use crate::state::UserCache; use crate::telemetry::spawn_blocking_with_tracing; use actix_web::http::header::HeaderValue; use actix_web::{FromRequest, HttpRequest}; @@ -20,7 +20,7 @@ use tokio::sync::RwLock; pub async fn login( pg_pool: PgPool, - cache: Arc>, + cache: Arc>, email: String, password: String, ) -> Result<(LoginResponse, Secret), AuthError> { @@ -50,13 +50,13 @@ pub async fn login( } } -pub async fn logout(logged_user: LoggedUser, cache: Arc>) { +pub async fn logout(logged_user: LoggedUser, cache: Arc>) { cache.write().await.unauthorized(logged_user); } pub async fn register( pg_pool: PgPool, - cache: Arc>, + cache: Arc>, username: String, email: String, password: String, diff --git a/src/component/ws/client.rs b/src/component/ws/client.rs index 1f0109c6..cb4092d3 100644 --- a/src/component/ws/client.rs +++ b/src/component/ws/client.rs @@ -1,6 +1,6 @@ use crate::component::auth::LoggedUser; use crate::component::ws::entities::{ - Connect, Disconnect, Socket, SocketMessagePayload, WebSocketMessage, + Connect, Disconnect, MessageDetail, MessagePayload, Socket, WebSocketMessage, }; use crate::component::ws::server::WSServer; use crate::component::ws::{HEARTBEAT_INTERVAL, PING_TIMEOUT}; @@ -36,10 +36,10 @@ impl MessageReceivers { } } +#[allow(dead_code)] pub struct WSClientData { - pub(crate) user: Arc, pub(crate) socket: Socket, - pub(crate) data: Bytes, + pub(crate) detail: MessageDetail, } pub struct WSClient { @@ -77,17 +77,13 @@ impl WSClient { } fn handle_binary_message(&self, bytes: Bytes, socket: Socket) { - let payload = SocketMessagePayload::from_bytes(&bytes); - match self.msg_receivers.get(payload.channel) { + let MessagePayload { channel, detail } = MessagePayload::from_bytes(&bytes); + match self.msg_receivers.get(channel) { None => { - tracing::error!("Can't find the receiver for {:?}", payload.channel); + tracing::error!("Can't find the receiver for {:?}", channel); } Some(handler) => { - let client_data = WSClientData { - user: self.user.clone(), - socket, - data: Bytes::from(payload.data), - }; + let client_data = WSClientData { socket, detail }; handler.receive(client_data); } } @@ -102,7 +98,7 @@ impl StreamHandler> for WSClient { ctx.pong(&msg); } Ok(Pong(_msg)) => { - tracing::trace!("Receive {} pong {:?}", &self.session_id, &msg); + // tracing::trace!("Receive {} pong {:?}", &self.session_id, &msg); self.hb = Instant::now(); } Ok(Binary(bytes)) => { diff --git a/src/component/ws/entities.rs b/src/component/ws/entities.rs index d31c302e..6915d672 100644 --- a/src/component/ws/entities.rs +++ b/src/component/ws/entities.rs @@ -63,17 +63,29 @@ impl std::ops::Deref for WebSocketMessage { } #[derive(Debug, Serialize, Deserialize)] -pub struct SocketMessagePayload { +pub struct MessagePayload { pub(crate) channel: u8, - pub(crate) data: Vec, + pub(crate) detail: MessageDetail, } -impl SocketMessagePayload { +impl MessagePayload { pub fn from_bytes>(bytes: T) -> Self { - bincode::deserialize(bytes.as_ref()).unwrap() + serde_json::from_slice(bytes.as_ref()).unwrap() } } +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum MessageDetail { + Document(MessageContent), + Database(MessageContent), +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct MessageContent { + content: String, +} + #[derive(Debug)] pub enum WSError { Internal, diff --git a/src/component/ws/server.rs b/src/component/ws/server.rs index 5cdb946a..46078298 100644 --- a/src/component/ws/server.rs +++ b/src/component/ws/server.rs @@ -13,9 +13,7 @@ impl WSServer { WSServer::default() } - pub fn send(&self, _msg: WebSocketMessage) { - unimplemented!() - } + pub fn send(&self, _msg: WebSocketMessage) {} } impl Actor for WSServer { @@ -40,9 +38,7 @@ impl Handler for WSServer { impl Handler for WSServer { type Result = (); - fn handle(&mut self, _msg: WebSocketMessage, _ctx: &mut Context) -> Self::Result { - unimplemented!() - } + fn handle(&mut self, _msg: WebSocketMessage, _ctx: &mut Context) -> Self::Result {} } impl actix::Supervised for WSServer { diff --git a/src/state.rs b/src/state.rs index 1ca93c91..8b57ee85 100644 --- a/src/state.rs +++ b/src/state.rs @@ -9,7 +9,7 @@ use tokio::sync::RwLock; #[derive(Clone)] pub struct State { pub pg_pool: PgPool, - pub cache: Arc>, + pub user: Arc>, } impl State { @@ -27,13 +27,14 @@ enum AuthStatus { pub const EXPIRED_DURATION_DAYS: i64 = 30; #[derive(Debug, Default)] -pub struct Cache { +pub struct UserCache { + // Keep track the user authentication state user: BTreeMap, } -impl Cache { +impl UserCache { pub fn new() -> Self { - Cache::default() + UserCache::default() } pub fn is_authorized(&self, user: &LoggedUser) -> bool {