feat: group stream (#399)

* feat: group stream

* chore: update
This commit is contained in:
Nathan.fooo 2024-03-20 11:24:31 +08:00 committed by GitHub
parent de92490e26
commit d4845a6784
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 291 additions and 8 deletions

2
Cargo.lock generated
View File

@ -1467,6 +1467,8 @@ name = "collab-stream"
version = "0.1.0"
dependencies = [
"anyhow",
"futures",
"rand 0.8.5",
"redis 0.24.0",
"thiserror",
"tokio",

View File

@ -11,3 +11,7 @@ tokio = { version = "1.26", features = ["rt-multi-thread", "macros" ] }
tokio-stream = { version = "0.1.14" }
thiserror = "1.0.58"
anyhow = "1.0.81"
[dev-dependencies]
futures = "0.3.30"
rand = "0.8.5"

View File

@ -1,4 +1,6 @@
use crate::error::StreamError;
use crate::stream::CollabStream;
use crate::stream_group::CollabStreamGroup;
use redis::aio::ConnectionManager;
pub struct CollabStreamClient {
@ -12,6 +14,22 @@ impl CollabStreamClient {
}
pub async fn stream(&self, workspace_id: &str, oid: &str) -> CollabStream {
CollabStream::new(self.connection_manager.clone(), workspace_id, oid)
CollabStream::new(workspace_id, oid, self.connection_manager.clone())
}
pub async fn group_stream(
&self,
workspace_id: &str,
oid: &str,
group_name: &str,
) -> Result<CollabStreamGroup, StreamError> {
let mut group = CollabStreamGroup::new(
workspace_id,
oid,
group_name,
self.connection_manager.clone(),
);
group.ensure_consumer_group("0").await?;
Ok(group)
}
}

View File

@ -17,6 +17,9 @@ pub enum StreamError {
#[error(transparent)]
ParseIntError(#[from] std::num::ParseIntError),
#[error("Stream group already exists")]
GroupAlreadyExists(String),
#[error("Internal error: {0}")]
Internal(anyhow::Error),
}

View File

@ -2,3 +2,4 @@ pub mod client;
mod error;
pub mod model;
pub mod stream;
pub mod stream_group;

View File

@ -55,6 +55,9 @@ pub struct MessageReadByStreamKey(pub BTreeMap<String, Vec<MessageRead>>);
impl FromRedisValue for MessageReadByStreamKey {
fn from_redis_value(v: &Value) -> RedisResult<Self> {
let mut map: BTreeMap<String, Vec<MessageRead>> = BTreeMap::new();
if matches!(v, Value::Nil) {
return Ok(MessageReadByStreamKey(map));
}
let value_by_id = bulk_from_redis_value(v)?.iter();
for value in value_by_id {

View File

@ -10,7 +10,7 @@ pub struct CollabStream {
}
impl CollabStream {
pub fn new(connection_manager: ConnectionManager, workspace_id: &str, oid: &str) -> Self {
pub fn new(workspace_id: &str, oid: &str, connection_manager: ConnectionManager) -> Self {
let stream_key = format!("af_collab-{}-{}", workspace_id, oid);
Self {
connection_manager,
@ -67,7 +67,7 @@ impl CollabStream {
.map(|ct| format!("{}-{}", ct.timestamp_ms, ct.sequence_number))
.unwrap_or_else(|| "$".to_string());
let options = StreamReadOptions::default().block(100);
let options = StreamReadOptions::default().group("1", "2").block(100);
let map: MessageReadByStreamKey = self
.connection_manager
.xread_options(&[&self.stream_key], &[&id], &options)

View File

@ -0,0 +1,112 @@
use crate::error::StreamError;
use crate::model::{CreatedTime, Message, MessageRead, MessageReadByStreamKey};
use redis::aio::ConnectionManager;
use redis::streams::{StreamMaxlen, StreamReadOptions};
use redis::{pipe, AsyncCommands, RedisError, RedisResult};
#[derive(Clone)]
pub struct CollabStreamGroup {
connection_manager: ConnectionManager,
stream_key: String,
group_name: String,
}
impl CollabStreamGroup {
pub fn new(
workspace_id: &str,
oid: &str,
group_name: &str,
connection_manager: ConnectionManager,
) -> Self {
let group_name = group_name.to_string();
let stream_key = format!("af_collab-{}-{}", workspace_id, oid);
Self {
group_name,
connection_manager,
stream_key,
}
}
/// Ensures the consumer group exists, creating it if necessary.
/// start_id:
/// Use '$' if you want new messages or '0' to read from the beginning.
pub async fn ensure_consumer_group(&mut self, start_id: &str) -> Result<(), StreamError> {
let _: RedisResult<()> = self
.connection_manager
.xgroup_create_mkstream(&self.stream_key, &self.group_name, start_id)
.await;
Ok(())
}
/// Acknowledges messages processed by a consumer.
pub async fn ack_messages(&mut self, message_ids: &[String]) -> Result<(), StreamError> {
self
.connection_manager
.xack(&self.stream_key, &self.group_name, message_ids)
.await?;
Ok(())
}
/// Inserts a single message into the Redis stream.
pub async fn insert_message(&mut self, message: Message) -> Result<CreatedTime, StreamError> {
let tuple = message.into_tuple_array();
let created_time = self
.connection_manager
.xadd(&self.stream_key, "*", tuple.as_slice())
.await?;
Ok(created_time)
}
/// Inserts multiple messages into the Redis stream using a pipeline.
///
pub async fn insert_messages(&mut self, messages: Vec<Message>) -> Result<(), StreamError> {
let mut pipe = pipe();
for message in messages {
let tuple = message.into_tuple_array();
pipe.xadd(&self.stream_key, "*", tuple.as_slice());
}
pipe.query_async(&mut self.connection_manager).await?;
Ok(())
}
/// Fetches number of messages from a Redis stream
/// Returns the messages that were not consumed yet. Which means each message is delivered to only
/// one consumer in the group
pub async fn fetch_messages(
&mut self,
consumer_name: &str,
count: usize,
) -> Result<Vec<Message>, StreamError> {
let options = StreamReadOptions::default()
.group(&self.group_name, consumer_name)
.count(count)
.block(100);
let map: MessageReadByStreamKey = self
.connection_manager
.xread_options(&[&self.stream_key], &[">"], &options)
.await?;
match map.0.into_iter().next() {
None => Ok(Vec::with_capacity(0)),
Some((_, messages)) => Ok(messages.into_iter().map(Into::into).collect()),
}
}
/// Reads all messages from the stream
///
pub async fn read_all_message(&mut self) -> Result<Vec<Message>, StreamError> {
let read_messages: Vec<MessageRead> =
self.connection_manager.xrange_all(&self.stream_key).await?;
Ok(read_messages.into_iter().map(Into::into).collect())
}
pub async fn clear(&mut self) -> Result<(), RedisError> {
self
.connection_manager
.xtrim(&self.stream_key, StreamMaxlen::Equals(0))
.await?;
Ok(())
}
}

View File

@ -0,0 +1,130 @@
use crate::stream_test::test_util::{random_i64, stream_client};
use collab_stream::model::Message;
use futures::future::join;
#[tokio::test]
async fn single_group_read_message_test() {
let workspace_id = "w1";
let oid = format!("o{}", random_i64());
let client = stream_client().await;
let mut group = client.group_stream(workspace_id, &oid, "g1").await.unwrap();
let random_uid = random_i64();
let msg = Message {
uid: random_uid,
raw_data: vec![1, 2, 3, 4, 5],
};
{
let client = stream_client().await;
let mut group = client.group_stream(workspace_id, &oid, "g2").await.unwrap();
group.insert_message(msg).await.unwrap();
}
let messages = group.fetch_messages("consumer1", 1).await.unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].raw_data, vec![1, 2, 3, 4, 5]);
assert_eq!(messages[0].uid, random_uid);
// after the message was consumed, it should not be available anymore
assert!(group
.fetch_messages("consumer1", 1)
.await
.unwrap()
.is_empty());
}
#[tokio::test]
async fn different_group_read_message_test() {
let oid = format!("o{}", random_i64());
let client = stream_client().await;
let mut group_1 = client.group_stream("w1", &oid, "g1").await.unwrap();
let mut group_2 = client.group_stream("w1", &oid, "g2").await.unwrap();
let random_uid = random_i64();
let msg = Message {
uid: random_uid,
raw_data: vec![1, 2, 3, 4, 5],
};
{
let client = stream_client().await;
let mut group = client.group_stream("w1", &oid, "g2").await.unwrap();
group.insert_message(msg).await.unwrap();
}
let (result1, result2) = join(
group_1.fetch_messages("consumer1", 1),
group_2.fetch_messages("consumer1", 1),
)
.await;
let group_1_messages = result1.unwrap();
let group_2_messages = result2.unwrap();
assert_eq!(group_1_messages[0].raw_data, vec![1, 2, 3, 4, 5]);
assert_eq!(group_2_messages[0].raw_data, vec![1, 2, 3, 4, 5]);
}
#[tokio::test]
async fn read_specific_num_of_message_test() {
let object_id = format!("o{}", random_i64());
let client = stream_client().await;
let mut group_1 = client.group_stream("w1", &object_id, "g1").await.unwrap();
let mut uids = vec![];
{
let client = stream_client().await;
let mut group = client.group_stream("w1", &object_id, "g2").await.unwrap();
let mut messages = vec![];
for _i in 0..5 {
let random_uid = random_i64();
uids.push(random_uid);
let msg = Message {
uid: random_uid,
raw_data: vec![1, 2, 3, 4, 5],
};
messages.push(msg);
}
group.insert_messages(messages).await.unwrap();
}
let messages = group_1.fetch_messages("consumer1", 15).await.unwrap();
assert_eq!(messages.len(), 5);
for i in 0..5 {
assert_eq!(messages[i].raw_data, vec![1, 2, 3, 4, 5]);
assert_eq!(messages[i].uid, uids[i]);
}
}
#[tokio::test]
async fn read_all_message_test() {
let object_id = format!("o{}", random_i64());
let client = stream_client().await;
let mut group = client.group_stream("w1", &object_id, "g1").await.unwrap();
let mut uids = vec![];
{
let client = stream_client().await;
let mut group = client.group_stream("w1", &object_id, "g2").await.unwrap();
let mut messages = vec![];
for _i in 0..5 {
let random_uid = random_i64();
uids.push(random_uid);
let msg = Message {
uid: random_uid,
raw_data: vec![1, 2, 3, 4, 5],
};
messages.push(msg);
}
group.insert_messages(messages).await.unwrap();
}
let messages = group.read_all_message().await.unwrap();
let consumer_messages = group.fetch_messages("consumer1", 15).await.unwrap();
assert_eq!(messages.len(), 5);
assert_eq!(consumer_messages.len(), 5);
for i in 0..5 {
assert_eq!(messages[i].raw_data, vec![1, 2, 3, 4, 5]);
assert_eq!(consumer_messages[i].raw_data, vec![1, 2, 3, 4, 5]);
assert_eq!(messages[i].uid, uids[i]);
assert_eq!(consumer_messages[i].uid, uids[i]);
}
}

View File

@ -1,2 +1,3 @@
mod group_read_test;
mod read_test;
mod test_util;

View File

@ -1,10 +1,11 @@
use crate::stream_test::test_util::stream_client;
use crate::stream_test::test_util::{random_i64, stream_client};
use collab_stream::model::Message;
#[tokio::test]
async fn read_single_message_test() {
let oid = format!("o{}", random_i64());
let client_2 = stream_client().await;
let mut stream_2 = client_2.stream("w1", "o1").await;
let mut stream_2 = client_2.stream("w1", &oid).await;
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
tokio::spawn(async move {
@ -19,7 +20,7 @@ async fn read_single_message_test() {
{
let client_1 = stream_client().await;
let mut stream_1 = client_1.stream("w1", "o1").await;
let mut stream_1 = client_1.stream("w1", &oid).await;
stream_1.insert_message(msg).await.unwrap();
}
@ -29,13 +30,14 @@ async fn read_single_message_test() {
#[tokio::test]
async fn read_multiple_messages_test() {
let oid = format!("o{}", random_i64());
let client_2 = stream_client().await;
let mut stream_2 = client_2.stream("w1", "o1").await;
let mut stream_2 = client_2.stream("w1", &oid).await;
stream_2.clear().await.unwrap();
{
let client_1 = stream_client().await;
let mut stream_1 = client_1.stream("w1", "o1").await;
let mut stream_1 = client_1.stream("w1", &oid).await;
let messages = vec![
Message {
uid: 1001,

View File

@ -1,5 +1,6 @@
use anyhow::Context;
use collab_stream::client::CollabStreamClient;
use rand::{thread_rng, Rng};
pub async fn redis_client() -> redis::Client {
let redis_uri = "redis://localhost:6379";
@ -15,3 +16,9 @@ pub async fn stream_client() -> CollabStreamClient {
.context("failed to create stream client")
.unwrap()
}
pub fn random_i64() -> i64 {
let mut rng = thread_rng();
let num: i64 = rng.gen();
num
}