diff --git a/libs/collab-stream/src/model.rs b/libs/collab-stream/src/model.rs index 0443ec5b..5e8b5901 100644 --- a/libs/collab-stream/src/model.rs +++ b/libs/collab-stream/src/model.rs @@ -5,18 +5,49 @@ use std::str::FromStr; use crate::error::{internal, StreamError}; use redis::{FromRedisValue, RedisError, RedisResult, Value}; +/// The [MessageId] generated by XADD has two parts: a timestamp and a sequence number, separated by +/// a hyphen (-). The timestamp is based on the server's time when the message is added, and the +/// sequence number is used to differentiate messages added at the same millisecond. +/// +/// If multiple messages are added within the same millisecond, Redis increments the sequence number +/// for each subsequent message +/// +/// An example message ID might look like this: 1631020452097-0. In this example, 1631020452097 is +/// the timestamp in milliseconds, and 0 is the sequence number. #[derive(Debug)] -pub struct CreatedTime { +pub struct MessageId { pub timestamp_ms: u64, - - // applies if more than one message is sent at the same millisecond - // this distinguishes the order of the messages pub sequence_number: u16, } -impl CreatedTime { - fn from_redis_stream_key(bytes: &[u8]) -> Result { - let s = std::str::from_utf8(bytes)?; +impl ToString for MessageId { + fn to_string(&self) -> String { + format!("{}-{}", self.timestamp_ms, self.sequence_number) + } +} + +impl MessageId { + pub fn sub_ms(&mut self, ms: u64) { + if self.timestamp_ms < ms { + return; + } + self.timestamp_ms -= ms; + } +} + +impl TryFrom<&[u8]> for MessageId { + type Error = StreamError; + + fn try_from(s: &[u8]) -> Result { + let s = std::str::from_utf8(s)?; + Self::try_from(s) + } +} + +impl TryFrom<&str> for MessageId { + type Error = StreamError; + + fn try_from(s: &str) -> Result { let parts: Vec<_> = s.splitn(2, '-').collect(); if parts.len() != 2 { @@ -27,17 +58,17 @@ impl CreatedTime { let timestamp_ms = u64::from_str(parts[0])?; let sequence_number = u16::from_str(parts[1])?; - Ok(CreatedTime { + Ok(MessageId { timestamp_ms, sequence_number, }) } } -impl FromRedisValue for CreatedTime { +impl FromRedisValue for MessageId { fn from_redis_value(v: &Value) -> RedisResult { match v { - Value::Data(stream_key) => CreatedTime::from_redis_stream_key(stream_key).map_err(|_| { + Value::Data(stream_key) => MessageId::try_from(stream_key.as_slice()).map_err(|_| { RedisError::from(( redis::ErrorKind::TypeError, "invalid stream key", @@ -50,13 +81,13 @@ impl FromRedisValue for CreatedTime { } #[derive(Debug)] -pub struct MessageReadByStreamKey(pub BTreeMap>); +pub struct StreamMessageByStreamKey(pub BTreeMap>); -impl FromRedisValue for MessageReadByStreamKey { +impl FromRedisValue for StreamMessageByStreamKey { fn from_redis_value(v: &Value) -> RedisResult { - let mut map: BTreeMap> = BTreeMap::new(); + let mut map: BTreeMap> = BTreeMap::new(); if matches!(v, Value::Nil) { - return Ok(MessageReadByStreamKey(map)); + return Ok(StreamMessageByStreamKey(map)); } let value_by_id = bulk_from_redis_value(v)?.iter(); @@ -74,25 +105,26 @@ impl FromRedisValue for MessageReadByStreamKey { let stream_key = RedisString::from_redis_value(&key_values[0])?.0; let values = bulk_from_redis_value(&key_values[1])?.iter(); for value in values { - let value = MessageRead::from_redis_value(value)?; + let value = StreamMessage::from_redis_value(value)?; map.entry(stream_key.clone()).or_default().push(value); } } - Ok(MessageReadByStreamKey(map)) + Ok(StreamMessageByStreamKey(map)) } } +/// A message in the Redis stream. It's the same as [Message] but with additional metadata. #[derive(Debug)] -pub struct MessageRead { +pub struct StreamMessage { /// user who did the change pub uid: i64, pub raw_data: Vec, /// only applicable when reading from redis - pub created_time: CreatedTime, + pub message_id: MessageId, } -impl FromRedisValue for MessageRead { +impl FromRedisValue for StreamMessage { // Optimized parsing function fn from_redis_value(v: &Value) -> RedisResult { let bulk = bulk_from_redis_value(v)?; @@ -104,7 +136,7 @@ impl FromRedisValue for MessageRead { ))); } - let created_time = CreatedTime::from_redis_value(&bulk[0])?; + let created_time = MessageId::from_redis_value(&bulk[0])?; let fields = bulk_from_redis_value(&bulk[1])?; if fields.len() != 4 { return Err(RedisError::from(( @@ -119,10 +151,10 @@ impl FromRedisValue for MessageRead { verify_field(&fields[2], "data")?; let raw_data = Vec::::from_redis_value(&fields[3])?; - Ok(MessageRead { + Ok(StreamMessage { uid, raw_data, - created_time, + message_id: created_time, }) } @@ -136,8 +168,8 @@ pub struct Message { pub raw_data: Vec, } -impl From for Message { - fn from(m: MessageRead) -> Self { +impl From for Message { + fn from(m: StreamMessage) -> Self { Message { uid: m.uid, raw_data: m.raw_data, diff --git a/libs/collab-stream/src/stream.rs b/libs/collab-stream/src/stream.rs index e7ccfa55..c1cab6c5 100644 --- a/libs/collab-stream/src/stream.rs +++ b/libs/collab-stream/src/stream.rs @@ -1,5 +1,5 @@ use crate::error::StreamError; -use crate::model::{CreatedTime, Message, MessageRead, MessageReadByStreamKey}; +use crate::model::{Message, MessageId, StreamMessage, StreamMessageByStreamKey}; use redis::aio::ConnectionManager; use redis::streams::{StreamMaxlen, StreamReadOptions}; use redis::{pipe, AsyncCommands, RedisError}; @@ -19,13 +19,13 @@ impl CollabStream { } /// Inserts a single message into the Redis stream. - pub async fn insert_message(&mut self, message: Message) -> Result { + pub async fn insert_message(&mut self, message: Message) -> Result { let tuple = message.into_tuple_array(); - let created_time = self + let message_id = self .connection_manager .xadd(&self.stream_key, "*", tuple.as_slice()) .await?; - Ok(created_time) + Ok(message_id) } /// Inserts multiple messages into the Redis stream using a pipeline. @@ -42,9 +42,9 @@ impl CollabStream { /// Fetches the next message from a Redis stream after a specified entry. /// - pub async fn next(&mut self) -> Result, StreamError> { + pub async fn next(&mut self) -> Result, StreamError> { let options = StreamReadOptions::default().count(1).block(100); - let map: MessageReadByStreamKey = self + let map: StreamMessageByStreamKey = self .connection_manager .xread_options(&[&self.stream_key], &["$"], &options) .await?; @@ -56,21 +56,21 @@ impl CollabStream { .ok_or_else(|| StreamError::UnexpectedValue("Empty stream".into()))?; debug_assert_eq!(messages.len(), 1); - Ok(messages.pop().map(Into::into)) + Ok(messages.pop()) } pub async fn next_after( &mut self, - after: Option, - ) -> Result, StreamError> { - let id = after - .map(|ct| format!("{}-{}", ct.timestamp_ms, ct.sequence_number)) + after: Option, + ) -> Result, StreamError> { + let message_id = after + .map(|ct| ct.to_string()) .unwrap_or_else(|| "$".to_string()); let options = StreamReadOptions::default().group("1", "2").block(100); - let map: MessageReadByStreamKey = self + let map: StreamMessageByStreamKey = self .connection_manager - .xread_options(&[&self.stream_key], &[&id], &options) + .xread_options(&[&self.stream_key], &[&message_id], &options) .await?; let (_, mut messages) = map @@ -80,11 +80,11 @@ impl CollabStream { .ok_or_else(|| StreamError::UnexpectedValue("Empty stream".into()))?; debug_assert_eq!(messages.len(), 1); - Ok(messages.pop().map(Into::into)) + Ok(messages.pop()) } pub async fn read_all_message(&mut self) -> Result, StreamError> { - let read_messages: Vec = + let read_messages: Vec = self.connection_manager.xrange_all(&self.stream_key).await?; Ok(read_messages.into_iter().map(Into::into).collect()) } diff --git a/libs/collab-stream/src/stream_group.rs b/libs/collab-stream/src/stream_group.rs index 6826e4fb..a66dd30a 100644 --- a/libs/collab-stream/src/stream_group.rs +++ b/libs/collab-stream/src/stream_group.rs @@ -1,7 +1,7 @@ use crate::error::StreamError; -use crate::model::{CreatedTime, Message, MessageRead, MessageReadByStreamKey}; +use crate::model::{Message, MessageId, StreamMessage, StreamMessageByStreamKey}; use redis::aio::ConnectionManager; -use redis::streams::{StreamMaxlen, StreamReadOptions}; +use redis::streams::{StreamMaxlen, StreamPendingData, StreamPendingReply, StreamReadOptions}; use redis::{pipe, AsyncCommands, RedisError, RedisResult}; #[derive(Clone)] @@ -40,6 +40,12 @@ impl CollabStreamGroup { } /// Acknowledges messages processed by a consumer. + /// + /// In Redis streams, when a message is delivered to a consumer using XREADGROUP, it moves into + /// a pending state for that consumer. Redis expects you to manually acknowledge these messages + /// using XACK once they have been successfully processed. If you don't acknowledge a message, + /// it remains in the pending state for that consumer. Redis keeps track of these messages so you + /// can handle message failures or retries. pub async fn ack_messages(&mut self, message_ids: &[String]) -> Result<(), StreamError> { self .connection_manager @@ -49,13 +55,13 @@ impl CollabStreamGroup { } /// Inserts a single message into the Redis stream. - pub async fn insert_message(&mut self, message: Message) -> Result { + pub async fn insert_message(&mut self, message: Message) -> Result { let tuple = message.into_tuple_array(); - let created_time = self + let message_id = self .connection_manager .xadd(&self.stream_key, "*", tuple.as_slice()) .await?; - Ok(created_time) + Ok(message_id) } /// Inserts multiple messages into the Redis stream using a pipeline. @@ -73,33 +79,88 @@ impl CollabStreamGroup { /// Fetches number of messages from a Redis stream /// Returns the messages that were not consumed yet. Which means each message is delivered to only /// one consumer in the group - pub async fn fetch_messages( + /// + /// $: This symbol is used with the XREAD command to indicate that you want to start reading only + /// new messages that arrive in the stream after the read command has been issued. Essentially, + /// it tells Redis to ignore all the messages already in the stream and only listen for new ones. + /// It's particularly useful when you want to start processing messages from the current moment + /// forward and don't need to process historical messages. + /// + /// >: This symbol is used with the XREADGROUP command in the context of consumer groups. When a + /// consumer group reads from a stream using >, it tells Redis to deliver only messages that have + /// not yet been acknowledged by any consumer in the group. This allows different consumers in the + /// group to read and process different messages concurrently, without receiving messages that have + /// already been processed by another consumer. It's a way to distribute the workload of processing + /// stream messages across multiple consumers. + pub async fn consumer_messages( &mut self, consumer_name: &str, - count: usize, - ) -> Result, StreamError> { - let options = StreamReadOptions::default() + option: ConsumeOptions, + ) -> Result, StreamError> { + let mut options = StreamReadOptions::default() .group(&self.group_name, consumer_name) - .count(count) .block(100); - let map: MessageReadByStreamKey = self + let mut message_id = ">".to_string(); + match option { + ConsumeOptions::Empty => {}, + ConsumeOptions::Count(count) => { + options = options.count(count); + }, + ConsumeOptions::After(after) => { + message_id = after.to_string(); + }, + } + + let map: StreamMessageByStreamKey = self .connection_manager - .xread_options(&[&self.stream_key], &[">"], &options) + .xread_options(&[&self.stream_key], &[message_id], &options) .await?; match map.0.into_iter().next() { None => Ok(Vec::with_capacity(0)), - Some((_, messages)) => Ok(messages.into_iter().map(Into::into).collect()), + Some((_, messages)) => Ok(messages), + } + } + + /// Get messages starting from a specific message id. + /// returns list of messages excluding the message with the start_id + pub async fn get_messages_starting_from_id( + &mut self, + start_id: Option, + count: usize, + ) -> Result, StreamError> { + let options = StreamReadOptions::default().count(count).block(100); + let message_id = start_id.unwrap_or_else(|| "0".to_string()); + let map: StreamMessageByStreamKey = self + .connection_manager + .xread_options(&[&self.stream_key], &[message_id], &options) + .await?; + + match map.0.into_iter().next() { + None => Ok(Vec::with_capacity(0)), + Some((_, messages)) => Ok(messages), } } /// Reads all messages from the stream /// - pub async fn read_all_message(&mut self) -> Result, StreamError> { - let read_messages: Vec = + pub async fn get_all_message(&mut self) -> Result, StreamError> { + let read_messages: Vec = self.connection_manager.xrange_all(&self.stream_key).await?; - Ok(read_messages.into_iter().map(Into::into).collect()) + Ok(read_messages.into_iter().collect()) + } + + pub async fn pending_reply(&mut self) -> Result, StreamError> { + let reply: StreamPendingReply = self + .connection_manager + .xpending(&self.stream_key, &self.group_name) + .await?; + + match reply { + StreamPendingReply::Empty => Ok(None), + StreamPendingReply::Data(data) => Ok(Some(data)), + } } pub async fn clear(&mut self) -> Result<(), RedisError> { @@ -110,3 +171,9 @@ impl CollabStreamGroup { Ok(()) } } + +pub enum ConsumeOptions { + Empty, + Count(usize), + After(MessageId), +} diff --git a/libs/collab-stream/tests/collab_stream_test/mod.rs b/libs/collab-stream/tests/collab_stream_test/mod.rs new file mode 100644 index 00000000..a7c84e32 --- /dev/null +++ b/libs/collab-stream/tests/collab_stream_test/mod.rs @@ -0,0 +1,4 @@ +mod pubsub_test; +mod stream_group_test; +mod stream_test; +mod test_util; diff --git a/libs/collab-stream/tests/stream_test/pubsub_test.rs b/libs/collab-stream/tests/collab_stream_test/pubsub_test.rs similarity index 89% rename from libs/collab-stream/tests/stream_test/pubsub_test.rs rename to libs/collab-stream/tests/collab_stream_test/pubsub_test.rs index 0da4c113..e0e8b6f5 100644 --- a/libs/collab-stream/tests/stream_test/pubsub_test.rs +++ b/libs/collab-stream/tests/collab_stream_test/pubsub_test.rs @@ -1,4 +1,4 @@ -use crate::stream_test::test_util::{pubsub_client, random_i64}; +use crate::collab_stream_test::test_util::{pubsub_client, random_i64}; use collab_stream::pubsub::PubSubMessage; @@ -7,7 +7,7 @@ use std::time::Duration; use tokio::time::sleep; #[tokio::test] -async fn different_group_read_message_test() { +async fn pubsub_test() { let oid = format!("o{}", random_i64()); let client_1 = pubsub_client().await; let client_2 = pubsub_client().await; diff --git a/libs/collab-stream/tests/stream_test/group_read_test.rs b/libs/collab-stream/tests/collab_stream_test/stream_group_test.rs similarity index 63% rename from libs/collab-stream/tests/stream_test/group_read_test.rs rename to libs/collab-stream/tests/collab_stream_test/stream_group_test.rs index aeeb213c..34ce6f22 100644 --- a/libs/collab-stream/tests/stream_test/group_read_test.rs +++ b/libs/collab-stream/tests/collab_stream_test/stream_group_test.rs @@ -1,5 +1,6 @@ -use crate::stream_test::test_util::{random_i64, stream_client}; -use collab_stream::model::Message; +use crate::collab_stream_test::test_util::{random_i64, stream_client}; +use collab_stream::model::{Message, MessageId}; +use collab_stream::stream_group::ConsumeOptions; use futures::future::join; #[tokio::test] @@ -21,14 +22,17 @@ async fn single_group_read_message_test() { group.insert_message(msg).await.unwrap(); } - let messages = group.fetch_messages("consumer1", 1).await.unwrap(); + let messages = group + .consumer_messages("consumer1", ConsumeOptions::Empty) + .await + .unwrap(); assert_eq!(messages.len(), 1); assert_eq!(messages[0].raw_data, vec![1, 2, 3, 4, 5]); assert_eq!(messages[0].uid, random_uid); // after the message was consumed, it should not be available anymore assert!(group - .fetch_messages("consumer1", 1) + .consumer_messages("consumer1", ConsumeOptions::Count(1)) .await .unwrap() .is_empty()); @@ -54,8 +58,8 @@ async fn different_group_read_message_test() { } let (result1, result2) = join( - group_1.fetch_messages("consumer1", 1), - group_2.fetch_messages("consumer1", 1), + group_1.consumer_messages("consumer1", ConsumeOptions::Empty), + group_2.consumer_messages("consumer1", ConsumeOptions::Empty), ) .await; let group_1_messages = result1.unwrap(); @@ -86,7 +90,10 @@ async fn read_specific_num_of_message_test() { group.insert_messages(messages).await.unwrap(); } - let messages = group_1.fetch_messages("consumer1", 15).await.unwrap(); + let messages = group_1 + .consumer_messages("consumer1", ConsumeOptions::Count(15)) + .await + .unwrap(); assert_eq!(messages.len(), 5); for i in 0..5 { assert_eq!(messages[i].raw_data, vec![1, 2, 3, 4, 5]); @@ -116,9 +123,18 @@ async fn read_all_message_test() { group.insert_messages(messages).await.unwrap(); } - let messages = group.read_all_message().await.unwrap(); - let consumer_messages = group.fetch_messages("consumer1", 15).await.unwrap(); - assert_eq!(messages.len(), 5); + // get all the message in the group + let messages = group.get_all_message().await.unwrap(); + for i in 0..5 { + assert_eq!(messages[i].raw_data, vec![1, 2, 3, 4, 5]); + assert_eq!(messages[i].uid, uids[i]); + } + + // consume all message for given consumer + let consumer_messages = group + .consumer_messages("consumer1", ConsumeOptions::Count(15)) + .await + .unwrap(); assert_eq!(consumer_messages.len(), 5); for i in 0..5 { assert_eq!(messages[i].raw_data, vec![1, 2, 3, 4, 5]); @@ -127,4 +143,31 @@ async fn read_all_message_test() { assert_eq!(messages[i].uid, uids[i]); assert_eq!(consumer_messages[i].uid, uids[i]); } + + // get the pending state + let pending = group.pending_reply().await.unwrap().unwrap(); + assert_eq!(pending.consumers.len(), 1); + assert_eq!(pending.consumers[0].name, "consumer1".to_string(),); + assert_eq!(pending.consumers[0].pending, 5); + + // get pending message start from first message + let mut message_id = MessageId::try_from(pending.start_id.as_str()).unwrap(); + + // try to min 2 millisecond from the message id in order to get all the messages. Otherwise, only + // 4 messages will be returned. + message_id.sub_ms(2); + let pending_messages = group + .get_messages_starting_from_id(Some(message_id.to_string()), pending.count) + .await + .unwrap(); + assert_eq!(pending_messages.len(), 5); + + // ack all messages. + let message_ids = consumer_messages + .iter() + .map(|m| m.message_id.to_string()) + .collect::>(); + group.ack_messages(&message_ids).await.unwrap(); + let pending = group.pending_reply().await.unwrap(); + assert!(pending.is_none()); } diff --git a/libs/collab-stream/tests/stream_test/read_test.rs b/libs/collab-stream/tests/collab_stream_test/stream_test.rs similarity index 95% rename from libs/collab-stream/tests/stream_test/read_test.rs rename to libs/collab-stream/tests/collab_stream_test/stream_test.rs index 374dfc6a..b7e41200 100644 --- a/libs/collab-stream/tests/stream_test/read_test.rs +++ b/libs/collab-stream/tests/collab_stream_test/stream_test.rs @@ -1,4 +1,4 @@ -use crate::stream_test::test_util::{random_i64, stream_client}; +use crate::collab_stream_test::test_util::{random_i64, stream_client}; use collab_stream::model::Message; #[tokio::test] diff --git a/libs/collab-stream/tests/stream_test/test_util.rs b/libs/collab-stream/tests/collab_stream_test/test_util.rs similarity index 100% rename from libs/collab-stream/tests/stream_test/test_util.rs rename to libs/collab-stream/tests/collab_stream_test/test_util.rs diff --git a/libs/collab-stream/tests/main.rs b/libs/collab-stream/tests/main.rs index 27702e4e..e4a94b98 100644 --- a/libs/collab-stream/tests/main.rs +++ b/libs/collab-stream/tests/main.rs @@ -1 +1 @@ -mod stream_test; +mod collab_stream_test; diff --git a/libs/collab-stream/tests/stream_test/mod.rs b/libs/collab-stream/tests/stream_test/mod.rs deleted file mode 100644 index 422f94c3..00000000 --- a/libs/collab-stream/tests/stream_test/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod group_read_test; -mod pubsub_test; -mod read_test; -mod test_util;