First pass on langchain. Need to find a better POC example

This commit is contained in:
whitney 2025-02-10 17:52:11 -08:00
parent ebe43463b6
commit 8ad6808ee5
10 changed files with 226 additions and 56 deletions

3
.gitignore vendored
View File

@ -1,3 +1,4 @@
env/
output/
config.json
tools/config.json
tools/__pycache__/

View File

@ -5,12 +5,11 @@ So the idea here is to make an llm agent with tools that use https://imgflip.com
## Setup
### Python
```bash
python -m venv env
source env/bin/activate
```
```bash
pip install -r requirements.txt
```
@ -26,3 +25,16 @@ Example:
```bash
python main.py
```
### Ollama
Note we are using the `langchain-ollama` python package (see requirements).
Pull model
```
ollama pull llama3
```
If failing, check status of ollama process:
```
sudo service ollama status
```

96
main.py
View File

@ -1,59 +1,51 @@
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama.llms import OllamaLLM
from langchain_core.tools import Tool
from langchain.agents import AgentExecutor, AgentType, Tool, initialize_agent
import requests
import os
import json
def load_config():
with open('config.json') as config_file:
return json.load(config_file)
def get_memes():
url = "https://api.imgflip.com/get_memes"
response = requests.get(url)
result = response.json()
if result['success']:
memes = result['data']['memes']
for meme in memes:
print(f"ID: {meme['id']}, Name: {meme['name']}")
else:
print("Failed to retrieve memes.")
from tools.get_memes import get_memes
from tools.caption_image import caption_image
from tools.download_image import download_image
def create_meme(template_id, text0, text1):
config = load_config()
username = config['username']
password = config['password']
url = "https://api.imgflip.com/caption_image"
payload = {
"template_id": template_id,
"username": username,
"password": password,
"text0": text0,
"text1": text1
}
response = requests.post(url, data=payload)
result = response.json()
if result['success']:
meme_url = result['data']['url']
print(f"Meme created! URL: {meme_url}")
download_image(meme_url)
else:
print(f"Error: {result['error_message']}")
system_prompt = """
You are an assistant that looks up the numerical template_id of a meme from imgflip.
The following tools are available to you:
def download_image(url):
response = requests.get(url)
if response.status_code == 200:
output_dir = 'output'
os.makedirs(output_dir, exist_ok=True)
image_path = os.path.join(output_dir, url.split("/")[-1])
with open(image_path, 'wb') as f:
f.write(response.content)
print(f"Image saved to {image_path}")
else:
print("Failed to download image.")
1. get_memes - Does not take any agruments. Returns a list of template_ids (integer) and names (string) which are the titles of the memes that correspond to the template_id.
2. caption_image - Given a valid template_id, top text, and bottom text, generates an image with the desired text. Returns the url of the new meme as a string.
3. download_image - Given a valid url returned from the caption_image tool, downloads the image we made locally.
Use these tools if necessary to answer questions.
"""
prompt_template = f"""
{system_prompt}
Question: {{question}}
Answer: Let's think step by step.
"""
prompt = ChatPromptTemplate.from_template(prompt_template)
tools = [
Tool(name="Get Memes", func=get_memes, description="Does not take any agruments. Returns a list of template_ids (integer) and names (string) which are the titles of the memes that correspond to the template_id."),
Tool(name="Caption Image", func=caption_image, description="Given a valid template_id, top text, and bottom text, generates an image with the desired text. Returns the url of the new meme as a string."),
Tool(name="Download Image", func=download_image, description="Given a valid url returned from the caption_image tool, downloads the image we made locally.")
]
llm = OllamaLLM(model="llama3")
agent_executor = initialize_agent(
tools=tools,
llm=llm,
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True
)
response = agent_executor.invoke({"input": "Generate an image for the 'stick poke' meme with the top text 'come on' and the bottom text 'do something'."})
print(response)
create_meme(20007896, "Top text", "Bottom text")
# get_memes()

View File

@ -1 +1,5 @@
requests
langchain
langchain-community
langchain-core
langchain-ollama

View File

@ -0,0 +1,16 @@
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama.llms import OllamaLLM
template = """Question: {question}
Answer: Let's think step by step."""
prompt = ChatPromptTemplate.from_template(template)
model = OllamaLLM(model="llama3")
chain = prompt | model
result = chain.invoke({"question": "What is LangChain?"})
print(result)

View File

@ -0,0 +1,73 @@
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama.llms import OllamaLLM
from langchain_core.tools import Tool
from langchain.agents import AgentExecutor, AgentType, Tool, initialize_agent
### Prompt Definitions
system_prompt = """
You are an assistant that performs basic arithmetic operations.
The following tools are available to you:
1. Add - Gets the sum of 2 numbers. Input Format: (a,b)
2. Subtract - Gets the subtraction result of 2 numbers. Input Format: (a,b)
Use these tools if necessary to answer questions.
"""
prompt_template = f"""
{system_prompt}
Question: {{question}}
Answer: Let's think step by step.
"""
prompt = ChatPromptTemplate.from_template(prompt_template)
### Tool Definitions
def add(*args) -> float:
args_tuple = string_to_tuple(args[0])
assert len(args_tuple) == 2
return args_tuple[0] + args_tuple[1]
def subtract(*args) -> float:
args_tuple = string_to_tuple(args[0])
assert len(args_tuple) == 2
return args_tuple[0] - args_tuple[1]
# When tool is invoked, we get whatever the LLM wanted to send it as a string.
def string_to_tuple(s):
s = s.strip("()")
vals = s.split(', ')
vals = tuple(map(float, vals))
return vals
tools = [
Tool(name="Add", func=add, description="Performs addition of exactly two numbers."),
Tool(name="Subtract", func=subtract, description="Performs subtraction of exactly two numbers."),
]
### Langchain Definition
llm = OllamaLLM(model="llama3")
agent_executor = initialize_agent(
tools=tools,
llm=llm,
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True
)
response = agent_executor.invoke({"input": "What is 3 + 5?"})
print(response)
response = agent_executor.invoke({"input": "What is 10 - 4?"})
print(response)
response = agent_executor.invoke({"input": "What is 5 + 4? Also, what is 99 - 33?"})
print(response)

32
tools/caption_image.py Normal file
View File

@ -0,0 +1,32 @@
import json
import requests
CAPTION_IMAGE_URL = "https://api.imgflip.com/caption_image"
def load_config():
with open('config.json') as config_file:
return json.load(config_file)
def caption_image(template_id, text0, text1):
config = load_config()
username = config['username']
password = config['password']
url = "https://api.imgflip.com/caption_image"
payload = {
"template_id": template_id,
"username": username,
"password": password,
"text0": text0,
"text1": text1
}
response = requests.post(url, data=payload)
result = response.json()
if result['success']:
meme_url = result['data']['url']
print(f"Meme created! URL: {meme_url}")
return meme_url
else:
return None

14
tools/download_image.py Normal file
View File

@ -0,0 +1,14 @@
import os
import requests
def download_image(url):
response = requests.get(url)
if response.status_code == 200:
output_dir = 'output'
os.makedirs(output_dir, exist_ok=True)
image_path = os.path.join(output_dir, url.split("/")[-1])
with open(image_path, 'wb') as f:
f.write(response.content)
print(f"Image saved to {image_path}")
else:
print("Failed to download image.")

18
tools/get_memes.py Normal file
View File

@ -0,0 +1,18 @@
import requests
GET_MEMES_URL = "https://api.imgflip.com/get_memes"
def get_memes(ignored):
return get_memes_helper()
def get_memes_helper():
response = requests.get(GET_MEMES_URL)
result = response.json()
if result['success']:
memes = result['data']['memes']
memes_str = "\n".join([f"ID: {meme['id']}, Name: {meme['name']}" for meme in memes])
return memes_str
else:
return "No template_ids found."

8
tools/test_tools.py Normal file
View File

@ -0,0 +1,8 @@
from caption_image import caption_image
from download_image import download_image
from get_memes import get_memes
get_memes()
meme_url = caption_image(20007896, "Top text", "Bottom text")
download_image(meme_url)