mirror of
https://github.com/runyanjake/memechain.git
synced 2025-10-04 15:57:28 -07:00
First pass on langchain. Need to find a better POC example
This commit is contained in:
parent
ebe43463b6
commit
8ad6808ee5
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
env/
|
||||
output/
|
||||
config.json
|
||||
tools/config.json
|
||||
tools/__pycache__/
|
||||
|
18
README.md
18
README.md
@ -5,12 +5,11 @@ So the idea here is to make an llm agent with tools that use https://imgflip.com
|
||||
|
||||
## Setup
|
||||
|
||||
### Python
|
||||
|
||||
```bash
|
||||
python -m venv env
|
||||
source env/bin/activate
|
||||
```
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
@ -26,3 +25,16 @@ Example:
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
|
||||
### Ollama
|
||||
Note we are using the `langchain-ollama` python package (see requirements).
|
||||
Pull model
|
||||
```
|
||||
ollama pull llama3
|
||||
```
|
||||
If failing, check status of ollama process:
|
||||
```
|
||||
sudo service ollama status
|
||||
```
|
||||
|
||||
|
||||
|
96
main.py
96
main.py
@ -1,59 +1,51 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_ollama.llms import OllamaLLM
|
||||
from langchain_core.tools import Tool
|
||||
from langchain.agents import AgentExecutor, AgentType, Tool, initialize_agent
|
||||
|
||||
import requests
|
||||
import os
|
||||
import json
|
||||
|
||||
def load_config():
|
||||
with open('config.json') as config_file:
|
||||
return json.load(config_file)
|
||||
|
||||
def get_memes():
|
||||
url = "https://api.imgflip.com/get_memes"
|
||||
response = requests.get(url)
|
||||
result = response.json()
|
||||
|
||||
if result['success']:
|
||||
memes = result['data']['memes']
|
||||
for meme in memes:
|
||||
print(f"ID: {meme['id']}, Name: {meme['name']}")
|
||||
else:
|
||||
print("Failed to retrieve memes.")
|
||||
from tools.get_memes import get_memes
|
||||
from tools.caption_image import caption_image
|
||||
from tools.download_image import download_image
|
||||
|
||||
def create_meme(template_id, text0, text1):
|
||||
config = load_config()
|
||||
username = config['username']
|
||||
password = config['password']
|
||||
|
||||
url = "https://api.imgflip.com/caption_image"
|
||||
payload = {
|
||||
"template_id": template_id,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"text0": text0,
|
||||
"text1": text1
|
||||
}
|
||||
|
||||
response = requests.post(url, data=payload)
|
||||
result = response.json()
|
||||
|
||||
if result['success']:
|
||||
meme_url = result['data']['url']
|
||||
print(f"Meme created! URL: {meme_url}")
|
||||
download_image(meme_url)
|
||||
else:
|
||||
print(f"Error: {result['error_message']}")
|
||||
system_prompt = """
|
||||
You are an assistant that looks up the numerical template_id of a meme from imgflip.
|
||||
The following tools are available to you:
|
||||
|
||||
def download_image(url):
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
output_dir = 'output'
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
image_path = os.path.join(output_dir, url.split("/")[-1])
|
||||
with open(image_path, 'wb') as f:
|
||||
f.write(response.content)
|
||||
print(f"Image saved to {image_path}")
|
||||
else:
|
||||
print("Failed to download image.")
|
||||
1. get_memes - Does not take any agruments. Returns a list of template_ids (integer) and names (string) which are the titles of the memes that correspond to the template_id.
|
||||
2. caption_image - Given a valid template_id, top text, and bottom text, generates an image with the desired text. Returns the url of the new meme as a string.
|
||||
3. download_image - Given a valid url returned from the caption_image tool, downloads the image we made locally.
|
||||
|
||||
Use these tools if necessary to answer questions.
|
||||
"""
|
||||
|
||||
prompt_template = f"""
|
||||
{system_prompt}
|
||||
|
||||
Question: {{question}}
|
||||
|
||||
Answer: Let's think step by step.
|
||||
"""
|
||||
|
||||
prompt = ChatPromptTemplate.from_template(prompt_template)
|
||||
|
||||
tools = [
|
||||
Tool(name="Get Memes", func=get_memes, description="Does not take any agruments. Returns a list of template_ids (integer) and names (string) which are the titles of the memes that correspond to the template_id."),
|
||||
Tool(name="Caption Image", func=caption_image, description="Given a valid template_id, top text, and bottom text, generates an image with the desired text. Returns the url of the new meme as a string."),
|
||||
Tool(name="Download Image", func=download_image, description="Given a valid url returned from the caption_image tool, downloads the image we made locally.")
|
||||
]
|
||||
|
||||
llm = OllamaLLM(model="llama3")
|
||||
|
||||
agent_executor = initialize_agent(
|
||||
tools=tools,
|
||||
llm=llm,
|
||||
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
response = agent_executor.invoke({"input": "Generate an image for the 'stick poke' meme with the top text 'come on' and the bottom text 'do something'."})
|
||||
print(response)
|
||||
|
||||
create_meme(20007896, "Top text", "Bottom text")
|
||||
# get_memes()
|
@ -1 +1,5 @@
|
||||
requests
|
||||
langchain
|
||||
langchain-community
|
||||
langchain-core
|
||||
langchain-ollama
|
||||
|
16
test/0_langchain_ollama_test.py
Normal file
16
test/0_langchain_ollama_test.py
Normal file
@ -0,0 +1,16 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_ollama.llms import OllamaLLM
|
||||
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
|
||||
model = OllamaLLM(model="llama3")
|
||||
|
||||
chain = prompt | model
|
||||
|
||||
result = chain.invoke({"question": "What is LangChain?"})
|
||||
|
||||
print(result)
|
73
test/1_langchain_tools_test.py
Normal file
73
test/1_langchain_tools_test.py
Normal file
@ -0,0 +1,73 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_ollama.llms import OllamaLLM
|
||||
from langchain_core.tools import Tool
|
||||
from langchain.agents import AgentExecutor, AgentType, Tool, initialize_agent
|
||||
|
||||
|
||||
### Prompt Definitions
|
||||
|
||||
system_prompt = """
|
||||
You are an assistant that performs basic arithmetic operations.
|
||||
|
||||
The following tools are available to you:
|
||||
1. Add - Gets the sum of 2 numbers. Input Format: (a,b)
|
||||
2. Subtract - Gets the subtraction result of 2 numbers. Input Format: (a,b)
|
||||
|
||||
Use these tools if necessary to answer questions.
|
||||
"""
|
||||
|
||||
prompt_template = f"""
|
||||
{system_prompt}
|
||||
|
||||
Question: {{question}}
|
||||
|
||||
Answer: Let's think step by step.
|
||||
"""
|
||||
|
||||
prompt = ChatPromptTemplate.from_template(prompt_template)
|
||||
|
||||
|
||||
### Tool Definitions
|
||||
|
||||
def add(*args) -> float:
|
||||
args_tuple = string_to_tuple(args[0])
|
||||
assert len(args_tuple) == 2
|
||||
return args_tuple[0] + args_tuple[1]
|
||||
|
||||
def subtract(*args) -> float:
|
||||
args_tuple = string_to_tuple(args[0])
|
||||
assert len(args_tuple) == 2
|
||||
return args_tuple[0] - args_tuple[1]
|
||||
|
||||
# When tool is invoked, we get whatever the LLM wanted to send it as a string.
|
||||
def string_to_tuple(s):
|
||||
s = s.strip("()")
|
||||
vals = s.split(', ')
|
||||
vals = tuple(map(float, vals))
|
||||
return vals
|
||||
|
||||
tools = [
|
||||
Tool(name="Add", func=add, description="Performs addition of exactly two numbers."),
|
||||
Tool(name="Subtract", func=subtract, description="Performs subtraction of exactly two numbers."),
|
||||
]
|
||||
|
||||
|
||||
### Langchain Definition
|
||||
|
||||
llm = OllamaLLM(model="llama3")
|
||||
|
||||
agent_executor = initialize_agent(
|
||||
tools=tools,
|
||||
llm=llm,
|
||||
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
response = agent_executor.invoke({"input": "What is 3 + 5?"})
|
||||
print(response)
|
||||
|
||||
response = agent_executor.invoke({"input": "What is 10 - 4?"})
|
||||
print(response)
|
||||
|
||||
response = agent_executor.invoke({"input": "What is 5 + 4? Also, what is 99 - 33?"})
|
||||
print(response)
|
32
tools/caption_image.py
Normal file
32
tools/caption_image.py
Normal file
@ -0,0 +1,32 @@
|
||||
import json
|
||||
import requests
|
||||
|
||||
CAPTION_IMAGE_URL = "https://api.imgflip.com/caption_image"
|
||||
|
||||
def load_config():
|
||||
with open('config.json') as config_file:
|
||||
return json.load(config_file)
|
||||
|
||||
def caption_image(template_id, text0, text1):
|
||||
config = load_config()
|
||||
username = config['username']
|
||||
password = config['password']
|
||||
|
||||
url = "https://api.imgflip.com/caption_image"
|
||||
payload = {
|
||||
"template_id": template_id,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"text0": text0,
|
||||
"text1": text1
|
||||
}
|
||||
|
||||
response = requests.post(url, data=payload)
|
||||
result = response.json()
|
||||
|
||||
if result['success']:
|
||||
meme_url = result['data']['url']
|
||||
print(f"Meme created! URL: {meme_url}")
|
||||
return meme_url
|
||||
else:
|
||||
return None
|
14
tools/download_image.py
Normal file
14
tools/download_image.py
Normal file
@ -0,0 +1,14 @@
|
||||
import os
|
||||
import requests
|
||||
|
||||
def download_image(url):
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
output_dir = 'output'
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
image_path = os.path.join(output_dir, url.split("/")[-1])
|
||||
with open(image_path, 'wb') as f:
|
||||
f.write(response.content)
|
||||
print(f"Image saved to {image_path}")
|
||||
else:
|
||||
print("Failed to download image.")
|
18
tools/get_memes.py
Normal file
18
tools/get_memes.py
Normal file
@ -0,0 +1,18 @@
|
||||
import requests
|
||||
|
||||
GET_MEMES_URL = "https://api.imgflip.com/get_memes"
|
||||
|
||||
def get_memes(ignored):
|
||||
return get_memes_helper()
|
||||
|
||||
def get_memes_helper():
|
||||
response = requests.get(GET_MEMES_URL)
|
||||
result = response.json()
|
||||
|
||||
if result['success']:
|
||||
memes = result['data']['memes']
|
||||
memes_str = "\n".join([f"ID: {meme['id']}, Name: {meme['name']}" for meme in memes])
|
||||
return memes_str
|
||||
else:
|
||||
return "No template_ids found."
|
||||
|
8
tools/test_tools.py
Normal file
8
tools/test_tools.py
Normal file
@ -0,0 +1,8 @@
|
||||
from caption_image import caption_image
|
||||
from download_image import download_image
|
||||
from get_memes import get_memes
|
||||
|
||||
get_memes()
|
||||
meme_url = caption_image(20007896, "Top text", "Bottom text")
|
||||
download_image(meme_url)
|
||||
|
Loading…
x
Reference in New Issue
Block a user