feat: allow starting a new chat
This commit is contained in:
@@ -19,11 +19,11 @@ pub trait ChatRepository {
|
||||
&self,
|
||||
text: &str,
|
||||
is_user: &bool,
|
||||
chat_id: &i32,
|
||||
chat_id: &i64,
|
||||
) -> Result<ChatMessageData>;
|
||||
async fn get_latest_messages(&self, chat_id: &i32, count: &i32)
|
||||
async fn get_latest_messages(&self, chat_id: &i64, count: &i64)
|
||||
-> Result<Vec<ChatMessageData>>;
|
||||
async fn get_chat_ids(&self) -> Result<Box<[i32]>>;
|
||||
async fn get_chat_ids(&self) -> Result<Box<[i64]>>;
|
||||
}
|
||||
|
||||
pub struct SqliteChatRepository {
|
||||
@@ -62,7 +62,7 @@ impl ChatRepository for SqliteChatRepository {
|
||||
&self,
|
||||
text: &str,
|
||||
is_user: &bool,
|
||||
chat_id: &i32,
|
||||
chat_id: &i64,
|
||||
) -> Result<ChatMessageData> {
|
||||
let result = sqlx::query_as::<_, ChatMessageData>(
|
||||
r#"
|
||||
@@ -82,8 +82,8 @@ impl ChatRepository for SqliteChatRepository {
|
||||
|
||||
async fn get_latest_messages(
|
||||
&self,
|
||||
chat_id: &i32,
|
||||
count: &i32,
|
||||
chat_id: &i64,
|
||||
count: &i64,
|
||||
) -> Result<Vec<ChatMessageData>> {
|
||||
// From all chat ids get the latest id.
|
||||
let rows = sqlx::query(
|
||||
@@ -116,15 +116,15 @@ impl ChatRepository for SqliteChatRepository {
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
async fn get_chat_ids(&self) -> Result<Box<[i32]>> {
|
||||
async fn get_chat_ids(&self) -> Result<Box<[i64]>> {
|
||||
let rows = sqlx::query("SELECT DISTINCT(chat_id) FROM messages ORDER BY chat_id DESC")
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.inspect_err(|e| println!("sql error: {}", e))?;
|
||||
let ids: Vec<i32> = rows
|
||||
let ids: Vec<i64> = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
let i: i32 = row.get(0);
|
||||
let i: i64 = row.get(0);
|
||||
i
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::chatpersistence::{ChatMessageData, ChatRepository};
|
||||
use anyhow::Result;
|
||||
use genai::chat::{ChatMessage, ChatRequest};
|
||||
use genai::Client;
|
||||
use shared::ai::ai_daemon_server::AiDaemon;
|
||||
use shared::ai::ai_service_server::AiService;
|
||||
use shared::ai::{
|
||||
ChatHistoryRequest, ChatHistoryResponse, ChatMessage as CMessage, ChatRequest as CRequest,
|
||||
ChatResponse as CResponse, DaemonStatusRequest, DaemonStatusResponse,
|
||||
@@ -25,10 +25,10 @@ impl DaemonServer {
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl AiDaemon for DaemonServer {
|
||||
impl AiService for DaemonServer {
|
||||
async fn chat(&self, request: Request<CRequest>) -> Result<Response<CResponse>, Status> {
|
||||
let r = request.into_inner();
|
||||
let chat_id = get_chat_id(self.repo.clone(), r.chat_id)
|
||||
let chat_id = id_or_new(self.repo.clone(), r.chat_id)
|
||||
.await
|
||||
.map_err(|e| Status::new(Code::Internal, e.to_string()))?;
|
||||
let mut messages = gather_history(self.repo.clone(), &chat_id)
|
||||
@@ -45,7 +45,7 @@ impl AiDaemon for DaemonServer {
|
||||
let user_message = message_to_dto(
|
||||
&self
|
||||
.repo
|
||||
.save_message(r.text(), &true, &0)
|
||||
.save_message(r.text(), &true, &chat_id)
|
||||
.await
|
||||
.map_err(|e| Status::new(Code::Internal, e.to_string()))?,
|
||||
);
|
||||
@@ -60,12 +60,12 @@ impl AiDaemon for DaemonServer {
|
||||
let ai_message = message_to_dto(
|
||||
&self
|
||||
.repo
|
||||
.save_message(response_text, &false, &0)
|
||||
.save_message(response_text, &false, &chat_id)
|
||||
.await
|
||||
.map_err(|e| Status::new(Code::Internal, e.to_string()))?,
|
||||
);
|
||||
let response = CResponse {
|
||||
chat_id: 1,
|
||||
chat_id: ai_message.chat_id,
|
||||
messages: vec![user_message, ai_message],
|
||||
};
|
||||
return Ok(Response::new(response));
|
||||
@@ -75,7 +75,7 @@ impl AiDaemon for DaemonServer {
|
||||
&self,
|
||||
request: Request<ChatHistoryRequest>,
|
||||
) -> Result<Response<ChatHistoryResponse>, Status> {
|
||||
let chat_id = get_chat_id(self.repo.clone(), request.into_inner().chat_id)
|
||||
let chat_id = get_latest_chat_id(self.repo.clone(), request.into_inner().chat_id)
|
||||
.await
|
||||
.map_err(|e| Status::new(Code::Internal, e.to_string()))?;
|
||||
let messages = self
|
||||
@@ -85,7 +85,7 @@ impl AiDaemon for DaemonServer {
|
||||
.map_err(|e| Status::new(Code::Internal, e.to_string()))?;
|
||||
|
||||
let response = ChatHistoryResponse {
|
||||
chat_id: 1,
|
||||
chat_id: chat_id,
|
||||
history: messages.iter().map(|m| message_to_dto(m)).collect(),
|
||||
};
|
||||
Ok(Response::new(response))
|
||||
@@ -107,6 +107,7 @@ impl AiDaemon for DaemonServer {
|
||||
pub fn message_to_dto(msg: &ChatMessageData) -> CMessage {
|
||||
CMessage {
|
||||
id: msg.id,
|
||||
chat_id: msg.chat_id,
|
||||
text: msg.text.clone(),
|
||||
is_user: msg.is_user,
|
||||
}
|
||||
@@ -114,7 +115,7 @@ pub fn message_to_dto(msg: &ChatMessageData) -> CMessage {
|
||||
|
||||
async fn gather_history(
|
||||
repo: Arc<dyn ChatRepository + Send + Sync>,
|
||||
chat_id: &i32,
|
||||
chat_id: &i64,
|
||||
) -> Result<Vec<ChatMessage>> {
|
||||
let messages = repo.get_latest_messages(chat_id, &10).await?;
|
||||
Ok(messages
|
||||
@@ -126,12 +127,22 @@ async fn gather_history(
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn get_chat_id(
|
||||
async fn get_latest_chat_id(
|
||||
repo: Arc<dyn ChatRepository + Send + Sync>,
|
||||
chat_id: Option<i64>,
|
||||
) -> Result<i32> {
|
||||
) -> Result<i64> {
|
||||
Ok(match chat_id {
|
||||
Some(i) => i as i32,
|
||||
Some(i) => i,
|
||||
None => repo.get_chat_ids().await?.get(0).copied().unwrap_or(0),
|
||||
})
|
||||
}
|
||||
|
||||
async fn id_or_new(
|
||||
repo: Arc<dyn ChatRepository + Send + Sync>,
|
||||
chat_id: Option<i64>,
|
||||
) -> Result<i64> {
|
||||
Ok(match chat_id {
|
||||
Some(i) => i,
|
||||
None => repo.get_chat_ids().await?.get(0).copied().unwrap_or(0) + 1,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ mod daemongrpc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use genai::Client;
|
||||
use shared::ai::ai_daemon_server::AiDaemonServer;
|
||||
use shared::ai::ai_service_server::AiServiceServer;
|
||||
use tonic::transport::Server;
|
||||
|
||||
use chatpersistence::SqliteChatRepository;
|
||||
@@ -24,7 +24,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.build_v1()?;
|
||||
println!("Started daemon at {}", addr_s);
|
||||
Server::builder()
|
||||
.add_service(AiDaemonServer::new(daemon))
|
||||
.add_service(AiServiceServer::new(daemon))
|
||||
.add_service(reflection_service)
|
||||
.serve(addr)
|
||||
.await?;
|
||||
|
||||
@@ -16,6 +16,7 @@ pub mod chatmessage {
|
||||
|
||||
pub enum TauriCommand {
|
||||
Chat,
|
||||
SetChatId,
|
||||
ChatHistory,
|
||||
DaemonState,
|
||||
ToggleDarkMode,
|
||||
@@ -27,6 +28,7 @@ pub mod chatmessage {
|
||||
match self {
|
||||
TauriCommand::TogglePopup => "toggle_popup",
|
||||
TauriCommand::Chat => "chat",
|
||||
TauriCommand::SetChatId => "set_chat_id",
|
||||
TauriCommand::ChatHistory => "chat_history",
|
||||
TauriCommand::DaemonState => "daemon_state",
|
||||
TauriCommand::ToggleDarkMode => "toggle_dark_mode",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
syntax = "proto3";
|
||||
package ai_daemon;
|
||||
|
||||
service AiDaemon {
|
||||
service AiService {
|
||||
rpc Chat(ChatRequest) returns (ChatResponse);
|
||||
rpc ChatHistory(ChatHistoryRequest) returns (ChatHistoryResponse);
|
||||
rpc DaemonStatus(DaemonStatusRequest) returns (DaemonStatusResponse);
|
||||
@@ -9,6 +9,7 @@ service AiDaemon {
|
||||
|
||||
message ChatMessage {
|
||||
int64 id = 1;
|
||||
int64 chat_id = 2;
|
||||
string text = 10;
|
||||
bool is_user = 20;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user