WIP split implementation based on backend
This commit is contained in:
@@ -5,6 +5,7 @@ import ollama
|
||||
|
||||
from tools import available_functions
|
||||
from ollamaagent import OllamaAgent
|
||||
from geminiagent import GeminiAgent
|
||||
|
||||
|
||||
class Backend(Enum):
|
||||
@@ -17,7 +18,7 @@ class Backend(Enum):
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
tools = [available_functions["fetch_web_page"].spec]
|
||||
tools = [available_functions["fetch_web_page"]]
|
||||
|
||||
|
||||
def run_agent(prompt: str, backend: Backend, model: str):
|
||||
@@ -25,16 +26,5 @@ def run_agent(prompt: str, backend: Backend, model: str):
|
||||
agent = OllamaAgent(model=model, tools=tools)
|
||||
return agent.prompt(message=prompt)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def execute_function(tool):
|
||||
function_name = tool["function"]["name"]
|
||||
args = tool["function"]["arguments"]
|
||||
logger.info(f"Agent is calling: {function_name}({args})")
|
||||
f = available_functions[function_name].function
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": f(**args),
|
||||
"name": function_name,
|
||||
}
|
||||
agent = GeminiAgent(model=model, tools=tools)
|
||||
return agent.prompt(message=prompt)
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
from google import genai
|
||||
from logger import get_logger
|
||||
|
||||
DEFAULT_MODEL = "gemini-2.5-flash"
|
||||
API_KEY="AIzaSyAqmZCqfNLkegMq69E3-U3PDlryXmXfrJs"
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class GeminiAgent:
|
||||
|
||||
def __init__(self, model, tools, max_loop=10):
|
||||
if model:
|
||||
self.model = model
|
||||
else:
|
||||
self.model = DEFAULT_MODEL
|
||||
logger.info(f"Model: {self.model}")
|
||||
self.tools = tools
|
||||
self.max_loop = max_loop
|
||||
self.client = genai.Client(api_key=API_KEY)
|
||||
|
||||
|
||||
def prompt(self, message):
|
||||
response = self.client.models.generate_content(
|
||||
model = self.model, contents=message
|
||||
)
|
||||
return response.text
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ollama
|
||||
|
||||
from logger import get_logger
|
||||
from agent import execute_function
|
||||
from tools import execute_function, Tool2
|
||||
|
||||
DEFAULT_MODEL = "ministral-3:8b"
|
||||
|
||||
@@ -14,29 +14,30 @@ system_prompt = {
|
||||
|
||||
class OllamaAgent:
|
||||
|
||||
def __init__(self, model, tools, max_loop=10):
|
||||
def __init__(self, model, tools: list[Tool2], max_loop=10):
|
||||
if model:
|
||||
self.model = model
|
||||
else:
|
||||
self.model = DEFAULT_MODEL
|
||||
logger.info("Model: {self.model}")
|
||||
logger.info(f"Model: {self.model}")
|
||||
self.tools = tools
|
||||
self.ollama_tools = list(map(lambda tool: tool.to_ollama(), tools))
|
||||
self.max_loop = max_loop
|
||||
|
||||
def prompt(message):
|
||||
def prompt(self, message):
|
||||
messages = [
|
||||
system_prompt,
|
||||
{"role": "user", "content": message},
|
||||
]
|
||||
loops = 0
|
||||
response = ollama.chat(model=self.model, messages=messages, tools=self.tools)
|
||||
response = ollama.chat(model=self.model, messages=messages, tools=self.ollama_tools)
|
||||
rmessage = response["message"]
|
||||
while "tool_calls" in rmessage and loops < max_loop:
|
||||
max_loop += 1
|
||||
while "tool_calls" in rmessage and loops < self.max_loop:
|
||||
self.max_loop += 1
|
||||
logger.debug(f"Tool calls: {len(rmessage["tool_calls"])}")
|
||||
for tool in rmessage["tool_calls"]:
|
||||
messages.append(execute_function(tool))
|
||||
response = ollama.chat(model=self.model, messages=messages, tools=self.tools)
|
||||
response = ollama.chat(model=self.model, messages=messages, tools=self.ollama_tools)
|
||||
rmessage = response["message"]
|
||||
return rmessage["content"]
|
||||
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from scrape import fetch_web_page
|
||||
from logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Tool:
|
||||
@@ -8,6 +12,47 @@ class Tool:
|
||||
self.function = function
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolProperty:
|
||||
name: str
|
||||
prop_type: str
|
||||
items: str | None
|
||||
description: str
|
||||
required: bool
|
||||
|
||||
|
||||
class Tool2:
|
||||
def __init__(
|
||||
self, function, name: str, description: str, properties: list[ToolProperty]
|
||||
):
|
||||
self.function = function
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.properties = properties
|
||||
|
||||
def to_ollama(self) -> dict:
|
||||
required = []
|
||||
props = {}
|
||||
for prop in self.properties:
|
||||
if prop.required:
|
||||
required.append(prop.name)
|
||||
props[prop.name] = {"type": prop.prop_type, "description": prop.description}
|
||||
if prop.items:
|
||||
props[prop.name]["items"] = {"type": prop.items.type}
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": props,
|
||||
"required": required,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def list_directory_contents(**kwargs):
|
||||
try:
|
||||
return str({path: os.listdir(kwargs["path"])})
|
||||
@@ -23,58 +68,34 @@ def get_current_directory(**kwargs):
|
||||
|
||||
|
||||
available_functions: map[str, Tool] = {
|
||||
"list_directory_contents": Tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_directory_contents",
|
||||
"description": "List files in a directory. Requires an absolute path (e.g., /home/user/project) obtained from get_current_directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The path to the directory",
|
||||
}
|
||||
},
|
||||
"required": ["path"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"list_directory_contents": Tool2(
|
||||
list_directory_contents,
|
||||
"list_directory_contenst",
|
||||
"List files in a directory. Requires an absolute path (e.g., /home/user/project) obtained from get_current_directory.",
|
||||
[ToolProperty("path", "string", None, "The path to the directory", True)],
|
||||
),
|
||||
"get_current_directory": Tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_directory",
|
||||
"description": "Returns the path of the current directory. Use this to orient yourself and to know which directory you are working in.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
"get_current_directory": Tool2(
|
||||
get_current_directory,
|
||||
"get_current_directory",
|
||||
"Returns the path of the current directory. Use this to orient yourself and to know which directory you are working in.",
|
||||
[],
|
||||
),
|
||||
"fetch_web_page": Tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "fetch_web_page",
|
||||
"description": "Fetch and read the content of a web page via URL",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The full URL to fetch",
|
||||
}
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"fetch_web_page": Tool2(
|
||||
fetch_web_page,
|
||||
"fetch_web_page",
|
||||
"Fetch and read the content of a web page via URL",
|
||||
[ToolProperty("url", "string", None, "The full URL to fetch from", True)],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def execute_function(tool):
|
||||
function_name = tool["function"]["name"]
|
||||
args = tool["function"]["arguments"]
|
||||
logger.info(f"Agent is calling: {function_name}({args})")
|
||||
f = available_functions[function_name].function
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": f(**args),
|
||||
"name": function_name,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user