diff --git a/docs/tls-security-profile.md b/docs/tls-security-profile.md new file mode 100644 index 000000000..578621298 --- /dev/null +++ b/docs/tls-security-profile.md @@ -0,0 +1,46 @@ +# TLS Security Profile Configuration + +This document describes how to configure and test the TLS security profile for outgoing connections to the Llama Stack provider. + +## Overview + +The TLS security profile allows you to enforce specific TLS security settings for connections from Lightspeed Stack to the Llama Stack server. This includes: + +- **Profile Type**: Predefined security profiles (OldType, IntermediateType, ModernType, Custom) +- **Minimum TLS Version**: Enforce minimum TLS protocol version (TLS 1.0 - 1.3) +- **Cipher Suites**: Specify allowed cipher suites +- **CA Certificate**: Custom CA certificate for server verification +- **Skip Verification**: Option to skip TLS verification (testing only) + +## Configuration + +Add the `tls_security_profile` section under `llama_stack` in your configuration file: + +```yaml +llama_stack: + url: https://llama-stack-server:8321 + use_as_library_client: false + tls_security_profile: + type: ModernType + minTLSVersion: VersionTLS13 + caCertPath: /path/to/ca-certificate.crt +``` + +### Configuration Options + +| Field | Type | Description | +|-------|------|-------------| +| `type` | string | Profile type: `OldType`, `IntermediateType`, `ModernType`, or `Custom` | +| `minTLSVersion` | string | Minimum TLS version: `VersionTLS10`, `VersionTLS11`, `VersionTLS12`, `VersionTLS13` | +| `ciphers` | list[string] | List of allowed cipher suites (optional, uses profile defaults) | +| `caCertPath` | string | Path to CA certificate file for server verification | +| `skipTLSVerification` | boolean | Skip TLS certificate verification (default: false, **testing only**) | + +### Profile Types + +| Profile | Min TLS Version | Description | +|---------|-----------------|-------------| +| `OldType` | TLS 1.0 | Legacy compatibility, wide cipher support | +| `IntermediateType` | TLS 1.2 | Balanced security and compatibility | +| `ModernType` | TLS 1.3 | Maximum security, TLS 1.3 only | +| `Custom` | Configurable | User-defined settings | diff --git a/src/client.py b/src/client.py index cb9d3ad32..deb660c2b 100644 --- a/src/client.py +++ b/src/client.py @@ -1,15 +1,17 @@ """Llama Stack client retrieval class.""" import logging - +import ssl from typing import Optional +import httpx from llama_stack import ( AsyncLlamaStackAsLibraryClient, # type: ignore ) from llama_stack_client import AsyncLlamaStackClient # type: ignore -from models.config import LlamaStackConfiguration +from models.config import LlamaStackConfiguration, TLSSecurityProfile from utils.types import Singleton +from utils import tls logger = logging.getLogger(__name__) @@ -20,6 +22,76 @@ class AsyncLlamaStackClientHolder(metaclass=Singleton): _lsc: Optional[AsyncLlamaStackClient] = None + def _construct_httpx_client( + self, tls_security_profile: Optional[TLSSecurityProfile] + ) -> Optional[httpx.AsyncClient]: + """Construct HTTPX client with TLS security profile configuration. + + Args: + tls_security_profile: TLS security profile configuration. + + Returns: + Configured httpx.AsyncClient if TLS profile is set, None otherwise. + """ + # if security profile is not set, return None to use default httpx client + if tls_security_profile is None or tls_security_profile.profile_type is None: + logger.info("No TLS security profile configured, using default settings") + return None + + logger.info("TLS security profile: %s", tls_security_profile.profile_type) + + # get the TLS profile type + profile_type = tls.TLSProfiles(tls_security_profile.profile_type) + + # retrieve ciphers - custom list or profile-based + ciphers = tls.ciphers_as_string(tls_security_profile.ciphers, profile_type) + logger.info("TLS ciphers: %s", ciphers) + + # retrieve minimum TLS version + min_tls_ver = tls.min_tls_version( + tls_security_profile.min_tls_version, profile_type + ) + logger.info("Minimum TLS version: %s", min_tls_ver) + + ssl_version = tls.ssl_tls_version(min_tls_ver) + logger.info("SSL version: %s", ssl_version) + + # check if TLS verification should be skipped (for testing only) + if tls_security_profile.skip_tls_verification: + logger.warning( + "TLS verification is disabled. This is insecure and should " + "only be used for testing purposes." + ) + return httpx.AsyncClient(verify=False) + + # create SSL context with the configured settings + context = ssl.create_default_context() + + # load CA certificate if specified + if tls_security_profile.ca_cert_path is not None: + logger.info("Loading CA certificate from: %s", tls_security_profile.ca_cert_path) + context.load_verify_locations(cafile=str(tls_security_profile.ca_cert_path)) + + if ssl_version is not None: + context.minimum_version = ssl_version + + if ciphers is not None: + # Note: TLS 1.3 ciphers cannot be set via set_ciphers() - they are + # automatically negotiated when TLS 1.3 is used. The set_ciphers() + # method only affects TLS 1.2 and below cipher selection. + try: + context.set_ciphers(ciphers) + except ssl.SSLError as e: + logger.warning( + "Could not set ciphers '%s': %s. " + "TLS 1.3 ciphers are automatically negotiated.", + ciphers, + e, + ) + + logger.info("Creating httpx.AsyncClient with TLS security profile") + return httpx.AsyncClient(verify=context) + async def load(self, llama_stack_config: LlamaStackConfiguration) -> None: """Retrieve Async Llama stack client according to configuration.""" if llama_stack_config.use_as_library_client is True: @@ -37,6 +109,12 @@ async def load(self, llama_stack_config: LlamaStackConfiguration) -> None: raise ValueError(msg) else: logger.info("Using Llama stack running as a service") + + # construct httpx client with TLS security profile if configured + http_client = self._construct_httpx_client( + llama_stack_config.tls_security_profile + ) + self._lsc = AsyncLlamaStackClient( base_url=llama_stack_config.url, api_key=( @@ -44,6 +122,7 @@ async def load(self, llama_stack_config: LlamaStackConfiguration) -> None: if llama_stack_config.api_key is not None else None ), + http_client=http_client, ) def get_client(self) -> AsyncLlamaStackClient: diff --git a/src/constants.py b/src/constants.py index 82ea14151..cea9625bf 100644 --- a/src/constants.py +++ b/src/constants.py @@ -152,3 +152,7 @@ # quota limiters constants USER_QUOTA_LIMITER = "user_limiter" CLUSTER_QUOTA_LIMITER = "cluster_limiter" + +# TLS security profile constants +DEFAULT_SSL_VERSION = "TLSv1_2" +DEFAULT_SSL_CIPHERS = "DEFAULT" diff --git a/src/models/config.py b/src/models/config.py index d8703ca67..be31afcad 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -28,6 +28,7 @@ import constants from utils import checks +from utils import tls class ConfigurationBase(BaseModel): @@ -76,6 +77,98 @@ def check_tls_configuration(self) -> Self: return self +class TLSSecurityProfile(ConfigurationBase): + """TLS security profile for outgoing connections. + + This configuration allows customizing the TLS security settings for + outgoing connections to LM providers. Users can specify: + - A predefined profile type (OldType, IntermediateType, ModernType, Custom) + - Minimum TLS version (VersionTLS10, VersionTLS11, VersionTLS12, VersionTLS13) + - List of allowed cipher suites + - CA certificate path for custom certificate authorities + - Option to skip TLS verification (for testing only) + + Example configuration: + tls_security_profile: + type: Custom + minTLSVersion: VersionTLS13 + ciphers: + - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 + caCertPath: /path/to/ca.crt + """ + + profile_type: Optional[str] = Field( + None, + alias="type", + title="Profile type", + description="TLS profile type: OldType, IntermediateType, ModernType, or Custom", + ) + min_tls_version: Optional[str] = Field( + None, + alias="minTLSVersion", + title="Minimum TLS version", + description="Minimum TLS version: VersionTLS10, VersionTLS11, VersionTLS12, VersionTLS13", + ) + ciphers: Optional[list[str]] = Field( + None, + title="Ciphers", + description="List of allowed cipher suites", + ) + ca_cert_path: Optional[FilePath] = Field( + None, + alias="caCertPath", + title="CA certificate path", + description="Path to CA certificate file for verifying server certificates", + ) + skip_tls_verification: bool = Field( + False, + alias="skipTLSVerification", + title="Skip TLS verification", + description="Skip TLS certificate verification (for testing only, not recommended for production)", + ) + + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + @model_validator(mode="after") + def check_tls_security_profile(self) -> Self: + """Validate TLS security profile configuration.""" + # check the TLS profile type + if self.profile_type is not None: + try: + tls.TLSProfiles(self.profile_type) + except ValueError as e: + valid_profiles = [p.value for p in tls.TLSProfiles] + raise ValueError( + f"Invalid TLS profile type '{self.profile_type}'. " + f"Valid types: {valid_profiles}" + ) from e + + # check the TLS protocol version + if self.min_tls_version is not None: + try: + tls.TLSProtocolVersion(self.min_tls_version) + except ValueError as e: + valid_versions = [v.value for v in tls.TLSProtocolVersion] + raise ValueError( + f"Invalid minimal TLS version '{self.min_tls_version}'. " + f"Valid versions: {valid_versions}" + ) from e + + # check ciphers - validate against profile if not Custom + if self.ciphers is not None and self.profile_type is not None: + if self.profile_type != tls.TLSProfiles.CUSTOM_TYPE: + profile = tls.TLSProfiles(self.profile_type) + supported_ciphers = tls.TLS_CIPHERS.get(profile, []) + for cipher in self.ciphers: + if cipher not in supported_ciphers: + raise ValueError( + f"Unsupported cipher '{cipher}' for profile '{self.profile_type}'" + ) + + return self + + class CORSConfiguration(ConfigurationBase): """CORS configuration. @@ -431,6 +524,12 @@ class LlamaStackConfiguration(ConfigurationBase): description="Path to configuration file used when Llama Stack is run in library mode", ) + tls_security_profile: Optional[TLSSecurityProfile] = Field( + None, + title="TLS security profile", + description="TLS security profile for outgoing connections to Llama Stack", + ) + @model_validator(mode="after") def check_llama_stack_model(self) -> Self: """ diff --git a/src/utils/tls.py b/src/utils/tls.py new file mode 100644 index 000000000..24255e137 --- /dev/null +++ b/src/utils/tls.py @@ -0,0 +1,149 @@ +"""TLS-related data structures and constants. + +For further information please look at TLS security profiles source: +https://github.com/openshift/api/blob/master/config/v1/types_tlssecurityprofile.go +""" + +import logging +import ssl +from enum import StrEnum +from typing import Optional + +logger = logging.getLogger(__name__) + + +class TLSProfiles(StrEnum): + """TLS profile names.""" + + OLD_TYPE = "OldType" + INTERMEDIATE_TYPE = "IntermediateType" + MODERN_TYPE = "ModernType" + CUSTOM_TYPE = "Custom" + + +class TLSProtocolVersion(StrEnum): + """TLS protocol versions.""" + + # version 1.0 of the TLS security protocol. + VERSION_TLS_10 = "VersionTLS10" + # version 1.1 of the TLS security protocol. + VERSION_TLS_11 = "VersionTLS11" + # version 1.2 of the TLS security protocol. + VERSION_TLS_12 = "VersionTLS12" + # version 1.3 of the TLS security protocol. + VERSION_TLS_13 = "VersionTLS13" + + +# Minimal TLS versions required for each TLS profile +MIN_TLS_VERSIONS: dict[TLSProfiles, TLSProtocolVersion] = { + TLSProfiles.OLD_TYPE: TLSProtocolVersion.VERSION_TLS_10, + TLSProfiles.INTERMEDIATE_TYPE: TLSProtocolVersion.VERSION_TLS_12, + TLSProfiles.MODERN_TYPE: TLSProtocolVersion.VERSION_TLS_13, +} + +# TLS ciphers defined for each TLS profile +TLS_CIPHERS: dict[TLSProfiles, list[str]] = { + TLSProfiles.OLD_TYPE: [ + "TLS_AES_128_GCM_SHA256", + "TLS_AES_256_GCM_SHA384", + "TLS_CHACHA20_POLY1305_SHA256", + "ECDHE-ECDSA-AES128-GCM-SHA256", + "ECDHE-RSA-AES128-GCM-SHA256", + "ECDHE-ECDSA-AES256-GCM-SHA384", + "ECDHE-RSA-AES256-GCM-SHA384", + "ECDHE-ECDSA-CHACHA20-POLY1305", + "ECDHE-RSA-CHACHA20-POLY1305", + "DHE-RSA-AES128-GCM-SHA256", + "DHE-RSA-AES256-GCM-SHA384", + "DHE-RSA-CHACHA20-POLY1305", + "ECDHE-ECDSA-AES128-SHA256", + "ECDHE-RSA-AES128-SHA256", + "ECDHE-ECDSA-AES128-SHA", + "ECDHE-RSA-AES128-SHA", + "ECDHE-ECDSA-AES256-SHA384", + "ECDHE-RSA-AES256-SHA384", + "ECDHE-ECDSA-AES256-SHA", + "ECDHE-RSA-AES256-SHA", + "DHE-RSA-AES128-SHA256", + "DHE-RSA-AES256-SHA256", + "AES128-GCM-SHA256", + "AES256-GCM-SHA384", + "AES128-SHA256", + "AES256-SHA256", + "AES128-SHA", + "AES256-SHA", + "DES-CBC3-SHA", + ], + TLSProfiles.INTERMEDIATE_TYPE: [ + "TLS_AES_128_GCM_SHA256", + "TLS_AES_256_GCM_SHA384", + "TLS_CHACHA20_POLY1305_SHA256", + "ECDHE-ECDSA-AES128-GCM-SHA256", + "ECDHE-RSA-AES128-GCM-SHA256", + "ECDHE-ECDSA-AES256-GCM-SHA384", + "ECDHE-RSA-AES256-GCM-SHA384", + "ECDHE-ECDSA-CHACHA20-POLY1305", + "ECDHE-RSA-CHACHA20-POLY1305", + "DHE-RSA-AES128-GCM-SHA256", + "DHE-RSA-AES256-GCM-SHA384", + ], + TLSProfiles.MODERN_TYPE: [ + "TLS_AES_128_GCM_SHA256", + "TLS_AES_256_GCM_SHA384", + "TLS_CHACHA20_POLY1305_SHA256", + ], +} + + +def ssl_tls_version( + tls_protocol_version: Optional[TLSProtocolVersion], +) -> Optional[ssl.TLSVersion]: + """Convert TLS protocol version string into its ssl.TLSVersion equivalent.""" + if tls_protocol_version is None: + return None + tls_versions = { + TLSProtocolVersion.VERSION_TLS_10: ssl.TLSVersion.TLSv1, + TLSProtocolVersion.VERSION_TLS_11: ssl.TLSVersion.TLSv1_1, + TLSProtocolVersion.VERSION_TLS_12: ssl.TLSVersion.TLSv1_2, + TLSProtocolVersion.VERSION_TLS_13: ssl.TLSVersion.TLSv1_3, + } + return tls_versions.get(tls_protocol_version, None) + + +def min_tls_version( + specified_tls_version: Optional[str], + tls_profile: TLSProfiles, +) -> Optional[TLSProtocolVersion]: + """Retrieve minimal TLS version for the profile or from specified configuration.""" + if specified_tls_version is not None: + try: + return TLSProtocolVersion(specified_tls_version) + except ValueError: + logger.warning( + "Invalid TLS version '%s', using profile default", specified_tls_version + ) + return MIN_TLS_VERSIONS.get(tls_profile) + + +def ciphers_from_list(ciphers: Optional[list[str]]) -> Optional[str]: + """Convert list of ciphers into one string to be consumable by SSL context.""" + if ciphers is None: + return None + return ":".join(ciphers) + + +def ciphers_for_tls_profile(tls_profile: TLSProfiles) -> Optional[str]: + """Retrieve list of ciphers for specified TLS profile.""" + ciphers = TLS_CIPHERS.get(tls_profile, None) + return ciphers_from_list(ciphers) + + +def ciphers_as_string( + ciphers: Optional[list[str]], tls_profile: TLSProfiles +) -> Optional[str]: + """Retrieve ciphers as one string for custom list or TLS profile-based list.""" + ciphers_as_str = ciphers_from_list(ciphers) + if ciphers_as_str is None: + return ciphers_for_tls_profile(tls_profile) + return ciphers_as_str + diff --git a/tests/configuration/lightspeed-stack-tls.yaml b/tests/configuration/lightspeed-stack-tls.yaml new file mode 100644 index 000000000..5de6edf1c --- /dev/null +++ b/tests/configuration/lightspeed-stack-tls.yaml @@ -0,0 +1,21 @@ +name: TLS Test Configuration +service: + host: localhost + port: 8080 + auth_enabled: false + workers: 1 + color_log: true + access_log: true +llama_stack: + url: https://localhost:8321 + use_as_library_client: false + api_key: test-key + tls_security_profile: + type: ModernType + minTLSVersion: VersionTLS13 + ciphers: + - TLS_AES_128_GCM_SHA256 + - TLS_AES_256_GCM_SHA384 +user_data_collection: + feedback_enabled: false + diff --git a/tests/integration/test_tls_configuration.py b/tests/integration/test_tls_configuration.py new file mode 100644 index 000000000..f9a77a0a0 --- /dev/null +++ b/tests/integration/test_tls_configuration.py @@ -0,0 +1,118 @@ +"""Integration tests for TLS security profile configuration.""" + +import ssl +from unittest.mock import patch, MagicMock + +import pytest + +from configuration import AppConfig +from client import AsyncLlamaStackClientHolder + + +@pytest.fixture(name="tls_configuration_filename") +def tls_configuration_filename_fixture() -> str: + """Retrieve TLS configuration file name for integration tests.""" + return "tests/configuration/lightspeed-stack-tls.yaml" + + +def test_loading_tls_configuration(tls_configuration_filename: str) -> None: + """Test loading configuration with TLS security profile.""" + cfg = AppConfig() + cfg.load_configuration(tls_configuration_filename) + + # check if configuration is loaded + assert cfg is not None + assert cfg.configuration is not None + + # check 'llama_stack' section + ls_config = cfg.llama_stack_configuration + assert ls_config.url == "https://localhost:8321" + assert ls_config.use_as_library_client is False + + # check TLS security profile + tls_profile = ls_config.tls_security_profile + assert tls_profile is not None + assert tls_profile.profile_type == "ModernType" + assert tls_profile.min_tls_version == "VersionTLS13" + assert tls_profile.ciphers == [ + "TLS_AES_128_GCM_SHA256", + "TLS_AES_256_GCM_SHA384", + ] + + +def test_tls_configuration_defaults(tls_configuration_filename: str) -> None: + """Test that TLS configuration has correct default values.""" + cfg = AppConfig() + cfg.load_configuration(tls_configuration_filename) + + tls_profile = cfg.llama_stack_configuration.tls_security_profile + assert tls_profile is not None + + # These should be None/False by default when not specified + assert tls_profile.ca_cert_path is None + assert tls_profile.skip_tls_verification is False + + +@pytest.mark.asyncio +async def test_client_construction_with_tls_profile( + tls_configuration_filename: str, +) -> None: + """Test that AsyncLlamaStackClientHolder constructs client with TLS settings.""" + cfg = AppConfig() + cfg.load_configuration(tls_configuration_filename) + + holder = AsyncLlamaStackClientHolder() + + # Mock httpx.AsyncClient to capture the SSL context + with patch("client.httpx.AsyncClient") as mock_async_client: + with patch("client.AsyncLlamaStackClient") as mock_llama_client: + mock_async_client.return_value = MagicMock() + mock_llama_client.return_value = MagicMock() + + await holder.load(cfg.llama_stack_configuration) + + # Verify httpx.AsyncClient was called with an SSL context + mock_async_client.assert_called_once() + call_kwargs = mock_async_client.call_args.kwargs + verify_arg = call_kwargs.get("verify") + + # Should be an SSL context, not a boolean + assert isinstance(verify_arg, ssl.SSLContext) + + # Verify minimum TLS version is set to TLS 1.3 + assert verify_arg.minimum_version == ssl.TLSVersion.TLSv1_3 + + # Verify AsyncLlamaStackClient was called with the custom http_client + mock_llama_client.assert_called_once() + llama_call_kwargs = mock_llama_client.call_args.kwargs + assert "http_client" in llama_call_kwargs + assert llama_call_kwargs["http_client"] is not None + + +@pytest.mark.asyncio +async def test_client_construction_without_tls_profile() -> None: + """Test that client is constructed normally without TLS profile.""" + from models.config import LlamaStackConfiguration + + cfg = LlamaStackConfiguration( + url="http://localhost:8321", + use_as_library_client=False, + tls_security_profile=None, + ) + + holder = AsyncLlamaStackClientHolder() + + with patch("client.httpx.AsyncClient") as mock_async_client: + with patch("client.AsyncLlamaStackClient") as mock_llama_client: + mock_llama_client.return_value = MagicMock() + + await holder.load(cfg) + + # httpx.AsyncClient should NOT be called when no TLS profile + mock_async_client.assert_not_called() + + # AsyncLlamaStackClient should be called with http_client=None + mock_llama_client.assert_called_once() + llama_call_kwargs = mock_llama_client.call_args.kwargs + assert llama_call_kwargs.get("http_client") is None + diff --git a/tests/unit/models/config/test_dump_configuration.py b/tests/unit/models/config/test_dump_configuration.py index 38177a8a7..58d584315 100644 --- a/tests/unit/models/config/test_dump_configuration.py +++ b/tests/unit/models/config/test_dump_configuration.py @@ -135,6 +135,7 @@ def test_dump_configuration(tmp_path: Path) -> None: "use_as_library_client": True, "api_key": "**********", "library_client_config_path": "tests/configuration/run.yaml", + "tls_security_profile": None, }, "user_data_collection": { "feedback_enabled": False, @@ -435,6 +436,7 @@ def test_dump_configuration_with_quota_limiters(tmp_path: Path) -> None: "use_as_library_client": True, "api_key": "**********", "library_client_config_path": "tests/configuration/run.yaml", + "tls_security_profile": None, }, "user_data_collection": { "feedback_enabled": False, @@ -620,6 +622,7 @@ def test_dump_configuration_byok(tmp_path: Path) -> None: "use_as_library_client": True, "api_key": "**********", "library_client_config_path": "tests/configuration/run.yaml", + "tls_security_profile": None, }, "user_data_collection": { "feedback_enabled": False, diff --git a/tests/unit/models/config/test_tls_security_profile.py b/tests/unit/models/config/test_tls_security_profile.py new file mode 100644 index 000000000..7f9ae134a --- /dev/null +++ b/tests/unit/models/config/test_tls_security_profile.py @@ -0,0 +1,165 @@ +"""Unit tests for TLSSecurityProfile configuration model.""" + +import pytest + +from models.config import TLSSecurityProfile, LlamaStackConfiguration + + +class TestTLSSecurityProfileInit: + """Tests for TLSSecurityProfile initialization.""" + + def test_default_initialization(self) -> None: + """Test default initialization with no parameters.""" + profile = TLSSecurityProfile() + assert profile.profile_type is None + assert profile.min_tls_version is None + assert profile.ciphers is None + + def test_initialization_with_profile_type(self) -> None: + """Test initialization with profile type.""" + profile = TLSSecurityProfile(profile_type="ModernType") + assert profile.profile_type == "ModernType" + assert profile.min_tls_version is None + assert profile.ciphers is None + + def test_initialization_with_alias(self) -> None: + """Test initialization using YAML-style aliases.""" + profile = TLSSecurityProfile(type="IntermediateType", minTLSVersion="VersionTLS12") + assert profile.profile_type == "IntermediateType" + assert profile.min_tls_version == "VersionTLS12" + + def test_initialization_with_all_fields(self) -> None: + """Test initialization with all fields.""" + ciphers = ["TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384"] + profile = TLSSecurityProfile( + profile_type="Custom", + min_tls_version="VersionTLS13", + ciphers=ciphers, + ) + assert profile.profile_type == "Custom" + assert profile.min_tls_version == "VersionTLS13" + assert profile.ciphers == ciphers + + +class TestTLSSecurityProfileValidation: + """Tests for TLSSecurityProfile validation.""" + + def test_valid_old_type_profile(self) -> None: + """Test valid OldType profile.""" + profile = TLSSecurityProfile(profile_type="OldType") + assert profile.profile_type == "OldType" + + def test_valid_intermediate_type_profile(self) -> None: + """Test valid IntermediateType profile.""" + profile = TLSSecurityProfile(profile_type="IntermediateType") + assert profile.profile_type == "IntermediateType" + + def test_valid_modern_type_profile(self) -> None: + """Test valid ModernType profile.""" + profile = TLSSecurityProfile(profile_type="ModernType") + assert profile.profile_type == "ModernType" + + def test_valid_custom_profile(self) -> None: + """Test valid Custom profile.""" + profile = TLSSecurityProfile(profile_type="Custom") + assert profile.profile_type == "Custom" + + def test_invalid_profile_type(self) -> None: + """Test invalid profile type raises error.""" + with pytest.raises(ValueError, match="Invalid TLS profile type"): + TLSSecurityProfile(profile_type="InvalidType") + + def test_valid_tls_versions(self) -> None: + """Test all valid TLS versions.""" + for version in ["VersionTLS10", "VersionTLS11", "VersionTLS12", "VersionTLS13"]: + profile = TLSSecurityProfile(min_tls_version=version) + assert profile.min_tls_version == version + + def test_invalid_tls_version(self) -> None: + """Test invalid TLS version raises error.""" + with pytest.raises(ValueError, match="Invalid minimal TLS version"): + TLSSecurityProfile(min_tls_version="VersionTLS14") + + def test_cipher_validation_non_custom_profile(self) -> None: + """Test that ciphers must be valid for non-Custom profiles.""" + # Using a cipher not in the ModernType profile + with pytest.raises(ValueError, match="Unsupported cipher"): + TLSSecurityProfile( + profile_type="ModernType", + ciphers=["INVALID_CIPHER"], + ) + + def test_cipher_validation_custom_profile_allows_any(self) -> None: + """Test that Custom profile allows any ciphers.""" + profile = TLSSecurityProfile( + profile_type="Custom", + ciphers=["ANY_CIPHER_ALLOWED"], + ) + assert profile.ciphers == ["ANY_CIPHER_ALLOWED"] + + def test_valid_ciphers_for_profile(self) -> None: + """Test valid ciphers for IntermediateType profile.""" + profile = TLSSecurityProfile( + profile_type="IntermediateType", + ciphers=["TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384"], + ) + assert profile.ciphers == ["TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384"] + + +class TestTLSSecurityProfileExtraFields: + """Tests for extra field handling.""" + + def test_extra_fields_forbidden(self) -> None: + """Test that extra fields are rejected.""" + with pytest.raises(ValueError): + TLSSecurityProfile( + profile_type="ModernType", + unknown_field="value", + ) + + +class TestLlamaStackConfigurationWithTLS: + """Tests for LlamaStackConfiguration with TLS security profile.""" + + def test_llama_stack_config_without_tls_profile(self) -> None: + """Test LlamaStackConfiguration without TLS profile.""" + config = LlamaStackConfiguration( + url="https://llama-stack:8321", + use_as_library_client=False, + ) + assert config.tls_security_profile is None + + def test_llama_stack_config_with_tls_profile(self) -> None: + """Test LlamaStackConfiguration with TLS profile.""" + tls_profile = TLSSecurityProfile( + profile_type="ModernType", + min_tls_version="VersionTLS13", + ) + config = LlamaStackConfiguration( + url="https://llama-stack:8321", + use_as_library_client=False, + tls_security_profile=tls_profile, + ) + assert config.tls_security_profile is not None + assert config.tls_security_profile.profile_type == "ModernType" + assert config.tls_security_profile.min_tls_version == "VersionTLS13" + + def test_llama_stack_config_with_custom_ciphers(self) -> None: + """Test LlamaStackConfiguration with custom TLS ciphers.""" + ciphers = [ + "TLS_AES_128_GCM_SHA256", + "TLS_AES_256_GCM_SHA384", + ] + tls_profile = TLSSecurityProfile( + profile_type="Custom", + min_tls_version="VersionTLS13", + ciphers=ciphers, + ) + config = LlamaStackConfiguration( + url="https://llama-stack:8321", + use_as_library_client=False, + tls_security_profile=tls_profile, + ) + assert config.tls_security_profile is not None + assert config.tls_security_profile.ciphers == ciphers + diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 5405092fe..bfc538b7d 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1,9 +1,12 @@ """Unit tests for functions defined in src/client.py.""" +import ssl +from unittest.mock import patch, MagicMock + import pytest from client import AsyncLlamaStackClientHolder -from models.config import LlamaStackConfiguration +from models.config import LlamaStackConfiguration, TLSSecurityProfile def test_async_client_get_client_method() -> None: @@ -71,3 +74,93 @@ async def test_get_async_llama_stack_wrong_configuration() -> None: ): client = AsyncLlamaStackClientHolder() await client.load(cfg) + + +class TestConstructHttpxClient: + """Tests for _construct_httpx_client method.""" + + def test_construct_httpx_client_no_profile(self) -> None: + """Test that None is returned when no TLS profile is set.""" + holder = AsyncLlamaStackClientHolder() + result = holder._construct_httpx_client(None) + assert result is None + + def test_construct_httpx_client_profile_type_none(self) -> None: + """Test that None is returned when profile_type is None.""" + holder = AsyncLlamaStackClientHolder() + profile = TLSSecurityProfile() + result = holder._construct_httpx_client(profile) + assert result is None + + def test_construct_httpx_client_with_modern_profile(self) -> None: + """Test that httpx client is created with ModernType profile.""" + holder = AsyncLlamaStackClientHolder() + profile = TLSSecurityProfile( + profile_type="ModernType", + min_tls_version="VersionTLS13", + ) + result = holder._construct_httpx_client(profile) + assert result is not None + + def test_construct_httpx_client_with_custom_ciphers(self) -> None: + """Test that httpx client is created with custom ciphers.""" + holder = AsyncLlamaStackClientHolder() + profile = TLSSecurityProfile( + profile_type="Custom", + min_tls_version="VersionTLS12", + ciphers=["TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384"], + ) + result = holder._construct_httpx_client(profile) + assert result is not None + + @patch("client.ssl.create_default_context") + def test_ssl_context_minimum_version_set(self, mock_create_context: MagicMock) -> None: + """Test that SSL context minimum version is set correctly.""" + mock_context = MagicMock() + mock_create_context.return_value = mock_context + + holder = AsyncLlamaStackClientHolder() + profile = TLSSecurityProfile( + profile_type="ModernType", + min_tls_version="VersionTLS13", + ) + holder._construct_httpx_client(profile) + + mock_create_context.assert_called_once() + assert mock_context.minimum_version == ssl.TLSVersion.TLSv1_3 + + @patch("client.ssl.create_default_context") + def test_ssl_context_ciphers_set(self, mock_create_context: MagicMock) -> None: + """Test that SSL context ciphers are set correctly.""" + mock_context = MagicMock() + mock_create_context.return_value = mock_context + + holder = AsyncLlamaStackClientHolder() + profile = TLSSecurityProfile( + profile_type="Custom", + ciphers=["CIPHER1", "CIPHER2"], + ) + holder._construct_httpx_client(profile) + + mock_context.set_ciphers.assert_called_once_with("CIPHER1:CIPHER2") + + +async def test_get_async_llama_stack_remote_client_with_tls() -> None: + """Test initialization of Llama Stack client with TLS profile.""" + tls_profile = TLSSecurityProfile( + profile_type="ModernType", + min_tls_version="VersionTLS13", + ) + cfg = LlamaStackConfiguration( + url="http://localhost:8321", + api_key=None, + use_as_library_client=False, + library_client_config_path="./tests/configuration/minimal-stack.yaml", + tls_security_profile=tls_profile, + ) + client = AsyncLlamaStackClientHolder() + await client.load(cfg) + assert client is not None + + ls_client = client.get_client() + assert ls_client is not None diff --git a/tests/unit/utils/test_tls.py b/tests/unit/utils/test_tls.py new file mode 100644 index 000000000..f637a558d --- /dev/null +++ b/tests/unit/utils/test_tls.py @@ -0,0 +1,225 @@ +"""Unit tests for TLS utilities defined in src/utils/tls.py.""" + +import ssl + +import pytest + +from utils.tls import ( + TLSProfiles, + TLSProtocolVersion, + MIN_TLS_VERSIONS, + TLS_CIPHERS, + ssl_tls_version, + min_tls_version, + ciphers_from_list, + ciphers_for_tls_profile, + ciphers_as_string, +) + + +class TestTLSProfiles: + """Tests for TLSProfiles enum.""" + + def test_tls_profiles_values(self) -> None: + """Test that TLSProfiles has expected values.""" + assert TLSProfiles.OLD_TYPE == "OldType" + assert TLSProfiles.INTERMEDIATE_TYPE == "IntermediateType" + assert TLSProfiles.MODERN_TYPE == "ModernType" + assert TLSProfiles.CUSTOM_TYPE == "Custom" + + def test_tls_profiles_from_string(self) -> None: + """Test creating TLSProfiles from string.""" + assert TLSProfiles("OldType") == TLSProfiles.OLD_TYPE + assert TLSProfiles("IntermediateType") == TLSProfiles.INTERMEDIATE_TYPE + assert TLSProfiles("ModernType") == TLSProfiles.MODERN_TYPE + assert TLSProfiles("Custom") == TLSProfiles.CUSTOM_TYPE + + def test_tls_profiles_invalid(self) -> None: + """Test invalid TLS profile raises error.""" + with pytest.raises(ValueError): + TLSProfiles("InvalidType") + + +class TestTLSProtocolVersion: + """Tests for TLSProtocolVersion enum.""" + + def test_tls_protocol_version_values(self) -> None: + """Test that TLSProtocolVersion has expected values.""" + assert TLSProtocolVersion.VERSION_TLS_10 == "VersionTLS10" + assert TLSProtocolVersion.VERSION_TLS_11 == "VersionTLS11" + assert TLSProtocolVersion.VERSION_TLS_12 == "VersionTLS12" + assert TLSProtocolVersion.VERSION_TLS_13 == "VersionTLS13" + + def test_tls_protocol_version_from_string(self) -> None: + """Test creating TLSProtocolVersion from string.""" + assert TLSProtocolVersion("VersionTLS10") == TLSProtocolVersion.VERSION_TLS_10 + assert TLSProtocolVersion("VersionTLS11") == TLSProtocolVersion.VERSION_TLS_11 + assert TLSProtocolVersion("VersionTLS12") == TLSProtocolVersion.VERSION_TLS_12 + assert TLSProtocolVersion("VersionTLS13") == TLSProtocolVersion.VERSION_TLS_13 + + def test_tls_protocol_version_invalid(self) -> None: + """Test invalid TLS version raises error.""" + with pytest.raises(ValueError): + TLSProtocolVersion("VersionTLS14") + + +class TestMinTLSVersionsMapping: + """Tests for MIN_TLS_VERSIONS mapping.""" + + def test_old_type_min_version(self) -> None: + """Test OldType has TLS 1.0 as minimum.""" + assert MIN_TLS_VERSIONS[TLSProfiles.OLD_TYPE] == TLSProtocolVersion.VERSION_TLS_10 + + def test_intermediate_type_min_version(self) -> None: + """Test IntermediateType has TLS 1.2 as minimum.""" + assert ( + MIN_TLS_VERSIONS[TLSProfiles.INTERMEDIATE_TYPE] + == TLSProtocolVersion.VERSION_TLS_12 + ) + + def test_modern_type_min_version(self) -> None: + """Test ModernType has TLS 1.3 as minimum.""" + assert MIN_TLS_VERSIONS[TLSProfiles.MODERN_TYPE] == TLSProtocolVersion.VERSION_TLS_13 + + +class TestTLSCiphersMapping: + """Tests for TLS_CIPHERS mapping.""" + + def test_old_type_has_ciphers(self) -> None: + """Test OldType has ciphers defined.""" + assert TLSProfiles.OLD_TYPE in TLS_CIPHERS + assert len(TLS_CIPHERS[TLSProfiles.OLD_TYPE]) > 0 + + def test_intermediate_type_has_ciphers(self) -> None: + """Test IntermediateType has ciphers defined.""" + assert TLSProfiles.INTERMEDIATE_TYPE in TLS_CIPHERS + assert len(TLS_CIPHERS[TLSProfiles.INTERMEDIATE_TYPE]) > 0 + + def test_modern_type_has_ciphers(self) -> None: + """Test ModernType has ciphers defined.""" + assert TLSProfiles.MODERN_TYPE in TLS_CIPHERS + assert len(TLS_CIPHERS[TLSProfiles.MODERN_TYPE]) > 0 + + def test_modern_type_has_fewer_ciphers_than_old(self) -> None: + """Test ModernType has fewer ciphers than OldType (more restrictive).""" + assert len(TLS_CIPHERS[TLSProfiles.MODERN_TYPE]) < len( + TLS_CIPHERS[TLSProfiles.OLD_TYPE] + ) + + +class TestSslTlsVersion: + """Tests for ssl_tls_version function.""" + + def test_ssl_tls_version_tls10(self) -> None: + """Test conversion of TLS 1.0.""" + result = ssl_tls_version(TLSProtocolVersion.VERSION_TLS_10) + assert result == ssl.TLSVersion.TLSv1 + + def test_ssl_tls_version_tls11(self) -> None: + """Test conversion of TLS 1.1.""" + result = ssl_tls_version(TLSProtocolVersion.VERSION_TLS_11) + assert result == ssl.TLSVersion.TLSv1_1 + + def test_ssl_tls_version_tls12(self) -> None: + """Test conversion of TLS 1.2.""" + result = ssl_tls_version(TLSProtocolVersion.VERSION_TLS_12) + assert result == ssl.TLSVersion.TLSv1_2 + + def test_ssl_tls_version_tls13(self) -> None: + """Test conversion of TLS 1.3.""" + result = ssl_tls_version(TLSProtocolVersion.VERSION_TLS_13) + assert result == ssl.TLSVersion.TLSv1_3 + + def test_ssl_tls_version_none(self) -> None: + """Test conversion of None returns None.""" + result = ssl_tls_version(None) + assert result is None + + +class TestMinTlsVersion: + """Tests for min_tls_version function.""" + + def test_min_tls_version_specified(self) -> None: + """Test that specified version overrides profile default.""" + result = min_tls_version("VersionTLS13", TLSProfiles.OLD_TYPE) + assert result == TLSProtocolVersion.VERSION_TLS_13 + + def test_min_tls_version_from_profile(self) -> None: + """Test that profile default is used when no version specified.""" + result = min_tls_version(None, TLSProfiles.MODERN_TYPE) + assert result == TLSProtocolVersion.VERSION_TLS_13 + + def test_min_tls_version_invalid_falls_back_to_profile(self) -> None: + """Test that invalid version falls back to profile default.""" + result = min_tls_version("InvalidVersion", TLSProfiles.INTERMEDIATE_TYPE) + assert result == TLSProtocolVersion.VERSION_TLS_12 + + +class TestCiphersFromList: + """Tests for ciphers_from_list function.""" + + def test_ciphers_from_list_with_ciphers(self) -> None: + """Test conversion of cipher list to string.""" + ciphers = ["CIPHER1", "CIPHER2", "CIPHER3"] + result = ciphers_from_list(ciphers) + assert result == "CIPHER1:CIPHER2:CIPHER3" + + def test_ciphers_from_list_single(self) -> None: + """Test conversion of single cipher.""" + ciphers = ["CIPHER1"] + result = ciphers_from_list(ciphers) + assert result == "CIPHER1" + + def test_ciphers_from_list_empty(self) -> None: + """Test conversion of empty list.""" + result = ciphers_from_list([]) + assert result == "" + + def test_ciphers_from_list_none(self) -> None: + """Test conversion of None returns None.""" + result = ciphers_from_list(None) + assert result is None + + +class TestCiphersForTlsProfile: + """Tests for ciphers_for_tls_profile function.""" + + def test_ciphers_for_old_type(self) -> None: + """Test getting ciphers for OldType profile.""" + result = ciphers_for_tls_profile(TLSProfiles.OLD_TYPE) + assert result is not None + assert ":" in result # Should be colon-separated + + def test_ciphers_for_modern_type(self) -> None: + """Test getting ciphers for ModernType profile.""" + result = ciphers_for_tls_profile(TLSProfiles.MODERN_TYPE) + assert result is not None + # Modern type should have TLS 1.3 ciphers + assert "TLS_AES_128_GCM_SHA256" in result + + def test_ciphers_for_custom_type(self) -> None: + """Test Custom type returns None (no predefined ciphers).""" + result = ciphers_for_tls_profile(TLSProfiles.CUSTOM_TYPE) + assert result is None + + +class TestCiphersAsString: + """Tests for ciphers_as_string function.""" + + def test_ciphers_as_string_custom_list(self) -> None: + """Test that custom cipher list is used when provided.""" + custom_ciphers = ["CUSTOM1", "CUSTOM2"] + result = ciphers_as_string(custom_ciphers, TLSProfiles.MODERN_TYPE) + assert result == "CUSTOM1:CUSTOM2" + + def test_ciphers_as_string_profile_default(self) -> None: + """Test that profile ciphers are used when no custom list.""" + result = ciphers_as_string(None, TLSProfiles.MODERN_TYPE) + assert result is not None + assert "TLS_AES_128_GCM_SHA256" in result + + def test_ciphers_as_string_empty_list_uses_profile(self) -> None: + """Test that empty list results in empty string (not profile default).""" + result = ciphers_as_string([], TLSProfiles.MODERN_TYPE) + assert result == "" +