memegraph/main.py

110 lines
3.6 KiB
Python

from langchain_ollama.llms import OllamaLLM
from langchain_core.output_parsers import PydanticOutputParser
from langgraph.graph import StateGraph, END
from pydantic import BaseModel, Field
from tools.get_memes import get_memes
from tools.caption_image import caption_image
from tools.download_image import download_image
from typing import TypedDict, Optional
import json
import re
############################# State Definition ################################
class State(TypedDict):
question: str
template_name: Optional[str]
text: Optional[list[str]]
template_id: Optional[int]
caption_result: Optional[dict]
image_path: Optional[str]
############################# LLM Definition ##################################
llm = OllamaLLM(model="llama3")
class MemeInfo(BaseModel):
template_name: str = Field(..., description="The meme template name")
text: list[str] = Field(..., description="List of captions/texts for the meme")
parser = PydanticOutputParser(pydantic_object=MemeInfo)
def parse_input(state: State) -> State:
prompt = f"""
Extract the meme name and list of texts from the user's request.
Return the output in this JSON format:
{parser.get_format_instructions()}
User input: {state["question"]}
"""
response = llm.invoke(prompt)
parsed = parser.invoke(response)
print(parsed)
return {**state, **parsed.dict()}
######################## Graph Components Definition ##########################
def node_get_memes(state: State) -> State:
memes_str = get_memes()
memes = memes_str.strip().splitlines()
for meme in memes:
match = re.match(r"ID: (\d+), Name: (.+)", meme)
if match:
meme_id = int(match.group(1))
meme_name = match.group(2).strip()
if meme_name.lower() == state["template_name"].lower():
print("Found Meme: " + meme)
return {**state, "template_id": meme_id}
raise ValueError(f"No matching meme found for template_name: {state['template_name']}")
def node_caption_image(state: State) -> State:
input_data = {
"template_id": int(state["template_id"]),
"text": state["text"]
}
json_input = json.dumps(input_data)
result = caption_image(json_input)
print(result)
return {**state, "caption_result": result}
def node_download_image(state: State) -> State:
caption_result = state["caption_result"]
print("Generated Image: " + caption_result)
if isinstance(caption_result, str) and caption_result.startswith("http"):
download_payload = json.dumps({"url": caption_result})
result = download_image(download_payload)
print("Saved to: " + result)
return {**state, "image_path": result}
else:
print("Failed to download image. " + caption_result)
############################## Graph Definition ###############################
graph = StateGraph(State)
graph.add_node("parse_input", parse_input)
graph.add_node("get_memes", node_get_memes)
graph.add_node("caption_image", node_caption_image)
graph.add_node("download_image", node_download_image)
graph.set_entry_point("parse_input")
graph.add_edge("parse_input", "get_memes")
graph.add_edge("get_memes", "caption_image")
graph.add_edge("caption_image", "download_image")
graph.add_edge("download_image", END)
app = graph.compile()
#################################### Run ######################################
question = "Generate an image for the 'two buttons' meme with first text 'generated meme' and second text 'langgraph error'."
final_state = app.invoke({"question": question})
print(final_state["image_path"])