mirror of
https://github.com/runyanjake/memegraph.git
synced 2025-10-04 15:07:30 -07:00
Finish the meme generator
This commit is contained in:
parent
b19ff7102e
commit
7798f7b8a2
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
output/
|
||||
env/
|
||||
tools/__pycache__/
|
||||
tools/config.json
|
15
README.md
15
README.md
@ -1,2 +1,17 @@
|
||||
# memegraph
|
||||
I previously did meme generation via Langchain in https://github.com/runyanjake/memechain. Here is the same project but written using Langgraph instead, since that seems to have superseded Langchain in recent time.
|
||||
|
||||
The main reason to make the switch is to explore [multi-agent](https://github.com/langchain-ai/langgraph/blob/main/docs/docs/tutorials/multi_agent/multi-agent-collaboration.ipynb) workflows, which hopefully we can apply to this project at some point.
|
||||
|
||||
## Setup
|
||||
|
||||
### Python
|
||||
```
|
||||
python3 -m venv env
|
||||
source env/bin/activate
|
||||
pip install -r requirements.txt
|
||||
python main.py
|
||||
...
|
||||
deactivate
|
||||
```
|
||||
|
||||
|
109
main.py
Normal file
109
main.py
Normal file
@ -0,0 +1,109 @@
|
||||
from langchain_ollama.llms import OllamaLLM
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
from langgraph.graph import StateGraph, END
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from tools.get_memes import get_memes
|
||||
from tools.caption_image import caption_image
|
||||
from tools.download_image import download_image
|
||||
|
||||
from typing import TypedDict, Optional
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
|
||||
############################# State Definition ################################
|
||||
class State(TypedDict):
|
||||
question: str
|
||||
template_name: Optional[str]
|
||||
text: Optional[list[str]]
|
||||
template_id: Optional[int]
|
||||
caption_result: Optional[dict]
|
||||
image_path: Optional[str]
|
||||
|
||||
############################# LLM Definition ##################################
|
||||
|
||||
llm = OllamaLLM(model="llama3")
|
||||
|
||||
class MemeInfo(BaseModel):
|
||||
template_name: str = Field(..., description="The meme template name")
|
||||
text: list[str] = Field(..., description="List of captions/texts for the meme")
|
||||
|
||||
parser = PydanticOutputParser(pydantic_object=MemeInfo)
|
||||
|
||||
def parse_input(state: State) -> State:
|
||||
prompt = f"""
|
||||
Extract the meme name and list of texts from the user's request.
|
||||
Return the output in this JSON format:
|
||||
|
||||
{parser.get_format_instructions()}
|
||||
|
||||
User input: {state["question"]}
|
||||
"""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
parsed = parser.invoke(response)
|
||||
print(parsed)
|
||||
return {**state, **parsed.dict()}
|
||||
|
||||
######################## Graph Components Definition ##########################
|
||||
def node_get_memes(state: State) -> State:
|
||||
memes_str = get_memes()
|
||||
memes = memes_str.strip().splitlines()
|
||||
|
||||
for meme in memes:
|
||||
match = re.match(r"ID: (\d+), Name: (.+)", meme)
|
||||
if match:
|
||||
meme_id = int(match.group(1))
|
||||
meme_name = match.group(2).strip()
|
||||
|
||||
if meme_name.lower() == state["template_name"].lower():
|
||||
print("Found Meme: " + meme)
|
||||
return {**state, "template_id": meme_id}
|
||||
|
||||
raise ValueError(f"No matching meme found for template_name: {state['template_name']}")
|
||||
|
||||
def node_caption_image(state: State) -> State:
|
||||
input_data = {
|
||||
"template_id": int(state["template_id"]),
|
||||
"text": state["text"]
|
||||
}
|
||||
json_input = json.dumps(input_data)
|
||||
result = caption_image(json_input)
|
||||
print(result)
|
||||
return {**state, "caption_result": result}
|
||||
|
||||
def node_download_image(state: State) -> State:
|
||||
caption_result = state["caption_result"]
|
||||
print("Generated Image: " + caption_result)
|
||||
if isinstance(caption_result, str) and caption_result.startswith("http"):
|
||||
download_payload = json.dumps({"url": caption_result})
|
||||
result = download_image(download_payload)
|
||||
print("Saved to: " + result)
|
||||
return {**state, "image_path": result}
|
||||
else:
|
||||
print("Failed to download image. " + caption_result)
|
||||
|
||||
############################## Graph Definition ###############################
|
||||
graph = StateGraph(State)
|
||||
|
||||
graph.add_node("parse_input", parse_input)
|
||||
graph.add_node("get_memes", node_get_memes)
|
||||
graph.add_node("caption_image", node_caption_image)
|
||||
graph.add_node("download_image", node_download_image)
|
||||
|
||||
graph.set_entry_point("parse_input")
|
||||
graph.add_edge("parse_input", "get_memes")
|
||||
graph.add_edge("get_memes", "caption_image")
|
||||
graph.add_edge("caption_image", "download_image")
|
||||
graph.add_edge("download_image", END)
|
||||
|
||||
app = graph.compile()
|
||||
|
||||
#################################### Run ######################################
|
||||
question = "Generate an image for the 'two buttons' meme with first text 'generated meme' and second text 'langgraph error'."
|
||||
|
||||
final_state = app.invoke({"question": question})
|
||||
print(final_state["image_path"])
|
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@ -0,0 +1,7 @@
|
||||
langchain_community
|
||||
langchain_core>=0.1.35
|
||||
langchain_experimental
|
||||
langchain_ollama
|
||||
langgraph
|
||||
pillow
|
||||
pydantic
|
40
tools/caption_image.py
Normal file
40
tools/caption_image.py
Normal file
@ -0,0 +1,40 @@
|
||||
import json
|
||||
import requests
|
||||
|
||||
CAPTION_IMAGE_URL = "https://api.imgflip.com/caption_image"
|
||||
|
||||
def load_config():
|
||||
with open('tools/config.json') as config_file:
|
||||
return json.load(config_file)
|
||||
|
||||
def caption_image(input_data):
|
||||
# Replace single quotes with double quotes because langchain likes to use single quotes
|
||||
input_data = input_data.replace("'", '"')
|
||||
|
||||
data = json.loads(input_data)
|
||||
template_id = data['template_id']
|
||||
text = data['text']
|
||||
|
||||
config = load_config()
|
||||
username = config['username']
|
||||
password = config['password']
|
||||
|
||||
url = CAPTION_IMAGE_URL
|
||||
payload = {
|
||||
"template_id": template_id,
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
|
||||
for i in range(len(text)):
|
||||
payload[f'text{i}'] = text[i]
|
||||
|
||||
response = requests.post(url, data=payload)
|
||||
result = response.json()
|
||||
|
||||
if result['success']:
|
||||
meme_url = result['data']['url']
|
||||
print(f"Meme created! URL: {meme_url}")
|
||||
return meme_url
|
||||
else:
|
||||
return None
|
4
tools/config.json.blanked
Normal file
4
tools/config.json.blanked
Normal file
@ -0,0 +1,4 @@
|
||||
{
|
||||
"username":"xxxx",
|
||||
"password":"yyyy"
|
||||
}
|
22
tools/download_image.py
Normal file
22
tools/download_image.py
Normal file
@ -0,0 +1,22 @@
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
|
||||
def download_image(input_data):
|
||||
input_data = input_data.replace("'", '"')
|
||||
|
||||
data = json.loads(input_data)
|
||||
url = data['url']
|
||||
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
output_dir = 'output'
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
image_path = os.path.join(output_dir, url.split("/")[-1])
|
||||
with open(image_path, 'wb') as f:
|
||||
f.write(response.content)
|
||||
print(f"Image saved to {image_path}")
|
||||
return image_path
|
||||
else:
|
||||
print("Failed to download image.")
|
||||
return None
|
18
tools/get_memes.py
Normal file
18
tools/get_memes.py
Normal file
@ -0,0 +1,18 @@
|
||||
import requests
|
||||
|
||||
GET_MEMES_URL = "https://api.imgflip.com/get_memes"
|
||||
|
||||
def get_memes():
|
||||
return get_memes_helper()
|
||||
|
||||
def get_memes_helper():
|
||||
response = requests.get(GET_MEMES_URL)
|
||||
result = response.json()
|
||||
|
||||
if result['success']:
|
||||
memes = result['data']['memes']
|
||||
memes_str = "\n".join([f"ID: {meme['id']}, Name: {meme['name']}" for meme in memes])
|
||||
return memes_str
|
||||
else:
|
||||
return "No template_ids found."
|
||||
|
Loading…
x
Reference in New Issue
Block a user