WIP split implementation based on backend
This commit is contained in:
@@ -6,6 +6,7 @@ readme = "README.md"
|
||||
requires-python = ">=3.14"
|
||||
dependencies = [
|
||||
"beautifulsoup4>=4.14.3",
|
||||
"dotenv>=0.9.9",
|
||||
"google-genai>=1.59.0",
|
||||
"lxml>=6.0.2",
|
||||
"markdownify>=1.2.2",
|
||||
|
||||
@@ -5,6 +5,7 @@ import ollama
|
||||
|
||||
from tools import available_functions
|
||||
from ollamaagent import OllamaAgent
|
||||
from geminiagent import GeminiAgent
|
||||
|
||||
|
||||
class Backend(Enum):
|
||||
@@ -17,7 +18,7 @@ class Backend(Enum):
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
tools = [available_functions["fetch_web_page"].spec]
|
||||
tools = [available_functions["fetch_web_page"]]
|
||||
|
||||
|
||||
def run_agent(prompt: str, backend: Backend, model: str):
|
||||
@@ -25,16 +26,5 @@ def run_agent(prompt: str, backend: Backend, model: str):
|
||||
agent = OllamaAgent(model=model, tools=tools)
|
||||
return agent.prompt(message=prompt)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def execute_function(tool):
|
||||
function_name = tool["function"]["name"]
|
||||
args = tool["function"]["arguments"]
|
||||
logger.info(f"Agent is calling: {function_name}({args})")
|
||||
f = available_functions[function_name].function
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": f(**args),
|
||||
"name": function_name,
|
||||
}
|
||||
agent = GeminiAgent(model=model, tools=tools)
|
||||
return agent.prompt(message=prompt)
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
from google import genai
|
||||
from logger import get_logger
|
||||
|
||||
DEFAULT_MODEL = "gemini-2.5-flash"
|
||||
API_KEY="AIzaSyAqmZCqfNLkegMq69E3-U3PDlryXmXfrJs"
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class GeminiAgent:
|
||||
|
||||
def __init__(self, model, tools, max_loop=10):
|
||||
if model:
|
||||
self.model = model
|
||||
else:
|
||||
self.model = DEFAULT_MODEL
|
||||
logger.info(f"Model: {self.model}")
|
||||
self.tools = tools
|
||||
self.max_loop = max_loop
|
||||
self.client = genai.Client(api_key=API_KEY)
|
||||
|
||||
|
||||
def prompt(self, message):
|
||||
response = self.client.models.generate_content(
|
||||
model = self.model, contents=message
|
||||
)
|
||||
return response.text
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ollama
|
||||
|
||||
from logger import get_logger
|
||||
from agent import execute_function
|
||||
from tools import execute_function, Tool2
|
||||
|
||||
DEFAULT_MODEL = "ministral-3:8b"
|
||||
|
||||
@@ -14,29 +14,30 @@ system_prompt = {
|
||||
|
||||
class OllamaAgent:
|
||||
|
||||
def __init__(self, model, tools, max_loop=10):
|
||||
def __init__(self, model, tools: list[Tool2], max_loop=10):
|
||||
if model:
|
||||
self.model = model
|
||||
else:
|
||||
self.model = DEFAULT_MODEL
|
||||
logger.info("Model: {self.model}")
|
||||
logger.info(f"Model: {self.model}")
|
||||
self.tools = tools
|
||||
self.ollama_tools = list(map(lambda tool: tool.to_ollama(), tools))
|
||||
self.max_loop = max_loop
|
||||
|
||||
def prompt(message):
|
||||
def prompt(self, message):
|
||||
messages = [
|
||||
system_prompt,
|
||||
{"role": "user", "content": message},
|
||||
]
|
||||
loops = 0
|
||||
response = ollama.chat(model=self.model, messages=messages, tools=self.tools)
|
||||
response = ollama.chat(model=self.model, messages=messages, tools=self.ollama_tools)
|
||||
rmessage = response["message"]
|
||||
while "tool_calls" in rmessage and loops < max_loop:
|
||||
max_loop += 1
|
||||
while "tool_calls" in rmessage and loops < self.max_loop:
|
||||
self.max_loop += 1
|
||||
logger.debug(f"Tool calls: {len(rmessage["tool_calls"])}")
|
||||
for tool in rmessage["tool_calls"]:
|
||||
messages.append(execute_function(tool))
|
||||
response = ollama.chat(model=self.model, messages=messages, tools=self.tools)
|
||||
response = ollama.chat(model=self.model, messages=messages, tools=self.ollama_tools)
|
||||
rmessage = response["message"]
|
||||
return rmessage["content"]
|
||||
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from scrape import fetch_web_page
|
||||
from logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Tool:
|
||||
@@ -8,6 +12,47 @@ class Tool:
|
||||
self.function = function
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolProperty:
|
||||
name: str
|
||||
prop_type: str
|
||||
items: str | None
|
||||
description: str
|
||||
required: bool
|
||||
|
||||
|
||||
class Tool2:
|
||||
def __init__(
|
||||
self, function, name: str, description: str, properties: list[ToolProperty]
|
||||
):
|
||||
self.function = function
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.properties = properties
|
||||
|
||||
def to_ollama(self) -> dict:
|
||||
required = []
|
||||
props = {}
|
||||
for prop in self.properties:
|
||||
if prop.required:
|
||||
required.append(prop.name)
|
||||
props[prop.name] = {"type": prop.prop_type, "description": prop.description}
|
||||
if prop.items:
|
||||
props[prop.name]["items"] = {"type": prop.items.type}
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": props,
|
||||
"required": required,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def list_directory_contents(**kwargs):
|
||||
try:
|
||||
return str({path: os.listdir(kwargs["path"])})
|
||||
@@ -23,58 +68,34 @@ def get_current_directory(**kwargs):
|
||||
|
||||
|
||||
available_functions: map[str, Tool] = {
|
||||
"list_directory_contents": Tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_directory_contents",
|
||||
"description": "List files in a directory. Requires an absolute path (e.g., /home/user/project) obtained from get_current_directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The path to the directory",
|
||||
}
|
||||
},
|
||||
"required": ["path"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"list_directory_contents": Tool2(
|
||||
list_directory_contents,
|
||||
"list_directory_contenst",
|
||||
"List files in a directory. Requires an absolute path (e.g., /home/user/project) obtained from get_current_directory.",
|
||||
[ToolProperty("path", "string", None, "The path to the directory", True)],
|
||||
),
|
||||
"get_current_directory": Tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_directory",
|
||||
"description": "Returns the path of the current directory. Use this to orient yourself and to know which directory you are working in.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
"get_current_directory": Tool2(
|
||||
get_current_directory,
|
||||
"get_current_directory",
|
||||
"Returns the path of the current directory. Use this to orient yourself and to know which directory you are working in.",
|
||||
[],
|
||||
),
|
||||
"fetch_web_page": Tool(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "fetch_web_page",
|
||||
"description": "Fetch and read the content of a web page via URL",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The full URL to fetch",
|
||||
}
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"fetch_web_page": Tool2(
|
||||
fetch_web_page,
|
||||
"fetch_web_page",
|
||||
"Fetch and read the content of a web page via URL",
|
||||
[ToolProperty("url", "string", None, "The full URL to fetch from", True)],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def execute_function(tool):
|
||||
function_name = tool["function"]["name"]
|
||||
args = tool["function"]["arguments"]
|
||||
logger.info(f"Agent is calling: {function_name}({args})")
|
||||
f = available_functions[function_name].function
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": f(**args),
|
||||
"name": function_name,
|
||||
}
|
||||
|
||||
22
uv.lock
generated
22
uv.lock
generated
@@ -79,6 +79,17 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dotenv"
|
||||
version = "0.9.9"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "python-dotenv" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/b7/545d2c10c1fc15e48653c91efde329a790f2eecfbbf2bd16003b5db2bab0/dotenv-0.9.9-py2.py3-none-any.whl", hash = "sha256:29cf74a087b31dafdb5a446b6d7e11cbce8ed2741540e2339c69fbef92c94ce9", size = 1892, upload-time = "2025-02-19T22:15:01.647Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "google-auth"
|
||||
version = "2.47.0"
|
||||
@@ -309,12 +320,22 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/ed/068e41660b832bb0b1aa5b58011dea2a3fe0ba7861ff38c4d4904c1c1a99/pydantic_core-2.41.5-cp314-cp314t-win_arm64.whl", hash = "sha256:35b44f37a3199f771c3eaa53051bc8a70cd7b54f333531c59e29fd4db5d15008", size = 1974769, upload-time = "2025-11-04T13:42:01.186Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-dotenv"
|
||||
version = "1.2.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f0/26/19cadc79a718c5edbec86fd4919a6b6d3f681039a2f6d66d14be94e75fb9/python_dotenv-1.2.1.tar.gz", hash = "sha256:42667e897e16ab0d66954af0e60a9caa94f0fd4ecf3aaf6d2d260eec1aa36ad6", size = 44221, upload-time = "2025-10-26T15:12:10.434Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/14/1b/a298b06749107c305e1fe0f814c6c74aea7b2f1e10989cb30f544a1b3253/python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61", size = 21230, upload-time = "2025-10-26T15:12:09.109Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-scraper"
|
||||
version = "0.1.0"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "beautifulsoup4" },
|
||||
{ name = "dotenv" },
|
||||
{ name = "google-genai" },
|
||||
{ name = "lxml" },
|
||||
{ name = "markdownify" },
|
||||
@@ -325,6 +346,7 @@ dependencies = [
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "beautifulsoup4", specifier = ">=4.14.3" },
|
||||
{ name = "dotenv", specifier = ">=0.9.9" },
|
||||
{ name = "google-genai", specifier = ">=1.59.0" },
|
||||
{ name = "lxml", specifier = ">=6.0.2" },
|
||||
{ name = "markdownify", specifier = ">=1.2.2" },
|
||||
|
||||
Reference in New Issue
Block a user