149 lines
4.5 KiB
Rust
149 lines
4.5 KiB
Rust
use crate::chatpersistence::{ChatMessageData, ChatRepository};
|
|
use anyhow::Result;
|
|
use genai::chat::{ChatMessage, ChatRequest};
|
|
use genai::Client;
|
|
use shared::ai::ai_service_server::AiService;
|
|
use shared::ai::{
|
|
ChatHistoryRequest, ChatHistoryResponse, ChatMessage as CMessage, ChatRequest as CRequest,
|
|
ChatResponse as CResponse, DaemonStatusRequest, DaemonStatusResponse,
|
|
};
|
|
use std::sync::Arc;
|
|
use tonic::{Code, Request, Response, Status};
|
|
|
|
pub struct DaemonServer {
|
|
repo: Arc<dyn ChatRepository + Send + Sync>,
|
|
client: Client,
|
|
}
|
|
|
|
impl DaemonServer {
|
|
pub fn new(repo: Arc<dyn ChatRepository + Send + Sync>, client: Client) -> Self {
|
|
Self {
|
|
repo: repo,
|
|
client: client,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tonic::async_trait]
|
|
impl AiService for DaemonServer {
|
|
async fn chat(&self, request: Request<CRequest>) -> Result<Response<CResponse>, Status> {
|
|
let r = request.into_inner();
|
|
let chat_id = id_or_new(self.repo.clone(), r.chat_id)
|
|
.await
|
|
.map_err(|e| Status::new(Code::Internal, e.to_string()))?;
|
|
let mut messages = gather_history(self.repo.clone(), &chat_id)
|
|
.await
|
|
.map_err(|e| Status::new(Code::Internal, e.to_string()))?;
|
|
messages.push(ChatMessage::user(r.text()));
|
|
let model = "llama3.2:latest";
|
|
let response = self
|
|
.client
|
|
.exec_chat(model, ChatRequest::new(messages), None)
|
|
.await
|
|
.map_err(|e| Status::new(Code::Internal, e.to_string()))?;
|
|
|
|
let user_message = message_to_dto(
|
|
&self
|
|
.repo
|
|
.save_message(r.text(), &true, &chat_id)
|
|
.await
|
|
.map_err(|e| Status::new(Code::Internal, e.to_string()))?,
|
|
);
|
|
let response_text = match response.first_text() {
|
|
Some(t) => t,
|
|
None => "[No response from AI]",
|
|
};
|
|
|
|
println!("User: {}", r.text());
|
|
println!("AI: {}", response_text);
|
|
|
|
let ai_message = message_to_dto(
|
|
&self
|
|
.repo
|
|
.save_message(response_text, &false, &chat_id)
|
|
.await
|
|
.map_err(|e| Status::new(Code::Internal, e.to_string()))?,
|
|
);
|
|
let response = CResponse {
|
|
chat_id: ai_message.chat_id,
|
|
messages: vec![user_message, ai_message],
|
|
};
|
|
return Ok(Response::new(response));
|
|
}
|
|
|
|
async fn chat_history(
|
|
&self,
|
|
request: Request<ChatHistoryRequest>,
|
|
) -> Result<Response<ChatHistoryResponse>, Status> {
|
|
let chat_id = get_latest_chat_id(self.repo.clone(), request.into_inner().chat_id)
|
|
.await
|
|
.map_err(|e| Status::new(Code::Internal, e.to_string()))?;
|
|
let messages = self
|
|
.repo
|
|
.get_latest_messages(&chat_id, &20)
|
|
.await
|
|
.map_err(|e| Status::new(Code::Internal, e.to_string()))?;
|
|
|
|
let response = ChatHistoryResponse {
|
|
chat_id: chat_id,
|
|
history: messages.iter().map(|m| message_to_dto(m)).collect(),
|
|
};
|
|
Ok(Response::new(response))
|
|
}
|
|
|
|
async fn daemon_status(
|
|
&self,
|
|
_: Request<DaemonStatusRequest>,
|
|
) -> Result<Response<DaemonStatusResponse>, Status> {
|
|
let status = DaemonStatusResponse {
|
|
is_ok: true,
|
|
message: None,
|
|
error: None,
|
|
};
|
|
Ok(Response::new(status))
|
|
}
|
|
}
|
|
|
|
pub fn message_to_dto(msg: &ChatMessageData) -> CMessage {
|
|
CMessage {
|
|
id: msg.id,
|
|
chat_id: msg.chat_id,
|
|
text: msg.text.clone(),
|
|
is_user: msg.is_user,
|
|
}
|
|
}
|
|
|
|
async fn gather_history(
|
|
repo: Arc<dyn ChatRepository + Send + Sync>,
|
|
chat_id: &i64,
|
|
) -> Result<Vec<ChatMessage>> {
|
|
let messages = repo.get_latest_messages(chat_id, &10).await?;
|
|
Ok(messages
|
|
.iter()
|
|
.map(|m| match m.is_user {
|
|
true => ChatMessage::assistant(m.text.clone()),
|
|
false => ChatMessage::user(m.text.clone()),
|
|
})
|
|
.collect())
|
|
}
|
|
|
|
async fn get_latest_chat_id(
|
|
repo: Arc<dyn ChatRepository + Send + Sync>,
|
|
chat_id: Option<i64>,
|
|
) -> Result<i64> {
|
|
Ok(match chat_id {
|
|
Some(i) => i,
|
|
None => repo.get_chat_ids().await?.get(0).copied().unwrap_or(0),
|
|
})
|
|
}
|
|
|
|
async fn id_or_new(
|
|
repo: Arc<dyn ChatRepository + Send + Sync>,
|
|
chat_id: Option<i64>,
|
|
) -> Result<i64> {
|
|
Ok(match chat_id {
|
|
Some(i) => i,
|
|
None => repo.get_chat_ids().await?.get(0).copied().unwrap_or(0) + 1,
|
|
})
|
|
}
|