mirror of
https://github.com/runyanjake/memechain.git
synced 2025-10-04 15:57:28 -07:00
53 lines
2.6 KiB
Python
53 lines
2.6 KiB
Python
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_ollama.llms import OllamaLLM
|
|
from langchain_core.tools import Tool
|
|
from langchain.agents import AgentExecutor, AgentType, Tool, initialize_agent
|
|
|
|
import requests
|
|
import os
|
|
|
|
from tools.get_memes import get_memes
|
|
from tools.caption_image import caption_image
|
|
from tools.download_image import download_image
|
|
|
|
system_prompt = """
|
|
You are an assistant that helps users create memes using the Imgflip API.
|
|
|
|
Given a prompt from the user you must perform the following steps in order, once, before responding:
|
|
1. Search for the numerical template_id of a requested meme using the "Get Memes" tool.
|
|
2. Generate a meme using the "Caption Image" tool once the template_id is found. The tool input must be valid JSON with the keys: "template_id" (integer) and "text" (list of strings).
|
|
3. Download the generated meme using the "Download Image" tool.
|
|
|
|
Your tool invocations must match the exact string of one of the tools listed above. Do not use any tool more than once.
|
|
Once the Download Image tool has been used, do not use any other tools. Parse the url out of the tool's response and respond with the url.
|
|
"""
|
|
|
|
prompt_template = f"""
|
|
{system_prompt}
|
|
|
|
Question: {{question}}
|
|
|
|
Answer: Let's think step by step. We should generate exactly one meme given the directions of the user. I will not ask the user for additional input after their request. Once the meme is created, I will conclude our conversation.
|
|
"""
|
|
|
|
prompt = ChatPromptTemplate.from_template(prompt_template)
|
|
|
|
tools = [
|
|
Tool(name="Get Memes", func=get_memes, description="Does not take any agruments. Returns a list of template_ids (integer) and names (string) which are the titles of the memes that correspond to the template_id."),
|
|
Tool(name="Caption Image", func=caption_image, description="Given a template_id and list of text strings, returns the url of a new meme as a string. Tool input is valid json syntax, with the following keys: 'template_id' (integer) and 'text' (list of strings)."),
|
|
Tool(name="Download Image", func=download_image, description="Given a valid url returned from the caption_image tool, downloads the image we made locally. Tool input is valid json syntax, with the following key: 'url' (string).")
|
|
]
|
|
|
|
llm = OllamaLLM(model="llama3")
|
|
|
|
agent_executor = initialize_agent(
|
|
tools=tools,
|
|
llm=llm,
|
|
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
|
verbose=True
|
|
)
|
|
|
|
response = agent_executor.invoke({"input": "Generate an image for the 'two buttons' meme with first text 'generated meme' and second text 'langchain error'."})
|
|
print(response)
|
|
|