Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion nemoguardrails/actions/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,35 @@ def _infer_model_name(llm: BaseLanguageModel):
return "unknown"


def _filter_params_for_openai_reasoning_models(llm: BaseLanguageModel, llm_params: Optional[dict]) -> Optional[dict]:
"""Filter out unsupported parameters for OpenAI reasoning models.

OpenAI reasoning models (o1, o3, gpt-5 excluding gpt-5-chat) only support
temperature=1. When using .bind() with other temperature values, the API
returns an error. This function removes the temperature parameter for these
models to allow the API default to apply.

See: https://github.com/langchain-ai/langchain/blob/master/libs/partners/openai/langchain_openai/chat_models/base.py
"""
if not llm_params or "temperature" not in llm_params:
return llm_params

model_name = _infer_model_name(llm).lower()

is_openai_reasoning_model = (
model_name.startswith("o1")
or model_name.startswith("o3")
or (model_name.startswith("gpt-5") and "chat" not in model_name)
)
Comment on lines +154 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: verify this detection logic handles future OpenAI reasoning models (e.g., gpt-6 or o4 series) that may have similar temperature restrictions

Prompt To Fix With AI
This is a comment left during a code review.
Path: nemoguardrails/actions/llm/utils.py
Line: 154:158

Comment:
**style:** verify this detection logic handles future OpenAI reasoning models (e.g., gpt-6 or o4 series) that may have similar temperature restrictions

How can I resolve this? If you propose a fix, please make it concise.


if is_openai_reasoning_model:
filtered = llm_params.copy()
filtered.pop("temperature", None)
return filtered

return llm_params


async def llm_call(
llm: Optional[BaseLanguageModel],
prompt: Union[str, List[dict]],
Expand Down Expand Up @@ -164,8 +193,9 @@ async def llm_call(
_setup_llm_call_info(llm, model_name, model_provider)
all_callbacks = _prepare_callbacks(custom_callback_handlers)

filtered_params = _filter_params_for_openai_reasoning_models(llm, llm_params)
generation_llm: Union[BaseLanguageModel, Runnable] = (
llm.bind(stop=stop, **llm_params) if llm_params and llm is not None else llm
llm.bind(stop=stop, **filtered_params) if filtered_params and llm is not None else llm
)

if isinstance(prompt, str):
Expand Down
51 changes: 51 additions & 0 deletions tests/test_actions_llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_extract_reasoning_from_content_blocks,
_extract_tool_calls_from_attribute,
_extract_tool_calls_from_content_blocks,
_filter_params_for_openai_reasoning_models,
_infer_provider_from_module,
_store_reasoning_traces,
_store_tool_calls,
Expand Down Expand Up @@ -532,3 +533,53 @@ def test_store_tool_calls_with_real_aimessage_multiple_tool_calls():
assert len(tool_calls) == 2
assert tool_calls[0]["name"] == "foo"
assert tool_calls[1]["name"] == "bar"


def _create_llm(model_name):
try:
from langchain_openai import ChatOpenAI

return ChatOpenAI(model=model_name)
except Exception:

class _MockLLM:
def __init__(self, model_name):
self.model_name = model_name

return _MockLLM(model_name)


class TestFilterParamsForOpenAIReasoningModels:
@pytest.mark.parametrize(
"model,params,expected",
[
("gpt-4", {"temperature": 0.5, "max_tokens": 100}, {"temperature": 0.5, "max_tokens": 100}),
("gpt-4o", {"temperature": 0.7}, {"temperature": 0.7}),
("gpt-4o-mini", {"temperature": 0.3, "max_tokens": 50}, {"temperature": 0.3, "max_tokens": 50}),
("gpt-5-chat", {"temperature": 0.5}, {"temperature": 0.5}),
("o1-preview", {"temperature": 0.001, "max_tokens": 100}, {"max_tokens": 100}),
("o1-mini", {"temperature": 0.5}, {}),
("o3", {"temperature": 0.001, "max_tokens": 200}, {"max_tokens": 200}),
("o3-mini", {"temperature": 0.1}, {}),
("gpt-5", {"temperature": 0.001}, {}),
("gpt-5-mini", {"temperature": 0.5, "max_tokens": 100}, {"max_tokens": 100}),
("gpt-5-nano", {"temperature": 0.001}, {}),
("o1-preview", {"max_tokens": 100}, {"max_tokens": 100}),
("o1-preview", {}, {}),
],
)
def test_filter_params(self, model, params, expected):
llm = _create_llm(model)
result = _filter_params_for_openai_reasoning_models(llm, params)
assert result == expected

def test_returns_none_when_llm_params_is_none(self):
llm = _create_llm("gpt-4")
result = _filter_params_for_openai_reasoning_models(llm, None)
assert result is None

def test_does_not_modify_original_params(self):
llm = _create_llm("o1-preview")
params = {"temperature": 0.5, "max_tokens": 100}
_filter_params_for_openai_reasoning_models(llm, params)
assert params == {"temperature": 0.5, "max_tokens": 100}