chore: remove subscriber when receive init message (#156)

This commit is contained in:
Nathan.fooo 2023-11-09 16:52:09 +08:00 committed by GitHub
parent f626e4a3b2
commit eb633c2ba4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 107 additions and 49 deletions

View File

@ -1,3 +1,4 @@
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter};
use thiserror::Error;
@ -22,6 +23,9 @@ pub enum GoTrueError {
#[error("{0}")]
NotLoggedIn(String),
#[error("{0}")]
Auth(String),
#[error(transparent)]
Unhandled(#[from] anyhow::Error),
}
@ -49,6 +53,12 @@ impl From<reqwest::Error> for GoTrueError {
return GoTrueError::InvalidRequest(value.to_string());
}
if let Some(status) = value.status() {
if status == StatusCode::UNAUTHORIZED {
return GoTrueError::Auth(value.to_string());
}
}
GoTrueError::Unhandled(value.into())
}
}

View File

@ -181,7 +181,7 @@ impl From<crate::gotrue::GoTrueError> for AppError {
GoTrueError::RequestTimeout(msg) => AppError::RequestTimeout(msg),
GoTrueError::InvalidRequest(msg) => AppError::InvalidRequest(msg),
GoTrueError::ClientError(err) => AppError::OAuthError(err.to_string()),
GoTrueError::Unhandled(err) => AppError::Internal(err),
GoTrueError::Auth(err) => AppError::OAuthError(err),
GoTrueError::Internal(err) => match (err.code, err.msg.as_str()) {
(400, m) if m.starts_with("oauth error") => AppError::OAuthError(err.msg),
(400, m) if m.starts_with("User already registered") => AppError::OAuthError(err.msg),
@ -189,6 +189,7 @@ impl From<crate::gotrue::GoTrueError> for AppError {
(422, _) => AppError::InvalidRequest(err.msg),
_ => AppError::OAuthError(err.to_string()),
},
GoTrueError::Unhandled(err) => AppError::Internal(err),
GoTrueError::NotLoggedIn(msg) => AppError::NotLoggedIn(msg),
}
}

View File

@ -264,16 +264,17 @@ where
where
P: CollabSyncProtocol + Send + Sync + 'static,
{
let payload = msg.payload();
if match msg.msg_id() {
// The msg_id is None if the message is [ServerBroadcast] or [ServerAwareness]
None => true,
Some(msg_id) => sink.ack_msg(msg.origin(), msg.object_id(), msg_id).await,
} && !msg.payload().is_empty()
} && payload.is_some()
{
trace!("start process message: {:?}", msg.msg_id());
SyncStream::<Sink, Stream>::process_payload(
origin,
msg.payload(),
payload.unwrap(),
object_id,
protocol,
collab,

View File

@ -374,7 +374,7 @@ impl Client {
/// This function attempts to read the current token and, if successful, returns the expiration timestamp.
///
/// # Returns
/// - `Ok(i64)`: An `i64` representing the expiration timestamp of the token.
/// - `Ok(i64)`: An `i64` representing the expiration timestamp of the token in seconds.
/// - `Err(AppError)`: An `AppError` indicating either an inability to read the token or that the user is not logged in.
///
#[inline]

View File

@ -160,22 +160,32 @@ impl WSClient {
match msg {
RealtimeMessage::Collab(collab_msg) => {
if let Some(channels) = weak_channels.upgrade() {
if let Some(channel) = channels.read().get(collab_msg.object_id()) {
let object_id = collab_msg.object_id().to_owned();
let is_channel_dropped = if let Some(channel) = channels.read().get(&object_id)
{
match channel.upgrade() {
None => {
// when calling [WSClient::subscribe], the caller is responsible for keeping
// the channel alive as long as it wants to receive messages from the websocket.
warn!("channel is dropped");
true
},
Some(channel) => {
channel.forward_to_stream(collab_msg);
false
},
}
} else {
warn!(
"can't find channel by object_id: {}",
collab_msg.object_id()
);
warn!("can't find channel of object_id: {}", object_id);
false
};
// Try to remove the channel if it is dropped. If failed, will try again next time.
if is_channel_dropped {
if let Some(mut w) = channels.try_write() {
trace!("remove channel: {}", object_id);
w.remove(&object_id);
}
}
} else {
warn!("channels are closed");

View File

@ -37,6 +37,7 @@ pub enum CollabMessage {
ServerInit(ServerCollabInit),
ServerAwareness(CollabAwarenessData),
ServerBroadcast(CollabBroadcastData),
CloseCollab(CloseCollabData),
}
impl CollabSinkMessage for CollabMessage {
@ -45,7 +46,7 @@ impl CollabSinkMessage for CollabMessage {
}
fn length(&self) -> usize {
self.payload().len()
self.payload().map(|payload| payload.len()).unwrap_or(0)
}
fn can_merge(&self) -> bool {
@ -125,11 +126,15 @@ impl CollabMessage {
CollabMessage::ServerInit(value) => Some(value.msg_id),
CollabMessage::ServerBroadcast(_) => None,
CollabMessage::ServerAwareness(_) => None,
CollabMessage::CloseCollab(_) => None,
}
}
pub fn is_empty(&self) -> bool {
self.payload().is_empty()
self
.payload()
.map(|payload| payload.is_empty())
.unwrap_or(true)
}
pub fn origin(&self) -> Option<&CollabOrigin> {
@ -140,6 +145,7 @@ impl CollabMessage {
CollabMessage::ServerInit(value) => Some(&value.origin),
CollabMessage::ServerBroadcast(value) => Some(&value.origin),
CollabMessage::ServerAwareness(_) => None,
CollabMessage::CloseCollab(value) => Some(&value.origin),
}
}
@ -155,6 +161,7 @@ impl CollabMessage {
CollabMessage::ServerInit(value) => &value.object_id,
CollabMessage::ServerBroadcast(value) => &value.object_id,
CollabMessage::ServerAwareness(value) => &value.object_id,
CollabMessage::CloseCollab(value) => &value.object_id,
}
}
}
@ -198,6 +205,9 @@ impl Display for CollabMessage {
value.object_id,
value.payload.len(),
)),
CollabMessage::CloseCollab(value) => {
f.write_fmt(format_args!("close collab: [oid:{}]", value.object_id,))
},
}
}
}
@ -217,14 +227,15 @@ impl CollabMessage {
serde_json::from_slice(data)
}
pub fn payload(&self) -> &Bytes {
pub fn payload(&self) -> Option<&Bytes> {
match self {
CollabMessage::ClientInit(value) => &value.payload,
CollabMessage::ClientUpdateSync(value) => &value.payload,
CollabMessage::ClientUpdateAck(value) => &value.payload,
CollabMessage::ServerInit(value) => &value.payload,
CollabMessage::ServerBroadcast(value) => &value.payload,
CollabMessage::ServerAwareness(value) => &value.payload,
CollabMessage::ClientInit(value) => Some(&value.payload),
CollabMessage::ClientUpdateSync(value) => Some(&value.payload),
CollabMessage::ClientUpdateAck(value) => Some(&value.payload),
CollabMessage::ServerInit(value) => Some(&value.payload),
CollabMessage::ServerBroadcast(value) => Some(&value.payload),
CollabMessage::ServerAwareness(value) => Some(&value.payload),
CollabMessage::CloseCollab(_) => None,
}
}
}
@ -469,6 +480,17 @@ impl TryFrom<RealtimeMessage> for CollabMessage {
}
}
#[derive(Clone, Eq, PartialEq, Debug, Serialize, Deserialize)]
pub struct CloseCollabData {
origin: CollabOrigin,
object_id: String,
}
impl From<CloseCollabData> for CollabMessage {
fn from(value: CloseCollabData) -> Self {
CollabMessage::CloseCollab(value)
}
}
impl From<CollabMessage> for RealtimeMessage {
fn from(msg: CollabMessage) -> Self {
Self::Collab(msg)

View File

@ -12,7 +12,7 @@ use database::collab::CollabStorage;
pub use realtime_entity::user::RealtimeUserImpl;
use std::ops::Deref;
use std::time::{Duration, Instant};
use tracing::{error, trace};
use tracing::error;
pub struct ClientSession<
U: Unpin + RealtimeUser,
@ -133,10 +133,7 @@ where
fn handle(&mut self, msg: RealtimeMessage, ctx: &mut Self::Context) {
match &msg {
RealtimeMessage::Collab(collab_msg) => {
trace!("{:?}: receives collab message: {:?}", self.user, collab_msg);
ctx.binary(msg)
},
RealtimeMessage::Collab(_collab_msg) => ctx.binary(msg),
RealtimeMessage::ServerKickedOff => {
// The server will send this message to the client when the client is kicked out. So
// set the current user to None and stop the session.

View File

@ -46,13 +46,14 @@ impl CollabBroadcast {
/// provided `buffer_capacity` size.
pub fn new(object_id: &str, collab: MutexCollab, buffer_capacity: usize) -> Self {
let object_id = object_id.to_owned();
// broadcast channel
let (sender, _) = channel(buffer_capacity);
let (doc_sub, awareness_sub) = {
let mut mutex_collab = collab.lock();
// Observer the document's update and broadcast it to all subscribers.
let cloned_oid = object_id.clone();
let sink = sender.clone();
let broadcast_sink = sender.clone();
let doc_sub = mutex_collab
.get_mut_awareness()
.doc_mut()
@ -60,13 +61,13 @@ impl CollabBroadcast {
let origin = CollabOrigin::from(txn);
let payload = gen_update_message(&event.update);
let msg = CollabBroadcastData::new(origin, cloned_oid.clone(), payload);
if let Err(_e) = sink.send(msg.into()) {
trace!("Broadcast group is closed");
if let Err(e) = broadcast_sink.send(msg.into()) {
error!("broadcast sink fail: {}", e);
}
})
.unwrap();
let sink = sender.clone();
let broadcast_sink = sender.clone();
let cloned_oid = object_id.clone();
// Observer the awareness's update and broadcast it to all subscribers.
@ -76,7 +77,7 @@ impl CollabBroadcast {
if let Ok(awareness_update) = gen_awareness_update_message(awareness, event) {
let payload = Message::Awareness(awareness_update).encode_v1();
let msg = CollabAwarenessData::new(cloned_oid.clone(), payload);
if let Err(_e) = sink.send(msg.into()) {
if let Err(_e) = broadcast_sink.send(msg.into()) {
trace!("Broadcast group is closed");
}
}
@ -178,7 +179,13 @@ impl CollabBroadcast {
error!("[🔴Server]: Incoming message's object id does not match the broadcast group's object id");
continue;
}
let mut decoder = DecoderV1::from(collab_msg.payload().as_ref());
let payload = collab_msg.payload();
if payload.is_none() {
continue;
}
let mut decoder = DecoderV1::from(payload.unwrap().as_ref());
match sink.try_lock() {
Ok(mut sink) => {
let reader = MessageReader::new(&mut decoder);

View File

@ -12,7 +12,7 @@ use std::sync::Arc;
use tokio::sync::RwLock;
use tokio::task::spawn_blocking;
use tracing::{error, event, warn};
use tracing::{error, event, trace, warn};
pub struct CollabGroupCache<S, U, AC> {
group_by_object_id: Arc<RwLock<HashMap<String, Arc<CollabGroup<U>>>>>,
@ -43,6 +43,15 @@ where
}
}
pub async fn remove_user(&self, object_id: &str, user: &U) {
let group_by_object_id = self.group_by_object_id.read().await;
if let Some(group) = group_by_object_id.get(object_id) {
if let Some(subscriber) = group.subscribers.write().await.remove(user) {
trace!("Remove subscriber: {}", subscriber.origin);
}
}
}
pub async fn contains_group(&self, object_id: &str) -> Result<bool, Error> {
let group_by_object_id = self.group_by_object_id.try_read()?;
Ok(group_by_object_id.get(object_id).is_some())

View File

@ -5,7 +5,7 @@ use collab::core::origin::CollabOrigin;
use database::collab::CollabStorage;
use futures_util::SinkExt;
use parking_lot::Mutex;
use realtime_entity::collab_msg::CollabMessage;
use realtime_entity::collab_msg::{CollabMessage, CollabSinkMessage};
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::future;
@ -103,14 +103,20 @@ where
}
}
// If the client's stream is already subscribe to the collab, return.
if self
.groups
.contains_user(object_id, user)
.await
.unwrap_or(false)
{
return Ok(());
// If the message is init sync message, which means the client just open the collab again. So
// remove the user from the group first and then subscribe the client's stream to the group.
if collab_message.is_init_msg() {
self.groups.remove_user(object_id, user).await;
} else {
// If the client's stream is already subscribe to the collab, return.
if self
.groups
.contains_user(object_id, user)
.await
.unwrap_or(false)
{
return Ok(());
}
}
let origin = match collab_message.origin() {
@ -118,7 +124,7 @@ where
error!("🔴The origin from client message is empty");
&CollabOrigin::Empty
},
Some(client) => client,
Some(origin) => origin,
};
match self.client_stream_by_user.write().await.get_mut(user) {
None => warn!("The client stream is not found"),
@ -234,20 +240,16 @@ where
}
debug!(
"Group: {} has {} members",
object_id,
collab_group.subscribers.read().await.len()
);
trace!(
"Group: {} members: {:?}",
"{}: Group member: {}. member ids: {:?}",
object_id,
collab_group.subscribers.read().await.len(),
collab_group
.subscribers
.read()
.await
.values()
.map(|value| &value.origin)
.collect::<Vec<_>>()
.map(|value| value.origin.client_user_id())
.collect::<Vec<_>>(),
);
}
},

View File

@ -131,7 +131,6 @@ where
let editing_collab_by_user = self.editing_collab_by_user.clone();
Box::pin(async move {
remove_user(&groups, &editing_collab_by_user, &msg.user).await;
if client_stream_by_user
.write()
.await