214 lines
5.3 KiB
Rust
214 lines
5.3 KiB
Rust
use crate::error::WorkerError;
|
|
use anyhow::anyhow;
|
|
use aws_sdk_s3::error::SdkError;
|
|
|
|
use anyhow::Result;
|
|
use aws_sdk_s3::operation::get_object::GetObjectError;
|
|
use aws_sdk_s3::primitives::ByteStream;
|
|
use axum::async_trait;
|
|
use base64::engine::general_purpose::STANDARD;
|
|
use base64::Engine;
|
|
use futures::AsyncReadExt;
|
|
use std::ops::Deref;
|
|
use std::path::{Path, PathBuf};
|
|
use tokio::fs;
|
|
use tokio::fs::File;
|
|
use tokio::io::AsyncWriteExt;
|
|
use tokio_util::compat::TokioAsyncReadCompatExt;
|
|
use tracing::error;
|
|
use uuid::Uuid;
|
|
|
|
#[async_trait]
|
|
pub trait S3Client: Send + Sync {
|
|
async fn get_blob_stream(&self, object_key: &str) -> Result<S3StreamResponse, WorkerError>;
|
|
async fn put_blob(
|
|
&self,
|
|
object_key: &str,
|
|
content: ByteStream,
|
|
content_type: Option<&str>,
|
|
) -> Result<(), WorkerError>;
|
|
async fn delete_blob(&self, object_key: &str) -> Result<(), WorkerError>;
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct S3ClientImpl {
|
|
pub inner: aws_sdk_s3::Client,
|
|
pub bucket: String,
|
|
}
|
|
|
|
impl Deref for S3ClientImpl {
|
|
type Target = aws_sdk_s3::Client;
|
|
|
|
fn deref(&self) -> &Self::Target {
|
|
&self.inner
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl S3Client for S3ClientImpl {
|
|
async fn get_blob_stream(&self, object_key: &str) -> Result<S3StreamResponse, WorkerError> {
|
|
match self
|
|
.inner
|
|
.get_object()
|
|
.bucket(&self.bucket)
|
|
.key(object_key)
|
|
.send()
|
|
.await
|
|
{
|
|
Ok(output) => {
|
|
let stream = output.body.into_async_read().compat();
|
|
let content_type = output.content_type;
|
|
let content_length = output.content_length;
|
|
Ok(S3StreamResponse {
|
|
stream: Box::new(stream),
|
|
content_type,
|
|
content_length,
|
|
})
|
|
},
|
|
Err(SdkError::ServiceError(service_err)) => match service_err.err() {
|
|
GetObjectError::NoSuchKey(_) => Err(WorkerError::RecordNotFound(format!(
|
|
"blob not found for key:{object_key}"
|
|
))),
|
|
_ => Err(WorkerError::from(anyhow!(
|
|
"Failed to get object from S3: {:?}",
|
|
service_err
|
|
))),
|
|
},
|
|
Err(err) => Err(WorkerError::from(anyhow!(
|
|
"Failed to get object from S3: {}",
|
|
err
|
|
))),
|
|
}
|
|
}
|
|
|
|
async fn put_blob(
|
|
&self,
|
|
object_key: &str,
|
|
content: ByteStream,
|
|
content_type: Option<&str>,
|
|
) -> Result<(), WorkerError> {
|
|
match self
|
|
.inner
|
|
.put_object()
|
|
.bucket(&self.bucket)
|
|
.key(object_key)
|
|
.body(content)
|
|
.content_type(content_type.unwrap_or("application/octet-stream"))
|
|
.send()
|
|
.await
|
|
{
|
|
Ok(_) => Ok(()),
|
|
Err(err) => Err(WorkerError::from(anyhow!(
|
|
"Failed to put object to S3: {}",
|
|
err
|
|
))),
|
|
}
|
|
}
|
|
|
|
async fn delete_blob(&self, object_key: &str) -> Result<(), WorkerError> {
|
|
match self
|
|
.inner
|
|
.delete_object()
|
|
.bucket(&self.bucket)
|
|
.key(object_key)
|
|
.send()
|
|
.await
|
|
{
|
|
Ok(_) => Ok(()),
|
|
Err(SdkError::ServiceError(service_err)) => Err(WorkerError::from(anyhow!(
|
|
"Failed to delete object from S3: {:?}",
|
|
service_err
|
|
))),
|
|
Err(err) => Err(WorkerError::from(anyhow!(
|
|
"Failed to delete object from S3: {}",
|
|
err
|
|
))),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct S3StreamResponse {
|
|
pub stream: Box<dyn futures::AsyncBufRead + Unpin + Send>,
|
|
pub content_type: Option<String>,
|
|
pub content_length: Option<i64>,
|
|
}
|
|
|
|
pub struct AutoRemoveDownloadedFile(PathBuf);
|
|
|
|
impl AsRef<Path> for AutoRemoveDownloadedFile {
|
|
fn as_ref(&self) -> &Path {
|
|
&self.0
|
|
}
|
|
}
|
|
|
|
impl AsRef<PathBuf> for AutoRemoveDownloadedFile {
|
|
fn as_ref(&self) -> &PathBuf {
|
|
&self.0
|
|
}
|
|
}
|
|
|
|
impl Deref for AutoRemoveDownloadedFile {
|
|
type Target = PathBuf;
|
|
|
|
fn deref(&self) -> &Self::Target {
|
|
&self.0
|
|
}
|
|
}
|
|
|
|
impl Drop for AutoRemoveDownloadedFile {
|
|
fn drop(&mut self) {
|
|
let path = self.0.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(err) = fs::remove_file(&path).await {
|
|
error!(
|
|
"Failed to delete the auto remove downloaded file: {:?}, error: {}",
|
|
path, err
|
|
)
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
pub async fn download_file(
|
|
storage_dir: &Path,
|
|
stream: Box<dyn futures::AsyncBufRead + Unpin + Send>,
|
|
expected_md5_base64: &Option<String>,
|
|
) -> Result<AutoRemoveDownloadedFile, anyhow::Error> {
|
|
let zip_file_path = storage_dir.join(format!("{}.zip", Uuid::new_v4()));
|
|
write_stream_to_file(&zip_file_path, expected_md5_base64, stream).await?;
|
|
Ok(AutoRemoveDownloadedFile(zip_file_path))
|
|
}
|
|
|
|
pub async fn write_stream_to_file(
|
|
file_path: &PathBuf,
|
|
expected_md5_base64: &Option<String>,
|
|
mut stream: Box<dyn futures::AsyncBufRead + Unpin + Send>,
|
|
) -> Result<(), anyhow::Error> {
|
|
let mut context = md5::Context::new();
|
|
let mut file = File::create(file_path).await?;
|
|
let mut buffer = vec![0u8; 1_048_576];
|
|
loop {
|
|
let bytes_read = stream.read(&mut buffer).await?;
|
|
if bytes_read == 0 {
|
|
break;
|
|
}
|
|
context.consume(&buffer[..bytes_read]);
|
|
file.write_all(&buffer[..bytes_read]).await?;
|
|
}
|
|
|
|
let digest = context.compute();
|
|
let md5_base64 = STANDARD.encode(digest.as_ref());
|
|
if let Some(expected_md5) = expected_md5_base64 {
|
|
if md5_base64 != *expected_md5 {
|
|
error!(
|
|
"[Import]: MD5 mismatch, expected: {}, current: {}",
|
|
expected_md5, md5_base64
|
|
);
|
|
return Err(anyhow!("MD5 mismatch"));
|
|
}
|
|
}
|
|
|
|
file.flush().await?;
|
|
Ok(())
|
|
}
|