WIP on persisting chat history on daemon
This commit is contained in:
@@ -4,8 +4,11 @@ version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.101"
|
||||
directories = "6.0.0"
|
||||
genai = "0.5.3"
|
||||
shared = { path = "../shared" }
|
||||
sqlx = { version = "0.8.6", features = ["runtime-tokio", "sqlite", "macros"] }
|
||||
tokio = { version = "1.49.0", features = ["full"] }
|
||||
tonic = "0.14.2"
|
||||
tonic-reflection = "0.14.2"
|
||||
|
||||
83
crates/daemon/src/chatpersistence.rs
Normal file
83
crates/daemon/src/chatpersistence.rs
Normal file
@@ -0,0 +1,83 @@
|
||||
use anyhow::Result;
|
||||
use directories::ProjectDirs;
|
||||
use sqlx::sqlite::SqliteConnectOptions;
|
||||
use sqlx::Row;
|
||||
use sqlx::SqlitePool;
|
||||
use tokio::fs;
|
||||
use tonic::async_trait;
|
||||
|
||||
pub struct ChatMessage {
|
||||
pub id: i64,
|
||||
pub text: String,
|
||||
pub is_user: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ChatRepository {
|
||||
async fn save_message(&self, text: &str, is_user: &bool) -> Result<()>;
|
||||
async fn get_all_messages(&self) -> Result<Vec<ChatMessage>>;
|
||||
}
|
||||
|
||||
pub struct SqliteChatRepository {
|
||||
pool: SqlitePool,
|
||||
}
|
||||
|
||||
impl SqliteChatRepository {
|
||||
pub async fn new() -> Result<Self> {
|
||||
let project_dirs = ProjectDirs::from("com", "jarno", "wsagent")
|
||||
.ok_or_else(|| anyhow::anyhow!("Could not find home directory!"))?;
|
||||
let config_dir = project_dirs.config_dir();
|
||||
fs::create_dir_all(config_dir).await?;
|
||||
let db_path = config_dir.join("agent.db");
|
||||
let connection_str = format!("sqlite:{}", db_path.display());
|
||||
println!("Connection string: {}", connection_str);
|
||||
|
||||
let pool = SqlitePool::connect_with(
|
||||
SqliteConnectOptions::new()
|
||||
.filename(&db_path)
|
||||
.create_if_missing(true),
|
||||
)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
"CREATE TABLE IF NOT EXISTS message (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
text TEXT NOT NULL,
|
||||
is_user BOOL NOT NULL
|
||||
)",
|
||||
)
|
||||
.execute(&pool)
|
||||
.await?;
|
||||
|
||||
Ok(Self { pool })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChatRepository for SqliteChatRepository {
|
||||
async fn save_message(&self, text: &str, is_user: &bool) -> Result<()> {
|
||||
sqlx::query("INSERT INTO messages (text, is_user) values (?, ?)")
|
||||
.bind(text)
|
||||
.bind(is_user)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_all_messages(&self) -> Result<Vec<ChatMessage>> {
|
||||
let rows = sqlx::query("SELECT id, text, is_user FROM messages ORDER BY id DESC LIMIT 10")
|
||||
.fetch_all(&self.pool)
|
||||
.await?;
|
||||
|
||||
let messages = rows
|
||||
.into_iter()
|
||||
.map(|row| ChatMessage {
|
||||
id: row.get(0),
|
||||
text: row.get(1),
|
||||
is_user: row.get(2),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(messages)
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,16 @@
|
||||
use genai::chat::{ChatMessage, ChatRequest};
|
||||
mod chatpersistence;
|
||||
|
||||
use genai::chat::{ChatMessage, ChatRequest, ChatResponse};
|
||||
use genai::Client;
|
||||
use shared::ai::ai_daemon_server::{AiDaemon, AiDaemonServer};
|
||||
use shared::ai::{PromptRequest, PromptResponse};
|
||||
use shared::ai::{
|
||||
ChatHistoryRequest, ChatHistoryResponse, ChatRequest as CRequest, ChatResponse as CResponse,
|
||||
PromptRequest, PromptResponse,
|
||||
};
|
||||
use tonic::{transport::Server, Request, Response, Status};
|
||||
|
||||
use chatpersistence::SqliteChatRepository;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct DaemonServer {}
|
||||
|
||||
@@ -24,6 +31,27 @@ impl AiDaemon for DaemonServer {
|
||||
let reply = PromptResponse { response: response };
|
||||
Ok(Response::new(reply))
|
||||
}
|
||||
|
||||
async fn chat(&self, request: Request<CRequest>) -> Result<Response<CResponse>, Status> {
|
||||
let response = CResponse {
|
||||
id: 1,
|
||||
chat_id: 1,
|
||||
text: "asdf".to_string(),
|
||||
is_user: false,
|
||||
};
|
||||
return Ok(Response::new(response));
|
||||
}
|
||||
|
||||
async fn chat_history(
|
||||
&self,
|
||||
request: Request<ChatHistoryRequest>,
|
||||
) -> Result<Response<ChatHistoryResponse>, Status> {
|
||||
let response = ChatHistoryResponse {
|
||||
chat_id: 1,
|
||||
history: vec![],
|
||||
};
|
||||
Ok(Response::new(response))
|
||||
}
|
||||
}
|
||||
|
||||
async fn prompt_ollama(
|
||||
@@ -42,6 +70,8 @@ async fn prompt_ollama(
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let chat_repo = SqliteChatRepository::new().await?;
|
||||
|
||||
let addr_s = "[::1]:50051";
|
||||
let addr = addr_s.parse().unwrap();
|
||||
let daemon = DaemonServer::default();
|
||||
|
||||
Reference in New Issue
Block a user