From 82628f0f1996fdc767b6d2f0a5157949b0c350d3 Mon Sep 17 00:00:00 2001 From: whitney Date: Mon, 14 Apr 2025 14:01:30 -0700 Subject: [PATCH] Update main.py --- main.py | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index 0451d7b..d94d830 100644 --- a/main.py +++ b/main.py @@ -10,10 +10,10 @@ from tools.download_image import download_image from typing import TypedDict, Optional +import argparse import json import re - ############################# State Definition ################################ class State(TypedDict): question: str @@ -24,7 +24,6 @@ class State(TypedDict): image_path: Optional[str] ############################# LLM Definition ################################## - llm = OllamaLLM(model="llama3") class MemeInfo(BaseModel): @@ -46,7 +45,7 @@ def parse_input(state: State) -> State: response = llm.invoke(prompt) parsed = parser.invoke(response) print(parsed) - return {**state, **parsed.dict()} + return {**state, **parsed.model_dump()} ######################## Graph Components Definition ########################## def node_get_memes(state: State) -> State: @@ -63,7 +62,7 @@ def node_get_memes(state: State) -> State: print("Found Meme: " + meme) return {**state, "template_id": meme_id} - raise ValueError(f"No matching meme found for template_name: {state['template_name']}") + raise ValueError(f"No matching meme found for template_name: {state['template_name']}\nAvailable Memes:\n{memes}") def node_caption_image(state: State) -> State: input_data = { @@ -76,15 +75,13 @@ def node_caption_image(state: State) -> State: return {**state, "caption_result": result} def node_download_image(state: State) -> State: - caption_result = state["caption_result"] - print("Generated Image: " + caption_result) - if isinstance(caption_result, str) and caption_result.startswith("http"): - download_payload = json.dumps({"url": caption_result}) + image_url = state["caption_result"] + if isinstance(image_url, str) and image_url.startswith("http"): + download_payload = json.dumps({"url": image_url}) result = download_image(download_payload) - print("Saved to: " + result) return {**state, "image_path": result} else: - print("Failed to download image. " + caption_result) + print("Failed to download image. " + image_url) ############################## Graph Definition ############################### graph = StateGraph(State) @@ -103,7 +100,20 @@ graph.add_edge("download_image", END) app = graph.compile() #################################### Run ###################################### -question = "Generate an image for the 'two buttons' meme with first text 'generated meme' and second text 'langgraph error'." +def main(): + parser = argparse.ArgumentParser(description="Meme Generation Script") + parser.add_argument( + "--prompt", + type=str, + help="Optional prompt for the meme generation", + default="Generate an image for the 'two buttons' meme with first text 'generated meme' and second text 'langgraph error'." + ) + args = parser.parse_args() + + question = args.prompt + final_state = app.invoke({"question": question}) + print(final_state["image_path"]) + +if __name__ == "__main__": + main() -final_state = app.invoke({"question": question}) -print(final_state["image_path"])