feat: storing chat messages and fetching latest messages
This commit is contained in:
@@ -1,11 +1,13 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use directories::ProjectDirs;
|
use directories::ProjectDirs;
|
||||||
|
use shared::ai::ChatMessage as CMessage;
|
||||||
use sqlx::sqlite::SqliteConnectOptions;
|
use sqlx::sqlite::SqliteConnectOptions;
|
||||||
use sqlx::Row;
|
use sqlx::Row;
|
||||||
use sqlx::SqlitePool;
|
use sqlx::SqlitePool;
|
||||||
use tokio::fs;
|
use tokio::fs;
|
||||||
use tonic::async_trait;
|
use tonic::async_trait;
|
||||||
|
|
||||||
|
#[derive(Debug, sqlx::FromRow)]
|
||||||
pub struct ChatMessage {
|
pub struct ChatMessage {
|
||||||
pub id: i64,
|
pub id: i64,
|
||||||
pub text: String,
|
pub text: String,
|
||||||
@@ -14,8 +16,16 @@ pub struct ChatMessage {
|
|||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait ChatRepository {
|
pub trait ChatRepository {
|
||||||
async fn save_message(&self, text: &str, is_user: &bool) -> Result<()>;
|
async fn save_message(&self, text: &str, is_user: &bool) -> Result<ChatMessage>;
|
||||||
async fn get_all_messages(&self) -> Result<Vec<ChatMessage>>;
|
async fn get_latest_messages(&self) -> Result<Vec<ChatMessage>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn message_to_dto(msg: &ChatMessage) -> CMessage {
|
||||||
|
CMessage {
|
||||||
|
id: msg.id,
|
||||||
|
text: msg.text.clone(),
|
||||||
|
is_user: msg.is_user,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct SqliteChatRepository {
|
pub struct SqliteChatRepository {
|
||||||
@@ -40,14 +50,15 @@ impl SqliteChatRepository {
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"CREATE TABLE IF NOT EXISTS message (
|
"CREATE TABLE IF NOT EXISTS messages (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
text TEXT NOT NULL,
|
text TEXT NOT NULL,
|
||||||
is_user BOOL NOT NULL
|
is_user BOOL NOT NULL
|
||||||
)",
|
)",
|
||||||
)
|
)
|
||||||
.execute(&pool)
|
.execute(&pool)
|
||||||
.await?;
|
.await
|
||||||
|
.inspect_err(|e| println!("sql error: {}", e))?;
|
||||||
|
|
||||||
Ok(Self { pool })
|
Ok(Self { pool })
|
||||||
}
|
}
|
||||||
@@ -55,19 +66,35 @@ impl SqliteChatRepository {
|
|||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl ChatRepository for SqliteChatRepository {
|
impl ChatRepository for SqliteChatRepository {
|
||||||
async fn save_message(&self, text: &str, is_user: &bool) -> Result<()> {
|
async fn save_message(&self, text: &str, is_user: &bool) -> Result<ChatMessage> {
|
||||||
sqlx::query("INSERT INTO messages (text, is_user) values (?, ?)")
|
let result = sqlx::query_as::<_, ChatMessage>(
|
||||||
.bind(text)
|
r#"
|
||||||
.bind(is_user)
|
INSERT INTO messages (text, is_user)
|
||||||
.execute(&self.pool)
|
VALUES (?, ?)
|
||||||
.await?;
|
RETURNING id, text, is_user
|
||||||
Ok(())
|
"#,
|
||||||
|
)
|
||||||
|
.bind(text)
|
||||||
|
.bind(is_user)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await
|
||||||
|
.inspect_err(|e| println!("sql error: {}", e))?;
|
||||||
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_all_messages(&self) -> Result<Vec<ChatMessage>> {
|
async fn get_latest_messages(&self) -> Result<Vec<ChatMessage>> {
|
||||||
let rows = sqlx::query("SELECT id, text, is_user FROM messages ORDER BY id DESC LIMIT 10")
|
let rows = sqlx::query(
|
||||||
.fetch_all(&self.pool)
|
r#"
|
||||||
.await?;
|
SELECT * FROM (
|
||||||
|
SELECT id, text, is_user
|
||||||
|
FROM messages
|
||||||
|
ORDER BY id DESC
|
||||||
|
LIMIT 10
|
||||||
|
) AS subquery ORDER BY id ASC"#,
|
||||||
|
)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await
|
||||||
|
.inspect_err(|e| println!("sql error: {}", e))?;
|
||||||
|
|
||||||
let messages = rows
|
let messages = rows
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
mod chatpersistence;
|
mod chatpersistence;
|
||||||
|
|
||||||
use std::cell::Cell;
|
|
||||||
use std::sync::atomic::AtomicI64;
|
use std::sync::atomic::AtomicI64;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use genai::chat::{ChatMessage, ChatRequest};
|
use genai::chat::{ChatMessage, ChatRequest};
|
||||||
use genai::Client;
|
use genai::Client;
|
||||||
@@ -10,18 +10,22 @@ use shared::ai::{
|
|||||||
ChatHistoryRequest, ChatHistoryResponse, ChatMessage as CMessage, ChatRequest as CRequest,
|
ChatHistoryRequest, ChatHistoryResponse, ChatMessage as CMessage, ChatRequest as CRequest,
|
||||||
ChatResponse as CResponse, PromptRequest, PromptResponse,
|
ChatResponse as CResponse, PromptRequest, PromptResponse,
|
||||||
};
|
};
|
||||||
use tonic::{transport::Server, Request, Response, Status};
|
use tonic::{transport::Server, Code, Request, Response, Status};
|
||||||
|
|
||||||
use chatpersistence::SqliteChatRepository;
|
use chatpersistence::SqliteChatRepository;
|
||||||
|
|
||||||
|
use crate::chatpersistence::{message_to_dto, ChatRepository};
|
||||||
|
|
||||||
pub struct DaemonServer {
|
pub struct DaemonServer {
|
||||||
message_counter: AtomicI64,
|
message_counter: AtomicI64,
|
||||||
|
repo: Arc<dyn ChatRepository + Send + Sync>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for DaemonServer {
|
impl DaemonServer {
|
||||||
fn default() -> Self {
|
pub fn new(repo: Arc<dyn ChatRepository + Send + Sync>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
message_counter: AtomicI64::new(0),
|
message_counter: AtomicI64::new(0),
|
||||||
|
repo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -46,36 +50,41 @@ impl AiDaemon for DaemonServer {
|
|||||||
|
|
||||||
async fn chat(&self, request: Request<CRequest>) -> Result<Response<CResponse>, Status> {
|
async fn chat(&self, request: Request<CRequest>) -> Result<Response<CResponse>, Status> {
|
||||||
let r = request.into_inner();
|
let r = request.into_inner();
|
||||||
println!("<<<: {}", r.text());
|
let user_message = message_to_dto(
|
||||||
|
&self
|
||||||
|
.repo
|
||||||
|
.save_message(r.text(), &true)
|
||||||
|
.await
|
||||||
|
.map_err(|e| Status::new(Code::Internal, e.to_string()))?,
|
||||||
|
);
|
||||||
|
let response_text = format!("Pong: {}", r.text());
|
||||||
|
let ai_message = message_to_dto(
|
||||||
|
&self
|
||||||
|
.repo
|
||||||
|
.save_message(response_text.as_str(), &false)
|
||||||
|
.await
|
||||||
|
.map_err(|e| Status::new(Code::Internal, e.to_string()))?,
|
||||||
|
);
|
||||||
let response = CResponse {
|
let response = CResponse {
|
||||||
chat_id: 1,
|
chat_id: 1,
|
||||||
messages: vec![
|
messages: vec![user_message, ai_message],
|
||||||
CMessage {
|
|
||||||
id: self
|
|
||||||
.message_counter
|
|
||||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
|
|
||||||
text: r.text().to_string(),
|
|
||||||
is_user: true,
|
|
||||||
},
|
|
||||||
CMessage {
|
|
||||||
id: self
|
|
||||||
.message_counter
|
|
||||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
|
|
||||||
text: format!("Pong: {}", r.text()),
|
|
||||||
is_user: false,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
};
|
};
|
||||||
return Ok(Response::new(response));
|
return Ok(Response::new(response));
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_history(
|
async fn chat_history(
|
||||||
&self,
|
&self,
|
||||||
request: Request<ChatHistoryRequest>,
|
_: Request<ChatHistoryRequest>,
|
||||||
) -> Result<Response<ChatHistoryResponse>, Status> {
|
) -> Result<Response<ChatHistoryResponse>, Status> {
|
||||||
|
let messages = self
|
||||||
|
.repo
|
||||||
|
.get_latest_messages()
|
||||||
|
.await
|
||||||
|
.map_err(|e| Status::new(Code::Internal, e.to_string()))?;
|
||||||
|
|
||||||
let response = ChatHistoryResponse {
|
let response = ChatHistoryResponse {
|
||||||
chat_id: 1,
|
chat_id: 1,
|
||||||
history: vec![],
|
history: messages.iter().map(|m| message_to_dto(m)).collect(),
|
||||||
};
|
};
|
||||||
Ok(Response::new(response))
|
Ok(Response::new(response))
|
||||||
}
|
}
|
||||||
@@ -101,7 +110,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
let addr_s = "[::1]:50051";
|
let addr_s = "[::1]:50051";
|
||||||
let addr = addr_s.parse().unwrap();
|
let addr = addr_s.parse().unwrap();
|
||||||
let daemon = DaemonServer::default();
|
let daemon = DaemonServer::new(Arc::new(chat_repo));
|
||||||
let reflection_service = tonic_reflection::server::Builder::configure()
|
let reflection_service = tonic_reflection::server::Builder::configure()
|
||||||
.register_encoded_file_descriptor_set(shared::ai::FILE_DESCRIPTOR_SET)
|
.register_encoded_file_descriptor_set(shared::ai::FILE_DESCRIPTOR_SET)
|
||||||
.build_v1()?;
|
.build_v1()?;
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ message ChatHistoryRequest {
|
|||||||
|
|
||||||
message ChatHistoryResponse {
|
message ChatHistoryResponse {
|
||||||
int64 chat_id = 1;
|
int64 chat_id = 1;
|
||||||
repeated ChatResponse history = 10;
|
repeated ChatMessage history = 10;
|
||||||
}
|
}
|
||||||
|
|
||||||
message PromptRequest {
|
message PromptRequest {
|
||||||
|
|||||||
@@ -2,10 +2,13 @@
|
|||||||
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
|
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
|
||||||
|
|
||||||
use feshared::chatmessage::{Message, MessageHistory};
|
use feshared::chatmessage::{Message, MessageHistory};
|
||||||
use shared::ai::{ai_daemon_client::AiDaemonClient, ChatRequest, PromptRequest};
|
use shared::ai::{
|
||||||
|
ai_daemon_client::AiDaemonClient, ChatHistoryRequest, ChatRequest, PromptRequest,
|
||||||
|
};
|
||||||
use tauri::{Emitter, Manager, State};
|
use tauri::{Emitter, Manager, State};
|
||||||
use tauri_plugin_global_shortcut::{Code, GlobalShortcutExt, Modifiers, Shortcut, ShortcutState};
|
use tauri_plugin_global_shortcut::{Code, GlobalShortcutExt, Modifiers, Shortcut, ShortcutState};
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
|
use tonic::{client, Response};
|
||||||
|
|
||||||
struct AppState {
|
struct AppState {
|
||||||
grpc_client: Mutex<AiDaemonClient<tonic::transport::Channel>>,
|
grpc_client: Mutex<AiDaemonClient<tonic::transport::Channel>>,
|
||||||
@@ -31,17 +34,6 @@ fn toggle_popup(app_handle: tauri::AppHandle) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tauri::command]
|
|
||||||
async fn prompt_llm(state: State<'_, AppState>, prompt: String) -> Result<String, String> {
|
|
||||||
println!(">>>> {}", prompt);
|
|
||||||
let mut client = state.grpc_client.lock().await;
|
|
||||||
let request = tonic::Request::new(PromptRequest { prompt });
|
|
||||||
match client.prompt(request).await {
|
|
||||||
Ok(response) => Ok(response.into_inner().response),
|
|
||||||
Err(e) => Err(format!("gRPC error: {}", e)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
async fn chat(
|
async fn chat(
|
||||||
state: State<'_, AppState>,
|
state: State<'_, AppState>,
|
||||||
@@ -72,7 +64,10 @@ async fn chat(
|
|||||||
})
|
})
|
||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
Err(e) => Err(format!("gRPC error: {}", e)),
|
Err(e) => {
|
||||||
|
println!("gRPC error: {}", e);
|
||||||
|
Err(format!("gRPC error: {}", e))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,25 +76,28 @@ async fn chat_history(
|
|||||||
state: State<'_, AppState>,
|
state: State<'_, AppState>,
|
||||||
chat_id: Option<i64>,
|
chat_id: Option<i64>,
|
||||||
) -> Result<MessageHistory, String> {
|
) -> Result<MessageHistory, String> {
|
||||||
let history = MessageHistory {
|
let mut client = state.grpc_client.lock().await;
|
||||||
chat_id: match chat_id {
|
let result = client
|
||||||
Some(_) => chat_id,
|
.chat_history(ChatHistoryRequest { chat_id: None })
|
||||||
None => Some(-1),
|
.await;
|
||||||
},
|
match result {
|
||||||
history: vec![
|
Ok(response) => {
|
||||||
Message {
|
let r = response.into_inner();
|
||||||
id: 1,
|
Ok(MessageHistory {
|
||||||
text: String::from("asd"),
|
chat_id: None,
|
||||||
is_user: false,
|
history: r
|
||||||
},
|
.history
|
||||||
Message {
|
.iter()
|
||||||
id: 2,
|
.map(|m| Message {
|
||||||
text: String::from("yeah!!!!"),
|
id: m.id,
|
||||||
is_user: true,
|
is_user: m.is_user,
|
||||||
},
|
text: m.text.clone(),
|
||||||
],
|
})
|
||||||
};
|
.collect(),
|
||||||
Ok(history)
|
})
|
||||||
|
}
|
||||||
|
Err(e) => Err(format!("gRPC error: {e}")),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@@ -117,12 +115,7 @@ async fn main() {
|
|||||||
current_chat: Mutex::new(None),
|
current_chat: Mutex::new(None),
|
||||||
})
|
})
|
||||||
.plugin(tauri_plugin_global_shortcut::Builder::new().build())
|
.plugin(tauri_plugin_global_shortcut::Builder::new().build())
|
||||||
.invoke_handler(tauri::generate_handler![
|
.invoke_handler(tauri::generate_handler![toggle_popup, chat_history, chat,])
|
||||||
toggle_popup,
|
|
||||||
prompt_llm,
|
|
||||||
chat_history,
|
|
||||||
chat,
|
|
||||||
])
|
|
||||||
.setup(|app| {
|
.setup(|app| {
|
||||||
/* Auto-hide popup when focus is lost
|
/* Auto-hide popup when focus is lost
|
||||||
if let Some(window) = app.get_webview_window("popup") {
|
if let Some(window) = app.get_webview_window("popup") {
|
||||||
|
|||||||
Reference in New Issue
Block a user