Update main.py

This commit is contained in:
whitney 2025-04-14 14:01:30 -07:00
parent 7798f7b8a2
commit 82628f0f19

36
main.py
View File

@ -10,10 +10,10 @@ from tools.download_image import download_image
from typing import TypedDict, Optional
import argparse
import json
import re
############################# State Definition ################################
class State(TypedDict):
question: str
@ -24,7 +24,6 @@ class State(TypedDict):
image_path: Optional[str]
############################# LLM Definition ##################################
llm = OllamaLLM(model="llama3")
class MemeInfo(BaseModel):
@ -46,7 +45,7 @@ def parse_input(state: State) -> State:
response = llm.invoke(prompt)
parsed = parser.invoke(response)
print(parsed)
return {**state, **parsed.dict()}
return {**state, **parsed.model_dump()}
######################## Graph Components Definition ##########################
def node_get_memes(state: State) -> State:
@ -63,7 +62,7 @@ def node_get_memes(state: State) -> State:
print("Found Meme: " + meme)
return {**state, "template_id": meme_id}
raise ValueError(f"No matching meme found for template_name: {state['template_name']}")
raise ValueError(f"No matching meme found for template_name: {state['template_name']}\nAvailable Memes:\n{memes}")
def node_caption_image(state: State) -> State:
input_data = {
@ -76,15 +75,13 @@ def node_caption_image(state: State) -> State:
return {**state, "caption_result": result}
def node_download_image(state: State) -> State:
caption_result = state["caption_result"]
print("Generated Image: " + caption_result)
if isinstance(caption_result, str) and caption_result.startswith("http"):
download_payload = json.dumps({"url": caption_result})
image_url = state["caption_result"]
if isinstance(image_url, str) and image_url.startswith("http"):
download_payload = json.dumps({"url": image_url})
result = download_image(download_payload)
print("Saved to: " + result)
return {**state, "image_path": result}
else:
print("Failed to download image. " + caption_result)
print("Failed to download image. " + image_url)
############################## Graph Definition ###############################
graph = StateGraph(State)
@ -103,7 +100,20 @@ graph.add_edge("download_image", END)
app = graph.compile()
#################################### Run ######################################
question = "Generate an image for the 'two buttons' meme with first text 'generated meme' and second text 'langgraph error'."
def main():
parser = argparse.ArgumentParser(description="Meme Generation Script")
parser.add_argument(
"--prompt",
type=str,
help="Optional prompt for the meme generation",
default="Generate an image for the 'two buttons' meme with first text 'generated meme' and second text 'langgraph error'."
)
args = parser.parse_args()
question = args.prompt
final_state = app.invoke({"question": question})
print(final_state["image_path"])
if __name__ == "__main__":
main()
final_state = app.invoke({"question": question})
print(final_state["image_path"])