diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d9ba3e7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +output/ +env/ +tools/__pycache__/ +tools/config.json diff --git a/README.md b/README.md index 73fb00e..3d5347e 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,17 @@ # memegraph I previously did meme generation via Langchain in https://github.com/runyanjake/memechain. Here is the same project but written using Langgraph instead, since that seems to have superseded Langchain in recent time. + +The main reason to make the switch is to explore [multi-agent](https://github.com/langchain-ai/langgraph/blob/main/docs/docs/tutorials/multi_agent/multi-agent-collaboration.ipynb) workflows, which hopefully we can apply to this project at some point. + +## Setup + +### Python +``` +python3 -m venv env +source env/bin/activate +pip install -r requirements.txt +python main.py +... +deactivate +``` + diff --git a/main.py b/main.py new file mode 100644 index 0000000..0451d7b --- /dev/null +++ b/main.py @@ -0,0 +1,109 @@ +from langchain_ollama.llms import OllamaLLM +from langchain_core.output_parsers import PydanticOutputParser +from langgraph.graph import StateGraph, END + +from pydantic import BaseModel, Field + +from tools.get_memes import get_memes +from tools.caption_image import caption_image +from tools.download_image import download_image + +from typing import TypedDict, Optional + +import json +import re + + +############################# State Definition ################################ +class State(TypedDict): + question: str + template_name: Optional[str] + text: Optional[list[str]] + template_id: Optional[int] + caption_result: Optional[dict] + image_path: Optional[str] + +############################# LLM Definition ################################## + +llm = OllamaLLM(model="llama3") + +class MemeInfo(BaseModel): + template_name: str = Field(..., description="The meme template name") + text: list[str] = Field(..., description="List of captions/texts for the meme") + +parser = PydanticOutputParser(pydantic_object=MemeInfo) + +def parse_input(state: State) -> State: + prompt = f""" + Extract the meme name and list of texts from the user's request. + Return the output in this JSON format: + + {parser.get_format_instructions()} + + User input: {state["question"]} + """ + + response = llm.invoke(prompt) + parsed = parser.invoke(response) + print(parsed) + return {**state, **parsed.dict()} + +######################## Graph Components Definition ########################## +def node_get_memes(state: State) -> State: + memes_str = get_memes() + memes = memes_str.strip().splitlines() + + for meme in memes: + match = re.match(r"ID: (\d+), Name: (.+)", meme) + if match: + meme_id = int(match.group(1)) + meme_name = match.group(2).strip() + + if meme_name.lower() == state["template_name"].lower(): + print("Found Meme: " + meme) + return {**state, "template_id": meme_id} + + raise ValueError(f"No matching meme found for template_name: {state['template_name']}") + +def node_caption_image(state: State) -> State: + input_data = { + "template_id": int(state["template_id"]), + "text": state["text"] + } + json_input = json.dumps(input_data) + result = caption_image(json_input) + print(result) + return {**state, "caption_result": result} + +def node_download_image(state: State) -> State: + caption_result = state["caption_result"] + print("Generated Image: " + caption_result) + if isinstance(caption_result, str) and caption_result.startswith("http"): + download_payload = json.dumps({"url": caption_result}) + result = download_image(download_payload) + print("Saved to: " + result) + return {**state, "image_path": result} + else: + print("Failed to download image. " + caption_result) + +############################## Graph Definition ############################### +graph = StateGraph(State) + +graph.add_node("parse_input", parse_input) +graph.add_node("get_memes", node_get_memes) +graph.add_node("caption_image", node_caption_image) +graph.add_node("download_image", node_download_image) + +graph.set_entry_point("parse_input") +graph.add_edge("parse_input", "get_memes") +graph.add_edge("get_memes", "caption_image") +graph.add_edge("caption_image", "download_image") +graph.add_edge("download_image", END) + +app = graph.compile() + +#################################### Run ###################################### +question = "Generate an image for the 'two buttons' meme with first text 'generated meme' and second text 'langgraph error'." + +final_state = app.invoke({"question": question}) +print(final_state["image_path"]) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..170e543 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +langchain_community +langchain_core>=0.1.35 +langchain_experimental +langchain_ollama +langgraph +pillow +pydantic diff --git a/tools/caption_image.py b/tools/caption_image.py new file mode 100644 index 0000000..373be40 --- /dev/null +++ b/tools/caption_image.py @@ -0,0 +1,40 @@ +import json +import requests + +CAPTION_IMAGE_URL = "https://api.imgflip.com/caption_image" + +def load_config(): + with open('tools/config.json') as config_file: + return json.load(config_file) + +def caption_image(input_data): + # Replace single quotes with double quotes because langchain likes to use single quotes + input_data = input_data.replace("'", '"') + + data = json.loads(input_data) + template_id = data['template_id'] + text = data['text'] + + config = load_config() + username = config['username'] + password = config['password'] + + url = CAPTION_IMAGE_URL + payload = { + "template_id": template_id, + "username": username, + "password": password, + } + + for i in range(len(text)): + payload[f'text{i}'] = text[i] + + response = requests.post(url, data=payload) + result = response.json() + + if result['success']: + meme_url = result['data']['url'] + print(f"Meme created! URL: {meme_url}") + return meme_url + else: + return None diff --git a/tools/config.json.blanked b/tools/config.json.blanked new file mode 100644 index 0000000..8d9a35c --- /dev/null +++ b/tools/config.json.blanked @@ -0,0 +1,4 @@ +{ + "username":"xxxx", + "password":"yyyy" +} diff --git a/tools/download_image.py b/tools/download_image.py new file mode 100644 index 0000000..c50d6b3 --- /dev/null +++ b/tools/download_image.py @@ -0,0 +1,22 @@ +import os +import requests +import json + +def download_image(input_data): + input_data = input_data.replace("'", '"') + + data = json.loads(input_data) + url = data['url'] + + response = requests.get(url) + if response.status_code == 200: + output_dir = 'output' + os.makedirs(output_dir, exist_ok=True) + image_path = os.path.join(output_dir, url.split("/")[-1]) + with open(image_path, 'wb') as f: + f.write(response.content) + print(f"Image saved to {image_path}") + return image_path + else: + print("Failed to download image.") + return None diff --git a/tools/get_memes.py b/tools/get_memes.py new file mode 100644 index 0000000..e1a25ea --- /dev/null +++ b/tools/get_memes.py @@ -0,0 +1,18 @@ +import requests + +GET_MEMES_URL = "https://api.imgflip.com/get_memes" + +def get_memes(): + return get_memes_helper() + +def get_memes_helper(): + response = requests.get(GET_MEMES_URL) + result = response.json() + + if result['success']: + memes = result['data']['memes'] + memes_str = "\n".join([f"ID: {meme['id']}, Name: {meme['name']}" for meme in memes]) + return memes_str + else: + return "No template_ids found." +