chore: redis stream for awareness

This commit is contained in:
Bartosz Sypytkowski 2024-10-08 06:23:03 +02:00
parent 1d7e35c2b9
commit f0b907157e
6 changed files with 212 additions and 22 deletions

View File

@ -1,6 +1,5 @@
use std::fmt::{Debug, Display, Formatter};
use collab::core::awareness::AwarenessUpdate;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use yrs::updates::decoder::{Decode, Decoder};
@ -22,7 +21,7 @@ pub const PERMISSION_GRANTED: u8 = 1;
pub enum Message {
Sync(SyncMessage),
Auth(Option<String>),
Awareness(AwarenessUpdate),
Awareness(Vec<u8>),
Custom(CustomMessage),
}
@ -44,7 +43,7 @@ impl Encode for Message {
},
Message::Awareness(update) => {
encoder.write_var(MSG_AWARENESS);
encoder.write_buf(update.encode_v1())
encoder.write_buf(&update)
},
Message::Custom(msg) => {
encoder.write_var(MSG_CUSTOM);
@ -64,8 +63,7 @@ impl Decode for Message {
},
MSG_AWARENESS => {
let data = decoder.read_buf()?;
let update = AwarenessUpdate::decode_v1(data)?;
Ok(Message::Awareness(update))
Ok(Message::Awareness(data.into()))
},
MSG_AUTH => {
let reason = if decoder.read_var::<u8>()? == PERMISSION_DENIED {

View File

@ -111,6 +111,7 @@ pub trait CollabSyncProtocol {
Message::Auth(reason) => self.handle_auth(collab, reason).await,
//FIXME: where is the QueryAwareness protocol?
Message::Awareness(update) => {
let update = AwarenessUpdate::decode_v1(&update)?;
self
.handle_awareness_update(message_origin, collab, update)
.await
@ -135,7 +136,7 @@ pub trait CollabSyncProtocol {
.map_err(|e| RTProtocolError::YrsTransaction(e.to_string()))?
.state_vector();
let awareness_update = awareness.update()?;
(state_vector, awareness_update)
(state_vector, awareness_update.encode_v1())
};
// 1. encode doc state vector

View File

@ -1,7 +1,10 @@
use crate::collab_update_sink::CollabUpdateSink;
use crate::collab_update_sink::{AwarenessUpdateSink, CollabUpdateSink};
use crate::error::StreamError;
use crate::lease::{Lease, LeaseAcquisition};
use crate::model::{CollabStreamUpdate, CollabStreamUpdateBatch, CollabUpdateEvent, MessageId};
use crate::model::{
AwarenessStreamUpdate, AwarenessStreamUpdateBatch, CollabStreamUpdate, CollabStreamUpdateBatch,
CollabUpdateEvent, MessageId,
};
use crate::pubsub::{CollabStreamPub, CollabStreamSub};
use crate::stream_group::{StreamConfig, StreamGroup};
use futures::Stream;
@ -88,6 +91,11 @@ impl CollabRedisStream {
CollabUpdateSink::new(self.connection_manager.clone(), stream_key)
}
pub fn awareness_update_sink(&self, workspace_id: &str, object_id: &str) -> AwarenessUpdateSink {
let stream_key = AwarenessStreamUpdate::stream_key(workspace_id, object_id);
AwarenessUpdateSink::new(self.connection_manager.clone(), stream_key)
}
pub fn collab_updates(
&self,
workspace_id: &str,
@ -112,4 +120,29 @@ impl CollabRedisStream {
}
}
}
pub fn awareness_updates(
&self,
workspace_id: &str,
object_id: &str,
since: Option<MessageId>,
) -> impl Stream<Item = Result<AwarenessStreamUpdate, StreamError>> {
// use `:` separator as it adheres to Redis naming conventions
let mut conn = self.connection_manager.clone();
let stream_key = AwarenessStreamUpdate::stream_key(workspace_id, object_id);
let read_options = StreamReadOptions::default().count(100);
let mut since = since.unwrap_or_default();
async_stream::try_stream! {
loop {
let last_id = since.to_string();
let batch: AwarenessStreamUpdateBatch = conn
.xread_options(&[&stream_key], &[&last_id], &read_options)
.await?;
for (message_id, update) in batch.updates {
since = since.max(message_id);
yield update;
}
}
}
}
}

View File

@ -1,5 +1,5 @@
use crate::error::StreamError;
use crate::model::{CollabStreamUpdate, MessageId};
use crate::model::{AwarenessStreamUpdate, CollabStreamUpdate, MessageId};
use collab::preclude::updates::encoder::Encode;
use redis::aio::ConnectionManager;
use redis::cmd;
@ -37,3 +37,31 @@ impl CollabUpdateSink {
Ok(msg_id)
}
}
pub struct AwarenessUpdateSink {
conn: Mutex<ConnectionManager>,
stream_key: String,
}
impl AwarenessUpdateSink {
pub fn new(conn: ConnectionManager, stream_key: String) -> Self {
AwarenessUpdateSink {
conn: conn.into(),
stream_key,
}
}
pub async fn send(&self, msg: &AwarenessStreamUpdate) -> Result<MessageId, StreamError> {
let mut lock = self.conn.lock().await;
let msg_id: MessageId = cmd("XADD")
.arg(&self.stream_key)
.arg("*")
.arg("sender")
.arg(msg.sender.to_string())
.arg("data")
.arg(&*msg.data)
.query_async(&mut *lock)
.await?;
Ok(msg_id)
}
}

View File

@ -368,7 +368,7 @@ impl TryFrom<CollabUpdateEvent> for StreamBinary {
}
pub struct CollabStreamUpdate {
pub data: Vec<u8>,
pub data: Vec<u8>, // yrs::Update::encode_v1
pub state_vector: StateVector,
pub sender: CollabOrigin,
pub flags: UpdateFlags,
@ -390,7 +390,7 @@ impl CollabStreamUpdate {
/// Returns Redis stream key, that's storing entries mapped to/from [CollabStreamUpdate].
pub fn stream_key(workspace_id: &str, object_id: &str) -> String {
format!("af_update:{}:{}", workspace_id, object_id)
format!("af:{}:{}:updates", workspace_id, object_id)
}
}
@ -451,6 +451,53 @@ impl FromRedisValue for CollabStreamUpdateBatch {
}
}
pub struct AwarenessStreamUpdate {
pub data: Vec<u8>, // AwarenessUpdate::encode_v1
pub sender: CollabOrigin,
}
impl AwarenessStreamUpdate {
/// Returns Redis stream key, that's storing entries mapped to/from [AwarenessStreamUpdate].
pub fn stream_key(workspace_id: &str, object_id: &str) -> String {
format!("af:{}:{}:awareness", workspace_id, object_id)
}
}
pub(crate) struct AwarenessStreamUpdateBatch {
pub updates: BTreeMap<MessageId, AwarenessStreamUpdate>,
}
impl FromRedisValue for AwarenessStreamUpdateBatch {
fn from_redis_value(v: &Value) -> RedisResult<Self> {
let sr: SRRows = SRRows::from_redis_value(v)?;
let mut updates = BTreeMap::new();
for stream in sr {
for (_stream_key, messages) in stream {
for message in messages {
for (message_id, fields) in message {
let message_id =
MessageId::try_from(message_id).map_err(|e| internal(e.to_string()))?;
let sender = match fields.get("sender") {
None => CollabOrigin::Empty,
Some(sender) => {
let raw_origin = String::from_redis_value(sender)?;
let collab_origin = collab_origin_from_str(&raw_origin)?;
collab_origin
},
};
let data_raw = fields
.get("data")
.ok_or_else(|| internal("expecting field `data`"))?;
let data: Vec<u8> = FromRedisValue::from_redis_value(data_raw)?;
updates.insert(message_id, AwarenessStreamUpdate { data, sender });
}
}
}
}
Ok(AwarenessStreamUpdateBatch { updates })
}
}
//FIXME: this should be `impl FromStr for CollabOrigin`
fn collab_origin_from_str(value: &str) -> RedisResult<CollabOrigin> {
match value {

View File

@ -12,14 +12,17 @@ use collab::lock::RwLock;
use collab::preclude::Collab;
use collab_entity::CollabType;
use collab_rt_entity::user::RealtimeUser;
use collab_rt_entity::{AckCode, BroadcastSync, CollabAck, MessageByObjectId, MsgId};
use collab_rt_entity::{
AckCode, AwarenessSync, BroadcastSync, CollabAck, MessageByObjectId, MsgId,
};
use collab_rt_entity::{ClientCollabMessage, CollabMessage};
use collab_rt_protocol::{decode_update, Message, MessageReader, RTProtocolError, SyncMessage};
use collab_stream::client::CollabRedisStream;
use collab_stream::collab_update_sink::CollabUpdateSink;
use collab_stream::collab_update_sink::{AwarenessUpdateSink, CollabUpdateSink};
use collab_stream::error::StreamError;
use collab_stream::model::{
CollabStreamUpdate, CollabUpdateEvent, MessageId, StreamBinary, UpdateFlags,
AwarenessStreamUpdate, CollabStreamUpdate, CollabUpdateEvent, MessageId, StreamBinary,
UpdateFlags,
};
use collab_stream::stream_group::StreamGroup;
use dashmap::DashMap;
@ -124,12 +127,22 @@ impl CollabGroup {
tasks and triggered when this `CollabGroup` is dropped.
*/
// setup task used to receive messages from Redis
// setup task used to receive collab updates from Redis
{
let state = state.clone();
tokio::spawn(async move {
if let Err(err) = Self::inbound_task(state).await {
tracing::warn!("failed to receive message: {}", err);
tracing::warn!("failed to receive collab update: {}", err);
}
});
}
// setup task used to receive awareness updates from Redis
{
let state = state.clone();
tokio::spawn(async move {
if let Err(err) = Self::inbound_awareness_task(state).await {
tracing::warn!("failed to receive awareness update: {}", err);
}
});
}
@ -156,7 +169,7 @@ impl CollabGroup {
&self.state.object_id
}
/// Task used to receive messages from Redis.
/// Task used to receive collab updates from Redis.
async fn inbound_task(state: Arc<CollabGroupState>) -> Result<(), RealtimeError> {
let mut updates = state.persister.collab_redis_stream.collab_updates(
&state.workspace_id,
@ -219,6 +232,64 @@ impl CollabGroup {
}
}
/// Task used to receive awareness updates from Redis.
async fn inbound_awareness_task(state: Arc<CollabGroupState>) -> Result<(), RealtimeError> {
let mut updates = state.persister.collab_redis_stream.awareness_updates(
&state.workspace_id,
&state.object_id,
None,
);
pin_mut!(updates);
loop {
tokio::select! {
_ = state.shutdown.cancelled() => {
break;
}
res = updates.next() => {
match res {
Some(Ok(awareness_update)) => {
Self::handle_inbound_awareness(&state, awareness_update).await;
},
Some(Err(err)) => {
tracing::warn!("failed to handle incoming update for collab `{}`: {}", state.object_id, err);
break;
},
None => {
break;
}
}
}
}
}
Ok(())
}
async fn handle_inbound_awareness(state: &CollabGroupState, update: AwarenessStreamUpdate) {
let sender = update.sender;
let message = AwarenessSync::new(
state.object_id.clone(),
Message::Awareness(update.data).encode_v1(),
sender.clone(),
);
for mut e in state.subscribers.iter_mut() {
let subscription = e.value_mut();
if sender == subscription.collab_origin {
continue; // don't send update to its sender
}
if let Err(err) = subscription.sink.send(message.clone().into()).await {
tracing::debug!(
"failed to send awareness `{}` update to `{}`: {}",
state.object_id,
subscription.collab_origin,
err
);
}
state.last_activity.store(Arc::new(Instant::now()));
}
}
async fn snapshot_task(state: Arc<CollabGroupState>, interval: Duration, is_new_collab: bool) {
if is_new_collab {
if let Err(err) = state.persister.save().await {
@ -606,7 +677,7 @@ impl CollabGroup {
async fn handle_awareness_update(
state: &CollabGroupState,
origin: &CollabOrigin,
update: AwarenessUpdate,
update: Vec<u8>,
) -> Result<Option<Vec<u8>>, RTProtocolError> {
state
.persister
@ -791,6 +862,7 @@ struct CollabPersister {
/// Collab stored temporarily.
temp_collab: ArcSwapOption<CollabSnapshot>,
update_sink: CollabUpdateSink,
awareness_sink: AwarenessUpdateSink,
}
impl CollabPersister {
@ -803,6 +875,7 @@ impl CollabPersister {
indexer: Option<Arc<dyn Indexer>>,
) -> Self {
let update_sink = collab_redis_stream.collab_update_sink(&workspace_id, &object_id);
let awareness_sink = collab_redis_stream.awareness_update_sink(&workspace_id, &object_id);
Self {
workspace_id,
object_id,
@ -811,6 +884,7 @@ impl CollabPersister {
collab_redis_stream,
indexer,
update_sink,
awareness_sink,
temp_collab: Default::default(),
}
}
@ -842,10 +916,18 @@ impl CollabPersister {
async fn send_awareness(
&self,
sender_session: &CollabOrigin,
awareness_update: AwarenessUpdate,
awareness_update: Vec<u8>,
) -> Result<MessageId, StreamError> {
// send awareness updates to redis queue: is it needed? What are we using awareness for here?
todo!()
// send awareness updates to redis queue:
// QUESTION: is it needed? Maybe we could reuse update_sink?
let msg_id = self
.awareness_sink
.send(&AwarenessStreamUpdate {
data: awareness_update,
sender: sender_session.clone(),
})
.await?;
Ok(msg_id)
}
async fn load(&self) -> Result<Arc<CollabSnapshot>, RealtimeError> {
@ -878,10 +960,11 @@ impl CollabPersister {
}
async fn save(&self) -> Result<(), RealtimeError> {
// 1. try to acquire lock
// 1. try to acquire lease
if let Some(lease) = self
.collab_redis_stream
.lease(&self.workspace_id, &self.object_id)
.await?
{
// 3. if collab has any changes (any redis updates were applied):
// 4. generate embeddings