WIP split implementation based on backend

This commit is contained in:
2026-01-18 21:16:05 +02:00
parent 63efca03f8
commit 21558b88cd
6 changed files with 132 additions and 70 deletions

View File

@@ -5,6 +5,7 @@ import ollama
from tools import available_functions
from ollamaagent import OllamaAgent
from geminiagent import GeminiAgent
class Backend(Enum):
@@ -17,7 +18,7 @@ class Backend(Enum):
logger = get_logger(__name__)
tools = [available_functions["fetch_web_page"].spec]
tools = [available_functions["fetch_web_page"]]
def run_agent(prompt: str, backend: Backend, model: str):
@@ -25,16 +26,5 @@ def run_agent(prompt: str, backend: Backend, model: str):
agent = OllamaAgent(model=model, tools=tools)
return agent.prompt(message=prompt)
else:
raise NotImplementedError
def execute_function(tool):
function_name = tool["function"]["name"]
args = tool["function"]["arguments"]
logger.info(f"Agent is calling: {function_name}({args})")
f = available_functions[function_name].function
return {
"role": "tool",
"content": f(**args),
"name": function_name,
}
agent = GeminiAgent(model=model, tools=tools)
return agent.prompt(message=prompt)

View File

@@ -0,0 +1,27 @@
from google import genai
from logger import get_logger
DEFAULT_MODEL = "gemini-2.5-flash"
API_KEY="AIzaSyAqmZCqfNLkegMq69E3-U3PDlryXmXfrJs"
logger = get_logger(__name__)
class GeminiAgent:
def __init__(self, model, tools, max_loop=10):
if model:
self.model = model
else:
self.model = DEFAULT_MODEL
logger.info(f"Model: {self.model}")
self.tools = tools
self.max_loop = max_loop
self.client = genai.Client(api_key=API_KEY)
def prompt(self, message):
response = self.client.models.generate_content(
model = self.model, contents=message
)
return response.text

View File

@@ -1,7 +1,7 @@
import ollama
from logger import get_logger
from agent import execute_function
from tools import execute_function, Tool2
DEFAULT_MODEL = "ministral-3:8b"
@@ -14,29 +14,30 @@ system_prompt = {
class OllamaAgent:
def __init__(self, model, tools, max_loop=10):
def __init__(self, model, tools: list[Tool2], max_loop=10):
if model:
self.model = model
else:
self.model = DEFAULT_MODEL
logger.info("Model: {self.model}")
logger.info(f"Model: {self.model}")
self.tools = tools
self.ollama_tools = list(map(lambda tool: tool.to_ollama(), tools))
self.max_loop = max_loop
def prompt(message):
def prompt(self, message):
messages = [
system_prompt,
{"role": "user", "content": message},
]
loops = 0
response = ollama.chat(model=self.model, messages=messages, tools=self.tools)
response = ollama.chat(model=self.model, messages=messages, tools=self.ollama_tools)
rmessage = response["message"]
while "tool_calls" in rmessage and loops < max_loop:
max_loop += 1
while "tool_calls" in rmessage and loops < self.max_loop:
self.max_loop += 1
logger.debug(f"Tool calls: {len(rmessage["tool_calls"])}")
for tool in rmessage["tool_calls"]:
messages.append(execute_function(tool))
response = ollama.chat(model=self.model, messages=messages, tools=self.tools)
response = ollama.chat(model=self.model, messages=messages, tools=self.ollama_tools)
rmessage = response["message"]
return rmessage["content"]

View File

@@ -1,5 +1,9 @@
import os
from dataclasses import dataclass
from scrape import fetch_web_page
from logger import get_logger
logger = get_logger(__name__)
class Tool:
@@ -8,6 +12,47 @@ class Tool:
self.function = function
@dataclass
class ToolProperty:
name: str
prop_type: str
items: str | None
description: str
required: bool
class Tool2:
def __init__(
self, function, name: str, description: str, properties: list[ToolProperty]
):
self.function = function
self.name = name
self.description = description
self.properties = properties
def to_ollama(self) -> dict:
required = []
props = {}
for prop in self.properties:
if prop.required:
required.append(prop.name)
props[prop.name] = {"type": prop.prop_type, "description": prop.description}
if prop.items:
props[prop.name]["items"] = {"type": prop.items.type}
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": props,
"required": required,
},
},
}
def list_directory_contents(**kwargs):
try:
return str({path: os.listdir(kwargs["path"])})
@@ -23,58 +68,34 @@ def get_current_directory(**kwargs):
available_functions: map[str, Tool] = {
"list_directory_contents": Tool(
{
"type": "function",
"function": {
"name": "list_directory_contents",
"description": "List files in a directory. Requires an absolute path (e.g., /home/user/project) obtained from get_current_directory.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The path to the directory",
}
},
"required": ["path"],
},
},
},
"list_directory_contents": Tool2(
list_directory_contents,
"list_directory_contenst",
"List files in a directory. Requires an absolute path (e.g., /home/user/project) obtained from get_current_directory.",
[ToolProperty("path", "string", None, "The path to the directory", True)],
),
"get_current_directory": Tool(
{
"type": "function",
"function": {
"name": "get_current_directory",
"description": "Returns the path of the current directory. Use this to orient yourself and to know which directory you are working in.",
"parameters": {
"type": "object",
"properties": {},
},
},
},
"get_current_directory": Tool2(
get_current_directory,
"get_current_directory",
"Returns the path of the current directory. Use this to orient yourself and to know which directory you are working in.",
[],
),
"fetch_web_page": Tool(
{
"type": "function",
"function": {
"name": "fetch_web_page",
"description": "Fetch and read the content of a web page via URL",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The full URL to fetch",
}
},
"required": ["url"],
},
},
},
"fetch_web_page": Tool2(
fetch_web_page,
"fetch_web_page",
"Fetch and read the content of a web page via URL",
[ToolProperty("url", "string", None, "The full URL to fetch from", True)],
),
}
def execute_function(tool):
function_name = tool["function"]["name"]
args = tool["function"]["arguments"]
logger.info(f"Agent is calling: {function_name}({args})")
f = available_functions[function_name].function
return {
"role": "tool",
"content": f(**args),
"name": function_name,
}