mirror of
https://github.com/runyanjake/memegraph.git
synced 2025-10-05 07:27:30 -07:00
Compare commits
No commits in common. "ef4c2f910214795f82b767a9545b30bb76a8e3dc" and "b19ff7102e5b3d09d4297c449da5dc9b6d31ddf1" have entirely different histories.
ef4c2f9102
...
b19ff7102e
4
.gitignore
vendored
4
.gitignore
vendored
@ -1,4 +0,0 @@
|
|||||||
output/
|
|
||||||
env/
|
|
||||||
tools/__pycache__/
|
|
||||||
tools/config.json
|
|
21
README.md
21
README.md
@ -1,23 +1,2 @@
|
|||||||
# memegraph
|
# memegraph
|
||||||
I previously did meme generation via Langchain in https://github.com/runyanjake/memechain. Here is the same project but written using Langgraph instead, since that seems to have superseded Langchain in recent time.
|
I previously did meme generation via Langchain in https://github.com/runyanjake/memechain. Here is the same project but written using Langgraph instead, since that seems to have superseded Langchain in recent time.
|
||||||
|
|
||||||
The main reason to make the switch is to explore [multi-agent](https://github.com/langchain-ai/langgraph/blob/main/docs/docs/tutorials/multi_agent/multi-agent-collaboration.ipynb) workflows, which hopefully we can apply to this project at some point.
|
|
||||||
|
|
||||||
## Setup
|
|
||||||
|
|
||||||
### Python
|
|
||||||
```
|
|
||||||
python3 -m venv env
|
|
||||||
source env/bin/activate
|
|
||||||
pip install -r requirements.txt
|
|
||||||
python main.py
|
|
||||||
...
|
|
||||||
deactivate
|
|
||||||
```
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
Can specify prompt when calling script or will use default prompt.
|
|
||||||
```
|
|
||||||
python main.py
|
|
||||||
python main.py --prompt "Generate an image for the 'Leonardo dicaprio cheers' meme with the text 'LANGGRAPH MEMES'."
|
|
||||||
```
|
|
||||||
|
119
main.py
119
main.py
@ -1,119 +0,0 @@
|
|||||||
from langchain_ollama.llms import OllamaLLM
|
|
||||||
from langchain_core.output_parsers import PydanticOutputParser
|
|
||||||
from langgraph.graph import StateGraph, END
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from tools.get_memes import get_memes
|
|
||||||
from tools.caption_image import caption_image
|
|
||||||
from tools.download_image import download_image
|
|
||||||
|
|
||||||
from typing import TypedDict, Optional
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
|
|
||||||
############################# State Definition ################################
|
|
||||||
class State(TypedDict):
|
|
||||||
question: str
|
|
||||||
template_name: Optional[str]
|
|
||||||
text: Optional[list[str]]
|
|
||||||
template_id: Optional[int]
|
|
||||||
caption_result: Optional[dict]
|
|
||||||
image_path: Optional[str]
|
|
||||||
|
|
||||||
############################# LLM Definition ##################################
|
|
||||||
llm = OllamaLLM(model="llama3")
|
|
||||||
|
|
||||||
class MemeInfo(BaseModel):
|
|
||||||
template_name: str = Field(..., description="The meme template name")
|
|
||||||
text: list[str] = Field(..., description="List of captions/texts for the meme")
|
|
||||||
|
|
||||||
parser = PydanticOutputParser(pydantic_object=MemeInfo)
|
|
||||||
|
|
||||||
def parse_input(state: State) -> State:
|
|
||||||
prompt = f"""
|
|
||||||
Extract the meme name and list of texts from the user's request.
|
|
||||||
Return the output in this JSON format:
|
|
||||||
|
|
||||||
{parser.get_format_instructions()}
|
|
||||||
|
|
||||||
User input: {state["question"]}
|
|
||||||
"""
|
|
||||||
|
|
||||||
response = llm.invoke(prompt)
|
|
||||||
parsed = parser.invoke(response)
|
|
||||||
print(parsed)
|
|
||||||
return {**state, **parsed.model_dump()}
|
|
||||||
|
|
||||||
######################## Graph Components Definition ##########################
|
|
||||||
def node_get_memes(state: State) -> State:
|
|
||||||
memes_str = get_memes()
|
|
||||||
memes = memes_str.strip().splitlines()
|
|
||||||
|
|
||||||
for meme in memes:
|
|
||||||
match = re.match(r"ID: (\d+), Name: (.+)", meme)
|
|
||||||
if match:
|
|
||||||
meme_id = int(match.group(1))
|
|
||||||
meme_name = match.group(2).strip()
|
|
||||||
|
|
||||||
if meme_name.lower() == state["template_name"].lower():
|
|
||||||
print("Found Meme: " + meme)
|
|
||||||
return {**state, "template_id": meme_id}
|
|
||||||
|
|
||||||
raise ValueError(f"No matching meme found for template_name: {state['template_name']}\nAvailable Memes:\n{memes}")
|
|
||||||
|
|
||||||
def node_caption_image(state: State) -> State:
|
|
||||||
input_data = {
|
|
||||||
"template_id": int(state["template_id"]),
|
|
||||||
"text": state["text"]
|
|
||||||
}
|
|
||||||
json_input = json.dumps(input_data)
|
|
||||||
result = caption_image(json_input)
|
|
||||||
print(result)
|
|
||||||
return {**state, "caption_result": result}
|
|
||||||
|
|
||||||
def node_download_image(state: State) -> State:
|
|
||||||
image_url = state["caption_result"]
|
|
||||||
if isinstance(image_url, str) and image_url.startswith("http"):
|
|
||||||
download_payload = json.dumps({"url": image_url})
|
|
||||||
result = download_image(download_payload)
|
|
||||||
return {**state, "image_path": result}
|
|
||||||
else:
|
|
||||||
print("Failed to download image. " + image_url)
|
|
||||||
|
|
||||||
############################## Graph Definition ###############################
|
|
||||||
graph = StateGraph(State)
|
|
||||||
|
|
||||||
graph.add_node("parse_input", parse_input)
|
|
||||||
graph.add_node("get_memes", node_get_memes)
|
|
||||||
graph.add_node("caption_image", node_caption_image)
|
|
||||||
graph.add_node("download_image", node_download_image)
|
|
||||||
|
|
||||||
graph.set_entry_point("parse_input")
|
|
||||||
graph.add_edge("parse_input", "get_memes")
|
|
||||||
graph.add_edge("get_memes", "caption_image")
|
|
||||||
graph.add_edge("caption_image", "download_image")
|
|
||||||
graph.add_edge("download_image", END)
|
|
||||||
|
|
||||||
app = graph.compile()
|
|
||||||
|
|
||||||
#################################### Run ######################################
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Meme Generation Script")
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt",
|
|
||||||
type=str,
|
|
||||||
help="Optional prompt for the meme generation",
|
|
||||||
default="Generate an image for the 'two buttons' meme with first text 'generated meme' and second text 'langgraph error'."
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
question = args.prompt
|
|
||||||
final_state = app.invoke({"question": question})
|
|
||||||
print(final_state["image_path"])
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
|||||||
langchain_community
|
|
||||||
langchain_core>=0.1.35
|
|
||||||
langchain_experimental
|
|
||||||
langchain_ollama
|
|
||||||
langgraph
|
|
||||||
pillow
|
|
||||||
pydantic
|
|
@ -1,40 +0,0 @@
|
|||||||
import json
|
|
||||||
import requests
|
|
||||||
|
|
||||||
CAPTION_IMAGE_URL = "https://api.imgflip.com/caption_image"
|
|
||||||
|
|
||||||
def load_config():
|
|
||||||
with open('tools/config.json') as config_file:
|
|
||||||
return json.load(config_file)
|
|
||||||
|
|
||||||
def caption_image(input_data):
|
|
||||||
# Replace single quotes with double quotes because langchain likes to use single quotes
|
|
||||||
input_data = input_data.replace("'", '"')
|
|
||||||
|
|
||||||
data = json.loads(input_data)
|
|
||||||
template_id = data['template_id']
|
|
||||||
text = data['text']
|
|
||||||
|
|
||||||
config = load_config()
|
|
||||||
username = config['username']
|
|
||||||
password = config['password']
|
|
||||||
|
|
||||||
url = CAPTION_IMAGE_URL
|
|
||||||
payload = {
|
|
||||||
"template_id": template_id,
|
|
||||||
"username": username,
|
|
||||||
"password": password,
|
|
||||||
}
|
|
||||||
|
|
||||||
for i in range(len(text)):
|
|
||||||
payload[f'text{i}'] = text[i]
|
|
||||||
|
|
||||||
response = requests.post(url, data=payload)
|
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
if result['success']:
|
|
||||||
meme_url = result['data']['url']
|
|
||||||
print(f"Meme created! URL: {meme_url}")
|
|
||||||
return meme_url
|
|
||||||
else:
|
|
||||||
return None
|
|
@ -1,4 +0,0 @@
|
|||||||
{
|
|
||||||
"username":"xxxx",
|
|
||||||
"password":"yyyy"
|
|
||||||
}
|
|
@ -1,22 +0,0 @@
|
|||||||
import os
|
|
||||||
import requests
|
|
||||||
import json
|
|
||||||
|
|
||||||
def download_image(input_data):
|
|
||||||
input_data = input_data.replace("'", '"')
|
|
||||||
|
|
||||||
data = json.loads(input_data)
|
|
||||||
url = data['url']
|
|
||||||
|
|
||||||
response = requests.get(url)
|
|
||||||
if response.status_code == 200:
|
|
||||||
output_dir = 'output'
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
image_path = os.path.join(output_dir, url.split("/")[-1])
|
|
||||||
with open(image_path, 'wb') as f:
|
|
||||||
f.write(response.content)
|
|
||||||
print(f"Image saved to {image_path}")
|
|
||||||
return image_path
|
|
||||||
else:
|
|
||||||
print("Failed to download image.")
|
|
||||||
return None
|
|
@ -1,18 +0,0 @@
|
|||||||
import requests
|
|
||||||
|
|
||||||
GET_MEMES_URL = "https://api.imgflip.com/get_memes"
|
|
||||||
|
|
||||||
def get_memes():
|
|
||||||
return get_memes_helper()
|
|
||||||
|
|
||||||
def get_memes_helper():
|
|
||||||
response = requests.get(GET_MEMES_URL)
|
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
if result['success']:
|
|
||||||
memes = result['data']['memes']
|
|
||||||
memes_str = "\n".join([f"ID: {meme['id']}, Name: {meme['name']}" for meme in memes])
|
|
||||||
return memes_str
|
|
||||||
else:
|
|
||||||
return "No template_ids found."
|
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user