feat: support ws

This commit is contained in:
nathan 2023-03-14 00:05:23 +08:00
parent 542bc83144
commit 5a7e223b99
11 changed files with 415 additions and 17 deletions

38
Cargo.lock generated
View File

@ -9,6 +9,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f728064aca1c318585bf4bb04ffcfac9e75e508ab4e8b1bd9ba5dfe04e2cbed5"
dependencies = [
"actix-rt",
"actix_derive",
"bitflags",
"bytes",
"crossbeam-channel",
@ -291,6 +292,17 @@ dependencies = [
"tokio",
]
[[package]]
name = "actix_derive"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d44b8fee1ced9671ba043476deddef739dd0959bf77030b26b738cc591737a7"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "adler"
version = "1.0.2"
@ -398,6 +410,7 @@ checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800"
name = "appflowy-server"
version = "0.1.0"
dependencies = [
"actix",
"actix-cors",
"actix-http",
"actix-identity",
@ -410,8 +423,11 @@ dependencies = [
"anyhow",
"argon2",
"async-stream",
"bincode",
"bytes",
"chrono",
"config",
"dashmap",
"derive_more",
"fancy-regex",
"futures-util",
@ -526,6 +542,15 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
name = "bincode"
version = "1.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
dependencies = [
"serde",
]
[[package]]
name = "bit-set"
version = "0.5.3"
@ -874,6 +899,19 @@ dependencies = [
"syn",
]
[[package]]
name = "dashmap"
version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc"
dependencies = [
"cfg-if",
"hashbrown",
"lock_api",
"once_cell",
"parking_lot_core 0.9.7",
]
[[package]]
name = "derive_more"
version = "0.99.17"

View File

@ -6,6 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
actix = "0.13"
actix-web = "4.3.1"
actix-http = "3.3.1"
actix-rt = "2"
@ -46,6 +47,9 @@ unicode-segmentation = "1.0"
lazy_static = "1.4.0"
fancy-regex = "0.11.0"
validator = "0.16.0"
bytes = "1.4.0"
bincode = "1.3.3"
dashmap = "5.4"
# tracing
tracing = { version = "0.1.37" }

View File

@ -1,5 +1,7 @@
mod token;
mod user;
mod ws;
pub use token::token_scope;
pub use user::user_scope;
pub use ws::ws_scope;

31
src/api/ws.rs Normal file
View File

@ -0,0 +1,31 @@
use crate::component::auth::LoggedUser;
use crate::component::ws::{MessageReceivers, WSClient, WSServer};
use actix::Addr;
use actix_web::web::{Data, Path, Payload};
use actix_web::{get, web, HttpRequest, HttpResponse, Result, Scope};
use actix_web_actors::ws;
pub fn ws_scope() -> Scope {
web::scope("/ws").service(establish_ws_connection)
}
#[get("/{token}")]
pub async fn establish_ws_connection(
request: HttpRequest,
payload: Payload,
token: Path<String>,
server: Data<Addr<WSServer>>,
msg_receivers: Data<MessageReceivers>,
) -> Result<HttpResponse> {
tracing::info!("establish_ws_connection");
let user = LoggedUser::from_token(token.clone())?;
let client = WSClient::new(user, server.get_ref().clone(), msg_receivers);
let result = ws::start(client, &request, payload);
match result {
Ok(response) => Ok(response),
Err(e) => {
tracing::error!("ws connection error: {:?}", e);
Err(e)
}
}
}

View File

@ -1,4 +1,4 @@
use crate::api::{token_scope, user_scope};
use crate::api::{token_scope, user_scope, ws_scope};
use crate::component::auth::HEADER_TOKEN;
use crate::config::config::{Config, DatabaseSetting};
use crate::middleware::cors::default_cors;
@ -67,6 +67,7 @@ pub async fn run(
.app_data(web::JsonConfig::default().limit(4096))
.service(user_scope())
.service(token_scope())
.service(ws_scope())
.app_data(Data::new(state.clone()))
})
.listen(listener)?

View File

@ -139,18 +139,13 @@ pub async fn change_password(
err: format!("{}", e),
})?;
// Save password to disk
sqlx::query!(
r#"
UPDATE users
SET password = $1
WHERE uid = $2
"#,
new_hash_password.expose_secret(),
uid
)
.execute(&mut transaction)
.await
.context("Failed to change user's password in the database.")?;
let sql = "update users set password = ? where uid = ?";
let _ = sqlx::query(sql)
.bind(new_hash_password.expose_secret())
.bind(uid)
.execute(&mut transaction)
.await
.context("Failed to change user's password in the database.")?;
transaction
.commit()
@ -228,7 +223,7 @@ pub struct LoggedUser {
uid: Secret<String>,
}
impl std::convert::From<Claim> for LoggedUser {
impl From<Claim> for LoggedUser {
fn from(c: Claim) -> Self {
Self {
uid: Secret::new(c.user_id()),
@ -269,7 +264,7 @@ impl FromRequest for LoggedUser {
}
}
impl std::convert::TryFrom<&HeaderValue> for LoggedUser {
impl TryFrom<&HeaderValue> for LoggedUser {
type Error = AuthError;
fn try_from(header: &HeaderValue) -> Result<Self, Self::Error> {
@ -360,13 +355,13 @@ impl Token {
}
}
impl std::convert::From<String> for Token {
impl From<String> for Token {
fn from(val: String) -> Self {
Self(val)
}
}
impl std::convert::Into<String> for Token {
impl Into<String> for Token {
fn into(self) -> String {
self.0
}

View File

@ -1,2 +1,3 @@
pub mod auth;
pub mod token_state;
pub mod ws;

168
src/component/ws/client.rs Normal file
View File

@ -0,0 +1,168 @@
use crate::component::auth::LoggedUser;
use crate::component::ws::entities::{
Connect, Disconnect, Socket, SocketMessagePayload, WebSocketMessage,
};
use crate::component::ws::server::WSServer;
use crate::component::ws::{HEARTBEAT_INTERVAL, PING_TIMEOUT};
use actix::*;
use actix_http::ws::Message::*;
use actix_web::web::Data;
use actix_web_actors::ws;
use bytes::Bytes;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
pub trait MessageReceiver: Send + Sync {
fn receive(&self, data: WSClientData);
}
#[derive(Default)]
pub struct MessageReceivers {
inner: HashMap<u8, Arc<dyn MessageReceiver>>,
}
impl MessageReceivers {
pub fn new() -> Self {
MessageReceivers::default()
}
pub fn set(&mut self, channel: u8, receiver: Arc<dyn MessageReceiver>) {
self.inner.insert(channel, receiver);
}
pub fn get(&self, source: u8) -> Option<Arc<dyn MessageReceiver>> {
self.inner.get(&source).cloned()
}
}
pub struct WSClientData {
pub(crate) user: Arc<LoggedUser>,
pub(crate) socket: Socket,
pub(crate) data: Bytes,
}
pub struct WSClient {
user: Arc<LoggedUser>,
server: Addr<WSServer>,
msg_receivers: Data<MessageReceivers>,
hb: Instant,
}
impl WSClient {
pub fn new(
user: LoggedUser,
server: Addr<WSServer>,
msg_receivers: Data<MessageReceivers>,
) -> Self {
Self {
user: Arc::new(user),
server,
msg_receivers,
hb: Instant::now(),
}
}
fn hb(&self, ctx: &mut ws::WebsocketContext<Self>) {
ctx.run_interval(HEARTBEAT_INTERVAL, |client, ctx| {
if Instant::now().duration_since(client.hb) > PING_TIMEOUT {
client.server.do_send(Disconnect {
user: client.user.clone(),
});
ctx.stop();
} else {
ctx.ping(b"");
}
});
}
fn handle_binary_message(&self, bytes: Bytes, socket: Socket) {
let payload = SocketMessagePayload::from_bytes(&bytes);
match self.msg_receivers.get(payload.channel) {
None => {
tracing::error!("Can't find the receiver for {:?}", payload.channel);
}
Some(handler) => {
let client_data = WSClientData {
user: self.user.clone(),
socket,
data: Bytes::from(payload.data),
};
handler.receive(client_data);
}
}
}
}
impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WSClient {
fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
match msg {
Ok(Ping(msg)) => {
self.hb = Instant::now();
ctx.pong(&msg);
}
Ok(Pong(_msg)) => {
// tracing::debug!("Receive {} pong {:?}", &self.session_id, &msg);
self.hb = Instant::now();
}
Ok(Binary(bytes)) => {
let socket = ctx.address().recipient();
self.handle_binary_message(bytes, socket);
}
Ok(Text(_)) => {
tracing::warn!("Receive unexpected text message");
}
Ok(Close(reason)) => {
ctx.close(reason);
ctx.stop();
}
Ok(ws::Message::Continuation(_)) => {}
Ok(ws::Message::Nop) => {}
Err(e) => {
tracing::error!("WebSocketStream protocol error {:?}", e);
ctx.stop();
}
}
}
}
impl Handler<WebSocketMessage> for WSClient {
type Result = ();
fn handle(&mut self, msg: WebSocketMessage, ctx: &mut Self::Context) {
ctx.binary(msg.0);
}
}
impl Actor for WSClient {
type Context = ws::WebsocketContext<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
self.hb(ctx);
let socket = ctx.address().recipient();
let connect = Connect {
socket,
user: self.user.clone(),
};
self.server
.send(connect)
.into_actor(self)
.then(|res, _client, _ctx| {
match res {
Ok(Ok(_)) => tracing::trace!("Send connect message to server success"),
Ok(Err(e)) => tracing::error!("Send connect message to server failed: {:?}", e),
Err(e) => tracing::error!("Send connect message to server failed: {:?}", e),
}
fut::ready(())
})
.wait(ctx);
}
fn stopping(&mut self, _: &mut Self::Context) -> Running {
self.server.do_send(Disconnect {
user: self.user.clone(),
});
Running::Stop
}
}

View File

@ -0,0 +1,84 @@
use crate::component::auth::LoggedUser;
use actix::{Message, Recipient};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::fmt::Formatter;
use std::sync::Arc;
pub type Socket = Recipient<WebSocketMessage>;
#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, Eq)]
pub struct WSSessionId(pub String);
impl<T: AsRef<str>> std::convert::From<T> for WSSessionId {
fn from(s: T) -> Self {
WSSessionId(s.as_ref().to_owned())
}
}
impl std::fmt::Display for WSSessionId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let desc = &self.0.to_string();
f.write_str(desc)
}
}
pub struct Session {
pub user: Arc<LoggedUser>,
pub socket: Socket,
}
impl std::convert::From<Connect> for Session {
fn from(c: Connect) -> Self {
Self {
user: c.user,
socket: c.socket,
}
}
}
#[derive(Debug, Message, Clone)]
#[rtype(result = "Result<(), WSError>")]
pub struct Connect {
pub socket: Socket,
pub user: Arc<LoggedUser>,
}
#[derive(Debug, Message, Clone)]
#[rtype(result = "Result<(), WSError>")]
pub struct Disconnect {
pub user: Arc<LoggedUser>,
}
#[derive(Debug, Message, Clone)]
#[rtype(result = "()")]
pub struct WebSocketMessage(pub Bytes);
impl std::ops::Deref for WebSocketMessage {
type Target = Bytes;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SocketMessagePayload {
pub(crate) channel: u8,
pub(crate) data: Vec<u8>,
}
impl SocketMessagePayload {
pub fn to_bytes(&self) -> Vec<u8> {
bincode::serialize(self).unwrap()
}
pub fn from_bytes<T: AsRef<[u8]>>(bytes: T) -> Self {
bincode::deserialize(bytes.as_ref()).unwrap()
}
}
#[derive(Debug)]
pub enum WSError {
Internal,
}

12
src/component/ws/mod.rs Normal file
View File

@ -0,0 +1,12 @@
use std::time::Duration;
mod client;
mod entities;
mod server;
pub use client::*;
pub use server::WSServer;
pub(crate) const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(8);
pub(crate) const PING_TIMEOUT: Duration = Duration::from_secs(60);
pub(crate) const MAX_PAYLOAD_SIZE: usize = 262_144; // max payload size is 256k

View File

@ -0,0 +1,62 @@
use crate::component::ws::entities::{
Connect, Disconnect, Session, WSError, WSSessionId, WebSocketMessage,
};
use actix::{Actor, Context, Handler};
use dashmap::DashMap;
pub struct WSServer {
sessions: DashMap<WSSessionId, Session>,
}
impl std::default::Default for WSServer {
fn default() -> Self {
Self {
sessions: DashMap::new(),
}
}
}
impl WSServer {
pub fn new() -> Self {
WSServer::default()
}
pub fn send(&self, _msg: WebSocketMessage) {
unimplemented!()
}
}
impl Actor for WSServer {
type Context = Context<Self>;
fn started(&mut self, _ctx: &mut Self::Context) {}
}
impl Handler<Connect> for WSServer {
type Result = Result<(), WSError>;
fn handle(&mut self, msg: Connect, _ctx: &mut Context<Self>) -> Self::Result {
let session: Session = msg.into();
self.sessions.insert(session.id.clone(), session);
Ok(())
}
}
impl Handler<Disconnect> for WSServer {
type Result = Result<(), WSError>;
fn handle(&mut self, msg: Disconnect, _: &mut Context<Self>) -> Self::Result {
self.sessions.remove(&msg.sid);
Ok(())
}
}
impl Handler<WebSocketMessage> for WSServer {
type Result = ();
fn handle(&mut self, _msg: WebSocketMessage, _ctx: &mut Context<Self>) -> Self::Result {
unimplemented!()
}
}
impl actix::Supervised for WSServer {
fn restarting(&mut self, _ctx: &mut Context<WSServer>) {
tracing::warn!("restarting");
}
}