Skip to content

PromptAgent

About

PromptAgent optimizes a prompt by scoring, generating feedback, and generating new prompts based on the feedback. PromptAgent starts from a seed prompt with known errors on the training data. At each step, a scored prompt and a sample of its errors are passed to a language model to produce a feedback "action". Then the prompt, errors, trajectory, and feedback action are passed to a language model to produce new prompts. These new prompts are scored and the best prompt from each branch (search_mode="beam") or all prompts (search_mode="greedy") are retained for the next step

Usage

The PromptAgentOptimizer requires a description of the failures after each step. You must provide this feedback in your evaluator by capturing errors and saving them in the prompt object's errors attribute.

[!IMPORTANT] Important Note Your evaluator function **MUST** save any errors to the prompt object's errors attribute. Otherwise the optimization will fail.

from lagnchain_openai import ChatOpenAI
from prompt_optimizer import PredictionError, Prompt
from prompt_optimizer.optimizers import PromptAgentOptimizer

# Simple QA validation set
validation_set = [
    {"question": "What is the capital of France?", "answer": "Paris"},
    {"question": "What is the largest planet in our solar system?", "answer": "Jupiter"},
    {"question": "What is the smallest planet in our solar system?", "answer": "Mercury"},
    {"question": "What is the longest river in the world?", "answer": "Nile"},
    {"question": "What is the smallest river in the world?", "answer": "Reprua River"},
]

# A langchain ChatModel for generating new prompts
client = ChatOpenAI(model="gpt-5", temperature=0.7)

# Evaluator function
def evaluator(prompt: Prompt, validation_set: list[dict]) -> list[str]:
    """Prompt evaluator function."""
    # Run the prompt through the AI system
    predictions = []
    num_correct = 0
    agent = get_agent()
    for row in validation_set:
        question = row["input"]
        messages = [{"role": "system", "content": prompt.content}, {"role": "user", "content": question}]
        response = agent.invoke(messages)
        prediction = response.content.strip()
        predictions.append(prediction)

        # Reward exact matches and collect errors
        actual = row["target"]
        if actual == prediction:
            num_correct += 1
        else:
            num_correct += 0
            # Save prediction error - Required for PromptAgentOptimizer
            error = PredictionError(input=question, prediction=prediction, actual=actual, feedback=None)
            prompt.errors.append(error)

    # Compute the score
    score = num_correct / len(validation_set)

    return score

# Initialize the optimizer
baseline_prompt = "Answer the user's questions to the best of your ability."
optimizer = PromptAgentOptimizer(
    client=client,
    seed_prompts=[baseline_prompt],
    validation_set=validation_set,
    max_depth=3,
    evaluator=evaluator,
)

# Run the optimization
optimized_prompt = optimizer.run()

Citation

@misc{wang2023promptagentstrategicplanninglanguage,
    title={PromptAgent: Strategic Planning with Language Models Enables Expert-level Prompt Optimization},
    author={Xinyuan Wang and Chenxi Li and Zhen Wang and Fan Bai and Haotian Luo and Jiayou Zhang and Nebojsa Jojic and Eric P. Xing and Zhiting Hu},
    year={2023},
    eprint={2310.16427},
    archivePrefix={arXiv},
    primaryClass={cs.CL},
    url={https://arxiv.org/abs/2310.16427},
}

Source

PromptAgentOptimizer

Bases: BaseOptimizer

PromptAgent Optimizer.

Based on PromptAgent: Strategic Planning with Language Models Enables Expert-level Prompt Optimization.

@misc{wang2023promptagentstrategicplanninglanguage,
    title={PromptAgent: Strategic Planning with Language Models Enables Expert-level Prompt Optimization},
    author={Xinyuan Wang and Chenxi Li and Zhen Wang and Fan Bai and Haotian Luo and Jiayou Zhang and Nebojsa Jojic and Eric P. Xing and Zhiting Hu},
    year={2023},
    eprint={2310.16427},
    archivePrefix={arXiv},
    primaryClass={cs.CL},
    url={https://arxiv.org/abs/2310.16427},
}
Source code in src/prompt_optimizer/optimizers/promptagent.py
class PromptAgentOptimizer(BaseOptimizer):
    """
    PromptAgent Optimizer.

    Based on PromptAgent: Strategic Planning with Language Models Enables Expert-level Prompt Optimization.

    ```
    @misc{wang2023promptagentstrategicplanninglanguage,
        title={PromptAgent: Strategic Planning with Language Models Enables Expert-level Prompt Optimization},
        author={Xinyuan Wang and Chenxi Li and Zhen Wang and Fan Bai and Haotian Luo and Jiayou Zhang and Nebojsa Jojic and Eric P. Xing and Zhiting Hu},
        year={2023},
        eprint={2310.16427},
        archivePrefix={arXiv},
        primaryClass={cs.CL},
        url={https://arxiv.org/abs/2310.16427},
    }
    ```
    """

    def __init__(
        self,
        *,
        client: ClientType,
        seed_prompts: list[Prompt],
        validation_set: ValidationSetType,
        max_depth: int,
        evaluator: Callable[[Prompt, ValidationSetType], ScoreType],
        output_path: Optional[Union[str, Path]] = None,
        batch_size: int = 5,
        expand_width: int = 3,
        num_samples: int = 2,
        search_mode: Literal["beam", "greedy"] = "beam",
        score_threshold: Optional[Union[float, int]] = None,
        **kwargs,
    ):
        """
        Initialize the PromptAgent Optimizer.

        Args:
            client (ClientType):
                Language model client to use for prompt generation and feedback.
            seed_prompts (list[Prompt]):
                List of prompts to seed generation.
            validation_set (ValidationSetType):
                Set of examples to evaluate the prompt on.
            max_depth (int):
                Maximum iteration depth for prompt generation.
            evaluator (Callable[[Prompt, ValidationSetType], ScoreType]):
                Function that takes a prompt and the validation data and returns a score.
            output_path (Union[str, Path], optional):
                Path to store run results. Should be a .jsonl file path.
                If None, no outputs will be written to disk. Defaults to None.
            batch_size (int, optional):
                Number of errors to sample for each action / new prompt generation. Defaults to 5.
            expand_width (int, optional):
                Number of feedback actions to generate per prompt. Defaults to 3.
            num_samples (int, optional):
                Number of new prompts to generate per feedback action. Defaults to 2.
            search_mode (Literal["beam", "greedy"], optional):
                Mode for filtering prompt candidates after each step. "greedy" keeps all prompts from the previous step.
                "beam" keeps only the highest scoring prompt from each branch of the previous step. Defaults to "beam".
            score_threshold (float, optional):
                Threshold for early convergence. If a prompt exceeds this score after any iteration, the optimization loop
                immediately ends. If set to None, the optimization loop will not terminate early. Defaults to None.
            kwargs:
                Additional keyword arguments.

        """
        super().__init__(
            client=client,
            seed_prompts=seed_prompts,
            validation_set=validation_set,
            max_depth=max_depth,
            evaluator=evaluator,
            output_path=output_path,
        )
        self.batch_size = batch_size
        self.expand_width = expand_width
        self.num_samples = num_samples
        self.search_mode = search_mode
        self.score_threshold = score_threshold
        self.PARENT_KEY = "_parent"

    def _extract_responses(self, content: str) -> list[str]:
        """
        Extract the responses between <START> and <END>.

        Args:
            content (str): Output string from an LLM generation request.

        Returns:
            list[str]: List of all responses within <START> and </?END> or <START> and </?START>.

        """
        pattern = r"<START>(.*?)(?:<\/?END>|<\/?START>)"
        matches = re.findall(pattern, content, flags=re.DOTALL)
        return matches

    def _generate(self, metaprompt_template: str, template_kwargs: dict) -> str:
        """
        Generate a completion for a given template and kwargs and parse the results.

        Args:
            metaprompt_template (str): Template for the metaprompt.
            template_kwargs (dict): Key word arguments to fill the template values.
            kwargs: Additional kwargs to pass to the OpenAI client.completions.create (e.g. temperature)

        Returns:
            list[str]: The parsed generation results.

        """
        metaprompt = metaprompt_template.format(**template_kwargs)
        input = [{"role": "user", "content": metaprompt}]
        raw_response = self.client.invoke(input=input)
        response = raw_response.content.strip()
        return response

    def _map_trajectory(self, prompt: Optional[Prompt] = None):
        """Map the trajectory of the prompt."""
        if prompt is None:
            return ""
        parent = prompt.metadata.get(self.PARENT_KEY, None)
        return (self._map_trajectory(parent) + "\n\n" + f"Prompt: {prompt.content}\nScore: {prompt.score}").strip()

    def generate_prompt_candidates(self, *, prompts: list[Prompt], **kwargs) -> list[Prompt]:
        """Generate prompt candidates using gradients."""
        prompt_candidates = []
        for prompt in track(prompts, description="Generating prompt candidates", transient=True):
            if len(prompt.errors) == 0:
                continue

            # Map prompt trajectory
            trajectory_prompts = self._map_trajectory(prompt=prompt)

            for _ in range(self.expand_width):
                # Sample and collect errors into error string
                error_sample = random.choices(prompt.errors, k=self.batch_size)
                error_string = "\n\n".join(
                    [
                        ERROR_STRING_FEEDBACK_TEMPLATE.format(index=i + 1, input=error.input, prediction=error.prediction, feedback=error.feedback)
                        if error.feedback is not None
                        else ERROR_STRING_TEMPLATE.format(index=i + 1, input=error.input, prediction=error.prediction, actual=error.actual)
                        for i, error in enumerate(error_sample)
                    ]
                )

                # Generate actions
                template_kwargs = {
                    "prompt": prompt.content,
                    "error_string": error_string,
                    "steps_per_gradient": self.num_samples,
                    "trajectory_prompts": trajectory_prompts,
                }
                action = self._generate(metaprompt_template=ERROR_FEEDBACK_TEMPLATE, template_kwargs=template_kwargs)

                # Generate new prompts from the action
                template_kwargs.update({"error_feedback": action})
                raw_new_prompts = self._generate(metaprompt_template=STATE_TRANSIT_PROMPT_TEMPLATE, template_kwargs=template_kwargs)
                new_prompts = self._extract_responses(raw_new_prompts)
                new_prompts = new_prompts[: self.num_samples]
                metadata = {self.PARENT_KEY: prompt, "_action": action, "_resampled": False}
                new_prompt_candidates = [Prompt(content=new_prompt, metadata=metadata) for new_prompt in new_prompts]

                # Save prompts to prompt candidates
                prompt_candidates.extend(new_prompt_candidates)

        return prompt_candidates

    def _get_best_prompt(self, prompts: list[Prompt]):
        """Get the highest scoring prompt."""
        if any(prompt.score is None for prompt in prompts):
            raise ValueError("All prompts must be scored before calling this function.")
        return max(prompts, key=lambda x: x.score)

    def select_prompt_candidates(self, *, prompts: list[Prompt], validation_set: ValidationSetType) -> list[Prompt]:
        """Select prompt candidates according to the search mode."""
        self._score_prompts(prompts=prompts, validation_set=validation_set)
        # If this is the first iteration, keep all prompts
        if len(self._p) == 1:
            return prompts

        # Otherwise, keep based on search_mode
        elif self.search_mode == "greedy":
            # Keep all prompts
            return prompts

        elif self.search_mode == "beam":
            # Select the best prompt in each branch
            # Split prompts by parent
            parent_to_prompts = {prompt.metadata[self.PARENT_KEY]: [] for prompt in prompts}
            for prompt in prompts:
                parent_to_prompts[prompt.metadata[self.PARENT_KEY]].append(prompt)

            # Get the best prompt from each parent
            return [self._get_best_prompt(prompts=branch_prompts) for _, branch_prompts in parent_to_prompts.items()]

    def check_early_convergence(self, *, all_prompts: list[list[Prompt]]):
        """Check if the early convergence criteria is met."""
        if self.score_threshold is None:
            return False

        # Flatten all iterations
        prompts = sum(all_prompts, start=[])

        # Check if early convergence criteria is met
        highest_score = max(prompts, key=lambda x: x.score).score
        if highest_score >= self.score_threshold:
            return True
        return False

    def select_best_prompt(self, *, all_prompts: list[list[Prompt]]) -> Prompt:
        """Select the top scoring prompt."""
        # Flatten all iterations
        prompts = sum(all_prompts, start=[])

        # Select the single prompt with the highest score
        best_prompt = self._get_best_prompt(prompts=prompts)
        logger.info(f"Best score: {best_prompt.score:.3f}")
        return best_prompt

__init__(*, client, seed_prompts, validation_set, max_depth, evaluator, output_path=None, batch_size=5, expand_width=3, num_samples=2, search_mode='beam', score_threshold=None, **kwargs)

Initialize the PromptAgent Optimizer.

Parameters:

Name Type Description Default
client ClientType

Language model client to use for prompt generation and feedback.

required
seed_prompts list[Prompt]

List of prompts to seed generation.

required
validation_set ValidationSetType

Set of examples to evaluate the prompt on.

required
max_depth int

Maximum iteration depth for prompt generation.

required
evaluator Callable[[Prompt, ValidationSetType], ScoreType]

Function that takes a prompt and the validation data and returns a score.

required
output_path Union[str, Path]

Path to store run results. Should be a .jsonl file path. If None, no outputs will be written to disk. Defaults to None.

None
batch_size int

Number of errors to sample for each action / new prompt generation. Defaults to 5.

5
expand_width int

Number of feedback actions to generate per prompt. Defaults to 3.

3
num_samples int

Number of new prompts to generate per feedback action. Defaults to 2.

2
search_mode Literal['beam', 'greedy']

Mode for filtering prompt candidates after each step. "greedy" keeps all prompts from the previous step. "beam" keeps only the highest scoring prompt from each branch of the previous step. Defaults to "beam".

'beam'
score_threshold float

Threshold for early convergence. If a prompt exceeds this score after any iteration, the optimization loop immediately ends. If set to None, the optimization loop will not terminate early. Defaults to None.

None
kwargs

Additional keyword arguments.

{}
Source code in src/prompt_optimizer/optimizers/promptagent.py
def __init__(
    self,
    *,
    client: ClientType,
    seed_prompts: list[Prompt],
    validation_set: ValidationSetType,
    max_depth: int,
    evaluator: Callable[[Prompt, ValidationSetType], ScoreType],
    output_path: Optional[Union[str, Path]] = None,
    batch_size: int = 5,
    expand_width: int = 3,
    num_samples: int = 2,
    search_mode: Literal["beam", "greedy"] = "beam",
    score_threshold: Optional[Union[float, int]] = None,
    **kwargs,
):
    """
    Initialize the PromptAgent Optimizer.

    Args:
        client (ClientType):
            Language model client to use for prompt generation and feedback.
        seed_prompts (list[Prompt]):
            List of prompts to seed generation.
        validation_set (ValidationSetType):
            Set of examples to evaluate the prompt on.
        max_depth (int):
            Maximum iteration depth for prompt generation.
        evaluator (Callable[[Prompt, ValidationSetType], ScoreType]):
            Function that takes a prompt and the validation data and returns a score.
        output_path (Union[str, Path], optional):
            Path to store run results. Should be a .jsonl file path.
            If None, no outputs will be written to disk. Defaults to None.
        batch_size (int, optional):
            Number of errors to sample for each action / new prompt generation. Defaults to 5.
        expand_width (int, optional):
            Number of feedback actions to generate per prompt. Defaults to 3.
        num_samples (int, optional):
            Number of new prompts to generate per feedback action. Defaults to 2.
        search_mode (Literal["beam", "greedy"], optional):
            Mode for filtering prompt candidates after each step. "greedy" keeps all prompts from the previous step.
            "beam" keeps only the highest scoring prompt from each branch of the previous step. Defaults to "beam".
        score_threshold (float, optional):
            Threshold for early convergence. If a prompt exceeds this score after any iteration, the optimization loop
            immediately ends. If set to None, the optimization loop will not terminate early. Defaults to None.
        kwargs:
            Additional keyword arguments.

    """
    super().__init__(
        client=client,
        seed_prompts=seed_prompts,
        validation_set=validation_set,
        max_depth=max_depth,
        evaluator=evaluator,
        output_path=output_path,
    )
    self.batch_size = batch_size
    self.expand_width = expand_width
    self.num_samples = num_samples
    self.search_mode = search_mode
    self.score_threshold = score_threshold
    self.PARENT_KEY = "_parent"

check_early_convergence(*, all_prompts)

Check if the early convergence criteria is met.

Source code in src/prompt_optimizer/optimizers/promptagent.py
def check_early_convergence(self, *, all_prompts: list[list[Prompt]]):
    """Check if the early convergence criteria is met."""
    if self.score_threshold is None:
        return False

    # Flatten all iterations
    prompts = sum(all_prompts, start=[])

    # Check if early convergence criteria is met
    highest_score = max(prompts, key=lambda x: x.score).score
    if highest_score >= self.score_threshold:
        return True
    return False

generate_prompt_candidates(*, prompts, **kwargs)

Generate prompt candidates using gradients.

Source code in src/prompt_optimizer/optimizers/promptagent.py
def generate_prompt_candidates(self, *, prompts: list[Prompt], **kwargs) -> list[Prompt]:
    """Generate prompt candidates using gradients."""
    prompt_candidates = []
    for prompt in track(prompts, description="Generating prompt candidates", transient=True):
        if len(prompt.errors) == 0:
            continue

        # Map prompt trajectory
        trajectory_prompts = self._map_trajectory(prompt=prompt)

        for _ in range(self.expand_width):
            # Sample and collect errors into error string
            error_sample = random.choices(prompt.errors, k=self.batch_size)
            error_string = "\n\n".join(
                [
                    ERROR_STRING_FEEDBACK_TEMPLATE.format(index=i + 1, input=error.input, prediction=error.prediction, feedback=error.feedback)
                    if error.feedback is not None
                    else ERROR_STRING_TEMPLATE.format(index=i + 1, input=error.input, prediction=error.prediction, actual=error.actual)
                    for i, error in enumerate(error_sample)
                ]
            )

            # Generate actions
            template_kwargs = {
                "prompt": prompt.content,
                "error_string": error_string,
                "steps_per_gradient": self.num_samples,
                "trajectory_prompts": trajectory_prompts,
            }
            action = self._generate(metaprompt_template=ERROR_FEEDBACK_TEMPLATE, template_kwargs=template_kwargs)

            # Generate new prompts from the action
            template_kwargs.update({"error_feedback": action})
            raw_new_prompts = self._generate(metaprompt_template=STATE_TRANSIT_PROMPT_TEMPLATE, template_kwargs=template_kwargs)
            new_prompts = self._extract_responses(raw_new_prompts)
            new_prompts = new_prompts[: self.num_samples]
            metadata = {self.PARENT_KEY: prompt, "_action": action, "_resampled": False}
            new_prompt_candidates = [Prompt(content=new_prompt, metadata=metadata) for new_prompt in new_prompts]

            # Save prompts to prompt candidates
            prompt_candidates.extend(new_prompt_candidates)

    return prompt_candidates

get_all_prompts(include_candidates=False)

Get all the prompts from the latest training run.

The default behavior returns a list of lists, where each internal list contains the retained candidates after one iteration step. Setting include_candidates to True will also include all generated candidate prompts.

Parameters:

Name Type Description Default
include_candidates bool

Whether to include all the candidate prompts in the output. If True, candidate prompts from each iteration will be included. Defaults to False.

False

Returns:

Type Description
list[list[Prompt]]

list[list[Prompt]]: List of lists where each list contains the prompts from each iteration. E.g. list[0] contains prompts from the first iteration, list[1] the second, etc. If include_candidates is False, each inner list contains only the retained prompts at each iteration. If include_candidates is True, each inner list contains all candidate prompts at each iteration, including those that were discarded.

Source code in src/prompt_optimizer/optimizers/base.py
def get_all_prompts(self, include_candidates: bool = False) -> list[list[Prompt]]:
    """
    Get all the prompts from the latest training run.

    The default behavior returns a list of lists, where each internal list contains the
    retained candidates after one iteration step.
    Setting include_candidates to True will also include all generated candidate prompts.

    Args:
        include_candidates (bool, optional):
            Whether to include all the candidate prompts in the output.
            If True, candidate prompts from each iteration will be included.
            Defaults to False.

    Returns:
        list[list[Prompt]]:
            List of lists where each list contains the prompts from each iteration.
            E.g. list[0] contains prompts from the first iteration, list[1] the second, etc.
            If include_candidates is False, each inner list contains only the retained prompts at each iteration.
            If include_candidates is True, each inner list contains all candidate prompts at each iteration,
            including those that were discarded.

    """
    # Decide whether to include candidates
    if include_candidates:
        all_prompts = self._g
    else:
        all_prompts = self._p

    return all_prompts

run()

Run the optimization pipeline.

Source code in src/prompt_optimizer/optimizers/base.py
def run(self) -> Prompt:
    """Run the optimization pipeline."""
    # Score seed_prompts
    self.seed_prompts = self._score_prompts(self.seed_prompts, self.validation_set)

    # Initialize objects
    self._p = [self.seed_prompts]
    self._g = [self.seed_prompts]

    # Iterate until max depth
    for t in track(range(1, self.max_depth + 1), description="Step", total=self.max_depth):
        # Generate prompt candidates
        g_t = self.generate_prompt_candidates(prompts=self._p[t - 1], validation_set=self.validation_set)
        self._g.append(g_t)
        # Select prompt candidates
        p_t = self.select_prompt_candidates(prompts=self._g[t], validation_set=self.validation_set)
        self._p.append(p_t)
        # Check for early convergence
        if self.check_early_convergence(all_prompts=self._p):
            break

    # Save prompts if requested
    self.save_prompts(output_path=self.output_path)

    # Return best prompt
    return self.select_best_prompt(all_prompts=self._p)

save_prompts(output_path)

Save prompts in jsonl format.

Source code in src/prompt_optimizer/optimizers/base.py
def save_prompts(self, output_path: Optional[Union[str, Path]]):
    """Save prompts in jsonl format."""
    # Exit if no output path is set
    if self.output_path is None:
        return

    # Get and deduplicate prompts
    prompts = sum(self._p, start=[])
    prompts = list(set(prompts))

    # Save the prompts to the file
    lines = [prompt.model_dump_json() for prompt in prompts]
    with open(output_path, "w") as f:
        for line in lines:
            f.write(line)
            f.write("\n")

select_best_prompt(*, all_prompts)

Select the top scoring prompt.

Source code in src/prompt_optimizer/optimizers/promptagent.py
def select_best_prompt(self, *, all_prompts: list[list[Prompt]]) -> Prompt:
    """Select the top scoring prompt."""
    # Flatten all iterations
    prompts = sum(all_prompts, start=[])

    # Select the single prompt with the highest score
    best_prompt = self._get_best_prompt(prompts=prompts)
    logger.info(f"Best score: {best_prompt.score:.3f}")
    return best_prompt

select_prompt_candidates(*, prompts, validation_set)

Select prompt candidates according to the search mode.

Source code in src/prompt_optimizer/optimizers/promptagent.py
def select_prompt_candidates(self, *, prompts: list[Prompt], validation_set: ValidationSetType) -> list[Prompt]:
    """Select prompt candidates according to the search mode."""
    self._score_prompts(prompts=prompts, validation_set=validation_set)
    # If this is the first iteration, keep all prompts
    if len(self._p) == 1:
        return prompts

    # Otherwise, keep based on search_mode
    elif self.search_mode == "greedy":
        # Keep all prompts
        return prompts

    elif self.search_mode == "beam":
        # Select the best prompt in each branch
        # Split prompts by parent
        parent_to_prompts = {prompt.metadata[self.PARENT_KEY]: [] for prompt in prompts}
        for prompt in prompts:
            parent_to_prompts[prompt.metadata[self.PARENT_KEY]].append(prompt)

        # Get the best prompt from each parent
        return [self._get_best_prompt(prompts=branch_prompts) for _, branch_prompts in parent_to_prompts.items()]