Files
ws-agent/crates/daemon/src/chatpersistence.rs

96 lines
2.6 KiB
Rust

use anyhow::Result;
use directories::ProjectDirs;
use sqlx::sqlite::SqliteConnectOptions;
use sqlx::Row;
use sqlx::SqlitePool;
use tokio::fs;
use tonic::async_trait;
#[derive(Debug, sqlx::FromRow)]
pub struct ChatMessage {
pub id: i64,
pub text: String,
pub is_user: bool,
}
#[async_trait]
pub trait ChatRepository {
async fn save_message(&self, text: &str, is_user: &bool) -> Result<ChatMessage>;
async fn get_latest_messages(&self) -> Result<Vec<ChatMessage>>;
}
pub struct SqliteChatRepository {
pool: SqlitePool,
}
impl SqliteChatRepository {
pub async fn new() -> Result<Self> {
let project_dirs = ProjectDirs::from("com", "jarno", "wsagent")
.ok_or_else(|| anyhow::anyhow!("Could not find home directory!"))?;
let config_dir = project_dirs.config_dir();
fs::create_dir_all(config_dir).await?;
let db_path = config_dir.join("agent.db");
let connection_str = format!("sqlite:{}", db_path.display());
println!("Connection string: {}", connection_str);
let pool = SqlitePool::connect_with(
SqliteConnectOptions::new()
.filename(&db_path)
.create_if_missing(true),
)
.await?;
sqlx::migrate!("./migrations")
.run(&pool)
.await
.inspect_err(|e| eprintln!("Migration failed! {}", e))?;
Ok(Self { pool })
}
}
#[async_trait]
impl ChatRepository for SqliteChatRepository {
async fn save_message(&self, text: &str, is_user: &bool) -> Result<ChatMessage> {
let result = sqlx::query_as::<_, ChatMessage>(
r#"
INSERT INTO messages (text, is_user)
VALUES (?, ?)
RETURNING id, text, is_user
"#,
)
.bind(text)
.bind(is_user)
.fetch_one(&self.pool)
.await
.inspect_err(|e| println!("sql error: {}", e))?;
Ok(result)
}
async fn get_latest_messages(&self) -> Result<Vec<ChatMessage>> {
let rows = sqlx::query(
r#"
SELECT * FROM (
SELECT id, text, is_user
FROM messages
ORDER BY id DESC
LIMIT 10
) AS subquery ORDER BY id ASC"#,
)
.fetch_all(&self.pool)
.await
.inspect_err(|e| println!("sql error: {}", e))?;
let messages = rows
.into_iter()
.map(|row| ChatMessage {
id: row.get(0),
text: row.get(1),
is_user: row.get(2),
})
.collect();
Ok(messages)
}
}