diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7889b69d..1fdcd036 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -67,6 +67,6 @@ jobs: run: pytest tests/unit_tests/ # 9) Run basic tests for each optimizer (some will fail due to the small LLM model chosen for free GitHub CI) - - name: Run optimizers test suite - run: pytest tests/llm_optimizers_tests/test_optimizer.py || true - continue-on-error: true +# - name: Run optimizers test suite +# run: pytest tests/llm_optimizers_tests/test_optimizer.py || true +# continue-on-error: true diff --git a/docs/tutorials/minibatch.ipynb b/docs/tutorials/minibatch.ipynb index 890d13f1..95076033 100644 --- a/docs/tutorials/minibatch.ipynb +++ b/docs/tutorials/minibatch.ipynb @@ -601,11 +601,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Step 1] \u001b[92mAverage test score: 1.0\u001b[0m\n", + "[Step 1] \u001B[92mAverage test score: 1.0\u001B[0m\n", "Epoch: 0. Iteration: 1\n", "[Step 1] Instantaneous train score: 1.0\n", "[Step 1] Average train score: 1.0\n", - "[Step 1] \u001b[91mParameter: str:20: You're a helpful agent\u001b[0m\n" + "[Step 1] \u001B[91mParameter: str:20: You're a helpful agent\u001B[0m\n" ] }, { @@ -641,11 +641,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Step 2] \u001b[92mAverage test score: 1.0\u001b[0m\n", + "[Step 2] \u001B[92mAverage test score: 1.0\u001B[0m\n", "Epoch: 0. Iteration: 2\n", "[Step 2] Instantaneous train score: 1.0\n", "[Step 2] Average train score: 1.0\n", - "[Step 2] \u001b[91mParameter: str:20: You're a helpful agent\u001b[0m\n" + "[Step 2] \u001B[91mParameter: str:20: You're a helpful agent\u001B[0m\n" ] }, { @@ -677,11 +677,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Step 3] \u001b[92mAverage test score: 1.0\u001b[0m\n", + "[Step 3] \u001B[92mAverage test score: 1.0\u001B[0m\n", "Epoch: 0. Iteration: 3\n", "[Step 3] Instantaneous train score: 1.0\n", "[Step 3] Average train score: 1.0\n", - "[Step 3] \u001b[91mParameter: str:20: You're a helpful agent\u001b[0m\n" + "[Step 3] \u001B[91mParameter: str:20: You're a helpful agent\u001B[0m\n" ] }, { @@ -714,11 +714,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Step 4] \u001b[92mAverage test score: 1.0\u001b[0m\n", + "[Step 4] \u001B[92mAverage test score: 1.0\u001B[0m\n", "Epoch: 0. Iteration: 4\n", "[Step 4] Instantaneous train score: 1.0\n", "[Step 4] Average train score: 1.0\n", - "[Step 4] \u001b[91mParameter: str:20: You're a helpful agent\u001b[0m\n" + "[Step 4] \u001B[91mParameter: str:20: You're a helpful agent\u001B[0m\n" ] }, { @@ -751,11 +751,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Step 5] \u001b[92mAverage test score: 1.0\u001b[0m\n", + "[Step 5] \u001B[92mAverage test score: 1.0\u001B[0m\n", "Epoch: 0. Iteration: 5\n", "[Step 5] Instantaneous train score: 1.0\n", "[Step 5] Average train score: 1.0\n", - "[Step 5] \u001b[91mParameter: str:20: You're a helpful agent\u001b[0m\n", + "[Step 5] \u001B[91mParameter: str:20: You're a helpful agent\u001B[0m\n", "FINISHED TRAINING\n" ] }, @@ -831,4 +831,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/opto/features/experimental_optimizers/__init__.py b/opto/features/experimental_optimizers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/opto/features/experimental_optimizers/agentic_opt.py b/opto/features/experimental_optimizers/agentic_opt.py new file mode 100644 index 00000000..adc07dd6 --- /dev/null +++ b/opto/features/experimental_optimizers/agentic_opt.py @@ -0,0 +1,233 @@ +""" +An Agentic Optimizer that has access to tools and other resources to perform complex optimization tasks. +Particularly useful for formal theorem proving, code optimization (including kernel code), and algorithm designs. + +We use types defined in opto.features.flows +""" + +import json +from textwrap import dedent +from dataclasses import dataclass, asdict +from typing import Dict + +from opto.trace.nodes import ParameterNode, Node, MessageNode +from opto.trace.propagators import TraceGraph, GraphPropagator +from opto.trace.propagators.propagators import Propagator + +from opto.optimizers.optoprime_v2 import OptoPrimeV2, OptimizerPromptSymbolSet +from opto.optimizers.optimizer import Optimizer +from opto.trainer.guide import Guide +from opto.utils.llm import AbstractModel, LLM + +from typing import Any, Callable, Dict, List, Tuple, Optional + +from opto.features.flows.compose import Loop, ChatHistory, StopCondition, Check + +""" +A few design that it must have: +1. **multi-turn conversation by default (memory management)** +2. Flexibility to do tool-use +3. has scaffolding functions people can call to design their custom optimizer + +First an abstract agent with the features, then implement? + +Idea: write it like you would use for VeriBench +1. bug fix loop +2. external reward loop + +initial task prompt -> initial solution +initial solution -> improvement prompt -> improved solution -> improvement prompt -> improved solution + +(inside flow) Loop (use a boolean reward function) to keep executing + +Inherits this optimizer +and specify the main forward by just calling different scaffolding functions +""" + + +# Base class that provides scaffolding +class AgenticOptimizer(Optimizer): + def __init__( + self, + parameters: List[ParameterNode], + *args, + propagator: Propagator = None, + **kwargs + ): + pass + + +""" +Add a veribench optimizer here. +A code optimizer is general, can work with kernel, others + +initial_task_message -> initial_code +(initial_code, reward) -> optimizer_step -> (improved_code, reward) -> optimizer_step -> improved_code + +initial_code + 1. check if there is a bug. If there is, try 10 times to fix. return the code --> make it correct/compilable + 2. take an improvement step --> make it run faster + +improve initial code, bug fix 10 times -> improve bug-fixed code, bug fix 10 times -> ... + +optimizer = (improve initial code, bug fix 10 times) +priority_search + +TODO: +1. Make sure guide can be deep-copied +2. Declare a node that's trainable, we can have an initializer. When the node is created, we call the initializer. + +add an init function to the optimizer API + +define a class of Initializer + +init function of the optimizer can receive some arguments + +standardize optimizer API: +- objective, context + +State of the worker: the worker decides what to put into the state +(each candidate has its own optimizer) +(receive a candidate) (update rule is stateless) + +Flip the order between + +Use initializer on node, and then use **projection** on the node +Optimizer only focuses on improvement + +LLM-based initializer, LLM-based projection + +============= +planning agent: plan -> coding agent: write the code + +------------ + +CodeOptimizer: +1. Optimization history (different from memory) + +CodeOptimizer: +task: make this code run fast, parameterNode(initial code) + +oracle: gives value + +node(initial_node, description="this is xxx, the value should be...") + +------------ + +init_param -> 25 next params (top 5) -> 25 params -> 25 params + +Protein: multi-objective optimization +""" + +class CodeOptimizer(Optimizer): + def __init__( + self, + parameters: List[ParameterNode], + llm: AbstractModel = None, + *args, + propagator: Propagator = None, + bug_judge: Guide = None, # compiler() + reward_function: Callable[[str], float] = None, + max_bug_fix_tries: int = 5, + max_optimization_tries: int = 10, + chat_max_len: int = 25, + **kwargs + ): + super().__init__(parameters, *args, propagator=propagator, **kwargs) + + self.llm = llm or LLM() + + assert len(parameters) == 1, "CodeOptimizer expects a single ParameterNode as input" + self.init_code = parameters[0] + + # Initialize chat history + self.chat_history = ChatHistory(max_round=chat_max_len, auto_summary=False) + + # Store environment checker and reward function + self.bug_judge = bug_judge + self.reward_function = reward_function + self.max_bug_fix_tries = max_bug_fix_tries + self.max_optimization_tries = max_optimization_tries + + self.task_description = None + self.initial_instruction = None + + # 2. Do a single improvement step + + def initial_context(self, task_description, initial_instruction): + """ + This provides the history of how the initial code was produced + """ + self.task_description = task_description + self.initial_instruction = initial_instruction + self.chat_history.add_system_message(self.task_description) + self.chat_history.add(self.initial_instruction,'user') + self.chat_history.add(self.init_code, 'assistant') + + def bug_fix_step(self, lean4_code: str, max_try: int = 5) -> str: + """ + This function is used to self-correct the Lean 4 code. + It will be called by the LLM when the code does not compile. + """ + + # apply heuristic fixes + lean4_code = self.remove_import_error(lean4_code) + valid, error_details = self.bug_judge(lean4_code) + + if valid: + return lean4_code + + temp_conv_hist = self.chat_history.copy() + + counter = 0 + while not valid and counter < max_try: + # sometimes LLM will hallucinate import error, so we remove that import statement + + print(f"Attempt {counter+1}: Fixing compilation errors") + + detailed_error_message = self.concat_error_messages(error_details) + + temp_conv_hist.add(detailed_error_message + "\n\n" + f"Lean code compilation FAILED with {len(error_details)} errors. If a theorem keeps giving error, you can use := sorry to skip it. Please wrap your lean code in ```lean and ```", + "user") + + raw_program = self.llm(temp_conv_hist.get_messages(), verbose=False) + + lean4_code = self.simple_post_process(raw_program) + lean4_code = self.remove_import_error(lean4_code) + + valid, error_details = self.bug_judge(lean4_code) + + if valid: + print(f"Successfully fixed errors after {counter} attempts") + # we add to the round + self.chat_history = self.chat_history + temp_conv_hist[-1].remove_system_message() + return lean4_code + else: + counter += 1 + temp_conv_hist.add_message(lean4_code, "assistant") + + return lean4_code + + def _step(self, verbose=False, *args, **kwargs) -> Dict[ParameterNode, Any]: + """ + Each step, we perform bug fix for a few rounds, then do one improvement step + We add everything to the chat history + """ + lean4_code = self.bug_fix_step(self.init_code.data, max_try=self.max_bug_fix_tries) + # do one step improvement + lean4_code = self.improve_step(lean4_code) + + return lean4_code + + def _extract_code(self, response: str) -> str: + """Extract code from markdown code blocks.""" + import re + # Match python code blocks + pattern = r'```python\n(.*?)```' + matches = re.findall(pattern, response, re.DOTALL) + + if matches: + return matches[0].strip() + + # If no code block found, return the response as-is + return response.strip() diff --git a/opto/features/flows/compose.py b/opto/features/flows/compose.py deleted file mode 100644 index dff95bbc..00000000 --- a/opto/features/flows/compose.py +++ /dev/null @@ -1,229 +0,0 @@ -import opto.trace as trace -from typing import Union, get_type_hints, Any, Dict, List, Optional -from opto.utils.llm import AbstractModel, LLM -import contextvars - -""" -TracedLLM: -1. special operations that supports specifying inputs (system_prompt, user_prompt) to LLM and parsing of outputs, wrap - everything under one command. -2. Easy to use interface -- can be inherited by users. -3. Support multi-turn chatting (message history) - -Usage patterns: - -Direct use: (only supports single input, single output) (signature: str -> str) -llm = TracedLLM("You are a helpful assistant.") -response = llm("Hello, what's the weather in France today?") -""" - -USED_TracedLLM = contextvars.ContextVar('USED_TracedLLM', default=list()) - - -class ChatHistory: - def __init__(self, max_len=50, auto_summary=False): - """Initialize chat history for multi-turn conversation. - - Args: - max_len: Maximum number of messages to keep in history - auto_summary: Whether to automatically summarize old messages - """ - self.messages: List[Dict[str, Any]] = [] - self.max_len = max_len - self.auto_summary = auto_summary - - def __len__(self): - return len(self.messages) - - def add(self, content: Union[trace.Node, str], role): - """Add a message to history with role validation. - - Args: - content: The content of the message - role: The role of the message ("user" or "assistant") - """ - if role not in ["user", "assistant"]: - raise ValueError(f"Invalid role '{role}'. Must be 'user' or 'assistant'.") - - # Check for alternating user/assistant pattern - if len(self.messages) > 0: - last_msg = self.messages[-1] - if last_msg["role"] == role: - print(f"Warning: Adding consecutive {role} messages. Consider alternating user/assistant messages.") - - self.messages.append({"role": role, "content": content}) - self._trim_history() - - def append(self, message: Dict[str, Any]): - """Append a message directly to history.""" - if "role" not in message or "content" not in message: - raise ValueError("Message must have 'role' and 'content' fields.") - self.add(message["content"], message["role"]) - - def __iter__(self): - return iter(self.messages) - - def get_messages(self) -> List[Dict[str, str]]: - messages = [] - for message in self.messages: - if isinstance(message['content'], trace.Node): - messages.append({"role": message["role"], "content": message["content"].data}) - else: - messages.append(message) - return messages - - def get_messages_as_node(self, llm_name="") -> List[trace.Node]: - node_list = [] - for message in self.messages: - # If user query is a node and has other computation attached, we can't rename it - if isinstance(message['content'], trace.Node): - node_list.append(message['content']) - else: - role = message["role"] - content = message["content"] - name = f"{llm_name}_{role}" if llm_name else f"{role}" - if role == 'user': - name += "_query" - elif role == 'assistant': - name += "_response" - node_list.append(trace.node(content, name=name)) - - return node_list - - def _trim_history(self): - """Trim history to max_len while preserving first user message.""" - if len(self.messages) <= self.max_len: - return - - # Find first user message index - first_user_idx = None - for i, msg in enumerate(self.messages): - if msg["role"] == "user": - first_user_idx = i - break - - # Keep first user message - protected_messages = [] - if first_user_idx is not None: - first_user_msg = self.messages[first_user_idx] - protected_messages.append(first_user_msg) - - # Calculate how many recent messages we can keep - remaining_slots = self.max_len - len(protected_messages) - if remaining_slots > 0: - # Get recent messages - recent_messages = self.messages[-remaining_slots:] - # Avoid duplicating first user message - if first_user_idx is not None: - first_user_msg = self.messages[first_user_idx] - recent_messages = [msg for msg in recent_messages if msg != first_user_msg] - - self.messages = protected_messages + recent_messages - else: - self.messages = protected_messages - - -DEFAULT_SYSTEM_PROMPT_DESCRIPTION = ("the system prompt to the agent. By tuning this prompt, we can control the " - "behavior of the agent. For example, it can be used to provide instructions to " - "the agent (such as how to reason about the problem, how to use tools, " - "how to answer the question), or provide in-context examples of how to solve the " - "problem.") - - -@trace.model -class TracedLLM: - """ - This high-level model provides an easy-to-use interface for LLM calls with system prompts and optional chat history. - - Python usage patterns: - - llm = UF_LLM(system_prompt) - response = llm.chat(user_prompt) - response_2 = llm.chat(user_prompt_2) - - The underlying Trace Graph: - TracedLLM_response0 = TracedLLM.forward.call_llm(args_0=system_prompt0, args_1=TracedLLM0_user_query0) - TracedLLM_response1 = TracedLLM.forward.call_llm(args_0=system_prompt0, args_1=TracedLLM0_user_query0, args_2=TracedLLM_response0, args_3=TracedLLM0_user_query1) - TracedLLM_response2 = TracedLLM.forward.call_llm(args_0=system_prompt0, args_1=TracedLLM0_user_query0, args_2=TracedLLM_response0, args_3=TracedLLM0_user_query1, args_4=TracedLLM_response1, args_5=TracedLLM0_user_query2) - """ - - def __init__(self, - system_prompt: Union[str, None, trace.Node] = None, - llm: AbstractModel = None, chat_history_on=False, - trainable=False, model_name=None): - """Initialize TracedLLM with a system prompt. - - Args: - system_prompt: The system prompt to use for LLM calls. If None and the class has a docstring, the docstring will be used. - llm: The LLM model to use for inference - chat_history_on: if on, maintain chat history for multi-turn conversations - """ - if system_prompt is None: - system_prompt = "You are a helpful assistant." - - self.system_prompt = trace.node(system_prompt, name='system_prompt', - description=DEFAULT_SYSTEM_PROMPT_DESCRIPTION, - trainable=trainable) - # if system_prompt is already a node, then we have to override its trainable attribute - self.system_prompt.trainable = trainable - - if llm is None: - llm = LLM() - assert isinstance(llm, AbstractModel), f"{llm} must be an instance of AbstractModel" - self.llm = llm - self.chat_history = ChatHistory() - self.chat_history_on = chat_history_on - - current_llm_sessions = USED_TracedLLM.get() - self.model_name = model_name if model_name else f"{self.__class__.__name__}{len(current_llm_sessions)}" - current_llm_sessions.append(1) # just a marker - - def forward(self, user_query: str, chat_history_on: Optional[bool] = None) -> str: - """This function takes user_query as input, and returns the response from the LLM, with the system prompt prepended. - This method will always save chat history. - - If chat_history_on is set to False, the chat history will not be included in the LLM input. - If chat_history_on is None, it will use the class-level chat_history_on setting. - If chat_history_on is True, the chat history will be included in the LLM input. - - Args: - user_query: The user query to send to the LLM - - Returns: - str: For direct pattern - """ - chat_history_on = self.chat_history_on if chat_history_on is None else chat_history_on - - messages = [{"role": "system", "content": self.system_prompt.data}] - if chat_history_on: - messages.extend(self.chat_history.get_messages()) - messages.append({"role": "user", "content": user_query}) - - response = self.llm(messages=messages) - - @trace.bundle(output_name=f"{self.model_name}_response") - def call_llm(*messages) -> str: - """Call the LLM model. - Args: - messages: All the conversation history so far, starting from system prompt, to alternating user/assistant messages, ending with the current user query. - Returns: - response from the LLM - """ - return response.choices[0].message.content - - user_query_node = trace.node(user_query, name=f"{self.model_name}_user_query") - arg_list = [self.system_prompt] - if chat_history_on: - arg_list += self.chat_history.get_messages_as_node(self.model_name) - arg_list += [user_query_node] - - response_node = call_llm(*arg_list) - - # save to chat history - self.chat_history.add(user_query_node, role="user") - self.chat_history.add(response_node, role="assistant") - - return response_node - - def chat(self, user_query: str) -> str: - return self.forward(user_query) diff --git a/opto/features/flows/compose/__init__.py b/opto/features/flows/compose/__init__.py new file mode 100644 index 00000000..2950f740 --- /dev/null +++ b/opto/features/flows/compose/__init__.py @@ -0,0 +1,5 @@ +from opto.features.flows.compose.llm import TracedLLM, ChatHistory +from opto.features.flows.compose.parser import llm_call +from opto.features.flows.compose.agentic_ops import Loop, StopCondition, Check + +__all__ = ["TracedLLM", "ChatHistory", "llm_call", "Loop", "StopCondition", "Check"] \ No newline at end of file diff --git a/opto/features/flows/compose/agentic_ops.py b/opto/features/flows/compose/agentic_ops.py new file mode 100644 index 00000000..7d0d87c9 --- /dev/null +++ b/opto/features/flows/compose/agentic_ops.py @@ -0,0 +1,355 @@ +import opto.trace as trace +from typing import Tuple, Union, get_type_hints, Any, Dict, List, Optional, Callable +from opto.utils.llm import AbstractModel, LLM +from opto.features.flows.types import MultiModalPayload, QueryModel, StructuredInput, StructuredOutput, \ + ForwardMixin +from opto.trainer.guide import Guide +import numpy as np +import contextvars + +# =========== Mixin for Agentic Optimizer =========== + +""" + +@llm_call +def auto_correct(solution: InitialSolution) -> ImprovedSolution: + ... + +class InitialSolution: + "This is the intiial solution of a coding problem" + +class ImprovedSolution: + name: str + +a = auto_correct(solution) + +class CodeOptimizer: + def __init__(self): + # self.auto_correct = Loop(auto_correct) + # self.auto_correct = Loop(self.auto_correct_op) + self.auto_improve = Loop() + + def auto_correct_op(self, solution): + prompt = "\n improve this." + llm_response = self.llm(prompt + solution) + return llm_response + + def forward(self): + next_phase = self.auto_correct(solution) + return self.auto_improve(next_phase) + +TODO: +1. Support different styles of calling/initializing loop. Need to cover all use cases. +""" + + +class StopCondition(ForwardMixin): + """ + A stop condition for the loop. It can be a callable class instance or a function. + A simple stop condition is not necessary to be a subclass of this class. + + Implement an init method to pass in extra info into the stop condition + + Example Usage: + + class MaxIterationsOrConverged: + def __init__(self, max_iters=5, threshold=0.01): + self.max_iters = max_iters + self.threshold = threshold + + def __call__(self, param_history: List, result_history: List) -> bool: + # Stop if we've done enough iterations + if len(result_history) >= self.max_iters: + return True + + # Stop if results have converged (example) + if len(result_history) >= 2: + diff = abs(result_history[-1] - result_history[-2]) + if diff < self.threshold: + return True + + return False + """ + + def forward(self, param_history: List, result_history: List) -> bool: + """The Loop will call the stop condition with the param_history and result_history. + The stop condition should return a boolean value. + """ + raise NotImplementedError("Need to implement the forward function") + + +class Loop(ForwardMixin): + + def __init__(self, func: Callable[[Any], Any], stop_condition: Any = None): + assert callable(func), "func must be a callable" + + self.stop_condition = stop_condition + self.func = func + if stop_condition is not None: + self.check_stop_condition(stop_condition) + + def check_stop_condition(self, stop_condition: Any): + # Check if it's a callable class (has __call__ method and is an instance) + if not callable(stop_condition): + raise TypeError("stop_condition must be callable") + + # Check if it's a class instance (not a function or method) + # Functions/methods don't have __dict__ or they have __self__ for bound methods + if not hasattr(stop_condition, '__dict__') and not hasattr(stop_condition, '__self__'): + raise TypeError("stop_condition must be a callable class instance, not a function") + + # Get type hints from the __call__ method + try: + hints = get_type_hints(stop_condition.__call__) + if 'return' in hints: + # Optional: validate that return type annotation is bool + assert hints['return'] == bool, \ + f"stop_condition.__call__ must be annotated to return bool, got {hints['return']}" + except AttributeError: + pass # If __call__ doesn't have type hints, skip validation + + def step(self, *current_params: Any) -> Tuple[Any, Dict[str, Any]]: + """Override this method to define the loop logic. If `func` is passed in during init, this is not necessary.""" + if self.func is not None: + result, info = self.func(*current_params) + return result, info + raise NotImplementedError("Must provide func during initialization or override step method") + + def forward(self, *args, max_try: int = 10, stop_condition: Any = None) -> Tuple[ + Any, List[Dict[str, Any]]]: + """ + Execute the loop with the given initial parameters. + + Args: + *args: Initial parameters to pass to the function + max_try: Maximum number of iterations + stop_condition: Optional stop condition to override the one from __init__ + + Returns: + Tuple of (final_params, result_history) + """ + param_history = [] + result_history = [] + current_params = args + + # Use the stop_condition from forward() if provided, otherwise use the one from __init__ + active_stop_condition = stop_condition if stop_condition is not None else self.stop_condition + + # Check if the initial parameter is already good enough (only if stop_condition exists) + if active_stop_condition is not None: + should_stop = active_stop_condition(param_history, result_history) + assert isinstance(should_stop, (bool, np.bool_)), \ + f"stop_condition must return a boolean value, got {type(should_stop)}" + + if should_stop: + return current_params if len(current_params) > 1 else ( + current_params[0] if current_params else None), result_history + + for iteration in range(max_try): + # Track the parameters before calling step + param_history.append( + current_params if len(current_params) > 1 else (current_params[0] if current_params else None)) + + # Update current_params using step method + result, info = self.step(*current_params) + + # Normalize result to always be a tuple for consistency + if not isinstance(result, tuple): + current_params = (result,) + else: + current_params = result + + # Track the result from step + result_history.append(info) + + # Check stop condition after each step (only if it exists) + if active_stop_condition is not None: + should_stop = active_stop_condition(param_history, result_history) + assert isinstance(should_stop, (bool, np.bool_)), \ + f"stop_condition must return a boolean value, got {type(should_stop)}" + + if should_stop: + break + + # Return unpacked if single param, tuple if multiple params + final_result = current_params if len(current_params) > 1 else (current_params[0] if current_params else None) + return final_result, result_history + + +""" +Usage patterns: + +Check(should_optimize, value) + .then(lambda v: Loop(optimize_step, stop_condition)()) + .or_else(skip_optimization)() + +class ConditionalUpdateLoop(Loop): + def step(self, param): + return Check(needs_large_step, param) + .then(large_update) + .or_else(small_update)() + +Check(validate_data, data).then( + lambda d: Check(needs_preprocessing, d) + .then(lambda: Loop(preprocess, QualityStop())()) + .or_else(process_directly)() +).or_else(reject)() +""" + + +class Check(ForwardMixin): + """A DSL for conditional execution with fluent interface.""" + + def __init__(self, condition_func, *args, **kwargs): + """ + Initialize the Check with a condition function and its arguments. + + Args: + condition_func: A callable that returns a truthy/falsy value + *args: Positional arguments to pass to the condition function + **kwargs: Keyword arguments to pass to the condition function + """ + self.condition_func = condition_func + self.args = args + self.kwargs = kwargs + self.condition_result = None + self.condition_evaluated = False + self.then_func = None + self.then_args = None + self.then_kwargs = None + self.elif_branches = [] # List of (condition, func, args, kwargs) tuples + self.else_func = None + self.else_args = None + self.else_kwargs = None + self.do_func = None + self.do_args = None + self.do_kwargs = None + + def _evaluate_condition(self): + """Lazily evaluate the condition function.""" + if not self.condition_evaluated: + result = self.condition_func(*self.args, **self.kwargs) + # Store both the truthiness and the actual return value + self.condition_result = result + + self.condition_evaluated = True + return self.condition_result + + def then(self, callback_func, *extra_args, **extra_kwargs): + """ + Define the function to execute if the condition is truthy. + + Args: + callback_func: The function to execute if condition is true + *extra_args: Additional positional arguments for the callback + **extra_kwargs: Additional keyword arguments for the callback + """ + self.then_func = callback_func + self.then_args = extra_args + self.then_kwargs = extra_kwargs + return self + + def elseif(self, condition_func, callback_func, *extra_args, **extra_kwargs): + """ + Add an elif branch with its own condition and callback. + + Args: + condition_func: The condition to check if previous conditions were false + callback_func: The function to execute if this condition is true + *extra_args: Additional positional arguments for the callback + **extra_kwargs: Additional keyword arguments for the callback + """ + self.elif_branches.append((condition_func, callback_func, extra_args, extra_kwargs)) + return self + + def or_else(self, callback_func, *extra_args, **extra_kwargs): + """ + Define the function to execute if all conditions are falsy. + + Args: + callback_func: The function to execute if all conditions are false + *extra_args: Additional positional arguments for the callback + **extra_kwargs: Additional keyword arguments for the callback + """ + self.else_func = callback_func + self.else_args = extra_args + self.else_kwargs = extra_kwargs + return self + + # Alternative names for or_else + otherwise = or_else + else_ = or_else + + def do(self, callback_func, *extra_args, **extra_kwargs): + """ + Define a function to execute after all branches, regardless of condition. + + Args: + callback_func: The function to always execute at the end + *extra_args: Additional positional arguments for the callback + **extra_kwargs: Additional keyword arguments for the callback + """ + self.do_func = callback_func + self.do_args = extra_args + self.do_kwargs = extra_kwargs + return self + + def forward(self): + """ + Execute the appropriate callback based on the condition results. + + Returns: + The return value of whichever callback was executed, + or the do callback if no branch was executed. + """ + condition_result = self._evaluate_condition() + execution_result = None + branch_executed = False + + # Check main condition + if condition_result: + if self.then_func: + # Combine original args with then args, plus condition result + all_args = self.args + self.then_args + # If condition returned a non-boolean value, include it + if condition_result is not True: + all_args = (condition_result,) + all_args + all_kwargs = {**self.kwargs, **self.then_kwargs} + execution_result = self.then_func(*all_args, **all_kwargs) + branch_executed = True + + # Check elif branches if main condition was false + if not branch_executed: + for elif_condition, elif_func, elif_args, elif_kwargs in self.elif_branches: + # Evaluate elif condition with original args + elif_result = elif_condition(*self.args, **self.kwargs) + if elif_result: + # Combine args similar to then branch + all_args = self.args + elif_args + if elif_result is not True: + all_args = (elif_result,) + all_args + all_kwargs = {**self.kwargs, **elif_kwargs} + execution_result = elif_func(*all_args, **all_kwargs) + branch_executed = True + break + + # Execute else if no condition was true + if not branch_executed and self.else_func: + # Combine original args with else args + all_args = self.args + self.else_args + all_kwargs = {**self.kwargs, **self.else_kwargs} + execution_result = self.else_func(*all_args, **all_kwargs) + branch_executed = True + + # Always execute do if defined + if self.do_func: + # Do gets the execution result as first arg if there was one + do_args = self.do_args + if execution_result is not None: + do_args = (execution_result,) + do_args + all_kwargs = {**self.kwargs, **self.do_kwargs} + do_result = self.do_func(*do_args, **all_kwargs) + # Return the execution result if there was one, otherwise the do result + return execution_result if execution_result is not None else do_result + + return execution_result \ No newline at end of file diff --git a/opto/features/flows/compose/llm.py b/opto/features/flows/compose/llm.py new file mode 100644 index 00000000..815c695b --- /dev/null +++ b/opto/features/flows/compose/llm.py @@ -0,0 +1,418 @@ +import opto.trace as trace +from typing import Tuple, Union, get_type_hints, Any, Dict, List, Optional, Callable +from opto.utils.llm import AbstractModel, LLM +from opto.features.flows.types import MultiModalPayload, QueryModel, StructuredInput, StructuredOutput, \ + ForwardMixin +from opto.trainer.guide import Guide +import numpy as np +import contextvars + +# =========== LLM Base Model =========== +""" +TracedLLM: +1. special operations that supports specifying inputs (system_prompt, user_prompt) to LLM and parsing of outputs, wrap + everything under one command. +2. Easy to use interface -- can be inherited by users. +3. Support multi-turn chatting (message history) + +Usage patterns: + +Direct use: (only supports single input, single output) (signature: str -> str) +llm = TracedLLM("You are a helpful assistant.") +response = llm("Hello, what's the weather in France today?") +""" + +USED_TracedLLM = contextvars.ContextVar('USED_TracedLLM', default=list()) + +class ChatHistory: + def __init__(self, max_round=25, auto_summary=False): + """Initialize chat history for multi-turn conversation. + + Args: + max_round: Maximum number of conversation rounds (user-assistant pairs) to keep in history + auto_summary: Whether to automatically summarize old messages + """ + self.messages: List[Dict[str, Any]] = [] + self.max_round = max_round + self.auto_summary = auto_summary + + def __len__(self): + return len(self.messages) + + def add_system_message(self, content: Union[trace.Node, str]): + """Add or replace a system message at the beginning of the chat history. + + Args: + content: The content of the system message + """ + # Check if the first message is a system message + if len(self.messages) > 0 and self.messages[0].get("role") == "system": + print("Warning: Replacing existing system message.") + self.messages[0] = {"role": "system", "content": content} + else: + # Insert system message at the beginning + self.messages.insert(0, {"role": "system", "content": content}) + self._trim_history() + + def add_message(self, content: Union[trace.Node, str], role): + """Alias for add""" + return self.add(content, role) + + def add(self, content: Union[trace.Node, str], role): + """Add a message to history with role validation. + + Args: + content: The content of the message + role: The role of the message ("user" or "assistant") + """ + if role not in ["user", "assistant"]: + raise ValueError(f"Invalid role '{role}'. Must be 'user' or 'assistant'.") + + # Check for alternating user/assistant pattern + if len(self.messages) > 0: + last_msg = self.messages[-1] + if last_msg["role"] == role: + print(f"Warning: Adding consecutive {role} messages. Consider alternating user/assistant messages.") + + self.messages.append({"role": role, "content": content}) + self._trim_history() + + def append(self, message: Dict[str, Any]): + """Append a message directly to history.""" + if "role" not in message or "content" not in message or type(message) is not dict: + raise ValueError("Message must have 'role' and 'content' fields.") + self.add(message["content"], message["role"]) + + def __iter__(self): + return iter(self.messages) + + def __getitem__(self, index): + """Get a specific round or slice of rounds as a ChatHistory object. + + Args: + index: Integer index or slice object + + Returns: + ChatHistory: A new ChatHistory object containing the selected round(s). + Each round includes [system_prompt, user_prompt, response] where system_prompt is + included if it exists in the chat history. + """ + # Get user and assistant message indices + user_indices = [i for i, msg in enumerate(self.messages) if msg["role"] == "user"] + assistant_indices = [i for i, msg in enumerate(self.messages) if msg["role"] == "assistant"] + + # Build rounds (user-assistant pairs that are complete) + rounds = [] + for user_idx in user_indices: + # Find corresponding assistant response + assistant_msg = None + for asst_idx in assistant_indices: + if asst_idx > user_idx: + assistant_msg = self.messages[asst_idx] + break + + if assistant_msg: + rounds.append([self.messages[user_idx], assistant_msg]) + + # Create new ChatHistory object + new_history = ChatHistory(max_round=self.max_round, auto_summary=self.auto_summary) + + # Handle slicing + if isinstance(index, slice): + selected_rounds = rounds[index] + # Add system message if it exists + if self.messages and self.messages[0].get("role") == "system": + new_history.messages.append(self.messages[0].copy()) + # Add all selected rounds + for round_msgs in selected_rounds: + for msg in round_msgs: + new_history.messages.append({"role": msg["role"], "content": msg["content"]}) + return new_history + + # Handle single index (including negative indexing) + round_msgs = rounds[index] # This will handle negative indices and raise IndexError if out of bounds + + # Add system message if it exists + if self.messages and self.messages[0].get("role") == "system": + new_history.messages.append({"role": self.messages[0]["role"], "content": self.messages[0]["content"]}) + + # Add the selected round + for msg in round_msgs: + new_history.messages.append({"role": msg["role"], "content": msg["content"]}) + + return new_history + + def get_messages(self) -> List[Dict[str, str]]: + messages = [] + for message in self.messages: + if isinstance(message['content'], trace.Node): + messages.append({"role": message["role"], "content": message["content"].data}) + else: + messages.append(message) + return messages + + def get_messages_as_node(self, llm_name="") -> List[trace.Node]: + node_list = [] + for message in self.messages: + # If user query is a node and has other computation attached, we can't rename it + if isinstance(message['content'], trace.Node): + node_list.append(message['content']) + else: + role = message["role"] + content = message["content"] + name = f"{llm_name}_{role}" if llm_name else f"{role}" + if role == 'user': + name += "_query" + elif role == 'assistant': + name += "_response" + node_list.append(trace.node(content, name=name)) + + return node_list + + def copy(self, include_system=True): + """Create a deep copy of the chat history. + + Args: + include_system: Whether to include the system message in the copy (default: True) + + Returns: + ChatHistory: A new ChatHistory instance with the same messages + """ + new_history = ChatHistory(max_round=self.max_round, auto_summary=self.auto_summary) + for message in self.messages: + # Skip system message if include_system is False + if not include_system and message["role"] == "system": + continue + # Create a new dict to avoid reference issues + new_message = {"role": message["role"], "content": message["content"]} + new_history.messages.append(new_message) + return new_history + + def remove_system_message(self): + """Create a copy of the chat history without the system message. + + Returns: + ChatHistory: A new ChatHistory instance without the system message + """ + return self.copy(include_system=False) + + def __add__(self, other): + """Merge two chat histories together. + + Args: + other: Another ChatHistory instance + + Returns: + ChatHistory: A new ChatHistory instance with messages from both histories + + Raises: + TypeError: If other is not a ChatHistory instance + ValueError: If both histories have system prompts + """ + if not isinstance(other, ChatHistory): + raise TypeError("Can only add ChatHistory instances together") + + # Check if both have system messages + has_system_self = self.messages and self.messages[0].get("role") == "system" + has_system_other = other.messages and other.messages[0].get("role") == "system" + + if has_system_self and has_system_other: + raise ValueError("Cannot merge two chat histories that both have system prompts") + + # Create new history with max of the two max_rounds + new_history = ChatHistory( + max_round=max(self.max_round, other.max_round), + auto_summary=self.auto_summary or other.auto_summary + ) + + # Add messages from self + for message in self.messages: + new_message = {"role": message["role"], "content": message["content"]} + new_history.messages.append(new_message) + + # Add messages from other, skipping system message if self already has one + for message in other.messages: + if has_system_self and message["role"] == "system": + continue + new_message = {"role": message["role"], "content": message["content"]} + new_history.messages.append(new_message) + + # Trim the combined history to respect max_round + new_history._trim_history() + + return new_history + + def _trim_history(self): + """Trim history to max_round while preserving system message and the first round.""" + # Count the number of rounds (user-assistant pairs) + user_indices = [i for i, msg in enumerate(self.messages) if msg["role"] == "user"] + assistant_indices = [i for i, msg in enumerate(self.messages) if msg["role"] == "assistant"] + + # If we don't have enough messages to form complete rounds, return + if len(user_indices) <= self.max_round or len(assistant_indices) < len(user_indices): + return + + protected_messages = [] + + # Keep system message if it exists + if self.messages and self.messages[0].get("role") == "system": + protected_messages.append(self.messages[0]) + + # Always keep the first round (first user message and first assistant response) + if user_indices: + first_user_idx = user_indices[0] + protected_messages.append(self.messages[first_user_idx]) + + # Find the first assistant message after the first user message + for i in assistant_indices: + if i > first_user_idx: + protected_messages.append(self.messages[i]) + break + + # Calculate how many recent rounds we can keep + num_protected_rounds = 1 if user_indices else 0 # We've protected the first round if it exists + remaining_rounds = self.max_round - num_protected_rounds + + if remaining_rounds > 0: + # Get the most recent rounds (user-assistant pairs) + recent_rounds = [] + + # Start from the most recent user message and work backwards + for i in range(len(user_indices) - 1, num_protected_rounds - 1, -1): + if len(recent_rounds) // 2 >= remaining_rounds: + break + + user_idx = user_indices[i] + user_msg = self.messages[user_idx] + + # Find the corresponding assistant response + assistant_msg = None + for j in assistant_indices: + if j > user_idx: + assistant_msg = self.messages[j] + break + + if assistant_msg: + # Add the pair in chronological order + recent_rounds = [user_msg, assistant_msg] + recent_rounds + + # Combine protected messages with recent rounds + self.messages = protected_messages + recent_rounds + else: + self.messages = protected_messages + + +DEFAULT_SYSTEM_PROMPT_DESCRIPTION = ("the system prompt to the agent. By tuning this prompt, we can control the " + "behavior of the agent. For example, it can be used to provide instructions to " + "the agent (such as how to reason about the problem, how to use tools, " + "how to answer the question), or provide in-context examples of how to solve the " + "problem.") + + +@trace.model +class TracedLLM: + """ + This high-level model provides an easy-to-use interface for LLM calls with system prompts and optional chat history. + + Python usage patterns: + + llm = UF_LLM(system_prompt) + response = llm.chat(user_prompt) + response_2 = llm.chat(user_prompt_2) + + The underlying Trace Graph: + TracedLLM_response0 = TracedLLM.forward.call_llm(args_0=system_prompt0, args_1=TracedLLM0_user_query0) + TracedLLM_response1 = TracedLLM.forward.call_llm(args_0=system_prompt0, args_1=TracedLLM0_user_query0, args_2=TracedLLM_response0, args_3=TracedLLM0_user_query1) + TracedLLM_response2 = TracedLLM.forward.call_llm(args_0=system_prompt0, args_1=TracedLLM0_user_query0, args_2=TracedLLM_response0, args_3=TracedLLM0_user_query1, args_4=TracedLLM_response1, args_5=TracedLLM0_user_query2) + """ + + def __init__(self, + system_prompt: Union[str, None, trace.Node] = None, + llm: AbstractModel = None, chat_history_on=False, + trainable=False, model_name=None): + """Initialize TracedLLM with a system prompt. + + Args: + system_prompt: The system prompt to use for LLM calls. If None and the class has a docstring, the docstring will be used. + llm: The LLM model to use for inference + chat_history_on: if on, maintain chat history for multi-turn conversations + model_name: override the default name of the model + """ + if system_prompt is None: + system_prompt = "You are a helpful assistant." + + self.system_prompt = trace.node(system_prompt, name='system_prompt', + description=DEFAULT_SYSTEM_PROMPT_DESCRIPTION, + trainable=trainable) + # if system_prompt is already a node, then we have to override its trainable attribute + self.system_prompt.trainable = trainable + + if llm is None: + llm = LLM() + assert isinstance(llm, AbstractModel), f"{llm} must be an instance of AbstractModel" + self.llm = llm + self.chat_history = ChatHistory() + self.chat_history_on = chat_history_on + + current_llm_sessions = USED_TracedLLM.get() + self.model_name = model_name if model_name else f"{self.__class__.__name__}{len(current_llm_sessions)}" + current_llm_sessions.append(1) # just a marker + + def forward(self, user_query: str, + payload: Optional[MultiModalPayload] = None, + chat_history_on: Optional[bool] = None) -> str: + """This function takes user_query as input, and returns the response from the LLM, with the system prompt prepended. + This method will always save chat history. + + If chat_history_on is set to False, the chat history will not be included in the LLM input. + If chat_history_on is None, it will use the class-level chat_history_on setting. + If chat_history_on is True, the chat history will be included in the LLM input. + + Args: + user_query: The user query to send to the LLM. Can be + + Returns: + str: For direct pattern + """ + chat_history_on = self.chat_history_on if chat_history_on is None else chat_history_on + + user_message = QueryModel(query=user_query, multimodal_payload=payload).query + + messages = [{"role": "system", "content": self.system_prompt.data}] + if chat_history_on: + messages.extend(self.chat_history.get_messages()) + messages.append({"role": "user", "content": user_message}) + + response = self.llm(messages=messages) + + @trace.bundle(output_name=f"{self.model_name}_response") + def call_llm(*messages) -> str: + """Call the LLM model. + Args: + messages: All the conversation history so far, starting from system prompt, to alternating user/assistant messages, ending with the current user query. + Returns: + response from the LLM + """ + return response.choices[0].message.content + + user_query_node = trace.node(user_query, name=f"{self.model_name}_user_query") + arg_list = [self.system_prompt] + if chat_history_on: + arg_list += self.chat_history.get_messages_as_node(self.model_name) + arg_list += [user_query_node] + + response_node = call_llm(*arg_list) + + # save to chat history + self.chat_history.add(user_query_node, role="user") + self.chat_history.add(response_node, role="assistant") + + return response_node + + def chat(self, user_query: str, payload: Optional[MultiModalPayload] = None, + chat_history_on: Optional[bool] = None) -> str: + """Note that chat/forward always assumes it's a single turn of the conversation. History/context management will be accomplished + through other APIs""" + return self.forward(user_query, payload, chat_history_on) + +# =========== =========== diff --git a/opto/features/flows/compose/parser.py b/opto/features/flows/compose/parser.py new file mode 100644 index 00000000..25fb84fa --- /dev/null +++ b/opto/features/flows/compose/parser.py @@ -0,0 +1,398 @@ +from typing import Optional, Callable, Any, Dict, Literal, get_type_hints +from abc import ABC, abstractmethod +import inspect +import json +from opto.features.flows.types import StructuredInput, StructuredOutput, ForwardMixin + + +class PromptAdapter(ABC): + """Base adapter for converting structured types to/from LLM messages""" + + @abstractmethod + def format_system_prompt( + self, + task_description: str, + input_type: type[StructuredInput], + output_type: type[StructuredOutput] + ) -> str: + """Generate system prompt explaining the task""" + pass + + @abstractmethod + def format_input(self, input_data: StructuredInput) -> str: + """Convert input instance to string for LLM""" + pass + + @abstractmethod + def parse_output(self, llm_response: str, output_type: type[StructuredOutput]) -> StructuredOutput: + """Parse LLM response into output instance""" + pass + + +class JSONAdapter(PromptAdapter): + """Standard JSON-based communication""" + + def format_system_prompt(self, task_description: str, input_type, output_type) -> str: + output_schema = json.dumps(output_type.model_json_schema(), indent=2) + return f"""Task: {task_description} + +Output Schema: +{output_schema} + +Respond with valid JSON matching the schema above.""" + + def format_input(self, input_data: StructuredInput) -> str: + return input_data.model_dump_json(indent=2) + + def parse_output(self, llm_response: str, output_type) -> StructuredOutput: + # Try direct JSON parsing + try: + return output_type.model_validate_json(llm_response) + except Exception: + # Fallback: extract JSON from markdown code blocks + import re + json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', llm_response, re.DOTALL) + if json_match: + return output_type.model_validate_json(json_match.group(1)) + raise ValueError(f"Failed to parse JSON response: {llm_response[:200]}") + + +class MarkdownAdapter(PromptAdapter): + """Innovative markdown-based format using YAML-style frontmatter + sections""" + + def format_system_prompt(self, task_description: str, input_type, output_type) -> str: + # Build field descriptions + output_fields = [] + for field_name, field_info in output_type.model_fields.items(): + field_type = field_info.annotation.__name__ if hasattr(field_info.annotation, '__name__') else str( + field_info.annotation) + desc = field_info.description or "" + output_fields.append(f"- **{field_name}** (`{field_type}`): {desc}") + + fields_str = "\n".join(output_fields) + + return f"""# Task +{task_description} + +# Output Format +Respond using markdown with YAML frontmatter for metadata and sections for complex content. + +## Required Fields +{fields_str} + +## Structure +```markdown +--- +field_name: simple value +other_field: simple value +--- + +# SectionField (if complex content) +Complex content here... + +# AnotherSectionField +More complex content... +``` + +Use frontmatter for simple fields (strings, numbers, booleans). +Use markdown sections (# FieldName) for complex content (long text, lists, nested structures).""" + + def format_input(self, input_data: StructuredInput) -> str: + """Format input as markdown with frontmatter""" + lines = ["---"] + + # Simple fields in frontmatter + complex_fields = {} + for field_name, value in input_data.model_dump().items(): + if isinstance(value, (str, int, float, bool, type(None))): + if isinstance(value, str) and ('\n' in value or len(value) > 100): + complex_fields[field_name] = value + else: + lines.append(f"{field_name}: {value}") + else: + complex_fields[field_name] = value + + lines.append("---") + lines.append("") + + # Complex fields as sections + for field_name, value in complex_fields.items(): + lines.append(f"# {field_name}") + if isinstance(value, str): + lines.append(value) + else: + lines.append(json.dumps(value, indent=2)) + lines.append("") + + return "\n".join(lines) + + def parse_output(self, llm_response: str, output_type) -> StructuredOutput: + """Parse markdown frontmatter + sections into structured output""" + import re + + result = {} + + # Extract frontmatter + frontmatter_match = re.match(r'^---\s*\n(.*?)\n---\s*\n', llm_response, re.DOTALL) + if frontmatter_match: + frontmatter = frontmatter_match.group(1) + remaining = llm_response[frontmatter_match.end():] + + # Parse YAML-style frontmatter + for line in frontmatter.split('\n'): + if ':' in line: + key, value = line.split(':', 1) + key = key.strip() + value = value.strip() + + # Type coercion + if value.lower() == 'true': + value = True + elif value.lower() == 'false': + value = False + elif value.isdigit(): + value = int(value) + else: + try: + value = float(value) + except ValueError: + pass # Keep as string + + result[key] = value + else: + remaining = llm_response + + # Extract sections + sections = re.split(r'^# (\w+)\s*$', remaining, flags=re.MULTILINE) + for i in range(1, len(sections), 2): + if i + 1 < len(sections): + field_name = sections[i].strip() + content = sections[i + 1].strip() + result[field_name] = content + + return output_type(**result) + + +class XMLAdapter(PromptAdapter): + """XML-based format similar to Anthropic's style""" + + def format_system_prompt(self, task_description: str, input_type, output_type) -> str: + # Build field descriptions + output_fields = [] + for field_name, field_info in output_type.model_fields.items(): + field_type = field_info.annotation.__name__ if hasattr(field_info.annotation, '__name__') else str( + field_info.annotation) + desc = field_info.description or "" + output_fields.append(f" <{field_name} type='{field_type}'>{desc}") + + fields_str = "\n".join(output_fields) + + return f"""Task: {task_description} + +Output Format: Respond with XML structure containing these fields: + +{fields_str} + + +Provide actual values, not descriptions.""" + + def format_input(self, input_data: StructuredInput) -> str: + """Convert to XML format""" + lines = [""] + for field_name, value in input_data.model_dump().items(): + if isinstance(value, (list, dict)): + value = json.dumps(value) + lines.append(f" <{field_name}>{value}") + lines.append("") + return "\n".join(lines) + + def parse_output(self, llm_response: str, output_type) -> StructuredOutput: + """Parse XML response""" + import re + + result = {} + + # Extract fields using regex (simple XML parsing) + for field_name in output_type.model_fields.keys(): + pattern = f'<{field_name}>(.*?)' + match = re.search(pattern, llm_response, re.DOTALL) + if match: + value = match.group(1).strip() + + # Try JSON parsing for complex types + try: + value = json.loads(value) + except (json.JSONDecodeError, TypeError): + pass + + result[field_name] = value + + return output_type(**result) + + +# =========== Structured LLM Input/Output With Parsing =========== + +""" +Usage: + +@llm_call +def evaluate_person(person: Person) -> Preference: + "Evaluate if a person matches our criteria" + ... + +person = Person(name="Alice", age=30, income=75000) +preference = evaluate_person(person) + +TODO 2: add trace bundle and input/output conversion +""" + + +class StructuredLLMCallable: + """Enhanced wrapper supporting multiple function patterns""" + + def __init__( + self, + func: Callable, + llm, + input_type: type[StructuredInput], + output_type: type[StructuredOutput], + has_preprocessing: bool, + adapter: PromptAdapter + ): + self.func = func + self.llm = llm + self.input_type = input_type + self.output_type = output_type + self.has_preprocessing = has_preprocessing + self.adapter = adapter + + # Store output_type in adapter if it's MarkdownAdapter (for user message generation) + if isinstance(adapter, MarkdownAdapter): + adapter.output_type = output_type + + # Copy metadata + self.__name__ = func.__name__ + self.__doc__ = func.__doc__ + self.__module__ = func.__module__ + self.__annotations__ = func.__annotations__ + + def __call__( + self, + input_data: StructuredInput, + system_prompt: Optional[str] = None + ) -> StructuredOutput: + """Execute with automatic LLM invocation""" + + # Validate input type + if not isinstance(input_data, self.input_type): + raise TypeError(f"Expected {self.input_type}, got {type(input_data)}") + + # Execute preprocessing if function has implementation + if self.has_preprocessing: + result = self.func(input_data) + + # Check if function returned the output type class (Usage 2 & 3) + if inspect.isclass(result) and issubclass(result, StructuredOutput): + self.output_type = result + # Update adapter's output_type + if isinstance(self.adapter, MarkdownAdapter): + self.adapter.output_type = result + elif isinstance(result, StructuredOutput): + # Function did full processing, return directly + return result + + # Build system prompt using adapter + if system_prompt is None: + task_description = self.func.__doc__ or "Process the input data" + system_prompt = self.adapter.format_system_prompt( + task_description, + self.input_type, + self.output_type + ) + + # Format input using adapter + user_message = self.adapter.format_input(input_data) + + # Invoke LLM + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_message} + ] + + response = self.llm(messages=messages) + + # Parse output using adapter + output_instance = self.adapter.parse_output(response, self.output_type) + + return output_instance + + +def llm_call( + func: Callable = None, + *, + llm=None, + adapter: Literal["json", "markdown", "xml"] = "markdown", + type_hints: Optional[Dict[str, str]] = None, + **kwargs +): + """ + Enhanced decorator supporting three usage patterns. + + Args: + func: Function to decorate + llm: LLM instance + adapter: Output format - "json", "markdown", or "xml" + type_hints: Custom type hints for markdown adapter fields, e.g., + {"confidence": "must be a single float value between 0 and 1"} + """ + + # Create adapter instance + adapter_map = { + "json": JSONAdapter(), + "markdown": MarkdownAdapter(type_hints=type_hints), + "xml": XMLAdapter() + } + adapter_instance = adapter_map[adapter] + + def decorator(f: Callable): + hints = get_type_hints(f) + + # Extract input type (first parameter) + params = list(hints.items()) + input_type = None + for param_name, param_type in params: + if param_name != 'return': + input_type = param_type + break + + # Get output type from annotation + output_type = hints.get('return') + + # Detect if function has implementation (Usage 2 & 3) + source = inspect.getsource(f) + has_implementation = not source.strip().endswith('...') + + # Validate types + if input_type and not issubclass(input_type, StructuredInput): + raise TypeError(f"Input type must inherit from StructuredInput") + + if output_type and not issubclass(output_type, StructuredOutput): + raise TypeError(f"Output type must inherit from StructuredOutput") + + # Use default LLM if none provided + actual_llm = llm if llm is not None else LLM() + + return StructuredLLMCallable( + f, + actual_llm, + input_type, + output_type, + has_implementation, + adapter_instance + ) + + # Handle decorator syntax + if func is None: + return decorator + else: + return decorator(func) diff --git a/opto/features/flows/types.py b/opto/features/flows/types.py index 4196b926..685c1e41 100644 --- a/opto/features/flows/types.py +++ b/opto/features/flows/types.py @@ -1,10 +1,261 @@ """Types for opto flows.""" -from pydantic import BaseModel, Field, create_model, ConfigDict +from typing import List, Dict, Union +from pydantic import BaseModel, model_validator from typing import Any, Optional, Callable, Dict, Union, Type, List +from dataclasses import dataclass import re import json +from opto.optimizers.utils import encode_image_to_base64 +from opto import trace + class TraceObject: def __str__(self): # Any subclass that inherits this will be friendly to the optimizer - raise NotImplementedError("Subclasses must implement __str__") \ No newline at end of file + raise NotImplementedError("Subclasses must implement __str__") + + +class ForwardMixin: + def forward(self, *args, **kwargs): + raise NotImplementedError("Subclasses must implement forward") + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + +# ====== Multi-Modal LLM Support ====== +class MultiModalPayload(BaseModel): + image_bytes: Optional[str] = None # base64-encoded data URL + + @classmethod + def from_path(cls, path: str) -> "MultiModalPayload": + """Create a payload by loading an image from a local file path.""" + data_url = encode_image_to_base64(path) + return cls(image_bytes=data_url) + + def load_image(self, path: str) -> None: + """Mutate the current payload to include a new image.""" + self.image_bytes = encode_image_to_base64(path) + + +class QueryModel(BaseModel): + # Expose "query" as already-normalized: always a List[Dict[str, Any]] + query: List[Dict[str, Any]] + multimodal_payload: Optional[MultiModalPayload] = None + + @model_validator(mode="before") + @classmethod + def normalize(cls, data: Any): + """ + Accepts: + { "query": "hello" } + { "query": "hello", "multimodal_payload": {"image_bytes": "..."} } + And always produces: + { "query": [ {text block}, maybe {image_url block} ], "multimodal_payload": ...} + """ + if not isinstance(data, dict): + raise TypeError("QueryModel input must be a dict") + + raw_query: Any = data.get("query") + if isinstance(raw_query, trace.Node): + assert isinstance(raw_query.data, (str, list)), "If using trace.Node, its data must be str" + raw_query = raw_query.data + + # 1) Start with the text part + if isinstance(raw_query, str): + out: List[Dict[str, Any]] = [{"type": "text", "text": raw_query}] + else: + raise TypeError("`query` must be a string") + + # 2) If we have an image, append an image block + payload = data.get("multimodal_payload") + image_bytes: Optional[str] = None + if payload is not None: + if isinstance(payload, dict): + image_bytes = payload.get("image_bytes") + else: + # Could be already-parsed MultiModalPayload + image_bytes = getattr(payload, "image_bytes", None) + + if image_bytes: + out = out + [{ + "type": "image_url", + "image_url": {"url": image_bytes} + }] + + # 3) Write back normalized fields + data["query"] = out + return data + + +# ====== ====== + +# ======= Structured LLM Info Support ====== + +from typing import Any, Dict, Optional, Type, get_type_hints +from pydantic import BaseModel, create_model, Field +import inspect + + +class StructuredData(BaseModel): + """ + Base class for structured data (inputs/outputs) with support for both + inheritance and dynamic on-the-fly usage. + """ + + _docstring: Optional[str] = None + + def __init_subclass__(cls, **kwargs): + """Called when a class inherits from StructuredData""" + super().__init_subclass__(**kwargs) + cls._docstring = inspect.getdoc(cls) + + def __new__(cls, docstring: Optional[str] = None, **kwargs): + """ + Handle both inheritance and on-the-fly usage. + """ + # Check if being used dynamically (direct instantiation with docstring) + if cls in (StructuredData, StructuredInput, StructuredOutput) and \ + docstring is not None and isinstance(docstring, str): + # Determine the appropriate class name for dynamic instances + if cls is StructuredInput or (cls is StructuredData and 'Input' in str(cls)): + dynamic_name = 'DynamicStructuredInput' + elif cls is StructuredOutput: + dynamic_name = 'DynamicStructuredOutput' + else: + dynamic_name = 'DynamicStructuredData' + + dynamic_cls = type( + dynamic_name, + (cls,), + { + '__doc__': docstring, + '_docstring': docstring, + '__module__': cls.__module__, + } + ) + instance = super(StructuredData, dynamic_cls).__new__(dynamic_cls) + return instance + + return super().__new__(cls) + + def __init__(self, docstring: Optional[str] = None, **kwargs): + """Initialize the instance""" + dynamic_names = ('DynamicStructuredData', 'DynamicStructuredInput', 'DynamicStructuredOutput') + if isinstance(docstring, str) and self.__class__.__name__ in dynamic_names: + super().__init__(**kwargs) + object.__setattr__(self, '_docstring', docstring) + else: + super().__init__(**kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + """Allow dynamic attribute setting for on-the-fly usage""" + if name.startswith('_'): + super().__setattr__(name, value) + else: + dynamic_names = ('DynamicStructuredData', 'DynamicStructuredInput', 'DynamicStructuredOutput') + if self.__class__.__name__ in dynamic_names: + object.__setattr__(self, name, value) + else: + super().__setattr__(name, value) + + @classmethod + def get_docstring(cls) -> Optional[str]: + """Get the docstring of the class""" + return getattr(cls, '_docstring', None) or inspect.getdoc(cls) + + @classmethod + def get_fields_info(cls) -> Dict[str, Any]: + """Get information about all fields""" + base_names = ('StructuredData', 'StructuredInput', 'StructuredOutput') + dynamic_names = ('DynamicStructuredData', 'DynamicStructuredInput', 'DynamicStructuredOutput') + + if cls.__name__ in dynamic_names or cls.__name__ in base_names: + return {} + + fields_info = {} + for field_name, field in cls.model_fields.items(): + fields_info[field_name] = { + 'type': field.annotation, + 'required': field.is_required(), + 'default': field.default if field.default is not None else None, + } + return fields_info + + def get_instance_fields(self) -> Dict[str, Any]: + """Get all fields and their values from the instance""" + dynamic_names = ('DynamicStructuredData', 'DynamicStructuredInput', 'DynamicStructuredOutput') + if self.__class__.__name__ in dynamic_names: + return { + k: v for k, v in self.__dict__.items() + if not k.startswith('_') + } + else: + return self.model_dump() + + def __str__(self, template: Optional[str] = None) -> str: + """ + Convert the structured data to a string format for LLM consumption. + Subclasses can override for specific formatting. + """ + docstring = self.get_docstring() or "No description provided" + fields = self.get_instance_fields() + + if template: + return template.format(docstring=docstring, fields=fields) + + lines = [f"Description: {docstring}", "", "Fields:"] + for field_name, field_value in fields.items(): + lines.append(f"- {field_name}: {field_value}") + + return "\n".join(lines) + + +class StructuredInput(StructuredData): + """ + Base class for structured inputs that can be used via inheritance or dynamically. + """ + + def __str__(self, template: Optional[str] = None) -> str: + """ + Convert the structured input to a string format emphasizing input data. + """ + docstring = self.get_docstring() or "No description provided" + fields = self.get_instance_fields() + + if template: + return template.format(docstring=docstring, fields=fields) + + lines = [f"Input: {docstring}", "", "Provided data:"] + for field_name, field_value in fields.items(): + lines.append(f"- {field_name}: {field_value}") + + return "\n".join(lines) + + +class StructuredOutput(StructuredData): + """ + Base class for structured outputs from LLM functions. + """ + + def __str__(self, template: Optional[str] = None) -> str: + """ + Convert the structured output to a string format emphasizing results. + """ + docstring = self.get_docstring() or "No description provided" + fields = self.get_instance_fields() + + if template: + return template.format(docstring=docstring, fields=fields) + + lines = [f"Output: {docstring}", "", "Results:"] + for field_name, field_value in fields.items(): + lines.append(f"- {field_name}: {field_value}") + + return "\n".join(lines) + +# ======= ====== + +# ======= Agentic Optimizer Support ======= + +# ======= ======= diff --git a/opto/optimizers/opro_v2.py b/opto/optimizers/opro_v2.py index ff5c801d..19b33e58 100644 --- a/opto/optimizers/opro_v2.py +++ b/opto/optimizers/opro_v2.py @@ -1,7 +1,7 @@ import json from textwrap import dedent from dataclasses import dataclass, asdict -from typing import Dict +from typing import Dict, Optional from opto.optimizers.optoprime_v2 import OptoPrimeV2, OptimizerPromptSymbolSet @@ -15,8 +15,8 @@ class OPROPromptSymbolSet(OptimizerPromptSymbolSet): Attributes ---------- - problem_context_section_title : str - Title for the problem context section in prompts. + instruction_section_title : str + Title for the instruction section in prompts. variable_section_title : str Title for the variable/solution section in prompts. feedback_section_title : str @@ -49,9 +49,10 @@ class OPROPromptSymbolSet(OptimizerPromptSymbolSet): more focused set of symbols specifically for OPRO optimization. """ - problem_context_section_title = "# Problem Context" + instruction_section_title = "# Instruction" variable_section_title = "# Solution" feedback_section_title = "# Feedback" + context_section_title = "# Context" node_tag = "node" # nodes that are constants in the graph variable_tag = "solution" # nodes that can be changed @@ -72,6 +73,7 @@ def default_prompt_symbols(self) -> Dict[str, str]: "variables": self.variables_section_title, "feedback": self.feedback_section_title, "instruction": self.instruction_section_title, + "context": self.context_section_title } @dataclass @@ -89,6 +91,9 @@ class ProblemInstance: The current proposed solution that can be modified. feedback : str Feedback about the current solution. + context: str + Optional context information that might be useful to solve the problem. + optimizer_prompt_symbol_set : OPROPromptSymbolSet The symbol set used for formatting the problem. problem_template : str @@ -107,12 +112,13 @@ class ProblemInstance: instruction: str variables: str feedback: str + context: Optional[str] optimizer_prompt_symbol_set: OPROPromptSymbolSet problem_template = dedent( """ - # Problem Context + # Instruction {instruction} # Solution @@ -124,12 +130,24 @@ class ProblemInstance: ) def __repr__(self) -> str: - return self.problem_template.format( + optimization_query = self.problem_template.format( instruction=self.instruction, variables=self.variables, feedback=self.feedback, ) + context_section = dedent(""" + + # Context + {context} + """) + + if self.context is not None and self.context.strip() != "": + context_section.format(context=self.context) + optimization_query += context_section + + return optimization_query + class OPROv2(OptoPrimeV2): """OPRO (Optimization by PROmpting) optimizer version 2. @@ -197,6 +215,7 @@ class OPROv2(OptoPrimeV2): - {instruction_section_title}: the instruction which describes the things you need to do or the question you should answer. - {variables_section_title}: the proposed solution that you can change/tweak (trainable). - {feedback_section_title}: the feedback about the solution. + - {context_section_title}: the context information that might be useful to solve the problem. If `data_type` is `code`, it means `{value_tag}` is the source code of a python code, which may include docstring and definitions. """ @@ -229,6 +248,14 @@ class OPROv2(OptoPrimeV2): """ ) + context_prompt = dedent( + """ + Here is some additional **context** to solving this problem: + + {context} + """ + ) + final_prompt = dedent( """ What are your revised solutions on {names}? @@ -244,6 +271,7 @@ def __init__(self, *args, optimizer_prompt_symbol_set: OptimizerPromptSymbolSet = None, include_example=False, # default example in OptoPrimeV2 does not work in OPRO memory_size=5, + problem_context: Optional[str] = None, **kwargs): """Initialize the OPROv2 optimizer. @@ -264,6 +292,7 @@ def __init__(self, *args, optimizer_prompt_symbol_set = optimizer_prompt_symbol_set or OPROPromptSymbolSet() super().__init__(*args, optimizer_prompt_symbol_set=optimizer_prompt_symbol_set, include_example=include_example, memory_size=memory_size, + problem_context=problem_context, **kwargs) def problem_instance(self, summary, mask=None): @@ -328,6 +357,7 @@ def initialize_prompt(self): variables_section_title=self.optimizer_prompt_symbol_set.variables_section_title.replace(" ", ""), feedback_section_title=self.optimizer_prompt_symbol_set.feedback_section_title.replace(" ", ""), instruction_section_title=self.optimizer_prompt_symbol_set.instruction_section_title.replace(" ", ""), + context_section_title=self.optimizer_prompt_symbol_set.context_section_title.replace(" ", "") ) self.output_format_prompt = self.output_format_prompt_template.format( output_format=self.optimizer_prompt_symbol_set.output_format, @@ -336,4 +366,5 @@ def initialize_prompt(self): instruction_section_title=self.optimizer_prompt_symbol_set.instruction_section_title.replace(" ", ""), feedback_section_title=self.optimizer_prompt_symbol_set.feedback_section_title.replace(" ", ""), variables_section_title=self.optimizer_prompt_symbol_set.variables_section_title.replace(" ", ""), + context_section_title=self.optimizer_prompt_symbol_set.context_section_title.replace(" ", "") ) diff --git a/opto/optimizers/optoprime_v2.py b/opto/optimizers/optoprime_v2.py index a512af8e..376f0a26 100644 --- a/opto/optimizers/optoprime_v2.py +++ b/opto/optimizers/optoprime_v2.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, asdict from opto.optimizers.optoprime import OptoPrime, FunctionFeedback from opto.trace.utils import dedent -from opto.optimizers.utils import truncate_expression, extract_xml_like_data +from opto.optimizers.utils import truncate_expression, extract_xml_like_data, encode_image_to_base64 from opto.trace.nodes import ParameterNode, Node, MessageNode from opto.trace.propagators import TraceGraph, GraphPropagator @@ -17,6 +17,11 @@ from typing import Dict, Any +@dataclass +class MultiModalPayload: + image_bytes: Optional[str] = None # base64 encoded image bytes + + class OptimizerPromptSymbolSet: """ By inheriting this class and pass into the optimizer. People can change the optimizer documentation @@ -37,6 +42,7 @@ class OptimizerPromptSymbolSet: instruction_section_title = "# Instruction" code_section_title = "# Code" documentation_section_title = "# Documentation" + context_section_title = "# Context" node_tag = "node" # nodes that are constants in the graph variable_tag = "variable" # nodes that can be changed @@ -141,6 +147,7 @@ def default_prompt_symbols(self) -> Dict[str, str]: "instruction": self.instruction_section_title, "code": self.code_section_title, "documentation": self.documentation_section_title, + "context": self.context_section_title } @@ -149,7 +156,7 @@ class OptimizerPromptSymbolSetJSON(OptimizerPromptSymbolSet): expect_json = True - custom_output_format_instruction = """ + custom_output_format_instruction = dedent(""" {{ "reasoning": , "suggestion": {{ @@ -157,7 +164,7 @@ class OptimizerPromptSymbolSetJSON(OptimizerPromptSymbolSet): : , }} }} - """ + """) def example_output(self, reasoning, variables): """ @@ -172,9 +179,65 @@ def example_output(self, reasoning, variables): } return json.dumps(output, indent=2) - def output_response_extractor(self, response: str, suggestion_tag = "suggestion") -> Dict[str, Any]: - # Use extract_llm_suggestion from OptoPrime => it could be implemented the other way around (OptoPrime would uses this helper but it should be moved out of OptoPrimev2) - return OptoPrime.extract_llm_suggestion(self, response, suggestion_tag=suggestion_tag, reasoning_tag="reasoning", return_only_suggestion=False) + def output_response_extractor(self, response: str) -> Dict[str, Any]: + reasoning = "" + suggestion_tag = "suggestion" + + if "```" in response: + response = response.replace("```", "").strip() + + suggestion = {} + attempt_n = 0 + while attempt_n < 2: + try: + suggestion = json.loads(response)[suggestion_tag] + reasoning = json.loads(response)[self.reasoning_tag] + break + except json.JSONDecodeError: + # Remove things outside the brackets + response = re.findall(r"{.*}", response, re.DOTALL) + if len(response) > 0: + response = response[0] + attempt_n += 1 + except Exception: + attempt_n += 1 + + if not isinstance(suggestion, dict): + suggestion = {} + + if len(suggestion) == 0: + # we try to extract key/value separately and return it as a dictionary + pattern = rf'"{suggestion_tag}"\s*:\s*\{{(.*?)\}}' + suggestion_match = re.search(pattern, str(response), re.DOTALL) + if suggestion_match: + suggestion = {} + # Extract the entire content of the suggestion dictionary + suggestion_content = suggestion_match.group(1) + # Regex to extract each key-value pair; + # This scheme assumes double quotes but is robust to missing commas at the end of the line + pair_pattern = r'"([a-zA-Z0-9_]+)"\s*:\s*"(.*)"' + # Find all matches of key-value pairs + pairs = re.findall(pair_pattern, suggestion_content, re.DOTALL) + for key, value in pairs: + suggestion[key] = value + + if len(suggestion) == 0: + print(f"Cannot extract suggestion from LLM's response:") + print(response) + + # if the suggested value is a code, and the entire code body is empty (i.e., not even function signature is present) + # then we remove such suggestion + keys_to_remove = [] + for key, value in suggestion.items(): + if "__code" in key and value.strip() == "": + keys_to_remove.append(key) + for key in keys_to_remove: + del suggestion[key] + + extracted_data = {"reasoning": reasoning, + "variables": suggestion} + + return extracted_data class OptimizerPromptSymbolSet2(OptimizerPromptSymbolSet): @@ -186,6 +249,7 @@ class OptimizerPromptSymbolSet2(OptimizerPromptSymbolSet): instruction_section_title = "# Instruction" code_section_title = "# Code" documentation_section_title = "# Documentation" + context_section_title = "# Context" node_tag = "const" # nodes that are constants in the graph variable_tag = "var" # nodes that can be changed @@ -208,6 +272,7 @@ class ProblemInstance: others: str outputs: str feedback: str + context: Optional[str] optimizer_prompt_symbol_set: OptimizerPromptSymbolSet @@ -240,7 +305,7 @@ class ProblemInstance: ) def __repr__(self) -> str: - return self.problem_template.format( + optimization_query = self.problem_template.format( instruction=self.instruction, code=self.code, documentation=self.documentation, @@ -249,12 +314,25 @@ def __repr__(self) -> str: outputs=self.outputs, others=self.others, feedback=self.feedback, + context=self.context ) + context_section = dedent(""" + + # Context + {context} + """) + + if self.context is not None and self.context.strip() != "": + context_section.format(context=self.context) + optimization_query += context_section + + return optimization_query + @dataclass class MemoryInstance: - variables: Dict[str, Tuple[Any, str]] # name -> (data, constraint) + variables: Dict[str, Tuple[Any, str]] # name -> (data, constraint) feedback: str optimizer_prompt_symbol_set: OptimizerPromptSymbolSet @@ -303,6 +381,7 @@ class OptoPrimeV2(OptoPrime): - {others_section_title}: the intermediate values created through the code execution. - {outputs_section_title}: the result of the code output. - {feedback_section_title}: the feedback about the code's execution result. + - {context_section_title}: the context information that might be useful to solve the problem. In `{variables_section_title}`, `{inputs_section_title}`, `{outputs_section_title}`, and `{others_section_title}`, the format is: @@ -357,17 +436,22 @@ class OptoPrimeV2(OptoPrime): example_prompt = dedent( """ - Here are some feasible but not optimal solutions for the current problem instance. Consider this as a hint to help you understand the problem better. ================================ - {examples} - ================================ """ ) + context_prompt = dedent( + """ + Here is some additional **context** to solving this problem: + + {context} + """ + ) + final_prompt = dedent( """ What are your suggestions on variables {names}? @@ -387,20 +471,20 @@ def __init__( # ignore the type conversion error when extracting updated values from LLM's suggestion include_example=False, memory_size=0, # Memory size to store the past feedback - max_tokens=8192, + max_tokens=4096, log=True, - initial_var_char_limit=2000, + initial_var_char_limit=100, optimizer_prompt_symbol_set: OptimizerPromptSymbolSet = OptimizerPromptSymbolSet(), use_json_object_format=True, # whether to use json object format for the response when calling LLM truncate_expression=truncate_expression, + problem_context: Optional[str] = None, **kwargs, ): super().__init__(parameters, *args, propagator=propagator, **kwargs) - if optimizer_prompt_symbol_set is None: - optimizer_prompt_symbol_set = OptimizerPromptSymbolSet() - self.truncate_expression = truncate_expression + self.problem_context = problem_context + self.multimodal_payload = MultiModalPayload() self.use_json_object_format = use_json_object_format if optimizer_prompt_symbol_set.expect_json and use_json_object_format else False self.ignore_extraction_error = ignore_extraction_error @@ -443,6 +527,27 @@ def __init__( self.prompt_symbols = copy.deepcopy(self.default_prompt_symbols) self.initialize_prompt() + def add_image_context(self, image_path: str, context: str = ""): + if self.problem_context is None: + self.problem_context = "" + + if context == "": + context = "The attached image is given to the workflow. You should use the image to help you understand the problem and provide better suggestions. You can refer to the image when providing your suggestions." + + self.problem_context += f"{context}\n\n" + + # we load in the image and convert to base64 + data_url = encode_image_to_base64(image_path) + self.multimodal_payload.image_bytes = data_url + + self.initialize_prompt() + + def add_context(self, context: str): + if self.problem_context is None: + self.problem_context = "" + self.problem_context += f"{context}\n\n" + self.initialize_prompt() + def initialize_prompt(self): self.representation_prompt = self.representation_prompt.format( variable_expression_format=dedent(f""" @@ -463,7 +568,8 @@ def initialize_prompt(self): instruction_section_title=self.optimizer_prompt_symbol_set.instruction_section_title.replace(" ", ""), code_section_title=self.optimizer_prompt_symbol_set.code_section_title.replace(" ", ""), documentation_section_title=self.optimizer_prompt_symbol_set.documentation_section_title.replace(" ", ""), - others_section_title=self.optimizer_prompt_symbol_set.others_section_title.replace(" ", "") + others_section_title=self.optimizer_prompt_symbol_set.others_section_title.replace(" ", ""), + context_section_title=self.optimizer_prompt_symbol_set.context_section_title.replace(" ", "") ) self.output_format_prompt = self.output_format_prompt_template.format( output_format=self.optimizer_prompt_symbol_set.output_format, @@ -476,7 +582,8 @@ def initialize_prompt(self): documentation_section_title=self.optimizer_prompt_symbol_set.documentation_section_title.replace(" ", ""), variables_section_title=self.optimizer_prompt_symbol_set.variables_section_title.replace(" ", ""), inputs_section_title=self.optimizer_prompt_symbol_set.inputs_section_title.replace(" ", ""), - others_section_title=self.optimizer_prompt_symbol_set.others_section_title.replace(" ", "") + others_section_title=self.optimizer_prompt_symbol_set.others_section_title.replace(" ", ""), + context_section_title=self.optimizer_prompt_symbol_set.context_section_title.replace(" ", "") ) def repr_node_value(self, node_dict, node_tag="node", @@ -603,6 +710,7 @@ def problem_instance(self, summary, mask=None): constraint_tag=self.optimizer_prompt_symbol_set.constraint_tag) if self.optimizer_prompt_symbol_set.others_section_title not in mask else "" ), feedback=summary.user_feedback if self.optimizer_prompt_symbol_set.feedback_section_title not in mask else "", + context=self.problem_context if self.optimizer_prompt_symbol_set.context_section_title not in mask else "", optimizer_prompt_symbol_set=self.optimizer_prompt_symbol_set ) @@ -664,9 +772,20 @@ def call_llm( if verbose not in (False, "output"): print("Prompt\n", system_prompt + user_prompt) + user_message_content = [] + if self.multimodal_payload.image_bytes is not None: + user_message_content.append({ + "type": "image_url", + "image_url": { + "url": self.multimodal_payload.image_bytes + } + }) + + user_message_content.append({"type": "text", "text": user_prompt}) + messages = [ {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, + {"role": "user", "content": user_message_content}, ] response_format = {"type": "json_object"} if self.use_json_object_format else None @@ -678,3 +797,42 @@ def call_llm( if verbose: print("LLM response:\n", response) return response + + def save(self, path: str): + """Save the optimizer state to a file.""" + with open(path, 'wb') as f: + pickle.dump({ + "truncate_expression": self.truncate_expression, + "use_json_object_format": self.use_json_object_format, + "ignore_extraction_error": self.ignore_extraction_error, + "objective": self.objective, + "initial_var_char_limit": self.initial_var_char_limit, + "optimizer_prompt_symbol_set": self.optimizer_prompt_symbol_set, + "include_example": self.include_example, + "max_tokens": self.max_tokens, + "memory": self.memory, + "default_prompt_symbols": self.default_prompt_symbols, + "prompt_symbols": self.prompt_symbols, + "representation_prompt": self.representation_prompt, + "output_format_prompt": self.output_format_prompt, + 'context_prompt': self.context_prompt + }, f) + + def load(self, path: str): + """Load the optimizer state from a file.""" + with open(path, 'rb') as f: + state = pickle.load(f) + self.truncate_expression = state["truncate_expression"] + self.use_json_object_format = state["use_json_object_format"] + self.ignore_extraction_error = state["ignore_extraction_error"] + self.objective = state["objective"] + self.initial_var_char_limit = state["initial_var_char_limit"] + self.optimizer_prompt_symbol_set = state["optimizer_prompt_symbol_set"] + self.include_example = state["include_example"] + self.max_tokens = state["max_tokens"] + self.memory = state["memory"] + self.default_prompt_symbols = state["default_prompt_symbols"] + self.prompt_symbols = state["prompt_symbols"] + self.representation_prompt = state["representation_prompt"] + self.output_format_prompt = state["output_format_prompt"] + self.context_prompt = state["context_prompt"] diff --git a/opto/optimizers/utils.py b/opto/optimizers/utils.py index 13a5ad01..4fbec459 100644 --- a/opto/optimizers/utils.py +++ b/opto/optimizers/utils.py @@ -1,5 +1,8 @@ +import base64 +import mimetypes from typing import Dict, Any + def print_color(message, color=None, logger=None): colors = { "red": "\033[91m", @@ -134,3 +137,17 @@ def extract_xml_like_data(text: str, reasoning_tag: str = "reasoning", if var_name: # Only require name to be non-empty, value can be empty result['variables'][var_name] = var_value return result + + +def encode_image_to_base64(path: str) -> str: + # Read binary + with open(path, "rb") as f: + image_bytes = f.read() + # Guess MIME type from file extension + mime_type, _ = mimetypes.guess_type(path) + if mime_type is None: + # fallback + mime_type = "image/jpeg" + b64 = base64.b64encode(image_bytes).decode("utf-8") + data_url = f"data:{mime_type};base64,{b64}" + return data_url diff --git a/opto/trace/nodes.py b/opto/trace/nodes.py index 0f4c85e5..5c86bbf2 100644 --- a/opto/trace/nodes.py +++ b/opto/trace/nodes.py @@ -8,7 +8,7 @@ import contextvars -def node(data, name=None, trainable=False, description=None): +def node(data, name=None, trainable=False, description=None, **kwargs): """Create a Node object from data. This is the primary factory function for creating nodes in the Trace computation graph. @@ -29,6 +29,9 @@ def node(data, name=None, trainable=False, description=None): description : str, optional A textual description of the node's purpose or constraints. Used as soft constraints during optimization and for documentation. + **kwargs : dict + Additional keyword arguments to pass to the Node or ParameterNode constructor, + such as 'info' or 'projections'. Returns ------- @@ -85,6 +88,7 @@ def node(data, name=None, trainable=False, description=None): name=name, trainable=True, description=description, + **kwargs ) else: if isinstance(data, Node): @@ -92,7 +96,7 @@ def node(data, name=None, trainable=False, description=None): warnings.warn(f"Name {name} is ignored because data is already a Node.") return data else: - return Node(data, name=name, description=description) + return Node(data, name=name, description=description, **kwargs) NAME_SCOPES = [] # A stack of name scopes diff --git a/tests/unit_tests/test_priority_search.py b/tests/unit_tests/test_priority_search.py index a7ff24d3..bc215523 100644 --- a/tests/unit_tests/test_priority_search.py +++ b/tests/unit_tests/test_priority_search.py @@ -130,6 +130,15 @@ def _llm_callable(messages, **kwargs): A dummy LLM callable that simulates a response. """ problem = messages[1]['content'] + # in newer LLM API (LiteLLM, OpenAI client, etc.), the user message content is now a list of typed messages: + # [{'type': 'text', 'text': '...'}, {'type': 'image', 'image_url': '...'}] + # this expansion is necessary for multi-modal inputs + + if type(problem) is list: + for typed_message in problem: + if typed_message['type'] == 'text': + problem = typed_message['text'] + break # extract name from name = re.findall(r"", problem)