From 90ae1d5fb678426b29d4218c9727ca7523fad765 Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Wed, 10 May 2023 20:26:30 +0800 Subject: [PATCH] Feat/ws test (#4) * test: ws test * test: update * test: update * test: sync update * feat: ws test --- Cargo.lock | 27 ++++++- Cargo.toml | 19 ++--- crates/websocket/Cargo.toml | 5 +- crates/websocket/src/channel_ext.rs | 40 ++++++++++ crates/websocket/src/client.rs | 51 +++++------- crates/websocket/src/entities.rs | 72 ++++++++++++++++- crates/websocket/src/error.rs | 2 +- crates/websocket/src/lib.rs | 1 + crates/websocket/src/server.rs | 119 +++++++++++++++++++++------- src/api/ws.rs | 4 +- src/application.rs | 2 +- src/state.rs | 2 +- src/telemetry.rs | 12 ++- tests/util/test_server.rs | 33 +++++++- tests/ws/client.rs | 90 +++++++++++++++++++++ tests/ws/mod.rs | 3 +- tests/ws/test.rs | 32 ++++++++ 17 files changed, 424 insertions(+), 90 deletions(-) create mode 100644 crates/websocket/src/channel_ext.rs create mode 100644 tests/ws/client.rs create mode 100644 tests/ws/test.rs diff --git a/Cargo.lock b/Cargo.lock index fb2a4e37..dc63905d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -442,13 +442,14 @@ dependencies = [ "actix-web-flash-messages", "anyhow", "argon2", + "assert-json-diff", "async-stream", "bincode", "bytes", "chrono", + "collab", "collab-client-ws", - "collab-persistence", - "collab-sync", + "collab-plugins", "config", "dashmap", "derive_more", @@ -466,6 +467,7 @@ dependencies = [ "serde_json", "snowflake", "sqlx", + "tempfile", "thiserror", "token", "tokio", @@ -536,6 +538,16 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "async-stream" version = "0.3.5" @@ -812,6 +824,7 @@ dependencies = [ [[package]] name = "collab" version = "0.1.0" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=af4941#af4941ba5394157869eca56d4c937dbec1f0a0e3" dependencies = [ "anyhow", "bytes", @@ -828,8 +841,10 @@ dependencies = [ [[package]] name = "collab-client-ws" version = "0.1.0" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=af4941#af4941ba5394157869eca56d4c937dbec1f0a0e3" dependencies = [ "bytes", + "collab-sync", "futures-util", "serde", "serde_json", @@ -844,6 +859,7 @@ dependencies = [ [[package]] name = "collab-persistence" version = "0.1.0" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=af4941#af4941ba5394157869eca56d4c937dbec1f0a0e3" dependencies = [ "bincode", "chrono", @@ -863,11 +879,14 @@ dependencies = [ [[package]] name = "collab-plugins" version = "0.1.0" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=af4941#af4941ba5394157869eca56d4c937dbec1f0a0e3" dependencies = [ "collab", "collab-client-ws", "collab-persistence", "collab-sync", + "futures-util", + "tokio", "tracing", "y-sync", "yrs", @@ -876,6 +895,7 @@ dependencies = [ [[package]] name = "collab-sync" version = "0.1.0" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=af4941#af4941ba5394157869eca56d4c937dbec1f0a0e3" dependencies = [ "bytes", "collab", @@ -3621,14 +3641,13 @@ dependencies = [ "actix-web-actors", "bytes", "collab", - "collab-persistence", "collab-plugins", - "collab-sync", "dashmap", "futures-util", "parking_lot 0.12.1", "secrecy", "serde", + "serde_json", "thiserror", "tokio", "tokio-stream", diff --git a/Cargo.toml b/Cargo.toml index b58d1cf6..bf0d5494 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,8 +51,6 @@ bytes = "1.4.0" bincode = "1.3.3" dashmap = "5.4" rcgen = { version = "0.10.0", features = ["pem", "x509-parser"] } -collab-sync = {version = "0.1.0"} -collab-persistence = {version = "0.1.0"} # tracing tracing = { version = "0.1.37" } @@ -66,16 +64,19 @@ sqlx = { version = "0.6", default-features = false, features = ["runtime-actix-r token = { path = "./crates/token" } snowflake = { path = "./crates/snowflake" } websocket = { path = "./crates/websocket" } +collab-plugins = { version = "0.1.0", features = ["sync", "disk_rocksdb"] } [dev-dependencies] once_cell = "1.7.2" +collab = { version = "0.1.0" } collab-client-ws = { version = "0.1.0" } +tempfile = "3.4.0" +assert-json-diff = "2.0.2" [[bin]] name = "appflowy_server" path = "src/main.rs" - [lib] path = "src/lib.rs" @@ -87,14 +88,14 @@ members = [ ] [patch.crates-io] -collab = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" } -collab-client-ws = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" } -collab-sync= { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" } -collab-persistence = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" } -collab-plugins = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "4a12ed" } +collab = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "af4941" } +collab-client-ws = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "af4941" } +collab-sync= { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "af4941" } +collab-persistence = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "af4941" } +collab-plugins = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "af4941" } #collab = { path = "./crates/AppFlowy-Collab/collab" } #collab-client-ws = { path = "./crates/AppFlowy-Collab/collab-client-ws" } -#collab-sync = { path = "./crates/AppFlowy-Collab/collab-sync" } #collab-persistence = { path = "./crates/AppFlowy-Collab/collab-persistence" } +#collab-sync = { path = "./crates/AppFlowy-Collab/collab-sync" } #collab-plugins = { path = "./crates/AppFlowy-Collab/collab-plugins"} diff --git a/crates/websocket/Cargo.toml b/crates/websocket/Cargo.toml index 64d85c68..dce64df9 100644 --- a/crates/websocket/Cargo.toml +++ b/crates/websocket/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" actix = "0.13" actix-web-actors = { version = "4.2.0" } serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" thiserror = "1.0.30" bytes = "1.0" secrecy = { version = "0.8", features = ["serde"] } @@ -20,6 +21,4 @@ tokio = { version = "1.26", features = ["sync"] } dashmap = "5.4.0" collab = { version = "0.1.0"} -collab-sync = { version = "0.1.0"} -collab-persistence = { version = "0.1.0"} -collab-plugins = { version = "0.1.0", features = ["disk_rocksdb"]} +collab-plugins = { version = "0.1.0", features = ["disk_rocksdb", "sync"]} diff --git a/crates/websocket/src/channel_ext.rs b/crates/websocket/src/channel_ext.rs new file mode 100644 index 00000000..0cce1a12 --- /dev/null +++ b/crates/websocket/src/channel_ext.rs @@ -0,0 +1,40 @@ +use crate::error::WSError; +use futures_util::Sink; +use std::fmt::Debug; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub struct UnboundedSenderSink(pub tokio::sync::mpsc::UnboundedSender); + +impl UnboundedSenderSink { + pub fn new(tx: tokio::sync::mpsc::UnboundedSender) -> Self { + Self(tx) + } +} + +impl Sink for UnboundedSenderSink +where + T: Send + Sync + 'static + Debug, +{ + type Error = WSError; + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + // An unbounded channel can always accept messages without blocking, so we always return Ready. + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + let _ = self.0.send(item); + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + // There is no buffering in an unbounded channel, so we always return Ready. + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + // An unbounded channel is closed by dropping the sender, so we don't need to do anything here. + Poll::Ready(Ok(())) + } +} diff --git a/crates/websocket/src/client.rs b/crates/websocket/src/client.rs index 551d0a61..2b84b2ef 100644 --- a/crates/websocket/src/client.rs +++ b/crates/websocket/src/client.rs @@ -1,5 +1,5 @@ -use crate::entities::{ClientMessage, Connect, Disconnect, ServerMessage, WSUser}; -use crate::error::WSError; +use crate::entities::{ClientMessage, Connect, Disconnect, ServerMessage, WSMessage, WSUser}; + use crate::CollabServer; use actix::{ fut, Actor, ActorContext, ActorFutureExt, Addr, AsyncContext, ContextFutureSpawner, Handler, @@ -7,13 +7,12 @@ use actix::{ }; use actix_web_actors::ws; use bytes::Bytes; +use std::ops::Deref; -use collab_sync::msg::CollabMessage; -use futures_util::Sink; +use collab_plugins::sync::msg::CollabMessage; -use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; + use std::time::{Duration, Instant}; const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); @@ -48,10 +47,13 @@ impl CollabSession { }); } - fn send_to_server(&self, bytes: Bytes) { - match CollabMessage::from_vec(bytes.to_vec()) { - Ok(collab_msg) => { + fn forward_binary_to_ws_server(&self, bytes: Bytes) { + match WSMessage::from_vec(bytes.to_vec()) { + Ok(ws_message) => { + tracing::trace!("[WSClient]: forward message to server"); + let collab_msg = CollabMessage::from_vec(&ws_message.payload).unwrap(); self.server.do_send(ClientMessage { + business_id: ws_message.business_id, user: self.user.clone(), collab_msg, }); @@ -83,7 +85,7 @@ impl Actor for CollabSession { tracing::trace!("Send connect message to server success") }, _ => { - tracing::error!("Send connect message to server failed"); + tracing::error!("🔴Send connect message to server failed"); ctx.stop(); }, } @@ -103,8 +105,9 @@ impl Actor for CollabSession { impl Handler for CollabSession { type Result = (); - fn handle(&mut self, msg: ServerMessage, ctx: &mut Self::Context) { - ctx.binary(msg.collab_msg); + fn handle(&mut self, server_msg: ServerMessage, ctx: &mut Self::Context) { + tracing::trace!("[WSClient]: forward message to client"); + ctx.binary(WSMessage::from(server_msg)); } } @@ -129,7 +132,7 @@ impl StreamHandler> for CollabSession { }, ws::Message::Text(_) => {}, ws::Message::Binary(bytes) => { - self.send_to_server(bytes); + self.forward_binary_to_ws_server(bytes); }, ws::Message::Close(reason) => { ctx.close(reason); @@ -145,24 +148,10 @@ impl StreamHandler> for CollabSession { /// A helper struct that wraps the [Recipient] type to implement the [Sink] trait pub struct ClientSink(pub Recipient); +impl Deref for ClientSink { + type Target = Recipient; -impl Sink for ClientSink { - type Error = WSError; - - fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn start_send(self: Pin<&mut Self>, item: CollabMessage) -> Result<(), Self::Error> { - self.0.do_send(ServerMessage { collab_msg: item }); - Ok(()) - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + fn deref(&self) -> &Self::Target { + &self.0 } } diff --git a/crates/websocket/src/entities.rs b/crates/websocket/src/entities.rs index 3ba6cf54..89e82d2b 100644 --- a/crates/websocket/src/entities.rs +++ b/crates/websocket/src/entities.rs @@ -1,9 +1,10 @@ use crate::error::WSError; use actix::{Message, Recipient}; - -use collab_sync::msg::CollabMessage; +use bytes::Bytes; +use collab_plugins::sync::msg::CollabMessage; use secrecy::{ExposeSecret, Secret}; - +use serde::{Deserialize, Serialize}; +use std::fmt::{Display, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; @@ -12,6 +13,12 @@ pub struct WSUser { pub user_id: Secret, } +impl Display for WSUser { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(self.user_id.expose_secret()) + } +} + impl Hash for WSUser { fn hash(&self, state: &mut H) { let uid: &String = self.user_id.expose_secret(); @@ -45,6 +52,7 @@ pub struct Disconnect { #[derive(Debug, Message, Clone)] #[rtype(result = "()")] pub struct ClientMessage { + pub business_id: String, pub user: Arc, pub collab_msg: CollabMessage, } @@ -52,5 +60,61 @@ pub struct ClientMessage { #[derive(Debug, Message, Clone)] #[rtype(result = "()")] pub struct ServerMessage { - pub collab_msg: CollabMessage, + pub business_id: String, + pub payload: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WSMessage { + pub business_id: String, + pub payload: Vec, +} + +impl WSMessage { + pub fn from_vec(bytes: Vec) -> Result { + serde_json::from_slice(&bytes) + } +} + +impl From for Bytes { + fn from(msg: WSMessage) -> Self { + let bytes = serde_json::to_vec(&msg).unwrap_or_default(); + Bytes::from(bytes) + } +} + +impl From for WSMessage { + fn from(server_msg: ServerMessage) -> Self { + Self { + business_id: server_msg.business_id, + payload: server_msg.payload, + } + } +} + +impl From for WSMessage { + fn from(msg: CollabMessage) -> Self { + Self { + business_id: msg.business_id().to_string(), + payload: msg.to_vec(), + } + } +} + +impl From for ServerMessage { + fn from(msg: CollabMessage) -> Self { + Self { + business_id: msg.business_id().to_string(), + payload: msg.to_vec(), + } + } +} + +impl From for WSMessage { + fn from(client_msg: ClientMessage) -> Self { + Self { + business_id: client_msg.business_id, + payload: client_msg.collab_msg.to_vec(), + } + } } diff --git a/crates/websocket/src/error.rs b/crates/websocket/src/error.rs index 54729bf7..6564f12c 100644 --- a/crates/websocket/src/error.rs +++ b/crates/websocket/src/error.rs @@ -1,7 +1,7 @@ #[derive(Debug, thiserror::Error)] pub enum WSError { #[error(transparent)] - Persistence(#[from] collab_persistence::error::PersistenceError), + Persistence(#[from] collab_plugins::disk::error::PersistenceError), #[error("Internal failure: {0}")] Internal(#[from] Box), diff --git a/crates/websocket/src/lib.rs b/crates/websocket/src/lib.rs index e98c29ca..bab6ea69 100644 --- a/crates/websocket/src/lib.rs +++ b/crates/websocket/src/lib.rs @@ -1,3 +1,4 @@ +mod channel_ext; mod client; pub mod entities; mod error; diff --git a/crates/websocket/src/server.rs b/crates/websocket/src/server.rs index 8eb573af..8d057a6f 100644 --- a/crates/websocket/src/server.rs +++ b/crates/websocket/src/server.rs @@ -1,26 +1,27 @@ -use crate::entities::{ClientMessage, Connect, Disconnect, WSUser}; +use crate::entities::{ClientMessage, Connect, Disconnect, ServerMessage, WSMessage, WSUser}; use crate::error::WSError; use crate::ClientSink; +use crate::channel_ext::UnboundedSenderSink; use actix::{Actor, Context, Handler, ResponseFuture}; use collab::core::collab::MutexCollab; use collab::core::origin::CollabOrigin; -use collab_persistence::kv::rocks_kv::RocksCollabDB; -use collab_persistence::kv::KVStore; -use collab_plugins::disk_plugin::rocksdb_server::RocksdbServerDiskPlugin; -use collab_sync::server::{ +use collab_plugins::disk::keys::make_collab_id_key; +use collab_plugins::disk::kv::rocks_kv::RocksCollabDB; +use collab_plugins::disk::kv::KVStore; +use collab_plugins::disk::rocksdb_server::RocksdbServerDiskPlugin; +use collab_plugins::sync::msg::CollabMessage; +use collab_plugins::sync::server::{ CollabBroadcast, CollabGroup, CollabIDGen, CollabId, NonZeroNodeId, COLLAB_ID_LEN, }; use dashmap::DashMap; use parking_lot::{Mutex, RwLock}; use std::collections::HashMap; - -use collab_persistence::keys::make_collab_id_key; -use collab_sync::msg::CollabMessage; - use std::sync::Arc; + use tokio::sync::mpsc::Sender; use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::StreamExt; #[derive(Clone)] pub struct CollabServer { @@ -30,7 +31,7 @@ pub struct CollabServer { /// Memory cache for fast lookup of collab_id from object_id collab_id_by_object_id: Arc>, collab_groups: Arc>>, - client_streams: Arc, ClientStream>>>, + client_streams: Arc, WSClientStream>>>, } impl CollabServer { @@ -46,6 +47,7 @@ impl CollabServer { }) } + /// Create a new collab id for the object id. fn create_collab_id(&self, object_id: &str) -> Result { let collab_id = self.collab_id_gen.lock().next_id(); let collab_key = make_collab_id_key(object_id.as_ref()); @@ -56,6 +58,8 @@ impl CollabServer { Ok(collab_id) } + /// Get the collab id for the object + /// If the object doesn't have a collab id, return None fn get_collab_id(&self, object_id: &str) -> Option { let collab_key = make_collab_id_key(object_id.as_ref()); let read_txn = self.db.read_txn(); @@ -66,6 +70,7 @@ impl CollabServer { Some(CollabId::from_be_bytes(bytes)) } + /// Get or create a collab id if the object doesn't have one fn get_or_create_collab_id(&self, object_id: &str) -> Result { let collab_id = self.get_collab_id(object_id); if let Some(collab_id) = collab_id { @@ -81,6 +86,7 @@ impl CollabServer { } } + /// Create the collab group for the object if it doesn't exist fn create_group_if_need(&self, collab_id: CollabId, object_id: &str) { if self.collab_groups.read().contains_key(&collab_id) { return; @@ -108,10 +114,17 @@ impl Actor for CollabServer { impl Handler for CollabServer { type Result = Result<(), WSError>; - fn handle(&mut self, msg: Connect, _ctx: &mut Context) -> Self::Result { - let (stream_tx, rx) = tokio::sync::mpsc::channel(100); - let stream = ClientStream::new(ClientSink(msg.socket), ReceiverStream::new(rx), stream_tx); - self.client_streams.write().insert(msg.user, stream); + fn handle(&mut self, new_conn: Connect, _ctx: &mut Context) -> Self::Result { + tracing::trace!("[WSServer]: {} connect", new_conn.user); + + // When receive a new connection, create a new [ClientStream] that holds the connection's websocket + let (stream_tx, stream_rx) = tokio::sync::mpsc::channel(1000); + let stream = WSClientStream::new( + ClientSink(new_conn.socket), + ReceiverStream::new(stream_rx), + stream_tx, + ); + self.client_streams.write().insert(new_conn.user, stream); Ok(()) } } @@ -119,6 +132,7 @@ impl Handler for CollabServer { impl Handler for CollabServer { type Result = Result<(), WSError>; fn handle(&mut self, msg: Disconnect, _: &mut Context) -> Self::Result { + tracing::trace!("[WSServer]: {} disconnect", msg.user); self.client_streams.write().remove(&msg.user); Ok(()) } @@ -127,13 +141,16 @@ impl Handler for CollabServer { impl Handler for CollabServer { type Result = ResponseFuture<()>; - fn handle(&mut self, msg: ClientMessage, _ctx: &mut Context) -> Self::Result { - let object_id = msg.collab_msg.object_id(); + fn handle(&mut self, client_msg: ClientMessage, _ctx: &mut Context) -> Self::Result { + let object_id = client_msg.collab_msg.object_id(); + // Get the collab_id for the object_id. If the object_id is not exist, create a new collab_id for it. + // Also create a new [CollabGroup] for the collab_id if it is not exist. if let Ok(collab_id) = self.get_or_create_collab_id(object_id) { if let Some(collab_group) = self.collab_groups.write().get_mut(&collab_id) { - if let Some(stream) = self.client_streams.write().get_mut(&msg.user) { - if let Some((sink, stream)) = stream.split() { - let origin = match msg.collab_msg.origin() { + if let Some(client_stream) = self.client_streams.write().get_mut(&client_msg.user) { + // If the client's stream is not subscribed to the collab group, subscribe it. + if let Some((sink, stream)) = client_stream.split() { + let origin = match client_msg.collab_msg.origin() { None => CollabOrigin::Empty, Some(client) => client.clone(), }; @@ -147,8 +164,19 @@ impl Handler for CollabServer { let client_streams = self.client_streams.clone(); Box::pin(async move { - if let Some(client_stream) = client_streams.read().get(&msg.user) { - let _ = client_stream.stream_tx.send(Ok(msg.collab_msg)).await; + if let Some(client_stream) = client_streams.read().get(&client_msg.user) { + tracing::trace!( + "[WSServer]: receives client message: {:?}", + client_msg.collab_msg.msg_id() + ); + match client_stream + .stream_tx + .send(Ok(WSMessage::from(client_msg))) + .await + { + Ok(_) => {}, + Err(e) => tracing::trace!("send error: {:?}", e), + } } }) } else { @@ -163,17 +191,17 @@ impl actix::Supervised for CollabServer { } } -pub struct ClientStream { +pub struct WSClientStream { sink: Option, - stream: Option>>, - stream_tx: Sender>, + stream: Option>>, + stream_tx: Sender>, } -impl ClientStream { +impl WSClientStream { pub fn new( sink: ClientSink, - stream: ReceiverStream>, - stream_tx: Sender>, + stream: ReceiverStream>, + stream_tx: Sender>, ) -> Self { Self { sink: Some(sink), @@ -182,9 +210,40 @@ impl ClientStream { } } - pub fn split(&mut self) -> Option<(ClientSink, ReceiverStream>)> { - let sink = self.sink.take()?; - let stream = self.stream.take()?; + #[allow(clippy::type_complexity)] + pub fn split(&mut self) -> Option<(UnboundedSenderSink, ReceiverStream>)> + where + T: TryFrom + Into + Send + Sync + 'static, + { + let client_sink = self.sink.take()?; + let mut stream = self.stream.take()?; + + // forward sink + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + tokio::spawn(async move { + while let Some(msg) = rx.recv().await { + client_sink.do_send(msg.into()); + } + }); + let sink = UnboundedSenderSink::::new(tx); + + // forward stream + let (tx, rx) = tokio::sync::mpsc::channel(100); + tokio::spawn(async move { + while let Some(Ok(msg)) = stream.next().await { + let _ = tx.send(T::try_from(msg)).await; + } + }); + let stream = ReceiverStream::new(rx); + Some((sink, stream)) } } + +impl TryFrom for CollabMessage { + type Error = WSError; + + fn try_from(value: WSMessage) -> Result { + CollabMessage::from_vec(&value.payload).map_err(|e| WSError::Internal(Box::new(e))) + } +} diff --git a/src/api/ws.rs b/src/api/ws.rs index ed350e2b..462cc105 100644 --- a/src/api/ws.rs +++ b/src/api/ws.rs @@ -5,6 +5,7 @@ use actix_web::web::{Data, Path, Payload}; use actix_web::{get, web, HttpRequest, HttpResponse, Result, Scope}; use actix_web_actors::ws; use secrecy::Secret; + use websocket::entities::WSUser; use websocket::{CollabServer, CollabSession}; @@ -20,12 +21,13 @@ pub async fn establish_ws_connection( state: Data, server: Data>, ) -> Result { + tracing::trace!("{:?}", request); let user = LoggedUser::from_token(&state.config.application.server_key, token.as_str())?; let client = CollabSession::new(user.into(), server.get_ref().clone()); match ws::start(client, &request, payload) { Ok(response) => Ok(response), Err(e) => { - tracing::error!("ws connection error: {:?}", e); + tracing::error!("🔴ws connection error: {:?}", e); Err(e) }, } diff --git a/src/application.rs b/src/application.rs index 41f0f6d3..37e059e0 100644 --- a/src/application.rs +++ b/src/application.rs @@ -12,7 +12,7 @@ use actix_web::{dev::Server, web, web::Data, App, HttpServer}; use actix::Actor; -use collab_persistence::kv::rocks_kv::RocksCollabDB; +use collab_plugins::disk::kv::rocks_kv::RocksCollabDB; use openssl::ssl::{SslAcceptor, SslAcceptorBuilder, SslFiletype, SslMethod}; use openssl::x509::X509; use secrecy::{ExposeSecret, Secret}; diff --git a/src/state.rs b/src/state.rs index 16de5560..95930965 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,8 +1,8 @@ use crate::component::auth::LoggedUser; use crate::config::config::Config; use chrono::{DateTime, Utc}; -use collab_persistence::kv::rocks_kv::RocksCollabDB; +use collab_plugins::disk::kv::rocks_kv::RocksCollabDB; use snowflake::Snowflake; use sqlx::PgPool; use std::collections::BTreeMap; diff --git a/src/telemetry.rs b/src/telemetry.rs index b0b5772b..aaea2e68 100644 --- a/src/telemetry.rs +++ b/src/telemetry.rs @@ -3,6 +3,7 @@ use tracing::subscriber::set_global_default; use tracing::Subscriber; use tracing_bunyan_formatter::{BunyanFormattingLayer, JsonStorageLayer}; use tracing_log::LogTracer; + use tracing_subscriber::fmt::MakeWriter; use tracing_subscriber::{layer::SubscriberExt, EnvFilter}; @@ -15,11 +16,18 @@ pub fn get_subscriber( where Sink: for<'a> MakeWriter<'a> + Send + Sync + 'static, { - let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(env_filter)); - // let env_filter = EnvFilter::new(env_filter); + let env_filter = match EnvFilter::try_from_default_env() { + Ok(env_filter) => { + dbg!("Using default env filter"); + env_filter + }, + Err(_) => EnvFilter::new(env_filter), + }; let formatting_layer = BunyanFormattingLayer::new(name, sink); tracing_subscriber::fmt() .with_ansi(true) + .with_target(true) + .with_max_level(tracing::Level::TRACE) .finish() .with(env_filter) .with(JsonStorageLayer) diff --git a/tests/util/test_server.rs b/tests/util/test_server.rs index 78bb6ad6..1c952e42 100644 --- a/tests/util/test_server.rs +++ b/tests/util/test_server.rs @@ -2,9 +2,17 @@ use appflowy_server::application::{init_state, Application}; use appflowy_server::config::config::{get_configuration, DatabaseSetting}; use appflowy_server::state::State; use appflowy_server::telemetry::{get_subscriber, init_subscriber}; +use collab::core::collab::MutexCollab; +use collab::core::origin::CollabOrigin; + +use collab_plugins::disk::keys::make_collab_id_key; + +use collab_plugins::disk::rocksdb_server::RocksdbServerDiskPlugin; +use collab_plugins::sync::server::{CollabId, COLLAB_ID_LEN}; use once_cell::sync::Lazy; use reqwest::Certificate; use std::path::PathBuf; +use std::sync::Arc; use appflowy_server::component::auth::{RegisterResponse, HEADER_TOKEN}; use sqlx::types::Uuid; @@ -12,11 +20,13 @@ use sqlx::{Connection, Executor, PgConnection, PgPool}; // Ensure that the `tracing` stack is only initialised once using `once_cell` static TRACING: Lazy<()> = Lazy::new(|| { - let level = "trace".to_string(); + let level = "debug".to_string(); let mut filters = vec![]; filters.push(format!("appflowy_server={}", level)); filters.push(format!("collab_client_ws={}", level)); - filters.push(format!("hyper={}", level)); + filters.push(format!("websocket={}", level)); + filters.push(format!("collab_sync={}", level)); + // filters.push(format!("hyper={}", level)); filters.push(format!("actix_web={}", level)); let subscriber_name = "test".to_string(); @@ -89,6 +99,25 @@ impl TestServer { .await .expect("Change password failed") } + + pub fn get_doc(&self, object_id: &str) -> serde_json::Value { + let collab = MutexCollab::new(CollabOrigin::Empty, object_id, vec![]); + let collab_id = self.collab_id_from_object_id(object_id); + let plugin = RocksdbServerDiskPlugin::new(collab_id, self.state.rocksdb.clone()).unwrap(); + collab.lock().add_plugin(Arc::new(plugin)); + collab.initial(); + let collab = collab.lock(); + collab.to_json_value() + } + + pub fn collab_id_from_object_id(&self, object_id: &str) -> CollabId { + let read_txn = self.state.rocksdb.read_txn(); + let collab_key = make_collab_id_key(object_id.as_ref()); + let value = read_txn.get(collab_key.as_ref()).unwrap().unwrap(); + let mut bytes = [0; COLLAB_ID_LEN]; + bytes[0..COLLAB_ID_LEN].copy_from_slice(value.as_ref()); + CollabId::from_be_bytes(bytes) + } } pub async fn spawn_server() -> TestServer { diff --git a/tests/ws/client.rs b/tests/ws/client.rs new file mode 100644 index 00000000..9c9998b1 --- /dev/null +++ b/tests/ws/client.rs @@ -0,0 +1,90 @@ +use collab::core::collab::MutexCollab; +use collab::core::origin::{CollabClient, CollabOrigin}; + +use collab_client_ws::{WSClient, WSMessageHandler}; +use collab_plugins::disk::kv::rocks_kv::RocksCollabDB; +use collab_plugins::disk::rocksdb::RocksdbDiskPlugin; +use collab_plugins::sync::SyncPlugin; +use std::net::SocketAddr; +use std::ops::Deref; +use std::path::PathBuf; +use std::sync::Arc; +use tempfile::TempDir; + +pub async fn spawn_client( + uid: i64, + object_id: &str, + address: String, +) -> std::io::Result { + let ws_client = WSClient::new(address, 100); + let addr = ws_client.connect().await.unwrap().unwrap(); + let origin = origin_from_tcp_stream(&addr); + let handler = ws_client.subscribe("collab".to_string()).await.unwrap(); + + // + let (sink, stream) = (handler.sink(), handler.stream()); + let collab = Arc::new(MutexCollab::new(origin.clone(), object_id, vec![])); + let sync_plugin = SyncPlugin::new(origin, object_id, collab.clone(), sink, stream); + collab.lock().add_plugin(Arc::new(sync_plugin)); + + // disk + let tempdir = TempDir::new().unwrap(); + let db_path = tempdir.into_path(); + let db = Arc::new(RocksCollabDB::open(db_path.clone()).unwrap()); + let disk_plugin = RocksdbDiskPlugin::new(uid, db.clone()).unwrap(); + collab.lock().add_plugin(Arc::new(disk_plugin)); + collab.initial(); + + let cleaner = Cleaner::new(db_path); + Ok(TestClient { + ws_client, + db, + collab, + cleaner, + handlers: vec![handler], + }) +} + +fn origin_from_tcp_stream(addr: &SocketAddr) -> CollabOrigin { + let origin = CollabClient::new(addr.port() as i64, &addr.to_string()); + CollabOrigin::Client(origin) +} + +pub struct TestClient { + #[allow(dead_code)] + ws_client: WSClient, + pub db: Arc, + pub collab: Arc, + + #[allow(dead_code)] + cleaner: Cleaner, + + #[allow(dead_code)] + handlers: Vec>, +} + +struct Cleaner(PathBuf); + +impl Cleaner { + fn new(dir: PathBuf) -> Self { + Cleaner(dir) + } + + fn cleanup(dir: &PathBuf) { + let _ = std::fs::remove_dir_all(dir); + } +} + +impl Drop for Cleaner { + fn drop(&mut self) { + Self::cleanup(&self.0) + } +} + +impl Deref for TestClient { + type Target = Arc; + + fn deref(&self) -> &Self::Target { + &self.collab + } +} diff --git a/tests/ws/mod.rs b/tests/ws/mod.rs index 8b137891..0f3cbef3 100644 --- a/tests/ws/mod.rs +++ b/tests/ws/mod.rs @@ -1 +1,2 @@ - +mod client; +mod test; diff --git a/tests/ws/test.rs b/tests/ws/test.rs new file mode 100644 index 00000000..16c17380 --- /dev/null +++ b/tests/ws/test.rs @@ -0,0 +1,32 @@ +use crate::util::{spawn_server, TestUser}; +use crate::ws::client::spawn_client; +use serde_json::json; +use std::time::Duration; + +#[actix_rt::test] +async fn ws_conn_test() { + let server = spawn_server().await; + let test_user = TestUser::generate(); + let token = test_user.register(&server).await; + let address = format!("{}/{}", server.ws_addr, token); + let client = spawn_client(1, "1", address).await.unwrap(); + + wait_a_sec().await; + { + let collab = client.lock(); + collab.insert("1", "a"); + } + wait_a_sec().await; + + let value = server.get_doc("1"); + assert_json_diff::assert_json_eq!( + value, + json!({ + "1": "a" + }) + ); +} + +async fn wait_a_sec() { + tokio::time::sleep(Duration::from_secs(2)).await; +}