265 lines
6.5 KiB
Rust
265 lines
6.5 KiB
Rust
use std::fmt::{Display, Formatter};
|
|
use std::net::TcpListener;
|
|
|
|
use actix::{Actor, Addr};
|
|
use actix_web::dev::Server;
|
|
use actix_web::web::{Data, Path, Payload};
|
|
use actix_web::{get, web, App, HttpRequest, HttpResponse, HttpServer, Result};
|
|
|
|
use actix_web_actors::ws;
|
|
|
|
use collab::core::collab::MutexCollab;
|
|
use collab::core::origin::CollabOrigin;
|
|
use once_cell::sync::Lazy;
|
|
use realtime::core::{CollabManager, CollabSession};
|
|
use realtime::entities::RealtimeUser;
|
|
use serde_aux::field_attributes::deserialize_number_from_string;
|
|
use std::path::PathBuf;
|
|
|
|
use collab_define::CollabType;
|
|
use std::time::Duration;
|
|
|
|
use crate::util::log::{get_subscriber, init_subscriber};
|
|
use crate::util::storage_impl::CollabMemoryStorageImpl;
|
|
use storage::collab::CollabStorage;
|
|
use storage::entities::QueryCollabParams;
|
|
|
|
// Ensure that the `tracing` stack is only initialised once using `once_cell`
|
|
static TRACING: Lazy<()> = Lazy::new(|| {
|
|
let level = "trace".to_string();
|
|
let mut filters = vec![];
|
|
filters.push(format!("actix_web={}", level));
|
|
filters.push(format!("collab={}", level));
|
|
filters.push(format!("collab_sync={}", level));
|
|
filters.push(format!("collab_plugins={}", level));
|
|
filters.push(format!("realtime={}", level));
|
|
|
|
let subscriber_name = "test".to_string();
|
|
let subscriber = get_subscriber(subscriber_name, filters.join(","), std::io::stdout);
|
|
init_subscriber(subscriber);
|
|
});
|
|
|
|
#[derive(Clone)]
|
|
pub struct TestServer {
|
|
pub state: State,
|
|
pub api_client: reqwest::Client,
|
|
pub address: String,
|
|
pub port: u16,
|
|
pub ws_addr: String,
|
|
pub storage: CollabMemoryStorageImpl,
|
|
}
|
|
|
|
impl TestServer {
|
|
pub async fn get_collab(&self, object_id: &str) -> serde_json::Value {
|
|
let params = QueryCollabParams {
|
|
object_id: object_id.to_string(),
|
|
collab_type: CollabType::Document,
|
|
};
|
|
let raw_data = self.storage.get_collab(params).await.unwrap();
|
|
let collab =
|
|
MutexCollab::new_with_raw_data(CollabOrigin::Server, object_id, vec![raw_data], vec![])
|
|
.unwrap();
|
|
collab.async_initialize().await;
|
|
collab.to_json_value()
|
|
}
|
|
}
|
|
|
|
pub async fn spawn_server() -> TestServer {
|
|
Lazy::force(&TRACING);
|
|
let config = Config::default();
|
|
let state = init_state(config.clone()).await;
|
|
let storage = CollabMemoryStorageImpl::new(storage::collab::Config {
|
|
flush_per_update: 0,
|
|
});
|
|
let application = Application::build(config, state.clone(), storage.clone())
|
|
.await
|
|
.expect("Failed to build application");
|
|
|
|
let port = application.port();
|
|
tokio::spawn(async {
|
|
let _ = application.run_until_stopped().await;
|
|
});
|
|
let builder = reqwest::Client::builder();
|
|
let address = format!("http://localhost:{}", port);
|
|
let ws_addr = format!("ws://localhost:{}/ws", port);
|
|
|
|
let api_client = builder.no_proxy().build().unwrap();
|
|
TestServer {
|
|
state,
|
|
storage,
|
|
api_client,
|
|
address,
|
|
ws_addr,
|
|
port,
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct Cleaner {
|
|
path: PathBuf,
|
|
should_clean: bool,
|
|
}
|
|
|
|
impl Cleaner {
|
|
pub(crate) fn new(path: PathBuf) -> Self {
|
|
Self {
|
|
path,
|
|
should_clean: true,
|
|
}
|
|
}
|
|
|
|
fn cleanup(dir: &PathBuf) {
|
|
let _ = std::fs::remove_dir_all(dir);
|
|
}
|
|
}
|
|
|
|
impl Drop for Cleaner {
|
|
fn drop(&mut self) {
|
|
if self.should_clean {
|
|
Self::cleanup(&self.path)
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct Application {
|
|
port: u16,
|
|
server: Server,
|
|
}
|
|
|
|
impl Application {
|
|
pub async fn build<S>(config: Config, state: State, storage: S) -> Result<Self, anyhow::Error>
|
|
where
|
|
S: CollabStorage + Unpin,
|
|
{
|
|
let address = format!("{}:{}", config.application.host, config.application.port);
|
|
let listener = TcpListener::bind(&address)?;
|
|
let port = listener.local_addr().unwrap().port();
|
|
let server = run(listener, state, config, storage).await?;
|
|
|
|
Ok(Self { port, server })
|
|
}
|
|
|
|
pub async fn run_until_stopped(self) -> Result<(), std::io::Error> {
|
|
self.server.await
|
|
}
|
|
|
|
pub fn port(&self) -> u16 {
|
|
self.port
|
|
}
|
|
}
|
|
|
|
pub async fn run<S>(
|
|
listener: TcpListener,
|
|
state: State,
|
|
_config: Config,
|
|
storage: S,
|
|
) -> Result<Server, anyhow::Error>
|
|
where
|
|
S: CollabStorage + Unpin,
|
|
{
|
|
let collab_server = CollabManager::new(storage.clone()).unwrap().start();
|
|
let server = HttpServer::new(move || {
|
|
App::new()
|
|
.service(web::scope("/ws").service(establish_ws_connection))
|
|
.app_data(Data::new(collab_server.clone()))
|
|
.app_data(Data::new(state.clone()))
|
|
.app_data(Data::new(storage.clone()))
|
|
})
|
|
.listen(listener)?;
|
|
Ok(server.run())
|
|
}
|
|
|
|
#[get("/{token}")]
|
|
pub async fn establish_ws_connection(
|
|
request: HttpRequest,
|
|
payload: Payload,
|
|
token: Path<String>,
|
|
state: Data<State>,
|
|
server: Data<Addr<CollabManager<CollabMemoryStorageImpl>>>,
|
|
) -> Result<HttpResponse> {
|
|
tracing::trace!("{:?}", request);
|
|
let user = TestLoggedUser {
|
|
user_id: token.as_str().parse().unwrap(),
|
|
};
|
|
let client = CollabSession::new(
|
|
user,
|
|
server.get_ref().clone(),
|
|
Duration::from_secs(state.config.websocket.heartbeat_interval as u64),
|
|
Duration::from_secs(state.config.websocket.client_timeout as u64),
|
|
);
|
|
match ws::start(client, &request, payload) {
|
|
Ok(response) => Ok(response),
|
|
Err(e) => {
|
|
tracing::error!("🔴ws connection error: {:?}", e);
|
|
Err(e)
|
|
},
|
|
}
|
|
}
|
|
|
|
#[derive(serde::Deserialize, Clone, Debug)]
|
|
pub struct WebsocketSetting {
|
|
pub heartbeat_interval: u8,
|
|
pub client_timeout: u8,
|
|
}
|
|
|
|
impl Default for WebsocketSetting {
|
|
fn default() -> Self {
|
|
Self {
|
|
heartbeat_interval: 8,
|
|
client_timeout: 10,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(serde::Deserialize, Clone, Debug)]
|
|
pub struct Config {
|
|
pub application: ApplicationSetting,
|
|
pub websocket: WebsocketSetting,
|
|
}
|
|
|
|
impl Default for Config {
|
|
fn default() -> Self {
|
|
Self {
|
|
application: ApplicationSetting {
|
|
// Use a random OS port
|
|
port: 0,
|
|
host: "0.0.0.0".to_string(),
|
|
server_key: "".to_string(),
|
|
},
|
|
websocket: Default::default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(serde::Deserialize, Clone, Debug)]
|
|
pub struct ApplicationSetting {
|
|
#[serde(deserialize_with = "deserialize_number_from_string")]
|
|
pub port: u16,
|
|
pub host: String,
|
|
pub server_key: String,
|
|
}
|
|
#[derive(Clone)]
|
|
pub struct State {
|
|
pub config: Config,
|
|
}
|
|
pub async fn init_state(config: Config) -> State {
|
|
State { config }
|
|
}
|
|
|
|
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
|
|
pub struct TestLoggedUser {
|
|
pub user_id: i64,
|
|
}
|
|
|
|
impl Display for TestLoggedUser {
|
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
f.write_str("TestLoggedUser")
|
|
}
|
|
}
|
|
|
|
impl RealtimeUser for TestLoggedUser {
|
|
fn user_id(&self) -> &i64 {
|
|
&self.user_id
|
|
}
|
|
}
|