test: add ws reconnect test (#58)

* test: add test

* test: add reconnect
This commit is contained in:
Nathan.fooo 2023-09-18 11:42:32 +08:00 committed by GitHub
parent 6c4bbbbf7f
commit 7ae645a7c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 272 additions and 81 deletions

2
Cargo.lock generated
View File

@ -781,6 +781,7 @@ dependencies = [
"futures-util",
"gotrue-entity",
"opener",
"parking_lot",
"reqwest",
"serde",
"serde_json",
@ -2477,6 +2478,7 @@ dependencies = [
"assert-json-diff",
"async-trait",
"bytes",
"chrono",
"collab",
"collab-define",
"collab-plugins",

View File

@ -16,7 +16,7 @@ storage-entity = { path = "../storage-entity" }
opener = "0.6.1"
url = "2.4.1"
tokio-stream = { version = "0.1.14" }
parking_lot = "0.12.1"
# ws
tracing = { version = "0.1" }
thiserror = "1.0.39"

View File

@ -10,7 +10,7 @@ use crate::ws::{BusinessID, ClientRealtimeMessage, WSError, WebSocketChannel};
use tokio::sync::broadcast::{channel, Receiver, Sender};
use tokio::sync::{Mutex, RwLock};
use tokio_retry::strategy::FixedInterval;
use tokio_retry::Retry;
use tokio_retry::{Condition, RetryIf};
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::MaybeTlsStream;
@ -37,11 +37,12 @@ impl Default for WSClientConfig {
type ChannelByObjectId = HashMap<String, Weak<WebSocketChannel>>;
pub struct WSClient {
addr: Mutex<Option<String>>,
addr: Arc<parking_lot::Mutex<Option<String>>>,
config: WSClientConfig,
state: Arc<Mutex<ConnectStateNotify>>,
sender: Sender<Message>,
channels: Arc<RwLock<HashMap<BusinessID, ChannelByObjectId>>>,
ping: Arc<Mutex<ServerFixIntervalPing>>,
ping: Arc<Mutex<Option<ServerFixIntervalPing>>>,
}
impl WSClient {
@ -49,14 +50,10 @@ impl WSClient {
let (sender, _) = channel(config.buffer_capacity);
let state = Arc::new(Mutex::new(ConnectStateNotify::new()));
let channels = Arc::new(RwLock::new(HashMap::new()));
let ping = Arc::new(Mutex::new(ServerFixIntervalPing::new(
Duration::from_secs(config.ping_per_secs),
state.clone(),
sender.clone(),
config.retry_connect_per_pings,
)));
let ping = Arc::new(Mutex::new(None));
WSClient {
addr: Mutex::new(None),
addr: Arc::new(parking_lot::Mutex::new(None)),
config,
state,
sender,
channels,
@ -65,12 +62,16 @@ impl WSClient {
}
pub async fn connect(&self, addr: String) -> Result<Option<SocketAddr>, WSError> {
*self.addr.lock().await = Some(addr.clone());
*self.addr.lock() = Some(addr.clone());
self.set_state(ConnectState::Connecting).await;
let retry_strategy = FixedInterval::new(Duration::from_secs(2)).take(3);
let action = ConnectAction::new(addr);
let stream = Retry::spawn(retry_strategy, action).await?;
let action = ConnectAction::new(addr.clone());
let cond = RetryCondition {
connecting_addr: addr,
addr: Arc::downgrade(&self.addr),
};
let stream = RetryIf::spawn(retry_strategy, action, cond).await?;
let addr = match stream.get_ref() {
MaybeTlsStream::Plain(s) => s.local_addr().ok(),
_ => None,
@ -80,7 +81,16 @@ impl WSClient {
self.set_state(ConnectState::Connected).await;
let weak_channels = Arc::downgrade(&self.channels);
let sender = self.sender.clone();
self.ping.lock().await.run();
let mut ping = ServerFixIntervalPing::new(
Duration::from_secs(self.config.ping_per_secs),
self.state.clone(),
sender.clone(),
self.config.retry_connect_per_pings,
);
ping.run();
*self.ping.lock().await = Some(ping);
// Receive messages from the websocket, and send them to the channels.
tokio::spawn(async move {
while let Some(Ok(msg)) = stream.next().await {
@ -158,7 +168,9 @@ impl WSClient {
}
pub async fn disconnect(&self) {
*self.addr.lock() = None;
let _ = self.sender.send(Message::Close(None));
self.set_state(ConnectState::Disconnected).await;
}
async fn set_state(&self, state: ConnectState) {
@ -295,3 +307,20 @@ impl ConnectState {
matches!(self, ConnectState::Disconnected)
}
}
struct RetryCondition {
connecting_addr: String,
addr: Weak<parking_lot::Mutex<Option<String>>>,
}
impl Condition<WSError> for RetryCondition {
fn should_retry(&mut self, _error: &WSError) -> bool {
self
.addr
.upgrade()
.map(|addr| match addr.lock().as_ref() {
None => false,
Some(addr) => addr == &self.connecting_addr,
})
.unwrap_or(false)
}
}

View File

@ -30,6 +30,7 @@ storage-entity = { path = "../storage-entity" }
y-sync = { version = "0.3.1" }
yrs = "0.16.5"
lib0 = "0.16.3"
chrono = "0.4.30"
[dev-dependencies]
actix = "0.13"

View File

@ -181,23 +181,26 @@ impl Deref for ClientWSSink {
pub struct RealtimeUserImpl {
pub uuid: String,
pub device_id: String,
timestamp: i64,
}
impl RealtimeUserImpl {
pub fn new(uuid: String, device_id: String) -> Self {
Self {
uuid,
device_id,
timestamp: chrono::Utc::now().timestamp(),
}
}
}
impl Display for RealtimeUserImpl {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!(
"uuid:{}|device_id:{}",
self.uuid, self.device_id,
"uuid:{}|device_id:{}:{}",
self.uuid, self.device_id, self.timestamp
))
}
}
impl RealtimeUser for RealtimeUserImpl {
fn id(&self) -> &str {
&self.uuid
}
fn device_id(&self) -> &str {
&self.device_id
}
}
impl RealtimeUser for RealtimeUserImpl {}

View File

@ -12,17 +12,19 @@ use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use storage::collab::CollabStorage;
use crate::entities::RealtimeUser;
use tokio::task::spawn_blocking;
use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
pub struct CollabGroupCache<S> {
group_by_object_id: RwLock<HashMap<String, Arc<CollabGroup>>>,
pub struct CollabGroupCache<S, U> {
group_by_object_id: RwLock<HashMap<String, Arc<CollabGroup<U>>>>,
storage: S,
}
impl<S> CollabGroupCache<S>
impl<S, U> CollabGroupCache<S, U>
where
S: CollabStorage + Clone,
U: RealtimeUser,
{
pub fn new(storage: S) -> Self {
Self {
@ -57,7 +59,7 @@ where
workspace_id: &str,
object_id: &str,
collab_type: CollabType,
) -> Arc<CollabGroup> {
) -> Arc<CollabGroup<U>> {
tracing::trace!("Create new group for object_id:{}", object_id);
let collab = MutexCollab::new(CollabOrigin::Server, object_id, vec![]);
let broadcast = CollabBroadcast::new(object_id, collab.clone(), 10);
@ -89,20 +91,22 @@ where
}
}
impl<S> Deref for CollabGroupCache<S>
impl<S, U> Deref for CollabGroupCache<S, U>
where
S: CollabStorage,
U: RealtimeUser,
{
type Target = RwLock<HashMap<String, Arc<CollabGroup>>>;
type Target = RwLock<HashMap<String, Arc<CollabGroup<U>>>>;
fn deref(&self) -> &Self::Target {
&self.group_by_object_id
}
}
impl<S> DerefMut for CollabGroupCache<S>
impl<S, U> DerefMut for CollabGroupCache<S, U>
where
S: CollabStorage,
U: RealtimeUser,
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.group_by_object_id
@ -110,7 +114,7 @@ where
}
/// A group used to manage a single [Collab] object
pub struct CollabGroup {
pub struct CollabGroup<U> {
pub collab: Arc<MutexCollab>,
/// A broadcast used to propagate updates produced by yrs [yrs::Doc] and [Awareness]
@ -119,10 +123,13 @@ pub struct CollabGroup {
/// A list of subscribers to this group. Each subscriber will receive updates from the
/// broadcast.
pub subscribers: RwLock<HashMap<CollabOrigin, Subscription>>,
pub subscribers: RwLock<HashMap<U, Subscription>>,
}
impl CollabGroup {
impl<U> CollabGroup<U>
where
U: RealtimeUser,
{
/// Mutate the [Collab] by the given closure
pub fn get_mut_collab<F>(&self, f: F)
where

View File

@ -14,27 +14,28 @@ use storage::error::StorageError;
use storage_entity::{InsertCollabParams, QueryCollabParams, RawData};
use crate::collaborate::group::CollabGroup;
use crate::entities::RealtimeUser;
use y_sync::awareness::Awareness;
use yrs::updates::decoder::Decode;
use yrs::{ReadTxn, StateVector, Transact, Update};
pub struct CollabStoragePlugin<S> {
pub struct CollabStoragePlugin<S, U> {
uid: i64,
workspace_id: String,
storage: S,
did_load: AtomicBool,
update_count: AtomicU32,
group: Weak<CollabGroup>,
group: Weak<CollabGroup<U>>,
collab_type: CollabType,
}
impl<S> CollabStoragePlugin<S> {
impl<S, U> CollabStoragePlugin<S, U> {
pub fn new(
uid: i64,
workspace_id: &str,
collab_type: CollabType,
storage: S,
group: Weak<CollabGroup>,
group: Weak<CollabGroup<U>>,
) -> Self {
let workspace_id = workspace_id.to_string();
let did_load = AtomicBool::new(false);
@ -62,9 +63,10 @@ fn init_collab_with_raw_data(raw_data: RawData, doc: &Doc) -> Result<(), Realtim
}
#[async_trait]
impl<S> CollabPlugin for CollabStoragePlugin<S>
impl<S, U> CollabPlugin for CollabStoragePlugin<S, U>
where
S: CollabStorage,
U: RealtimeUser,
{
async fn init(&self, object_id: &str, _origin: &CollabOrigin, doc: &Doc) {
let params = QueryCollabParams {

View File

@ -23,7 +23,7 @@ pub struct CollabServer<S, U> {
#[allow(dead_code)]
storage: S,
/// Keep track of all collab groups
groups: Arc<CollabGroupCache<S>>,
groups: Arc<CollabGroupCache<S, U>>,
/// Keep track of all object ids that a user is subscribed to
editing_collab_by_user: Arc<RwLock<HashMap<U, HashSet<Editing>>>>,
/// Keep track of all client streams
@ -45,6 +45,18 @@ where
client_stream_by_user: Default::default(),
})
}
fn remove_user(&self, user: &U) {
self.client_stream_by_user.write().remove(user);
let editing_set = self.editing_collab_by_user.write().remove(user);
if let Some(editing_set) = editing_set {
tracing::info!("Remove user from group: {}", user);
for editing in editing_set {
remove_user_from_group(user, &self.groups, &editing);
}
}
}
}
impl<S, U> Actor for CollabServer<S, U>
@ -58,19 +70,20 @@ where
impl<S, U> Handler<Connect<U>> for CollabServer<S, U>
where
U: RealtimeUser + Unpin,
S: 'static + Unpin,
S: CollabStorage + Unpin,
{
type Result = Result<(), RealtimeError>;
fn handle(&mut self, new_conn: Connect<U>, _ctx: &mut Context<Self>) -> Self::Result {
tracing::trace!("[💭Server]: new connection => {} ", new_conn.user);
// Remove the user from the group if the user is already connected
self.remove_user(&new_conn.user);
let stream = CollabClientStream::new(ClientWSSink(new_conn.socket));
self
.client_stream_by_user
.write()
.insert(new_conn.user, stream);
Ok(())
}
}
@ -83,21 +96,7 @@ where
type Result = Result<(), RealtimeError>;
fn handle(&mut self, msg: Disconnect<U>, _: &mut Context<Self>) -> Self::Result {
tracing::trace!("[💭Server]: disconnect => {}", msg.user);
self.client_stream_by_user.write().remove(&msg.user);
// Remove the user from all collab groups that the user is subscribed to
let editing_set = self.editing_collab_by_user.write().remove(&msg.user);
if let Some(editing_set) = editing_set {
if !editing_set.is_empty() {
let groups = self.groups.clone();
tokio::task::spawn_blocking(move || {
for editing in editing_set {
remove_user_from_group(&groups, &editing);
}
});
}
}
self.remove_user(&msg.user);
Ok(())
}
}
@ -136,7 +135,8 @@ async fn forward_message_to_collab_group<U>(
{
if let Some(client_stream) = client_streams.read().get(&client_msg.user) {
tracing::trace!(
"[💭Server]: receives: [oid:{}|msg_id:{:?}]",
"[💭Server]: receives: user:{} message: [oid:{}|msg_id:{:?}]",
client_msg.user,
client_msg.content.object_id(),
client_msg.content.msg_id()
);
@ -154,7 +154,7 @@ async fn forward_message_to_collab_group<U>(
async fn subscribe_collab_group_change_if_need<U, S>(
client_msg: &ClientMessage<U>,
groups: &Arc<CollabGroupCache<S>>,
groups: &Arc<CollabGroupCache<S, U>>,
edit_collab_by_user: &Arc<RwLock<HashMap<U, HashSet<Editing>>>>,
client_streams: &Arc<RwLock<HashMap<U, CollabClientStream>>>,
) -> Result<(), RealtimeError>
@ -200,7 +200,7 @@ where
if groups
.read()
.get(object_id)
.map(|group| group.subscribers.read().get(origin).is_some())
.map(|group| group.subscribers.read().get(&client_msg.user).is_some())
.unwrap_or(false)
{
return Ok(());
@ -213,7 +213,7 @@ where
collab_group
.subscribers
.write()
.entry(origin.clone())
.entry(client_msg.user.clone())
.or_insert_with(|| {
tracing::trace!(
"[💭Server]: {} subscribe group:{}",
@ -250,15 +250,16 @@ where
}
/// Remove the user from the group and remove the group from the cache if the group is empty.
fn remove_user_from_group<S>(groups: &Arc<CollabGroupCache<S>>, editing: &Editing)
fn remove_user_from_group<S, U>(user: &U, groups: &Arc<CollabGroupCache<S, U>>, editing: &Editing)
where
S: CollabStorage,
U: RealtimeUser,
{
let mut groups_write_guard = groups.write();
let should_remove_group = groups_write_guard.get_mut(&editing.object_id).map(|group| {
tracing::debug!("Remove subscriber: {}", editing.origin);
group.subscribers.write().remove(&editing.origin);
tracing::info!("Remove subscriber: {}", editing.origin);
group.subscribers.write().remove(user);
let should_remove = group.is_empty();
if should_remove {
group.flush_collab();

View File

@ -8,15 +8,15 @@ use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use std::fmt::{Debug, Display};
use std::hash::Hash;
use std::sync::Arc;
pub trait RealtimeUser:
Clone + Debug + Send + Sync + 'static + Display + Hash + Eq + PartialEq
{
fn id(&self) -> &str;
fn device_id(&self) -> &str;
}
impl<T> RealtimeUser for Arc<T> where T: RealtimeUser {}
#[derive(Debug, Message, Clone)]
#[rtype(result = "Result<(), RealtimeError>")]
pub struct Connect<U> {

View File

@ -3,6 +3,7 @@ use actix::Addr;
use actix_web::web::{Data, Path, Payload};
use actix_web::{get, web, HttpRequest, HttpResponse, Result, Scope};
use actix_web_actors::ws;
use std::sync::Arc;
use realtime::client::{ClientWSSession, RealtimeUserImpl};
use realtime::collaborate::CollabServer;
@ -23,16 +24,13 @@ pub async fn establish_ws_connection(
payload: Payload,
path: Path<(String, String)>,
state: Data<AppState>,
server: Data<Addr<CollabServer<CollabStorageProxy, RealtimeUserImpl>>>,
server: Data<Addr<CollabServer<CollabStorageProxy, Arc<RealtimeUserImpl>>>>,
) -> Result<HttpResponse> {
tracing::info!("ws connect: {:?}", request);
let (token, device_id) = path.into_inner();
let auth = authorization_from_token(token.as_str(), &state)?;
let user_uuid = UserUuid::from_auth(auth)?;
let realtime_user = RealtimeUserImpl {
uuid: user_uuid.to_string(),
device_id,
};
let realtime_user = Arc::new(RealtimeUserImpl::new(user_uuid.to_string(), device_id));
let client = ClientWSSession::new(
realtime_user,
server.get_ref().clone(),

View File

@ -83,7 +83,7 @@ where
.map(|(_, server_key)| Key::from(server_key.expose_secret().as_bytes()))
.unwrap_or_else(Key::generate);
let collab_server = CollabServer::<_, RealtimeUserImpl>::new(storage.collab_storage.clone())
let collab_server = CollabServer::<_, Arc<RealtimeUserImpl>>::new(storage.collab_storage.clone())
.unwrap()
.start();
let mut server = HttpServer::new(move || {

View File

@ -30,3 +30,36 @@ async fn realtime_connect_test() {
}
}
}
#[tokio::test]
async fn realtime_disconnect_test() {
let _guard = REGISTERED_USER_MUTEX.lock().await;
let mut c = client_api_client();
c.sign_in_password(&REGISTERED_EMAIL, &REGISTERED_PASSWORD)
.await
.unwrap();
let ws_client = WSClient::new(WSClientConfig {
buffer_capacity: 100,
ping_per_secs: 2,
retry_connect_per_pings: 5,
});
ws_client
.connect(c.ws_url("fake_device_id").unwrap())
.await
.unwrap();
let mut state = ws_client.subscribe_connect_state().await;
loop {
tokio::select! {
_ = ws_client.disconnect() => {},
value = state.recv() => {
let new_state = value.unwrap();
if new_state == ConnectState::Disconnected {
break;
}
},
}
}
}

View File

@ -6,8 +6,11 @@ use collab_define::CollabType;
use crate::realtime::test_client::{assert_collab_json, TestClient};
use assert_json_diff::assert_json_eq;
use shared_entity::error_code::ErrorCode;
use std::time::Duration;
use storage::collab::FLUSH_PER_UPDATE;
use storage_entity::QueryCollabParams;
use uuid::Uuid;
#[tokio::test]
async fn realtime_write_collab_test() {
@ -25,7 +28,6 @@ async fn realtime_write_collab_test() {
// Wait for the messages to be sent
tokio::time::sleep(Duration::from_secs(2)).await;
test_client.disconnect().await;
assert_collab_json(
&mut test_client.api_client,
@ -73,6 +75,52 @@ async fn one_direction_peer_sync_test() {
assert_json_eq!(json_1, json_2);
}
#[tokio::test]
async fn same_user_with_same_device_id_test() {
let object_id = uuid::Uuid::new_v4().to_string();
let collab_type = CollabType::Document;
// Client_1_2 will force the server to disconnect client_1_1. So any changes made by client_1_1
// will not be saved to the server.
let device_id = Uuid::new_v4().to_string();
let client_1_1 =
TestClient::new_with_device_id(&object_id, &device_id, collab_type.clone()).await;
let mut client_1_2 =
TestClient::new_with_device_id(&object_id, &device_id, collab_type.clone()).await;
client_1_1.collab.lock().insert("1", "a");
client_1_2.collab.lock().insert("2", "b");
client_1_1.collab.lock().insert("3", "c");
tokio::time::sleep(Duration::from_millis(200)).await;
let json_1 = client_1_1.collab.lock().to_json_value();
let json_2 = client_1_2.collab.lock().to_json_value();
assert_json_eq!(
json_1,
json!({
"1": "a",
"3": "c"
})
);
assert_json_eq!(
json_2,
json!({
"2": "b"
})
);
assert_collab_json(
&mut client_1_2.api_client,
&object_id,
&collab_type,
5,
json!({
"2": "b"
}),
)
.await;
}
#[tokio::test]
async fn two_direction_peer_sync_test() {
let _client_api = client_api_client();
@ -174,3 +222,51 @@ async fn multiple_collab_edit_test() {
)
.await;
}
#[tokio::test]
async fn ws_reconnect_sync_test() {
let object_id = uuid::Uuid::new_v4().to_string();
let collab_type = CollabType::Document;
let mut test_client = TestClient::new(&object_id, collab_type.clone()).await;
// Disconnect the client and edit the collab. The updates will not be sent to the server.
test_client.disconnect().await;
for i in 0..=5 {
test_client
.collab
.lock()
.insert(&i.to_string(), i.to_string());
}
// it will return RecordNotFound error when trying to get the collab from the server
let err = test_client
.api_client
.get_collab(QueryCollabParams {
object_id: object_id.clone(),
collab_type: collab_type.clone(),
})
.await
.unwrap_err();
assert_eq!(err.code, ErrorCode::RecordNotFound);
// After reconnect the collab should be synced to the server.
test_client.reconnect().await;
// Wait for the messages to be sent
tokio::time::sleep(Duration::from_secs(2)).await;
assert_collab_json(
&mut test_client.api_client,
&object_id,
&collab_type,
3,
json!( {
"0": "0",
"1": "1",
"2": "2",
"3": "3",
"4": "4",
"5": "5",
}),
)
.await;
}

View File

@ -25,10 +25,16 @@ pub(crate) struct TestClient {
#[allow(dead_code)]
pub handler: Arc<WebSocketChannel>,
pub api_client: client_api::Client,
device_id: String,
}
impl TestClient {
pub(crate) async fn new(object_id: &str, collab_type: CollabType) -> Self {
pub(crate) async fn new_with_device_id(
object_id: &str,
device_id: &str,
collab_type: CollabType,
) -> Self {
let device_id = device_id.to_string();
let mut api_client = client_api_client();
let _guard = REGISTERED_USER_MUTEX.lock().await;
@ -38,7 +44,6 @@ impl TestClient {
.await
.unwrap();
let device_id = Uuid::new_v4().to_string();
// Connect to server via websocket
let ws_client = WSClient::new(WSClientConfig {
buffer_capacity: 100,
@ -67,7 +72,7 @@ impl TestClient {
.await
.unwrap();
let (sink, stream) = (handler.sink(), handler.stream());
let origin = CollabOrigin::Client(CollabClient::new(uid, device_id));
let origin = CollabOrigin::Client(CollabClient::new(uid, device_id.clone()));
let collab = Arc::new(MutexCollab::new(origin.clone(), object_id, vec![]));
let object = SyncObject::new(object_id, &workspace_id, collab_type);
@ -87,12 +92,26 @@ impl TestClient {
origin,
collab,
handler,
device_id,
}
}
pub(crate) async fn new(object_id: &str, collab_type: CollabType) -> Self {
let device_id = Uuid::new_v4().to_string();
Self::new_with_device_id(object_id, &device_id, collab_type).await
}
pub(crate) async fn disconnect(&self) {
self.ws_client.disconnect().await;
}
pub(crate) async fn reconnect(&self) {
self
.ws_client
.connect(self.api_client.ws_url(&self.device_id).unwrap())
.await
.unwrap();
}
}
#[allow(dead_code)]
@ -130,9 +149,9 @@ pub async fn assert_collab_json(
}
tokio::time::sleep(Duration::from_millis(200)).await;
},
Err(_) => {
Err(e) => {
if retry_count > 5 {
panic!("Query collab failed");
panic!("Query collab failed: {}", e);
}
tokio::time::sleep(Duration::from_millis(200)).await;
}