From 21558b88cd95da77a355936d25971de3b4758f4d Mon Sep 17 00:00:00 2001 From: Jarno Date: Sun, 18 Jan 2026 21:16:05 +0200 Subject: [PATCH] WIP split implementation based on backend --- pyproject.toml | 1 + src/AiAgentScraper/agent.py | 18 +---- src/AiAgentScraper/geminiagent.py | 27 +++++++ src/AiAgentScraper/ollamaagent.py | 17 +++-- src/AiAgentScraper/tools.py | 117 ++++++++++++++++++------------ uv.lock | 22 ++++++ 6 files changed, 132 insertions(+), 70 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8fb268e..1d88fe2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ readme = "README.md" requires-python = ">=3.14" dependencies = [ "beautifulsoup4>=4.14.3", + "dotenv>=0.9.9", "google-genai>=1.59.0", "lxml>=6.0.2", "markdownify>=1.2.2", diff --git a/src/AiAgentScraper/agent.py b/src/AiAgentScraper/agent.py index 03e0685..f26fc42 100644 --- a/src/AiAgentScraper/agent.py +++ b/src/AiAgentScraper/agent.py @@ -5,6 +5,7 @@ import ollama from tools import available_functions from ollamaagent import OllamaAgent +from geminiagent import GeminiAgent class Backend(Enum): @@ -17,7 +18,7 @@ class Backend(Enum): logger = get_logger(__name__) -tools = [available_functions["fetch_web_page"].spec] +tools = [available_functions["fetch_web_page"]] def run_agent(prompt: str, backend: Backend, model: str): @@ -25,16 +26,5 @@ def run_agent(prompt: str, backend: Backend, model: str): agent = OllamaAgent(model=model, tools=tools) return agent.prompt(message=prompt) else: - raise NotImplementedError - - -def execute_function(tool): - function_name = tool["function"]["name"] - args = tool["function"]["arguments"] - logger.info(f"Agent is calling: {function_name}({args})") - f = available_functions[function_name].function - return { - "role": "tool", - "content": f(**args), - "name": function_name, - } + agent = GeminiAgent(model=model, tools=tools) + return agent.prompt(message=prompt) diff --git a/src/AiAgentScraper/geminiagent.py b/src/AiAgentScraper/geminiagent.py index e69de29..ceaf941 100644 --- a/src/AiAgentScraper/geminiagent.py +++ b/src/AiAgentScraper/geminiagent.py @@ -0,0 +1,27 @@ +from google import genai +from logger import get_logger + +DEFAULT_MODEL = "gemini-2.5-flash" +API_KEY="AIzaSyAqmZCqfNLkegMq69E3-U3PDlryXmXfrJs" + +logger = get_logger(__name__) + +class GeminiAgent: + + def __init__(self, model, tools, max_loop=10): + if model: + self.model = model + else: + self.model = DEFAULT_MODEL + logger.info(f"Model: {self.model}") + self.tools = tools + self.max_loop = max_loop + self.client = genai.Client(api_key=API_KEY) + + + def prompt(self, message): + response = self.client.models.generate_content( + model = self.model, contents=message + ) + return response.text + diff --git a/src/AiAgentScraper/ollamaagent.py b/src/AiAgentScraper/ollamaagent.py index b1b391d..b294213 100644 --- a/src/AiAgentScraper/ollamaagent.py +++ b/src/AiAgentScraper/ollamaagent.py @@ -1,7 +1,7 @@ import ollama from logger import get_logger -from agent import execute_function +from tools import execute_function, Tool2 DEFAULT_MODEL = "ministral-3:8b" @@ -14,29 +14,30 @@ system_prompt = { class OllamaAgent: - def __init__(self, model, tools, max_loop=10): + def __init__(self, model, tools: list[Tool2], max_loop=10): if model: self.model = model else: self.model = DEFAULT_MODEL - logger.info("Model: {self.model}") + logger.info(f"Model: {self.model}") self.tools = tools + self.ollama_tools = list(map(lambda tool: tool.to_ollama(), tools)) self.max_loop = max_loop - def prompt(message): + def prompt(self, message): messages = [ system_prompt, {"role": "user", "content": message}, ] loops = 0 - response = ollama.chat(model=self.model, messages=messages, tools=self.tools) + response = ollama.chat(model=self.model, messages=messages, tools=self.ollama_tools) rmessage = response["message"] - while "tool_calls" in rmessage and loops < max_loop: - max_loop += 1 + while "tool_calls" in rmessage and loops < self.max_loop: + self.max_loop += 1 logger.debug(f"Tool calls: {len(rmessage["tool_calls"])}") for tool in rmessage["tool_calls"]: messages.append(execute_function(tool)) - response = ollama.chat(model=self.model, messages=messages, tools=self.tools) + response = ollama.chat(model=self.model, messages=messages, tools=self.ollama_tools) rmessage = response["message"] return rmessage["content"] diff --git a/src/AiAgentScraper/tools.py b/src/AiAgentScraper/tools.py index 8cea562..27b5de4 100644 --- a/src/AiAgentScraper/tools.py +++ b/src/AiAgentScraper/tools.py @@ -1,5 +1,9 @@ import os +from dataclasses import dataclass from scrape import fetch_web_page +from logger import get_logger + +logger = get_logger(__name__) class Tool: @@ -8,6 +12,47 @@ class Tool: self.function = function +@dataclass +class ToolProperty: + name: str + prop_type: str + items: str | None + description: str + required: bool + + +class Tool2: + def __init__( + self, function, name: str, description: str, properties: list[ToolProperty] + ): + self.function = function + self.name = name + self.description = description + self.properties = properties + + def to_ollama(self) -> dict: + required = [] + props = {} + for prop in self.properties: + if prop.required: + required.append(prop.name) + props[prop.name] = {"type": prop.prop_type, "description": prop.description} + if prop.items: + props[prop.name]["items"] = {"type": prop.items.type} + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": props, + "required": required, + }, + }, + } + + def list_directory_contents(**kwargs): try: return str({path: os.listdir(kwargs["path"])}) @@ -23,58 +68,34 @@ def get_current_directory(**kwargs): available_functions: map[str, Tool] = { - "list_directory_contents": Tool( - { - "type": "function", - "function": { - "name": "list_directory_contents", - "description": "List files in a directory. Requires an absolute path (e.g., /home/user/project) obtained from get_current_directory.", - "parameters": { - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "The path to the directory", - } - }, - "required": ["path"], - }, - }, - }, + "list_directory_contents": Tool2( list_directory_contents, + "list_directory_contenst", + "List files in a directory. Requires an absolute path (e.g., /home/user/project) obtained from get_current_directory.", + [ToolProperty("path", "string", None, "The path to the directory", True)], ), - "get_current_directory": Tool( - { - "type": "function", - "function": { - "name": "get_current_directory", - "description": "Returns the path of the current directory. Use this to orient yourself and to know which directory you are working in.", - "parameters": { - "type": "object", - "properties": {}, - }, - }, - }, + "get_current_directory": Tool2( get_current_directory, + "get_current_directory", + "Returns the path of the current directory. Use this to orient yourself and to know which directory you are working in.", + [], ), - "fetch_web_page": Tool( - { - "type": "function", - "function": { - "name": "fetch_web_page", - "description": "Fetch and read the content of a web page via URL", - "parameters": { - "type": "object", - "properties": { - "url": { - "type": "string", - "description": "The full URL to fetch", - } - }, - "required": ["url"], - }, - }, - }, + "fetch_web_page": Tool2( fetch_web_page, + "fetch_web_page", + "Fetch and read the content of a web page via URL", + [ToolProperty("url", "string", None, "The full URL to fetch from", True)], ), } + + +def execute_function(tool): + function_name = tool["function"]["name"] + args = tool["function"]["arguments"] + logger.info(f"Agent is calling: {function_name}({args})") + f = available_functions[function_name].function + return { + "role": "tool", + "content": f(**args), + "name": function_name, + } diff --git a/uv.lock b/uv.lock index fbbe2ad..eeafa79 100644 --- a/uv.lock +++ b/uv.lock @@ -79,6 +79,17 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, ] +[[package]] +name = "dotenv" +version = "0.9.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dotenv" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/b2/b7/545d2c10c1fc15e48653c91efde329a790f2eecfbbf2bd16003b5db2bab0/dotenv-0.9.9-py2.py3-none-any.whl", hash = "sha256:29cf74a087b31dafdb5a446b6d7e11cbce8ed2741540e2339c69fbef92c94ce9", size = 1892, upload-time = "2025-02-19T22:15:01.647Z" }, +] + [[package]] name = "google-auth" version = "2.47.0" @@ -309,12 +320,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9f/ed/068e41660b832bb0b1aa5b58011dea2a3fe0ba7861ff38c4d4904c1c1a99/pydantic_core-2.41.5-cp314-cp314t-win_arm64.whl", hash = "sha256:35b44f37a3199f771c3eaa53051bc8a70cd7b54f333531c59e29fd4db5d15008", size = 1974769, upload-time = "2025-11-04T13:42:01.186Z" }, ] +[[package]] +name = "python-dotenv" +version = "1.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f0/26/19cadc79a718c5edbec86fd4919a6b6d3f681039a2f6d66d14be94e75fb9/python_dotenv-1.2.1.tar.gz", hash = "sha256:42667e897e16ab0d66954af0e60a9caa94f0fd4ecf3aaf6d2d260eec1aa36ad6", size = 44221, upload-time = "2025-10-26T15:12:10.434Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/1b/a298b06749107c305e1fe0f814c6c74aea7b2f1e10989cb30f544a1b3253/python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61", size = 21230, upload-time = "2025-10-26T15:12:09.109Z" }, +] + [[package]] name = "python-scraper" version = "0.1.0" source = { virtual = "." } dependencies = [ { name = "beautifulsoup4" }, + { name = "dotenv" }, { name = "google-genai" }, { name = "lxml" }, { name = "markdownify" }, @@ -325,6 +346,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "beautifulsoup4", specifier = ">=4.14.3" }, + { name = "dotenv", specifier = ">=0.9.9" }, { name = "google-genai", specifier = ">=1.59.0" }, { name = "lxml", specifier = ">=6.0.2" }, { name = "markdownify", specifier = ">=1.2.2" },