Finish the meme generator

This commit is contained in:
whitney 2025-04-14 13:49:54 -07:00
parent b19ff7102e
commit 7798f7b8a2
8 changed files with 219 additions and 0 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
output/
env/
tools/__pycache__/
tools/config.json

View File

@ -1,2 +1,17 @@
# memegraph
I previously did meme generation via Langchain in https://github.com/runyanjake/memechain. Here is the same project but written using Langgraph instead, since that seems to have superseded Langchain in recent time.
The main reason to make the switch is to explore [multi-agent](https://github.com/langchain-ai/langgraph/blob/main/docs/docs/tutorials/multi_agent/multi-agent-collaboration.ipynb) workflows, which hopefully we can apply to this project at some point.
## Setup
### Python
```
python3 -m venv env
source env/bin/activate
pip install -r requirements.txt
python main.py
...
deactivate
```

109
main.py Normal file
View File

@ -0,0 +1,109 @@
from langchain_ollama.llms import OllamaLLM
from langchain_core.output_parsers import PydanticOutputParser
from langgraph.graph import StateGraph, END
from pydantic import BaseModel, Field
from tools.get_memes import get_memes
from tools.caption_image import caption_image
from tools.download_image import download_image
from typing import TypedDict, Optional
import json
import re
############################# State Definition ################################
class State(TypedDict):
question: str
template_name: Optional[str]
text: Optional[list[str]]
template_id: Optional[int]
caption_result: Optional[dict]
image_path: Optional[str]
############################# LLM Definition ##################################
llm = OllamaLLM(model="llama3")
class MemeInfo(BaseModel):
template_name: str = Field(..., description="The meme template name")
text: list[str] = Field(..., description="List of captions/texts for the meme")
parser = PydanticOutputParser(pydantic_object=MemeInfo)
def parse_input(state: State) -> State:
prompt = f"""
Extract the meme name and list of texts from the user's request.
Return the output in this JSON format:
{parser.get_format_instructions()}
User input: {state["question"]}
"""
response = llm.invoke(prompt)
parsed = parser.invoke(response)
print(parsed)
return {**state, **parsed.dict()}
######################## Graph Components Definition ##########################
def node_get_memes(state: State) -> State:
memes_str = get_memes()
memes = memes_str.strip().splitlines()
for meme in memes:
match = re.match(r"ID: (\d+), Name: (.+)", meme)
if match:
meme_id = int(match.group(1))
meme_name = match.group(2).strip()
if meme_name.lower() == state["template_name"].lower():
print("Found Meme: " + meme)
return {**state, "template_id": meme_id}
raise ValueError(f"No matching meme found for template_name: {state['template_name']}")
def node_caption_image(state: State) -> State:
input_data = {
"template_id": int(state["template_id"]),
"text": state["text"]
}
json_input = json.dumps(input_data)
result = caption_image(json_input)
print(result)
return {**state, "caption_result": result}
def node_download_image(state: State) -> State:
caption_result = state["caption_result"]
print("Generated Image: " + caption_result)
if isinstance(caption_result, str) and caption_result.startswith("http"):
download_payload = json.dumps({"url": caption_result})
result = download_image(download_payload)
print("Saved to: " + result)
return {**state, "image_path": result}
else:
print("Failed to download image. " + caption_result)
############################## Graph Definition ###############################
graph = StateGraph(State)
graph.add_node("parse_input", parse_input)
graph.add_node("get_memes", node_get_memes)
graph.add_node("caption_image", node_caption_image)
graph.add_node("download_image", node_download_image)
graph.set_entry_point("parse_input")
graph.add_edge("parse_input", "get_memes")
graph.add_edge("get_memes", "caption_image")
graph.add_edge("caption_image", "download_image")
graph.add_edge("download_image", END)
app = graph.compile()
#################################### Run ######################################
question = "Generate an image for the 'two buttons' meme with first text 'generated meme' and second text 'langgraph error'."
final_state = app.invoke({"question": question})
print(final_state["image_path"])

7
requirements.txt Normal file
View File

@ -0,0 +1,7 @@
langchain_community
langchain_core>=0.1.35
langchain_experimental
langchain_ollama
langgraph
pillow
pydantic

40
tools/caption_image.py Normal file
View File

@ -0,0 +1,40 @@
import json
import requests
CAPTION_IMAGE_URL = "https://api.imgflip.com/caption_image"
def load_config():
with open('tools/config.json') as config_file:
return json.load(config_file)
def caption_image(input_data):
# Replace single quotes with double quotes because langchain likes to use single quotes
input_data = input_data.replace("'", '"')
data = json.loads(input_data)
template_id = data['template_id']
text = data['text']
config = load_config()
username = config['username']
password = config['password']
url = CAPTION_IMAGE_URL
payload = {
"template_id": template_id,
"username": username,
"password": password,
}
for i in range(len(text)):
payload[f'text{i}'] = text[i]
response = requests.post(url, data=payload)
result = response.json()
if result['success']:
meme_url = result['data']['url']
print(f"Meme created! URL: {meme_url}")
return meme_url
else:
return None

View File

@ -0,0 +1,4 @@
{
"username":"xxxx",
"password":"yyyy"
}

22
tools/download_image.py Normal file
View File

@ -0,0 +1,22 @@
import os
import requests
import json
def download_image(input_data):
input_data = input_data.replace("'", '"')
data = json.loads(input_data)
url = data['url']
response = requests.get(url)
if response.status_code == 200:
output_dir = 'output'
os.makedirs(output_dir, exist_ok=True)
image_path = os.path.join(output_dir, url.split("/")[-1])
with open(image_path, 'wb') as f:
f.write(response.content)
print(f"Image saved to {image_path}")
return image_path
else:
print("Failed to download image.")
return None

18
tools/get_memes.py Normal file
View File

@ -0,0 +1,18 @@
import requests
GET_MEMES_URL = "https://api.imgflip.com/get_memes"
def get_memes():
return get_memes_helper()
def get_memes_helper():
response = requests.get(GET_MEMES_URL)
result = response.json()
if result['success']:
memes = result['data']['memes']
memes_str = "\n".join([f"ID: {meme['id']}, Name: {meme['name']}" for meme in memes])
return memes_str
else:
return "No template_ids found."