diff --git a/libs/app_error/src/gotrue.rs b/libs/app_error/src/gotrue.rs index edea8528..a54af2e7 100644 --- a/libs/app_error/src/gotrue.rs +++ b/libs/app_error/src/gotrue.rs @@ -1,3 +1,4 @@ +use reqwest::StatusCode; use serde::{Deserialize, Serialize}; use std::fmt::{Display, Formatter}; use thiserror::Error; @@ -22,6 +23,9 @@ pub enum GoTrueError { #[error("{0}")] NotLoggedIn(String), + #[error("{0}")] + Auth(String), + #[error(transparent)] Unhandled(#[from] anyhow::Error), } @@ -49,6 +53,12 @@ impl From for GoTrueError { return GoTrueError::InvalidRequest(value.to_string()); } + if let Some(status) = value.status() { + if status == StatusCode::UNAUTHORIZED { + return GoTrueError::Auth(value.to_string()); + } + } + GoTrueError::Unhandled(value.into()) } } diff --git a/libs/app_error/src/lib.rs b/libs/app_error/src/lib.rs index 1d56f0a7..6804ef46 100644 --- a/libs/app_error/src/lib.rs +++ b/libs/app_error/src/lib.rs @@ -181,7 +181,7 @@ impl From for AppError { GoTrueError::RequestTimeout(msg) => AppError::RequestTimeout(msg), GoTrueError::InvalidRequest(msg) => AppError::InvalidRequest(msg), GoTrueError::ClientError(err) => AppError::OAuthError(err.to_string()), - GoTrueError::Unhandled(err) => AppError::Internal(err), + GoTrueError::Auth(err) => AppError::OAuthError(err), GoTrueError::Internal(err) => match (err.code, err.msg.as_str()) { (400, m) if m.starts_with("oauth error") => AppError::OAuthError(err.msg), (400, m) if m.starts_with("User already registered") => AppError::OAuthError(err.msg), @@ -189,6 +189,7 @@ impl From for AppError { (422, _) => AppError::InvalidRequest(err.msg), _ => AppError::OAuthError(err.to_string()), }, + GoTrueError::Unhandled(err) => AppError::Internal(err), GoTrueError::NotLoggedIn(msg) => AppError::NotLoggedIn(msg), } } diff --git a/libs/client-api/src/collab_sync/sync.rs b/libs/client-api/src/collab_sync/sync.rs index 4159dca3..bcbe8bd3 100644 --- a/libs/client-api/src/collab_sync/sync.rs +++ b/libs/client-api/src/collab_sync/sync.rs @@ -264,16 +264,17 @@ where where P: CollabSyncProtocol + Send + Sync + 'static, { + let payload = msg.payload(); if match msg.msg_id() { // The msg_id is None if the message is [ServerBroadcast] or [ServerAwareness] None => true, Some(msg_id) => sink.ack_msg(msg.origin(), msg.object_id(), msg_id).await, - } && !msg.payload().is_empty() + } && payload.is_some() { trace!("start process message: {:?}", msg.msg_id()); SyncStream::::process_payload( origin, - msg.payload(), + payload.unwrap(), object_id, protocol, collab, diff --git a/libs/client-api/src/http.rs b/libs/client-api/src/http.rs index e551c542..e329deda 100644 --- a/libs/client-api/src/http.rs +++ b/libs/client-api/src/http.rs @@ -374,7 +374,7 @@ impl Client { /// This function attempts to read the current token and, if successful, returns the expiration timestamp. /// /// # Returns - /// - `Ok(i64)`: An `i64` representing the expiration timestamp of the token. + /// - `Ok(i64)`: An `i64` representing the expiration timestamp of the token in seconds. /// - `Err(AppError)`: An `AppError` indicating either an inability to read the token or that the user is not logged in. /// #[inline] diff --git a/libs/client-api/src/ws/client.rs b/libs/client-api/src/ws/client.rs index 33c802b7..b3ada7df 100644 --- a/libs/client-api/src/ws/client.rs +++ b/libs/client-api/src/ws/client.rs @@ -160,22 +160,32 @@ impl WSClient { match msg { RealtimeMessage::Collab(collab_msg) => { if let Some(channels) = weak_channels.upgrade() { - if let Some(channel) = channels.read().get(collab_msg.object_id()) { + let object_id = collab_msg.object_id().to_owned(); + let is_channel_dropped = if let Some(channel) = channels.read().get(&object_id) + { match channel.upgrade() { None => { // when calling [WSClient::subscribe], the caller is responsible for keeping // the channel alive as long as it wants to receive messages from the websocket. warn!("channel is dropped"); + true }, Some(channel) => { channel.forward_to_stream(collab_msg); + false }, } } else { - warn!( - "can't find channel by object_id: {}", - collab_msg.object_id() - ); + warn!("can't find channel of object_id: {}", object_id); + false + }; + + // Try to remove the channel if it is dropped. If failed, will try again next time. + if is_channel_dropped { + if let Some(mut w) = channels.try_write() { + trace!("remove channel: {}", object_id); + w.remove(&object_id); + } } } else { warn!("channels are closed"); diff --git a/libs/realtime-entity/src/collab_msg.rs b/libs/realtime-entity/src/collab_msg.rs index 0f49bad9..f235d258 100644 --- a/libs/realtime-entity/src/collab_msg.rs +++ b/libs/realtime-entity/src/collab_msg.rs @@ -37,6 +37,7 @@ pub enum CollabMessage { ServerInit(ServerCollabInit), ServerAwareness(CollabAwarenessData), ServerBroadcast(CollabBroadcastData), + CloseCollab(CloseCollabData), } impl CollabSinkMessage for CollabMessage { @@ -45,7 +46,7 @@ impl CollabSinkMessage for CollabMessage { } fn length(&self) -> usize { - self.payload().len() + self.payload().map(|payload| payload.len()).unwrap_or(0) } fn can_merge(&self) -> bool { @@ -125,11 +126,15 @@ impl CollabMessage { CollabMessage::ServerInit(value) => Some(value.msg_id), CollabMessage::ServerBroadcast(_) => None, CollabMessage::ServerAwareness(_) => None, + CollabMessage::CloseCollab(_) => None, } } pub fn is_empty(&self) -> bool { - self.payload().is_empty() + self + .payload() + .map(|payload| payload.is_empty()) + .unwrap_or(true) } pub fn origin(&self) -> Option<&CollabOrigin> { @@ -140,6 +145,7 @@ impl CollabMessage { CollabMessage::ServerInit(value) => Some(&value.origin), CollabMessage::ServerBroadcast(value) => Some(&value.origin), CollabMessage::ServerAwareness(_) => None, + CollabMessage::CloseCollab(value) => Some(&value.origin), } } @@ -155,6 +161,7 @@ impl CollabMessage { CollabMessage::ServerInit(value) => &value.object_id, CollabMessage::ServerBroadcast(value) => &value.object_id, CollabMessage::ServerAwareness(value) => &value.object_id, + CollabMessage::CloseCollab(value) => &value.object_id, } } } @@ -198,6 +205,9 @@ impl Display for CollabMessage { value.object_id, value.payload.len(), )), + CollabMessage::CloseCollab(value) => { + f.write_fmt(format_args!("close collab: [oid:{}]", value.object_id,)) + }, } } } @@ -217,14 +227,15 @@ impl CollabMessage { serde_json::from_slice(data) } - pub fn payload(&self) -> &Bytes { + pub fn payload(&self) -> Option<&Bytes> { match self { - CollabMessage::ClientInit(value) => &value.payload, - CollabMessage::ClientUpdateSync(value) => &value.payload, - CollabMessage::ClientUpdateAck(value) => &value.payload, - CollabMessage::ServerInit(value) => &value.payload, - CollabMessage::ServerBroadcast(value) => &value.payload, - CollabMessage::ServerAwareness(value) => &value.payload, + CollabMessage::ClientInit(value) => Some(&value.payload), + CollabMessage::ClientUpdateSync(value) => Some(&value.payload), + CollabMessage::ClientUpdateAck(value) => Some(&value.payload), + CollabMessage::ServerInit(value) => Some(&value.payload), + CollabMessage::ServerBroadcast(value) => Some(&value.payload), + CollabMessage::ServerAwareness(value) => Some(&value.payload), + CollabMessage::CloseCollab(_) => None, } } } @@ -469,6 +480,17 @@ impl TryFrom for CollabMessage { } } +#[derive(Clone, Eq, PartialEq, Debug, Serialize, Deserialize)] +pub struct CloseCollabData { + origin: CollabOrigin, + object_id: String, +} + +impl From for CollabMessage { + fn from(value: CloseCollabData) -> Self { + CollabMessage::CloseCollab(value) + } +} impl From for RealtimeMessage { fn from(msg: CollabMessage) -> Self { Self::Collab(msg) diff --git a/libs/realtime/src/client.rs b/libs/realtime/src/client.rs index b8a40a8a..f931451b 100644 --- a/libs/realtime/src/client.rs +++ b/libs/realtime/src/client.rs @@ -12,7 +12,7 @@ use database::collab::CollabStorage; pub use realtime_entity::user::RealtimeUserImpl; use std::ops::Deref; use std::time::{Duration, Instant}; -use tracing::{error, trace}; +use tracing::error; pub struct ClientSession< U: Unpin + RealtimeUser, @@ -133,10 +133,7 @@ where fn handle(&mut self, msg: RealtimeMessage, ctx: &mut Self::Context) { match &msg { - RealtimeMessage::Collab(collab_msg) => { - trace!("{:?}: receives collab message: {:?}", self.user, collab_msg); - ctx.binary(msg) - }, + RealtimeMessage::Collab(_collab_msg) => ctx.binary(msg), RealtimeMessage::ServerKickedOff => { // The server will send this message to the client when the client is kicked out. So // set the current user to None and stop the session. diff --git a/libs/realtime/src/collaborate/broadcast.rs b/libs/realtime/src/collaborate/broadcast.rs index 010bc1f4..9a917991 100644 --- a/libs/realtime/src/collaborate/broadcast.rs +++ b/libs/realtime/src/collaborate/broadcast.rs @@ -46,13 +46,14 @@ impl CollabBroadcast { /// provided `buffer_capacity` size. pub fn new(object_id: &str, collab: MutexCollab, buffer_capacity: usize) -> Self { let object_id = object_id.to_owned(); + // broadcast channel let (sender, _) = channel(buffer_capacity); let (doc_sub, awareness_sub) = { let mut mutex_collab = collab.lock(); // Observer the document's update and broadcast it to all subscribers. let cloned_oid = object_id.clone(); - let sink = sender.clone(); + let broadcast_sink = sender.clone(); let doc_sub = mutex_collab .get_mut_awareness() .doc_mut() @@ -60,13 +61,13 @@ impl CollabBroadcast { let origin = CollabOrigin::from(txn); let payload = gen_update_message(&event.update); let msg = CollabBroadcastData::new(origin, cloned_oid.clone(), payload); - if let Err(_e) = sink.send(msg.into()) { - trace!("Broadcast group is closed"); + if let Err(e) = broadcast_sink.send(msg.into()) { + error!("broadcast sink fail: {}", e); } }) .unwrap(); - let sink = sender.clone(); + let broadcast_sink = sender.clone(); let cloned_oid = object_id.clone(); // Observer the awareness's update and broadcast it to all subscribers. @@ -76,7 +77,7 @@ impl CollabBroadcast { if let Ok(awareness_update) = gen_awareness_update_message(awareness, event) { let payload = Message::Awareness(awareness_update).encode_v1(); let msg = CollabAwarenessData::new(cloned_oid.clone(), payload); - if let Err(_e) = sink.send(msg.into()) { + if let Err(_e) = broadcast_sink.send(msg.into()) { trace!("Broadcast group is closed"); } } @@ -178,7 +179,13 @@ impl CollabBroadcast { error!("[🔴Server]: Incoming message's object id does not match the broadcast group's object id"); continue; } - let mut decoder = DecoderV1::from(collab_msg.payload().as_ref()); + + let payload = collab_msg.payload(); + if payload.is_none() { + continue; + } + + let mut decoder = DecoderV1::from(payload.unwrap().as_ref()); match sink.try_lock() { Ok(mut sink) => { let reader = MessageReader::new(&mut decoder); diff --git a/libs/realtime/src/collaborate/group.rs b/libs/realtime/src/collaborate/group.rs index 288bdb6f..1145604b 100644 --- a/libs/realtime/src/collaborate/group.rs +++ b/libs/realtime/src/collaborate/group.rs @@ -12,7 +12,7 @@ use std::sync::Arc; use tokio::sync::RwLock; use tokio::task::spawn_blocking; -use tracing::{error, event, warn}; +use tracing::{error, event, trace, warn}; pub struct CollabGroupCache { group_by_object_id: Arc>>>>, @@ -43,6 +43,15 @@ where } } + pub async fn remove_user(&self, object_id: &str, user: &U) { + let group_by_object_id = self.group_by_object_id.read().await; + if let Some(group) = group_by_object_id.get(object_id) { + if let Some(subscriber) = group.subscribers.write().await.remove(user) { + trace!("Remove subscriber: {}", subscriber.origin); + } + } + } + pub async fn contains_group(&self, object_id: &str) -> Result { let group_by_object_id = self.group_by_object_id.try_read()?; Ok(group_by_object_id.get(object_id).is_some()) diff --git a/libs/realtime/src/collaborate/retry.rs b/libs/realtime/src/collaborate/retry.rs index b54cd6c5..99ec8037 100644 --- a/libs/realtime/src/collaborate/retry.rs +++ b/libs/realtime/src/collaborate/retry.rs @@ -5,7 +5,7 @@ use collab::core::origin::CollabOrigin; use database::collab::CollabStorage; use futures_util::SinkExt; use parking_lot::Mutex; -use realtime_entity::collab_msg::CollabMessage; +use realtime_entity::collab_msg::{CollabMessage, CollabSinkMessage}; use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; use std::future; @@ -103,14 +103,20 @@ where } } - // If the client's stream is already subscribe to the collab, return. - if self - .groups - .contains_user(object_id, user) - .await - .unwrap_or(false) - { - return Ok(()); + // If the message is init sync message, which means the client just open the collab again. So + // remove the user from the group first and then subscribe the client's stream to the group. + if collab_message.is_init_msg() { + self.groups.remove_user(object_id, user).await; + } else { + // If the client's stream is already subscribe to the collab, return. + if self + .groups + .contains_user(object_id, user) + .await + .unwrap_or(false) + { + return Ok(()); + } } let origin = match collab_message.origin() { @@ -118,7 +124,7 @@ where error!("🔴The origin from client message is empty"); &CollabOrigin::Empty }, - Some(client) => client, + Some(origin) => origin, }; match self.client_stream_by_user.write().await.get_mut(user) { None => warn!("The client stream is not found"), @@ -234,20 +240,16 @@ where } debug!( - "Group: {} has {} members", - object_id, - collab_group.subscribers.read().await.len() - ); - trace!( - "Group: {} members: {:?}", + "{}: Group member: {}. member ids: {:?}", object_id, + collab_group.subscribers.read().await.len(), collab_group .subscribers .read() .await .values() - .map(|value| &value.origin) - .collect::>() + .map(|value| value.origin.client_user_id()) + .collect::>(), ); } }, diff --git a/libs/realtime/src/collaborate/server.rs b/libs/realtime/src/collaborate/server.rs index 4c558c69..beccabd4 100644 --- a/libs/realtime/src/collaborate/server.rs +++ b/libs/realtime/src/collaborate/server.rs @@ -131,7 +131,6 @@ where let editing_collab_by_user = self.editing_collab_by_user.clone(); Box::pin(async move { remove_user(&groups, &editing_collab_by_user, &msg.user).await; - if client_stream_by_user .write() .await