feat: ws message

This commit is contained in:
nathan 2023-03-14 09:34:00 +08:00
parent db498ef5e8
commit d834637d2f
8 changed files with 42 additions and 36 deletions

View File

@ -30,7 +30,7 @@ async fn login_handler(
let password = UserPassword::parse(req.password)
.map_err(|_| InputParamsError::InvalidPassword)?
.0;
let (resp, token) = login(state.pg_pool.clone(), state.cache.clone(), email, password).await?;
let (resp, token) = login(state.pg_pool.clone(), state.user.clone(), email, password).await?;
// Renews the session key, assigning existing session state to new key.
session.renew();
@ -43,7 +43,7 @@ async fn login_handler(
}
async fn logout_handler(logged_user: LoggedUser, state: Data<State>) -> Result<HttpResponse> {
logout(logged_user, state.cache.clone()).await;
logout(logged_user, state.user.clone()).await;
Ok(HttpResponse::Ok().finish())
}
@ -62,7 +62,7 @@ async fn register_handler(req: Json<RegisterRequest>, state: Data<State>) -> Res
let resp = register(
state.pg_pool.clone(),
state.cache.clone(),
state.user.clone(),
name,
email,
password,

View File

@ -1,5 +1,6 @@
use crate::component::auth::LoggedUser;
use crate::component::ws::{MessageReceivers, WSClient, WSServer};
use crate::state::State;
use actix::Addr;
use actix_web::web::{Data, Path, Payload};
use actix_web::{get, web, HttpRequest, HttpResponse, Result, Scope};
@ -14,14 +15,14 @@ pub async fn establish_ws_connection(
request: HttpRequest,
payload: Payload,
token: Path<String>,
_state: Data<State>,
server: Data<Addr<WSServer>>,
msg_receivers: Data<MessageReceivers>,
) -> Result<HttpResponse> {
tracing::info!("establish_ws_connection");
let user = LoggedUser::from_token(token.clone())?;
let client = WSClient::new(user, server.get_ref().clone(), msg_receivers);
let result = ws::start(client, &request, payload);
match result {
match ws::start(client, &request, payload) {
Ok(response) => Ok(response),
Err(e) => {
tracing::error!("ws connection error: {:?}", e);

View File

@ -87,7 +87,7 @@ pub async fn init_state(configuration: &Config) -> State {
State {
pg_pool,
cache: Arc::new(Default::default()),
user: Arc::new(Default::default()),
}
}

View File

@ -2,7 +2,7 @@ use crate::component::auth::{
compute_hash_password, internal_error, validate_credentials, AuthError, Credentials,
};
use crate::config::env::{domain, jwt_secret};
use crate::state::Cache;
use crate::state::UserCache;
use crate::telemetry::spawn_blocking_with_tracing;
use actix_web::http::header::HeaderValue;
use actix_web::{FromRequest, HttpRequest};
@ -20,7 +20,7 @@ use tokio::sync::RwLock;
pub async fn login(
pg_pool: PgPool,
cache: Arc<RwLock<Cache>>,
cache: Arc<RwLock<UserCache>>,
email: String,
password: String,
) -> Result<(LoginResponse, Secret<Token>), AuthError> {
@ -50,13 +50,13 @@ pub async fn login(
}
}
pub async fn logout(logged_user: LoggedUser, cache: Arc<RwLock<Cache>>) {
pub async fn logout(logged_user: LoggedUser, cache: Arc<RwLock<UserCache>>) {
cache.write().await.unauthorized(logged_user);
}
pub async fn register(
pg_pool: PgPool,
cache: Arc<RwLock<Cache>>,
cache: Arc<RwLock<UserCache>>,
username: String,
email: String,
password: String,

View File

@ -1,6 +1,6 @@
use crate::component::auth::LoggedUser;
use crate::component::ws::entities::{
Connect, Disconnect, Socket, SocketMessagePayload, WebSocketMessage,
Connect, Disconnect, MessageDetail, MessagePayload, Socket, WebSocketMessage,
};
use crate::component::ws::server::WSServer;
use crate::component::ws::{HEARTBEAT_INTERVAL, PING_TIMEOUT};
@ -36,10 +36,10 @@ impl MessageReceivers {
}
}
#[allow(dead_code)]
pub struct WSClientData {
pub(crate) user: Arc<LoggedUser>,
pub(crate) socket: Socket,
pub(crate) data: Bytes,
pub(crate) detail: MessageDetail,
}
pub struct WSClient {
@ -77,17 +77,13 @@ impl WSClient {
}
fn handle_binary_message(&self, bytes: Bytes, socket: Socket) {
let payload = SocketMessagePayload::from_bytes(&bytes);
match self.msg_receivers.get(payload.channel) {
let MessagePayload { channel, detail } = MessagePayload::from_bytes(&bytes);
match self.msg_receivers.get(channel) {
None => {
tracing::error!("Can't find the receiver for {:?}", payload.channel);
tracing::error!("Can't find the receiver for {:?}", channel);
}
Some(handler) => {
let client_data = WSClientData {
user: self.user.clone(),
socket,
data: Bytes::from(payload.data),
};
let client_data = WSClientData { socket, detail };
handler.receive(client_data);
}
}
@ -102,7 +98,7 @@ impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WSClient {
ctx.pong(&msg);
}
Ok(Pong(_msg)) => {
tracing::trace!("Receive {} pong {:?}", &self.session_id, &msg);
// tracing::trace!("Receive {} pong {:?}", &self.session_id, &msg);
self.hb = Instant::now();
}
Ok(Binary(bytes)) => {

View File

@ -63,17 +63,29 @@ impl std::ops::Deref for WebSocketMessage {
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SocketMessagePayload {
pub struct MessagePayload {
pub(crate) channel: u8,
pub(crate) data: Vec<u8>,
pub(crate) detail: MessageDetail,
}
impl SocketMessagePayload {
impl MessagePayload {
pub fn from_bytes<T: AsRef<[u8]>>(bytes: T) -> Self {
bincode::deserialize(bytes.as_ref()).unwrap()
serde_json::from_slice(bytes.as_ref()).unwrap()
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum MessageDetail {
Document(MessageContent),
Database(MessageContent),
}
#[derive(Debug, Serialize, Deserialize)]
pub struct MessageContent {
content: String,
}
#[derive(Debug)]
pub enum WSError {
Internal,

View File

@ -13,9 +13,7 @@ impl WSServer {
WSServer::default()
}
pub fn send(&self, _msg: WebSocketMessage) {
unimplemented!()
}
pub fn send(&self, _msg: WebSocketMessage) {}
}
impl Actor for WSServer {
@ -40,9 +38,7 @@ impl Handler<Disconnect> for WSServer {
impl Handler<WebSocketMessage> for WSServer {
type Result = ();
fn handle(&mut self, _msg: WebSocketMessage, _ctx: &mut Context<Self>) -> Self::Result {
unimplemented!()
}
fn handle(&mut self, _msg: WebSocketMessage, _ctx: &mut Context<Self>) -> Self::Result {}
}
impl actix::Supervised for WSServer {

View File

@ -9,7 +9,7 @@ use tokio::sync::RwLock;
#[derive(Clone)]
pub struct State {
pub pg_pool: PgPool,
pub cache: Arc<RwLock<Cache>>,
pub user: Arc<RwLock<UserCache>>,
}
impl State {
@ -27,13 +27,14 @@ enum AuthStatus {
pub const EXPIRED_DURATION_DAYS: i64 = 30;
#[derive(Debug, Default)]
pub struct Cache {
pub struct UserCache {
// Keep track the user authentication state
user: BTreeMap<String, AuthStatus>,
}
impl Cache {
impl UserCache {
pub fn new() -> Self {
Cache::default()
UserCache::default()
}
pub fn is_authorized(&self, user: &LoggedUser) -> bool {