feat: support ws
This commit is contained in:
parent
542bc83144
commit
5a7e223b99
|
|
@ -9,6 +9,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "f728064aca1c318585bf4bb04ffcfac9e75e508ab4e8b1bd9ba5dfe04e2cbed5"
|
||||
dependencies = [
|
||||
"actix-rt",
|
||||
"actix_derive",
|
||||
"bitflags",
|
||||
"bytes",
|
||||
"crossbeam-channel",
|
||||
|
|
@ -291,6 +292,17 @@ dependencies = [
|
|||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "actix_derive"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d44b8fee1ced9671ba043476deddef739dd0959bf77030b26b738cc591737a7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "adler"
|
||||
version = "1.0.2"
|
||||
|
|
@ -398,6 +410,7 @@ checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800"
|
|||
name = "appflowy-server"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"actix",
|
||||
"actix-cors",
|
||||
"actix-http",
|
||||
"actix-identity",
|
||||
|
|
@ -410,8 +423,11 @@ dependencies = [
|
|||
"anyhow",
|
||||
"argon2",
|
||||
"async-stream",
|
||||
"bincode",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"config",
|
||||
"dashmap",
|
||||
"derive_more",
|
||||
"fancy-regex",
|
||||
"futures-util",
|
||||
|
|
@ -526,6 +542,15 @@ version = "1.6.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
|
||||
|
||||
[[package]]
|
||||
name = "bincode"
|
||||
version = "1.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bit-set"
|
||||
version = "0.5.3"
|
||||
|
|
@ -874,6 +899,19 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "5.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"hashbrown",
|
||||
"lock_api",
|
||||
"once_cell",
|
||||
"parking_lot_core 0.9.7",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more"
|
||||
version = "0.99.17"
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ edition = "2021"
|
|||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
actix = "0.13"
|
||||
actix-web = "4.3.1"
|
||||
actix-http = "3.3.1"
|
||||
actix-rt = "2"
|
||||
|
|
@ -46,6 +47,9 @@ unicode-segmentation = "1.0"
|
|||
lazy_static = "1.4.0"
|
||||
fancy-regex = "0.11.0"
|
||||
validator = "0.16.0"
|
||||
bytes = "1.4.0"
|
||||
bincode = "1.3.3"
|
||||
dashmap = "5.4"
|
||||
|
||||
# tracing
|
||||
tracing = { version = "0.1.37" }
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
mod token;
|
||||
mod user;
|
||||
mod ws;
|
||||
|
||||
pub use token::token_scope;
|
||||
pub use user::user_scope;
|
||||
pub use ws::ws_scope;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,31 @@
|
|||
use crate::component::auth::LoggedUser;
|
||||
use crate::component::ws::{MessageReceivers, WSClient, WSServer};
|
||||
use actix::Addr;
|
||||
use actix_web::web::{Data, Path, Payload};
|
||||
use actix_web::{get, web, HttpRequest, HttpResponse, Result, Scope};
|
||||
use actix_web_actors::ws;
|
||||
|
||||
pub fn ws_scope() -> Scope {
|
||||
web::scope("/ws").service(establish_ws_connection)
|
||||
}
|
||||
|
||||
#[get("/{token}")]
|
||||
pub async fn establish_ws_connection(
|
||||
request: HttpRequest,
|
||||
payload: Payload,
|
||||
token: Path<String>,
|
||||
server: Data<Addr<WSServer>>,
|
||||
msg_receivers: Data<MessageReceivers>,
|
||||
) -> Result<HttpResponse> {
|
||||
tracing::info!("establish_ws_connection");
|
||||
let user = LoggedUser::from_token(token.clone())?;
|
||||
let client = WSClient::new(user, server.get_ref().clone(), msg_receivers);
|
||||
let result = ws::start(client, &request, payload);
|
||||
match result {
|
||||
Ok(response) => Ok(response),
|
||||
Err(e) => {
|
||||
tracing::error!("ws connection error: {:?}", e);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
use crate::api::{token_scope, user_scope};
|
||||
use crate::api::{token_scope, user_scope, ws_scope};
|
||||
use crate::component::auth::HEADER_TOKEN;
|
||||
use crate::config::config::{Config, DatabaseSetting};
|
||||
use crate::middleware::cors::default_cors;
|
||||
|
|
@ -67,6 +67,7 @@ pub async fn run(
|
|||
.app_data(web::JsonConfig::default().limit(4096))
|
||||
.service(user_scope())
|
||||
.service(token_scope())
|
||||
.service(ws_scope())
|
||||
.app_data(Data::new(state.clone()))
|
||||
})
|
||||
.listen(listener)?
|
||||
|
|
|
|||
|
|
@ -139,18 +139,13 @@ pub async fn change_password(
|
|||
err: format!("{}", e),
|
||||
})?;
|
||||
// Save password to disk
|
||||
sqlx::query!(
|
||||
r#"
|
||||
UPDATE users
|
||||
SET password = $1
|
||||
WHERE uid = $2
|
||||
"#,
|
||||
new_hash_password.expose_secret(),
|
||||
uid
|
||||
)
|
||||
.execute(&mut transaction)
|
||||
.await
|
||||
.context("Failed to change user's password in the database.")?;
|
||||
let sql = "update users set password = ? where uid = ?";
|
||||
let _ = sqlx::query(sql)
|
||||
.bind(new_hash_password.expose_secret())
|
||||
.bind(uid)
|
||||
.execute(&mut transaction)
|
||||
.await
|
||||
.context("Failed to change user's password in the database.")?;
|
||||
|
||||
transaction
|
||||
.commit()
|
||||
|
|
@ -228,7 +223,7 @@ pub struct LoggedUser {
|
|||
uid: Secret<String>,
|
||||
}
|
||||
|
||||
impl std::convert::From<Claim> for LoggedUser {
|
||||
impl From<Claim> for LoggedUser {
|
||||
fn from(c: Claim) -> Self {
|
||||
Self {
|
||||
uid: Secret::new(c.user_id()),
|
||||
|
|
@ -269,7 +264,7 @@ impl FromRequest for LoggedUser {
|
|||
}
|
||||
}
|
||||
|
||||
impl std::convert::TryFrom<&HeaderValue> for LoggedUser {
|
||||
impl TryFrom<&HeaderValue> for LoggedUser {
|
||||
type Error = AuthError;
|
||||
|
||||
fn try_from(header: &HeaderValue) -> Result<Self, Self::Error> {
|
||||
|
|
@ -360,13 +355,13 @@ impl Token {
|
|||
}
|
||||
}
|
||||
|
||||
impl std::convert::From<String> for Token {
|
||||
impl From<String> for Token {
|
||||
fn from(val: String) -> Self {
|
||||
Self(val)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::Into<String> for Token {
|
||||
impl Into<String> for Token {
|
||||
fn into(self) -> String {
|
||||
self.0
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
pub mod auth;
|
||||
pub mod token_state;
|
||||
pub mod ws;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,168 @@
|
|||
use crate::component::auth::LoggedUser;
|
||||
use crate::component::ws::entities::{
|
||||
Connect, Disconnect, Socket, SocketMessagePayload, WebSocketMessage,
|
||||
};
|
||||
use crate::component::ws::server::WSServer;
|
||||
use crate::component::ws::{HEARTBEAT_INTERVAL, PING_TIMEOUT};
|
||||
use actix::*;
|
||||
use actix_http::ws::Message::*;
|
||||
use actix_web::web::Data;
|
||||
use actix_web_actors::ws;
|
||||
use bytes::Bytes;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
pub trait MessageReceiver: Send + Sync {
|
||||
fn receive(&self, data: WSClientData);
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct MessageReceivers {
|
||||
inner: HashMap<u8, Arc<dyn MessageReceiver>>,
|
||||
}
|
||||
|
||||
impl MessageReceivers {
|
||||
pub fn new() -> Self {
|
||||
MessageReceivers::default()
|
||||
}
|
||||
|
||||
pub fn set(&mut self, channel: u8, receiver: Arc<dyn MessageReceiver>) {
|
||||
self.inner.insert(channel, receiver);
|
||||
}
|
||||
|
||||
pub fn get(&self, source: u8) -> Option<Arc<dyn MessageReceiver>> {
|
||||
self.inner.get(&source).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WSClientData {
|
||||
pub(crate) user: Arc<LoggedUser>,
|
||||
pub(crate) socket: Socket,
|
||||
pub(crate) data: Bytes,
|
||||
}
|
||||
|
||||
pub struct WSClient {
|
||||
user: Arc<LoggedUser>,
|
||||
server: Addr<WSServer>,
|
||||
msg_receivers: Data<MessageReceivers>,
|
||||
hb: Instant,
|
||||
}
|
||||
|
||||
impl WSClient {
|
||||
pub fn new(
|
||||
user: LoggedUser,
|
||||
server: Addr<WSServer>,
|
||||
msg_receivers: Data<MessageReceivers>,
|
||||
) -> Self {
|
||||
Self {
|
||||
user: Arc::new(user),
|
||||
server,
|
||||
msg_receivers,
|
||||
hb: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
fn hb(&self, ctx: &mut ws::WebsocketContext<Self>) {
|
||||
ctx.run_interval(HEARTBEAT_INTERVAL, |client, ctx| {
|
||||
if Instant::now().duration_since(client.hb) > PING_TIMEOUT {
|
||||
client.server.do_send(Disconnect {
|
||||
user: client.user.clone(),
|
||||
});
|
||||
ctx.stop();
|
||||
} else {
|
||||
ctx.ping(b"");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn handle_binary_message(&self, bytes: Bytes, socket: Socket) {
|
||||
let payload = SocketMessagePayload::from_bytes(&bytes);
|
||||
match self.msg_receivers.get(payload.channel) {
|
||||
None => {
|
||||
tracing::error!("Can't find the receiver for {:?}", payload.channel);
|
||||
}
|
||||
Some(handler) => {
|
||||
let client_data = WSClientData {
|
||||
user: self.user.clone(),
|
||||
socket,
|
||||
data: Bytes::from(payload.data),
|
||||
};
|
||||
handler.receive(client_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WSClient {
|
||||
fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
|
||||
match msg {
|
||||
Ok(Ping(msg)) => {
|
||||
self.hb = Instant::now();
|
||||
ctx.pong(&msg);
|
||||
}
|
||||
Ok(Pong(_msg)) => {
|
||||
// tracing::debug!("Receive {} pong {:?}", &self.session_id, &msg);
|
||||
self.hb = Instant::now();
|
||||
}
|
||||
Ok(Binary(bytes)) => {
|
||||
let socket = ctx.address().recipient();
|
||||
self.handle_binary_message(bytes, socket);
|
||||
}
|
||||
Ok(Text(_)) => {
|
||||
tracing::warn!("Receive unexpected text message");
|
||||
}
|
||||
Ok(Close(reason)) => {
|
||||
ctx.close(reason);
|
||||
ctx.stop();
|
||||
}
|
||||
Ok(ws::Message::Continuation(_)) => {}
|
||||
Ok(ws::Message::Nop) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("WebSocketStream protocol error {:?}", e);
|
||||
ctx.stop();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Handler<WebSocketMessage> for WSClient {
|
||||
type Result = ();
|
||||
|
||||
fn handle(&mut self, msg: WebSocketMessage, ctx: &mut Self::Context) {
|
||||
ctx.binary(msg.0);
|
||||
}
|
||||
}
|
||||
|
||||
impl Actor for WSClient {
|
||||
type Context = ws::WebsocketContext<Self>;
|
||||
|
||||
fn started(&mut self, ctx: &mut Self::Context) {
|
||||
self.hb(ctx);
|
||||
let socket = ctx.address().recipient();
|
||||
let connect = Connect {
|
||||
socket,
|
||||
user: self.user.clone(),
|
||||
};
|
||||
self.server
|
||||
.send(connect)
|
||||
.into_actor(self)
|
||||
.then(|res, _client, _ctx| {
|
||||
match res {
|
||||
Ok(Ok(_)) => tracing::trace!("Send connect message to server success"),
|
||||
Ok(Err(e)) => tracing::error!("Send connect message to server failed: {:?}", e),
|
||||
Err(e) => tracing::error!("Send connect message to server failed: {:?}", e),
|
||||
}
|
||||
fut::ready(())
|
||||
})
|
||||
.wait(ctx);
|
||||
}
|
||||
|
||||
fn stopping(&mut self, _: &mut Self::Context) -> Running {
|
||||
self.server.do_send(Disconnect {
|
||||
user: self.user.clone(),
|
||||
});
|
||||
|
||||
Running::Stop
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
use crate::component::auth::LoggedUser;
|
||||
use actix::{Message, Recipient};
|
||||
use bytes::Bytes;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Formatter;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub type Socket = Recipient<WebSocketMessage>;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, Eq)]
|
||||
pub struct WSSessionId(pub String);
|
||||
|
||||
impl<T: AsRef<str>> std::convert::From<T> for WSSessionId {
|
||||
fn from(s: T) -> Self {
|
||||
WSSessionId(s.as_ref().to_owned())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WSSessionId {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
let desc = &self.0.to_string();
|
||||
f.write_str(desc)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Session {
|
||||
pub user: Arc<LoggedUser>,
|
||||
pub socket: Socket,
|
||||
}
|
||||
|
||||
impl std::convert::From<Connect> for Session {
|
||||
fn from(c: Connect) -> Self {
|
||||
Self {
|
||||
user: c.user,
|
||||
socket: c.socket,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Message, Clone)]
|
||||
#[rtype(result = "Result<(), WSError>")]
|
||||
pub struct Connect {
|
||||
pub socket: Socket,
|
||||
pub user: Arc<LoggedUser>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Message, Clone)]
|
||||
#[rtype(result = "Result<(), WSError>")]
|
||||
pub struct Disconnect {
|
||||
pub user: Arc<LoggedUser>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Message, Clone)]
|
||||
#[rtype(result = "()")]
|
||||
pub struct WebSocketMessage(pub Bytes);
|
||||
|
||||
impl std::ops::Deref for WebSocketMessage {
|
||||
type Target = Bytes;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SocketMessagePayload {
|
||||
pub(crate) channel: u8,
|
||||
pub(crate) data: Vec<u8>,
|
||||
}
|
||||
|
||||
impl SocketMessagePayload {
|
||||
pub fn to_bytes(&self) -> Vec<u8> {
|
||||
bincode::serialize(self).unwrap()
|
||||
}
|
||||
|
||||
pub fn from_bytes<T: AsRef<[u8]>>(bytes: T) -> Self {
|
||||
bincode::deserialize(bytes.as_ref()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum WSError {
|
||||
Internal,
|
||||
}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
use std::time::Duration;
|
||||
|
||||
mod client;
|
||||
mod entities;
|
||||
mod server;
|
||||
|
||||
pub use client::*;
|
||||
pub use server::WSServer;
|
||||
|
||||
pub(crate) const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(8);
|
||||
pub(crate) const PING_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
pub(crate) const MAX_PAYLOAD_SIZE: usize = 262_144; // max payload size is 256k
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
use crate::component::ws::entities::{
|
||||
Connect, Disconnect, Session, WSError, WSSessionId, WebSocketMessage,
|
||||
};
|
||||
use actix::{Actor, Context, Handler};
|
||||
use dashmap::DashMap;
|
||||
|
||||
pub struct WSServer {
|
||||
sessions: DashMap<WSSessionId, Session>,
|
||||
}
|
||||
|
||||
impl std::default::Default for WSServer {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
sessions: DashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl WSServer {
|
||||
pub fn new() -> Self {
|
||||
WSServer::default()
|
||||
}
|
||||
|
||||
pub fn send(&self, _msg: WebSocketMessage) {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl Actor for WSServer {
|
||||
type Context = Context<Self>;
|
||||
fn started(&mut self, _ctx: &mut Self::Context) {}
|
||||
}
|
||||
|
||||
impl Handler<Connect> for WSServer {
|
||||
type Result = Result<(), WSError>;
|
||||
fn handle(&mut self, msg: Connect, _ctx: &mut Context<Self>) -> Self::Result {
|
||||
let session: Session = msg.into();
|
||||
self.sessions.insert(session.id.clone(), session);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Handler<Disconnect> for WSServer {
|
||||
type Result = Result<(), WSError>;
|
||||
fn handle(&mut self, msg: Disconnect, _: &mut Context<Self>) -> Self::Result {
|
||||
self.sessions.remove(&msg.sid);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Handler<WebSocketMessage> for WSServer {
|
||||
type Result = ();
|
||||
|
||||
fn handle(&mut self, _msg: WebSocketMessage, _ctx: &mut Context<Self>) -> Self::Result {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl actix::Supervised for WSServer {
|
||||
fn restarting(&mut self, _ctx: &mut Context<WSServer>) {
|
||||
tracing::warn!("restarting");
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue