mirror of
https://github.com/runyanjake/memechain.git
synced 2025-10-05 08:07:29 -07:00
First pass on langchain. Need to find a better POC example
This commit is contained in:
parent
ebe43463b6
commit
8ad6808ee5
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
env/
|
env/
|
||||||
output/
|
output/
|
||||||
config.json
|
tools/config.json
|
||||||
|
tools/__pycache__/
|
||||||
|
18
README.md
18
README.md
@ -5,12 +5,11 @@ So the idea here is to make an llm agent with tools that use https://imgflip.com
|
|||||||
|
|
||||||
## Setup
|
## Setup
|
||||||
|
|
||||||
|
### Python
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m venv env
|
python -m venv env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
```
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -26,3 +25,16 @@ Example:
|
|||||||
```bash
|
```bash
|
||||||
python main.py
|
python main.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Ollama
|
||||||
|
Note we are using the `langchain-ollama` python package (see requirements).
|
||||||
|
Pull model
|
||||||
|
```
|
||||||
|
ollama pull llama3
|
||||||
|
```
|
||||||
|
If failing, check status of ollama process:
|
||||||
|
```
|
||||||
|
sudo service ollama status
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
86
main.py
86
main.py
@ -1,59 +1,51 @@
|
|||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_ollama.llms import OllamaLLM
|
||||||
|
from langchain_core.tools import Tool
|
||||||
|
from langchain.agents import AgentExecutor, AgentType, Tool, initialize_agent
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import os
|
import os
|
||||||
import json
|
|
||||||
|
|
||||||
def load_config():
|
from tools.get_memes import get_memes
|
||||||
with open('config.json') as config_file:
|
from tools.caption_image import caption_image
|
||||||
return json.load(config_file)
|
from tools.download_image import download_image
|
||||||
|
|
||||||
def get_memes():
|
system_prompt = """
|
||||||
url = "https://api.imgflip.com/get_memes"
|
You are an assistant that looks up the numerical template_id of a meme from imgflip.
|
||||||
response = requests.get(url)
|
The following tools are available to you:
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
if result['success']:
|
1. get_memes - Does not take any agruments. Returns a list of template_ids (integer) and names (string) which are the titles of the memes that correspond to the template_id.
|
||||||
memes = result['data']['memes']
|
2. caption_image - Given a valid template_id, top text, and bottom text, generates an image with the desired text. Returns the url of the new meme as a string.
|
||||||
for meme in memes:
|
3. download_image - Given a valid url returned from the caption_image tool, downloads the image we made locally.
|
||||||
print(f"ID: {meme['id']}, Name: {meme['name']}")
|
|
||||||
else:
|
|
||||||
print("Failed to retrieve memes.")
|
|
||||||
|
|
||||||
def create_meme(template_id, text0, text1):
|
Use these tools if necessary to answer questions.
|
||||||
config = load_config()
|
"""
|
||||||
username = config['username']
|
|
||||||
password = config['password']
|
|
||||||
|
|
||||||
url = "https://api.imgflip.com/caption_image"
|
prompt_template = f"""
|
||||||
payload = {
|
{system_prompt}
|
||||||
"template_id": template_id,
|
|
||||||
"username": username,
|
|
||||||
"password": password,
|
|
||||||
"text0": text0,
|
|
||||||
"text1": text1
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(url, data=payload)
|
Question: {{question}}
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
if result['success']:
|
Answer: Let's think step by step.
|
||||||
meme_url = result['data']['url']
|
"""
|
||||||
print(f"Meme created! URL: {meme_url}")
|
|
||||||
download_image(meme_url)
|
|
||||||
else:
|
|
||||||
print(f"Error: {result['error_message']}")
|
|
||||||
|
|
||||||
def download_image(url):
|
prompt = ChatPromptTemplate.from_template(prompt_template)
|
||||||
response = requests.get(url)
|
|
||||||
if response.status_code == 200:
|
|
||||||
output_dir = 'output'
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
image_path = os.path.join(output_dir, url.split("/")[-1])
|
|
||||||
with open(image_path, 'wb') as f:
|
|
||||||
f.write(response.content)
|
|
||||||
print(f"Image saved to {image_path}")
|
|
||||||
else:
|
|
||||||
print("Failed to download image.")
|
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
Tool(name="Get Memes", func=get_memes, description="Does not take any agruments. Returns a list of template_ids (integer) and names (string) which are the titles of the memes that correspond to the template_id."),
|
||||||
|
Tool(name="Caption Image", func=caption_image, description="Given a valid template_id, top text, and bottom text, generates an image with the desired text. Returns the url of the new meme as a string."),
|
||||||
|
Tool(name="Download Image", func=download_image, description="Given a valid url returned from the caption_image tool, downloads the image we made locally.")
|
||||||
|
]
|
||||||
|
|
||||||
|
llm = OllamaLLM(model="llama3")
|
||||||
|
|
||||||
|
agent_executor = initialize_agent(
|
||||||
|
tools=tools,
|
||||||
|
llm=llm,
|
||||||
|
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
response = agent_executor.invoke({"input": "Generate an image for the 'stick poke' meme with the top text 'come on' and the bottom text 'do something'."})
|
||||||
|
print(response)
|
||||||
|
|
||||||
create_meme(20007896, "Top text", "Bottom text")
|
|
||||||
# get_memes()
|
|
@ -1 +1,5 @@
|
|||||||
requests
|
requests
|
||||||
|
langchain
|
||||||
|
langchain-community
|
||||||
|
langchain-core
|
||||||
|
langchain-ollama
|
||||||
|
16
test/0_langchain_ollama_test.py
Normal file
16
test/0_langchain_ollama_test.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_ollama.llms import OllamaLLM
|
||||||
|
|
||||||
|
template = """Question: {question}
|
||||||
|
|
||||||
|
Answer: Let's think step by step."""
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate.from_template(template)
|
||||||
|
|
||||||
|
model = OllamaLLM(model="llama3")
|
||||||
|
|
||||||
|
chain = prompt | model
|
||||||
|
|
||||||
|
result = chain.invoke({"question": "What is LangChain?"})
|
||||||
|
|
||||||
|
print(result)
|
73
test/1_langchain_tools_test.py
Normal file
73
test/1_langchain_tools_test.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_ollama.llms import OllamaLLM
|
||||||
|
from langchain_core.tools import Tool
|
||||||
|
from langchain.agents import AgentExecutor, AgentType, Tool, initialize_agent
|
||||||
|
|
||||||
|
|
||||||
|
### Prompt Definitions
|
||||||
|
|
||||||
|
system_prompt = """
|
||||||
|
You are an assistant that performs basic arithmetic operations.
|
||||||
|
|
||||||
|
The following tools are available to you:
|
||||||
|
1. Add - Gets the sum of 2 numbers. Input Format: (a,b)
|
||||||
|
2. Subtract - Gets the subtraction result of 2 numbers. Input Format: (a,b)
|
||||||
|
|
||||||
|
Use these tools if necessary to answer questions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt_template = f"""
|
||||||
|
{system_prompt}
|
||||||
|
|
||||||
|
Question: {{question}}
|
||||||
|
|
||||||
|
Answer: Let's think step by step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate.from_template(prompt_template)
|
||||||
|
|
||||||
|
|
||||||
|
### Tool Definitions
|
||||||
|
|
||||||
|
def add(*args) -> float:
|
||||||
|
args_tuple = string_to_tuple(args[0])
|
||||||
|
assert len(args_tuple) == 2
|
||||||
|
return args_tuple[0] + args_tuple[1]
|
||||||
|
|
||||||
|
def subtract(*args) -> float:
|
||||||
|
args_tuple = string_to_tuple(args[0])
|
||||||
|
assert len(args_tuple) == 2
|
||||||
|
return args_tuple[0] - args_tuple[1]
|
||||||
|
|
||||||
|
# When tool is invoked, we get whatever the LLM wanted to send it as a string.
|
||||||
|
def string_to_tuple(s):
|
||||||
|
s = s.strip("()")
|
||||||
|
vals = s.split(', ')
|
||||||
|
vals = tuple(map(float, vals))
|
||||||
|
return vals
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
Tool(name="Add", func=add, description="Performs addition of exactly two numbers."),
|
||||||
|
Tool(name="Subtract", func=subtract, description="Performs subtraction of exactly two numbers."),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
### Langchain Definition
|
||||||
|
|
||||||
|
llm = OllamaLLM(model="llama3")
|
||||||
|
|
||||||
|
agent_executor = initialize_agent(
|
||||||
|
tools=tools,
|
||||||
|
llm=llm,
|
||||||
|
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
response = agent_executor.invoke({"input": "What is 3 + 5?"})
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
response = agent_executor.invoke({"input": "What is 10 - 4?"})
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
response = agent_executor.invoke({"input": "What is 5 + 4? Also, what is 99 - 33?"})
|
||||||
|
print(response)
|
32
tools/caption_image.py
Normal file
32
tools/caption_image.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import json
|
||||||
|
import requests
|
||||||
|
|
||||||
|
CAPTION_IMAGE_URL = "https://api.imgflip.com/caption_image"
|
||||||
|
|
||||||
|
def load_config():
|
||||||
|
with open('config.json') as config_file:
|
||||||
|
return json.load(config_file)
|
||||||
|
|
||||||
|
def caption_image(template_id, text0, text1):
|
||||||
|
config = load_config()
|
||||||
|
username = config['username']
|
||||||
|
password = config['password']
|
||||||
|
|
||||||
|
url = "https://api.imgflip.com/caption_image"
|
||||||
|
payload = {
|
||||||
|
"template_id": template_id,
|
||||||
|
"username": username,
|
||||||
|
"password": password,
|
||||||
|
"text0": text0,
|
||||||
|
"text1": text1
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(url, data=payload)
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if result['success']:
|
||||||
|
meme_url = result['data']['url']
|
||||||
|
print(f"Meme created! URL: {meme_url}")
|
||||||
|
return meme_url
|
||||||
|
else:
|
||||||
|
return None
|
14
tools/download_image.py
Normal file
14
tools/download_image.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import os
|
||||||
|
import requests
|
||||||
|
|
||||||
|
def download_image(url):
|
||||||
|
response = requests.get(url)
|
||||||
|
if response.status_code == 200:
|
||||||
|
output_dir = 'output'
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
image_path = os.path.join(output_dir, url.split("/")[-1])
|
||||||
|
with open(image_path, 'wb') as f:
|
||||||
|
f.write(response.content)
|
||||||
|
print(f"Image saved to {image_path}")
|
||||||
|
else:
|
||||||
|
print("Failed to download image.")
|
18
tools/get_memes.py
Normal file
18
tools/get_memes.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import requests
|
||||||
|
|
||||||
|
GET_MEMES_URL = "https://api.imgflip.com/get_memes"
|
||||||
|
|
||||||
|
def get_memes(ignored):
|
||||||
|
return get_memes_helper()
|
||||||
|
|
||||||
|
def get_memes_helper():
|
||||||
|
response = requests.get(GET_MEMES_URL)
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if result['success']:
|
||||||
|
memes = result['data']['memes']
|
||||||
|
memes_str = "\n".join([f"ID: {meme['id']}, Name: {meme['name']}" for meme in memes])
|
||||||
|
return memes_str
|
||||||
|
else:
|
||||||
|
return "No template_ids found."
|
||||||
|
|
8
tools/test_tools.py
Normal file
8
tools/test_tools.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from caption_image import caption_image
|
||||||
|
from download_image import download_image
|
||||||
|
from get_memes import get_memes
|
||||||
|
|
||||||
|
get_memes()
|
||||||
|
meme_url = caption_image(20007896, "Top text", "Bottom text")
|
||||||
|
download_image(meme_url)
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user