diff --git a/opto/trace/projections/__init__.py b/opto/trace/projections/__init__.py index 7264d5bd..6b285dc0 100644 --- a/opto/trace/projections/__init__.py +++ b/opto/trace/projections/__init__.py @@ -1,2 +1,2 @@ from opto.trace.projections.projections import Projection -from opto.trace.projections.code_projections import BlackCodeFormatter, DocstringProjection \ No newline at end of file +from opto.trace.projections.code_projections import BlackCodeFormatter, DocstringProjection, SuggestionNormalizationProjection \ No newline at end of file diff --git a/opto/trace/projections/code_projections.py b/opto/trace/projections/code_projections.py index 78a4642c..c7356a9d 100644 --- a/opto/trace/projections/code_projections.py +++ b/opto/trace/projections/code_projections.py @@ -1,5 +1,7 @@ from opto.trace.projections import Projection +import re +import ast class BlackCodeFormatter(Projection): # This requires the `black` package to be installed. @@ -28,4 +30,54 @@ def project(self, x: str) -> str: x = f'{x[0]}"""{self.docstring}"""{x[2]}' else: x = f'{x[0]}"""{self.docstring}"""' - return x \ No newline at end of file + return x + +class SuggestionNormalizationProjection(Projection): + """ + Normalize LLM-generated suggestion dicts: + - Literal-eval strings to their true types + - Alias frequent keys like "__code:8" ↔ "__code8" + - Black-reformat any code snippets + """ + def __init__(self, parameters): + self.parameters = parameters + + def project(self, suggestion: dict) -> dict: + from black import format_str, FileMode + def _find_key(node_name: str): + # exact match + if node_name in suggestion: + return node_name + # strip a colon before digits ("__code:8" → "__code8") + norm = re.sub(r":(?=\d+$)", "", node_name) + for k in suggestion: + if re.sub(r":(?=\d+$)", "", k) == norm: + return k + return None + + normalized: dict = {} + for node in self.parameters: + if not getattr(node, "trainable", False): + continue + key = _find_key(node.py_name) + if key is None: + continue + + raw_val = suggestion[key] + # re-format any Python defs + # check that key start with "__code" and contains "def" + if isinstance(raw_val, str) and key.startswith("__code"): + raw_val = format_str(raw_val, mode=FileMode()) + + # convert "123" → 123, "[1,2]" → [1,2], etc. + target_type = type(node.data) + if isinstance(raw_val, str) and target_type is not str: + try: + raw_val = target_type(ast.literal_eval(raw_val)) + except Exception: + pass + + # map by the parameter’s name, not the node itself + normalized[node.py_name] = raw_val + + return normalized diff --git a/tests/unit_tests/test_projection.py b/tests/unit_tests/test_projection.py index 794fffcd..84c9a1c4 100644 --- a/tests/unit_tests/test_projection.py +++ b/tests/unit_tests/test_projection.py @@ -1,4 +1,5 @@ -from opto.trace.projections import BlackCodeFormatter, DocstringProjection +from opto.trace.projections import BlackCodeFormatter, DocstringProjection, SuggestionNormalizationProjection +from types import SimpleNamespace def test_black_code_formatter(): code = """ @@ -35,4 +36,53 @@ def example_function(): assert formatted_code == new_code # assert '"""This is a new docstring."""' in formatted_code - # assert 'print("Hello, World!")' in formatted_code \ No newline at end of file + # assert 'print("Hello, World!")' in formatted_code + +def test_suggestion_normalization_projection(): + import re + import pytest + # Prepare a mock parameter list with various py_names, types, and trainable flags + params = [ + # code param: key comes in as "__code:1", should alias to "__code1" and be black‑formatted + SimpleNamespace(py_name="__code1", trainable=True, data=""), + # learning rate param: as float, but suggestion comes as a literal string + SimpleNamespace(py_name="__lr", trainable=True, data=0.0), + # should be skipped because not trainable + SimpleNamespace(py_name="__frozen", trainable=False, data=123), + # some other param, no suggestion provided + SimpleNamespace(py_name="__missing", trainable=True, data=1) + ] + + raw_suggestion = { + "__code:1": "def foo(x):return x*2", # needs black formatting + "__lr": "\"0.01\"", # needs literal‐eval → float + "__frozen": "999", # should be ignored + "unrelated": "[1,2,3]", # not in params + } + + proj = SuggestionNormalizationProjection(params) + normalized = proj.project(raw_suggestion) + + # It should only contain keys for trainable params that were suggested + assert set(normalized.keys()) == {"__code1", "__lr"} + + # 1) __code1 should be black‐formatted: 'def foo' newline indent 'return x * 2' + code_out = normalized["__code1"] + # check that there's exactly one indent (4 spaces) before the return, + # and that black added a trailing newline + assert re.search(r"def foo\(x\):\n {4}return x \* 2\n$", code_out) + + # 2) __lr should have been converted from the string "0.01" to float 0.01 + assert isinstance(normalized["__lr"], float) + assert normalized["__lr"] == pytest.approx(0.01) + + # 3) Non‐trainable or missing params should not appear + assert "__frozen" not in normalized + assert "__missing" not in normalized + + # --- literal‑eval failure should be left unchanged --- + # If ast.literal_eval raises, the original string remains + params_bad = [SimpleNamespace(py_name="__bad", trainable=True, data=100)] + raw_suggestion_bad = {"__bad": "not_a_number"} + normalized_bad = SuggestionNormalizationProjection(params_bad).project(raw_suggestion_bad) + assert normalized_bad["__bad"] == "not_a_number" \ No newline at end of file