diff --git a/Cargo.lock b/Cargo.lock index 278b5d85..fca046b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1467,6 +1467,8 @@ name = "collab-stream" version = "0.1.0" dependencies = [ "anyhow", + "futures", + "rand 0.8.5", "redis 0.24.0", "thiserror", "tokio", diff --git a/libs/collab-stream/Cargo.toml b/libs/collab-stream/Cargo.toml index ac930cd6..8ea7ad16 100644 --- a/libs/collab-stream/Cargo.toml +++ b/libs/collab-stream/Cargo.toml @@ -11,3 +11,7 @@ tokio = { version = "1.26", features = ["rt-multi-thread", "macros" ] } tokio-stream = { version = "0.1.14" } thiserror = "1.0.58" anyhow = "1.0.81" + +[dev-dependencies] +futures = "0.3.30" +rand = "0.8.5" \ No newline at end of file diff --git a/libs/collab-stream/src/client.rs b/libs/collab-stream/src/client.rs index e9f8475e..9464873e 100644 --- a/libs/collab-stream/src/client.rs +++ b/libs/collab-stream/src/client.rs @@ -1,4 +1,6 @@ +use crate::error::StreamError; use crate::stream::CollabStream; +use crate::stream_group::CollabStreamGroup; use redis::aio::ConnectionManager; pub struct CollabStreamClient { @@ -12,6 +14,22 @@ impl CollabStreamClient { } pub async fn stream(&self, workspace_id: &str, oid: &str) -> CollabStream { - CollabStream::new(self.connection_manager.clone(), workspace_id, oid) + CollabStream::new(workspace_id, oid, self.connection_manager.clone()) + } + + pub async fn group_stream( + &self, + workspace_id: &str, + oid: &str, + group_name: &str, + ) -> Result { + let mut group = CollabStreamGroup::new( + workspace_id, + oid, + group_name, + self.connection_manager.clone(), + ); + group.ensure_consumer_group("0").await?; + Ok(group) } } diff --git a/libs/collab-stream/src/error.rs b/libs/collab-stream/src/error.rs index 07ef4ea6..cf1f3190 100644 --- a/libs/collab-stream/src/error.rs +++ b/libs/collab-stream/src/error.rs @@ -17,6 +17,9 @@ pub enum StreamError { #[error(transparent)] ParseIntError(#[from] std::num::ParseIntError), + #[error("Stream group already exists")] + GroupAlreadyExists(String), + #[error("Internal error: {0}")] Internal(anyhow::Error), } diff --git a/libs/collab-stream/src/lib.rs b/libs/collab-stream/src/lib.rs index 0ef0f1e0..07ac461d 100644 --- a/libs/collab-stream/src/lib.rs +++ b/libs/collab-stream/src/lib.rs @@ -2,3 +2,4 @@ pub mod client; mod error; pub mod model; pub mod stream; +pub mod stream_group; diff --git a/libs/collab-stream/src/model.rs b/libs/collab-stream/src/model.rs index 9402b6aa..0443ec5b 100644 --- a/libs/collab-stream/src/model.rs +++ b/libs/collab-stream/src/model.rs @@ -55,6 +55,9 @@ pub struct MessageReadByStreamKey(pub BTreeMap>); impl FromRedisValue for MessageReadByStreamKey { fn from_redis_value(v: &Value) -> RedisResult { let mut map: BTreeMap> = BTreeMap::new(); + if matches!(v, Value::Nil) { + return Ok(MessageReadByStreamKey(map)); + } let value_by_id = bulk_from_redis_value(v)?.iter(); for value in value_by_id { diff --git a/libs/collab-stream/src/stream.rs b/libs/collab-stream/src/stream.rs index 17c1aa81..e7ccfa55 100644 --- a/libs/collab-stream/src/stream.rs +++ b/libs/collab-stream/src/stream.rs @@ -10,7 +10,7 @@ pub struct CollabStream { } impl CollabStream { - pub fn new(connection_manager: ConnectionManager, workspace_id: &str, oid: &str) -> Self { + pub fn new(workspace_id: &str, oid: &str, connection_manager: ConnectionManager) -> Self { let stream_key = format!("af_collab-{}-{}", workspace_id, oid); Self { connection_manager, @@ -67,7 +67,7 @@ impl CollabStream { .map(|ct| format!("{}-{}", ct.timestamp_ms, ct.sequence_number)) .unwrap_or_else(|| "$".to_string()); - let options = StreamReadOptions::default().block(100); + let options = StreamReadOptions::default().group("1", "2").block(100); let map: MessageReadByStreamKey = self .connection_manager .xread_options(&[&self.stream_key], &[&id], &options) diff --git a/libs/collab-stream/src/stream_group.rs b/libs/collab-stream/src/stream_group.rs new file mode 100644 index 00000000..6826e4fb --- /dev/null +++ b/libs/collab-stream/src/stream_group.rs @@ -0,0 +1,112 @@ +use crate::error::StreamError; +use crate::model::{CreatedTime, Message, MessageRead, MessageReadByStreamKey}; +use redis::aio::ConnectionManager; +use redis::streams::{StreamMaxlen, StreamReadOptions}; +use redis::{pipe, AsyncCommands, RedisError, RedisResult}; + +#[derive(Clone)] +pub struct CollabStreamGroup { + connection_manager: ConnectionManager, + stream_key: String, + group_name: String, +} + +impl CollabStreamGroup { + pub fn new( + workspace_id: &str, + oid: &str, + group_name: &str, + connection_manager: ConnectionManager, + ) -> Self { + let group_name = group_name.to_string(); + let stream_key = format!("af_collab-{}-{}", workspace_id, oid); + Self { + group_name, + connection_manager, + stream_key, + } + } + + /// Ensures the consumer group exists, creating it if necessary. + /// start_id: + /// Use '$' if you want new messages or '0' to read from the beginning. + pub async fn ensure_consumer_group(&mut self, start_id: &str) -> Result<(), StreamError> { + let _: RedisResult<()> = self + .connection_manager + .xgroup_create_mkstream(&self.stream_key, &self.group_name, start_id) + .await; + + Ok(()) + } + + /// Acknowledges messages processed by a consumer. + pub async fn ack_messages(&mut self, message_ids: &[String]) -> Result<(), StreamError> { + self + .connection_manager + .xack(&self.stream_key, &self.group_name, message_ids) + .await?; + Ok(()) + } + + /// Inserts a single message into the Redis stream. + pub async fn insert_message(&mut self, message: Message) -> Result { + let tuple = message.into_tuple_array(); + let created_time = self + .connection_manager + .xadd(&self.stream_key, "*", tuple.as_slice()) + .await?; + Ok(created_time) + } + + /// Inserts multiple messages into the Redis stream using a pipeline. + /// + pub async fn insert_messages(&mut self, messages: Vec) -> Result<(), StreamError> { + let mut pipe = pipe(); + for message in messages { + let tuple = message.into_tuple_array(); + pipe.xadd(&self.stream_key, "*", tuple.as_slice()); + } + pipe.query_async(&mut self.connection_manager).await?; + Ok(()) + } + + /// Fetches number of messages from a Redis stream + /// Returns the messages that were not consumed yet. Which means each message is delivered to only + /// one consumer in the group + pub async fn fetch_messages( + &mut self, + consumer_name: &str, + count: usize, + ) -> Result, StreamError> { + let options = StreamReadOptions::default() + .group(&self.group_name, consumer_name) + .count(count) + .block(100); + + let map: MessageReadByStreamKey = self + .connection_manager + .xread_options(&[&self.stream_key], &[">"], &options) + .await?; + + match map.0.into_iter().next() { + None => Ok(Vec::with_capacity(0)), + Some((_, messages)) => Ok(messages.into_iter().map(Into::into).collect()), + } + } + + /// Reads all messages from the stream + /// + pub async fn read_all_message(&mut self) -> Result, StreamError> { + let read_messages: Vec = + self.connection_manager.xrange_all(&self.stream_key).await?; + Ok(read_messages.into_iter().map(Into::into).collect()) + } + + pub async fn clear(&mut self) -> Result<(), RedisError> { + self + .connection_manager + .xtrim(&self.stream_key, StreamMaxlen::Equals(0)) + .await?; + Ok(()) + } +} diff --git a/libs/collab-stream/tests/stream_test/group_read_test.rs b/libs/collab-stream/tests/stream_test/group_read_test.rs new file mode 100644 index 00000000..aeeb213c --- /dev/null +++ b/libs/collab-stream/tests/stream_test/group_read_test.rs @@ -0,0 +1,130 @@ +use crate::stream_test::test_util::{random_i64, stream_client}; +use collab_stream::model::Message; +use futures::future::join; + +#[tokio::test] +async fn single_group_read_message_test() { + let workspace_id = "w1"; + let oid = format!("o{}", random_i64()); + let client = stream_client().await; + let mut group = client.group_stream(workspace_id, &oid, "g1").await.unwrap(); + + let random_uid = random_i64(); + let msg = Message { + uid: random_uid, + raw_data: vec![1, 2, 3, 4, 5], + }; + + { + let client = stream_client().await; + let mut group = client.group_stream(workspace_id, &oid, "g2").await.unwrap(); + group.insert_message(msg).await.unwrap(); + } + + let messages = group.fetch_messages("consumer1", 1).await.unwrap(); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].raw_data, vec![1, 2, 3, 4, 5]); + assert_eq!(messages[0].uid, random_uid); + + // after the message was consumed, it should not be available anymore + assert!(group + .fetch_messages("consumer1", 1) + .await + .unwrap() + .is_empty()); +} + +#[tokio::test] +async fn different_group_read_message_test() { + let oid = format!("o{}", random_i64()); + let client = stream_client().await; + let mut group_1 = client.group_stream("w1", &oid, "g1").await.unwrap(); + let mut group_2 = client.group_stream("w1", &oid, "g2").await.unwrap(); + + let random_uid = random_i64(); + let msg = Message { + uid: random_uid, + raw_data: vec![1, 2, 3, 4, 5], + }; + + { + let client = stream_client().await; + let mut group = client.group_stream("w1", &oid, "g2").await.unwrap(); + group.insert_message(msg).await.unwrap(); + } + + let (result1, result2) = join( + group_1.fetch_messages("consumer1", 1), + group_2.fetch_messages("consumer1", 1), + ) + .await; + let group_1_messages = result1.unwrap(); + let group_2_messages = result2.unwrap(); + assert_eq!(group_1_messages[0].raw_data, vec![1, 2, 3, 4, 5]); + assert_eq!(group_2_messages[0].raw_data, vec![1, 2, 3, 4, 5]); +} + +#[tokio::test] +async fn read_specific_num_of_message_test() { + let object_id = format!("o{}", random_i64()); + let client = stream_client().await; + let mut group_1 = client.group_stream("w1", &object_id, "g1").await.unwrap(); + let mut uids = vec![]; + { + let client = stream_client().await; + let mut group = client.group_stream("w1", &object_id, "g2").await.unwrap(); + let mut messages = vec![]; + for _i in 0..5 { + let random_uid = random_i64(); + uids.push(random_uid); + let msg = Message { + uid: random_uid, + raw_data: vec![1, 2, 3, 4, 5], + }; + messages.push(msg); + } + group.insert_messages(messages).await.unwrap(); + } + + let messages = group_1.fetch_messages("consumer1", 15).await.unwrap(); + assert_eq!(messages.len(), 5); + for i in 0..5 { + assert_eq!(messages[i].raw_data, vec![1, 2, 3, 4, 5]); + assert_eq!(messages[i].uid, uids[i]); + } +} + +#[tokio::test] +async fn read_all_message_test() { + let object_id = format!("o{}", random_i64()); + let client = stream_client().await; + let mut group = client.group_stream("w1", &object_id, "g1").await.unwrap(); + let mut uids = vec![]; + { + let client = stream_client().await; + let mut group = client.group_stream("w1", &object_id, "g2").await.unwrap(); + let mut messages = vec![]; + for _i in 0..5 { + let random_uid = random_i64(); + uids.push(random_uid); + let msg = Message { + uid: random_uid, + raw_data: vec![1, 2, 3, 4, 5], + }; + messages.push(msg); + } + group.insert_messages(messages).await.unwrap(); + } + + let messages = group.read_all_message().await.unwrap(); + let consumer_messages = group.fetch_messages("consumer1", 15).await.unwrap(); + assert_eq!(messages.len(), 5); + assert_eq!(consumer_messages.len(), 5); + for i in 0..5 { + assert_eq!(messages[i].raw_data, vec![1, 2, 3, 4, 5]); + assert_eq!(consumer_messages[i].raw_data, vec![1, 2, 3, 4, 5]); + + assert_eq!(messages[i].uid, uids[i]); + assert_eq!(consumer_messages[i].uid, uids[i]); + } +} diff --git a/libs/collab-stream/tests/stream_test/mod.rs b/libs/collab-stream/tests/stream_test/mod.rs index b843d0d2..eadfd2e3 100644 --- a/libs/collab-stream/tests/stream_test/mod.rs +++ b/libs/collab-stream/tests/stream_test/mod.rs @@ -1,2 +1,3 @@ +mod group_read_test; mod read_test; mod test_util; diff --git a/libs/collab-stream/tests/stream_test/read_test.rs b/libs/collab-stream/tests/stream_test/read_test.rs index b4913a92..374dfc6a 100644 --- a/libs/collab-stream/tests/stream_test/read_test.rs +++ b/libs/collab-stream/tests/stream_test/read_test.rs @@ -1,10 +1,11 @@ -use crate::stream_test::test_util::stream_client; +use crate::stream_test::test_util::{random_i64, stream_client}; use collab_stream::model::Message; #[tokio::test] async fn read_single_message_test() { + let oid = format!("o{}", random_i64()); let client_2 = stream_client().await; - let mut stream_2 = client_2.stream("w1", "o1").await; + let mut stream_2 = client_2.stream("w1", &oid).await; let (tx, mut rx) = tokio::sync::mpsc::channel(1); tokio::spawn(async move { @@ -19,7 +20,7 @@ async fn read_single_message_test() { { let client_1 = stream_client().await; - let mut stream_1 = client_1.stream("w1", "o1").await; + let mut stream_1 = client_1.stream("w1", &oid).await; stream_1.insert_message(msg).await.unwrap(); } @@ -29,13 +30,14 @@ async fn read_single_message_test() { #[tokio::test] async fn read_multiple_messages_test() { + let oid = format!("o{}", random_i64()); let client_2 = stream_client().await; - let mut stream_2 = client_2.stream("w1", "o1").await; + let mut stream_2 = client_2.stream("w1", &oid).await; stream_2.clear().await.unwrap(); { let client_1 = stream_client().await; - let mut stream_1 = client_1.stream("w1", "o1").await; + let mut stream_1 = client_1.stream("w1", &oid).await; let messages = vec![ Message { uid: 1001, diff --git a/libs/collab-stream/tests/stream_test/test_util.rs b/libs/collab-stream/tests/stream_test/test_util.rs index e3211f4f..7a14d5e7 100644 --- a/libs/collab-stream/tests/stream_test/test_util.rs +++ b/libs/collab-stream/tests/stream_test/test_util.rs @@ -1,5 +1,6 @@ use anyhow::Context; use collab_stream::client::CollabStreamClient; +use rand::{thread_rng, Rng}; pub async fn redis_client() -> redis::Client { let redis_uri = "redis://localhost:6379"; @@ -15,3 +16,9 @@ pub async fn stream_client() -> CollabStreamClient { .context("failed to create stream client") .unwrap() } + +pub fn random_i64() -> i64 { + let mut rng = thread_rng(); + let num: i64 = rng.gen(); + num +}