Compare commits

...

3 Commits

Author SHA1 Message Date
ef4c2f9102 Update README 2025-04-14 14:02:53 -07:00
82628f0f19 Update main.py 2025-04-14 14:01:30 -07:00
7798f7b8a2 Finish the meme generator 2025-04-14 13:49:54 -07:00
8 changed files with 235 additions and 0 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
output/
env/
tools/__pycache__/
tools/config.json

View File

@ -1,2 +1,23 @@
# memegraph
I previously did meme generation via Langchain in https://github.com/runyanjake/memechain. Here is the same project but written using Langgraph instead, since that seems to have superseded Langchain in recent time.
The main reason to make the switch is to explore [multi-agent](https://github.com/langchain-ai/langgraph/blob/main/docs/docs/tutorials/multi_agent/multi-agent-collaboration.ipynb) workflows, which hopefully we can apply to this project at some point.
## Setup
### Python
```
python3 -m venv env
source env/bin/activate
pip install -r requirements.txt
python main.py
...
deactivate
```
## Usage
Can specify prompt when calling script or will use default prompt.
```
python main.py
python main.py --prompt "Generate an image for the 'Leonardo dicaprio cheers' meme with the text 'LANGGRAPH MEMES'."
```

119
main.py Normal file
View File

@ -0,0 +1,119 @@
from langchain_ollama.llms import OllamaLLM
from langchain_core.output_parsers import PydanticOutputParser
from langgraph.graph import StateGraph, END
from pydantic import BaseModel, Field
from tools.get_memes import get_memes
from tools.caption_image import caption_image
from tools.download_image import download_image
from typing import TypedDict, Optional
import argparse
import json
import re
############################# State Definition ################################
class State(TypedDict):
question: str
template_name: Optional[str]
text: Optional[list[str]]
template_id: Optional[int]
caption_result: Optional[dict]
image_path: Optional[str]
############################# LLM Definition ##################################
llm = OllamaLLM(model="llama3")
class MemeInfo(BaseModel):
template_name: str = Field(..., description="The meme template name")
text: list[str] = Field(..., description="List of captions/texts for the meme")
parser = PydanticOutputParser(pydantic_object=MemeInfo)
def parse_input(state: State) -> State:
prompt = f"""
Extract the meme name and list of texts from the user's request.
Return the output in this JSON format:
{parser.get_format_instructions()}
User input: {state["question"]}
"""
response = llm.invoke(prompt)
parsed = parser.invoke(response)
print(parsed)
return {**state, **parsed.model_dump()}
######################## Graph Components Definition ##########################
def node_get_memes(state: State) -> State:
memes_str = get_memes()
memes = memes_str.strip().splitlines()
for meme in memes:
match = re.match(r"ID: (\d+), Name: (.+)", meme)
if match:
meme_id = int(match.group(1))
meme_name = match.group(2).strip()
if meme_name.lower() == state["template_name"].lower():
print("Found Meme: " + meme)
return {**state, "template_id": meme_id}
raise ValueError(f"No matching meme found for template_name: {state['template_name']}\nAvailable Memes:\n{memes}")
def node_caption_image(state: State) -> State:
input_data = {
"template_id": int(state["template_id"]),
"text": state["text"]
}
json_input = json.dumps(input_data)
result = caption_image(json_input)
print(result)
return {**state, "caption_result": result}
def node_download_image(state: State) -> State:
image_url = state["caption_result"]
if isinstance(image_url, str) and image_url.startswith("http"):
download_payload = json.dumps({"url": image_url})
result = download_image(download_payload)
return {**state, "image_path": result}
else:
print("Failed to download image. " + image_url)
############################## Graph Definition ###############################
graph = StateGraph(State)
graph.add_node("parse_input", parse_input)
graph.add_node("get_memes", node_get_memes)
graph.add_node("caption_image", node_caption_image)
graph.add_node("download_image", node_download_image)
graph.set_entry_point("parse_input")
graph.add_edge("parse_input", "get_memes")
graph.add_edge("get_memes", "caption_image")
graph.add_edge("caption_image", "download_image")
graph.add_edge("download_image", END)
app = graph.compile()
#################################### Run ######################################
def main():
parser = argparse.ArgumentParser(description="Meme Generation Script")
parser.add_argument(
"--prompt",
type=str,
help="Optional prompt for the meme generation",
default="Generate an image for the 'two buttons' meme with first text 'generated meme' and second text 'langgraph error'."
)
args = parser.parse_args()
question = args.prompt
final_state = app.invoke({"question": question})
print(final_state["image_path"])
if __name__ == "__main__":
main()

7
requirements.txt Normal file
View File

@ -0,0 +1,7 @@
langchain_community
langchain_core>=0.1.35
langchain_experimental
langchain_ollama
langgraph
pillow
pydantic

40
tools/caption_image.py Normal file
View File

@ -0,0 +1,40 @@
import json
import requests
CAPTION_IMAGE_URL = "https://api.imgflip.com/caption_image"
def load_config():
with open('tools/config.json') as config_file:
return json.load(config_file)
def caption_image(input_data):
# Replace single quotes with double quotes because langchain likes to use single quotes
input_data = input_data.replace("'", '"')
data = json.loads(input_data)
template_id = data['template_id']
text = data['text']
config = load_config()
username = config['username']
password = config['password']
url = CAPTION_IMAGE_URL
payload = {
"template_id": template_id,
"username": username,
"password": password,
}
for i in range(len(text)):
payload[f'text{i}'] = text[i]
response = requests.post(url, data=payload)
result = response.json()
if result['success']:
meme_url = result['data']['url']
print(f"Meme created! URL: {meme_url}")
return meme_url
else:
return None

View File

@ -0,0 +1,4 @@
{
"username":"xxxx",
"password":"yyyy"
}

22
tools/download_image.py Normal file
View File

@ -0,0 +1,22 @@
import os
import requests
import json
def download_image(input_data):
input_data = input_data.replace("'", '"')
data = json.loads(input_data)
url = data['url']
response = requests.get(url)
if response.status_code == 200:
output_dir = 'output'
os.makedirs(output_dir, exist_ok=True)
image_path = os.path.join(output_dir, url.split("/")[-1])
with open(image_path, 'wb') as f:
f.write(response.content)
print(f"Image saved to {image_path}")
return image_path
else:
print("Failed to download image.")
return None

18
tools/get_memes.py Normal file
View File

@ -0,0 +1,18 @@
import requests
GET_MEMES_URL = "https://api.imgflip.com/get_memes"
def get_memes():
return get_memes_helper()
def get_memes_helper():
response = requests.get(GET_MEMES_URL)
result = response.json()
if result['success']:
memes = result['data']['memes']
memes_str = "\n".join([f"ID: {meme['id']}, Name: {meme['name']}" for meme in memes])
return memes_str
else:
return "No template_ids found."