mirror of
https://github.com/runyanjake/memegraph.git
synced 2025-10-04 23:17:31 -07:00
Update main.py
This commit is contained in:
parent
7798f7b8a2
commit
82628f0f19
36
main.py
36
main.py
@ -10,10 +10,10 @@ from tools.download_image import download_image
|
|||||||
|
|
||||||
from typing import TypedDict, Optional
|
from typing import TypedDict, Optional
|
||||||
|
|
||||||
|
import argparse
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
############################# State Definition ################################
|
############################# State Definition ################################
|
||||||
class State(TypedDict):
|
class State(TypedDict):
|
||||||
question: str
|
question: str
|
||||||
@ -24,7 +24,6 @@ class State(TypedDict):
|
|||||||
image_path: Optional[str]
|
image_path: Optional[str]
|
||||||
|
|
||||||
############################# LLM Definition ##################################
|
############################# LLM Definition ##################################
|
||||||
|
|
||||||
llm = OllamaLLM(model="llama3")
|
llm = OllamaLLM(model="llama3")
|
||||||
|
|
||||||
class MemeInfo(BaseModel):
|
class MemeInfo(BaseModel):
|
||||||
@ -46,7 +45,7 @@ def parse_input(state: State) -> State:
|
|||||||
response = llm.invoke(prompt)
|
response = llm.invoke(prompt)
|
||||||
parsed = parser.invoke(response)
|
parsed = parser.invoke(response)
|
||||||
print(parsed)
|
print(parsed)
|
||||||
return {**state, **parsed.dict()}
|
return {**state, **parsed.model_dump()}
|
||||||
|
|
||||||
######################## Graph Components Definition ##########################
|
######################## Graph Components Definition ##########################
|
||||||
def node_get_memes(state: State) -> State:
|
def node_get_memes(state: State) -> State:
|
||||||
@ -63,7 +62,7 @@ def node_get_memes(state: State) -> State:
|
|||||||
print("Found Meme: " + meme)
|
print("Found Meme: " + meme)
|
||||||
return {**state, "template_id": meme_id}
|
return {**state, "template_id": meme_id}
|
||||||
|
|
||||||
raise ValueError(f"No matching meme found for template_name: {state['template_name']}")
|
raise ValueError(f"No matching meme found for template_name: {state['template_name']}\nAvailable Memes:\n{memes}")
|
||||||
|
|
||||||
def node_caption_image(state: State) -> State:
|
def node_caption_image(state: State) -> State:
|
||||||
input_data = {
|
input_data = {
|
||||||
@ -76,15 +75,13 @@ def node_caption_image(state: State) -> State:
|
|||||||
return {**state, "caption_result": result}
|
return {**state, "caption_result": result}
|
||||||
|
|
||||||
def node_download_image(state: State) -> State:
|
def node_download_image(state: State) -> State:
|
||||||
caption_result = state["caption_result"]
|
image_url = state["caption_result"]
|
||||||
print("Generated Image: " + caption_result)
|
if isinstance(image_url, str) and image_url.startswith("http"):
|
||||||
if isinstance(caption_result, str) and caption_result.startswith("http"):
|
download_payload = json.dumps({"url": image_url})
|
||||||
download_payload = json.dumps({"url": caption_result})
|
|
||||||
result = download_image(download_payload)
|
result = download_image(download_payload)
|
||||||
print("Saved to: " + result)
|
|
||||||
return {**state, "image_path": result}
|
return {**state, "image_path": result}
|
||||||
else:
|
else:
|
||||||
print("Failed to download image. " + caption_result)
|
print("Failed to download image. " + image_url)
|
||||||
|
|
||||||
############################## Graph Definition ###############################
|
############################## Graph Definition ###############################
|
||||||
graph = StateGraph(State)
|
graph = StateGraph(State)
|
||||||
@ -103,7 +100,20 @@ graph.add_edge("download_image", END)
|
|||||||
app = graph.compile()
|
app = graph.compile()
|
||||||
|
|
||||||
#################################### Run ######################################
|
#################################### Run ######################################
|
||||||
question = "Generate an image for the 'two buttons' meme with first text 'generated meme' and second text 'langgraph error'."
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Meme Generation Script")
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
type=str,
|
||||||
|
help="Optional prompt for the meme generation",
|
||||||
|
default="Generate an image for the 'two buttons' meme with first text 'generated meme' and second text 'langgraph error'."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
question = args.prompt
|
||||||
|
final_state = app.invoke({"question": question})
|
||||||
|
print(final_state["image_path"])
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
final_state = app.invoke({"question": question})
|
|
||||||
print(final_state["image_path"])
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user