chore: remove subscriber when receive init message (#156)
This commit is contained in:
parent
f626e4a3b2
commit
eb633c2ba4
|
|
@ -1,3 +1,4 @@
|
|||
use reqwest::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::{Display, Formatter};
|
||||
use thiserror::Error;
|
||||
|
|
@ -22,6 +23,9 @@ pub enum GoTrueError {
|
|||
#[error("{0}")]
|
||||
NotLoggedIn(String),
|
||||
|
||||
#[error("{0}")]
|
||||
Auth(String),
|
||||
|
||||
#[error(transparent)]
|
||||
Unhandled(#[from] anyhow::Error),
|
||||
}
|
||||
|
|
@ -49,6 +53,12 @@ impl From<reqwest::Error> for GoTrueError {
|
|||
return GoTrueError::InvalidRequest(value.to_string());
|
||||
}
|
||||
|
||||
if let Some(status) = value.status() {
|
||||
if status == StatusCode::UNAUTHORIZED {
|
||||
return GoTrueError::Auth(value.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
GoTrueError::Unhandled(value.into())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -181,7 +181,7 @@ impl From<crate::gotrue::GoTrueError> for AppError {
|
|||
GoTrueError::RequestTimeout(msg) => AppError::RequestTimeout(msg),
|
||||
GoTrueError::InvalidRequest(msg) => AppError::InvalidRequest(msg),
|
||||
GoTrueError::ClientError(err) => AppError::OAuthError(err.to_string()),
|
||||
GoTrueError::Unhandled(err) => AppError::Internal(err),
|
||||
GoTrueError::Auth(err) => AppError::OAuthError(err),
|
||||
GoTrueError::Internal(err) => match (err.code, err.msg.as_str()) {
|
||||
(400, m) if m.starts_with("oauth error") => AppError::OAuthError(err.msg),
|
||||
(400, m) if m.starts_with("User already registered") => AppError::OAuthError(err.msg),
|
||||
|
|
@ -189,6 +189,7 @@ impl From<crate::gotrue::GoTrueError> for AppError {
|
|||
(422, _) => AppError::InvalidRequest(err.msg),
|
||||
_ => AppError::OAuthError(err.to_string()),
|
||||
},
|
||||
GoTrueError::Unhandled(err) => AppError::Internal(err),
|
||||
GoTrueError::NotLoggedIn(msg) => AppError::NotLoggedIn(msg),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -264,16 +264,17 @@ where
|
|||
where
|
||||
P: CollabSyncProtocol + Send + Sync + 'static,
|
||||
{
|
||||
let payload = msg.payload();
|
||||
if match msg.msg_id() {
|
||||
// The msg_id is None if the message is [ServerBroadcast] or [ServerAwareness]
|
||||
None => true,
|
||||
Some(msg_id) => sink.ack_msg(msg.origin(), msg.object_id(), msg_id).await,
|
||||
} && !msg.payload().is_empty()
|
||||
} && payload.is_some()
|
||||
{
|
||||
trace!("start process message: {:?}", msg.msg_id());
|
||||
SyncStream::<Sink, Stream>::process_payload(
|
||||
origin,
|
||||
msg.payload(),
|
||||
payload.unwrap(),
|
||||
object_id,
|
||||
protocol,
|
||||
collab,
|
||||
|
|
|
|||
|
|
@ -374,7 +374,7 @@ impl Client {
|
|||
/// This function attempts to read the current token and, if successful, returns the expiration timestamp.
|
||||
///
|
||||
/// # Returns
|
||||
/// - `Ok(i64)`: An `i64` representing the expiration timestamp of the token.
|
||||
/// - `Ok(i64)`: An `i64` representing the expiration timestamp of the token in seconds.
|
||||
/// - `Err(AppError)`: An `AppError` indicating either an inability to read the token or that the user is not logged in.
|
||||
///
|
||||
#[inline]
|
||||
|
|
|
|||
|
|
@ -160,22 +160,32 @@ impl WSClient {
|
|||
match msg {
|
||||
RealtimeMessage::Collab(collab_msg) => {
|
||||
if let Some(channels) = weak_channels.upgrade() {
|
||||
if let Some(channel) = channels.read().get(collab_msg.object_id()) {
|
||||
let object_id = collab_msg.object_id().to_owned();
|
||||
let is_channel_dropped = if let Some(channel) = channels.read().get(&object_id)
|
||||
{
|
||||
match channel.upgrade() {
|
||||
None => {
|
||||
// when calling [WSClient::subscribe], the caller is responsible for keeping
|
||||
// the channel alive as long as it wants to receive messages from the websocket.
|
||||
warn!("channel is dropped");
|
||||
true
|
||||
},
|
||||
Some(channel) => {
|
||||
channel.forward_to_stream(collab_msg);
|
||||
false
|
||||
},
|
||||
}
|
||||
} else {
|
||||
warn!(
|
||||
"can't find channel by object_id: {}",
|
||||
collab_msg.object_id()
|
||||
);
|
||||
warn!("can't find channel of object_id: {}", object_id);
|
||||
false
|
||||
};
|
||||
|
||||
// Try to remove the channel if it is dropped. If failed, will try again next time.
|
||||
if is_channel_dropped {
|
||||
if let Some(mut w) = channels.try_write() {
|
||||
trace!("remove channel: {}", object_id);
|
||||
w.remove(&object_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("channels are closed");
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ pub enum CollabMessage {
|
|||
ServerInit(ServerCollabInit),
|
||||
ServerAwareness(CollabAwarenessData),
|
||||
ServerBroadcast(CollabBroadcastData),
|
||||
CloseCollab(CloseCollabData),
|
||||
}
|
||||
|
||||
impl CollabSinkMessage for CollabMessage {
|
||||
|
|
@ -45,7 +46,7 @@ impl CollabSinkMessage for CollabMessage {
|
|||
}
|
||||
|
||||
fn length(&self) -> usize {
|
||||
self.payload().len()
|
||||
self.payload().map(|payload| payload.len()).unwrap_or(0)
|
||||
}
|
||||
|
||||
fn can_merge(&self) -> bool {
|
||||
|
|
@ -125,11 +126,15 @@ impl CollabMessage {
|
|||
CollabMessage::ServerInit(value) => Some(value.msg_id),
|
||||
CollabMessage::ServerBroadcast(_) => None,
|
||||
CollabMessage::ServerAwareness(_) => None,
|
||||
CollabMessage::CloseCollab(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.payload().is_empty()
|
||||
self
|
||||
.payload()
|
||||
.map(|payload| payload.is_empty())
|
||||
.unwrap_or(true)
|
||||
}
|
||||
|
||||
pub fn origin(&self) -> Option<&CollabOrigin> {
|
||||
|
|
@ -140,6 +145,7 @@ impl CollabMessage {
|
|||
CollabMessage::ServerInit(value) => Some(&value.origin),
|
||||
CollabMessage::ServerBroadcast(value) => Some(&value.origin),
|
||||
CollabMessage::ServerAwareness(_) => None,
|
||||
CollabMessage::CloseCollab(value) => Some(&value.origin),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -155,6 +161,7 @@ impl CollabMessage {
|
|||
CollabMessage::ServerInit(value) => &value.object_id,
|
||||
CollabMessage::ServerBroadcast(value) => &value.object_id,
|
||||
CollabMessage::ServerAwareness(value) => &value.object_id,
|
||||
CollabMessage::CloseCollab(value) => &value.object_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -198,6 +205,9 @@ impl Display for CollabMessage {
|
|||
value.object_id,
|
||||
value.payload.len(),
|
||||
)),
|
||||
CollabMessage::CloseCollab(value) => {
|
||||
f.write_fmt(format_args!("close collab: [oid:{}]", value.object_id,))
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -217,14 +227,15 @@ impl CollabMessage {
|
|||
serde_json::from_slice(data)
|
||||
}
|
||||
|
||||
pub fn payload(&self) -> &Bytes {
|
||||
pub fn payload(&self) -> Option<&Bytes> {
|
||||
match self {
|
||||
CollabMessage::ClientInit(value) => &value.payload,
|
||||
CollabMessage::ClientUpdateSync(value) => &value.payload,
|
||||
CollabMessage::ClientUpdateAck(value) => &value.payload,
|
||||
CollabMessage::ServerInit(value) => &value.payload,
|
||||
CollabMessage::ServerBroadcast(value) => &value.payload,
|
||||
CollabMessage::ServerAwareness(value) => &value.payload,
|
||||
CollabMessage::ClientInit(value) => Some(&value.payload),
|
||||
CollabMessage::ClientUpdateSync(value) => Some(&value.payload),
|
||||
CollabMessage::ClientUpdateAck(value) => Some(&value.payload),
|
||||
CollabMessage::ServerInit(value) => Some(&value.payload),
|
||||
CollabMessage::ServerBroadcast(value) => Some(&value.payload),
|
||||
CollabMessage::ServerAwareness(value) => Some(&value.payload),
|
||||
CollabMessage::CloseCollab(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -469,6 +480,17 @@ impl TryFrom<RealtimeMessage> for CollabMessage {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Debug, Serialize, Deserialize)]
|
||||
pub struct CloseCollabData {
|
||||
origin: CollabOrigin,
|
||||
object_id: String,
|
||||
}
|
||||
|
||||
impl From<CloseCollabData> for CollabMessage {
|
||||
fn from(value: CloseCollabData) -> Self {
|
||||
CollabMessage::CloseCollab(value)
|
||||
}
|
||||
}
|
||||
impl From<CollabMessage> for RealtimeMessage {
|
||||
fn from(msg: CollabMessage) -> Self {
|
||||
Self::Collab(msg)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ use database::collab::CollabStorage;
|
|||
pub use realtime_entity::user::RealtimeUserImpl;
|
||||
use std::ops::Deref;
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::{error, trace};
|
||||
use tracing::error;
|
||||
|
||||
pub struct ClientSession<
|
||||
U: Unpin + RealtimeUser,
|
||||
|
|
@ -133,10 +133,7 @@ where
|
|||
|
||||
fn handle(&mut self, msg: RealtimeMessage, ctx: &mut Self::Context) {
|
||||
match &msg {
|
||||
RealtimeMessage::Collab(collab_msg) => {
|
||||
trace!("{:?}: receives collab message: {:?}", self.user, collab_msg);
|
||||
ctx.binary(msg)
|
||||
},
|
||||
RealtimeMessage::Collab(_collab_msg) => ctx.binary(msg),
|
||||
RealtimeMessage::ServerKickedOff => {
|
||||
// The server will send this message to the client when the client is kicked out. So
|
||||
// set the current user to None and stop the session.
|
||||
|
|
|
|||
|
|
@ -46,13 +46,14 @@ impl CollabBroadcast {
|
|||
/// provided `buffer_capacity` size.
|
||||
pub fn new(object_id: &str, collab: MutexCollab, buffer_capacity: usize) -> Self {
|
||||
let object_id = object_id.to_owned();
|
||||
// broadcast channel
|
||||
let (sender, _) = channel(buffer_capacity);
|
||||
let (doc_sub, awareness_sub) = {
|
||||
let mut mutex_collab = collab.lock();
|
||||
|
||||
// Observer the document's update and broadcast it to all subscribers.
|
||||
let cloned_oid = object_id.clone();
|
||||
let sink = sender.clone();
|
||||
let broadcast_sink = sender.clone();
|
||||
let doc_sub = mutex_collab
|
||||
.get_mut_awareness()
|
||||
.doc_mut()
|
||||
|
|
@ -60,13 +61,13 @@ impl CollabBroadcast {
|
|||
let origin = CollabOrigin::from(txn);
|
||||
let payload = gen_update_message(&event.update);
|
||||
let msg = CollabBroadcastData::new(origin, cloned_oid.clone(), payload);
|
||||
if let Err(_e) = sink.send(msg.into()) {
|
||||
trace!("Broadcast group is closed");
|
||||
if let Err(e) = broadcast_sink.send(msg.into()) {
|
||||
error!("broadcast sink fail: {}", e);
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let sink = sender.clone();
|
||||
let broadcast_sink = sender.clone();
|
||||
let cloned_oid = object_id.clone();
|
||||
|
||||
// Observer the awareness's update and broadcast it to all subscribers.
|
||||
|
|
@ -76,7 +77,7 @@ impl CollabBroadcast {
|
|||
if let Ok(awareness_update) = gen_awareness_update_message(awareness, event) {
|
||||
let payload = Message::Awareness(awareness_update).encode_v1();
|
||||
let msg = CollabAwarenessData::new(cloned_oid.clone(), payload);
|
||||
if let Err(_e) = sink.send(msg.into()) {
|
||||
if let Err(_e) = broadcast_sink.send(msg.into()) {
|
||||
trace!("Broadcast group is closed");
|
||||
}
|
||||
}
|
||||
|
|
@ -178,7 +179,13 @@ impl CollabBroadcast {
|
|||
error!("[🔴Server]: Incoming message's object id does not match the broadcast group's object id");
|
||||
continue;
|
||||
}
|
||||
let mut decoder = DecoderV1::from(collab_msg.payload().as_ref());
|
||||
|
||||
let payload = collab_msg.payload();
|
||||
if payload.is_none() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut decoder = DecoderV1::from(payload.unwrap().as_ref());
|
||||
match sink.try_lock() {
|
||||
Ok(mut sink) => {
|
||||
let reader = MessageReader::new(&mut decoder);
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ use std::sync::Arc;
|
|||
use tokio::sync::RwLock;
|
||||
use tokio::task::spawn_blocking;
|
||||
|
||||
use tracing::{error, event, warn};
|
||||
use tracing::{error, event, trace, warn};
|
||||
|
||||
pub struct CollabGroupCache<S, U, AC> {
|
||||
group_by_object_id: Arc<RwLock<HashMap<String, Arc<CollabGroup<U>>>>>,
|
||||
|
|
@ -43,6 +43,15 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn remove_user(&self, object_id: &str, user: &U) {
|
||||
let group_by_object_id = self.group_by_object_id.read().await;
|
||||
if let Some(group) = group_by_object_id.get(object_id) {
|
||||
if let Some(subscriber) = group.subscribers.write().await.remove(user) {
|
||||
trace!("Remove subscriber: {}", subscriber.origin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn contains_group(&self, object_id: &str) -> Result<bool, Error> {
|
||||
let group_by_object_id = self.group_by_object_id.try_read()?;
|
||||
Ok(group_by_object_id.get(object_id).is_some())
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ use collab::core::origin::CollabOrigin;
|
|||
use database::collab::CollabStorage;
|
||||
use futures_util::SinkExt;
|
||||
use parking_lot::Mutex;
|
||||
use realtime_entity::collab_msg::CollabMessage;
|
||||
use realtime_entity::collab_msg::{CollabMessage, CollabSinkMessage};
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::future;
|
||||
|
|
@ -103,14 +103,20 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
// If the client's stream is already subscribe to the collab, return.
|
||||
if self
|
||||
.groups
|
||||
.contains_user(object_id, user)
|
||||
.await
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return Ok(());
|
||||
// If the message is init sync message, which means the client just open the collab again. So
|
||||
// remove the user from the group first and then subscribe the client's stream to the group.
|
||||
if collab_message.is_init_msg() {
|
||||
self.groups.remove_user(object_id, user).await;
|
||||
} else {
|
||||
// If the client's stream is already subscribe to the collab, return.
|
||||
if self
|
||||
.groups
|
||||
.contains_user(object_id, user)
|
||||
.await
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
let origin = match collab_message.origin() {
|
||||
|
|
@ -118,7 +124,7 @@ where
|
|||
error!("🔴The origin from client message is empty");
|
||||
&CollabOrigin::Empty
|
||||
},
|
||||
Some(client) => client,
|
||||
Some(origin) => origin,
|
||||
};
|
||||
match self.client_stream_by_user.write().await.get_mut(user) {
|
||||
None => warn!("The client stream is not found"),
|
||||
|
|
@ -234,20 +240,16 @@ where
|
|||
}
|
||||
|
||||
debug!(
|
||||
"Group: {} has {} members",
|
||||
object_id,
|
||||
collab_group.subscribers.read().await.len()
|
||||
);
|
||||
trace!(
|
||||
"Group: {} members: {:?}",
|
||||
"{}: Group member: {}. member ids: {:?}",
|
||||
object_id,
|
||||
collab_group.subscribers.read().await.len(),
|
||||
collab_group
|
||||
.subscribers
|
||||
.read()
|
||||
.await
|
||||
.values()
|
||||
.map(|value| &value.origin)
|
||||
.collect::<Vec<_>>()
|
||||
.map(|value| value.origin.client_user_id())
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -131,7 +131,6 @@ where
|
|||
let editing_collab_by_user = self.editing_collab_by_user.clone();
|
||||
Box::pin(async move {
|
||||
remove_user(&groups, &editing_collab_by_user, &msg.user).await;
|
||||
|
||||
if client_stream_by_user
|
||||
.write()
|
||||
.await
|
||||
|
|
|
|||
Loading…
Reference in New Issue