diff --git a/.gitignore b/.gitignore index c16328c..2653b96 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ env/ output/ -config.json \ No newline at end of file +tools/config.json +tools/__pycache__/ diff --git a/README.md b/README.md index 52f9586..3acd9db 100644 --- a/README.md +++ b/README.md @@ -5,12 +5,11 @@ So the idea here is to make an llm agent with tools that use https://imgflip.com ## Setup +### Python + ```bash python -m venv env source env/bin/activate -``` - -```bash pip install -r requirements.txt ``` @@ -26,3 +25,16 @@ Example: ```bash python main.py ``` + +### Ollama +Note we are using the `langchain-ollama` python package (see requirements). +Pull model +``` +ollama pull llama3 +``` +If failing, check status of ollama process: +``` +sudo service ollama status +``` + + diff --git a/main.py b/main.py index f5a0c59..faf59d7 100644 --- a/main.py +++ b/main.py @@ -1,59 +1,51 @@ +from langchain_core.prompts import ChatPromptTemplate +from langchain_ollama.llms import OllamaLLM +from langchain_core.tools import Tool +from langchain.agents import AgentExecutor, AgentType, Tool, initialize_agent + import requests import os -import json -def load_config(): - with open('config.json') as config_file: - return json.load(config_file) - -def get_memes(): - url = "https://api.imgflip.com/get_memes" - response = requests.get(url) - result = response.json() - - if result['success']: - memes = result['data']['memes'] - for meme in memes: - print(f"ID: {meme['id']}, Name: {meme['name']}") - else: - print("Failed to retrieve memes.") +from tools.get_memes import get_memes +from tools.caption_image import caption_image +from tools.download_image import download_image -def create_meme(template_id, text0, text1): - config = load_config() - username = config['username'] - password = config['password'] - - url = "https://api.imgflip.com/caption_image" - payload = { - "template_id": template_id, - "username": username, - "password": password, - "text0": text0, - "text1": text1 - } - - response = requests.post(url, data=payload) - result = response.json() - - if result['success']: - meme_url = result['data']['url'] - print(f"Meme created! URL: {meme_url}") - download_image(meme_url) - else: - print(f"Error: {result['error_message']}") +system_prompt = """ + You are an assistant that looks up the numerical template_id of a meme from imgflip. + The following tools are available to you: -def download_image(url): - response = requests.get(url) - if response.status_code == 200: - output_dir = 'output' - os.makedirs(output_dir, exist_ok=True) - image_path = os.path.join(output_dir, url.split("/")[-1]) - with open(image_path, 'wb') as f: - f.write(response.content) - print(f"Image saved to {image_path}") - else: - print("Failed to download image.") + 1. get_memes - Does not take any agruments. Returns a list of template_ids (integer) and names (string) which are the titles of the memes that correspond to the template_id. + 2. caption_image - Given a valid template_id, top text, and bottom text, generates an image with the desired text. Returns the url of the new meme as a string. + 3. download_image - Given a valid url returned from the caption_image tool, downloads the image we made locally. + Use these tools if necessary to answer questions. +""" + +prompt_template = f""" + {system_prompt} + + Question: {{question}} + + Answer: Let's think step by step. +""" + +prompt = ChatPromptTemplate.from_template(prompt_template) + +tools = [ + Tool(name="Get Memes", func=get_memes, description="Does not take any agruments. Returns a list of template_ids (integer) and names (string) which are the titles of the memes that correspond to the template_id."), + Tool(name="Caption Image", func=caption_image, description="Given a valid template_id, top text, and bottom text, generates an image with the desired text. Returns the url of the new meme as a string."), + Tool(name="Download Image", func=download_image, description="Given a valid url returned from the caption_image tool, downloads the image we made locally.") +] + +llm = OllamaLLM(model="llama3") + +agent_executor = initialize_agent( + tools=tools, + llm=llm, + agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, + verbose=True +) + +response = agent_executor.invoke({"input": "Generate an image for the 'stick poke' meme with the top text 'come on' and the bottom text 'do something'."}) +print(response) -create_meme(20007896, "Top text", "Bottom text") -# get_memes() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f229360..e2a6ca5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,5 @@ requests +langchain +langchain-community +langchain-core +langchain-ollama diff --git a/test/0_langchain_ollama_test.py b/test/0_langchain_ollama_test.py new file mode 100644 index 0000000..515be8a --- /dev/null +++ b/test/0_langchain_ollama_test.py @@ -0,0 +1,16 @@ +from langchain_core.prompts import ChatPromptTemplate +from langchain_ollama.llms import OllamaLLM + +template = """Question: {question} + +Answer: Let's think step by step.""" + +prompt = ChatPromptTemplate.from_template(template) + +model = OllamaLLM(model="llama3") + +chain = prompt | model + +result = chain.invoke({"question": "What is LangChain?"}) + +print(result) diff --git a/test/1_langchain_tools_test.py b/test/1_langchain_tools_test.py new file mode 100644 index 0000000..dcc35b7 --- /dev/null +++ b/test/1_langchain_tools_test.py @@ -0,0 +1,73 @@ +from langchain_core.prompts import ChatPromptTemplate +from langchain_ollama.llms import OllamaLLM +from langchain_core.tools import Tool +from langchain.agents import AgentExecutor, AgentType, Tool, initialize_agent + + +### Prompt Definitions + +system_prompt = """ + You are an assistant that performs basic arithmetic operations. + + The following tools are available to you: + 1. Add - Gets the sum of 2 numbers. Input Format: (a,b) + 2. Subtract - Gets the subtraction result of 2 numbers. Input Format: (a,b) + + Use these tools if necessary to answer questions. +""" + +prompt_template = f""" + {system_prompt} + + Question: {{question}} + + Answer: Let's think step by step. +""" + +prompt = ChatPromptTemplate.from_template(prompt_template) + + +### Tool Definitions + +def add(*args) -> float: + args_tuple = string_to_tuple(args[0]) + assert len(args_tuple) == 2 + return args_tuple[0] + args_tuple[1] + +def subtract(*args) -> float: + args_tuple = string_to_tuple(args[0]) + assert len(args_tuple) == 2 + return args_tuple[0] - args_tuple[1] + +# When tool is invoked, we get whatever the LLM wanted to send it as a string. +def string_to_tuple(s): + s = s.strip("()") + vals = s.split(', ') + vals = tuple(map(float, vals)) + return vals + +tools = [ + Tool(name="Add", func=add, description="Performs addition of exactly two numbers."), + Tool(name="Subtract", func=subtract, description="Performs subtraction of exactly two numbers."), +] + + +### Langchain Definition + +llm = OllamaLLM(model="llama3") + +agent_executor = initialize_agent( + tools=tools, + llm=llm, + agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, + verbose=True +) + +response = agent_executor.invoke({"input": "What is 3 + 5?"}) +print(response) + +response = agent_executor.invoke({"input": "What is 10 - 4?"}) +print(response) + +response = agent_executor.invoke({"input": "What is 5 + 4? Also, what is 99 - 33?"}) +print(response) diff --git a/tools/caption_image.py b/tools/caption_image.py new file mode 100644 index 0000000..149ff29 --- /dev/null +++ b/tools/caption_image.py @@ -0,0 +1,32 @@ +import json +import requests + +CAPTION_IMAGE_URL = "https://api.imgflip.com/caption_image" + +def load_config(): + with open('config.json') as config_file: + return json.load(config_file) + +def caption_image(template_id, text0, text1): + config = load_config() + username = config['username'] + password = config['password'] + + url = "https://api.imgflip.com/caption_image" + payload = { + "template_id": template_id, + "username": username, + "password": password, + "text0": text0, + "text1": text1 + } + + response = requests.post(url, data=payload) + result = response.json() + + if result['success']: + meme_url = result['data']['url'] + print(f"Meme created! URL: {meme_url}") + return meme_url + else: + return None diff --git a/tools/download_image.py b/tools/download_image.py new file mode 100644 index 0000000..8c50fe4 --- /dev/null +++ b/tools/download_image.py @@ -0,0 +1,14 @@ +import os +import requests + +def download_image(url): + response = requests.get(url) + if response.status_code == 200: + output_dir = 'output' + os.makedirs(output_dir, exist_ok=True) + image_path = os.path.join(output_dir, url.split("/")[-1]) + with open(image_path, 'wb') as f: + f.write(response.content) + print(f"Image saved to {image_path}") + else: + print("Failed to download image.") diff --git a/tools/get_memes.py b/tools/get_memes.py new file mode 100644 index 0000000..2019d2e --- /dev/null +++ b/tools/get_memes.py @@ -0,0 +1,18 @@ +import requests + +GET_MEMES_URL = "https://api.imgflip.com/get_memes" + +def get_memes(ignored): + return get_memes_helper() + +def get_memes_helper(): + response = requests.get(GET_MEMES_URL) + result = response.json() + + if result['success']: + memes = result['data']['memes'] + memes_str = "\n".join([f"ID: {meme['id']}, Name: {meme['name']}" for meme in memes]) + return memes_str + else: + return "No template_ids found." + diff --git a/tools/test_tools.py b/tools/test_tools.py new file mode 100644 index 0000000..e3c6a22 --- /dev/null +++ b/tools/test_tools.py @@ -0,0 +1,8 @@ +from caption_image import caption_image +from download_image import download_image +from get_memes import get_memes + +get_memes() +meme_url = caption_image(20007896, "Top text", "Bottom text") +download_image(meme_url) +