mirror of
https://github.com/runyanjake/memegraph.git
synced 2025-10-04 15:07:30 -07:00
120 lines
3.8 KiB
Python
120 lines
3.8 KiB
Python
from langchain_ollama.llms import OllamaLLM
|
|
from langchain_core.output_parsers import PydanticOutputParser
|
|
from langgraph.graph import StateGraph, END
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from tools.get_memes import get_memes
|
|
from tools.caption_image import caption_image
|
|
from tools.download_image import download_image
|
|
|
|
from typing import TypedDict, Optional
|
|
|
|
import argparse
|
|
import json
|
|
import re
|
|
|
|
############################# State Definition ################################
|
|
class State(TypedDict):
|
|
question: str
|
|
template_name: Optional[str]
|
|
text: Optional[list[str]]
|
|
template_id: Optional[int]
|
|
caption_result: Optional[dict]
|
|
image_path: Optional[str]
|
|
|
|
############################# LLM Definition ##################################
|
|
llm = OllamaLLM(model="llama3")
|
|
|
|
class MemeInfo(BaseModel):
|
|
template_name: str = Field(..., description="The meme template name")
|
|
text: list[str] = Field(..., description="List of captions/texts for the meme")
|
|
|
|
parser = PydanticOutputParser(pydantic_object=MemeInfo)
|
|
|
|
def parse_input(state: State) -> State:
|
|
prompt = f"""
|
|
Extract the meme name and list of texts from the user's request.
|
|
Return the output in this JSON format:
|
|
|
|
{parser.get_format_instructions()}
|
|
|
|
User input: {state["question"]}
|
|
"""
|
|
|
|
response = llm.invoke(prompt)
|
|
parsed = parser.invoke(response)
|
|
print(parsed)
|
|
return {**state, **parsed.model_dump()}
|
|
|
|
######################## Graph Components Definition ##########################
|
|
def node_get_memes(state: State) -> State:
|
|
memes_str = get_memes()
|
|
memes = memes_str.strip().splitlines()
|
|
|
|
for meme in memes:
|
|
match = re.match(r"ID: (\d+), Name: (.+)", meme)
|
|
if match:
|
|
meme_id = int(match.group(1))
|
|
meme_name = match.group(2).strip()
|
|
|
|
if meme_name.lower() == state["template_name"].lower():
|
|
print("Found Meme: " + meme)
|
|
return {**state, "template_id": meme_id}
|
|
|
|
raise ValueError(f"No matching meme found for template_name: {state['template_name']}\nAvailable Memes:\n{memes}")
|
|
|
|
def node_caption_image(state: State) -> State:
|
|
input_data = {
|
|
"template_id": int(state["template_id"]),
|
|
"text": state["text"]
|
|
}
|
|
json_input = json.dumps(input_data)
|
|
result = caption_image(json_input)
|
|
print(result)
|
|
return {**state, "caption_result": result}
|
|
|
|
def node_download_image(state: State) -> State:
|
|
image_url = state["caption_result"]
|
|
if isinstance(image_url, str) and image_url.startswith("http"):
|
|
download_payload = json.dumps({"url": image_url})
|
|
result = download_image(download_payload)
|
|
return {**state, "image_path": result}
|
|
else:
|
|
print("Failed to download image. " + image_url)
|
|
|
|
############################## Graph Definition ###############################
|
|
graph = StateGraph(State)
|
|
|
|
graph.add_node("parse_input", parse_input)
|
|
graph.add_node("get_memes", node_get_memes)
|
|
graph.add_node("caption_image", node_caption_image)
|
|
graph.add_node("download_image", node_download_image)
|
|
|
|
graph.set_entry_point("parse_input")
|
|
graph.add_edge("parse_input", "get_memes")
|
|
graph.add_edge("get_memes", "caption_image")
|
|
graph.add_edge("caption_image", "download_image")
|
|
graph.add_edge("download_image", END)
|
|
|
|
app = graph.compile()
|
|
|
|
#################################### Run ######################################
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Meme Generation Script")
|
|
parser.add_argument(
|
|
"--prompt",
|
|
type=str,
|
|
help="Optional prompt for the meme generation",
|
|
default="Generate an image for the 'two buttons' meme with first text 'generated meme' and second text 'langgraph error'."
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
question = args.prompt
|
|
final_state = app.invoke({"question": question})
|
|
print(final_state["image_path"])
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|