diff --git a/libs/client-api/src/http.rs b/libs/client-api/src/http.rs index 5ac9c26c..3adef375 100644 --- a/libs/client-api/src/http.rs +++ b/libs/client-api/src/http.rs @@ -944,7 +944,9 @@ impl Client { .into_data() } - pub fn ws_url(&self, device_id: &str) -> Result { + pub async fn ws_url(&self, device_id: &str) -> Result { + self.refresh_if_required().await?; + let access_token = self.access_token()?; Ok(format!("{}/{}/{}", self.ws_addr, access_token, device_id)) } @@ -1107,12 +1109,7 @@ impl Client { .into_data() } - #[instrument(level = "debug", skip_all, err)] - pub async fn http_client_with_auth( - &self, - method: Method, - url: &str, - ) -> Result { + pub async fn refresh_if_required(&self) -> Result<(), AppResponseError> { let expires_at = self.token_expires_at()?; // Refresh token if it's about to expire @@ -1124,6 +1121,16 @@ impl Client { // Add 10 seconds buffer self.refresh_token().await?; } + Ok(()) + } + + #[instrument(level = "debug", skip_all, err)] + pub async fn http_client_with_auth( + &self, + method: Method, + url: &str, + ) -> Result { + self.refresh_if_required().await?; let access_token = self.access_token()?; trace!("start request: {}, method: {}", url, method); diff --git a/tests/user/update.rs b/tests/user/update.rs index 892f3540..8dbdfa67 100644 --- a/tests/user/update.rs +++ b/tests/user/update.rs @@ -166,7 +166,7 @@ async fn user_change_notify_test() { let device_id = "fake_device_id"; let _ = ws_client - .connect(c.ws_url(device_id).unwrap(), device_id) + .connect(c.ws_url(device_id).await.unwrap(), device_id) .await .unwrap(); diff --git a/tests/util/test_client.rs b/tests/util/test_client.rs index 6fbbb6f9..4cd68cf1 100644 --- a/tests/util/test_client.rs +++ b/tests/util/test_client.rs @@ -68,7 +68,7 @@ impl TestClient { if start_ws_conn { ws_client - .connect(api_client.ws_url(&device_id).unwrap(), &device_id) + .connect(api_client.ws_url(&device_id).await.unwrap(), &device_id) .await .unwrap(); } @@ -490,7 +490,7 @@ impl TestClient { self .ws_client .connect( - self.api_client.ws_url(&self.device_id).unwrap(), + self.api_client.ws_url(&self.device_id).await.unwrap(), &self.device_id, ) .await diff --git a/tests/websocket/connect.rs b/tests/websocket/connect.rs index 942a908c..03779da7 100644 --- a/tests/websocket/connect.rs +++ b/tests/websocket/connect.rs @@ -9,7 +9,26 @@ async fn realtime_connect_test() { let device_id = "fake_device_id"; loop { tokio::select! { - _ = ws_client.connect(c.ws_url(device_id).unwrap(), device_id) => {}, + _ = ws_client.connect(c.ws_url(device_id).await.unwrap(), device_id) => {}, + value = state.recv() => { + let new_state = value.unwrap(); + if new_state == ConnectState::Connected { + break; + } + }, + } + } +} + +#[tokio::test] +async fn realtime_connect_after_token_exp_test() { + let (c, _user) = generate_unique_registered_user_client().await; + let ws_client = WSClient::new(WSClientConfig::default(), c.clone()); + let mut state = ws_client.subscribe_connect_state(); + let device_id = "fake_device_id"; + loop { + tokio::select! { + _ = ws_client.connect(c.ws_url(device_id).await.unwrap(), device_id) => {}, value = state.recv() => { let new_state = value.unwrap(); if new_state == ConnectState::Connected { @@ -26,7 +45,7 @@ async fn realtime_disconnect_test() { let ws_client = WSClient::new(WSClientConfig::default(), c.clone()); let device_id = "fake_device_id"; ws_client - .connect(c.ws_url(device_id).unwrap(), device_id) + .connect(c.ws_url(device_id).await.unwrap(), device_id) .await .unwrap();