From 8961b02455f1a928be1e1efd1ddd6c5a8bab44a8 Mon Sep 17 00:00:00 2001 From: windweller Date: Fri, 3 Oct 2025 15:40:37 -0400 Subject: [PATCH 01/14] initial changes --- opto/optimizers/optoprime_v2.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/opto/optimizers/optoprime_v2.py b/opto/optimizers/optoprime_v2.py index cc898bac..ce692bd2 100644 --- a/opto/optimizers/optoprime_v2.py +++ b/opto/optimizers/optoprime_v2.py @@ -37,6 +37,7 @@ class OptimizerPromptSymbolSet: instruction_section_title = "# Instruction" code_section_title = "# Code" documentation_section_title = "# Documentation" + context_section_title = "# Context" node_tag = "node" # nodes that are constants in the graph variable_tag = "variable" # nodes that can be changed @@ -141,6 +142,7 @@ def default_prompt_symbols(self) -> Dict[str, str]: "instruction": self.instruction_section_title, "code": self.code_section_title, "documentation": self.documentation_section_title, + "context": self.context_section_title } @@ -242,6 +244,7 @@ class OptimizerPromptSymbolSet2(OptimizerPromptSymbolSet): instruction_section_title = "# Instruction" code_section_title = "# Code" documentation_section_title = "# Documentation" + context_section_title = "# Context" node_tag = "const" # nodes that are constants in the graph variable_tag = "var" # nodes that can be changed @@ -264,6 +267,7 @@ class ProblemInstance: others: str outputs: str feedback: str + context: str optimizer_prompt_symbol_set: OptimizerPromptSymbolSet @@ -292,6 +296,9 @@ class ProblemInstance: # Feedback {feedback} + + # Context + {context} """ ) @@ -305,6 +312,7 @@ def __repr__(self) -> str: outputs=self.outputs, others=self.others, feedback=self.feedback, + context=self.context ) @@ -359,6 +367,7 @@ class OptoPrimeV2(OptoPrime): - {others_section_title}: the intermediate values created through the code execution. - {outputs_section_title}: the result of the code output. - {feedback_section_title}: the feedback about the code's execution result. + - {context_section_title}: the context information that might be useful to solve the problem. In `{variables_section_title}`, `{inputs_section_title}`, `{outputs_section_title}`, and `{others_section_title}`, the format is: @@ -413,17 +422,22 @@ class OptoPrimeV2(OptoPrime): example_prompt = dedent( """ - Here are some feasible but not optimal solutions for the current problem instance. Consider this as a hint to help you understand the problem better. ================================ - {examples} - ================================ """ ) + context_prompt = dedent( + """ + Here is some additional **context** to solving this problem: + + {context} + """ + ) + final_prompt = dedent( """ What are your suggestions on variables {names}? @@ -476,6 +490,7 @@ def __init__( ) self.example_problem_summary.variables = {'a': (5, "a > 0")} self.example_problem_summary.inputs = {'b': (1, None), 'c': (5, None)} + self.example_problem_summary.context = "" self.example_problem = self.problem_instance(self.example_problem_summary) self.example_response = self.optimizer_prompt_symbol_set.example_output( @@ -656,6 +671,7 @@ def problem_instance(self, summary, mask=None): constraint_tag=self.optimizer_prompt_symbol_set.constraint_tag) if self.optimizer_prompt_symbol_set.others_section_title not in mask else "" ), feedback=summary.user_feedback if self.optimizer_prompt_symbol_set.feedback_section_title not in mask else "", + context=summary.context if self.optimizer_prompt_symbol_set.context_section_title not in mask else "", optimizer_prompt_symbol_set=self.optimizer_prompt_symbol_set ) From 949d8ff99c2714cd4a65944999a77dcd44b4fd11 Mon Sep 17 00:00:00 2001 From: windweller Date: Fri, 3 Oct 2025 15:54:11 -0400 Subject: [PATCH 02/14] make context optional --- opto/optimizers/optoprime_v2.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/opto/optimizers/optoprime_v2.py b/opto/optimizers/optoprime_v2.py index ce692bd2..b56f4d56 100644 --- a/opto/optimizers/optoprime_v2.py +++ b/opto/optimizers/optoprime_v2.py @@ -267,7 +267,7 @@ class ProblemInstance: others: str outputs: str feedback: str - context: str + context: Optional[str] optimizer_prompt_symbol_set: OptimizerPromptSymbolSet @@ -296,14 +296,11 @@ class ProblemInstance: # Feedback {feedback} - - # Context - {context} """ ) def __repr__(self) -> str: - return self.problem_template.format( + optimization_query = self.problem_template.format( instruction=self.instruction, code=self.code, documentation=self.documentation, @@ -315,6 +312,18 @@ def __repr__(self) -> str: context=self.context ) + context_section = dedent(""" + + # Context + {context} + """) + + if self.context is not None and self.context.strip() != "": + context_section.format(context=self.context) + optimization_query += context_section + + return optimization_query + @dataclass class MemoryInstance: From e36dd7c578e8560056a1d8a81ac2b34e8388b05e Mon Sep 17 00:00:00 2001 From: windweller Date: Fri, 3 Oct 2025 17:00:44 -0400 Subject: [PATCH 03/14] finish adding image support to optoprime_v2 --- docs/tutorials/minibatch.ipynb | 22 +++++++------- opto/optimizers/optoprime_v2.py | 51 ++++++++++++++++++++++++++++----- opto/optimizers/utils.py | 17 +++++++++++ 3 files changed, 72 insertions(+), 18 deletions(-) diff --git a/docs/tutorials/minibatch.ipynb b/docs/tutorials/minibatch.ipynb index f752d866..e7cd4233 100644 --- a/docs/tutorials/minibatch.ipynb +++ b/docs/tutorials/minibatch.ipynb @@ -601,11 +601,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Step 1] \u001b[92mAverage test score: 1.0\u001b[0m\n", + "[Step 1] \u001B[92mAverage test score: 1.0\u001B[0m\n", "Epoch: 0. Iteration: 1\n", "[Step 1] Instantaneous train score: 1.0\n", "[Step 1] Average train score: 1.0\n", - "[Step 1] \u001b[91mParameter: str:20: You're a helpful agent\u001b[0m\n" + "[Step 1] \u001B[91mParameter: str:20: You're a helpful agent\u001B[0m\n" ] }, { @@ -641,11 +641,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Step 2] \u001b[92mAverage test score: 1.0\u001b[0m\n", + "[Step 2] \u001B[92mAverage test score: 1.0\u001B[0m\n", "Epoch: 0. Iteration: 2\n", "[Step 2] Instantaneous train score: 1.0\n", "[Step 2] Average train score: 1.0\n", - "[Step 2] \u001b[91mParameter: str:20: You're a helpful agent\u001b[0m\n" + "[Step 2] \u001B[91mParameter: str:20: You're a helpful agent\u001B[0m\n" ] }, { @@ -677,11 +677,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Step 3] \u001b[92mAverage test score: 1.0\u001b[0m\n", + "[Step 3] \u001B[92mAverage test score: 1.0\u001B[0m\n", "Epoch: 0. Iteration: 3\n", "[Step 3] Instantaneous train score: 1.0\n", "[Step 3] Average train score: 1.0\n", - "[Step 3] \u001b[91mParameter: str:20: You're a helpful agent\u001b[0m\n" + "[Step 3] \u001B[91mParameter: str:20: You're a helpful agent\u001B[0m\n" ] }, { @@ -714,11 +714,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Step 4] \u001b[92mAverage test score: 1.0\u001b[0m\n", + "[Step 4] \u001B[92mAverage test score: 1.0\u001B[0m\n", "Epoch: 0. Iteration: 4\n", "[Step 4] Instantaneous train score: 1.0\n", "[Step 4] Average train score: 1.0\n", - "[Step 4] \u001b[91mParameter: str:20: You're a helpful agent\u001b[0m\n" + "[Step 4] \u001B[91mParameter: str:20: You're a helpful agent\u001B[0m\n" ] }, { @@ -751,11 +751,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Step 5] \u001b[92mAverage test score: 1.0\u001b[0m\n", + "[Step 5] \u001B[92mAverage test score: 1.0\u001B[0m\n", "Epoch: 0. Iteration: 5\n", "[Step 5] Instantaneous train score: 1.0\n", "[Step 5] Average train score: 1.0\n", - "[Step 5] \u001b[91mParameter: str:20: You're a helpful agent\u001b[0m\n", + "[Step 5] \u001B[91mParameter: str:20: You're a helpful agent\u001B[0m\n", "FINISHED TRAINING\n" ] }, @@ -831,4 +831,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/opto/optimizers/optoprime_v2.py b/opto/optimizers/optoprime_v2.py index b56f4d56..87a57446 100644 --- a/opto/optimizers/optoprime_v2.py +++ b/opto/optimizers/optoprime_v2.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, asdict from opto.optimizers.optoprime import OptoPrime, FunctionFeedback from opto.trace.utils import dedent -from opto.optimizers.utils import truncate_expression, extract_xml_like_data +from opto.optimizers.utils import truncate_expression, extract_xml_like_data, encode_image_to_base64 from opto.trace.nodes import ParameterNode, Node, MessageNode from opto.trace.propagators import TraceGraph, GraphPropagator @@ -17,6 +17,11 @@ from typing import Dict, Any +@dataclass +class MultiModalPayload: + image_bytes: Optional[str] = None # base64 encoded image bytes + + class OptimizerPromptSymbolSet: """ By inheriting this class and pass into the optimizer. People can change the optimizer documentation @@ -327,7 +332,7 @@ def __repr__(self) -> str: @dataclass class MemoryInstance: - variables: Dict[str, Tuple[Any, str]] # name -> (data, constraint) + variables: Dict[str, Tuple[Any, str]] # name -> (data, constraint) feedback: str optimizer_prompt_symbol_set: OptimizerPromptSymbolSet @@ -472,11 +477,14 @@ def __init__( optimizer_prompt_symbol_set: OptimizerPromptSymbolSet = OptimizerPromptSymbolSet(), use_json_object_format=True, # whether to use json object format for the response when calling LLM truncate_expression=truncate_expression, + problem_context: Optional[str] = None, **kwargs, ): super().__init__(parameters, *args, propagator=propagator, **kwargs) self.truncate_expression = truncate_expression + self.problem_context = problem_context + self.multimodal_payload = MultiModalPayload() self.use_json_object_format = use_json_object_format if optimizer_prompt_symbol_set.expect_json and use_json_object_format else False self.ignore_extraction_error = ignore_extraction_error @@ -499,7 +507,6 @@ def __init__( ) self.example_problem_summary.variables = {'a': (5, "a > 0")} self.example_problem_summary.inputs = {'b': (1, None), 'c': (5, None)} - self.example_problem_summary.context = "" self.example_problem = self.problem_instance(self.example_problem_summary) self.example_response = self.optimizer_prompt_symbol_set.example_output( @@ -520,6 +527,23 @@ def __init__( self.prompt_symbols = copy.deepcopy(self.default_prompt_symbols) self.initialize_prompt() + def add_image_context(self, image_path: str, context: str = ""): + if self.problem_context is None: + self.problem_context = "" + self.problem_context += f"{context}\n\n" + + # we load in the image and convert to base64 + data_url = encode_image_to_base64(image_path) + self.multimodal_payload.image_bytes = data_url + + self.initialize_prompt() + + def add_context(self, context: str): + if self.problem_context is None: + self.problem_context = "" + self.problem_context += f"{context}\n\n" + self.initialize_prompt() + def initialize_prompt(self): self.representation_prompt = self.representation_prompt.format( variable_expression_format=dedent(f""" @@ -540,7 +564,8 @@ def initialize_prompt(self): instruction_section_title=self.optimizer_prompt_symbol_set.instruction_section_title.replace(" ", ""), code_section_title=self.optimizer_prompt_symbol_set.code_section_title.replace(" ", ""), documentation_section_title=self.optimizer_prompt_symbol_set.documentation_section_title.replace(" ", ""), - others_section_title=self.optimizer_prompt_symbol_set.others_section_title.replace(" ", "") + others_section_title=self.optimizer_prompt_symbol_set.others_section_title.replace(" ", ""), + context_section_title=self.optimizer_prompt_symbol_set.context_section_title.replace(" ", "") ) self.output_format_prompt = self.output_format_prompt_template.format( output_format=self.optimizer_prompt_symbol_set.output_format, @@ -553,7 +578,8 @@ def initialize_prompt(self): documentation_section_title=self.optimizer_prompt_symbol_set.documentation_section_title.replace(" ", ""), variables_section_title=self.optimizer_prompt_symbol_set.variables_section_title.replace(" ", ""), inputs_section_title=self.optimizer_prompt_symbol_set.inputs_section_title.replace(" ", ""), - others_section_title=self.optimizer_prompt_symbol_set.others_section_title.replace(" ", "") + others_section_title=self.optimizer_prompt_symbol_set.others_section_title.replace(" ", ""), + context_section_title=self.optimizer_prompt_symbol_set.context_section_title.replace(" ", "") ) def repr_node_value(self, node_dict, node_tag="node", @@ -680,7 +706,7 @@ def problem_instance(self, summary, mask=None): constraint_tag=self.optimizer_prompt_symbol_set.constraint_tag) if self.optimizer_prompt_symbol_set.others_section_title not in mask else "" ), feedback=summary.user_feedback if self.optimizer_prompt_symbol_set.feedback_section_title not in mask else "", - context=summary.context if self.optimizer_prompt_symbol_set.context_section_title not in mask else "", + context=self.problem_context if self.optimizer_prompt_symbol_set.context_section_title not in mask else "", optimizer_prompt_symbol_set=self.optimizer_prompt_symbol_set ) @@ -742,9 +768,20 @@ def call_llm( if verbose not in (False, "output"): print("Prompt\n", system_prompt + user_prompt) + user_message_content = [] + if self.multimodal_payload.image_bytes is not None: + user_message_content.append({ + "type": "image_url", + "image_url": { + "url": self.multimodal_payload.image_bytes + } + }) + + user_message_content.append({"type": "text", "text": user_prompt}) + messages = [ {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, + {"role": "user", "content": user_message_content}, ] response_format = {"type": "json_object"} if self.use_json_object_format else None diff --git a/opto/optimizers/utils.py b/opto/optimizers/utils.py index 13a5ad01..4fbec459 100644 --- a/opto/optimizers/utils.py +++ b/opto/optimizers/utils.py @@ -1,5 +1,8 @@ +import base64 +import mimetypes from typing import Dict, Any + def print_color(message, color=None, logger=None): colors = { "red": "\033[91m", @@ -134,3 +137,17 @@ def extract_xml_like_data(text: str, reasoning_tag: str = "reasoning", if var_name: # Only require name to be non-empty, value can be empty result['variables'][var_name] = var_value return result + + +def encode_image_to_base64(path: str) -> str: + # Read binary + with open(path, "rb") as f: + image_bytes = f.read() + # Guess MIME type from file extension + mime_type, _ = mimetypes.guess_type(path) + if mime_type is None: + # fallback + mime_type = "image/jpeg" + b64 = base64.b64encode(image_bytes).decode("utf-8") + data_url = f"data:{mime_type};base64,{b64}" + return data_url From 7dcb880644e909c5f4a7fcf9c210e95eadf6bc19 Mon Sep 17 00:00:00 2001 From: windweller Date: Fri, 3 Oct 2025 17:20:09 -0400 Subject: [PATCH 04/14] Finish updating OPRO to accept additional context --- opto/optimizers/opro_v2.py | 43 ++++++++++++++++++++++++++++----- opto/optimizers/optoprime_v2.py | 4 +-- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/opto/optimizers/opro_v2.py b/opto/optimizers/opro_v2.py index ff5c801d..19b33e58 100644 --- a/opto/optimizers/opro_v2.py +++ b/opto/optimizers/opro_v2.py @@ -1,7 +1,7 @@ import json from textwrap import dedent from dataclasses import dataclass, asdict -from typing import Dict +from typing import Dict, Optional from opto.optimizers.optoprime_v2 import OptoPrimeV2, OptimizerPromptSymbolSet @@ -15,8 +15,8 @@ class OPROPromptSymbolSet(OptimizerPromptSymbolSet): Attributes ---------- - problem_context_section_title : str - Title for the problem context section in prompts. + instruction_section_title : str + Title for the instruction section in prompts. variable_section_title : str Title for the variable/solution section in prompts. feedback_section_title : str @@ -49,9 +49,10 @@ class OPROPromptSymbolSet(OptimizerPromptSymbolSet): more focused set of symbols specifically for OPRO optimization. """ - problem_context_section_title = "# Problem Context" + instruction_section_title = "# Instruction" variable_section_title = "# Solution" feedback_section_title = "# Feedback" + context_section_title = "# Context" node_tag = "node" # nodes that are constants in the graph variable_tag = "solution" # nodes that can be changed @@ -72,6 +73,7 @@ def default_prompt_symbols(self) -> Dict[str, str]: "variables": self.variables_section_title, "feedback": self.feedback_section_title, "instruction": self.instruction_section_title, + "context": self.context_section_title } @dataclass @@ -89,6 +91,9 @@ class ProblemInstance: The current proposed solution that can be modified. feedback : str Feedback about the current solution. + context: str + Optional context information that might be useful to solve the problem. + optimizer_prompt_symbol_set : OPROPromptSymbolSet The symbol set used for formatting the problem. problem_template : str @@ -107,12 +112,13 @@ class ProblemInstance: instruction: str variables: str feedback: str + context: Optional[str] optimizer_prompt_symbol_set: OPROPromptSymbolSet problem_template = dedent( """ - # Problem Context + # Instruction {instruction} # Solution @@ -124,12 +130,24 @@ class ProblemInstance: ) def __repr__(self) -> str: - return self.problem_template.format( + optimization_query = self.problem_template.format( instruction=self.instruction, variables=self.variables, feedback=self.feedback, ) + context_section = dedent(""" + + # Context + {context} + """) + + if self.context is not None and self.context.strip() != "": + context_section.format(context=self.context) + optimization_query += context_section + + return optimization_query + class OPROv2(OptoPrimeV2): """OPRO (Optimization by PROmpting) optimizer version 2. @@ -197,6 +215,7 @@ class OPROv2(OptoPrimeV2): - {instruction_section_title}: the instruction which describes the things you need to do or the question you should answer. - {variables_section_title}: the proposed solution that you can change/tweak (trainable). - {feedback_section_title}: the feedback about the solution. + - {context_section_title}: the context information that might be useful to solve the problem. If `data_type` is `code`, it means `{value_tag}` is the source code of a python code, which may include docstring and definitions. """ @@ -229,6 +248,14 @@ class OPROv2(OptoPrimeV2): """ ) + context_prompt = dedent( + """ + Here is some additional **context** to solving this problem: + + {context} + """ + ) + final_prompt = dedent( """ What are your revised solutions on {names}? @@ -244,6 +271,7 @@ def __init__(self, *args, optimizer_prompt_symbol_set: OptimizerPromptSymbolSet = None, include_example=False, # default example in OptoPrimeV2 does not work in OPRO memory_size=5, + problem_context: Optional[str] = None, **kwargs): """Initialize the OPROv2 optimizer. @@ -264,6 +292,7 @@ def __init__(self, *args, optimizer_prompt_symbol_set = optimizer_prompt_symbol_set or OPROPromptSymbolSet() super().__init__(*args, optimizer_prompt_symbol_set=optimizer_prompt_symbol_set, include_example=include_example, memory_size=memory_size, + problem_context=problem_context, **kwargs) def problem_instance(self, summary, mask=None): @@ -328,6 +357,7 @@ def initialize_prompt(self): variables_section_title=self.optimizer_prompt_symbol_set.variables_section_title.replace(" ", ""), feedback_section_title=self.optimizer_prompt_symbol_set.feedback_section_title.replace(" ", ""), instruction_section_title=self.optimizer_prompt_symbol_set.instruction_section_title.replace(" ", ""), + context_section_title=self.optimizer_prompt_symbol_set.context_section_title.replace(" ", "") ) self.output_format_prompt = self.output_format_prompt_template.format( output_format=self.optimizer_prompt_symbol_set.output_format, @@ -336,4 +366,5 @@ def initialize_prompt(self): instruction_section_title=self.optimizer_prompt_symbol_set.instruction_section_title.replace(" ", ""), feedback_section_title=self.optimizer_prompt_symbol_set.feedback_section_title.replace(" ", ""), variables_section_title=self.optimizer_prompt_symbol_set.variables_section_title.replace(" ", ""), + context_section_title=self.optimizer_prompt_symbol_set.context_section_title.replace(" ", "") ) diff --git a/opto/optimizers/optoprime_v2.py b/opto/optimizers/optoprime_v2.py index 87a57446..2486710b 100644 --- a/opto/optimizers/optoprime_v2.py +++ b/opto/optimizers/optoprime_v2.py @@ -156,7 +156,7 @@ class OptimizerPromptSymbolSetJSON(OptimizerPromptSymbolSet): expect_json = True - custom_output_format_instruction = """ + custom_output_format_instruction = dedent(""" {{ "reasoning": , "suggestion": {{ @@ -164,7 +164,7 @@ class OptimizerPromptSymbolSetJSON(OptimizerPromptSymbolSet): : , }} }} - """ + """) def example_output(self, reasoning, variables): """ From 8a31e8b3d487ac9e563fd13413c571e2020599d7 Mon Sep 17 00:00:00 2001 From: windweller Date: Sun, 5 Oct 2025 13:50:01 -0400 Subject: [PATCH 05/14] add context prompt into pickle save/load. Modify `test_priority_search`'s mock test to expect different kind of input --- opto/optimizers/optoprime_v2.py | 2 ++ tests/unit_tests/test_priority_search.py | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/opto/optimizers/optoprime_v2.py b/opto/optimizers/optoprime_v2.py index 2486710b..2eb0a862 100644 --- a/opto/optimizers/optoprime_v2.py +++ b/opto/optimizers/optoprime_v2.py @@ -811,6 +811,7 @@ def save(self, path: str): "prompt_symbols": self.prompt_symbols, "representation_prompt": self.representation_prompt, "output_format_prompt": self.output_format_prompt, + 'context_prompt': self.context_prompt }, f) def load(self, path: str): @@ -830,3 +831,4 @@ def load(self, path: str): self.prompt_symbols = state["prompt_symbols"] self.representation_prompt = state["representation_prompt"] self.output_format_prompt = state["output_format_prompt"] + self.context_prompt = state["context_prompt"] diff --git a/tests/unit_tests/test_priority_search.py b/tests/unit_tests/test_priority_search.py index 2ebda047..4a698fba 100644 --- a/tests/unit_tests/test_priority_search.py +++ b/tests/unit_tests/test_priority_search.py @@ -121,6 +121,15 @@ def _llm_callable(messages, **kwargs): A dummy LLM callable that simulates a response. """ problem = messages[1]['content'] + # in newer LLM API (LiteLLM, OpenAI client, etc.), the user message content is now a list of typed messages: + # [{'type': 'text', 'text': '...'}, {'type': 'image', 'image_url': '...'}] + # this expansion is necessary for multi-modal inputs + + if type(problem) is list: + for typed_message in problem: + if typed_message['type'] == 'text': + problem = typed_message['text'] + break # extract name from name = re.findall(r"", problem) From c271d9c468e5e5ac1f67d3196fcbb75d1ecc31c4 Mon Sep 17 00:00:00 2001 From: windweller Date: Sun, 5 Oct 2025 14:58:49 -0400 Subject: [PATCH 06/14] comment out the small-LLM test --- .github/workflows/ci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7889b69d..1fdcd036 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -67,6 +67,6 @@ jobs: run: pytest tests/unit_tests/ # 9) Run basic tests for each optimizer (some will fail due to the small LLM model chosen for free GitHub CI) - - name: Run optimizers test suite - run: pytest tests/llm_optimizers_tests/test_optimizer.py || true - continue-on-error: true +# - name: Run optimizers test suite +# run: pytest tests/llm_optimizers_tests/test_optimizer.py || true +# continue-on-error: true From 148d4bc7e6ac2560f414bc5687e2bf88d6e13d69 Mon Sep 17 00:00:00 2001 From: windweller Date: Mon, 6 Oct 2025 13:33:53 -0400 Subject: [PATCH 07/14] add multi-modal support for the LLM model as well --- opto/features/flows/compose.py | 17 ++++++--- opto/features/flows/types.py | 68 +++++++++++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 7 deletions(-) diff --git a/opto/features/flows/compose.py b/opto/features/flows/compose.py index f075ac4b..0d05801a 100644 --- a/opto/features/flows/compose.py +++ b/opto/features/flows/compose.py @@ -1,6 +1,7 @@ import opto.trace as trace from typing import Union, get_type_hints, Any, Dict, List, Optional from opto.utils.llm import AbstractModel, LLM +from opto.features.flows.types import MultiModalPayload, QueryModel import contextvars """ @@ -178,7 +179,8 @@ def __init__(self, self.model_name = model_name if model_name else f"TracedLLM{len(current_llm_sessions)}" current_llm_sessions.append(1) # just a marker - def forward(self, user_query: str, chat_history_on: Optional[bool] = None) -> str: + def forward(self, user_query: str, chat_history_on: Optional[bool] = None, + payload: Optional[MultiModalPayload] = None) -> str: """This function takes user_query as input, and returns the response from the LLM, with the system prompt prepended. This method will always save chat history. @@ -187,17 +189,19 @@ def forward(self, user_query: str, chat_history_on: Optional[bool] = None) -> st If chat_history_on is True, the chat history will be included in the LLM input. Args: - user_query: The user query to send to the LLM + user_query: The user query to send to the LLM. Can be Returns: str: For direct pattern """ chat_history_on = self.chat_history_on if chat_history_on is None else chat_history_on + user_message = QueryModel(query=user_query, multimodal_payload=payload).query + messages = [{"role": "system", "content": self.system_prompt.data}] if chat_history_on: messages.extend(self.chat_history.get_messages()) - messages.append({"role": "user", "content": user_query}) + messages.append({"role": "user", "content": user_message}) response = self.llm(messages=messages) @@ -226,5 +230,8 @@ def call_llm(*args) -> str: return response_node - def chat(self, user_query: str) -> str: - return self.forward(user_query) + def chat(self, user_query: str, chat_history_on: Optional[bool] = None, + payload: Optional[MultiModalPayload] = None) -> str: + """Note that chat/forward always assumes it's a single turn of the conversation. History/context management will be accomplished + through other APIs""" + return self.forward(user_query, chat_history_on, payload) diff --git a/opto/features/flows/types.py b/opto/features/flows/types.py index 4196b926..e5589bed 100644 --- a/opto/features/flows/types.py +++ b/opto/features/flows/types.py @@ -1,10 +1,74 @@ """Types for opto flows.""" -from pydantic import BaseModel, Field, create_model, ConfigDict +from typing import List, Dict, Union +from pydantic import BaseModel, model_validator from typing import Any, Optional, Callable, Dict, Union, Type, List +from dataclasses import dataclass import re import json +from opto.optimizers.utils import encode_image_to_base64 + class TraceObject: def __str__(self): # Any subclass that inherits this will be friendly to the optimizer - raise NotImplementedError("Subclasses must implement __str__") \ No newline at end of file + raise NotImplementedError("Subclasses must implement __str__") + + +class MultiModalPayload(BaseModel): + image_bytes: Optional[str] = None # base64-encoded data URL + + @classmethod + def from_path(cls, path: str) -> "MultiModalPayload": + """Create a payload by loading an image from a local file path.""" + data_url = encode_image_to_base64(path) + return cls(image_bytes=data_url) + + def load_image(self, path: str) -> None: + """Mutate the current payload to include a new image.""" + self.image_bytes = encode_image_to_base64(path) + +class QueryModel(BaseModel): + # Expose "query" as already-normalized: always a List[Dict[str, Any]] + query: List[Dict[str, Any]] + multimodal_payload: Optional[MultiModalPayload] = None + + @model_validator(mode="before") + @classmethod + def normalize(cls, data: Any): + """ + Accepts: + { "query": "hello" } + { "query": "hello", "multimodal_payload": {"image_bytes": "..."} } + And always produces: + { "query": [ {text block}, maybe {image_url block} ], "multimodal_payload": ...} + """ + if not isinstance(data, dict): + raise TypeError("QueryModel input must be a dict") + + raw_query: str = data.get("query") + + # 1) Start with the text part + if isinstance(raw_query, str): + out: List[Dict[str, Any]] = [{"type": "text", "text": raw_query}] + else: + raise TypeError("`query` must be a string or a list of dicts") + + # 2) If we have an image, append an image block + payload = data.get("multimodal_payload") + image_bytes: Optional[str] = None + if payload is not None: + if isinstance(payload, dict): + image_bytes = payload.get("image_bytes") + else: + # Could be already-parsed MultiModalPayload + image_bytes = getattr(payload, "image_bytes", None) + + if image_bytes: + out = out + [{ + "type": "image_url", + "image_url": {"url": image_bytes} + }] + + # 3) Write back normalized fields + data["query"] = out + return data From b340d91c5940fed4935e5383dbe5b15572cf64ad Mon Sep 17 00:00:00 2001 From: windweller Date: Mon, 6 Oct 2025 14:32:32 -0400 Subject: [PATCH 08/14] update the image-context prompt on optimizer. Make LLM module better. --- opto/features/flows/compose.py | 15 ++++++++------- opto/optimizers/optoprime_v2.py | 4 ++++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/opto/features/flows/compose.py b/opto/features/flows/compose.py index 0d05801a..93ac91e1 100644 --- a/opto/features/flows/compose.py +++ b/opto/features/flows/compose.py @@ -158,6 +158,7 @@ def __init__(self, system_prompt: The system prompt to use for LLM calls. If None and the class has a docstring, the docstring will be used. llm: The LLM model to use for inference chat_history_on: if on, maintain chat history for multi-turn conversations + model_name: override the default name of the model """ if system_prompt is None: system_prompt = "You are a helpful assistant." @@ -176,11 +177,12 @@ def __init__(self, self.chat_history_on = chat_history_on current_llm_sessions = USED_TracedLLM.get() - self.model_name = model_name if model_name else f"TracedLLM{len(current_llm_sessions)}" + self.model_name = model_name if model_name else f"{self.__class__.__name__}{len(current_llm_sessions)}" current_llm_sessions.append(1) # just a marker - def forward(self, user_query: str, chat_history_on: Optional[bool] = None, - payload: Optional[MultiModalPayload] = None) -> str: + def forward(self, user_query: str, + payload: Optional[MultiModalPayload] = None, + chat_history_on: Optional[bool] = None) -> str: """This function takes user_query as input, and returns the response from the LLM, with the system prompt prepended. This method will always save chat history. @@ -205,7 +207,7 @@ def forward(self, user_query: str, chat_history_on: Optional[bool] = None, response = self.llm(messages=messages) - @trace.bundle(output_name="TracedLLM_response") + @trace.bundle(output_name=f"{self.model_name}_response") def call_llm(*args) -> str: """Call the LLM model. Args: @@ -230,8 +232,7 @@ def call_llm(*args) -> str: return response_node - def chat(self, user_query: str, chat_history_on: Optional[bool] = None, - payload: Optional[MultiModalPayload] = None) -> str: + def chat(self, user_query: str, payload: Optional[MultiModalPayload] = None, chat_history_on: Optional[bool] = None) -> str: """Note that chat/forward always assumes it's a single turn of the conversation. History/context management will be accomplished through other APIs""" - return self.forward(user_query, chat_history_on, payload) + return self.forward(user_query, payload, chat_history_on) diff --git a/opto/optimizers/optoprime_v2.py b/opto/optimizers/optoprime_v2.py index 2eb0a862..376f0a26 100644 --- a/opto/optimizers/optoprime_v2.py +++ b/opto/optimizers/optoprime_v2.py @@ -530,6 +530,10 @@ def __init__( def add_image_context(self, image_path: str, context: str = ""): if self.problem_context is None: self.problem_context = "" + + if context == "": + context = "The attached image is given to the workflow. You should use the image to help you understand the problem and provide better suggestions. You can refer to the image when providing your suggestions." + self.problem_context += f"{context}\n\n" # we load in the image and convert to base64 From acae873b6aa9f245a5d3ffc589ab60b6f7ab41ef Mon Sep 17 00:00:00 2001 From: windweller Date: Mon, 6 Oct 2025 14:47:14 -0400 Subject: [PATCH 09/14] fix a bug on QueryModel not handling Node as input --- opto/features/flows/compose.py | 5 ++--- opto/features/flows/types.py | 9 ++++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/opto/features/flows/compose.py b/opto/features/flows/compose.py index 93ac91e1..3059a9fa 100644 --- a/opto/features/flows/compose.py +++ b/opto/features/flows/compose.py @@ -208,11 +208,10 @@ def forward(self, user_query: str, response = self.llm(messages=messages) @trace.bundle(output_name=f"{self.model_name}_response") - def call_llm(*args) -> str: + def call_llm(*messages) -> str: """Call the LLM model. Args: - All the conversation history so far, starting from system prompt, to alternating user/assistant messages, ending with the current user query. - + messages: All the conversation history so far, starting from system prompt, to alternating user/assistant messages, ending with the current user query. Returns: response from the LLM """ diff --git a/opto/features/flows/types.py b/opto/features/flows/types.py index e5589bed..33712944 100644 --- a/opto/features/flows/types.py +++ b/opto/features/flows/types.py @@ -6,7 +6,7 @@ import re import json from opto.optimizers.utils import encode_image_to_base64 - +from opto import trace class TraceObject: def __str__(self): @@ -45,13 +45,16 @@ def normalize(cls, data: Any): if not isinstance(data, dict): raise TypeError("QueryModel input must be a dict") - raw_query: str = data.get("query") + raw_query: Any = data.get("query") + if isinstance(raw_query, trace.Node): + assert isinstance(raw_query.data, (str, list)), "If using trace.Node, its data must be str" + raw_query = raw_query.data # 1) Start with the text part if isinstance(raw_query, str): out: List[Dict[str, Any]] = [{"type": "text", "text": raw_query}] else: - raise TypeError("`query` must be a string or a list of dicts") + raise TypeError("`query` must be a string") # 2) If we have an image, append an image block payload = data.get("multimodal_payload") From 2c27415f646093b5178575bdd702865af2eeee78 Mon Sep 17 00:00:00 2001 From: windweller Date: Mon, 13 Oct 2025 13:55:25 -0400 Subject: [PATCH 10/14] initial commit --- .../experimental_optimizers/__init__.py | 0 .../experimental_optimizers/agentic_opt.py | 46 +++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 opto/features/experimental_optimizers/__init__.py create mode 100644 opto/features/experimental_optimizers/agentic_opt.py diff --git a/opto/features/experimental_optimizers/__init__.py b/opto/features/experimental_optimizers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/opto/features/experimental_optimizers/agentic_opt.py b/opto/features/experimental_optimizers/agentic_opt.py new file mode 100644 index 00000000..61df2585 --- /dev/null +++ b/opto/features/experimental_optimizers/agentic_opt.py @@ -0,0 +1,46 @@ +""" +An Agentic Optimizer that has access to tools and other resources to perform complex optimization tasks. +Particularly useful for formal theorem proving, code optimization (including kernel code), and algorithm designs. +""" + +import json +from textwrap import dedent +from dataclasses import dataclass, asdict +from typing import Dict + +from opto.trace.nodes import ParameterNode, Node, MessageNode +from opto.trace.propagators import TraceGraph, GraphPropagator +from opto.trace.propagators.propagators import Propagator + +from opto.optimizers.optoprime_v2 import OptoPrimeV2, OptimizerPromptSymbolSet +from opto.optimizers.optimizer import Optimizer + +from typing import Any, List, Dict, Union, Tuple, Optional + +""" +A few design that it must have: +1. multi-turn conversation by default +2. can take in tools +""" + +class AgenticOptimizer(Optimizer): + def __init__( + self, + parameters: List[ParameterNode], + llm: AbstractModel = None, + *args, + propagator: Propagator = None, + objective: Union[None, str] = None, + ignore_extraction_error: bool = True, + # ignore the type conversion error when extracting updated values from LLM's suggestion + include_example=False, + memory_size=0, # Memory size to store the past feedback + max_tokens=4096, + log=True, + initial_var_char_limit=100, + optimizer_prompt_symbol_set: OptimizerPromptSymbolSet = None, + use_json_object_format=True, # whether to use json object format for the response when calling LLM + truncate_expression=truncate_expression, + **kwargs, + ): + pass From 961e43faa4d95ab4dea090cd2bc02b7ef1af468d Mon Sep 17 00:00:00 2001 From: windweller Date: Mon, 13 Oct 2025 15:04:35 -0400 Subject: [PATCH 11/14] add structured input/output for LLM but in a more Trace-like way (using functions and decorators) --- .../experimental_optimizers/agentic_opt.py | 1 + opto/features/flows/compose.py | 75 +++++++- opto/features/flows/types.py | 177 ++++++++++++++++++ 3 files changed, 251 insertions(+), 2 deletions(-) diff --git a/opto/features/experimental_optimizers/agentic_opt.py b/opto/features/experimental_optimizers/agentic_opt.py index 7c3ebaa8..106e2abb 100644 --- a/opto/features/experimental_optimizers/agentic_opt.py +++ b/opto/features/experimental_optimizers/agentic_opt.py @@ -30,6 +30,7 @@ Idea: write it like you would use for VeriBench 1. bug fix loop 2. external reward loop + """ diff --git a/opto/features/flows/compose.py b/opto/features/flows/compose.py index 3059a9fa..4e0be1ae 100644 --- a/opto/features/flows/compose.py +++ b/opto/features/flows/compose.py @@ -1,9 +1,10 @@ import opto.trace as trace -from typing import Union, get_type_hints, Any, Dict, List, Optional +from typing import Union, get_type_hints, Any, Dict, List, Optional, Callable from opto.utils.llm import AbstractModel, LLM -from opto.features.flows.types import MultiModalPayload, QueryModel +from opto.features.flows.types import MultiModalPayload, QueryModel, StructuredInput, StructuredOutput import contextvars +# =========== LLM Base Model =========== """ TracedLLM: 1. special operations that supports specifying inputs (system_prompt, user_prompt) to LLM and parsing of outputs, wrap @@ -235,3 +236,73 @@ def chat(self, user_query: str, payload: Optional[MultiModalPayload] = None, cha """Note that chat/forward always assumes it's a single turn of the conversation. History/context management will be accomplished through other APIs""" return self.forward(user_query, payload, chat_history_on) + +# =========== =========== + +# =========== Structured LLM Input/Output With Parsing =========== + +""" +Usage: + +@llm_call +def evaluate_person(person: Person) -> Preference: + "Evaluate if a person matches our criteria" + ... + +person = Person(name="Alice", age=30, income=75000) +preference = evaluate_person(person) + +TODO: add LLM call and parsing logic +TODO 2: add trace bundle and input/output conversion +""" + +def llm_call(func: Callable): + """ + Decorator that extracts input/output schemas from type-annotated functions. + + Usage: + @call_llm + def process_person(person: Person) -> Preference: + ... + + # Access schemas + process_person.input_type + process_person.output_type + process_person.input_schema + process_person.output_schema + """ + hints = get_type_hints(func) + + # Get first parameter type and return type + params = list(hints.items()) + input_type = None + output_type = None + + # Find first non-return parameter + for param_name, param_type in params: + if param_name != 'return': + input_type = param_type + break + + output_type = hints.get('return') + + # Validate types + if input_type and not issubclass(input_type, StructuredInput): + raise TypeError(f"Input type {input_type} must inherit from StructuredInput") + + if output_type and not issubclass(output_type, StructuredOutput): + raise TypeError(f"Output type {output_type} must inherit from StructuredOutput") + + # Attach metadata to function + func.input_type = input_type + func.output_type = output_type + func.input_schema = input_type.model_json_schema() if input_type else None + func.output_schema = output_type.model_json_schema() if output_type else None + + # Additional helper methods + func.get_input_docstring = lambda: input_type.get_docstring() if input_type else None + func.get_output_docstring = lambda: output_type.get_docstring() if output_type else None + func.get_input_fields = lambda: input_type.get_fields_info() if input_type else {} + func.get_output_fields = lambda: output_type.get_fields_info() if output_type else {} + + return func \ No newline at end of file diff --git a/opto/features/flows/types.py b/opto/features/flows/types.py index 33712944..79d46f91 100644 --- a/opto/features/flows/types.py +++ b/opto/features/flows/types.py @@ -8,12 +8,14 @@ from opto.optimizers.utils import encode_image_to_base64 from opto import trace + class TraceObject: def __str__(self): # Any subclass that inherits this will be friendly to the optimizer raise NotImplementedError("Subclasses must implement __str__") +# ====== Multi-Modal LLM Support ====== class MultiModalPayload(BaseModel): image_bytes: Optional[str] = None # base64-encoded data URL @@ -27,6 +29,7 @@ def load_image(self, path: str) -> None: """Mutate the current payload to include a new image.""" self.image_bytes = encode_image_to_base64(path) + class QueryModel(BaseModel): # Expose "query" as already-normalized: always a List[Dict[str, Any]] query: List[Dict[str, Any]] @@ -75,3 +78,177 @@ def normalize(cls, data: Any): # 3) Write back normalized fields data["query"] = out return data + + +# ====== ====== + +# ======= Structured LLM Info Support ====== + +from typing import Any, Dict, Optional, Type, get_type_hints +from pydantic import BaseModel, create_model, Field +import inspect + + +class StructuredData(BaseModel): + """ + Base class for structured data (inputs/outputs) with support for both + inheritance and dynamic on-the-fly usage. + """ + + _docstring: Optional[str] = None + + def __init_subclass__(cls, **kwargs): + """Called when a class inherits from StructuredData""" + super().__init_subclass__(**kwargs) + cls._docstring = inspect.getdoc(cls) + + def __new__(cls, docstring: Optional[str] = None, **kwargs): + """ + Handle both inheritance and on-the-fly usage. + """ + # Check if being used dynamically (direct instantiation with docstring) + if cls in (StructuredData, StructuredInput, StructuredOutput) and \ + docstring is not None and isinstance(docstring, str): + # Determine the appropriate class name for dynamic instances + if cls is StructuredInput or (cls is StructuredData and 'Input' in str(cls)): + dynamic_name = 'DynamicStructuredInput' + elif cls is StructuredOutput: + dynamic_name = 'DynamicStructuredOutput' + else: + dynamic_name = 'DynamicStructuredData' + + dynamic_cls = type( + dynamic_name, + (cls,), + { + '__doc__': docstring, + '_docstring': docstring, + '__module__': cls.__module__, + } + ) + instance = super(StructuredData, dynamic_cls).__new__(dynamic_cls) + return instance + + return super().__new__(cls) + + def __init__(self, docstring: Optional[str] = None, **kwargs): + """Initialize the instance""" + dynamic_names = ('DynamicStructuredData', 'DynamicStructuredInput', 'DynamicStructuredOutput') + if isinstance(docstring, str) and self.__class__.__name__ in dynamic_names: + super().__init__(**kwargs) + object.__setattr__(self, '_docstring', docstring) + else: + super().__init__(**kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + """Allow dynamic attribute setting for on-the-fly usage""" + if name.startswith('_'): + super().__setattr__(name, value) + else: + dynamic_names = ('DynamicStructuredData', 'DynamicStructuredInput', 'DynamicStructuredOutput') + if self.__class__.__name__ in dynamic_names: + object.__setattr__(self, name, value) + else: + super().__setattr__(name, value) + + @classmethod + def get_docstring(cls) -> Optional[str]: + """Get the docstring of the class""" + return getattr(cls, '_docstring', None) or inspect.getdoc(cls) + + @classmethod + def get_fields_info(cls) -> Dict[str, Any]: + """Get information about all fields""" + base_names = ('StructuredData', 'StructuredInput', 'StructuredOutput') + dynamic_names = ('DynamicStructuredData', 'DynamicStructuredInput', 'DynamicStructuredOutput') + + if cls.__name__ in dynamic_names or cls.__name__ in base_names: + return {} + + fields_info = {} + for field_name, field in cls.model_fields.items(): + fields_info[field_name] = { + 'type': field.annotation, + 'required': field.is_required(), + 'default': field.default if field.default is not None else None, + } + return fields_info + + def get_instance_fields(self) -> Dict[str, Any]: + """Get all fields and their values from the instance""" + dynamic_names = ('DynamicStructuredData', 'DynamicStructuredInput', 'DynamicStructuredOutput') + if self.__class__.__name__ in dynamic_names: + return { + k: v for k, v in self.__dict__.items() + if not k.startswith('_') + } + else: + return self.model_dump() + + def __str__(self, template: Optional[str] = None) -> str: + """ + Convert the structured data to a string format for LLM consumption. + Subclasses can override for specific formatting. + """ + docstring = self.get_docstring() or "No description provided" + fields = self.get_instance_fields() + + if template: + return template.format(docstring=docstring, fields=fields) + + lines = [f"Description: {docstring}", "", "Fields:"] + for field_name, field_value in fields.items(): + lines.append(f"- {field_name}: {field_value}") + + return "\n".join(lines) + + +class StructuredInput(StructuredData): + """ + Base class for structured inputs that can be used via inheritance or dynamically. + """ + + def __str__(self, template: Optional[str] = None) -> str: + """ + Convert the structured input to a string format emphasizing input data. + """ + docstring = self.get_docstring() or "No description provided" + fields = self.get_instance_fields() + + if template: + return template.format(docstring=docstring, fields=fields) + + lines = [f"Input: {docstring}", "", "Provided data:"] + for field_name, field_value in fields.items(): + lines.append(f"- {field_name}: {field_value}") + + return "\n".join(lines) + + +class StructuredOutput(StructuredData): + """ + Base class for structured outputs from LLM functions. + """ + + def __str__(self, template: Optional[str] = None) -> str: + """ + Convert the structured output to a string format emphasizing results. + """ + docstring = self.get_docstring() or "No description provided" + fields = self.get_instance_fields() + + if template: + return template.format(docstring=docstring, fields=fields) + + lines = [f"Output: {docstring}", "", "Results:"] + for field_name, field_value in fields.items(): + lines.append(f"- {field_name}: {field_value}") + + return "\n".join(lines) + + +# ======= ====== + +# ======= Agentic Optimizer Support ======= + +# ======= ======= From bcd7fc46f42957e28401d82e255dcd8556a1a61a Mon Sep 17 00:00:00 2001 From: windweller Date: Mon, 13 Oct 2025 15:56:13 -0400 Subject: [PATCH 12/14] modify decorator so that the LLM can be passed in. Not finished -- make input output traced nodes --- .../experimental_optimizers/agentic_opt.py | 2 + opto/features/flows/compose.py | 185 +++++++++++++++--- 2 files changed, 155 insertions(+), 32 deletions(-) diff --git a/opto/features/experimental_optimizers/agentic_opt.py b/opto/features/experimental_optimizers/agentic_opt.py index 106e2abb..831572ba 100644 --- a/opto/features/experimental_optimizers/agentic_opt.py +++ b/opto/features/experimental_optimizers/agentic_opt.py @@ -31,6 +31,8 @@ 1. bug fix loop 2. external reward loop +initial task prompt -> initial solution +initial solution -> improvement prompt -> improved solution -> improvement prompt -> improved solution """ diff --git a/opto/features/flows/compose.py b/opto/features/flows/compose.py index 4e0be1ae..900e44c1 100644 --- a/opto/features/flows/compose.py +++ b/opto/features/flows/compose.py @@ -252,18 +252,134 @@ def evaluate_person(person: Person) -> Preference: person = Person(name="Alice", age=30, income=75000) preference = evaluate_person(person) -TODO: add LLM call and parsing logic TODO 2: add trace bundle and input/output conversion """ -def llm_call(func: Callable): + +class StructuredLLMCallable: + """ + Wrapper class that makes a decorated function callable and automatically invokes the LLM. """ - Decorator that extracts input/output schemas from type-annotated functions. + + def __init__(self, func: Callable, llm, input_type, output_type): + self.func = func + self.llm = llm + self.input_type = input_type + self.output_type = output_type + + # Store schemas + self.input_schema = input_type.model_json_schema() if input_type else None + self.output_schema = output_type.model_json_schema() if output_type else None + + # Copy function metadata + self.__name__ = func.__name__ + self.__doc__ = func.__doc__ + self.__module__ = func.__module__ + self.__annotations__ = func.__annotations__ + + def __call__(self, input_data: StructuredInput, system_prompt: Optional[str] = None) -> StructuredOutput: + """ + Automatically invoke the LLM with the input data. + + Args: + input_data: Instance of StructuredInput + system_prompt: Optional custom system prompt. If not provided, uses default. + + Returns: + Instance of StructuredOutput + """ + # Validate input type + if not isinstance(input_data, self.input_type): + raise TypeError(f"Expected input of type {self.input_type}, got {type(input_data)}") + + # Convert input to string representation for LLM + input_str = str(input_data) + + # Get function docstring as task description + func_doc = self.func.__doc__ or "Process the input data" + + # Build system prompt + if system_prompt is None: + output_fields = list(self.get_output_fields().keys()) + system_prompt = f"""You are a helpful assistant that performs the following task: {func_doc} + +You will receive input data and must produce output in JSON format with the following fields: {output_fields} + +Output description: {self.output_type.get_docstring() or 'Structured output'} + +Always respond with valid JSON only, no additional text.""" + + # Build user message with input data + user_message = f"""{input_str} + +Please respond with a JSON object containing the required output fields.""" + + # Construct messages in the format expected by the LLM + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_message} + ] + + # Invoke LLM + response = self.llm(messages=messages) + + # Parse response into StructuredOutput + # Note: Assumes LLM returns JSON string + try: + output_instance = self.output_type.model_validate_json(response) + except Exception: + # Fallback: try parsing as dict + import json + try: + output_dict = json.loads(response) + output_instance = self.output_type(**output_dict) + except Exception as e: + raise ValueError(f"Failed to parse LLM response into {self.output_type}: {e}\nResponse: {response}") + + return output_instance + + def get_input_docstring(self) -> Optional[str]: + return self.input_type.get_docstring() if self.input_type else None + + def get_output_docstring(self) -> Optional[str]: + return self.output_type.get_docstring() if self.output_type else None + + def get_input_fields(self) -> Dict[str, Any]: + return self.input_type.get_fields_info() if self.input_type else {} + + def get_output_fields(self) -> Dict[str, Any]: + return self.output_type.get_fields_info() if self.output_type else {} + + def __repr__(self): + return f"" + + +def llm_call(func: Callable = None, *, llm=None, **kwargs): + """ + Decorator that extracts input/output schemas from type-annotated functions + and creates a callable that automatically invokes the LLM. + + Args: + func: The function to decorate (automatically passed when used without arguments) + llm: AbstractModel instance to use for LLM calls + **kwargs: Additional LLM configuration parameters Usage: - @call_llm + # Without arguments + @llm_call def process_person(person: Person) -> Preference: - ... + '''Evaluate if a person matches our criteria''' + pass + + # With arguments + @llm_call(llm=customized_llm) + def process_person(person: Person) -> Preference: + '''Evaluate if a person matches our criteria''' + pass + + # Call it directly - LLM is invoked automatically + person = Person(name="Alice", age=30, income=75000) + result = process_person(person) # Returns Preference instance # Access schemas process_person.input_type @@ -271,38 +387,43 @@ def process_person(person: Person) -> Preference: process_person.input_schema process_person.output_schema """ - hints = get_type_hints(func) - # Get first parameter type and return type - params = list(hints.items()) - input_type = None - output_type = None + def decorator(f: Callable): + hints = get_type_hints(f) - # Find first non-return parameter - for param_name, param_type in params: - if param_name != 'return': - input_type = param_type - break + # Get first parameter type and return type + params = list(hints.items()) + input_type = None + output_type = None + + # Find first non-return parameter + for param_name, param_type in params: + if param_name != 'return': + input_type = param_type + break - output_type = hints.get('return') + output_type = hints.get('return') - # Validate types - if input_type and not issubclass(input_type, StructuredInput): - raise TypeError(f"Input type {input_type} must inherit from StructuredInput") + # Validate types + if input_type and not issubclass(input_type, StructuredInput): + raise TypeError(f"Input type {input_type} must inherit from StructuredInput") - if output_type and not issubclass(output_type, StructuredOutput): - raise TypeError(f"Output type {output_type} must inherit from StructuredOutput") + if output_type and not issubclass(output_type, StructuredOutput): + raise TypeError(f"Output type {output_type} must inherit from StructuredOutput") - # Attach metadata to function - func.input_type = input_type - func.output_type = output_type - func.input_schema = input_type.model_json_schema() if input_type else None - func.output_schema = output_type.model_json_schema() if output_type else None + # Use default LLM if none provided + # Note: Replace this with your actual default LLM initialization + actual_llm = llm + if actual_llm is None: + actual_llm = LLM() # we use the default LLM - # Additional helper methods - func.get_input_docstring = lambda: input_type.get_docstring() if input_type else None - func.get_output_docstring = lambda: output_type.get_docstring() if output_type else None - func.get_input_fields = lambda: input_type.get_fields_info() if input_type else {} - func.get_output_fields = lambda: output_type.get_fields_info() if output_type else {} + # Create and return the callable wrapper + return StructuredLLMCallable(f, actual_llm, input_type, output_type) - return func \ No newline at end of file + # Handle both @llm_call and @llm_call(...) syntax + if func is None: + # Called with arguments: @llm_call(llm=custom_llm) + return decorator + else: + # Called without arguments: @llm_call + return decorator(func) \ No newline at end of file From 9c3722e7ecbdec10dac7dc080a74cfff0adbaaa4 Mon Sep 17 00:00:00 2001 From: windweller Date: Tue, 21 Oct 2025 15:23:17 -0400 Subject: [PATCH 13/14] re-organize the module --- .../experimental_optimizers/agentic_opt.py | 187 +++++++- opto/features/flows/compose.py | 429 ------------------ opto/features/flows/compose/__init__.py | 5 + opto/features/flows/compose/agentic_ops.py | 354 +++++++++++++++ opto/features/flows/compose/llm.py | 418 +++++++++++++++++ opto/features/flows/compose/parser.py | 200 ++++++++ opto/features/flows/types.py | 9 +- 7 files changed, 1157 insertions(+), 445 deletions(-) delete mode 100644 opto/features/flows/compose.py create mode 100644 opto/features/flows/compose/__init__.py create mode 100644 opto/features/flows/compose/agentic_ops.py create mode 100644 opto/features/flows/compose/llm.py create mode 100644 opto/features/flows/compose/parser.py diff --git a/opto/features/experimental_optimizers/agentic_opt.py b/opto/features/experimental_optimizers/agentic_opt.py index 831572ba..b1decb83 100644 --- a/opto/features/experimental_optimizers/agentic_opt.py +++ b/opto/features/experimental_optimizers/agentic_opt.py @@ -16,14 +16,18 @@ from opto.optimizers.optoprime_v2 import OptoPrimeV2, OptimizerPromptSymbolSet from opto.optimizers.optimizer import Optimizer +from opto.trainer.guide import Guide from opto.utils.llm import AbstractModel, LLM -from typing import Any, List, Dict, Union, Tuple, Optional +from typing import Any, Callable, Dict, List, Tuple, Optional + +from opto.features.flows.compose import Loop, ChatHistory, StopCondition, Check """ A few design that it must have: -1. multi-turn conversation by default (memory management) -2. can take in tools (RAG in particular, but MCP servers as well) +1. **multi-turn conversation by default (memory management)** +2. Flexibility to do tool-use +3. has scaffolding functions people can call to design their custom optimizer First an abstract agent with the features, then implement? @@ -33,26 +37,179 @@ initial task prompt -> initial solution initial solution -> improvement prompt -> improved solution -> improvement prompt -> improved solution + +(inside flow) Loop (use a boolean reward function) to keep executing + +Inherits this optimizer +and specify the main forward by just calling different scaffolding functions """ +# Base class that provides scaffolding class AgenticOptimizer(Optimizer): def __init__( self, parameters: List[ParameterNode], - llm: AbstractModel = None, *args, propagator: Propagator = None, - objective: Union[None, str] = None, - ignore_extraction_error: bool = True, - # ignore the type conversion error when extracting updated values from LLM's suggestion - include_example=False, - memory_size=0, # Memory size to store the past feedback - max_tokens=4096, - log=True, - initial_var_char_limit=2000, - optimizer_prompt_symbol_set: OptimizerPromptSymbolSet = None, - use_json_object_format=True, # whether to use json object format for the response when calling LLM - **kwargs, + **kwargs ): pass + + +""" +Add a veribench optimizer here. +A code optimizer is general, can work with kernel, others + +initial_user_message -> initial_code -> bug fix here +initial_code -> improvement_prompt -> improved_code -> improvement_prompt -> improved_code + +initial_code + 1. check if there is a bug. If there is, try 10 times to fix. return the code --> make it correct/compilable + 2. take an improvement step --> make it run faster + +improve initial code, bug fix 10 times -> improve bug-fixed code, bug fix 10 times -> ... + +optimizer = (improve initial code, bug fix 10 times) +priority_search + +TODO: +1. Make sure guide can be deep-copied +2. Declare a node that's trainable, we can have an initializer. When the node is created, we call the initializer. + +add an init function to the optimizer API + +define a class of Initializer + +init function of the optimizer can receive some arguments + +standardize optimizer API: +- objective, context + +State of the worker: the worker decides what to put into the state +(each candidate has its own optimizer) +(receive a candidate) (update rule is stateless) + +Flip the order between + +Use initializer on node, and then use **projection** on the node +Optimizer only focuses on improvement + +LLM-based initializer, LLM-based projection + +============= +planning agent: plan -> coding agent: write the code + +""" + +class CodeOptimizer(Optimizer): + def __init__( + self, + parameters: List[ParameterNode], + llm: AbstractModel = None, + *args, + propagator: Propagator = None, + bug_judge: Guide = None, # compiler() + reward_function: Callable[[str], float] = None, + max_bug_fix_tries: int = 5, + max_optimization_tries: int = 10, + chat_max_len: int = 25, + **kwargs + ): + super().__init__(parameters, *args, propagator=propagator, **kwargs) + + self.llm = llm or LLM() + + assert len(parameters) == 1, "CodeOptimizer expects a single ParameterNode as input" + self.init_code = parameters[0] + + # Initialize chat history + self.chat_history = ChatHistory(max_round=chat_max_len, auto_summary=False) + + # Store environment checker and reward function + self.bug_judge = bug_judge + self.max_bug_fix_tries = max_bug_fix_tries + self.max_optimization_tries = max_optimization_tries + + self.task_description = None + self.initial_instruction = None + + # 2. Do a single improvement step + + def initial_context(self, task_description, initial_instruction): + """ + This provides the history of how the initial code was produced + """ + self.task_description = task_description + self.initial_instruction = initial_instruction + self.chat_history.add_system_message(self.task_description) + self.chat_history.add(self.initial_instruction,'user') + self.chat_history.add(self.init_code, 'assistant') + + def bug_fix_step(self, lean4_code: str, max_try: int = 5) -> str: + """ + This function is used to self-correct the Lean 4 code. + It will be called by the LLM when the code does not compile. + """ + + # apply heuristic fixes + lean4_code = self.remove_import_error(lean4_code) + valid, error_details = self.bug_judge(lean4_code) + + if valid: + return lean4_code + + temp_conv_hist = self.chat_history.copy() + + counter = 0 + while not valid and counter < max_try: + # sometimes LLM will hallucinate import error, so we remove that import statement + + print(f"Attempt {counter+1}: Fixing compilation errors") + + detailed_error_message = self.concat_error_messages(error_details) + + temp_conv_hist.add(detailed_error_message + "\n\n" + f"Lean code compilation FAILED with {len(error_details)} errors. If a theorem keeps giving error, you can use := sorry to skip it. Please wrap your lean code in ```lean and ```", + "user") + + raw_program = self.llm(temp_conv_hist.get_messages(), verbose=False) + + lean4_code = self.simple_post_process(raw_program) + lean4_code = self.remove_import_error(lean4_code) + + valid, error_details = self.bug_judge(lean4_code) + + if valid: + print(f"Successfully fixed errors after {counter} attempts") + # we add to the round + self.chat_history = self.chat_history + temp_conv_hist[-1].remove_system_message() + return lean4_code + else: + counter += 1 + temp_conv_hist.add_message(lean4_code, "assistant") + + return lean4_code + + def _step(self, verbose=False, *args, **kwargs) -> Dict[ParameterNode, Any]: + """ + Each step, we perform bug fix for a few rounds, then do one improvement step + We add everything to the chat history + """ + lean4_code = self.bug_fix_step(self.init_code.data, max_try=self.max_bug_fix_tries) + # do one step improvement + lean4_code = self.improve_step(lean4_code) + + return lean4_code + + def _extract_code(self, response: str) -> str: + """Extract code from markdown code blocks.""" + import re + # Match python code blocks + pattern = r'```python\n(.*?)```' + matches = re.findall(pattern, response, re.DOTALL) + + if matches: + return matches[0].strip() + + # If no code block found, return the response as-is + return response.strip() diff --git a/opto/features/flows/compose.py b/opto/features/flows/compose.py deleted file mode 100644 index 900e44c1..00000000 --- a/opto/features/flows/compose.py +++ /dev/null @@ -1,429 +0,0 @@ -import opto.trace as trace -from typing import Union, get_type_hints, Any, Dict, List, Optional, Callable -from opto.utils.llm import AbstractModel, LLM -from opto.features.flows.types import MultiModalPayload, QueryModel, StructuredInput, StructuredOutput -import contextvars - -# =========== LLM Base Model =========== -""" -TracedLLM: -1. special operations that supports specifying inputs (system_prompt, user_prompt) to LLM and parsing of outputs, wrap - everything under one command. -2. Easy to use interface -- can be inherited by users. -3. Support multi-turn chatting (message history) - -Usage patterns: - -Direct use: (only supports single input, single output) (signature: str -> str) -llm = TracedLLM("You are a helpful assistant.") -response = llm("Hello, what's the weather in France today?") -""" - -USED_TracedLLM = contextvars.ContextVar('USED_TracedLLM', default=list()) - - -class ChatHistory: - def __init__(self, max_len=50, auto_summary=False): - """Initialize chat history for multi-turn conversation. - - Args: - max_len: Maximum number of messages to keep in history - auto_summary: Whether to automatically summarize old messages - """ - self.messages: List[Dict[str, Any]] = [] - self.max_len = max_len - self.auto_summary = auto_summary - - def __len__(self): - return len(self.messages) - - def add(self, content: Union[trace.Node, str], role): - """Add a message to history with role validation. - - Args: - content: The content of the message - role: The role of the message ("user" or "assistant") - """ - if role not in ["user", "assistant"]: - raise ValueError(f"Invalid role '{role}'. Must be 'user' or 'assistant'.") - - # Check for alternating user/assistant pattern - if len(self.messages) > 0: - last_msg = self.messages[-1] - if last_msg["role"] == role: - print(f"Warning: Adding consecutive {role} messages. Consider alternating user/assistant messages.") - - self.messages.append({"role": role, "content": content}) - self._trim_history() - - def append(self, message: Dict[str, Any]): - """Append a message directly to history.""" - if "role" not in message or "content" not in message: - raise ValueError("Message must have 'role' and 'content' fields.") - self.add(message["content"], message["role"]) - - def __iter__(self): - return iter(self.messages) - - def get_messages(self) -> List[Dict[str, str]]: - messages = [] - for message in self.messages: - if isinstance(message['content'], trace.Node): - messages.append({"role": message["role"], "content": message["content"].data}) - else: - messages.append(message) - return messages - - def get_messages_as_node(self, llm_name="") -> List[trace.Node]: - node_list = [] - for message in self.messages: - # If user query is a node and has other computation attached, we can't rename it - if isinstance(message['content'], trace.Node): - node_list.append(message['content']) - else: - role = message["role"] - content = message["content"] - name = f"{llm_name}_{role}" if llm_name else f"{role}" - if role == 'user': - name += "_query" - elif role == 'assistant': - name += "_response" - node_list.append(trace.node(content, name=name)) - - return node_list - - def _trim_history(self): - """Trim history to max_len while preserving first user message.""" - if len(self.messages) <= self.max_len: - return - - # Find first user message index - first_user_idx = None - for i, msg in enumerate(self.messages): - if msg["role"] == "user": - first_user_idx = i - break - - # Keep first user message - protected_messages = [] - if first_user_idx is not None: - first_user_msg = self.messages[first_user_idx] - protected_messages.append(first_user_msg) - - # Calculate how many recent messages we can keep - remaining_slots = self.max_len - len(protected_messages) - if remaining_slots > 0: - # Get recent messages - recent_messages = self.messages[-remaining_slots:] - # Avoid duplicating first user message - if first_user_idx is not None: - first_user_msg = self.messages[first_user_idx] - recent_messages = [msg for msg in recent_messages if msg != first_user_msg] - - self.messages = protected_messages + recent_messages - else: - self.messages = protected_messages - - -DEFAULT_SYSTEM_PROMPT_DESCRIPTION = ("the system prompt to the agent. By tuning this prompt, we can control the " - "behavior of the agent. For example, it can be used to provide instructions to " - "the agent (such as how to reason about the problem, how to use tools, " - "how to answer the question), or provide in-context examples of how to solve the " - "problem.") - - -@trace.model -class TracedLLM: - """ - This high-level model provides an easy-to-use interface for LLM calls with system prompts and optional chat history. - - Python usage patterns: - - llm = UF_LLM(system_prompt) - response = llm.chat(user_prompt) - response_2 = llm.chat(user_prompt_2) - - The underlying Trace Graph: - TracedLLM_response0 = TracedLLM.forward.call_llm(args_0=system_prompt0, args_1=TracedLLM0_user_query0) - TracedLLM_response1 = TracedLLM.forward.call_llm(args_0=system_prompt0, args_1=TracedLLM0_user_query0, args_2=TracedLLM_response0, args_3=TracedLLM0_user_query1) - TracedLLM_response2 = TracedLLM.forward.call_llm(args_0=system_prompt0, args_1=TracedLLM0_user_query0, args_2=TracedLLM_response0, args_3=TracedLLM0_user_query1, args_4=TracedLLM_response1, args_5=TracedLLM0_user_query2) - """ - - def __init__(self, - system_prompt: Union[str, None, trace.Node] = None, - llm: AbstractModel = None, chat_history_on=False, - trainable=False, model_name=None): - """Initialize TracedLLM with a system prompt. - - Args: - system_prompt: The system prompt to use for LLM calls. If None and the class has a docstring, the docstring will be used. - llm: The LLM model to use for inference - chat_history_on: if on, maintain chat history for multi-turn conversations - model_name: override the default name of the model - """ - if system_prompt is None: - system_prompt = "You are a helpful assistant." - - self.system_prompt = trace.node(system_prompt, name='system_prompt', - description=DEFAULT_SYSTEM_PROMPT_DESCRIPTION, - trainable=trainable) - # if system_prompt is already a node, then we have to override its trainable attribute - self.system_prompt.trainable = trainable - - if llm is None: - llm = LLM() - assert isinstance(llm, AbstractModel), f"{llm} must be an instance of AbstractModel" - self.llm = llm - self.chat_history = ChatHistory() - self.chat_history_on = chat_history_on - - current_llm_sessions = USED_TracedLLM.get() - self.model_name = model_name if model_name else f"{self.__class__.__name__}{len(current_llm_sessions)}" - current_llm_sessions.append(1) # just a marker - - def forward(self, user_query: str, - payload: Optional[MultiModalPayload] = None, - chat_history_on: Optional[bool] = None) -> str: - """This function takes user_query as input, and returns the response from the LLM, with the system prompt prepended. - This method will always save chat history. - - If chat_history_on is set to False, the chat history will not be included in the LLM input. - If chat_history_on is None, it will use the class-level chat_history_on setting. - If chat_history_on is True, the chat history will be included in the LLM input. - - Args: - user_query: The user query to send to the LLM. Can be - - Returns: - str: For direct pattern - """ - chat_history_on = self.chat_history_on if chat_history_on is None else chat_history_on - - user_message = QueryModel(query=user_query, multimodal_payload=payload).query - - messages = [{"role": "system", "content": self.system_prompt.data}] - if chat_history_on: - messages.extend(self.chat_history.get_messages()) - messages.append({"role": "user", "content": user_message}) - - response = self.llm(messages=messages) - - @trace.bundle(output_name=f"{self.model_name}_response") - def call_llm(*messages) -> str: - """Call the LLM model. - Args: - messages: All the conversation history so far, starting from system prompt, to alternating user/assistant messages, ending with the current user query. - Returns: - response from the LLM - """ - return response.choices[0].message.content - - user_query_node = trace.node(user_query, name=f"{self.model_name}_user_query") - arg_list = [self.system_prompt] - if chat_history_on: - arg_list += self.chat_history.get_messages_as_node(self.model_name) - arg_list += [user_query_node] - - response_node = call_llm(*arg_list) - - # save to chat history - self.chat_history.add(user_query_node, role="user") - self.chat_history.add(response_node, role="assistant") - - return response_node - - def chat(self, user_query: str, payload: Optional[MultiModalPayload] = None, chat_history_on: Optional[bool] = None) -> str: - """Note that chat/forward always assumes it's a single turn of the conversation. History/context management will be accomplished - through other APIs""" - return self.forward(user_query, payload, chat_history_on) - -# =========== =========== - -# =========== Structured LLM Input/Output With Parsing =========== - -""" -Usage: - -@llm_call -def evaluate_person(person: Person) -> Preference: - "Evaluate if a person matches our criteria" - ... - -person = Person(name="Alice", age=30, income=75000) -preference = evaluate_person(person) - -TODO 2: add trace bundle and input/output conversion -""" - - -class StructuredLLMCallable: - """ - Wrapper class that makes a decorated function callable and automatically invokes the LLM. - """ - - def __init__(self, func: Callable, llm, input_type, output_type): - self.func = func - self.llm = llm - self.input_type = input_type - self.output_type = output_type - - # Store schemas - self.input_schema = input_type.model_json_schema() if input_type else None - self.output_schema = output_type.model_json_schema() if output_type else None - - # Copy function metadata - self.__name__ = func.__name__ - self.__doc__ = func.__doc__ - self.__module__ = func.__module__ - self.__annotations__ = func.__annotations__ - - def __call__(self, input_data: StructuredInput, system_prompt: Optional[str] = None) -> StructuredOutput: - """ - Automatically invoke the LLM with the input data. - - Args: - input_data: Instance of StructuredInput - system_prompt: Optional custom system prompt. If not provided, uses default. - - Returns: - Instance of StructuredOutput - """ - # Validate input type - if not isinstance(input_data, self.input_type): - raise TypeError(f"Expected input of type {self.input_type}, got {type(input_data)}") - - # Convert input to string representation for LLM - input_str = str(input_data) - - # Get function docstring as task description - func_doc = self.func.__doc__ or "Process the input data" - - # Build system prompt - if system_prompt is None: - output_fields = list(self.get_output_fields().keys()) - system_prompt = f"""You are a helpful assistant that performs the following task: {func_doc} - -You will receive input data and must produce output in JSON format with the following fields: {output_fields} - -Output description: {self.output_type.get_docstring() or 'Structured output'} - -Always respond with valid JSON only, no additional text.""" - - # Build user message with input data - user_message = f"""{input_str} - -Please respond with a JSON object containing the required output fields.""" - - # Construct messages in the format expected by the LLM - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_message} - ] - - # Invoke LLM - response = self.llm(messages=messages) - - # Parse response into StructuredOutput - # Note: Assumes LLM returns JSON string - try: - output_instance = self.output_type.model_validate_json(response) - except Exception: - # Fallback: try parsing as dict - import json - try: - output_dict = json.loads(response) - output_instance = self.output_type(**output_dict) - except Exception as e: - raise ValueError(f"Failed to parse LLM response into {self.output_type}: {e}\nResponse: {response}") - - return output_instance - - def get_input_docstring(self) -> Optional[str]: - return self.input_type.get_docstring() if self.input_type else None - - def get_output_docstring(self) -> Optional[str]: - return self.output_type.get_docstring() if self.output_type else None - - def get_input_fields(self) -> Dict[str, Any]: - return self.input_type.get_fields_info() if self.input_type else {} - - def get_output_fields(self) -> Dict[str, Any]: - return self.output_type.get_fields_info() if self.output_type else {} - - def __repr__(self): - return f"" - - -def llm_call(func: Callable = None, *, llm=None, **kwargs): - """ - Decorator that extracts input/output schemas from type-annotated functions - and creates a callable that automatically invokes the LLM. - - Args: - func: The function to decorate (automatically passed when used without arguments) - llm: AbstractModel instance to use for LLM calls - **kwargs: Additional LLM configuration parameters - - Usage: - # Without arguments - @llm_call - def process_person(person: Person) -> Preference: - '''Evaluate if a person matches our criteria''' - pass - - # With arguments - @llm_call(llm=customized_llm) - def process_person(person: Person) -> Preference: - '''Evaluate if a person matches our criteria''' - pass - - # Call it directly - LLM is invoked automatically - person = Person(name="Alice", age=30, income=75000) - result = process_person(person) # Returns Preference instance - - # Access schemas - process_person.input_type - process_person.output_type - process_person.input_schema - process_person.output_schema - """ - - def decorator(f: Callable): - hints = get_type_hints(f) - - # Get first parameter type and return type - params = list(hints.items()) - input_type = None - output_type = None - - # Find first non-return parameter - for param_name, param_type in params: - if param_name != 'return': - input_type = param_type - break - - output_type = hints.get('return') - - # Validate types - if input_type and not issubclass(input_type, StructuredInput): - raise TypeError(f"Input type {input_type} must inherit from StructuredInput") - - if output_type and not issubclass(output_type, StructuredOutput): - raise TypeError(f"Output type {output_type} must inherit from StructuredOutput") - - # Use default LLM if none provided - # Note: Replace this with your actual default LLM initialization - actual_llm = llm - if actual_llm is None: - actual_llm = LLM() # we use the default LLM - - # Create and return the callable wrapper - return StructuredLLMCallable(f, actual_llm, input_type, output_type) - - # Handle both @llm_call and @llm_call(...) syntax - if func is None: - # Called with arguments: @llm_call(llm=custom_llm) - return decorator - else: - # Called without arguments: @llm_call - return decorator(func) \ No newline at end of file diff --git a/opto/features/flows/compose/__init__.py b/opto/features/flows/compose/__init__.py new file mode 100644 index 00000000..2950f740 --- /dev/null +++ b/opto/features/flows/compose/__init__.py @@ -0,0 +1,5 @@ +from opto.features.flows.compose.llm import TracedLLM, ChatHistory +from opto.features.flows.compose.parser import llm_call +from opto.features.flows.compose.agentic_ops import Loop, StopCondition, Check + +__all__ = ["TracedLLM", "ChatHistory", "llm_call", "Loop", "StopCondition", "Check"] \ No newline at end of file diff --git a/opto/features/flows/compose/agentic_ops.py b/opto/features/flows/compose/agentic_ops.py new file mode 100644 index 00000000..eb9c2f68 --- /dev/null +++ b/opto/features/flows/compose/agentic_ops.py @@ -0,0 +1,354 @@ +import opto.trace as trace +from typing import Tuple, Union, get_type_hints, Any, Dict, List, Optional, Callable +from opto.utils.llm import AbstractModel, LLM +from opto.features.flows.types import MultiModalPayload, QueryModel, StructuredInput, StructuredOutput, \ + ForwardMixin +from opto.trainer.guide import Guide +import numpy as np +import contextvars + +# =========== Mixin for Agentic Optimizer =========== + +""" + +@llm_call +def auto_correct(solution: InitialSolution) -> ImprovedSolution: + ... + +class InitialSolution: + "This is the intiial solution of a coding problem" + +class ImprovedSolution: + name: str + +a = auto_correct(solution) + +class CodeOptimizer: + def __init__(self): + # self.auto_correct = Loop(auto_correct) + # self.auto_correct = Loop(self.auto_correct_op) + self.auto_improve = Loop() + + def auto_correct_op(self, solution): + prompt = "\n improve this." + llm_response = self.llm(prompt + solution) + return llm_response + + def forward(self): + next_phase = self.auto_correct(solution) + return self.auto_improve(next_phase) + +TODO: +1. Support different styles of calling/initializing loop. Need to cover all use cases. +""" + + +class StopCondition(ForwardMixin): + """ + A stop condition for the loop. It can be a callable class instance or a function. + A simple stop condition is not necessary to be a subclass of this class. + + Implement an init method to pass in extra info into the stop condition + + Example Usage: + + class MaxIterationsOrConverged: + def __init__(self, max_iters=5, threshold=0.01): + self.max_iters = max_iters + self.threshold = threshold + + def __call__(self, param_history: List, result_history: List) -> bool: + # Stop if we've done enough iterations + if len(result_history) >= self.max_iters: + return True + + # Stop if results have converged (example) + if len(result_history) >= 2: + diff = abs(result_history[-1] - result_history[-2]) + if diff < self.threshold: + return True + + return False + """ + + def forward(self, param_history: List, result_history: List) -> bool: + """The Loop will call the stop condition with the param_history and result_history. + The stop condition should return a boolean value. + """ + raise NotImplementedError("Need to implement the forward function") + + +class Loop(ForwardMixin): + + def __init__(self, func: Callable[[Any], Any], stop_condition: Any = None): + assert callable(func), "func must be a callable" + + self.stop_condition = stop_condition + self.func = func + if stop_condition is not None: + self.check_stop_condition(stop_condition) + + def check_stop_condition(self, stop_condition: Any): + # Check if it's a callable class (has __call__ method and is an instance) + if not callable(stop_condition): + raise TypeError("stop_condition must be callable") + + # Check if it's a class instance (not a function or method) + # Functions/methods don't have __dict__ or they have __self__ for bound methods + if not hasattr(stop_condition, '__dict__') and not hasattr(stop_condition, '__self__'): + raise TypeError("stop_condition must be a callable class instance, not a function") + + # Get type hints from the __call__ method + try: + hints = get_type_hints(stop_condition.__call__) + if 'return' in hints: + # Optional: validate that return type annotation is bool + assert hints['return'] == bool, \ + f"stop_condition.__call__ must be annotated to return bool, got {hints['return']}" + except AttributeError: + pass # If __call__ doesn't have type hints, skip validation + + def step(self, *current_params: Any) -> Tuple[Any, Dict[str, Any]]: + """Override this method to define the loop logic. If `func` is passed in during init, this is not necessary.""" + if self.func is not None: + result, info = self.func(*current_params) + return result, info + raise NotImplementedError("Must provide func during initialization or override step method") + + def forward(self, *args, max_try: int = 10, stop_condition: Any = None) -> Tuple[ + Any, List[Dict[str, Any]]]: + """ + Execute the loop with the given initial parameters. + + Args: + *args: Initial parameters to pass to the function + max_try: Maximum number of iterations + stop_condition: Optional stop condition to override the one from __init__ + + Returns: + Tuple of (final_params, result_history) + """ + param_history = [] + result_history = [] + current_params = args + + # Use the stop_condition from forward() if provided, otherwise use the one from __init__ + active_stop_condition = stop_condition if stop_condition is not None else self.stop_condition + + # Check if the initial parameter is already good enough (only if stop_condition exists) + if active_stop_condition is not None: + should_stop = active_stop_condition(param_history, result_history) + assert isinstance(should_stop, (bool, np.bool_)), \ + f"stop_condition must return a boolean value, got {type(should_stop)}" + + if should_stop: + return current_params if len(current_params) > 1 else ( + current_params[0] if current_params else None), result_history + + for iteration in range(max_try): + # Track the parameters before calling step + param_history.append( + current_params if len(current_params) > 1 else (current_params[0] if current_params else None)) + + # Update current_params using step method + result, info = self.step(*current_params) + + # Normalize result to always be a tuple for consistency + if not isinstance(result, tuple): + current_params = (result,) + else: + current_params = result + + # Track the result from step + result_history.append(info) + + # Check stop condition after each step (only if it exists) + if active_stop_condition is not None: + should_stop = active_stop_condition(param_history, result_history) + assert isinstance(should_stop, (bool, np.bool_)), \ + f"stop_condition must return a boolean value, got {type(should_stop)}" + + if should_stop: + break + + # Return unpacked if single param, tuple if multiple params + final_result = current_params if len(current_params) > 1 else (current_params[0] if current_params else None) + return final_result, result_history + + +""" +Usage patterns: + +Check(should_optimize, value) + .then(lambda v: Loop(optimize_step, stop_condition)()) + .or_else(skip_optimization)() + +class ConditionalUpdateLoop(Loop): + def step(self, param): + return Check(needs_large_step, param) + .then(large_update) + .or_else(small_update)() + +Check(validate_data, data).then( + lambda d: Check(needs_preprocessing, d) + .then(lambda: Loop(preprocess, QualityStop())()) + .or_else(process_directly)() +).or_else(reject)() +""" + + +class Check(ForwardMixin): + """A DSL for conditional execution with fluent interface.""" + + def __init__(self, condition_func, *args, **kwargs): + """ + Initialize the Check with a condition function and its arguments. + + Args: + condition_func: A callable that returns a truthy/falsy value + *args: Positional arguments to pass to the condition function + **kwargs: Keyword arguments to pass to the condition function + """ + self.condition_func = condition_func + self.args = args + self.kwargs = kwargs + self.condition_result = None + self.condition_evaluated = False + self.then_func = None + self.then_args = None + self.then_kwargs = None + self.elif_branches = [] # List of (condition, func, args, kwargs) tuples + self.else_func = None + self.else_args = None + self.else_kwargs = None + self.do_func = None + self.do_args = None + self.do_kwargs = None + + def _evaluate_condition(self): + """Lazily evaluate the condition function.""" + if not self.condition_evaluated: + result = self.condition_func(*self.args, **self.kwargs) + # Store both the truthiness and the actual return value + self.condition_result = result + self.condition_evaluated = True + return self.condition_result + + def then(self, callback_func, *extra_args, **extra_kwargs): + """ + Define the function to execute if the condition is truthy. + + Args: + callback_func: The function to execute if condition is true + *extra_args: Additional positional arguments for the callback + **extra_kwargs: Additional keyword arguments for the callback + """ + self.then_func = callback_func + self.then_args = extra_args + self.then_kwargs = extra_kwargs + return self + + def elseif(self, condition_func, callback_func, *extra_args, **extra_kwargs): + """ + Add an elif branch with its own condition and callback. + + Args: + condition_func: The condition to check if previous conditions were false + callback_func: The function to execute if this condition is true + *extra_args: Additional positional arguments for the callback + **extra_kwargs: Additional keyword arguments for the callback + """ + self.elif_branches.append((condition_func, callback_func, extra_args, extra_kwargs)) + return self + + def or_else(self, callback_func, *extra_args, **extra_kwargs): + """ + Define the function to execute if all conditions are falsy. + + Args: + callback_func: The function to execute if all conditions are false + *extra_args: Additional positional arguments for the callback + **extra_kwargs: Additional keyword arguments for the callback + """ + self.else_func = callback_func + self.else_args = extra_args + self.else_kwargs = extra_kwargs + return self + + # Alternative names for or_else + otherwise = or_else + else_ = or_else + + def do(self, callback_func, *extra_args, **extra_kwargs): + """ + Define a function to execute after all branches, regardless of condition. + + Args: + callback_func: The function to always execute at the end + *extra_args: Additional positional arguments for the callback + **extra_kwargs: Additional keyword arguments for the callback + """ + self.do_func = callback_func + self.do_args = extra_args + self.do_kwargs = extra_kwargs + return self + + def forward(self): + """ + Execute the appropriate callback based on the condition results. + + Returns: + The return value of whichever callback was executed, + or the do callback if no branch was executed. + """ + condition_result = self._evaluate_condition() + execution_result = None + branch_executed = False + + # Check main condition + if condition_result: + if self.then_func: + # Combine original args with then args, plus condition result + all_args = self.args + self.then_args + # If condition returned a non-boolean value, include it + if condition_result is not True: + all_args = (condition_result,) + all_args + all_kwargs = {**self.kwargs, **self.then_kwargs} + execution_result = self.then_func(*all_args, **all_kwargs) + branch_executed = True + + # Check elif branches if main condition was false + if not branch_executed: + for elif_condition, elif_func, elif_args, elif_kwargs in self.elif_branches: + # Evaluate elif condition with original args + elif_result = elif_condition(*self.args, **self.kwargs) + if elif_result: + # Combine args similar to then branch + all_args = self.args + elif_args + if elif_result is not True: + all_args = (elif_result,) + all_args + all_kwargs = {**self.kwargs, **elif_kwargs} + execution_result = elif_func(*all_args, **all_kwargs) + branch_executed = True + break + + # Execute else if no condition was true + if not branch_executed and self.else_func: + # Combine original args with else args + all_args = self.args + self.else_args + all_kwargs = {**self.kwargs, **self.else_kwargs} + execution_result = self.else_func(*all_args, **all_kwargs) + branch_executed = True + + # Always execute do if defined + if self.do_func: + # Do gets the execution result as first arg if there was one + do_args = self.do_args + if execution_result is not None: + do_args = (execution_result,) + do_args + all_kwargs = {**self.kwargs, **self.do_kwargs} + do_result = self.do_func(*do_args, **all_kwargs) + # Return the execution result if there was one, otherwise the do result + return execution_result if execution_result is not None else do_result + + return execution_result \ No newline at end of file diff --git a/opto/features/flows/compose/llm.py b/opto/features/flows/compose/llm.py new file mode 100644 index 00000000..815c695b --- /dev/null +++ b/opto/features/flows/compose/llm.py @@ -0,0 +1,418 @@ +import opto.trace as trace +from typing import Tuple, Union, get_type_hints, Any, Dict, List, Optional, Callable +from opto.utils.llm import AbstractModel, LLM +from opto.features.flows.types import MultiModalPayload, QueryModel, StructuredInput, StructuredOutput, \ + ForwardMixin +from opto.trainer.guide import Guide +import numpy as np +import contextvars + +# =========== LLM Base Model =========== +""" +TracedLLM: +1. special operations that supports specifying inputs (system_prompt, user_prompt) to LLM and parsing of outputs, wrap + everything under one command. +2. Easy to use interface -- can be inherited by users. +3. Support multi-turn chatting (message history) + +Usage patterns: + +Direct use: (only supports single input, single output) (signature: str -> str) +llm = TracedLLM("You are a helpful assistant.") +response = llm("Hello, what's the weather in France today?") +""" + +USED_TracedLLM = contextvars.ContextVar('USED_TracedLLM', default=list()) + +class ChatHistory: + def __init__(self, max_round=25, auto_summary=False): + """Initialize chat history for multi-turn conversation. + + Args: + max_round: Maximum number of conversation rounds (user-assistant pairs) to keep in history + auto_summary: Whether to automatically summarize old messages + """ + self.messages: List[Dict[str, Any]] = [] + self.max_round = max_round + self.auto_summary = auto_summary + + def __len__(self): + return len(self.messages) + + def add_system_message(self, content: Union[trace.Node, str]): + """Add or replace a system message at the beginning of the chat history. + + Args: + content: The content of the system message + """ + # Check if the first message is a system message + if len(self.messages) > 0 and self.messages[0].get("role") == "system": + print("Warning: Replacing existing system message.") + self.messages[0] = {"role": "system", "content": content} + else: + # Insert system message at the beginning + self.messages.insert(0, {"role": "system", "content": content}) + self._trim_history() + + def add_message(self, content: Union[trace.Node, str], role): + """Alias for add""" + return self.add(content, role) + + def add(self, content: Union[trace.Node, str], role): + """Add a message to history with role validation. + + Args: + content: The content of the message + role: The role of the message ("user" or "assistant") + """ + if role not in ["user", "assistant"]: + raise ValueError(f"Invalid role '{role}'. Must be 'user' or 'assistant'.") + + # Check for alternating user/assistant pattern + if len(self.messages) > 0: + last_msg = self.messages[-1] + if last_msg["role"] == role: + print(f"Warning: Adding consecutive {role} messages. Consider alternating user/assistant messages.") + + self.messages.append({"role": role, "content": content}) + self._trim_history() + + def append(self, message: Dict[str, Any]): + """Append a message directly to history.""" + if "role" not in message or "content" not in message or type(message) is not dict: + raise ValueError("Message must have 'role' and 'content' fields.") + self.add(message["content"], message["role"]) + + def __iter__(self): + return iter(self.messages) + + def __getitem__(self, index): + """Get a specific round or slice of rounds as a ChatHistory object. + + Args: + index: Integer index or slice object + + Returns: + ChatHistory: A new ChatHistory object containing the selected round(s). + Each round includes [system_prompt, user_prompt, response] where system_prompt is + included if it exists in the chat history. + """ + # Get user and assistant message indices + user_indices = [i for i, msg in enumerate(self.messages) if msg["role"] == "user"] + assistant_indices = [i for i, msg in enumerate(self.messages) if msg["role"] == "assistant"] + + # Build rounds (user-assistant pairs that are complete) + rounds = [] + for user_idx in user_indices: + # Find corresponding assistant response + assistant_msg = None + for asst_idx in assistant_indices: + if asst_idx > user_idx: + assistant_msg = self.messages[asst_idx] + break + + if assistant_msg: + rounds.append([self.messages[user_idx], assistant_msg]) + + # Create new ChatHistory object + new_history = ChatHistory(max_round=self.max_round, auto_summary=self.auto_summary) + + # Handle slicing + if isinstance(index, slice): + selected_rounds = rounds[index] + # Add system message if it exists + if self.messages and self.messages[0].get("role") == "system": + new_history.messages.append(self.messages[0].copy()) + # Add all selected rounds + for round_msgs in selected_rounds: + for msg in round_msgs: + new_history.messages.append({"role": msg["role"], "content": msg["content"]}) + return new_history + + # Handle single index (including negative indexing) + round_msgs = rounds[index] # This will handle negative indices and raise IndexError if out of bounds + + # Add system message if it exists + if self.messages and self.messages[0].get("role") == "system": + new_history.messages.append({"role": self.messages[0]["role"], "content": self.messages[0]["content"]}) + + # Add the selected round + for msg in round_msgs: + new_history.messages.append({"role": msg["role"], "content": msg["content"]}) + + return new_history + + def get_messages(self) -> List[Dict[str, str]]: + messages = [] + for message in self.messages: + if isinstance(message['content'], trace.Node): + messages.append({"role": message["role"], "content": message["content"].data}) + else: + messages.append(message) + return messages + + def get_messages_as_node(self, llm_name="") -> List[trace.Node]: + node_list = [] + for message in self.messages: + # If user query is a node and has other computation attached, we can't rename it + if isinstance(message['content'], trace.Node): + node_list.append(message['content']) + else: + role = message["role"] + content = message["content"] + name = f"{llm_name}_{role}" if llm_name else f"{role}" + if role == 'user': + name += "_query" + elif role == 'assistant': + name += "_response" + node_list.append(trace.node(content, name=name)) + + return node_list + + def copy(self, include_system=True): + """Create a deep copy of the chat history. + + Args: + include_system: Whether to include the system message in the copy (default: True) + + Returns: + ChatHistory: A new ChatHistory instance with the same messages + """ + new_history = ChatHistory(max_round=self.max_round, auto_summary=self.auto_summary) + for message in self.messages: + # Skip system message if include_system is False + if not include_system and message["role"] == "system": + continue + # Create a new dict to avoid reference issues + new_message = {"role": message["role"], "content": message["content"]} + new_history.messages.append(new_message) + return new_history + + def remove_system_message(self): + """Create a copy of the chat history without the system message. + + Returns: + ChatHistory: A new ChatHistory instance without the system message + """ + return self.copy(include_system=False) + + def __add__(self, other): + """Merge two chat histories together. + + Args: + other: Another ChatHistory instance + + Returns: + ChatHistory: A new ChatHistory instance with messages from both histories + + Raises: + TypeError: If other is not a ChatHistory instance + ValueError: If both histories have system prompts + """ + if not isinstance(other, ChatHistory): + raise TypeError("Can only add ChatHistory instances together") + + # Check if both have system messages + has_system_self = self.messages and self.messages[0].get("role") == "system" + has_system_other = other.messages and other.messages[0].get("role") == "system" + + if has_system_self and has_system_other: + raise ValueError("Cannot merge two chat histories that both have system prompts") + + # Create new history with max of the two max_rounds + new_history = ChatHistory( + max_round=max(self.max_round, other.max_round), + auto_summary=self.auto_summary or other.auto_summary + ) + + # Add messages from self + for message in self.messages: + new_message = {"role": message["role"], "content": message["content"]} + new_history.messages.append(new_message) + + # Add messages from other, skipping system message if self already has one + for message in other.messages: + if has_system_self and message["role"] == "system": + continue + new_message = {"role": message["role"], "content": message["content"]} + new_history.messages.append(new_message) + + # Trim the combined history to respect max_round + new_history._trim_history() + + return new_history + + def _trim_history(self): + """Trim history to max_round while preserving system message and the first round.""" + # Count the number of rounds (user-assistant pairs) + user_indices = [i for i, msg in enumerate(self.messages) if msg["role"] == "user"] + assistant_indices = [i for i, msg in enumerate(self.messages) if msg["role"] == "assistant"] + + # If we don't have enough messages to form complete rounds, return + if len(user_indices) <= self.max_round or len(assistant_indices) < len(user_indices): + return + + protected_messages = [] + + # Keep system message if it exists + if self.messages and self.messages[0].get("role") == "system": + protected_messages.append(self.messages[0]) + + # Always keep the first round (first user message and first assistant response) + if user_indices: + first_user_idx = user_indices[0] + protected_messages.append(self.messages[first_user_idx]) + + # Find the first assistant message after the first user message + for i in assistant_indices: + if i > first_user_idx: + protected_messages.append(self.messages[i]) + break + + # Calculate how many recent rounds we can keep + num_protected_rounds = 1 if user_indices else 0 # We've protected the first round if it exists + remaining_rounds = self.max_round - num_protected_rounds + + if remaining_rounds > 0: + # Get the most recent rounds (user-assistant pairs) + recent_rounds = [] + + # Start from the most recent user message and work backwards + for i in range(len(user_indices) - 1, num_protected_rounds - 1, -1): + if len(recent_rounds) // 2 >= remaining_rounds: + break + + user_idx = user_indices[i] + user_msg = self.messages[user_idx] + + # Find the corresponding assistant response + assistant_msg = None + for j in assistant_indices: + if j > user_idx: + assistant_msg = self.messages[j] + break + + if assistant_msg: + # Add the pair in chronological order + recent_rounds = [user_msg, assistant_msg] + recent_rounds + + # Combine protected messages with recent rounds + self.messages = protected_messages + recent_rounds + else: + self.messages = protected_messages + + +DEFAULT_SYSTEM_PROMPT_DESCRIPTION = ("the system prompt to the agent. By tuning this prompt, we can control the " + "behavior of the agent. For example, it can be used to provide instructions to " + "the agent (such as how to reason about the problem, how to use tools, " + "how to answer the question), or provide in-context examples of how to solve the " + "problem.") + + +@trace.model +class TracedLLM: + """ + This high-level model provides an easy-to-use interface for LLM calls with system prompts and optional chat history. + + Python usage patterns: + + llm = UF_LLM(system_prompt) + response = llm.chat(user_prompt) + response_2 = llm.chat(user_prompt_2) + + The underlying Trace Graph: + TracedLLM_response0 = TracedLLM.forward.call_llm(args_0=system_prompt0, args_1=TracedLLM0_user_query0) + TracedLLM_response1 = TracedLLM.forward.call_llm(args_0=system_prompt0, args_1=TracedLLM0_user_query0, args_2=TracedLLM_response0, args_3=TracedLLM0_user_query1) + TracedLLM_response2 = TracedLLM.forward.call_llm(args_0=system_prompt0, args_1=TracedLLM0_user_query0, args_2=TracedLLM_response0, args_3=TracedLLM0_user_query1, args_4=TracedLLM_response1, args_5=TracedLLM0_user_query2) + """ + + def __init__(self, + system_prompt: Union[str, None, trace.Node] = None, + llm: AbstractModel = None, chat_history_on=False, + trainable=False, model_name=None): + """Initialize TracedLLM with a system prompt. + + Args: + system_prompt: The system prompt to use for LLM calls. If None and the class has a docstring, the docstring will be used. + llm: The LLM model to use for inference + chat_history_on: if on, maintain chat history for multi-turn conversations + model_name: override the default name of the model + """ + if system_prompt is None: + system_prompt = "You are a helpful assistant." + + self.system_prompt = trace.node(system_prompt, name='system_prompt', + description=DEFAULT_SYSTEM_PROMPT_DESCRIPTION, + trainable=trainable) + # if system_prompt is already a node, then we have to override its trainable attribute + self.system_prompt.trainable = trainable + + if llm is None: + llm = LLM() + assert isinstance(llm, AbstractModel), f"{llm} must be an instance of AbstractModel" + self.llm = llm + self.chat_history = ChatHistory() + self.chat_history_on = chat_history_on + + current_llm_sessions = USED_TracedLLM.get() + self.model_name = model_name if model_name else f"{self.__class__.__name__}{len(current_llm_sessions)}" + current_llm_sessions.append(1) # just a marker + + def forward(self, user_query: str, + payload: Optional[MultiModalPayload] = None, + chat_history_on: Optional[bool] = None) -> str: + """This function takes user_query as input, and returns the response from the LLM, with the system prompt prepended. + This method will always save chat history. + + If chat_history_on is set to False, the chat history will not be included in the LLM input. + If chat_history_on is None, it will use the class-level chat_history_on setting. + If chat_history_on is True, the chat history will be included in the LLM input. + + Args: + user_query: The user query to send to the LLM. Can be + + Returns: + str: For direct pattern + """ + chat_history_on = self.chat_history_on if chat_history_on is None else chat_history_on + + user_message = QueryModel(query=user_query, multimodal_payload=payload).query + + messages = [{"role": "system", "content": self.system_prompt.data}] + if chat_history_on: + messages.extend(self.chat_history.get_messages()) + messages.append({"role": "user", "content": user_message}) + + response = self.llm(messages=messages) + + @trace.bundle(output_name=f"{self.model_name}_response") + def call_llm(*messages) -> str: + """Call the LLM model. + Args: + messages: All the conversation history so far, starting from system prompt, to alternating user/assistant messages, ending with the current user query. + Returns: + response from the LLM + """ + return response.choices[0].message.content + + user_query_node = trace.node(user_query, name=f"{self.model_name}_user_query") + arg_list = [self.system_prompt] + if chat_history_on: + arg_list += self.chat_history.get_messages_as_node(self.model_name) + arg_list += [user_query_node] + + response_node = call_llm(*arg_list) + + # save to chat history + self.chat_history.add(user_query_node, role="user") + self.chat_history.add(response_node, role="assistant") + + return response_node + + def chat(self, user_query: str, payload: Optional[MultiModalPayload] = None, + chat_history_on: Optional[bool] = None) -> str: + """Note that chat/forward always assumes it's a single turn of the conversation. History/context management will be accomplished + through other APIs""" + return self.forward(user_query, payload, chat_history_on) + +# =========== =========== diff --git a/opto/features/flows/compose/parser.py b/opto/features/flows/compose/parser.py new file mode 100644 index 00000000..e9852616 --- /dev/null +++ b/opto/features/flows/compose/parser.py @@ -0,0 +1,200 @@ +import opto.trace as trace +from typing import Tuple, Union, get_type_hints, Any, Dict, List, Optional, Callable +from opto.utils.llm import AbstractModel, LLM +from opto.features.flows.types import MultiModalPayload, QueryModel, StructuredInput, StructuredOutput, \ + ForwardMixin +from opto.trainer.guide import Guide +import numpy as np +import contextvars + +# =========== Structured LLM Input/Output With Parsing =========== + +""" +Usage: + +@llm_call +def evaluate_person(person: Person) -> Preference: + "Evaluate if a person matches our criteria" + ... + +person = Person(name="Alice", age=30, income=75000) +preference = evaluate_person(person) + +TODO 2: add trace bundle and input/output conversion +""" + + +class StructuredLLMCallable: + """ + Wrapper class that makes a decorated function callable and automatically invokes the LLM. + """ + + def __init__(self, func: ForwardMixin, llm, input_type, output_type): + self.func = func + self.llm = llm + self.input_type = input_type + self.output_type = output_type + + # Store schemas + self.input_schema = input_type.model_json_schema() if input_type else None + self.output_schema = output_type.model_json_schema() if output_type else None + + # Copy function metadata + self.__name__ = func.__name__ + self.__doc__ = func.__doc__ + self.__module__ = func.__module__ + self.__annotations__ = func.__annotations__ + + def __call__(self, input_data: StructuredInput, system_prompt: Optional[str] = None) -> StructuredOutput: + """ + Automatically invoke the LLM with the input data. + + Args: + input_data: Instance of StructuredInput + system_prompt: Optional custom system prompt. If not provided, uses default. + + Returns: + Instance of StructuredOutput + """ + # Validate input type + if not isinstance(input_data, self.input_type): + raise TypeError(f"Expected input of type {self.input_type}, got {type(input_data)}") + + # Convert input to string representation for LLM + input_str = str(input_data) + + # Get function docstring as task description + func_doc = self.func.__doc__ or "Process the input data" + + # Build system prompt + if system_prompt is None: + output_fields = list(self.get_output_fields().keys()) + system_prompt = f"""You are a helpful assistant that performs the following task: {func_doc} + +You will receive input data and must produce output in JSON format with the following fields: {output_fields} + +Output description: {self.output_type.get_docstring() or 'Structured output'} + +Always respond with valid JSON only, no additional text.""" + + # Build user message with input data + user_message = f"""{input_str} + +Please respond with a JSON object containing the required output fields.""" + + # Construct messages in the format expected by the LLM + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_message} + ] + + # Invoke LLM + response = self.llm(messages=messages) + + # Parse response into StructuredOutput + # Note: Assumes LLM returns JSON string + try: + output_instance = self.output_type.model_validate_json(response) + except Exception: + # Fallback: try parsing as dict + import json + try: + output_dict = json.loads(response) + output_instance = self.output_type(**output_dict) + except Exception as e: + raise ValueError(f"Failed to parse LLM response into {self.output_type}: {e}\nResponse: {response}") + + return output_instance + + def get_input_docstring(self) -> Optional[str]: + return self.input_type.get_docstring() if self.input_type else None + + def get_output_docstring(self) -> Optional[str]: + return self.output_type.get_docstring() if self.output_type else None + + def get_input_fields(self) -> Dict[str, Any]: + return self.input_type.get_fields_info() if self.input_type else {} + + def get_output_fields(self) -> Dict[str, Any]: + return self.output_type.get_fields_info() if self.output_type else {} + + def __repr__(self): + return f"" + + +def llm_call(func: Callable = None, *, llm=None, **kwargs): + """ + Decorator that extracts input/output schemas from type-annotated functions + and creates a callable that automatically invokes the LLM. + + Args: + func: The function to decorate (automatically passed when used without arguments) + llm: AbstractModel instance to use for LLM calls + **kwargs: Additional LLM configuration parameters + + Usage: + # Without arguments + @llm_call + def process_person(person: Person) -> Preference: + '''Evaluate if a person matches our criteria''' + pass + + # With arguments + @llm_call(llm=customized_llm) + def process_person(person: Person) -> Preference: + '''Evaluate if a person matches our criteria''' + pass + + # Call it directly - LLM is invoked automatically + person = Person(name="Alice", age=30, income=75000) + result = process_person(person) # Returns Preference instance + + # Access schemas + process_person.input_type + process_person.output_type + process_person.input_schema + process_person.output_schema + """ + + def decorator(f: Callable): + hints = get_type_hints(f) + + # Get first parameter type and return type + params = list(hints.items()) + input_type = None + output_type = None + + # Find first non-return parameter + for param_name, param_type in params: + if param_name != 'return': + input_type = param_type + break + + output_type = hints.get('return') + + # Validate types + if input_type and not issubclass(input_type, StructuredInput): + raise TypeError(f"Input type {input_type} must inherit from StructuredInput") + + if output_type and not issubclass(output_type, StructuredOutput): + raise TypeError(f"Output type {output_type} must inherit from StructuredOutput") + + # Use default LLM if none provided + # Note: Replace this with your actual default LLM initialization + actual_llm = llm + if actual_llm is None: + actual_llm = LLM() # we use the default LLM + + # Create and return the callable wrapper + return StructuredLLMCallable(f, actual_llm, input_type, output_type) + + # Handle both @llm_call and @llm_call(...) syntax + if func is None: + # Called with arguments: @llm_call(llm=custom_llm) + return decorator + else: + # Called without arguments: @llm_call + return decorator(func) + + +# =========== =========== \ No newline at end of file diff --git a/opto/features/flows/types.py b/opto/features/flows/types.py index 79d46f91..685c1e41 100644 --- a/opto/features/flows/types.py +++ b/opto/features/flows/types.py @@ -15,6 +15,14 @@ def __str__(self): raise NotImplementedError("Subclasses must implement __str__") +class ForwardMixin: + def forward(self, *args, **kwargs): + raise NotImplementedError("Subclasses must implement forward") + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + # ====== Multi-Modal LLM Support ====== class MultiModalPayload(BaseModel): image_bytes: Optional[str] = None # base64-encoded data URL @@ -246,7 +254,6 @@ def __str__(self, template: Optional[str] = None) -> str: return "\n".join(lines) - # ======= ====== # ======= Agentic Optimizer Support ======= From 89368b60aba392688c397bcb64383b9bbe5e56bb Mon Sep 17 00:00:00 2001 From: windweller Date: Tue, 28 Oct 2025 13:18:40 -0400 Subject: [PATCH 14/14] updates --- .../experimental_optimizers/agentic_opt.py | 22 +- opto/features/flows/compose/agentic_ops.py | 1 + opto/features/flows/compose/parser.py | 450 +++++++++++++----- opto/trace/nodes.py | 8 +- 4 files changed, 351 insertions(+), 130 deletions(-) diff --git a/opto/features/experimental_optimizers/agentic_opt.py b/opto/features/experimental_optimizers/agentic_opt.py index b1decb83..adc07dd6 100644 --- a/opto/features/experimental_optimizers/agentic_opt.py +++ b/opto/features/experimental_optimizers/agentic_opt.py @@ -61,8 +61,8 @@ def __init__( Add a veribench optimizer here. A code optimizer is general, can work with kernel, others -initial_user_message -> initial_code -> bug fix here -initial_code -> improvement_prompt -> improved_code -> improvement_prompt -> improved_code +initial_task_message -> initial_code +(initial_code, reward) -> optimizer_step -> (improved_code, reward) -> optimizer_step -> improved_code initial_code 1. check if there is a bug. If there is, try 10 times to fix. return the code --> make it correct/compilable @@ -100,6 +100,23 @@ def __init__( ============= planning agent: plan -> coding agent: write the code +------------ + +CodeOptimizer: +1. Optimization history (different from memory) + +CodeOptimizer: +task: make this code run fast, parameterNode(initial code) + +oracle: gives value + +node(initial_node, description="this is xxx, the value should be...") + +------------ + +init_param -> 25 next params (top 5) -> 25 params -> 25 params + +Protein: multi-objective optimization """ class CodeOptimizer(Optimizer): @@ -128,6 +145,7 @@ def __init__( # Store environment checker and reward function self.bug_judge = bug_judge + self.reward_function = reward_function self.max_bug_fix_tries = max_bug_fix_tries self.max_optimization_tries = max_optimization_tries diff --git a/opto/features/flows/compose/agentic_ops.py b/opto/features/flows/compose/agentic_ops.py index eb9c2f68..7d0d87c9 100644 --- a/opto/features/flows/compose/agentic_ops.py +++ b/opto/features/flows/compose/agentic_ops.py @@ -231,6 +231,7 @@ def _evaluate_condition(self): result = self.condition_func(*self.args, **self.kwargs) # Store both the truthiness and the actual return value self.condition_result = result + self.condition_evaluated = True return self.condition_result diff --git a/opto/features/flows/compose/parser.py b/opto/features/flows/compose/parser.py index e9852616..25fb84fa 100644 --- a/opto/features/flows/compose/parser.py +++ b/opto/features/flows/compose/parser.py @@ -1,11 +1,234 @@ -import opto.trace as trace -from typing import Tuple, Union, get_type_hints, Any, Dict, List, Optional, Callable -from opto.utils.llm import AbstractModel, LLM -from opto.features.flows.types import MultiModalPayload, QueryModel, StructuredInput, StructuredOutput, \ - ForwardMixin -from opto.trainer.guide import Guide -import numpy as np -import contextvars +from typing import Optional, Callable, Any, Dict, Literal, get_type_hints +from abc import ABC, abstractmethod +import inspect +import json +from opto.features.flows.types import StructuredInput, StructuredOutput, ForwardMixin + + +class PromptAdapter(ABC): + """Base adapter for converting structured types to/from LLM messages""" + + @abstractmethod + def format_system_prompt( + self, + task_description: str, + input_type: type[StructuredInput], + output_type: type[StructuredOutput] + ) -> str: + """Generate system prompt explaining the task""" + pass + + @abstractmethod + def format_input(self, input_data: StructuredInput) -> str: + """Convert input instance to string for LLM""" + pass + + @abstractmethod + def parse_output(self, llm_response: str, output_type: type[StructuredOutput]) -> StructuredOutput: + """Parse LLM response into output instance""" + pass + + +class JSONAdapter(PromptAdapter): + """Standard JSON-based communication""" + + def format_system_prompt(self, task_description: str, input_type, output_type) -> str: + output_schema = json.dumps(output_type.model_json_schema(), indent=2) + return f"""Task: {task_description} + +Output Schema: +{output_schema} + +Respond with valid JSON matching the schema above.""" + + def format_input(self, input_data: StructuredInput) -> str: + return input_data.model_dump_json(indent=2) + + def parse_output(self, llm_response: str, output_type) -> StructuredOutput: + # Try direct JSON parsing + try: + return output_type.model_validate_json(llm_response) + except Exception: + # Fallback: extract JSON from markdown code blocks + import re + json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', llm_response, re.DOTALL) + if json_match: + return output_type.model_validate_json(json_match.group(1)) + raise ValueError(f"Failed to parse JSON response: {llm_response[:200]}") + + +class MarkdownAdapter(PromptAdapter): + """Innovative markdown-based format using YAML-style frontmatter + sections""" + + def format_system_prompt(self, task_description: str, input_type, output_type) -> str: + # Build field descriptions + output_fields = [] + for field_name, field_info in output_type.model_fields.items(): + field_type = field_info.annotation.__name__ if hasattr(field_info.annotation, '__name__') else str( + field_info.annotation) + desc = field_info.description or "" + output_fields.append(f"- **{field_name}** (`{field_type}`): {desc}") + + fields_str = "\n".join(output_fields) + + return f"""# Task +{task_description} + +# Output Format +Respond using markdown with YAML frontmatter for metadata and sections for complex content. + +## Required Fields +{fields_str} + +## Structure +```markdown +--- +field_name: simple value +other_field: simple value +--- + +# SectionField (if complex content) +Complex content here... + +# AnotherSectionField +More complex content... +``` + +Use frontmatter for simple fields (strings, numbers, booleans). +Use markdown sections (# FieldName) for complex content (long text, lists, nested structures).""" + + def format_input(self, input_data: StructuredInput) -> str: + """Format input as markdown with frontmatter""" + lines = ["---"] + + # Simple fields in frontmatter + complex_fields = {} + for field_name, value in input_data.model_dump().items(): + if isinstance(value, (str, int, float, bool, type(None))): + if isinstance(value, str) and ('\n' in value or len(value) > 100): + complex_fields[field_name] = value + else: + lines.append(f"{field_name}: {value}") + else: + complex_fields[field_name] = value + + lines.append("---") + lines.append("") + + # Complex fields as sections + for field_name, value in complex_fields.items(): + lines.append(f"# {field_name}") + if isinstance(value, str): + lines.append(value) + else: + lines.append(json.dumps(value, indent=2)) + lines.append("") + + return "\n".join(lines) + + def parse_output(self, llm_response: str, output_type) -> StructuredOutput: + """Parse markdown frontmatter + sections into structured output""" + import re + + result = {} + + # Extract frontmatter + frontmatter_match = re.match(r'^---\s*\n(.*?)\n---\s*\n', llm_response, re.DOTALL) + if frontmatter_match: + frontmatter = frontmatter_match.group(1) + remaining = llm_response[frontmatter_match.end():] + + # Parse YAML-style frontmatter + for line in frontmatter.split('\n'): + if ':' in line: + key, value = line.split(':', 1) + key = key.strip() + value = value.strip() + + # Type coercion + if value.lower() == 'true': + value = True + elif value.lower() == 'false': + value = False + elif value.isdigit(): + value = int(value) + else: + try: + value = float(value) + except ValueError: + pass # Keep as string + + result[key] = value + else: + remaining = llm_response + + # Extract sections + sections = re.split(r'^# (\w+)\s*$', remaining, flags=re.MULTILINE) + for i in range(1, len(sections), 2): + if i + 1 < len(sections): + field_name = sections[i].strip() + content = sections[i + 1].strip() + result[field_name] = content + + return output_type(**result) + + +class XMLAdapter(PromptAdapter): + """XML-based format similar to Anthropic's style""" + + def format_system_prompt(self, task_description: str, input_type, output_type) -> str: + # Build field descriptions + output_fields = [] + for field_name, field_info in output_type.model_fields.items(): + field_type = field_info.annotation.__name__ if hasattr(field_info.annotation, '__name__') else str( + field_info.annotation) + desc = field_info.description or "" + output_fields.append(f" <{field_name} type='{field_type}'>{desc}") + + fields_str = "\n".join(output_fields) + + return f"""Task: {task_description} + +Output Format: Respond with XML structure containing these fields: + +{fields_str} + + +Provide actual values, not descriptions.""" + + def format_input(self, input_data: StructuredInput) -> str: + """Convert to XML format""" + lines = [""] + for field_name, value in input_data.model_dump().items(): + if isinstance(value, (list, dict)): + value = json.dumps(value) + lines.append(f" <{field_name}>{value}") + lines.append("") + return "\n".join(lines) + + def parse_output(self, llm_response: str, output_type) -> StructuredOutput: + """Parse XML response""" + import re + + result = {} + + # Extract fields using regex (simple XML parsing) + for field_name in output_type.model_fields.keys(): + pattern = f'<{field_name}>(.*?)' + match = re.search(pattern, llm_response, re.DOTALL) + if match: + value = match.group(1).strip() + + # Try JSON parsing for complex types + try: + value = json.loads(value) + except (json.JSONDecodeError, TypeError): + pass + + result[field_name] = value + + return output_type(**result) + # =========== Structured LLM Input/Output With Parsing =========== @@ -25,176 +248,151 @@ def evaluate_person(person: Person) -> Preference: class StructuredLLMCallable: - """ - Wrapper class that makes a decorated function callable and automatically invokes the LLM. - """ - - def __init__(self, func: ForwardMixin, llm, input_type, output_type): + """Enhanced wrapper supporting multiple function patterns""" + + def __init__( + self, + func: Callable, + llm, + input_type: type[StructuredInput], + output_type: type[StructuredOutput], + has_preprocessing: bool, + adapter: PromptAdapter + ): self.func = func self.llm = llm self.input_type = input_type self.output_type = output_type + self.has_preprocessing = has_preprocessing + self.adapter = adapter - # Store schemas - self.input_schema = input_type.model_json_schema() if input_type else None - self.output_schema = output_type.model_json_schema() if output_type else None + # Store output_type in adapter if it's MarkdownAdapter (for user message generation) + if isinstance(adapter, MarkdownAdapter): + adapter.output_type = output_type - # Copy function metadata + # Copy metadata self.__name__ = func.__name__ self.__doc__ = func.__doc__ self.__module__ = func.__module__ self.__annotations__ = func.__annotations__ - def __call__(self, input_data: StructuredInput, system_prompt: Optional[str] = None) -> StructuredOutput: - """ - Automatically invoke the LLM with the input data. - - Args: - input_data: Instance of StructuredInput - system_prompt: Optional custom system prompt. If not provided, uses default. + def __call__( + self, + input_data: StructuredInput, + system_prompt: Optional[str] = None + ) -> StructuredOutput: + """Execute with automatic LLM invocation""" - Returns: - Instance of StructuredOutput - """ # Validate input type if not isinstance(input_data, self.input_type): - raise TypeError(f"Expected input of type {self.input_type}, got {type(input_data)}") - - # Convert input to string representation for LLM - input_str = str(input_data) - - # Get function docstring as task description - func_doc = self.func.__doc__ or "Process the input data" - - # Build system prompt + raise TypeError(f"Expected {self.input_type}, got {type(input_data)}") + + # Execute preprocessing if function has implementation + if self.has_preprocessing: + result = self.func(input_data) + + # Check if function returned the output type class (Usage 2 & 3) + if inspect.isclass(result) and issubclass(result, StructuredOutput): + self.output_type = result + # Update adapter's output_type + if isinstance(self.adapter, MarkdownAdapter): + self.adapter.output_type = result + elif isinstance(result, StructuredOutput): + # Function did full processing, return directly + return result + + # Build system prompt using adapter if system_prompt is None: - output_fields = list(self.get_output_fields().keys()) - system_prompt = f"""You are a helpful assistant that performs the following task: {func_doc} - -You will receive input data and must produce output in JSON format with the following fields: {output_fields} - -Output description: {self.output_type.get_docstring() or 'Structured output'} + task_description = self.func.__doc__ or "Process the input data" + system_prompt = self.adapter.format_system_prompt( + task_description, + self.input_type, + self.output_type + ) -Always respond with valid JSON only, no additional text.""" + # Format input using adapter + user_message = self.adapter.format_input(input_data) - # Build user message with input data - user_message = f"""{input_str} - -Please respond with a JSON object containing the required output fields.""" - - # Construct messages in the format expected by the LLM + # Invoke LLM messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_message} ] - # Invoke LLM response = self.llm(messages=messages) - # Parse response into StructuredOutput - # Note: Assumes LLM returns JSON string - try: - output_instance = self.output_type.model_validate_json(response) - except Exception: - # Fallback: try parsing as dict - import json - try: - output_dict = json.loads(response) - output_instance = self.output_type(**output_dict) - except Exception as e: - raise ValueError(f"Failed to parse LLM response into {self.output_type}: {e}\nResponse: {response}") + # Parse output using adapter + output_instance = self.adapter.parse_output(response, self.output_type) return output_instance - def get_input_docstring(self) -> Optional[str]: - return self.input_type.get_docstring() if self.input_type else None - - def get_output_docstring(self) -> Optional[str]: - return self.output_type.get_docstring() if self.output_type else None - - def get_input_fields(self) -> Dict[str, Any]: - return self.input_type.get_fields_info() if self.input_type else {} - def get_output_fields(self) -> Dict[str, Any]: - return self.output_type.get_fields_info() if self.output_type else {} - - def __repr__(self): - return f"" - - -def llm_call(func: Callable = None, *, llm=None, **kwargs): +def llm_call( + func: Callable = None, + *, + llm=None, + adapter: Literal["json", "markdown", "xml"] = "markdown", + type_hints: Optional[Dict[str, str]] = None, + **kwargs +): """ - Decorator that extracts input/output schemas from type-annotated functions - and creates a callable that automatically invokes the LLM. + Enhanced decorator supporting three usage patterns. Args: - func: The function to decorate (automatically passed when used without arguments) - llm: AbstractModel instance to use for LLM calls - **kwargs: Additional LLM configuration parameters - - Usage: - # Without arguments - @llm_call - def process_person(person: Person) -> Preference: - '''Evaluate if a person matches our criteria''' - pass - - # With arguments - @llm_call(llm=customized_llm) - def process_person(person: Person) -> Preference: - '''Evaluate if a person matches our criteria''' - pass - - # Call it directly - LLM is invoked automatically - person = Person(name="Alice", age=30, income=75000) - result = process_person(person) # Returns Preference instance - - # Access schemas - process_person.input_type - process_person.output_type - process_person.input_schema - process_person.output_schema + func: Function to decorate + llm: LLM instance + adapter: Output format - "json", "markdown", or "xml" + type_hints: Custom type hints for markdown adapter fields, e.g., + {"confidence": "must be a single float value between 0 and 1"} """ + # Create adapter instance + adapter_map = { + "json": JSONAdapter(), + "markdown": MarkdownAdapter(type_hints=type_hints), + "xml": XMLAdapter() + } + adapter_instance = adapter_map[adapter] + def decorator(f: Callable): hints = get_type_hints(f) - # Get first parameter type and return type + # Extract input type (first parameter) params = list(hints.items()) input_type = None - output_type = None - - # Find first non-return parameter for param_name, param_type in params: if param_name != 'return': input_type = param_type break + # Get output type from annotation output_type = hints.get('return') + # Detect if function has implementation (Usage 2 & 3) + source = inspect.getsource(f) + has_implementation = not source.strip().endswith('...') + # Validate types if input_type and not issubclass(input_type, StructuredInput): - raise TypeError(f"Input type {input_type} must inherit from StructuredInput") + raise TypeError(f"Input type must inherit from StructuredInput") if output_type and not issubclass(output_type, StructuredOutput): - raise TypeError(f"Output type {output_type} must inherit from StructuredOutput") + raise TypeError(f"Output type must inherit from StructuredOutput") # Use default LLM if none provided - # Note: Replace this with your actual default LLM initialization - actual_llm = llm - if actual_llm is None: - actual_llm = LLM() # we use the default LLM - - # Create and return the callable wrapper - return StructuredLLMCallable(f, actual_llm, input_type, output_type) - - # Handle both @llm_call and @llm_call(...) syntax + actual_llm = llm if llm is not None else LLM() + + return StructuredLLMCallable( + f, + actual_llm, + input_type, + output_type, + has_implementation, + adapter_instance + ) + + # Handle decorator syntax if func is None: - # Called with arguments: @llm_call(llm=custom_llm) return decorator else: - # Called without arguments: @llm_call return decorator(func) - - -# =========== =========== \ No newline at end of file diff --git a/opto/trace/nodes.py b/opto/trace/nodes.py index 0f4c85e5..5c86bbf2 100644 --- a/opto/trace/nodes.py +++ b/opto/trace/nodes.py @@ -8,7 +8,7 @@ import contextvars -def node(data, name=None, trainable=False, description=None): +def node(data, name=None, trainable=False, description=None, **kwargs): """Create a Node object from data. This is the primary factory function for creating nodes in the Trace computation graph. @@ -29,6 +29,9 @@ def node(data, name=None, trainable=False, description=None): description : str, optional A textual description of the node's purpose or constraints. Used as soft constraints during optimization and for documentation. + **kwargs : dict + Additional keyword arguments to pass to the Node or ParameterNode constructor, + such as 'info' or 'projections'. Returns ------- @@ -85,6 +88,7 @@ def node(data, name=None, trainable=False, description=None): name=name, trainable=True, description=description, + **kwargs ) else: if isinstance(data, Node): @@ -92,7 +96,7 @@ def node(data, name=None, trainable=False, description=None): warnings.warn(f"Name {name} is ignored because data is already a Node.") return data else: - return Node(data, name=name, description=description) + return Node(data, name=name, description=description, **kwargs) NAME_SCOPES = [] # A stack of name scopes