WIP split implementation based on backend
This commit is contained in:
@@ -6,6 +6,7 @@ readme = "README.md"
|
|||||||
requires-python = ">=3.14"
|
requires-python = ">=3.14"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"beautifulsoup4>=4.14.3",
|
"beautifulsoup4>=4.14.3",
|
||||||
|
"dotenv>=0.9.9",
|
||||||
"google-genai>=1.59.0",
|
"google-genai>=1.59.0",
|
||||||
"lxml>=6.0.2",
|
"lxml>=6.0.2",
|
||||||
"markdownify>=1.2.2",
|
"markdownify>=1.2.2",
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import ollama
|
|||||||
|
|
||||||
from tools import available_functions
|
from tools import available_functions
|
||||||
from ollamaagent import OllamaAgent
|
from ollamaagent import OllamaAgent
|
||||||
|
from geminiagent import GeminiAgent
|
||||||
|
|
||||||
|
|
||||||
class Backend(Enum):
|
class Backend(Enum):
|
||||||
@@ -17,7 +18,7 @@ class Backend(Enum):
|
|||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
tools = [available_functions["fetch_web_page"].spec]
|
tools = [available_functions["fetch_web_page"]]
|
||||||
|
|
||||||
|
|
||||||
def run_agent(prompt: str, backend: Backend, model: str):
|
def run_agent(prompt: str, backend: Backend, model: str):
|
||||||
@@ -25,16 +26,5 @@ def run_agent(prompt: str, backend: Backend, model: str):
|
|||||||
agent = OllamaAgent(model=model, tools=tools)
|
agent = OllamaAgent(model=model, tools=tools)
|
||||||
return agent.prompt(message=prompt)
|
return agent.prompt(message=prompt)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
agent = GeminiAgent(model=model, tools=tools)
|
||||||
|
return agent.prompt(message=prompt)
|
||||||
|
|
||||||
def execute_function(tool):
|
|
||||||
function_name = tool["function"]["name"]
|
|
||||||
args = tool["function"]["arguments"]
|
|
||||||
logger.info(f"Agent is calling: {function_name}({args})")
|
|
||||||
f = available_functions[function_name].function
|
|
||||||
return {
|
|
||||||
"role": "tool",
|
|
||||||
"content": f(**args),
|
|
||||||
"name": function_name,
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -0,0 +1,27 @@
|
|||||||
|
from google import genai
|
||||||
|
from logger import get_logger
|
||||||
|
|
||||||
|
DEFAULT_MODEL = "gemini-2.5-flash"
|
||||||
|
API_KEY="AIzaSyAqmZCqfNLkegMq69E3-U3PDlryXmXfrJs"
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
class GeminiAgent:
|
||||||
|
|
||||||
|
def __init__(self, model, tools, max_loop=10):
|
||||||
|
if model:
|
||||||
|
self.model = model
|
||||||
|
else:
|
||||||
|
self.model = DEFAULT_MODEL
|
||||||
|
logger.info(f"Model: {self.model}")
|
||||||
|
self.tools = tools
|
||||||
|
self.max_loop = max_loop
|
||||||
|
self.client = genai.Client(api_key=API_KEY)
|
||||||
|
|
||||||
|
|
||||||
|
def prompt(self, message):
|
||||||
|
response = self.client.models.generate_content(
|
||||||
|
model = self.model, contents=message
|
||||||
|
)
|
||||||
|
return response.text
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import ollama
|
import ollama
|
||||||
|
|
||||||
from logger import get_logger
|
from logger import get_logger
|
||||||
from agent import execute_function
|
from tools import execute_function, Tool2
|
||||||
|
|
||||||
DEFAULT_MODEL = "ministral-3:8b"
|
DEFAULT_MODEL = "ministral-3:8b"
|
||||||
|
|
||||||
@@ -14,29 +14,30 @@ system_prompt = {
|
|||||||
|
|
||||||
class OllamaAgent:
|
class OllamaAgent:
|
||||||
|
|
||||||
def __init__(self, model, tools, max_loop=10):
|
def __init__(self, model, tools: list[Tool2], max_loop=10):
|
||||||
if model:
|
if model:
|
||||||
self.model = model
|
self.model = model
|
||||||
else:
|
else:
|
||||||
self.model = DEFAULT_MODEL
|
self.model = DEFAULT_MODEL
|
||||||
logger.info("Model: {self.model}")
|
logger.info(f"Model: {self.model}")
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
|
self.ollama_tools = list(map(lambda tool: tool.to_ollama(), tools))
|
||||||
self.max_loop = max_loop
|
self.max_loop = max_loop
|
||||||
|
|
||||||
def prompt(message):
|
def prompt(self, message):
|
||||||
messages = [
|
messages = [
|
||||||
system_prompt,
|
system_prompt,
|
||||||
{"role": "user", "content": message},
|
{"role": "user", "content": message},
|
||||||
]
|
]
|
||||||
loops = 0
|
loops = 0
|
||||||
response = ollama.chat(model=self.model, messages=messages, tools=self.tools)
|
response = ollama.chat(model=self.model, messages=messages, tools=self.ollama_tools)
|
||||||
rmessage = response["message"]
|
rmessage = response["message"]
|
||||||
while "tool_calls" in rmessage and loops < max_loop:
|
while "tool_calls" in rmessage and loops < self.max_loop:
|
||||||
max_loop += 1
|
self.max_loop += 1
|
||||||
logger.debug(f"Tool calls: {len(rmessage["tool_calls"])}")
|
logger.debug(f"Tool calls: {len(rmessage["tool_calls"])}")
|
||||||
for tool in rmessage["tool_calls"]:
|
for tool in rmessage["tool_calls"]:
|
||||||
messages.append(execute_function(tool))
|
messages.append(execute_function(tool))
|
||||||
response = ollama.chat(model=self.model, messages=messages, tools=self.tools)
|
response = ollama.chat(model=self.model, messages=messages, tools=self.ollama_tools)
|
||||||
rmessage = response["message"]
|
rmessage = response["message"]
|
||||||
return rmessage["content"]
|
return rmessage["content"]
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
from scrape import fetch_web_page
|
from scrape import fetch_web_page
|
||||||
|
from logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Tool:
|
class Tool:
|
||||||
@@ -8,6 +12,47 @@ class Tool:
|
|||||||
self.function = function
|
self.function = function
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolProperty:
|
||||||
|
name: str
|
||||||
|
prop_type: str
|
||||||
|
items: str | None
|
||||||
|
description: str
|
||||||
|
required: bool
|
||||||
|
|
||||||
|
|
||||||
|
class Tool2:
|
||||||
|
def __init__(
|
||||||
|
self, function, name: str, description: str, properties: list[ToolProperty]
|
||||||
|
):
|
||||||
|
self.function = function
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.properties = properties
|
||||||
|
|
||||||
|
def to_ollama(self) -> dict:
|
||||||
|
required = []
|
||||||
|
props = {}
|
||||||
|
for prop in self.properties:
|
||||||
|
if prop.required:
|
||||||
|
required.append(prop.name)
|
||||||
|
props[prop.name] = {"type": prop.prop_type, "description": prop.description}
|
||||||
|
if prop.items:
|
||||||
|
props[prop.name]["items"] = {"type": prop.items.type}
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": props,
|
||||||
|
"required": required,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def list_directory_contents(**kwargs):
|
def list_directory_contents(**kwargs):
|
||||||
try:
|
try:
|
||||||
return str({path: os.listdir(kwargs["path"])})
|
return str({path: os.listdir(kwargs["path"])})
|
||||||
@@ -23,58 +68,34 @@ def get_current_directory(**kwargs):
|
|||||||
|
|
||||||
|
|
||||||
available_functions: map[str, Tool] = {
|
available_functions: map[str, Tool] = {
|
||||||
"list_directory_contents": Tool(
|
"list_directory_contents": Tool2(
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "list_directory_contents",
|
|
||||||
"description": "List files in a directory. Requires an absolute path (e.g., /home/user/project) obtained from get_current_directory.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The path to the directory",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["path"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
list_directory_contents,
|
list_directory_contents,
|
||||||
|
"list_directory_contenst",
|
||||||
|
"List files in a directory. Requires an absolute path (e.g., /home/user/project) obtained from get_current_directory.",
|
||||||
|
[ToolProperty("path", "string", None, "The path to the directory", True)],
|
||||||
),
|
),
|
||||||
"get_current_directory": Tool(
|
"get_current_directory": Tool2(
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "get_current_directory",
|
|
||||||
"description": "Returns the path of the current directory. Use this to orient yourself and to know which directory you are working in.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
get_current_directory,
|
get_current_directory,
|
||||||
|
"get_current_directory",
|
||||||
|
"Returns the path of the current directory. Use this to orient yourself and to know which directory you are working in.",
|
||||||
|
[],
|
||||||
),
|
),
|
||||||
"fetch_web_page": Tool(
|
"fetch_web_page": Tool2(
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "fetch_web_page",
|
|
||||||
"description": "Fetch and read the content of a web page via URL",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"url": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The full URL to fetch",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["url"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
fetch_web_page,
|
fetch_web_page,
|
||||||
|
"fetch_web_page",
|
||||||
|
"Fetch and read the content of a web page via URL",
|
||||||
|
[ToolProperty("url", "string", None, "The full URL to fetch from", True)],
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def execute_function(tool):
|
||||||
|
function_name = tool["function"]["name"]
|
||||||
|
args = tool["function"]["arguments"]
|
||||||
|
logger.info(f"Agent is calling: {function_name}({args})")
|
||||||
|
f = available_functions[function_name].function
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"content": f(**args),
|
||||||
|
"name": function_name,
|
||||||
|
}
|
||||||
|
|||||||
22
uv.lock
generated
22
uv.lock
generated
@@ -79,6 +79,17 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" },
|
{ url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dotenv"
|
||||||
|
version = "0.9.9"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "python-dotenv" },
|
||||||
|
]
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b2/b7/545d2c10c1fc15e48653c91efde329a790f2eecfbbf2bd16003b5db2bab0/dotenv-0.9.9-py2.py3-none-any.whl", hash = "sha256:29cf74a087b31dafdb5a446b6d7e11cbce8ed2741540e2339c69fbef92c94ce9", size = 1892, upload-time = "2025-02-19T22:15:01.647Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "google-auth"
|
name = "google-auth"
|
||||||
version = "2.47.0"
|
version = "2.47.0"
|
||||||
@@ -309,12 +320,22 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/9f/ed/068e41660b832bb0b1aa5b58011dea2a3fe0ba7861ff38c4d4904c1c1a99/pydantic_core-2.41.5-cp314-cp314t-win_arm64.whl", hash = "sha256:35b44f37a3199f771c3eaa53051bc8a70cd7b54f333531c59e29fd4db5d15008", size = 1974769, upload-time = "2025-11-04T13:42:01.186Z" },
|
{ url = "https://files.pythonhosted.org/packages/9f/ed/068e41660b832bb0b1aa5b58011dea2a3fe0ba7861ff38c4d4904c1c1a99/pydantic_core-2.41.5-cp314-cp314t-win_arm64.whl", hash = "sha256:35b44f37a3199f771c3eaa53051bc8a70cd7b54f333531c59e29fd4db5d15008", size = 1974769, upload-time = "2025-11-04T13:42:01.186Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "python-dotenv"
|
||||||
|
version = "1.2.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/f0/26/19cadc79a718c5edbec86fd4919a6b6d3f681039a2f6d66d14be94e75fb9/python_dotenv-1.2.1.tar.gz", hash = "sha256:42667e897e16ab0d66954af0e60a9caa94f0fd4ecf3aaf6d2d260eec1aa36ad6", size = 44221, upload-time = "2025-10-26T15:12:10.434Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/14/1b/a298b06749107c305e1fe0f814c6c74aea7b2f1e10989cb30f544a1b3253/python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61", size = 21230, upload-time = "2025-10-26T15:12:09.109Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "python-scraper"
|
name = "python-scraper"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
source = { virtual = "." }
|
source = { virtual = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "beautifulsoup4" },
|
{ name = "beautifulsoup4" },
|
||||||
|
{ name = "dotenv" },
|
||||||
{ name = "google-genai" },
|
{ name = "google-genai" },
|
||||||
{ name = "lxml" },
|
{ name = "lxml" },
|
||||||
{ name = "markdownify" },
|
{ name = "markdownify" },
|
||||||
@@ -325,6 +346,7 @@ dependencies = [
|
|||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "beautifulsoup4", specifier = ">=4.14.3" },
|
{ name = "beautifulsoup4", specifier = ">=4.14.3" },
|
||||||
|
{ name = "dotenv", specifier = ">=0.9.9" },
|
||||||
{ name = "google-genai", specifier = ">=1.59.0" },
|
{ name = "google-genai", specifier = ">=1.59.0" },
|
||||||
{ name = "lxml", specifier = ">=6.0.2" },
|
{ name = "lxml", specifier = ">=6.0.2" },
|
||||||
{ name = "markdownify", specifier = ">=1.2.2" },
|
{ name = "markdownify", specifier = ">=1.2.2" },
|
||||||
|
|||||||
Reference in New Issue
Block a user