248 lines
6.7 KiB
Rust
248 lines
6.7 KiB
Rust
use anyhow::Result;
|
|
use appflowy_worker::error::WorkerError;
|
|
use appflowy_worker::import_worker::report::{ImportNotifier, ImportProgress};
|
|
use appflowy_worker::import_worker::worker::{run_import_worker, ImportTask};
|
|
use appflowy_worker::s3_client::{BlobMeta, S3Client, S3StreamResponse};
|
|
use aws_sdk_s3::primitives::ByteStream;
|
|
use axum::async_trait;
|
|
|
|
use redis::aio::ConnectionManager;
|
|
use redis::AsyncCommands;
|
|
use redis::RedisResult;
|
|
use serde_json::json;
|
|
use sqlx::PgPool;
|
|
use sqlx::__rt::timeout;
|
|
use std::sync::{Arc, Once};
|
|
use std::time::Duration;
|
|
use tokio::runtime::Builder;
|
|
use tokio::task::LocalSet;
|
|
|
|
use tracing_subscriber::fmt::Subscriber;
|
|
use tracing_subscriber::util::SubscriberInitExt;
|
|
use tracing_subscriber::EnvFilter;
|
|
|
|
#[sqlx::test(migrations = false)]
|
|
async fn create_custom_task_test(pg_pool: PgPool) {
|
|
let redis_client = redis_connection_manager().await;
|
|
let stream_name = uuid::Uuid::new_v4().to_string();
|
|
let notifier = Arc::new(MockNotifier::new());
|
|
let mut task_provider = MockTaskProvider::new(redis_client.clone(), stream_name.clone());
|
|
let _ = run_importer_worker(
|
|
pg_pool,
|
|
redis_client.clone(),
|
|
notifier.clone(),
|
|
stream_name,
|
|
3,
|
|
);
|
|
|
|
let mut task_workspace_ids = vec![];
|
|
// generate 5 tasks
|
|
for _ in 0..5 {
|
|
let workspace_id = uuid::Uuid::new_v4().to_string();
|
|
task_workspace_ids.push(workspace_id.clone());
|
|
task_provider
|
|
.create_task(ImportTask::Custom(json!({"workspace_id": workspace_id})))
|
|
.await;
|
|
}
|
|
|
|
let mut rx = notifier.subscribe();
|
|
timeout(Duration::from_secs(30), async {
|
|
while let Ok(task) = rx.recv().await {
|
|
task_workspace_ids.retain(|_id| {
|
|
if let ImportProgress::Finished(_result) = &task {
|
|
return false;
|
|
}
|
|
true
|
|
});
|
|
|
|
if task_workspace_ids.is_empty() {
|
|
break;
|
|
}
|
|
}
|
|
})
|
|
.await
|
|
.unwrap();
|
|
}
|
|
|
|
// #[tokio::test]
|
|
// async fn consume_group_task_test() {
|
|
// let mut redis_client = redis_client().await;
|
|
// let stream_name = format!("import_task_stream_{}", uuid::Uuid::new_v4());
|
|
// let consumer_group = "import_task_group";
|
|
// let consumer_name = "appflowy_worker";
|
|
// let workspace_id = uuid::Uuid::new_v4().to_string();
|
|
// let user_uuid = uuid::Uuid::new_v4().to_string();
|
|
//
|
|
// let _: RedisResult<()> = redis_client.xgroup_create_mkstream(&stream_name, consumer_group, "0");
|
|
// // 1. insert a task
|
|
// let task = json!({
|
|
// "notion": {
|
|
// "uid": 1,
|
|
// "user_uuid": user_uuid,
|
|
// "workspace_id": workspace_id,
|
|
// "s3_key": workspace_id,
|
|
// "file_type": "zip",
|
|
// "host": "http::localhost",
|
|
// }
|
|
// });
|
|
//
|
|
// let _: () = redis_client
|
|
// .xadd(&stream_name, "*", &[("task", task.to_string())])
|
|
// .unwrap();
|
|
// tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
|
//
|
|
// // 2. consume a task
|
|
// let options = StreamReadOptions::default()
|
|
// .group(consumer_group, consumer_name)
|
|
// .count(3);
|
|
//
|
|
// let tasks: StreamReadReply = redis_client
|
|
// .xread_options(&[&stream_name], &[">"], &options)
|
|
// .unwrap();
|
|
// assert!(!tasks.keys.is_empty());
|
|
//
|
|
// for stream_key in tasks.keys {
|
|
// for stream_id in stream_key.ids {
|
|
// let task_str = match stream_id.map.get("task") {
|
|
// Some(value) => match value {
|
|
// Value::Data(data) => String::from_utf8_lossy(data).to_string(),
|
|
// _ => panic!("Task field is not a string"),
|
|
// },
|
|
// None => continue,
|
|
// };
|
|
//
|
|
// let _ = from_str::<ImportTask>(&task_str).unwrap();
|
|
// let _: () = redis_client
|
|
// .xack(&stream_name, consumer_group, &[stream_id.id.clone()])
|
|
// .unwrap();
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
pub async fn redis_connection_manager() -> redis::aio::ConnectionManager {
|
|
let redis_uri = "redis://localhost:6379";
|
|
redis::Client::open(redis_uri)
|
|
.expect("failed to create redis client")
|
|
.get_connection_manager()
|
|
.await
|
|
.expect("failed to get redis connection manager")
|
|
}
|
|
|
|
fn run_importer_worker(
|
|
pg_pool: PgPool,
|
|
redis_client: ConnectionManager,
|
|
notifier: Arc<dyn ImportNotifier>,
|
|
stream_name: String,
|
|
tick_interval_secs: u64,
|
|
) -> std::thread::JoinHandle<()> {
|
|
setup_log();
|
|
let max_import_file_size = 1_000_000_000;
|
|
|
|
std::thread::spawn(move || {
|
|
let runtime = Builder::new_current_thread().enable_all().build().unwrap();
|
|
let local_set = LocalSet::new();
|
|
let import_worker_fut = local_set.run_until(run_import_worker(
|
|
pg_pool,
|
|
redis_client,
|
|
None,
|
|
Arc::new(MockS3Client),
|
|
notifier,
|
|
&stream_name,
|
|
tick_interval_secs,
|
|
max_import_file_size,
|
|
));
|
|
runtime.block_on(import_worker_fut).unwrap();
|
|
})
|
|
}
|
|
|
|
struct MockTaskProvider {
|
|
redis_client: ConnectionManager,
|
|
stream_name: String,
|
|
}
|
|
|
|
impl MockTaskProvider {
|
|
fn new(redis_client: ConnectionManager, stream_name: String) -> Self {
|
|
Self {
|
|
redis_client,
|
|
stream_name,
|
|
}
|
|
}
|
|
|
|
async fn create_task(&mut self, task: ImportTask) {
|
|
let task = serde_json::to_string(&task).unwrap();
|
|
let result: RedisResult<()> = self
|
|
.redis_client
|
|
.xadd(&self.stream_name, "*", &[("task", task.to_string())])
|
|
.await;
|
|
result.unwrap();
|
|
}
|
|
}
|
|
|
|
struct MockNotifier {
|
|
tx: tokio::sync::broadcast::Sender<ImportProgress>,
|
|
}
|
|
|
|
impl MockNotifier {
|
|
fn new() -> Self {
|
|
let (tx, _) = tokio::sync::broadcast::channel(100);
|
|
Self { tx }
|
|
}
|
|
fn subscribe(&self) -> tokio::sync::broadcast::Receiver<ImportProgress> {
|
|
self.tx.subscribe()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl ImportNotifier for MockNotifier {
|
|
async fn notify_progress(&self, progress: ImportProgress) {
|
|
println!("notify_progress: {:?}", progress);
|
|
self.tx.send(progress).unwrap();
|
|
}
|
|
}
|
|
|
|
struct MockS3Client;
|
|
|
|
#[async_trait]
|
|
impl S3Client for MockS3Client {
|
|
async fn get_blob_stream(&self, _object_key: &str) -> Result<S3StreamResponse, WorkerError> {
|
|
todo!()
|
|
}
|
|
|
|
async fn put_blob(
|
|
&self,
|
|
_object_key: &str,
|
|
_content: ByteStream,
|
|
_content_type: Option<&str>,
|
|
) -> std::result::Result<(), WorkerError> {
|
|
todo!()
|
|
}
|
|
|
|
async fn delete_blob(&self, _object_key: &str) -> Result<(), WorkerError> {
|
|
Ok(())
|
|
}
|
|
|
|
async fn is_blob_exist(&self, _object_key: &str) -> Result<bool, WorkerError> {
|
|
Ok(false)
|
|
}
|
|
|
|
async fn get_blob_meta(&self, _object_key: &str) -> Result<BlobMeta, WorkerError> {
|
|
todo!()
|
|
}
|
|
}
|
|
|
|
pub fn setup_log() {
|
|
static START: Once = Once::new();
|
|
START.call_once(|| {
|
|
let level = std::env::var("RUST_LOG").unwrap_or("trace".to_string());
|
|
let mut filters = vec![];
|
|
filters.push(format!("appflowy_worker={}", level));
|
|
std::env::set_var("RUST_LOG", filters.join(","));
|
|
|
|
let subscriber = Subscriber::builder()
|
|
.with_ansi(true)
|
|
.with_env_filter(EnvFilter::from_default_env())
|
|
.finish();
|
|
subscriber.try_init().unwrap();
|
|
});
|
|
}
|