diff --git a/Cargo.lock b/Cargo.lock index fb41d68a..865cd6df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -824,6 +824,7 @@ dependencies = [ [[package]] name = "collab" version = "0.1.0" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=47a4e9#47a4e903b63825b59f5d42351aa4a23cf5ef29f6" dependencies = [ "anyhow", "bytes", @@ -840,6 +841,7 @@ dependencies = [ [[package]] name = "collab-client-ws" version = "0.1.0" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=47a4e9#47a4e903b63825b59f5d42351aa4a23cf5ef29f6" dependencies = [ "bytes", "collab-sync", @@ -857,6 +859,7 @@ dependencies = [ [[package]] name = "collab-persistence" version = "0.1.0" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=47a4e9#47a4e903b63825b59f5d42351aa4a23cf5ef29f6" dependencies = [ "bincode", "chrono", @@ -876,6 +879,7 @@ dependencies = [ [[package]] name = "collab-plugins" version = "0.1.0" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=47a4e9#47a4e903b63825b59f5d42351aa4a23cf5ef29f6" dependencies = [ "collab", "collab-client-ws", @@ -891,6 +895,7 @@ dependencies = [ [[package]] name = "collab-sync" version = "0.1.0" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Collab?rev=47a4e9#47a4e903b63825b59f5d42351aa4a23cf5ef29f6" dependencies = [ "bytes", "collab", diff --git a/Cargo.toml b/Cargo.toml index 4b83ef43..da5e8c7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -88,11 +88,11 @@ members = [ ] [patch.crates-io] -collab = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "ad656c" } -collab-client-ws = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "ad656c" } -collab-sync= { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "ad656c" } -collab-persistence = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "ad656c" } -collab-plugins = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "ad656c" } +collab = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "47a4e9" } +collab-client-ws = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "47a4e9" } +collab-sync= { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "47a4e9" } +collab-persistence = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "47a4e9" } +collab-plugins = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "47a4e9" } #collab = { path = "./crates/AppFlowy-Collab/collab" } #collab-client-ws = { path = "./crates/AppFlowy-Collab/collab-client-ws" } diff --git a/crates/websocket/src/client.rs b/crates/websocket/src/client.rs index 3ce8d17b..8da31106 100644 --- a/crates/websocket/src/client.rs +++ b/crates/websocket/src/client.rs @@ -54,7 +54,6 @@ impl CollabSession { fn forward_binary_to_ws_server(&self, bytes: Bytes) { match WSMessage::from_vec(bytes.to_vec()) { Ok(ws_message) => { - tracing::trace!("[WSClient]: forward message to server"); let collab_msg = CollabMessage::from_vec(&ws_message.payload).unwrap(); self.server.do_send(ClientMessage { business_id: ws_message.business_id, @@ -110,7 +109,6 @@ impl Handler for CollabSession { type Result = (); fn handle(&mut self, server_msg: ServerMessage, ctx: &mut Self::Context) { - tracing::trace!("[WSClient]: forward message to client"); ctx.binary(WSMessage::from(server_msg)); } } diff --git a/crates/websocket/src/entities.rs b/crates/websocket/src/entities.rs index 89e82d2b..f9823d6c 100644 --- a/crates/websocket/src/entities.rs +++ b/crates/websocket/src/entities.rs @@ -52,7 +52,7 @@ pub struct Disconnect { #[derive(Debug, Message, Clone)] #[rtype(result = "()")] pub struct ClientMessage { - pub business_id: String, + pub business_id: u8, pub user: Arc, pub collab_msg: CollabMessage, } @@ -60,13 +60,15 @@ pub struct ClientMessage { #[derive(Debug, Message, Clone)] #[rtype(result = "()")] pub struct ServerMessage { - pub business_id: String, + pub business_id: u8, + pub object_id: String, pub payload: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WSMessage { - pub business_id: String, + pub business_id: u8, + pub object_id: String, pub payload: Vec, } @@ -87,6 +89,7 @@ impl From for WSMessage { fn from(server_msg: ServerMessage) -> Self { Self { business_id: server_msg.business_id, + object_id: server_msg.object_id, payload: server_msg.payload, } } @@ -95,7 +98,8 @@ impl From for WSMessage { impl From for WSMessage { fn from(msg: CollabMessage) -> Self { Self { - business_id: msg.business_id().to_string(), + business_id: msg.business_id(), + object_id: msg.object_id().to_string(), payload: msg.to_vec(), } } @@ -104,7 +108,8 @@ impl From for WSMessage { impl From for ServerMessage { fn from(msg: CollabMessage) -> Self { Self { - business_id: msg.business_id().to_string(), + business_id: msg.business_id(), + object_id: msg.object_id().to_string(), payload: msg.to_vec(), } } @@ -114,6 +119,7 @@ impl From for WSMessage { fn from(client_msg: ClientMessage) -> Self { Self { business_id: client_msg.business_id, + object_id: client_msg.collab_msg.object_id().to_string(), payload: client_msg.collab_msg.to_vec(), } } diff --git a/crates/websocket/src/error.rs b/crates/websocket/src/error.rs index 6564f12c..eb17cc44 100644 --- a/crates/websocket/src/error.rs +++ b/crates/websocket/src/error.rs @@ -1,8 +1,5 @@ -#[derive(Debug, thiserror::Error)] +#[derive(Debug, Clone, thiserror::Error)] pub enum WSError { - #[error(transparent)] - Persistence(#[from] collab_plugins::disk::error::PersistenceError), - - #[error("Internal failure: {0}")] - Internal(#[from] Box), + #[error("Internal failure:{0}")] + Internal(String), } diff --git a/crates/websocket/src/server.rs b/crates/websocket/src/server.rs index a34dcd2a..491a4c95 100644 --- a/crates/websocket/src/server.rs +++ b/crates/websocket/src/server.rs @@ -19,8 +19,7 @@ use parking_lot::{Mutex, RwLock}; use std::collections::HashMap; use std::sync::Arc; -use tokio::sync::mpsc::Sender; -use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::wrappers::{BroadcastStream, ReceiverStream}; use tokio_stream::StreamExt; #[derive(Clone)] @@ -53,10 +52,14 @@ impl CollabServer { fn create_collab_id(&self, object_id: &str) -> Result { let collab_id = self.collab_id_gen.lock().next_id(); let collab_key = make_collab_id_key(object_id.as_ref()); - self.db.with_write_txn(|w_txn| { - w_txn.insert(collab_key.as_ref(), collab_id.to_be_bytes())?; - Ok(()) - })?; + self + .db + .with_write_txn(|w_txn| { + w_txn.insert(collab_key.as_ref(), collab_id.to_be_bytes())?; + Ok(()) + }) + .map_err(|e| WSError::Internal(e.to_string()))?; + tracing::trace!("[WSServer]: Create new collab id: {}", collab_id); Ok(collab_id) } @@ -93,6 +96,11 @@ impl CollabServer { if self.collab_groups.read().contains_key(&collab_id) { return; } + tracing::trace!( + "[WSServer]: Create new group: collab_id:{} object_id:{}", + collab_id, + object_id + ); let collab = MutexCollab::new(CollabOrigin::Empty, object_id, vec![]); let plugin = RocksdbServerDiskPlugin::new(collab_id, self.db.clone()).unwrap(); @@ -119,13 +127,7 @@ impl Handler for CollabServer { fn handle(&mut self, new_conn: Connect, _ctx: &mut Context) -> Self::Result { tracing::trace!("[WSServer]: {} connect", new_conn.user); - // When receive a new connection, create a new [ClientStream] that holds the connection's websocket - let (stream_tx, stream_rx) = tokio::sync::mpsc::channel(1000); - let stream = WSClientStream::new( - ClientSink(new_conn.socket), - ReceiverStream::new(stream_rx), - stream_tx, - ); + let stream = WSClientStream::new(ClientSink(new_conn.socket)); self.client_streams.write().insert(new_conn.user, stream); Ok(()) } @@ -149,17 +151,33 @@ impl Handler for CollabServer { // Also create a new [CollabGroup] for the collab_id if it is not exist. if let Ok(collab_id) = self.get_or_create_collab_id(object_id) { if let Some(collab_group) = self.collab_groups.write().get_mut(&collab_id) { - if let Some(client_stream) = self.client_streams.write().get_mut(&client_msg.user) { - // If the client's stream is not subscribed to the collab group, subscribe it. - if let Some((sink, stream)) = client_stream.split() { - let origin = match client_msg.collab_msg.origin() { - None => CollabOrigin::Empty, - Some(client) => client.clone(), - }; - let sub = collab_group - .broadcast - .subscribe(origin.clone(), sink, stream); - collab_group.subscribers.insert(origin, sub); + let origin = match client_msg.collab_msg.origin() { + None => { + tracing::error!("🔴The origin from client message is empty"); + CollabOrigin::Empty + }, + Some(client) => client.clone(), + }; + + let is_subscribe = collab_group.subscribers.get(&origin).is_some(); + // If the client's stream is not subscribed to the collab group, subscribe it. + if !is_subscribe { + if let Some(client_stream) = self.client_streams.write().get_mut(&client_msg.user) { + if let Some((sink, stream)) = client_stream.stream_object::( + object_id.to_string(), + move |object_id, msg| msg.object_id() == object_id, + move |object_id, msg| msg.object_id == object_id, + ) { + tracing::trace!( + "[WSServer]: {} subscribe group:{}", + client_msg.user, + collab_id + ); + let subscription = collab_group + .broadcast + .subscribe(origin.clone(), sink, stream); + collab_group.subscribers.insert(origin, subscription); + } } } } @@ -168,16 +186,17 @@ impl Handler for CollabServer { Box::pin(async move { if let Some(client_stream) = client_streams.read().get(&client_msg.user) { tracing::trace!( - "[WSServer]: receives client message: {:?}", + "[WSServer]: receives: [collab_id:{}|oid:{}|msg_id:{:?}]", + collab_id, + client_msg.collab_msg.object_id(), client_msg.collab_msg.msg_id() ); match client_stream .stream_tx .send(Ok(WSMessage::from(client_msg))) - .await { Ok(_) => {}, - Err(e) => tracing::trace!("send error: {:?}", e), + Err(e) => tracing::error!("🔴send error: {:?}", e), } } }) @@ -194,37 +213,41 @@ impl actix::Supervised for CollabServer { } pub struct WSClientStream { - sink: Option, - stream: Option>>, - stream_tx: Sender>, + sink: ClientSink, + stream_tx: tokio::sync::broadcast::Sender>, } impl WSClientStream { - pub fn new( - sink: ClientSink, - stream: ReceiverStream>, - stream_tx: Sender>, - ) -> Self { - Self { - sink: Some(sink), - stream: Some(stream), - stream_tx, - } + pub fn new(sink: ClientSink) -> Self { + // When receive a new connection, create a new [ClientStream] that holds the connection's websocket + let (stream_tx, _) = tokio::sync::broadcast::channel(1000); + Self { sink, stream_tx } } + /// Returns a [UnboundedSenderSink] and a [ReceiverStream] for the object_id. #[allow(clippy::type_complexity)] - pub fn split(&mut self) -> Option<(UnboundedSenderSink, ReceiverStream>)> + pub fn stream_object( + &mut self, + object_id: String, + sink_filter: F1, + stream_filter: F2, + ) -> Option<(UnboundedSenderSink, ReceiverStream>)> where T: TryFrom + Into + Send + Sync + 'static, + F1: Fn(&str, &T) -> bool + Send + Sync + 'static, + F2: Fn(&str, &WSMessage) -> bool + Send + Sync + 'static, { - let client_sink = self.sink.take()?; - let mut stream = self.stream.take()?; + let client_sink = self.sink.clone(); + let mut stream = BroadcastStream::new(self.stream_tx.subscribe()); + let cloned_object_id = object_id.clone(); // forward sink let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); tokio::spawn(async move { while let Some(msg) = rx.recv().await { - client_sink.do_send(msg.into()); + if sink_filter(&cloned_object_id, &msg) { + client_sink.do_send(msg.into()); + } } }); let sink = UnboundedSenderSink::::new(tx); @@ -232,8 +255,10 @@ impl WSClientStream { // forward stream let (tx, rx) = tokio::sync::mpsc::channel(100); tokio::spawn(async move { - while let Some(Ok(msg)) = stream.next().await { - let _ = tx.send(T::try_from(msg)).await; + while let Some(Ok(Ok(msg))) = stream.next().await { + if stream_filter(&object_id, &msg) { + let _ = tx.send(T::try_from(msg)).await; + } } }); let stream = ReceiverStream::new(rx); @@ -246,6 +271,6 @@ impl TryFrom for CollabMessage { type Error = WSError; fn try_from(value: WSMessage) -> Result { - CollabMessage::from_vec(&value.payload).map_err(|e| WSError::Internal(Box::new(e))) + CollabMessage::from_vec(&value.payload).map_err(|e| WSError::Internal(e.to_string())) } } diff --git a/tests/util/test_server.rs b/tests/util/test_server.rs index 69a67a95..5572abe8 100644 --- a/tests/util/test_server.rs +++ b/tests/util/test_server.rs @@ -207,7 +207,7 @@ impl TestUser { pub fn generate() -> Self { Self { name: "Me".to_string(), - email: "me@appflowy.io".to_string(), + email: format!("{}@appflowy.io", Uuid::new_v4()), password: "Hello@AppFlowy123".to_string(), } } diff --git a/tests/ws/client.rs b/tests/ws/client.rs deleted file mode 100644 index 039bed41..00000000 --- a/tests/ws/client.rs +++ /dev/null @@ -1,98 +0,0 @@ -use collab::core::collab::MutexCollab; -use collab::core::origin::{CollabClient, CollabOrigin}; - -use collab_client_ws::{WSBusinessHandler, WSClient, WSClientConfig}; -use collab_plugins::disk::kv::rocks_kv::RocksCollabDB; -use collab_plugins::disk::rocksdb::RocksdbDiskPlugin; -use collab_plugins::sync::SyncPlugin; -use std::net::SocketAddr; -use std::ops::Deref; -use std::path::PathBuf; -use std::sync::Arc; -use std::time::Duration; -use tempfile::TempDir; - -pub async fn spawn_client( - uid: i64, - object_id: &str, - address: String, -) -> std::io::Result { - let ws_client = WSClient::new(address, WSClientConfig::default()); - let addr = ws_client.connect().await.unwrap().unwrap(); - let origin = origin_from_tcp_stream(&addr); - let handler = ws_client - .subscribe_business("collab".to_string()) - .await - .unwrap(); - - // - let (sink, stream) = (handler.sink(), handler.stream()); - let collab = Arc::new(MutexCollab::new(origin.clone(), object_id, vec![])); - let sync_plugin = SyncPlugin::new(origin, object_id, collab.clone(), sink, stream); - collab.lock().add_plugin(Arc::new(sync_plugin)); - - // disk - let tempdir = TempDir::new().unwrap(); - let db_path = tempdir.into_path(); - let db = Arc::new(RocksCollabDB::open(db_path.clone()).unwrap()); - let disk_plugin = RocksdbDiskPlugin::new(uid, db.clone()).unwrap(); - collab.lock().add_plugin(Arc::new(disk_plugin)); - collab.initial(); - - let cleaner = Cleaner::new(db_path); - Ok(TestClient { - ws_client, - db, - collab, - cleaner, - handlers: vec![handler], - }) -} - -fn origin_from_tcp_stream(addr: &SocketAddr) -> CollabOrigin { - let origin = CollabClient::new(addr.port() as i64, &addr.to_string()); - CollabOrigin::Client(origin) -} - -pub struct TestClient { - #[allow(dead_code)] - ws_client: WSClient, - pub db: Arc, - pub collab: Arc, - - #[allow(dead_code)] - cleaner: Cleaner, - - #[allow(dead_code)] - handlers: Vec>, -} - -struct Cleaner(PathBuf); - -impl Cleaner { - fn new(dir: PathBuf) -> Self { - Cleaner(dir) - } - - fn cleanup(dir: &PathBuf) { - let _ = std::fs::remove_dir_all(dir); - } -} - -impl Drop for Cleaner { - fn drop(&mut self) { - Self::cleanup(&self.0) - } -} - -impl Deref for TestClient { - type Target = Arc; - - fn deref(&self) -> &Self::Target { - &self.collab - } -} - -pub async fn wait(secs: u64) { - tokio::time::sleep(Duration::from_secs(secs)).await; -} diff --git a/tests/ws/mod.rs b/tests/ws/mod.rs index 44a68920..342983cc 100644 --- a/tests/ws/mod.rs +++ b/tests/ws/mod.rs @@ -1,3 +1,4 @@ -mod client; -mod test; +mod multiple_client_doc_test; +mod one_client_doc_test; +mod script; mod ws_reconnect; diff --git a/tests/ws/multiple_client_doc_test.rs b/tests/ws/multiple_client_doc_test.rs new file mode 100644 index 00000000..68579691 --- /dev/null +++ b/tests/ws/multiple_client_doc_test.rs @@ -0,0 +1,57 @@ +use crate::ws::script::{ScriptTest, TestScript::*}; +use serde_json::json; + +#[actix_rt::test] +async fn client_with_multiple_objects_test() { + let mut test = ScriptTest::new().await; + test + .run_scripts(vec![ + CreateClient { uid: 0 }, + OpenObject { + uid: 0, + object_id: "1".to_string(), + }, + OpenObject { + uid: 0, + object_id: "2".to_string(), + }, + ModifyClientCollab { + uid: 0, + object_id: "1".to_string(), + f: |collab| { + collab.insert("1", "a"); + }, + }, + ModifyClientCollab { + uid: 0, + object_id: "2".to_string(), + f: |collab| { + collab.insert("2", "b"); + }, + }, + Wait { secs: 2 }, + AssertClientContent { + uid: 0, + object_id: "1".to_string(), + expected: json!({ + "1": "a" + }), + }, + AssertClientEqualToServer { + uid: 0, + object_id: "1".to_string(), + }, + AssertClientContent { + uid: 0, + object_id: "2".to_string(), + expected: json!({ + "2": "b" + }), + }, + AssertClientEqualToServer { + uid: 0, + object_id: "2".to_string(), + }, + ]) + .await; +} diff --git a/tests/ws/one_client_doc_test.rs b/tests/ws/one_client_doc_test.rs new file mode 100644 index 00000000..fd17e7d3 --- /dev/null +++ b/tests/ws/one_client_doc_test.rs @@ -0,0 +1,120 @@ +use crate::ws::script::{ScriptTest, TestScript::*}; +use serde_json::json; + +#[actix_rt::test] +async fn single_client_connect_test() { + let mut test = ScriptTest::new().await; + test + .run_scripts(vec![ + CreateClient { uid: 0 }, + OpenObject { + uid: 0, + object_id: "1".to_string(), + }, + ModifyClientCollab { + uid: 0, + object_id: "1".to_string(), + f: |collab| { + collab.insert("1", "a"); + }, + }, + Wait { secs: 1 }, + AssertClientContent { + uid: 0, + object_id: "1".to_string(), + expected: json!({ + "1": "a" + }), + }, + AssertClientEqualToServer { + uid: 0, + object_id: "1".to_string(), + }, + ]) + .await; +} + +#[actix_rt::test] +async fn client_single_write_test() { + let mut test = ScriptTest::new().await; + test + .run_scripts(vec![ + CreateClient { uid: 0 }, + CreateClient { uid: 1 }, + OpenObject { + uid: 0, + object_id: "1".to_string(), + }, + OpenObject { + uid: 1, + object_id: "1".to_string(), + }, + ModifyClientCollab { + uid: 0, + object_id: "1".to_string(), + f: |collab| { + collab.insert("1", "a"); + }, + }, + Wait { secs: 2 }, + AssertClientEqualToServer { + uid: 0, + object_id: "1".to_string(), + }, + AssertClientEqualToServer { + uid: 1, + object_id: "1".to_string(), + }, + ]) + .await; +} + +#[actix_rt::test] +async fn client_multiple_write_test() { + let mut test = ScriptTest::new().await; + test + .run_scripts(vec![ + CreateClient { uid: 0 }, + CreateClient { uid: 1 }, + OpenObject { + uid: 0, + object_id: "1".to_string(), + }, + OpenObject { + uid: 1, + object_id: "1".to_string(), + }, + ModifyClientCollab { + uid: 0, + object_id: "1".to_string(), + f: |collab| { + collab.insert("1", "a"); + }, + }, + Wait { secs: 1 }, + ModifyClientCollab { + uid: 1, + object_id: "1".to_string(), + f: |collab| { + collab.insert("2", "b"); + }, + }, + Wait { secs: 1 }, + AssertServerContent { + object_id: "1".to_string(), + expected: json!({ + "1": "a", + "2": "b" + }), + }, + AssertClientEqualToServer { + uid: 0, + object_id: "1".to_string(), + }, + AssertClientEqualToServer { + uid: 1, + object_id: "1".to_string(), + }, + ]) + .await; +} diff --git a/tests/ws/script.rs b/tests/ws/script.rs new file mode 100644 index 00000000..6de4e90d --- /dev/null +++ b/tests/ws/script.rs @@ -0,0 +1,202 @@ +use crate::util::{spawn_server, TestServer, TestUser}; + +use collab::core::collab::MutexCollab; +use collab::core::origin::{CollabClient, CollabOrigin}; +use collab::preclude::Collab; +use collab_client_ws::{WSClient, WSClientConfig, WSObjectHandler}; +use collab_plugins::disk::kv::rocks_kv::RocksCollabDB; +use collab_plugins::disk::rocksdb::RocksdbDiskPlugin; +use collab_plugins::sync::SyncPlugin; +use serde_json::Value; +use std::collections::HashMap; +use std::net::SocketAddr; + +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; +use tempfile::TempDir; +use tokio::sync::RwLock; + +pub enum TestScript { + CreateClient { + uid: i64, + }, + OpenObject { + uid: i64, + object_id: String, + }, + AssertClientContent { + uid: i64, + object_id: String, + expected: Value, + }, + Wait { + secs: u64, + }, + AssertServerContent { + object_id: String, + expected: Value, + }, + ModifyClientCollab { + uid: i64, + object_id: String, + f: fn(&Collab), + }, + AssertClientEqualToServer { + uid: i64, + object_id: String, + }, +} + +pub struct ScriptTest { + server: TestServer, + pub clients: RwLock>, +} + +impl ScriptTest { + pub async fn new() -> Self { + let server = spawn_server().await; + ScriptTest { + server, + clients: RwLock::new(HashMap::new()), + } + } + + async fn get_client_doc_value(&self, uid: i64, object_id: &str) -> Value { + self + .clients + .read() + .await + .get(&uid) + .unwrap() + .collab_by_object_id + .get(object_id) + .unwrap() + .lock() + .to_json_value() + } + + pub async fn run_script(&mut self, script: TestScript) { + match script { + TestScript::CreateClient { uid } => { + let test_user = TestUser::generate(); + let token = test_user.register(&self.server).await; + let address = format!("{}/{}", self.server.ws_addr, token); + let ws = WSClient::new(address, WSClientConfig::default()); + let addr = ws.connect().await.unwrap().unwrap(); + let origin = origin_from_tcp_stream(&addr); + let tempdir = TempDir::new().unwrap(); + let db_path = tempdir.into_path(); + let db = Arc::new(RocksCollabDB::open(db_path.clone()).unwrap()); + let cleaner = Cleaner::new(db_path); + let client = TestClient { + ws, + db, + origin, + collab_by_object_id: Default::default(), + handlers: vec![], + cleaner, + }; + self.clients.write().await.insert(uid, client); + }, + TestScript::OpenObject { uid, object_id } => { + let mut clients = self.clients.write().await; + let client = clients.get_mut(&uid).unwrap(); + let handler = client.ws.subscribe(1, object_id.clone()).await.unwrap(); + let (sink, stream) = (handler.sink(), handler.stream()); + let collab = Arc::new(MutexCollab::new(client.origin.clone(), &object_id, vec![])); + + // Sync + let sync_plugin = SyncPlugin::new( + client.origin.clone(), + &object_id, + collab.clone(), + sink, + stream, + ); + collab.lock().add_plugin(Arc::new(sync_plugin)); + + // Disk + let disk_plugin = RocksdbDiskPlugin::new(uid, client.db.clone()).unwrap(); + collab.lock().add_plugin(Arc::new(disk_plugin)); + + collab.initial(); + client.handlers.push(handler); + client.collab_by_object_id.insert(object_id, collab); + }, + TestScript::Wait { secs } => { + tokio::time::sleep(Duration::from_secs(secs)).await; + }, + TestScript::AssertClientContent { + uid, + object_id, + expected, + } => { + let value = self.get_client_doc_value(uid, &object_id).await; + assert_json_diff::assert_json_eq!(value, expected); + }, + TestScript::AssertServerContent { + object_id, + expected, + } => { + let value = self.server.get_doc(&object_id); + assert_json_diff::assert_json_eq!(value, expected); + }, + TestScript::AssertClientEqualToServer { uid, object_id } => { + let server_value = self.server.get_doc(&object_id); + let client_value = self.get_client_doc_value(uid, &object_id).await; + assert_eq!(client_value, server_value); + assert_json_diff::assert_json_eq!(client_value, server_value); + }, + TestScript::ModifyClientCollab { uid, object_id, f } => { + let mut clients = self.clients.write().await; + let client = clients.get_mut(&uid).unwrap(); + let collab = client + .collab_by_object_id + .get_mut(&object_id) + .unwrap() + .lock(); + f(&collab); + }, + } + } + + pub async fn run_scripts(&mut self, scripts: Vec) { + for script in scripts { + self.run_script(script).await; + } + } +} + +fn origin_from_tcp_stream(addr: &SocketAddr) -> CollabOrigin { + let origin = CollabClient::new(addr.port() as i64, &addr.to_string()); + CollabOrigin::Client(origin) +} + +pub struct TestClient { + pub ws: WSClient, + pub db: Arc, + pub origin: CollabOrigin, + pub collab_by_object_id: HashMap>, + pub handlers: Vec>, + #[allow(dead_code)] + cleaner: Cleaner, +} + +struct Cleaner(PathBuf); + +impl Cleaner { + fn new(dir: PathBuf) -> Self { + Cleaner(dir) + } + + fn cleanup(dir: &PathBuf) { + let _ = std::fs::remove_dir_all(dir); + } +} + +impl Drop for Cleaner { + fn drop(&mut self) { + Self::cleanup(&self.0) + } +} diff --git a/tests/ws/test.rs b/tests/ws/test.rs deleted file mode 100644 index 9b3fb6f6..00000000 --- a/tests/ws/test.rs +++ /dev/null @@ -1,27 +0,0 @@ -use crate::util::{spawn_server, TestUser}; -use crate::ws::client::{spawn_client, wait}; -use serde_json::json; - -#[actix_rt::test] -async fn ws_conn_test() { - let server = spawn_server().await; - let test_user = TestUser::generate(); - let token = test_user.register(&server).await; - let address = format!("{}/{}", server.ws_addr, token); - let client = spawn_client(1, "1", address).await.unwrap(); - - wait(1).await; - { - let collab = client.lock(); - collab.insert("1", "a"); - } - wait(1).await; - - let value = server.get_doc("1"); - assert_json_diff::assert_json_eq!( - value, - json!({ - "1": "a" - }) - ); -} diff --git a/tests/ws/ws_reconnect.rs b/tests/ws/ws_reconnect.rs index 1cdf351a..c7f0e688 100644 --- a/tests/ws/ws_reconnect.rs +++ b/tests/ws/ws_reconnect.rs @@ -1,4 +1,5 @@ use crate::util::{spawn_server, TestUser}; +use std::time::Duration; use collab_client_ws::{WSClient, WSClientConfig}; @@ -18,5 +19,10 @@ async fn ws_retry_connect() { }, ); let _addr = ws_client.connect().await.unwrap().unwrap(); - // wait(20).await; + // wait(10).await; +} + +#[allow(dead_code)] +async fn wait(secs: u64) { + tokio::time::sleep(Duration::from_secs(secs)).await; }