diff --git a/examples/saved_scenario_create_update.ipynb b/examples/saved_scenario_create_update.ipynb new file mode 100644 index 0000000..a4135d5 --- /dev/null +++ b/examples/saved_scenario_create_update.ipynb @@ -0,0 +1,191 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Saving Scenarios to MyETM\n", + "\n", + "This notebook demonstrates how to save scenarios to MyETM as SavedScenarios." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from example_helpers import setup_notebook\n", + "setup_notebook()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyetm.clients.base_client import BaseClient\n", + "from pyetm.models.scenario import Scenario\n", + "from pyetm.models.saved_scenario import SavedScenario\n", + "\n", + "client = BaseClient()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using scenario.save()\n", + "\n", + "The simplest way to save a scenario is using the `.save()` method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a session scenario\n", + "scenario = Scenario.new(area_code=\"nl\", end_year=2050, title=\"My Test Scenario\")\n", + "print(f\"Created session scenario {scenario.id}\")\n", + "\n", + "# Save it to MyETM - automatically uses scenario.id, scenario.title, and scenario.private\n", + "saved_scenario = scenario.save(description=\"Saved using save()\")\n", + "\n", + "print(f\"Saved as SavedScenario {saved_scenario.id}\")\n", + "print(f\"Title: {saved_scenario.title}\")\n", + "print(f\"Description: {saved_scenario.description}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The underlying scenario's title will be taken for the saved scenario, unless another title is specified. If there are no titles, an error will be thrown." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "scenario2 = Scenario.new(area_code=\"nl\", end_year=2050, title=\"Original Title\")\n", + "saved_scenario2 = scenario2.save(title=\"Custom Title\", private=True)\n", + "\n", + "print(f\"Scenario title: {scenario2.title}\")\n", + "print(f\"SavedScenario title: {saved_scenario2.title}\")\n", + "print(f\"Private: {saved_scenario2.private}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Accessing the underlying scenario\n", + "\n", + "You can update the title, description, and privacy settings:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "underlying_scenario = saved_scenario.get_scenario(client)\n", + "print(f\"Scenario ID: {underlying_scenario.id}\")\n", + "print(f\"Area code: {underlying_scenario.area_code}\")\n", + "print(f\"End year: {underlying_scenario.end_year}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Alternative Methods to create a saved scenario\n", + "\n", + "You can also use `SavedScenario.create()` or `SavedScenario.from_scenario()` directly:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Method 2: Using SavedScenario.from_scenario()\n", + "scenario3 = Scenario.new(area_code=\"nl\", end_year=2050)\n", + "saved3 = SavedScenario.from_scenario(client, scenario3, \"From scenario method\")\n", + "print(f\"Method 2 - SavedScenario ID: {saved3.id}\")\n", + "\n", + "# Method 3: Using SavedScenario.create() with explicit params\n", + "saved4 = SavedScenario.create(\n", + " client,\n", + " params={\n", + " \"scenario_id\": scenario3.id,\n", + " \"title\": \"Created with explicit params\",\n", + " \"description\": \"Using SavedScenario.create()\",\n", + " \"private\": False\n", + " }\n", + ")\n", + "print(f\"Method 3 - SavedScenario ID: {saved4.id}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Update\n", + "\n", + "You can update any of the saved scenario's modifiable fields such as title and description." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "saved_scenario.update(\n", + " client,\n", + " title=\"Updated Title\",\n", + " description=\"Updated description\"\n", + ")\n", + "\n", + "print(f\"Updated title: {saved_scenario.title}\")\n", + "print(f\"Updated description: {saved_scenario.description}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pyetm-Rh4Np-o3-py3.12", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/pyetm/clients/session.py b/src/pyetm/clients/session.py index fe7f7ae..b9e43c3 100644 --- a/src/pyetm/clients/session.py +++ b/src/pyetm/clients/session.py @@ -45,7 +45,7 @@ def content(self) -> bytes: @field_validator("status_code", mode="before") @classmethod - def raise_for_status(cls, value, info: ValidationInfo) -> int: + def raise_for_status(cls, value, info: ValidationInfo) -> None: """Raise appropriate exception for HTTP errors.""" if value == 401: raise PermissionError("Invalid or missing ETM_API_TOKEN") diff --git a/src/pyetm/models/__init__.py b/src/pyetm/models/__init__.py index 6e0abb7..33ee7f3 100644 --- a/src/pyetm/models/__init__.py +++ b/src/pyetm/models/__init__.py @@ -1,6 +1,7 @@ from .custom_curves import CustomCurves from .gqueries import Gqueries from .inputs import Input, Inputs +from .saved_scenario import SavedScenario from .scenario import Scenario from .scenarios import Scenarios from .sortables import Sortable, Sortables diff --git a/src/pyetm/models/saved_scenario.py b/src/pyetm/models/saved_scenario.py new file mode 100644 index 0000000..f25e111 --- /dev/null +++ b/src/pyetm/models/saved_scenario.py @@ -0,0 +1,216 @@ +from __future__ import annotations +from datetime import datetime +from typing import Any, Dict, Optional, TYPE_CHECKING +from pydantic import Field, PrivateAttr +from pyetm.models.base import Base +from pyetm.clients import BaseClient +from pyetm.services.scenario_runners.create_saved_scenario import ( + CreateSavedScenarioRunner, +) +from pyetm.services.scenario_runners.update_saved_scenario import ( + UpdateSavedScenarioRunner, +) +from pyetm.services.scenario_runners.fetch_saved_scenario import ( + FetchSavedScenarioRunner, +) + +if TYPE_CHECKING: + from pyetm.models.scenario import Scenario + + +class SavedScenarioError(Exception): + """Base saved scenario error""" + + +class SavedScenario(Base): + """ + Pydantic model for a MyETM SavedScenario. + + A SavedScenario wraps an ETEngine session scenario and persists it in MyETM. + The response includes both SavedScenario metadata and the full nested Scenario. + """ + + id: int = Field(..., description="Unique saved scenario identifier in MyETM") + scenario_id: int = Field(..., description="Reference to ETEngine scenario") + title: str = Field(..., description="Title of the saved scenario") + description: Optional[str] = None + private: Optional[bool] = False + area_code: Optional[str] = None + end_year: Optional[int] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + scenario: Optional[Dict[str, Any]] = None + + _scenario_session: Optional[Scenario] = PrivateAttr(None) + + @classmethod + def create( + cls, params: Dict[str, Any], client: Optional[BaseClient] = None + ) -> "SavedScenario": + """ + Create a new SavedScenario in MyETM from an existing session scenario. + + Args: + params: Dictionary with required keys (scenario_id, title) and optional keys + (description, private) + client: Optional BaseClient instance + + Returns: + SavedScenario instance + + Raises: + SavedScenarioError if creation fails + """ + if client is None: + client = BaseClient() + result = CreateSavedScenarioRunner.run(client, params) + + if not result.success: + raise SavedScenarioError( + f"Could not create saved scenario: {result.errors}" + ) + + saved_scenario = cls.model_validate(result.data) + for warning in result.errors: + saved_scenario.add_warning("base", warning) + + for field, value in params.items(): + if hasattr(saved_scenario, field) and field not in result.data: + setattr(saved_scenario, field, value) + + return saved_scenario + + @classmethod + def from_scenario( + cls, + scenario: "Scenario", + title: str, + client: Optional[BaseClient] = None, + **kwargs, + ) -> "SavedScenario": + """ + Convenience method to create SavedScenario from a Scenario instance. + + Args: + scenario: Scenario instance to save + title: Title for the saved scenario + client: Optional BaseClient instance + **kwargs: Optional params (description, private) + + Returns: + SavedScenario instance + """ + params = {"scenario_id": scenario.id, "title": title, **kwargs} + return cls.create(params, client=client) + + @classmethod + def load( + cls, saved_scenario_id: int, client: Optional[BaseClient] = None + ) -> "SavedScenario": + """ + Load an existing SavedScenario from MyETM by its ID. + + Args: + saved_scenario_id: The ID of the saved scenario to load + client: Optional BaseClient instance + + Returns: + SavedScenario instance + + Raises: + SavedScenarioError if loading fails + """ + if client is None: + client = BaseClient() + + template = type("T", (), {"id": saved_scenario_id}) + result = FetchSavedScenarioRunner.run(client, template) + + if not result.success: + raise SavedScenarioError( + f"Could not load saved scenario {saved_scenario_id}: {result.errors}" + ) + + saved_scenario = cls.model_validate(result.data) + for warning in result.errors: + saved_scenario.add_warning("base", warning) + + return saved_scenario + + @classmethod + def new( + cls, + scenario_id: int, + title: str, + client: Optional[BaseClient] = None, + **kwargs, + ) -> "SavedScenario": + """ + Create a new SavedScenario from an ETEngine scenario ID. + + Args: + scenario_id: The ETEngine scenario ID to save + title: Title for the saved scenario + client: Optional BaseClient instance + **kwargs: Optional params (description, private) + + Returns: + SavedScenario instance + + Raises: + SavedScenarioError if creation fails + """ + params = {"scenario_id": scenario_id, "title": title, **kwargs} + return cls.create(params, client=client) + + @property + def session(self) -> "Scenario": + """ + Get the current underlying ETEngine Scenario for this SavedScenario. + + Returns: + Scenario: The current ETEngine scenario session (cached after first access) + """ + from pyetm.models.scenario import Scenario + + # Return cached if already loaded + if self._scenario_session is not None: + return self._scenario_session + + # Build from nested data if available (e.g., from SavedScenario.load()) + if self.scenario is not None: + self._scenario_session = Scenario.model_validate(self.scenario) + return self._scenario_session + + # Fetch fresh from ETEngine API + self._scenario_session = Scenario.load(self.scenario_id) + return self._scenario_session + + def update(self, client: Optional[BaseClient] = None, **kwargs) -> None: + """ + Update this SavedScenario + + Args: + client: Optional BaseClient instance + **kwargs: Fields to update (title, description, private, discarded) + """ + if client is None: + client = BaseClient() + result = UpdateSavedScenarioRunner.run(client, self.id, kwargs) + + if not result.success: + raise SavedScenarioError( + f"Could not update saved scenario: {result.errors}" + ) + + for warning in result.errors: + self.add_warning("update", warning) + + if result.data: + for field, value in result.data.items(): + if hasattr(self, field): + setattr(self, field, value) + + for field, value in kwargs.items(): + if hasattr(self, field) and (not result.data or field not in result.data): + setattr(self, field, value) diff --git a/src/pyetm/models/scenario.py b/src/pyetm/models/scenario.py index 659a261..6170c8a 100644 --- a/src/pyetm/models/scenario.py +++ b/src/pyetm/models/scenario.py @@ -81,11 +81,14 @@ def new(cls, area_code: str, end_year: int, **kwargs) -> "Scenario": if not result.success: raise ScenarioError(f"Could not create scenario: {result.errors}") - # parse into a Scenario scenario = cls.model_validate(result.data) for warning in result.errors: scenario.add_warning("base", warning) + for field, value in scenario_data.items(): + if hasattr(scenario, field) and field not in result.data: + setattr(scenario, field, value) + return scenario @classmethod @@ -130,6 +133,36 @@ def to_excel(self, path: PathLike | str, **export_options) -> None: ScenarioExcelService.export_to_excel([self], path, **export_options) + def save( + self, client: Optional[BaseClient] = None, title: Optional[str] = None, **kwargs + ): + """ + Save this scenario to MyETM as a SavedScenario. + + Returns: + SavedScenario instance + """ + from pyetm.models.saved_scenario import SavedScenario + + client = client or BaseClient() + + save_title = title or self.title + if not save_title: + raise ScenarioError( + "Title is required to save scenario. Provide title parameter or set scenario.title" + ) + + params = { + "scenario_id": self.id, + "title": save_title, + **kwargs, + } + + if self.private is not None: + params.setdefault("private", self.private) + + return SavedScenario.create(params, client=client) + def update_metadata(self, **kwargs) -> Dict[str, Any]: """ Update metadata for this scenario. diff --git a/src/pyetm/services/scenario_runners/base_runner.py b/src/pyetm/services/scenario_runners/base_runner.py index 64cd1e2..685cd3c 100644 --- a/src/pyetm/services/scenario_runners/base_runner.py +++ b/src/pyetm/services/scenario_runners/base_runner.py @@ -100,6 +100,42 @@ def _make_batch_requests( return make_batch_requests(client, formatted_requests) + @classmethod + def _validate_required_fields( + cls, data: Dict[str, Any], required_keys: List[str] + ) -> List[str]: + """ + Check for missing required fields. + """ + missing = [key for key in required_keys if key not in data] + if missing: + return [f"Missing required fields: {', '.join(missing)}"] + return [] + + @staticmethod + def _filter_allowed_fields( + data: Dict[str, Any], + allowed_keys: List[str], + context: str, + ) -> tuple[Dict[str, Any], List[str]]: + """ + Filter dictionary to only allowed keys, returning filtered data and warnings. + + Args: + data: Input dictionary to filter + allowed_keys: List of keys to keep + context: Description for warning messages (e.g. "create saved scenario") + + Returns: + (filtered_data, warnings) tuple + """ + filtered = {k: v for k, v in data.items() if k in allowed_keys} + ignored_keys = set(data.keys()) - set(filtered.keys()) + warnings = [ + f"Ignoring invalid field for {context}: {key!r}" for key in ignored_keys + ] + return filtered, warnings + @classmethod def _validate_response_keys( cls, data: Dict[str, Any], required_keys: List[str], fill_missing: bool = False diff --git a/src/pyetm/services/scenario_runners/create_saved_scenario.py b/src/pyetm/services/scenario_runners/create_saved_scenario.py new file mode 100644 index 0000000..9354aaa --- /dev/null +++ b/src/pyetm/services/scenario_runners/create_saved_scenario.py @@ -0,0 +1,71 @@ +from typing import Any, Dict +from pyetm.services.scenario_runners.base_runner import BaseRunner +from ..service_result import ServiceResult +from pyetm.clients.base_client import BaseClient + + +class CreateSavedScenarioRunner(BaseRunner[Dict[str, Any]]): + """ + Runner for creating a SavedScenario in MyETM from a SessionID scenario. + + POST /api/v3/saved_scenarios + + Args: + client: The HTTP client to use + saved_scenario_data: Dictionary with scenario_id, title, description, private + **kwargs: Additional arguments passed to the request + """ + + REQUIRED_KEYS = ["scenario_id", "title"] + OPTIONAL_KEYS = ["description", "private"] + + @staticmethod + def run( + client: BaseClient, saved_scenario_data: Dict[str, Any], **kwargs + ) -> ServiceResult[Dict[str, Any]]: + """ + Create a new SavedScenario in MyETM. + + Example usage: + result = CreateSavedScenarioRunner.run( + client=client, + saved_scenario_data={ + "scenario_id": 123, + "title": "My Saved Scenario", + "description": "Optional description", + "private": False + } + ) + """ + errors = CreateSavedScenarioRunner._validate_required_fields( + saved_scenario_data, CreateSavedScenarioRunner.REQUIRED_KEYS + ) + + if errors: + return ServiceResult.fail(errors) + + all_allowed = ( + CreateSavedScenarioRunner.REQUIRED_KEYS + + CreateSavedScenarioRunner.OPTIONAL_KEYS + ) + filtered_data, warnings = CreateSavedScenarioRunner._filter_allowed_fields( + saved_scenario_data, + all_allowed, + "create saved scenario", + ) + + payload = {"saved_scenario": filtered_data} + + result = CreateSavedScenarioRunner._make_request( + client=client, + method="post", + path="/saved_scenarios", + payload=payload, + **kwargs, + ) + + if result.success and warnings: + combined_errors = list(result.errors) + warnings + return ServiceResult.ok(data=result.data, errors=combined_errors) + + return result diff --git a/src/pyetm/services/scenario_runners/update_saved_scenario.py b/src/pyetm/services/scenario_runners/update_saved_scenario.py new file mode 100644 index 0000000..2c71121 --- /dev/null +++ b/src/pyetm/services/scenario_runners/update_saved_scenario.py @@ -0,0 +1,70 @@ +from typing import Any, Dict +from pyetm.services.scenario_runners.base_runner import BaseRunner +from ..service_result import ServiceResult +from pyetm.clients.base_client import BaseClient + + +class UpdateSavedScenarioRunner(BaseRunner[Dict[str, Any]]): + """ + Runner for updating a SavedScenario in MyETM. + + PUT /api/v3/saved_scenarios/:id + + Args: + client: The HTTP client to use + saved_scenario_id: ID of the SavedScenario to update + update_data: Dictionary with fields to update (title, description, private, discarded) + **kwargs: Additional arguments passed to the request + """ + + ALLOWED_KEYS = ["title", "scenario_id", "private", "discarded"] + + @staticmethod + def run( + client: BaseClient, + saved_scenario_id: int, + update_data: Dict[str, Any], + **kwargs, + ) -> ServiceResult[Dict[str, Any]]: + """ + Update an existing SavedScenario in MyETM. + + Example usage: + result = UpdateSavedScenarioRunner.run( + client=client, + saved_scenario_id=123, + update_data={ + "title": "Updated Title", + "description": "New description" + } + ) + """ + if not update_data: + return ServiceResult.fail(["No fields provided for update"]) + + filtered_data, warnings = UpdateSavedScenarioRunner._filter_allowed_fields( + update_data, + UpdateSavedScenarioRunner.ALLOWED_KEYS, + "update saved scenario", + ) + + if not filtered_data: + return ServiceResult.fail( + ["No valid fields provided for update"] + warnings + ) + + payload = {"saved_scenario": filtered_data} + + result = UpdateSavedScenarioRunner._make_request( + client=client, + method="put", + path=f"/saved_scenarios/{saved_scenario_id}", + payload=payload, + **kwargs, + ) + + if result.success and warnings: + combined_errors = list(result.errors) + warnings + return ServiceResult.ok(data=result.data, errors=combined_errors) + + return result diff --git a/tests/models/conftest.py b/tests/models/conftest.py index 2a57532..39d6d87 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -13,6 +13,7 @@ from pyetm.models.sortables import Sortables from pyetm.models.scenario import Scenario from pyetm.models.output_curves import OutputCurves +from pyetm.models.saved_scenario import SavedScenario # --- Scenario Fixtures --- # @@ -405,3 +406,34 @@ def patch_add_frame(monkeypatch): raising=True, ) return m + + +# --- SavedScenario Fixtures --- # + + +@pytest.fixture +def saved_scenario_data(): + """Saved scenario data for testing.""" + return { + "id": 456, + "scenario_id": 123, + "title": "My Saved Scenario", + "description": "A test description", + "private": False, + "area_code": "nl", + "end_year": 2050, + "created_at": "2025-01-01T12:00:00Z", + "updated_at": "2025-01-02T12:00:00Z", + } + + +@pytest.fixture +def saved_scenario(saved_scenario_data): + """A basic SavedScenario instance for testing.""" + return SavedScenario.model_validate(saved_scenario_data) + + +@pytest.fixture +def mock_client(): + """Mock BaseClient for testing.""" + return Mock() diff --git a/tests/models/test_saved_scenario.py b/tests/models/test_saved_scenario.py new file mode 100644 index 0000000..7454c91 --- /dev/null +++ b/tests/models/test_saved_scenario.py @@ -0,0 +1,427 @@ +from unittest.mock import Mock, patch +import pytest +from datetime import datetime +from pyetm.models.saved_scenario import SavedScenario, SavedScenarioError +from pyetm.models.scenario import Scenario +from pyetm.services.scenario_runners.create_saved_scenario import ( + CreateSavedScenarioRunner, +) +from pyetm.services.scenario_runners.update_saved_scenario import ( + UpdateSavedScenarioRunner, +) + + +# --- Model Validation Tests --- # + + +def test_saved_scenario_session_validation_minimal(): + """Test SavedScenario model validates with minimal required fields.""" + data = { + "id": 1, + "scenario_id": 100, + "title": "Test Scenario", + } + saved_scenario = SavedScenario.model_validate(data) + assert saved_scenario.id == 1 + assert saved_scenario.scenario_id == 100 + assert saved_scenario.title == "Test Scenario" + assert saved_scenario.description is None + assert saved_scenario.private is False + + +def test_saved_scenario_session_validation_full(saved_scenario_data): + """Test SavedScenario model validates with all fields.""" + saved_scenario = SavedScenario.model_validate(saved_scenario_data) + assert saved_scenario.id == 456 + assert saved_scenario.scenario_id == 123 + assert saved_scenario.title == "My Saved Scenario" + assert saved_scenario.description == "A test description" + assert saved_scenario.private is False + assert saved_scenario.area_code == "nl" + assert saved_scenario.end_year == 2050 + + +def test_saved_scenario_session_with_nested_scenario(): + """Test SavedScenario model with nested scenario data.""" + data = { + "id": 1, + "scenario_id": 100, + "title": "Test Scenario", + "scenario": { + "id": 100, + "area_code": "nl", + "end_year": 2050, + }, + } + saved_scenario = SavedScenario.model_validate(data) + assert saved_scenario.scenario is not None + assert saved_scenario.scenario["id"] == 100 + + +# --- Create Tests --- # + + +def test_create_saved_scenario_success(monkeypatch, ok_service_result, mock_client): + """Test successful SavedScenario creation.""" + created_data = { + "id": 789, + "scenario_id": 123, + "title": "New Saved Scenario", + "description": "Created via API", + "private": True, + } + + monkeypatch.setattr( + CreateSavedScenarioRunner, + "run", + lambda client, params: ok_service_result(created_data), + ) + + params = { + "scenario_id": 123, + "title": "New Saved Scenario", + "description": "Created via API", + "private": True, + } + + saved_scenario = SavedScenario.create(params, client=mock_client) + assert saved_scenario.id == 789 + assert saved_scenario.scenario_id == 123 + assert saved_scenario.title == "New Saved Scenario" + assert saved_scenario.private is True + assert len(saved_scenario.warnings) == 0 + + +def test_create_saved_scenario_with_warnings( + monkeypatch, ok_service_result, mock_client +): + """Test SavedScenario creation with warnings.""" + created_data = { + "id": 790, + "scenario_id": 123, + "title": "Saved Scenario", + } + warnings = ["Ignoring invalid field for create saved scenario: 'invalid_field'"] + + monkeypatch.setattr( + CreateSavedScenarioRunner, + "run", + lambda client, params: ok_service_result(created_data, warnings), + ) + + params = { + "scenario_id": 123, + "title": "Saved Scenario", + "invalid_field": "should_be_ignored", + } + + saved_scenario = SavedScenario.create(params, client=mock_client) + assert saved_scenario.id == 790 + base_warnings = saved_scenario.warnings.get_by_field("base") + assert len(base_warnings) == 1 + assert base_warnings[0].message == warnings[0] + + +def test_create_saved_scenario_failure(monkeypatch, fail_service_result, mock_client): + """Test SavedScenario creation failure raises SavedScenarioError.""" + monkeypatch.setattr( + CreateSavedScenarioRunner, + "run", + lambda client, params: fail_service_result(["Missing required field: title"]), + ) + + params = {"scenario_id": 123} # Missing title + + with pytest.raises(SavedScenarioError, match="Could not create saved scenario"): + SavedScenario.create(params, client=mock_client) + + +def test_create_saved_scenario_preserves_params_not_in_response( + monkeypatch, ok_service_result, mock_client +): + """Test that params not returned by API are still set on the instance.""" + created_data = { + "id": 791, + "scenario_id": 123, + "title": "Saved Scenario", + # description not in response + } + + monkeypatch.setattr( + CreateSavedScenarioRunner, + "run", + lambda client, params: ok_service_result(created_data), + ) + + params = { + "scenario_id": 123, + "title": "Saved Scenario", + "description": "Local description", + } + + saved_scenario = SavedScenario.create(params, client=mock_client) + # description should be set from params since it wasn't in response + assert saved_scenario.description == "Local description" + + +# --- from_scenario Tests --- # + + +def test_from_scenario_success(monkeypatch, ok_service_result, mock_client): + """Test creating SavedScenario from a Scenario instance.""" + # Create a mock scenario + scenario = Mock(spec=Scenario) + scenario.id = 999 + + created_data = { + "id": 800, + "scenario_id": 999, + "title": "From Scenario", + "description": "Created from scenario", + "private": False, + } + + monkeypatch.setattr( + CreateSavedScenarioRunner, + "run", + lambda client, params: ok_service_result(created_data), + ) + + saved_scenario = SavedScenario.from_scenario( + scenario, + title="From Scenario", + client=mock_client, + description="Created from scenario", + ) + + assert saved_scenario.id == 800 + assert saved_scenario.scenario_id == 999 + assert saved_scenario.title == "From Scenario" + assert saved_scenario.description == "Created from scenario" + + +def test_from_scenario_with_kwargs(monkeypatch, ok_service_result, mock_client): + """Test from_scenario passes kwargs correctly.""" + scenario = Mock(spec=Scenario) + scenario.id = 1000 + + created_data = { + "id": 801, + "scenario_id": 1000, + "title": "Private Scenario", + "private": True, + } + + captured_params = {} + + def capture_run(client, params): + captured_params.update(params) + return ok_service_result(created_data) + + monkeypatch.setattr(CreateSavedScenarioRunner, "run", capture_run) + + SavedScenario.from_scenario( + scenario, title="Private Scenario", client=mock_client, private=True + ) + + assert captured_params["scenario_id"] == 1000 + assert captured_params["title"] == "Private Scenario" + assert captured_params["private"] is True + + +# --- Update Tests --- # + + +def test_update_saved_scenario_success( + monkeypatch, ok_service_result, saved_scenario, mock_client +): + """Test successful SavedScenario update.""" + updated_data = { + "id": 456, + "title": "Updated Title", + "description": "Updated description", + } + + monkeypatch.setattr( + UpdateSavedScenarioRunner, + "run", + lambda client, id, kwargs: ok_service_result(updated_data), + ) + + saved_scenario.update( + mock_client, title="Updated Title", description="Updated description" + ) + + assert saved_scenario.title == "Updated Title" + assert saved_scenario.description == "Updated description" + assert len(saved_scenario.warnings) == 0 + + +def test_update_saved_scenario_with_warnings( + monkeypatch, ok_service_result, saved_scenario, mock_client +): + """Test SavedScenario update with warnings.""" + updated_data = {"id": 456, "title": "New Title"} + warnings = ["Ignoring invalid field for update saved scenario: 'invalid_field'"] + + monkeypatch.setattr( + UpdateSavedScenarioRunner, + "run", + lambda client, id, kwargs: ok_service_result(updated_data, warnings), + ) + + saved_scenario.update(mock_client, title="New Title", invalid_field="ignored") + + assert saved_scenario.title == "New Title" + update_warnings = saved_scenario.warnings.get_by_field("update") + assert len(update_warnings) == 1 + assert update_warnings[0].message == warnings[0] + + +def test_update_saved_scenario_failure( + monkeypatch, fail_service_result, saved_scenario, mock_client +): + """Test SavedScenario update failure raises SavedScenarioError.""" + monkeypatch.setattr( + UpdateSavedScenarioRunner, + "run", + lambda client, id, kwargs: fail_service_result(["403: Forbidden"]), + ) + + with pytest.raises(SavedScenarioError, match="Could not update saved scenario"): + saved_scenario.update(mock_client, title="New Title") + + +def test_update_saved_scenario_applies_response_data( + monkeypatch, ok_service_result, saved_scenario, mock_client +): + """Test that update applies response data to the instance.""" + original_title = saved_scenario.title + updated_data = { + "id": 456, + "title": "Server Updated Title", + "private": True, + "updated_at": "2025-12-15T10:00:00Z", + } + + monkeypatch.setattr( + UpdateSavedScenarioRunner, + "run", + lambda client, id, kwargs: ok_service_result(updated_data), + ) + + saved_scenario.update(mock_client, title="Requested Title") + + # Should use server response, not the requested value + assert saved_scenario.title == "Server Updated Title" + assert saved_scenario.private is True + + +def test_update_saved_scenario_applies_kwargs_if_not_in_response( + monkeypatch, ok_service_result, saved_scenario, mock_client +): + """Test that kwargs are applied if not in response data.""" + updated_data = {"id": 456} # Response doesn't include title + + monkeypatch.setattr( + UpdateSavedScenarioRunner, + "run", + lambda client, id, kwargs: ok_service_result(updated_data), + ) + + saved_scenario.update(mock_client, title="Local Title") + + # Should use local value since it wasn't in response + assert saved_scenario.title == "Local Title" + + +def test_update_saved_scenario_discard( + monkeypatch, ok_service_result, saved_scenario, mock_client +): + """Test discarding a SavedScenario.""" + updated_data = {"id": 456, "discarded": True} + + monkeypatch.setattr( + UpdateSavedScenarioRunner, + "run", + lambda client, id, kwargs: ok_service_result(updated_data), + ) + + saved_scenario.update(mock_client, discarded=True) + # discarded is not a model field, but should not raise an error + + +def test_update_saved_scenario_change_privacy( + monkeypatch, ok_service_result, saved_scenario, mock_client +): + """Test changing privacy setting.""" + assert saved_scenario.private is False + + updated_data = {"id": 456, "private": True} + + monkeypatch.setattr( + UpdateSavedScenarioRunner, + "run", + lambda client, id, kwargs: ok_service_result(updated_data), + ) + + saved_scenario.update(mock_client, private=True) + assert saved_scenario.private is True + + +# --- session Property Tests --- # + + +def test_session_property_from_nested_data(saved_scenario): + """Test session property creates Scenario from nested data.""" + saved_scenario.scenario = { + "id": 123, + "area_code": "nl", + "end_year": 2050, + } + + scenario = saved_scenario.session + assert scenario.id == 123 + assert scenario.area_code == "nl" + assert scenario.end_year == 2050 + + +def test_session_property_caches_result(saved_scenario): + """Test session property caches the Scenario instance.""" + saved_scenario.scenario = { + "id": 123, + "area_code": "nl", + "end_year": 2050, + } + + scenario1 = saved_scenario.session + scenario2 = saved_scenario.session + + # Should return the same cached instance + assert scenario1 is scenario2 + + +def test_session_property_returns_cached_model(saved_scenario): + """Test session property returns cached model if set.""" + cached_scenario = Mock(spec=Scenario) + cached_scenario.id = 999 + saved_scenario._scenario_session = cached_scenario + + scenario = saved_scenario.session + assert scenario is cached_scenario + assert scenario.id == 999 + + +def test_session_property_fetches_if_no_nested_data(monkeypatch, saved_scenario): + """Test session property fetches if no nested scenario data.""" + saved_scenario.scenario = None + saved_scenario._scenario_session = None + + fetched_scenario = Mock(spec=Scenario) + fetched_scenario.id = 123 + + with patch.object(Scenario, "load", return_value=fetched_scenario) as mock_load: + scenario = saved_scenario.session + + mock_load.assert_called_once_with(123) + assert scenario is fetched_scenario diff --git a/tests/services/scenario_runners/test_create_saved_scenario.py b/tests/services/scenario_runners/test_create_saved_scenario.py new file mode 100644 index 0000000..d1a3efe --- /dev/null +++ b/tests/services/scenario_runners/test_create_saved_scenario.py @@ -0,0 +1,251 @@ +from pyetm.services.scenario_runners.create_saved_scenario import ( + CreateSavedScenarioRunner, +) + + +def test_create_saved_scenario_success_minimal(dummy_client, fake_response): + """Test creating a SavedScenario with only required fields.""" + body = { + "id": 456, + "scenario_id": 123, + "title": "My Saved Scenario", + "description": None, + "private": False, + } + response = fake_response(ok=True, status_code=201, json_data=body) + client = dummy_client(response, method="post") + + saved_scenario_data = {"scenario_id": 123, "title": "My Saved Scenario"} + + result = CreateSavedScenarioRunner.run(client, saved_scenario_data) + assert result.success is True + assert result.data == body + assert result.errors == [] + assert client.calls == [ + ("/saved_scenarios", {"json": {"saved_scenario": saved_scenario_data}}) + ] + + +def test_create_saved_scenario_success_with_optional_fields( + dummy_client, fake_response +): + """Test creating a SavedScenario with all fields.""" + body = { + "id": 457, + "scenario_id": 123, + "title": "My Saved Scenario", + "description": "A detailed description", + "private": True, + } + response = fake_response(ok=True, status_code=201, json_data=body) + client = dummy_client(response, method="post") + + saved_scenario_data = { + "scenario_id": 123, + "title": "My Saved Scenario", + "description": "A detailed description", + "private": True, + } + + result = CreateSavedScenarioRunner.run(client, saved_scenario_data) + assert result.success is True + assert result.data == body + assert result.errors == [] + assert client.calls == [ + ("/saved_scenarios", {"json": {"saved_scenario": saved_scenario_data}}) + ] + + +def test_create_saved_scenario_missing_required_field_scenario_id( + dummy_client, fake_response +): + """Test that missing scenario_id returns an error.""" + client = dummy_client({}, method="post") + + saved_scenario_data = {"title": "My Saved Scenario"} # Missing scenario_id + + result = CreateSavedScenarioRunner.run(client, saved_scenario_data) + assert result.success is False + assert result.data is None + assert "Missing required fields: scenario_id" in result.errors[0] + assert len(client.calls) == 0 # Should not make API call + + +def test_create_saved_scenario_missing_required_field_title( + dummy_client, fake_response +): + """Test that missing title returns an error.""" + client = dummy_client({}, method="post") + + saved_scenario_data = {"scenario_id": 123} # Missing title + + result = CreateSavedScenarioRunner.run(client, saved_scenario_data) + assert result.success is False + assert result.data is None + assert "Missing required fields: title" in result.errors[0] + assert len(client.calls) == 0 # Should not make API call + + +def test_create_saved_scenario_missing_both_required_fields( + dummy_client, fake_response +): + """Test that missing both required fields returns an error.""" + client = dummy_client({}, method="post") + + saved_scenario_data = {"private": True} # Missing both required fields + + result = CreateSavedScenarioRunner.run(client, saved_scenario_data) + assert result.success is False + assert result.data is None + error_msg = result.errors[0] + assert "Missing required fields:" in error_msg + assert "scenario_id" in error_msg + assert "title" in error_msg + assert len(client.calls) == 0 # Should not make API call + + +def test_create_saved_scenario_filters_invalid_fields(dummy_client, fake_response): + """Test that invalid fields are filtered and warnings are returned.""" + body = { + "id": 458, + "scenario_id": 123, + "title": "My Saved Scenario", + } + response = fake_response(ok=True, status_code=201, json_data=body) + client = dummy_client(response, method="post") + + saved_scenario_data = { + "scenario_id": 123, + "title": "My Saved Scenario", + "description": "Valid description", # Valid + "id": 999, # Invalid - should be filtered + "created_at": "2019-01-01", # Invalid - should be filtered + "invalid_field": "value", # Invalid - should be filtered + } + + result = CreateSavedScenarioRunner.run(client, saved_scenario_data) + assert result.success is True + assert result.data == body + + # Should have warnings for filtered fields + expected_warnings = [ + "Ignoring invalid field for create saved scenario: 'id'", + "Ignoring invalid field for create saved scenario: 'created_at'", + "Ignoring invalid field for create saved scenario: 'invalid_field'", + ] + for warning in expected_warnings: + assert warning in result.errors + + # Should only send valid fields + expected_payload = { + "saved_scenario": { + "scenario_id": 123, + "title": "My Saved Scenario", + "description": "Valid description", + } + } + assert client.calls == [("/saved_scenarios", {"json": expected_payload})] + + +def test_create_saved_scenario_http_failure_422(dummy_client, fake_response): + """Test handling of 422 validation error.""" + response = fake_response(ok=False, status_code=422, text="Validation Error") + client = dummy_client(response, method="post") + + saved_scenario_data = {"scenario_id": 123, "title": "My Saved Scenario"} + + result = CreateSavedScenarioRunner.run(client, saved_scenario_data) + assert result.success is False + assert result.data is None + assert result.errors == ["422: Validation Error"] + + +def test_create_saved_scenario_http_failure_401(dummy_client, fake_response): + """Test handling of 401 unauthorized error.""" + response = fake_response(ok=False, status_code=401, text="Unauthorized") + client = dummy_client(response, method="post") + + saved_scenario_data = {"scenario_id": 123, "title": "My Saved Scenario"} + + result = CreateSavedScenarioRunner.run(client, saved_scenario_data) + assert result.success is False + assert result.data is None + assert result.errors == ["401: Unauthorized"] + + +def test_create_saved_scenario_http_failure_404(dummy_client, fake_response): + """Test handling of 404 not found error (scenario doesn't exist).""" + response = fake_response(ok=False, status_code=404, text="Scenario not found") + client = dummy_client(response, method="post") + + saved_scenario_data = {"scenario_id": 99999, "title": "My Saved Scenario"} + + result = CreateSavedScenarioRunner.run(client, saved_scenario_data) + assert result.success is False + assert result.data is None + assert result.errors == ["404: Scenario not found"] + + +def test_create_saved_scenario_connection_error(dummy_client): + """Test handling of connection errors.""" + client = dummy_client(ConnectionError("Connection failed"), method="post") + + saved_scenario_data = {"scenario_id": 123, "title": "My Saved Scenario"} + + result = CreateSavedScenarioRunner.run(client, saved_scenario_data) + assert result.success is False + assert result.data is None + assert any("Connection failed" in err for err in result.errors) + + +def test_create_saved_scenario_with_kwargs(dummy_client, fake_response): + """Test that kwargs are passed through to the request.""" + body = {"id": 459, "scenario_id": 123, "title": "My Saved Scenario"} + response = fake_response(ok=True, status_code=201, json_data=body) + client = dummy_client(response, method="post") + + saved_scenario_data = {"scenario_id": 123, "title": "My Saved Scenario"} + + result = CreateSavedScenarioRunner.run(client, saved_scenario_data, timeout=30) + assert result.success is True + assert result.data == body + assert result.errors == [] + # Verify basic structure + assert len(client.calls) == 1 + assert client.calls[0][0] == "/saved_scenarios" + assert client.calls[0][1]["json"] == {"saved_scenario": saved_scenario_data} + + +def test_create_saved_scenario_payload_structure(dummy_client, fake_response): + """Test that the payload is correctly structured for the API.""" + body = {"id": 460, "scenario_id": 123, "title": "Test Scenario"} + response = fake_response(ok=True, status_code=201, json_data=body) + client = dummy_client(response, method="post") + + saved_scenario_data = { + "scenario_id": 123, + "title": "Test Scenario", + "description": "Test description", + "private": True, + } + + CreateSavedScenarioRunner.run(client, saved_scenario_data) + + # Verify the exact payload structure + expected_call = ( + "/saved_scenarios", + {"json": {"saved_scenario": saved_scenario_data}}, + ) + assert client.calls == [expected_call] + + +def test_create_saved_scenario_empty_data(dummy_client, fake_response): + """Test that empty data returns an error.""" + client = dummy_client({}, method="post") + + saved_scenario_data = {} + + result = CreateSavedScenarioRunner.run(client, saved_scenario_data) + assert result.success is False + assert result.data is None + assert len(client.calls) == 0 # Should not make API call diff --git a/tests/services/scenario_runners/test_update_saved_scenario.py b/tests/services/scenario_runners/test_update_saved_scenario.py new file mode 100644 index 0000000..318142e --- /dev/null +++ b/tests/services/scenario_runners/test_update_saved_scenario.py @@ -0,0 +1,318 @@ +from pyetm.services.scenario_runners.update_saved_scenario import ( + UpdateSavedScenarioRunner, +) + + +def test_update_saved_scenario_success_single_field(dummy_client, fake_response): + """Test updating a single field of a SavedScenario.""" + body = { + "id": 456, + "scenario_id": 123, + "title": "Updated Title", + "description": None, + "private": False, + } + response = fake_response(ok=True, status_code=200, json_data=body) + client = dummy_client(response, method="put") + + update_data = {"title": "Updated Title"} + + result = UpdateSavedScenarioRunner.run( + client, saved_scenario_id=456, update_data=update_data + ) + assert result.success is True + assert result.data == body + assert result.errors == [] + assert client.calls == [ + ("/saved_scenarios/456", {"json": {"saved_scenario": update_data}}) + ] + + +def test_update_saved_scenario_success_multiple_fields(dummy_client, fake_response): + """Test updating multiple fields of a SavedScenario.""" + body = { + "id": 456, + "scenario_id": 123, + "title": "New Title", + "private": True, + } + response = fake_response(ok=True, status_code=200, json_data=body) + client = dummy_client(response, method="put") + + update_data = { + "title": "New Title", + "private": True, + } + + result = UpdateSavedScenarioRunner.run( + client, saved_scenario_id=456, update_data=update_data + ) + assert result.success is True + assert result.data == body + assert result.errors == [] + assert client.calls == [ + ("/saved_scenarios/456", {"json": {"saved_scenario": update_data}}) + ] + + +def test_update_saved_scenario_success_all_allowed_fields(dummy_client, fake_response): + """Test updating all allowed fields of a SavedScenario.""" + body = { + "id": 456, + "scenario_id": 123, + "title": "Full Update", + "private": True, + "discarded": False, + } + response = fake_response(ok=True, status_code=200, json_data=body) + client = dummy_client(response, method="put") + + update_data = { + "title": "Full Update", + "private": True, + "discarded": False, + } + + result = UpdateSavedScenarioRunner.run( + client, saved_scenario_id=456, update_data=update_data + ) + assert result.success is True + assert result.data == body + assert result.errors == [] + assert client.calls == [ + ("/saved_scenarios/456", {"json": {"saved_scenario": update_data}}) + ] + + +def test_update_saved_scenario_empty_update_data(dummy_client, fake_response): + """Test that empty update data returns an error.""" + client = dummy_client({}, method="put") + + update_data = {} + + result = UpdateSavedScenarioRunner.run( + client, saved_scenario_id=456, update_data=update_data + ) + assert result.success is False + assert result.data is None + assert "No fields provided for update" in result.errors + assert len(client.calls) == 0 # Should not make API call + + +def test_update_saved_scenario_filters_invalid_fields(dummy_client, fake_response): + """Test that invalid fields are filtered and warnings are returned.""" + body = { + "id": 456, + "title": "Updated Title", + } + response = fake_response(ok=True, status_code=200, json_data=body) + client = dummy_client(response, method="put") + + update_data = { + "title": "Updated Title", # Valid + "id": 999, # Invalid - should be filtered + "scenario_id": 123, # Valid + "created_at": "2019-01-01", # Invalid - should be filtered + "invalid_field": "value", # Invalid - should be filtered + } + + result = UpdateSavedScenarioRunner.run( + client, saved_scenario_id=456, update_data=update_data + ) + assert result.success is True + assert result.data == body + + # Should have warnings for filtered fields + expected_warnings = [ + "Ignoring invalid field for update saved scenario: 'id'", + "Ignoring invalid field for update saved scenario: 'created_at'", + "Ignoring invalid field for update saved scenario: 'invalid_field'", + ] + for warning in expected_warnings: + assert warning in result.errors + + # Should only send valid fields + expected_payload = { + "saved_scenario": {"title": "Updated Title", "scenario_id": 123} + } + assert client.calls == [("/saved_scenarios/456", {"json": expected_payload})] + + +def test_update_saved_scenario_only_invalid_fields(dummy_client, fake_response): + """Test that only invalid fields returns an error.""" + client = dummy_client({}, method="put") + + update_data = { + "id": 999, # Invalid + "invalid_field": "value", # Invalid + } + + result = UpdateSavedScenarioRunner.run( + client, saved_scenario_id=456, update_data=update_data + ) + assert result.success is False + assert result.data is None + assert "No valid fields provided for update" in result.errors + assert len(client.calls) == 0 # Should not make API call + + +def test_update_saved_scenario_http_failure_422(dummy_client, fake_response): + """Test handling of 422 validation error.""" + response = fake_response(ok=False, status_code=422, text="Validation Error") + client = dummy_client(response, method="put") + + update_data = {"title": ""} # Invalid empty title + + result = UpdateSavedScenarioRunner.run( + client, saved_scenario_id=456, update_data=update_data + ) + assert result.success is False + assert result.data is None + assert result.errors == ["422: Validation Error"] + + +def test_update_saved_scenario_http_failure_401(dummy_client, fake_response): + """Test handling of 401 unauthorized error.""" + response = fake_response(ok=False, status_code=401, text="Unauthorized") + client = dummy_client(response, method="put") + + update_data = {"title": "New Title"} + + result = UpdateSavedScenarioRunner.run( + client, saved_scenario_id=456, update_data=update_data + ) + assert result.success is False + assert result.data is None + assert result.errors == ["401: Unauthorized"] + + +def test_update_saved_scenario_http_failure_404(dummy_client, fake_response): + """Test handling of 404 not found error.""" + response = fake_response(ok=False, status_code=404, text="SavedScenario not found") + client = dummy_client(response, method="put") + + update_data = {"title": "New Title"} + + result = UpdateSavedScenarioRunner.run( + client, saved_scenario_id=99999, update_data=update_data + ) + assert result.success is False + assert result.data is None + assert result.errors == ["404: SavedScenario not found"] + + +def test_update_saved_scenario_http_failure_403(dummy_client, fake_response): + """Test handling of 403 forbidden error (not owner).""" + response = fake_response( + ok=False, status_code=403, text="Not authorized to update this scenario" + ) + client = dummy_client(response, method="put") + + update_data = {"title": "New Title"} + + result = UpdateSavedScenarioRunner.run( + client, saved_scenario_id=456, update_data=update_data + ) + assert result.success is False + assert result.data is None + assert result.errors == ["403: Not authorized to update this scenario"] + + +def test_update_saved_scenario_connection_error(dummy_client): + """Test handling of connection errors.""" + client = dummy_client(ConnectionError("Connection failed"), method="put") + + update_data = {"title": "New Title"} + + result = UpdateSavedScenarioRunner.run( + client, saved_scenario_id=456, update_data=update_data + ) + assert result.success is False + assert result.data is None + assert any("Connection failed" in err for err in result.errors) + + +def test_update_saved_scenario_with_kwargs(dummy_client, fake_response): + """Test that kwargs are passed through to the request.""" + body = {"id": 456, "title": "Updated Title"} + response = fake_response(ok=True, status_code=200, json_data=body) + client = dummy_client(response, method="put") + + update_data = {"title": "Updated Title"} + + result = UpdateSavedScenarioRunner.run( + client, saved_scenario_id=456, update_data=update_data, timeout=30 + ) + assert result.success is True + assert result.data == body + assert result.errors == [] + # Verify basic structure + assert len(client.calls) == 1 + assert client.calls[0][0] == "/saved_scenarios/456" + assert client.calls[0][1]["json"] == {"saved_scenario": update_data} + + +def test_update_saved_scenario_payload_structure(dummy_client, fake_response): + """Test that the payload is correctly structured for the API.""" + body = {"id": 456, "title": "Test Title"} + response = fake_response(ok=True, status_code=200, json_data=body) + client = dummy_client(response, method="put") + + update_data = { + "title": "Test Title", + } + + UpdateSavedScenarioRunner.run( + client, saved_scenario_id=456, update_data=update_data + ) + + # Verify the exact payload structure + expected_call = ( + "/saved_scenarios/456", + {"json": {"saved_scenario": update_data}}, + ) + assert client.calls == [expected_call] + + +def test_update_saved_scenario_discard(dummy_client, fake_response): + """Test discarding a SavedScenario.""" + body = { + "id": 456, + "discarded": True, + } + response = fake_response(ok=True, status_code=200, json_data=body) + client = dummy_client(response, method="put") + + update_data = {"discarded": True} + + result = UpdateSavedScenarioRunner.run( + client, saved_scenario_id=456, update_data=update_data + ) + assert result.success is True + assert result.data == body + assert result.errors == [] + assert client.calls == [ + ("/saved_scenarios/456", {"json": {"saved_scenario": update_data}}) + ] + + +def test_update_saved_scenario_change_privacy(dummy_client, fake_response): + """Test changing privacy setting of a SavedScenario.""" + body = { + "id": 456, + "private": True, + } + response = fake_response(ok=True, status_code=200, json_data=body) + client = dummy_client(response, method="put") + + update_data = {"private": True} + + result = UpdateSavedScenarioRunner.run( + client, saved_scenario_id=456, update_data=update_data + ) + assert result.success is True + assert result.data == body + assert result.errors == [] + assert client.calls == [ + ("/saved_scenarios/456", {"json": {"saved_scenario": update_data}}) + ]