96 lines
2.6 KiB
Rust
96 lines
2.6 KiB
Rust
use anyhow::Result;
|
|
use directories::ProjectDirs;
|
|
use sqlx::sqlite::SqliteConnectOptions;
|
|
use sqlx::Row;
|
|
use sqlx::SqlitePool;
|
|
use tokio::fs;
|
|
use tonic::async_trait;
|
|
|
|
#[derive(Debug, sqlx::FromRow)]
|
|
pub struct ChatMessage {
|
|
pub id: i64,
|
|
pub text: String,
|
|
pub is_user: bool,
|
|
}
|
|
|
|
#[async_trait]
|
|
pub trait ChatRepository {
|
|
async fn save_message(&self, text: &str, is_user: &bool) -> Result<ChatMessage>;
|
|
async fn get_latest_messages(&self) -> Result<Vec<ChatMessage>>;
|
|
}
|
|
|
|
pub struct SqliteChatRepository {
|
|
pool: SqlitePool,
|
|
}
|
|
|
|
impl SqliteChatRepository {
|
|
pub async fn new() -> Result<Self> {
|
|
let project_dirs = ProjectDirs::from("com", "jarno", "wsagent")
|
|
.ok_or_else(|| anyhow::anyhow!("Could not find home directory!"))?;
|
|
let config_dir = project_dirs.config_dir();
|
|
fs::create_dir_all(config_dir).await?;
|
|
let db_path = config_dir.join("agent.db");
|
|
let connection_str = format!("sqlite:{}", db_path.display());
|
|
println!("Connection string: {}", connection_str);
|
|
|
|
let pool = SqlitePool::connect_with(
|
|
SqliteConnectOptions::new()
|
|
.filename(&db_path)
|
|
.create_if_missing(true),
|
|
)
|
|
.await?;
|
|
|
|
sqlx::migrate!("./migrations")
|
|
.run(&pool)
|
|
.await
|
|
.inspect_err(|e| eprintln!("Migration failed! {}", e))?;
|
|
|
|
Ok(Self { pool })
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl ChatRepository for SqliteChatRepository {
|
|
async fn save_message(&self, text: &str, is_user: &bool) -> Result<ChatMessage> {
|
|
let result = sqlx::query_as::<_, ChatMessage>(
|
|
r#"
|
|
INSERT INTO messages (text, is_user)
|
|
VALUES (?, ?)
|
|
RETURNING id, text, is_user
|
|
"#,
|
|
)
|
|
.bind(text)
|
|
.bind(is_user)
|
|
.fetch_one(&self.pool)
|
|
.await
|
|
.inspect_err(|e| println!("sql error: {}", e))?;
|
|
Ok(result)
|
|
}
|
|
|
|
async fn get_latest_messages(&self) -> Result<Vec<ChatMessage>> {
|
|
let rows = sqlx::query(
|
|
r#"
|
|
SELECT * FROM (
|
|
SELECT id, text, is_user
|
|
FROM messages
|
|
ORDER BY id DESC
|
|
LIMIT 10
|
|
) AS subquery ORDER BY id ASC"#,
|
|
)
|
|
.fetch_all(&self.pool)
|
|
.await
|
|
.inspect_err(|e| println!("sql error: {}", e))?;
|
|
|
|
let messages = rows
|
|
.into_iter()
|
|
.map(|row| ChatMessage {
|
|
id: row.get(0),
|
|
text: row.get(1),
|
|
is_user: row.get(2),
|
|
})
|
|
.collect();
|
|
|
|
Ok(messages)
|
|
}
|
|
}
|