feat: moved components to their own module, use ollama for the chat.
This commit is contained in:
@@ -7,7 +7,7 @@ use tokio::fs;
|
||||
use tonic::async_trait;
|
||||
|
||||
#[derive(Debug, sqlx::FromRow)]
|
||||
pub struct ChatMessage {
|
||||
pub struct ChatMessageData {
|
||||
pub id: i64,
|
||||
pub text: String,
|
||||
pub is_user: bool,
|
||||
@@ -15,8 +15,8 @@ pub struct ChatMessage {
|
||||
|
||||
#[async_trait]
|
||||
pub trait ChatRepository {
|
||||
async fn save_message(&self, text: &str, is_user: &bool) -> Result<ChatMessage>;
|
||||
async fn get_latest_messages(&self) -> Result<Vec<ChatMessage>>;
|
||||
async fn save_message(&self, text: &str, is_user: &bool) -> Result<ChatMessageData>;
|
||||
async fn get_latest_messages(&self) -> Result<Vec<ChatMessageData>>;
|
||||
}
|
||||
|
||||
pub struct SqliteChatRepository {
|
||||
@@ -51,8 +51,8 @@ impl SqliteChatRepository {
|
||||
|
||||
#[async_trait]
|
||||
impl ChatRepository for SqliteChatRepository {
|
||||
async fn save_message(&self, text: &str, is_user: &bool) -> Result<ChatMessage> {
|
||||
let result = sqlx::query_as::<_, ChatMessage>(
|
||||
async fn save_message(&self, text: &str, is_user: &bool) -> Result<ChatMessageData> {
|
||||
let result = sqlx::query_as::<_, ChatMessageData>(
|
||||
r#"
|
||||
INSERT INTO messages (text, is_user)
|
||||
VALUES (?, ?)
|
||||
@@ -67,7 +67,7 @@ impl ChatRepository for SqliteChatRepository {
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn get_latest_messages(&self) -> Result<Vec<ChatMessage>> {
|
||||
async fn get_latest_messages(&self) -> Result<Vec<ChatMessageData>> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT * FROM (
|
||||
@@ -83,7 +83,7 @@ impl ChatRepository for SqliteChatRepository {
|
||||
|
||||
let messages = rows
|
||||
.into_iter()
|
||||
.map(|row| ChatMessage {
|
||||
.map(|row| ChatMessageData {
|
||||
id: row.get(0),
|
||||
text: row.get(1),
|
||||
is_user: row.get(2),
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
use crate::chatpersistence::{ChatMessage, ChatRepository};
|
||||
use crate::chatpersistence::{ChatMessageData, ChatRepository};
|
||||
use anyhow::Result;
|
||||
use genai::chat::{ChatMessage, ChatRequest};
|
||||
use genai::Client;
|
||||
use shared::ai::ai_daemon_server::AiDaemon;
|
||||
use shared::ai::{
|
||||
ChatHistoryRequest, ChatHistoryResponse, ChatMessage as CMessage, ChatRequest as CRequest,
|
||||
@@ -9,11 +12,15 @@ use tonic::{Code, Request, Response, Status};
|
||||
|
||||
pub struct DaemonServer {
|
||||
repo: Arc<dyn ChatRepository + Send + Sync>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl DaemonServer {
|
||||
pub fn new(repo: Arc<dyn ChatRepository + Send + Sync>) -> Self {
|
||||
Self { repo }
|
||||
pub fn new(repo: Arc<dyn ChatRepository + Send + Sync>, client: Client) -> Self {
|
||||
Self {
|
||||
repo: repo,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +28,17 @@ impl DaemonServer {
|
||||
impl AiDaemon for DaemonServer {
|
||||
async fn chat(&self, request: Request<CRequest>) -> Result<Response<CResponse>, Status> {
|
||||
let r = request.into_inner();
|
||||
let mut messages = gather_history(self.repo.clone())
|
||||
.await
|
||||
.map_err(|e| Status::new(Code::Internal, e.to_string()))?;
|
||||
messages.push(ChatMessage::user(r.text()));
|
||||
let model = "llama3.2:latest";
|
||||
let response = self
|
||||
.client
|
||||
.exec_chat(model, ChatRequest::new(messages), None)
|
||||
.await
|
||||
.map_err(|e| Status::new(Code::Internal, e.to_string()))?;
|
||||
|
||||
let user_message = message_to_dto(
|
||||
&self
|
||||
.repo
|
||||
@@ -28,11 +46,18 @@ impl AiDaemon for DaemonServer {
|
||||
.await
|
||||
.map_err(|e| Status::new(Code::Internal, e.to_string()))?,
|
||||
);
|
||||
let response_text = format!("Pong: {}", r.text());
|
||||
let response_text = match response.first_text() {
|
||||
Some(t) => t,
|
||||
None => "[No response from AI]",
|
||||
};
|
||||
|
||||
println!("User: {}", r.text());
|
||||
println!("AI: {}", response_text.clone());
|
||||
|
||||
let ai_message = message_to_dto(
|
||||
&self
|
||||
.repo
|
||||
.save_message(response_text.as_str(), &false)
|
||||
.save_message(response_text, &false)
|
||||
.await
|
||||
.map_err(|e| Status::new(Code::Internal, e.to_string()))?,
|
||||
);
|
||||
@@ -73,10 +98,21 @@ impl AiDaemon for DaemonServer {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn message_to_dto(msg: &ChatMessage) -> CMessage {
|
||||
pub fn message_to_dto(msg: &ChatMessageData) -> CMessage {
|
||||
CMessage {
|
||||
id: msg.id,
|
||||
text: msg.text.clone(),
|
||||
is_user: msg.is_user,
|
||||
}
|
||||
}
|
||||
|
||||
async fn gather_history(repo: Arc<dyn ChatRepository + Send + Sync>) -> Result<Vec<ChatMessage>> {
|
||||
let messages = repo.get_latest_messages().await?;
|
||||
Ok(messages
|
||||
.iter()
|
||||
.map(|m| match m.is_user {
|
||||
true => ChatMessage::assistant(m.text.clone()),
|
||||
false => ChatMessage::user(m.text.clone()),
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
@@ -29,9 +29,11 @@ async fn prompt_ollama(
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let chat_repo = SqliteChatRepository::new().await?;
|
||||
|
||||
let client = Client::default();
|
||||
|
||||
let addr_s = "[::1]:50051";
|
||||
let addr = addr_s.parse().unwrap();
|
||||
let daemon = DaemonServer::new(Arc::new(chat_repo));
|
||||
let daemon = DaemonServer::new(Arc::new(chat_repo), client);
|
||||
let reflection_service = tonic_reflection::server::Builder::configure()
|
||||
.register_encoded_file_descriptor_set(shared::ai::FILE_DESCRIPTOR_SET)
|
||||
.build_v1()?;
|
||||
|
||||
Reference in New Issue
Block a user