diff --git a/libs/wasm-test/README.md b/libs/wasm-test/README.md index 6337619f..3f15453d 100644 --- a/libs/wasm-test/README.md +++ b/libs/wasm-test/README.md @@ -1,4 +1,13 @@ +Before running the test, AppFlowy Cloud need to run with nginx server by this command: + +```shell +docker compose up -d +``` + + +```shell + ## Run test > Before executing the test, you need to install the [Chrome Driver](https://chromedriver.chromium.org/downloads). If diff --git a/src/api/ws.rs b/src/api/ws.rs index 34c3da89..9bb4ed08 100644 --- a/src/api/ws.rs +++ b/src/api/ws.rs @@ -3,6 +3,7 @@ use actix::Addr; use actix_web::web::{Data, Path, Payload}; use actix_web::{get, web, HttpRequest, HttpResponse, Result, Scope}; use actix_web_actors::ws; +use std::collections::HashMap; use std::sync::Arc; use crate::biz::collab::storage::CollabAccessControlStorage; @@ -16,9 +17,10 @@ use realtime::client::rt_client::RealtimeClient; use realtime::server::RealtimeServer; use semver::Version; + use shared_entity::response::AppResponseError; use std::time::Duration; -use tracing::{debug, error, instrument}; +use tracing::{debug, error, instrument, trace}; pub fn ws_scope() -> Scope { web::scope("/ws") @@ -65,12 +67,21 @@ pub async fn establish_ws_connection_v1( payload: Payload, state: Data, server: Data, + web::Query(query_params): web::Query>, ) -> Result { + // Try to parse the connect info from the request body + // If it fails, try to parse it from the query params let ConnectInfo { access_token, client_version, device_id, - } = ConnectInfo::try_from(&request)?; + } = match ConnectInfo::parse_from(&request) { + Ok(info) => info, + Err(_) => { + trace!("Failed to parse connect info from request body. Trying to parse from query params."); + ConnectInfo::parse_from(&query_params)? + }, + }; start_connect( &request, @@ -122,7 +133,7 @@ async fn start_connect( { Ok(response) => Ok(response), Err(e) => { - tracing::error!("🔴ws connection error: {:?}", e); + error!("🔴ws connection error: {:?}", e); Err(e) }, } @@ -142,44 +153,56 @@ struct ConnectInfo { device_id: String, } -impl TryFrom<&HttpRequest> for ConnectInfo { - type Error = AppError; - fn try_from(req: &HttpRequest) -> Result { - let headers = req.headers(); +const CLIENT_VERSION: &str = "client-version"; +const DEVICE_ID: &str = "device-id"; - let access_token = match headers.get(AUTHORIZATION) { - Some(token) => token - .to_str() - .map_err(|_| AppError::InvalidRequest("invalid access token".to_string()))?, - None => return Err(AppError::OAuthError("no access token".to_string())), - }; +// Trait for parameter extraction +trait ExtractParameter { + fn extract_param(&self, key: &str) -> Result; +} - let client_version = headers - .get("client-version") - .map(|v| { - v.to_str() - .map_err(|_| AppError::InvalidRequest("invalid client version".to_string())) - .and_then(|v| { - Version::parse(v) - .map_err(|_| AppError::InvalidRequest("fail to parse client version".to_string())) - }) +// Implement the trait for HashMap +impl ExtractParameter for HashMap { + fn extract_param(&self, key: &str) -> Result { + self + .get(key) + .ok_or_else(|| { + AppError::InvalidRequest(format!("Parameter with given key:{} not found", key)) }) - .unwrap_or_else(|| { - error!("fail to get client version from header, use default version 0.5.0"); - Ok(Version::new(0, 5, 0)) - })?; + .map(|s| s.to_string()) + } +} - let device_id = match headers.get("device-id") { - Some(device_id) => device_id - .to_str() - .map_err(|_| AppError::InvalidRequest("invalid device id".to_string()))?, - None => return Err(AppError::InvalidRequest("empty device id".to_string())), - }; +// Implement the trait for HttpRequest +impl ExtractParameter for HttpRequest { + fn extract_param(&self, key: &str) -> Result { + self + .headers() + .get(key) + .ok_or_else(|| AppError::InvalidRequest(format!("Header with given key:{} not found", key))) + .and_then(|value| { + value + .to_str() + .map_err(|_| { + AppError::InvalidRequest(format!("Invalid header value for given key:{}", key)) + }) + .map(|s| s.to_string()) + }) + } +} + +impl ConnectInfo { + fn parse_from(source: &T) -> Result { + let access_token = source.extract_param(AUTHORIZATION.as_str())?; + let client_version_str = source.extract_param(CLIENT_VERSION)?; + let client_version = Version::parse(&client_version_str) + .map_err(|_| AppError::InvalidRequest(format!("Invalid version:{}", client_version_str)))?; + let device_id = source.extract_param(DEVICE_ID)?; Ok(Self { - access_token: access_token.to_string(), + access_token, client_version, - device_id: device_id.to_string(), + device_id, }) } }