diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..831a313 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,43 @@ +# Git +.git +.gitignore + +# Python +__pycache__ +*.py[cod] +*$py.class +*.so +.Python +.venv +venv/ +ENV/ +env/ +.eggs/ +*.egg-info/ +*.egg + +# Testing +.pytest_cache/ +.coverage +htmlcov/ +.tox/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Docker +Dockerfile +docker-compose*.yml +.docker/ + +# Documentation +*.md +docs/ + +# Misc +.env +.env.* +*.log diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..52758ad --- /dev/null +++ b/.gitignore @@ -0,0 +1,65 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +.idea/ + +# Virtual environments +.venv/ +venv/ +ENV/ +env/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ +.project +.pydevproject +.settings/ + +# Testing +.pytest_cache/ +.coverage +htmlcov/ +.tox/ +.nox/ + +# mypy +.mypy_cache/ + +# Environment variables +.env +.env.local +.env.*.local +*.env + +# Docker +docker-compose.override.yml + +# OS +.DS_Store +Thumbs.db + +# Logs +*.log +logs/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..b4a98cb --- /dev/null +++ b/Dockerfile @@ -0,0 +1,34 @@ +FROM python:3.13-slim AS builder + +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 + +WORKDIR /app + +RUN python -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +COPY requirements.txt . +RUN pip install --upgrade pip && \ + pip install -r requirements.txt + +FROM python:3.13-slim AS production + +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PATH="/opt/venv/bin:$PATH" + +WORKDIR /app + +COPY --from=builder --chown=1000:1000 /opt/venv /opt/venv +COPY --chown=1000:1000 api/ ./api/ +COPY --chown=1000:1000 cli/ ./cli/ + +USER 1000:1000 + +EXPOSE 8000 + +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1 + +CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/README.md b/README.md index 3145d38..ecfb1d3 100644 --- a/README.md +++ b/README.md @@ -1,31 +1,338 @@ -# Instructions +# Server Inventory Management API -You are developing an inventory management software solution for a cloud services company that provisions servers in multiple data centers. You must build a CRUD app for tracking the state of all the servers. +A CRUD application for tracking servers across multiple data centers. -Deliverables: -- PR to https://github.com/Mathpix/hiring-challenge-devops-python that includes: -- API code -- CLI code -- pytest test suite -- Working Docker Compose stack +## Quick Start -Short API.md on how to run everything, also a short API and CLI spec +### Using Docker Compose (Recommended) -Required endpoints: -- POST /servers → create a server -- GET /servers → list all servers -- GET /servers/{id} → get one server -- PUT /servers/{id} → update server -- DELETE /servers/{id} → delete server +```bash +# Start the entire stack (API + PostgreSQL) +docker-compose up -d -Requirements: -- Use FastAPI or Flask -- Store data in PostgreSQL -- Use raw SQL +# View logs +docker-compose logs -f api -Validate that: -- hostname is unique -- IP address looks like an IP +# Stop the stack +docker-compose down +``` -State is one of: active, offline, retired +The API will be available at `http://localhost:8000`. +### Running Locally + +1. **Start PostgreSQL** + ```bash + docker-compose up -d db + ``` + +2. **Install dependencies** + ```bash + pip install -r requirements.txt + ``` + +3. **Set environment variables** + ```bash + export DATABASE_URL=postgresql://postgres:postgres@localhost:5432/inventory + ``` + +4. **Start the API** + ```bash + uvicorn api.main:app --reload + ``` + +### Running Tests + +```bash +pytest -v +``` + +--- + +## API Specification + +Base URL: `http://localhost:8000` + +### Endpoints + +| Method | Endpoint | Description | +|--------|----------|-------------| +| POST | /servers | Create a new server | +| GET | /servers | List all servers | +| GET | /servers/{id} | Get a server by ID | +| PUT | /servers/{id} | Update a server | +| DELETE | /servers/{id} | Delete a server | +| GET | /health | Health check | + +--- + +### POST /servers + +Create a new server. + +**Request Body:** +```json +{ + "hostname": "web-server-01", + "ip_address": "192.168.1.100", + "datacenter": "us-east-1", + "state": "active" +} +``` + +**Validation Rules:** +- `hostname`: Required, must be unique, 1-255 alphanumeric characters with hyphens and dots +- `ip_address`: Required, must be valid IPv4 or IPv6 address +- `datacenter`: Required, string +- `state`: Required, one of: `active`, `offline`, `retired` + +**Response (201 Created):** +```json +{ + "id": 1, + "hostname": "web-server-01", + "ip_address": "192.168.1.100", + "datacenter": "us-east-1", + "state": "active", + "created_at": "2024-01-15T10:30:00.000000", + "updated_at": "2024-01-15T10:30:00.000000" +} +``` + +**Error Responses:** +- `409 Conflict`: Hostname already exists +- `422 Unprocessable Entity`: Validation error + +--- + +### GET /servers + +List all servers. + +**Response (200 OK):** +```json +[ + { + "id": 1, + "hostname": "web-server-01", + "ip_address": "192.168.1.100", + "datacenter": "us-east-1", + "state": "active", + "created_at": "2024-01-15T10:30:00.000000", + "updated_at": "2024-01-15T10:30:00.000000" + } +] +``` + +--- + +### GET /servers/{id} + +Get a server by ID. + +**Response (200 OK):** +```json +{ + "id": 1, + "hostname": "web-server-01", + "ip_address": "192.168.1.100", + "datacenter": "us-east-1", + "state": "active", + "created_at": "2024-01-15T10:30:00.000000", + "updated_at": "2024-01-15T10:30:00.000000" +} +``` + +**Error Responses:** +- `404 Not Found`: Server not found + +--- + +### PUT /servers/{id} + +Update an existing server. + +**Request Body:** +```json +{ + "hostname": "web-server-01-updated", + "ip_address": "10.0.0.1", + "datacenter": "us-west-2", + "state": "offline" +} +``` + +**Response (200 OK):** +```json +{ + "id": 1, + "hostname": "web-server-01-updated", + "ip_address": "10.0.0.1", + "datacenter": "us-west-2", + "state": "offline", + "created_at": "2024-01-15T10:30:00.000000", + "updated_at": "2024-01-15T11:00:00.000000" +} +``` + +**Error Responses:** +- `404 Not Found`: Server not found +- `409 Conflict`: Hostname already exists (on another server) +- `422 Unprocessable Entity`: Validation error + +--- + +### DELETE /servers/{id} + +Delete a server. + +**Response (204 No Content):** Empty body + +**Error Responses:** +- `404 Not Found`: Server not found + +--- + +### GET /health + +Health check endpoint. + +**Response (200 OK):** +```json +{ + "status": "healthy", + "database": "connected" +} +``` + +--- + +## CLI Specification + +The CLI interacts with the API to manage servers. + +### Installation + +```bash +pip install -r requirements.txt +``` + +### Configuration + +Set the API URL via environment variable: +```bash +export API_URL=http://localhost:8000 +``` + +Or use the `--api-url` flag: +```bash +python -m cli.main --api-url http://localhost:8000 list +``` + +### Commands + +#### List Servers + +```bash +# List all servers +python -m cli.main list + +# Output as JSON +python -m cli.main list --json-output +python -m cli.main list -j +``` + +**Example Output:** +``` +ID: 1 + Hostname: web-server-01 + IP Address: 192.168.1.100 + Datacenter: us-east-1 + State: active + Created: 2024-01-15T10:30:00.000000 + Updated: 2024-01-15T10:30:00.000000 +``` + +#### Get Server + +```bash +# Get server by ID +python -m cli.main get 1 + +# Output as JSON +python -m cli.main get 1 --json-output +``` + +#### Create Server + +```bash +# Create a new server +python -m cli.main create \ + --hostname web-server-02 \ + --ip-address 192.168.1.101 \ + --datacenter us-east-1 \ + --state active + +# Short form +python -m cli.main create \ + -n web-server-02 \ + -i 192.168.1.101 \ + -d us-east-1 \ + -s active +``` + +**Options:** +| Flag | Short | Required | Description | +|------|-------|----------|-------------| +| --hostname | -n | Yes | Server hostname | +| --ip-address | -i | Yes | IP address (IPv4 or IPv6) | +| --datacenter | -d | Yes | Data center location | +| --state | -s | No | Server state (default: active) | +| --json-output | -j | No | Output as JSON | + +#### Update Server + +```bash +# Update an existing server +python -m cli.main update 1 \ + --hostname web-server-01-updated \ + --ip-address 10.0.0.1 \ + --datacenter us-west-2 \ + --state offline +``` + +**Options:** Same as create command (all required for update) + +#### Delete Server + +```bash +# Delete with confirmation prompt +python -m cli.main delete 1 + +# Skip confirmation +python -m cli.main delete 1 --yes +python -m cli.main delete 1 -y +``` + +### Exit Codes + +| Code | Meaning | +|------|---------| +| 0 | Success | +| 1 | Error (API error, connection error, etc.) | + +--- + +## Data Model + +### Server + +| Field | Type | Description | +|-------|------|-------------| +| id | integer | Auto-generated unique identifier | +| hostname | string | Unique server hostname (1-255 chars) | +| ip_address | string | Valid IPv4 or IPv6 address | +| datacenter | string | Data center location | +| state | enum | One of: active, offline, retired | +| created_at | datetime | Record creation timestamp | +| updated_at | datetime | Last update timestamp | diff --git a/api/__init__.py b/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/database.py b/api/database.py new file mode 100644 index 0000000..3780fb2 --- /dev/null +++ b/api/database.py @@ -0,0 +1,112 @@ +import os +import threading +import psycopg2 +from psycopg2.extras import RealDictCursor +from contextlib import contextmanager + +DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://postgres:postgres@localhost:5432/inventory") + +_local = threading.local() + + +def get_connection(): + if not hasattr(_local, 'conn') or _local.conn is None or _local.conn.closed: + _local.conn = psycopg2.connect(DATABASE_URL, cursor_factory=RealDictCursor) + return _local.conn + + +def close_db(): + if hasattr(_local, 'conn') and _local.conn is not None: + _local.conn.close() + _local.conn = None + + +@contextmanager +def get_db(): + conn = get_connection() + try: + yield conn + conn.commit() + except Exception: + conn.rollback() + raise + + +def init_db(): + with get_db() as conn: + with conn.cursor() as cur: + cur.execute(""" + CREATE TABLE IF NOT EXISTS servers ( + id SERIAL PRIMARY KEY, + hostname VARCHAR(255) NOT NULL UNIQUE, + ip_address VARCHAR(45) NOT NULL, + datacenter VARCHAR(255) NOT NULL, + state VARCHAR(20) NOT NULL CHECK (state IN ('active', 'offline', 'retired')), + owner VARCHAR(255) NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + cur.execute("CREATE INDEX IF NOT EXISTS idx_servers_datacenter ON servers(datacenter)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_servers_state ON servers(state)") + + +def create_server(hostname: str, ip_address: str, datacenter: str, state: str, owner: str) -> dict: + with get_db() as conn: + with conn.cursor() as cur: + cur.execute( + """ + INSERT INTO servers (hostname, ip_address, datacenter, state, owner) + VALUES (%s, %s, %s, %s, %s) RETURNING id, hostname, ip_address, datacenter, state, owner, created_at, updated_at + """, + (hostname, ip_address, datacenter, state, owner) + ) + return dict(cur.fetchone()) + + +def get_all_servers(skip: int = 0, limit: int = 100) -> list: + with get_db() as conn: + with conn.cursor() as cur: + cur.execute( + "SELECT id, hostname, ip_address, datacenter, state, owner, created_at, updated_at FROM servers ORDER BY id LIMIT %s OFFSET %s", + (limit, skip) + ) + return [dict(row) for row in cur.fetchall()] + + +def get_server_by_id(server_id: int) -> dict | None: + with get_db() as conn: + with conn.cursor() as cur: + cur.execute( + "SELECT id, hostname, ip_address, datacenter, state, owner, created_at, updated_at FROM servers WHERE id = %s", + (server_id,) + ) + row = cur.fetchone() + return dict(row) if row else None + + +def update_server(server_id: int, hostname: str, ip_address: str, datacenter: str, state: str, owner: str) -> dict | None: + with get_db() as conn: + with conn.cursor() as cur: + cur.execute( + """ + UPDATE servers + SET hostname = %s, + ip_address = %s, + datacenter = %s, + state = %s, + owner = %s, + updated_at = CURRENT_TIMESTAMP + WHERE id = %s RETURNING id, hostname, ip_address, datacenter, state, owner, created_at, updated_at + """, + (hostname, ip_address, datacenter, state, owner, server_id) + ) + row = cur.fetchone() + return dict(row) if row else None + + +def delete_server(server_id: int) -> bool: + with get_db() as conn: + with conn.cursor() as cur: + cur.execute("DELETE FROM servers WHERE id = %s", (server_id,)) + return cur.rowcount > 0 diff --git a/api/main.py b/api/main.py new file mode 100644 index 0000000..a69c402 --- /dev/null +++ b/api/main.py @@ -0,0 +1,143 @@ +import logging +import sys +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException, Query, status +from psycopg2.errors import UniqueViolation + +from api.models import ServerCreate, ServerResponse +from api.database import ( + init_db, + close_db, + create_server, + get_all_servers, + get_server_by_id, + update_server, + delete_server, + get_db, +) + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)] +) +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + logger.info("Starting Server Inventory API") + init_db() + yield + close_db() + + +app = FastAPI( + title="Server Inventory API", + description="CRUD API for managing server inventory across data centers", + version="1.0.0", + lifespan=lifespan, +) + + +@app.post("/servers", response_model=ServerResponse, status_code=status.HTTP_201_CREATED) +def create_server_endpoint(server: ServerCreate): + logger.info("Creating server: hostname=%s, datacenter=%s", server.hostname, server.datacenter) + try: + result = create_server( + hostname=server.hostname, + ip_address=server.ip_address, + datacenter=server.datacenter, + state=server.state.value, + owner=server.owner, + ) + logger.info("Server created: id=%d, hostname=%s", result["id"], result["hostname"]) + return ServerResponse(**result) + except UniqueViolation: + logger.warning("Duplicate hostname rejected: %s", server.hostname) + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"Server with hostname '{server.hostname}' already exists" + ) + + +@app.get("/servers", response_model=list[ServerResponse]) +def list_servers_endpoint( + skip: int = Query(default=0, ge=0, description="Number of records to skip"), + limit: int = Query(default=100, ge=1, le=1000, description="Maximum number of records to return") +): + logger.debug("Listing servers: skip=%d, limit=%d", skip, limit) + servers = get_all_servers(skip=skip, limit=limit) + logger.debug("Retrieved %d servers", len(servers)) + return [ServerResponse(**s) for s in servers] + + +@app.get("/servers/{server_id}", response_model=ServerResponse) +def get_server_endpoint(server_id: int): + logger.debug("Getting server: id=%d", server_id) + server = get_server_by_id(server_id) + if not server: + logger.warning("Server not found: id=%d", server_id) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Server with id {server_id} not found" + ) + return ServerResponse(**server) + + +@app.put("/servers/{server_id}", response_model=ServerResponse) +def update_server_endpoint(server_id: int, server: ServerCreate): + logger.info("Updating server: id=%d, hostname=%s", server_id, server.hostname) + try: + result = update_server( + server_id=server_id, + hostname=server.hostname, + ip_address=server.ip_address, + datacenter=server.datacenter, + state=server.state.value, + owner=server.owner, + ) + if not result: + logger.warning("Server not found for update: id=%d", server_id) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Server with id {server_id} not found" + ) + logger.info("Server updated: id=%d", server_id) + return ServerResponse(**result) + except UniqueViolation: + logger.warning("Duplicate hostname rejected on update: %s", server.hostname) + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"Server with hostname '{server.hostname}' already exists" + ) + + +@app.delete("/servers/{server_id}", status_code=status.HTTP_204_NO_CONTENT) +def delete_server_endpoint(server_id: int): + logger.info("Deleting server: id=%d", server_id) + if not delete_server(server_id): + logger.warning("Server not found for delete: id=%d", server_id) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Server with id {server_id} not found" + ) + logger.info("Server deleted: id=%d", server_id) + return None + + +@app.get("/health") +def health_check(): + try: + with get_db() as conn: + with conn.cursor() as cur: + cur.execute("SELECT 1") + return {"status": "healthy", "database": "connected"} + except Exception as e: + logger.error("Health check failed: %s", str(e)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Database unavailable" + ) diff --git a/api/models.py b/api/models.py new file mode 100644 index 0000000..c035e3b --- /dev/null +++ b/api/models.py @@ -0,0 +1,76 @@ +import re +from datetime import datetime +from enum import Enum +from ipaddress import ip_address, AddressValueError +from pydantic import BaseModel, field_validator + + +class ServerState(str, Enum): + active = "active" + offline = "offline" + retired = "retired" + + +# RFC 1123 compliant hostname label pattern +_LABEL_PATTERN = re.compile(r'^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$') + + +def _validate_hostname(hostname: str) -> bool: + if not hostname or len(hostname) > 255: + return False + labels = hostname.split('.') + if any(label == '' for label in labels): + return False + return all(_LABEL_PATTERN.match(label) for label in labels) + + +def _validate_ip(ip: str) -> bool: + try: + ip_address(ip) + return True + except (AddressValueError, ValueError): + return False + + +class ServerCreate(BaseModel): + hostname: str + ip_address: str + datacenter: str + state: ServerState = ServerState.active + owner: str + + @field_validator("hostname") + @classmethod + def hostname_valid(cls, v: str) -> str: + if not _validate_hostname(v): + raise ValueError("Invalid hostname format") + return v + + @field_validator("ip_address") + @classmethod + def ip_valid(cls, v: str) -> str: + if not _validate_ip(v): + raise ValueError("Invalid IP address format") + return v + + @field_validator("datacenter") + @classmethod + def datacenter_valid(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("Datacenter must be a non-empty string") + if len(v) > 255: + raise ValueError("Datacenter must be 255 characters or fewer") + return v + + +class ServerResponse(BaseModel): + model_config = {"from_attributes": True} + + id: int + hostname: str + ip_address: str + datacenter: str + state: ServerState + owner: str + created_at: datetime + updated_at: datetime diff --git a/cli/__init__.py b/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cli/main.py b/cli/main.py new file mode 100644 index 0000000..743e7e7 --- /dev/null +++ b/cli/main.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +import os +import sys +import json +import logging +import click +import requests + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +DEFAULT_TIMEOUT = int(os.getenv("CLI_TIMEOUT", "30")) + + +def format_server(server: dict) -> str: + return ( + f"ID: {server['id']}\n" + f" Hostname: {server['hostname']}\n" + f" IP Address: {server['ip_address']}\n" + f" Datacenter: {server['datacenter']}\n" + f" State: {server['state']}\n" + f" Owner: {server['owner']}\n" + f" Created: {server['created_at']}\n" + f" Updated: {server['updated_at']}" + ) + + +def handle_error(response: requests.Response): + try: + detail = response.json().get("detail", response.text) + except json.JSONDecodeError: + detail = response.text + click.echo(f"Error: {detail}", err=True) + sys.exit(1) + + +def make_request(method: str, url: str, timeout: int, **kwargs) -> requests.Response: + """Make HTTP request with proper timeout handling.""" + # Use tuple for (connect_timeout, read_timeout) + timeout_tuple = (min(5, timeout), timeout) + try: + logger.debug("Making %s request to %s", method.upper(), url) + return requests.request(method, url, timeout=timeout_tuple, **kwargs) + except requests.exceptions.ConnectTimeout: + click.echo("Error: Connection timed out", err=True) + sys.exit(1) + except requests.exceptions.ReadTimeout: + click.echo("Error: Request timed out waiting for response", err=True) + sys.exit(1) + except requests.exceptions.ConnectionError: + click.echo("Error: Could not connect to API", err=True) + sys.exit(1) + + +@click.group() +@click.option("--api-url", envvar="API_URL", default="http://localhost:8000", help="API base URL") +@click.option("--timeout", "-t", envvar="CLI_TIMEOUT", default=DEFAULT_TIMEOUT, type=int, + help="Request timeout in seconds") +@click.option("--verbose", "-v", is_flag=True, help="Enable verbose output") +@click.pass_context +def cli(ctx, api_url, timeout, verbose): + ctx.ensure_object(dict) + ctx.obj["api_url"] = api_url + ctx.obj["timeout"] = timeout + if verbose: + logging.getLogger().setLevel(logging.DEBUG) + + +@cli.command("list") +@click.option("--json-output", "-j", is_flag=True, help="Output as JSON") +@click.option("--skip", default=0, type=int, help="Number of records to skip") +@click.option("--limit", default=100, type=int, help="Maximum number of records to return") +@click.pass_context +def list_servers(ctx, json_output, skip, limit): + api_url = ctx.obj["api_url"] + timeout = ctx.obj["timeout"] + + params = {"skip": skip, "limit": limit} + response = make_request("get", f"{api_url}/servers", timeout, params=params) + + if response.status_code != 200: + handle_error(response) + + servers = response.json() + + if json_output: + click.echo(json.dumps(servers, indent=2)) + elif not servers: + click.echo("No servers found.") + else: + for server in servers: + click.echo(format_server(server)) + click.echo() + + +@cli.command("get") +@click.argument("server_id", type=int) +@click.option("--json-output", "-j", is_flag=True, help="Output as JSON") +@click.pass_context +def get_server(ctx, server_id, json_output): + api_url = ctx.obj["api_url"] + timeout = ctx.obj["timeout"] + + response = make_request("get", f"{api_url}/servers/{server_id}", timeout) + + if response.status_code != 200: + handle_error(response) + + server = response.json() + + if json_output: + click.echo(json.dumps(server, indent=2)) + else: + click.echo(format_server(server)) + + +@cli.command("create") +@click.option("--hostname", "-n", required=True, help="Server hostname") +@click.option("--ip-address", "-i", required=True, help="Server IP address") +@click.option("--datacenter", "-d", required=True, help="Data center location") +@click.option("--state", "-s", type=click.Choice(["active", "offline", "retired"]), + default="active", help="Server state") +@click.option("--owner", "-o", required=True, help="Server owner") +@click.option("--json-output", "-j", is_flag=True, help="Output as JSON") +@click.pass_context +def create_server(ctx, hostname, ip_address, datacenter, state, owner, json_output): + api_url = ctx.obj["api_url"] + timeout = ctx.obj["timeout"] + payload = { + "hostname": hostname, + "ip_address": ip_address, + "datacenter": datacenter, + "state": state, + "owner": owner, + } + + logger.info("Creating server: %s", hostname) + response = make_request("post", f"{api_url}/servers", timeout, json=payload) + + if response.status_code != 201: + handle_error(response) + + server = response.json() + logger.info("Server created with ID: %d", server["id"]) + + if json_output: + click.echo(json.dumps(server, indent=2)) + else: + click.echo("Server created successfully:") + click.echo(format_server(server)) + + +@cli.command("update") +@click.argument("server_id", type=int) +@click.option("--hostname", "-n", required=True, help="Server hostname") +@click.option("--ip-address", "-i", required=True, help="Server IP address") +@click.option("--datacenter", "-d", required=True, help="Data center location") +@click.option("--state", "-s", type=click.Choice(["active", "offline", "retired"]), + required=True, help="Server state") +@click.option("--owner", "-o", required=True, help="Server owner") +@click.option("--json-output", "-j", is_flag=True, help="Output as JSON") +@click.pass_context +def update_server(ctx, server_id, hostname, ip_address, datacenter, state, owner, json_output): + api_url = ctx.obj["api_url"] + timeout = ctx.obj["timeout"] + payload = { + "hostname": hostname, + "ip_address": ip_address, + "datacenter": datacenter, + "state": state, + "owner": owner, + } + + logger.info("Updating server %d", server_id) + response = make_request("put", f"{api_url}/servers/{server_id}", timeout, json=payload) + + if response.status_code != 200: + handle_error(response) + + server = response.json() + logger.info("Server %d updated successfully", server_id) + + if json_output: + click.echo(json.dumps(server, indent=2)) + else: + click.echo("Server updated successfully:") + click.echo(format_server(server)) + + +@cli.command("delete") +@click.argument("server_id", type=int) +@click.option("--yes", "-y", is_flag=True, help="Skip confirmation") +@click.pass_context +def delete_server(ctx, server_id, yes): + api_url = ctx.obj["api_url"] + timeout = ctx.obj["timeout"] + + if not yes: + click.confirm(f"Are you sure you want to delete server {server_id}?", abort=True) + + logger.info("Deleting server %d", server_id) + response = make_request("delete", f"{api_url}/servers/{server_id}", timeout) + + if response.status_code != 204: + handle_error(response) + + logger.info("Server %d deleted successfully", server_id) + click.echo(f"Server {server_id} deleted successfully.") + + +def main(): + cli(obj={}) + + +if __name__ == "__main__": + main() diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..8e0ec11 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,36 @@ +services: + db: + image: postgres:15-alpine + environment: + POSTGRES_USER: ${POSTGRES_USER:-postgres} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-postgres} + POSTGRES_DB: ${POSTGRES_DB:-inventory} + ports: + - "${DB_PORT:-5432}:5432" + volumes: + - postgres_data:/var/lib/postgresql/data + healthcheck: + test: [ "CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-postgres}" ] + interval: 5s + timeout: 5s + retries: 5 + + api: + build: . + ports: + - "${API_PORT:-8000}:8000" + environment: + DATABASE_URL: postgresql://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-inventory} + depends_on: + db: + condition: service_healthy + healthcheck: + test: [ "CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" ] + interval: 10s + timeout: 5s + start_period: 5s + retries: 3 + restart: unless-stopped + +volumes: + postgres_data: diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..5ee6477 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +testpaths = tests diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..40117f7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +fastapi>=0.115.0 +uvicorn>=0.32.0 +psycopg2-binary>=2.9.10 +pydantic>=2.10.0 +httpx>=0.28.0 +click>=8.1.7 +requests>=2.32.0 +pytest>=8.3.0 +pytest-asyncio>=0.24.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..aa7794f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,123 @@ +from datetime import datetime +from unittest.mock import patch, MagicMock +import pytest +from fastapi.testclient import TestClient + + +class MockDatabase: + """In-memory mock database for testing.""" + + def __init__(self): + self.servers = {} + self.next_id = 1 + + def reset(self): + self.servers = {} + self.next_id = 1 + + def create_server(self, hostname: str, ip_address: str, datacenter: str, state: str, owner: str) -> dict: + # Check for duplicate hostname + for server in self.servers.values(): + if server["hostname"] == hostname: + from psycopg2.errors import UniqueViolation + raise UniqueViolation("duplicate key value violates unique constraint") + + now = datetime.now() + server = { + "id": self.next_id, + "hostname": hostname, + "ip_address": ip_address, + "datacenter": datacenter, + "state": state, + "owner": owner, + "created_at": now, + "updated_at": now, + } + self.servers[self.next_id] = server + self.next_id += 1 + return server + + def get_all_servers(self, skip: int = 0, limit: int = 100) -> list: + all_servers = sorted(self.servers.values(), key=lambda s: s["id"]) + return all_servers[skip:skip + limit] + + def get_server_by_id(self, server_id: int) -> dict | None: + return self.servers.get(server_id) + + def update_server(self, server_id: int, hostname: str, ip_address: str, datacenter: str, state: str, owner: str) -> dict | None: + if server_id not in self.servers: + return None + + # Check for duplicate hostname (excluding current server) + for sid, server in self.servers.items(): + if sid != server_id and server["hostname"] == hostname: + from psycopg2.errors import UniqueViolation + raise UniqueViolation("duplicate key value violates unique constraint") + + server = self.servers[server_id] + server["hostname"] = hostname + server["ip_address"] = ip_address + server["datacenter"] = datacenter + server["state"] = state + server["owner"] = owner + server["updated_at"] = datetime.now() + return server + + def delete_server(self, server_id: int) -> bool: + if server_id in self.servers: + del self.servers[server_id] + return True + return False + + +# Global mock database instance +mock_db = MockDatabase() + + +@pytest.fixture(autouse=True) +def reset_mock_db(): + """Reset mock database before each test.""" + mock_db.reset() + + +@pytest.fixture +def client(): + """Create a test client with mocked database.""" + with patch("api.main.create_server", mock_db.create_server), \ + patch("api.main.get_all_servers", mock_db.get_all_servers), \ + patch("api.main.get_server_by_id", mock_db.get_server_by_id), \ + patch("api.main.update_server", mock_db.update_server), \ + patch("api.main.delete_server", mock_db.delete_server), \ + patch("api.main.get_db") as mock_get_db, \ + patch("api.main.init_db"), \ + patch("api.main.close_db"): + # Mock get_db for health check + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.__enter__ = MagicMock(return_value=mock_conn) + mock_conn.__exit__ = MagicMock(return_value=False) + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + mock_get_db.return_value = mock_conn + + from api.main import app + yield TestClient(app) + + +@pytest.fixture +def sample_server(): + """Sample server data for testing.""" + return { + "hostname": "web-server-01", + "ip_address": "192.168.1.100", + "datacenter": "us-east-1", + "state": "active", + "owner": "dimitrije" + } + + +@pytest.fixture +def created_server(client, sample_server): + """Create a server and return its data.""" + response = client.post("/servers", json=sample_server) + return response.json() diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..b149a92 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,315 @@ +from concurrent.futures import ThreadPoolExecutor, as_completed + + +class TestCreateServer: + + def test_create_server_success(self, client, sample_server): + response = client.post("/servers", json=sample_server) + + assert response.status_code == 201 + data = response.json() + assert data["hostname"] == sample_server["hostname"] + assert data["ip_address"] == sample_server["ip_address"] + assert data["datacenter"] == sample_server["datacenter"] + assert data["state"] == sample_server["state"] + assert data["owner"] == sample_server["owner"] + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + + def test_create_server_duplicate_hostname(self, client, sample_server): + client.post("/servers", json=sample_server) + response = client.post("/servers", json=sample_server) + + assert response.status_code == 409 + assert "already exists" in response.json()["detail"] + + def test_create_server_invalid_ip(self, client, sample_server): + sample_server["ip_address"] = "invalid-ip" + response = client.post("/servers", json=sample_server) + + assert response.status_code == 422 + + def test_create_server_invalid_state(self, client, sample_server): + sample_server["state"] = "invalid-state" + response = client.post("/servers", json=sample_server) + + assert response.status_code == 422 + + def test_create_server_ipv6(self, client, sample_server): + sample_server["ip_address"] = "2001:0db8:85a3:0000:0000:8a2e:0370:7334" + response = client.post("/servers", json=sample_server) + + assert response.status_code == 201 + assert response.json()["ip_address"] == sample_server["ip_address"] + + def test_create_server_missing_fields(self, client): + response = client.post("/servers", json={"hostname": "test"}) + + assert response.status_code == 422 + + +class TestListServers: + + def test_list_servers_empty(self, client): + response = client.get("/servers") + + assert response.status_code == 200 + assert response.json() == [] + + def test_list_servers_with_data(self, client, sample_server): + client.post("/servers", json=sample_server) + + sample_server2 = sample_server.copy() + sample_server2["hostname"] = "web-server-02" + sample_server2["ip_address"] = "192.168.1.101" + client.post("/servers", json=sample_server2) + + response = client.get("/servers") + + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + def test_list_servers_pagination(self, client, sample_server): + # Create multiple servers + for i in range(5): + server = sample_server.copy() + server["hostname"] = f"web-server-{i:02d}" + server["ip_address"] = f"192.168.1.{100 + i}" + client.post("/servers", json=server) + + # Test limit + response = client.get("/servers?limit=2") + assert response.status_code == 200 + assert len(response.json()) == 2 + + # Test skip + response = client.get("/servers?skip=3") + assert response.status_code == 200 + assert len(response.json()) == 2 + + # Test skip and limit + response = client.get("/servers?skip=1&limit=2") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + assert data[0]["hostname"] == "web-server-01" + + def test_list_servers_pagination_validation(self, client): + # Test invalid skip + response = client.get("/servers?skip=-1") + assert response.status_code == 422 + + # Test invalid limit + response = client.get("/servers?limit=0") + assert response.status_code == 422 + + # Test limit exceeds max + response = client.get("/servers?limit=1001") + assert response.status_code == 422 + + +class TestGetServer: + + def test_get_server_success(self, client, created_server): + response = client.get(f"/servers/{created_server['id']}") + + assert response.status_code == 200 + assert response.json()["hostname"] == created_server["hostname"] + + def test_get_server_not_found(self, client): + response = client.get("/servers/99999") + + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +class TestUpdateServer: + + def test_update_server_success(self, client, created_server): + updated_data = { + "hostname": "updated-hostname", + "ip_address": "10.0.0.1", + "datacenter": "us-west-2", + "state": "offline", + "owner": "dimitrije" + } + response = client.put(f"/servers/{created_server['id']}", json=updated_data) + + assert response.status_code == 200 + data = response.json() + assert data["hostname"] == "updated-hostname" + assert data["ip_address"] == "10.0.0.1" + assert data["datacenter"] == "us-west-2" + assert data["state"] == "offline" + assert data["owner"] == "dimitrije" + + def test_update_server_not_found(self, client, sample_server): + response = client.put("/servers/99999", json=sample_server) + + assert response.status_code == 404 + + def test_update_server_duplicate_hostname(self, client, sample_server): + client.post("/servers", json=sample_server) + + sample_server2 = sample_server.copy() + sample_server2["hostname"] = "web-server-02" + sample_server2["ip_address"] = "192.168.1.101" + response2 = client.post("/servers", json=sample_server2) + server2_id = response2.json()["id"] + + sample_server2["hostname"] = sample_server["hostname"] + response = client.put(f"/servers/{server2_id}", json=sample_server2) + + assert response.status_code == 409 + + def test_update_server_same_hostname(self, client, created_server): + updated_data = { + "hostname": created_server["hostname"], + "ip_address": "10.0.0.1", + "datacenter": "us-west-2", + "state": "offline", + "owner": created_server["owner"] + } + response = client.put(f"/servers/{created_server['id']}", json=updated_data) + + assert response.status_code == 200 + + +class TestDeleteServer: + + def test_delete_server_success(self, client, created_server): + response = client.delete(f"/servers/{created_server['id']}") + + assert response.status_code == 204 + + get_response = client.get(f"/servers/{created_server['id']}") + assert get_response.status_code == 404 + + def test_delete_server_not_found(self, client): + response = client.delete("/servers/99999") + + assert response.status_code == 404 + + +class TestValidation: + + def test_valid_states(self, client, sample_server): + for state in ["active", "offline", "retired"]: + sample_server["state"] = state + sample_server["hostname"] = f"server-{state}" + response = client.post("/servers", json=sample_server) + assert response.status_code == 201 + + def test_valid_ipv4_addresses(self, client, sample_server): + valid_ips = ["0.0.0.0", "255.255.255.255", "192.168.1.1", "10.0.0.1"] + for i, ip in enumerate(valid_ips): + sample_server["ip_address"] = ip + sample_server["hostname"] = f"server-ipv4-{i}" + response = client.post("/servers", json=sample_server) + assert response.status_code == 201 + + def test_valid_ipv6_addresses(self, client, sample_server): + valid_ips = [ + "::1", + "fe80::1", + "2001:db8::1", + "2001:0db8:85a3:0000:0000:8a2e:0370:7334" + ] + for i, ip in enumerate(valid_ips): + sample_server["ip_address"] = ip + sample_server["hostname"] = f"server-ipv6-{i}" + response = client.post("/servers", json=sample_server) + assert response.status_code == 201 + + def test_valid_hostnames(self, client, sample_server): + valid_hostnames = [ + "a", + "server1", + "web-server", + "app.example.com", + "db-01.us-east-1.example.com", + ] + for i, hostname in enumerate(valid_hostnames): + sample_server["hostname"] = hostname + sample_server["ip_address"] = f"10.0.0.{i + 1}" + response = client.post("/servers", json=sample_server) + assert response.status_code == 201, f"Failed for hostname: {hostname}" + + def test_invalid_hostnames(self, client, sample_server): + invalid_hostnames = [ + "-server", + "server-", + ".server", + "server.", + "server..name", + "ser ver", + "a" * 256, + ] + for hostname in invalid_hostnames: + sample_server["hostname"] = hostname + response = client.post("/servers", json=sample_server) + assert response.status_code == 422, f"Should fail for hostname: {hostname}" + + +class TestHealthCheck: + + def test_health_check(self, client): + response = client.get("/health") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert data["database"] == "connected" + + +class TestConcurrency: + """Test concurrent access to prevent race conditions.""" + + def test_concurrent_create_different_hostnames(self, client, sample_server): + """Multiple threads creating servers with different hostnames should all succeed.""" + + def create_server(index): + server = sample_server.copy() + server["hostname"] = f"concurrent-server-{index}" + server["ip_address"] = f"10.0.0.{index}" + return client.post("/servers", json=server) + + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(create_server, i) for i in range(10)] + results = [f.result() for f in as_completed(futures)] + + success_count = sum(1 for r in results if r.status_code == 201) + assert success_count == 10 + + def test_concurrent_create_same_hostname(self, client, sample_server): + """Multiple threads creating servers with the same hostname - exactly one should succeed.""" + + def create_server(): + server = sample_server.copy() + server["hostname"] = "race-condition-test" + return client.post("/servers", json=server) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(create_server) for _ in range(5)] + results = [f.result() for f in as_completed(futures)] + + success_count = sum(1 for r in results if r.status_code == 201) + conflict_count = sum(1 for r in results if r.status_code == 409) + + assert success_count == 1 + assert conflict_count == 4 + + def test_concurrent_reads(self, client, created_server): + """Multiple concurrent reads should all succeed.""" + + def read_server(): + return client.get(f"/servers/{created_server['id']}") + + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(read_server) for _ in range(10)] + results = [f.result() for f in as_completed(futures)] + + assert all(r.status_code == 200 for r in results) + assert all(r.json()["id"] == created_server["id"] for r in results) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..c6617b1 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,335 @@ +import json +import pytest +from unittest.mock import patch, MagicMock +from click.testing import CliRunner +from cli.main import cli + + +@pytest.fixture +def runner(): + return CliRunner() + + +@pytest.fixture +def mock_server(): + return { + "id": 1, + "hostname": "web-server-01", + "ip_address": "192.168.1.100", + "datacenter": "us-east-1", + "state": "active", + "owner": "dimitrije", + "created_at": "2024-01-01T00:00:00", + "updated_at": "2024-01-01T00:00:00" + } + + +class TestListCommand: + + @patch("cli.main.requests.request") + def test_list_servers_empty(self, mock_request, runner): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [] + mock_request.return_value = mock_response + + result = runner.invoke(cli, ["list"]) + + assert result.exit_code == 0 + assert "No servers found" in result.output + + @patch("cli.main.requests.request") + def test_list_servers_with_data(self, mock_request, runner, mock_server): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [mock_server] + mock_request.return_value = mock_response + + result = runner.invoke(cli, ["list"]) + + assert result.exit_code == 0 + assert mock_server["hostname"] in result.output + + @patch("cli.main.requests.request") + def test_list_servers_json_output(self, mock_request, runner, mock_server): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [mock_server] + mock_request.return_value = mock_response + + result = runner.invoke(cli, ["list", "--json-output"]) + + assert result.exit_code == 0 + output = json.loads(result.output) + assert len(output) == 1 + assert output[0]["hostname"] == mock_server["hostname"] + + @patch("cli.main.requests.request") + def test_list_servers_with_pagination(self, mock_request, runner, mock_server): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [mock_server] + mock_request.return_value = mock_response + + result = runner.invoke(cli, ["list", "--skip", "10", "--limit", "50"]) + + assert result.exit_code == 0 + call_kwargs = mock_request.call_args[1] + assert call_kwargs["params"]["skip"] == 10 + assert call_kwargs["params"]["limit"] == 50 + + +class TestGetCommand: + + @patch("cli.main.requests.request") + def test_get_server_success(self, mock_request, runner, mock_server): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_server + mock_request.return_value = mock_response + + result = runner.invoke(cli, ["get", "1"]) + + assert result.exit_code == 0 + assert mock_server["hostname"] in result.output + + @patch("cli.main.requests.request") + def test_get_server_not_found(self, mock_request, runner): + mock_response = MagicMock() + mock_response.status_code = 404 + mock_response.json.return_value = {"detail": "Server not found"} + mock_request.return_value = mock_response + + result = runner.invoke(cli, ["get", "999"]) + + assert result.exit_code == 1 + assert "Server not found" in result.output + + @patch("cli.main.requests.request") + def test_get_server_json_output(self, mock_request, runner, mock_server): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_server + mock_request.return_value = mock_response + + result = runner.invoke(cli, ["get", "1", "--json-output"]) + + assert result.exit_code == 0 + output = json.loads(result.output) + assert output["hostname"] == mock_server["hostname"] + + +class TestCreateCommand: + + @patch("cli.main.requests.request") + def test_create_server_success(self, mock_request, runner, mock_server): + mock_response = MagicMock() + mock_response.status_code = 201 + mock_response.json.return_value = mock_server + mock_request.return_value = mock_response + + result = runner.invoke(cli, [ + "create", + "--hostname", "web-server-01", + "--ip-address", "192.168.1.100", + "--datacenter", "us-east-1", + "--state", "active", + "--owner", "dimitrije" + ]) + + assert result.exit_code == 0 + assert "created successfully" in result.output + + @patch("cli.main.requests.request") + def test_create_server_with_short_options(self, mock_request, runner, mock_server): + mock_response = MagicMock() + mock_response.status_code = 201 + mock_response.json.return_value = mock_server + mock_request.return_value = mock_response + + result = runner.invoke(cli, [ + "create", + "-n", "web-server-01", + "-i", "192.168.1.100", + "-d", "us-east-1", + "-s", "active", + "-o", "dimitrije" + ]) + + assert result.exit_code == 0 + assert "created successfully" in result.output + + @patch("cli.main.requests.request") + def test_create_server_duplicate(self, mock_request, runner): + mock_response = MagicMock() + mock_response.status_code = 409 + mock_response.json.return_value = {"detail": "hostname already exists"} + mock_request.return_value = mock_response + + result = runner.invoke(cli, [ + "create", + "--hostname", "web-server-01", + "--ip-address", "192.168.1.100", + "--datacenter", "us-east-1", + "--state", "active", + "--owner", "dimitrije" + ]) + + assert result.exit_code == 1 + assert "already exists" in result.output + + def test_create_server_missing_required(self, runner): + result = runner.invoke(cli, ["create", "--hostname", "test"]) + + assert result.exit_code != 0 + + +class TestUpdateCommand: + + @patch("cli.main.requests.request") + def test_update_server_success(self, mock_request, runner, mock_server): + mock_server["state"] = "offline" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_server + mock_request.return_value = mock_response + + result = runner.invoke(cli, [ + "update", "1", + "--hostname", "web-server-01", + "--ip-address", "192.168.1.100", + "--datacenter", "us-east-1", + "--state", "offline", + "--owner", "dimitrije" + ]) + + assert result.exit_code == 0 + assert "updated successfully" in result.output + + @patch("cli.main.requests.request") + def test_update_server_not_found(self, mock_request, runner): + mock_response = MagicMock() + mock_response.status_code = 404 + mock_response.json.return_value = {"detail": "Server not found"} + mock_request.return_value = mock_response + + result = runner.invoke(cli, [ + "update", "999", + "--hostname", "test", + "--ip-address", "192.168.1.1", + "--datacenter", "us-east-1", + "--state", "active", + "--owner", "dimitrije" + ]) + + assert result.exit_code == 1 + + +class TestDeleteCommand: + + @patch("cli.main.requests.request") + def test_delete_server_success(self, mock_request, runner): + mock_response = MagicMock() + mock_response.status_code = 204 + mock_request.return_value = mock_response + + result = runner.invoke(cli, ["delete", "1", "--yes"]) + + assert result.exit_code == 0 + assert "deleted successfully" in result.output + + @patch("cli.main.requests.request") + def test_delete_server_not_found(self, mock_request, runner): + mock_response = MagicMock() + mock_response.status_code = 404 + mock_response.json.return_value = {"detail": "Server not found"} + mock_request.return_value = mock_response + + result = runner.invoke(cli, ["delete", "999", "--yes"]) + + assert result.exit_code == 1 + + def test_delete_server_confirmation_abort(self, runner): + result = runner.invoke(cli, ["delete", "1"], input="n\n") + + assert result.exit_code != 0 + + +class TestConnectionError: + + @patch("cli.main.requests.request") + def test_connection_error(self, mock_request, runner): + import requests + mock_request.side_effect = requests.exceptions.ConnectionError() + + result = runner.invoke(cli, ["list"]) + + assert result.exit_code == 1 + assert "Could not connect" in result.output + + @patch("cli.main.requests.request") + def test_connect_timeout(self, mock_request, runner): + import requests + mock_request.side_effect = requests.exceptions.ConnectTimeout() + + result = runner.invoke(cli, ["list"]) + + assert result.exit_code == 1 + assert "Connection timed out" in result.output + + @patch("cli.main.requests.request") + def test_read_timeout(self, mock_request, runner): + import requests + mock_request.side_effect = requests.exceptions.ReadTimeout() + + result = runner.invoke(cli, ["list"]) + + assert result.exit_code == 1 + assert "Request timed out" in result.output + + +class TestApiUrlOption: + + @patch("cli.main.requests.request") + def test_custom_api_url(self, mock_request, runner): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [] + mock_request.return_value = mock_response + + result = runner.invoke(cli, ["--api-url", "http://custom:9000", "list"]) + + assert result.exit_code == 0 + mock_request.assert_called_once() + call_args = mock_request.call_args + assert "http://custom:9000" in call_args[0][1] + + +class TestTimeoutOption: + + @patch("cli.main.requests.request") + def test_custom_timeout(self, mock_request, runner): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [] + mock_request.return_value = mock_response + + result = runner.invoke(cli, ["--timeout", "60", "list"]) + + assert result.exit_code == 0 + call_kwargs = mock_request.call_args[1] + assert call_kwargs["timeout"] == (5, 60) + + +class TestVerboseOption: + + @patch("cli.main.requests.request") + def test_verbose_output(self, mock_request, runner): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [] + mock_request.return_value = mock_response + + result = runner.invoke(cli, ["--verbose", "list"]) + + assert result.exit_code == 0 diff --git a/tests/test_validators.py b/tests/test_validators.py new file mode 100644 index 0000000..38e4ae6 --- /dev/null +++ b/tests/test_validators.py @@ -0,0 +1,165 @@ +from api.models import _validate_ip as validate_ip_address, _validate_hostname as validate_hostname, ServerState + +VALID_STATES = {s.value for s in ServerState} + + +def validate_state(state: str) -> bool: + return state in VALID_STATES + + +class TestIPAddressValidation: + + def test_valid_ipv4(self): + valid_ips = [ + "192.168.1.1", + "10.0.0.1", + "0.0.0.0", + "255.255.255.255", + "127.0.0.1", + ] + for ip in valid_ips: + assert validate_ip_address(ip) is True, f"Should be valid: {ip}" + + def test_valid_ipv6(self): + valid_ips = [ + "::1", + "fe80::1", + "2001:db8::1", + "2001:0db8:85a3:0000:0000:8a2e:0370:7334", + "::ffff:192.168.1.1", + ] + for ip in valid_ips: + assert validate_ip_address(ip) is True, f"Should be valid: {ip}" + + def test_invalid_ip_addresses(self): + invalid_ips = [ + "invalid", + "192.168.1", + "192.168.1.1.1", + "256.0.0.1", + "192.168.1.999", + "", + "localhost", + "192.168.1.1:8080", + ] + for ip in invalid_ips: + assert validate_ip_address(ip) is False, f"Should be invalid: {ip}" + + +class TestStateValidation: + + def test_valid_states(self): + valid_states = ["active", "offline", "retired"] + for state in valid_states: + assert validate_state(state) is True, f"Should be valid: {state}" + + def test_invalid_states(self): + invalid_states = [ + "running", + "stopped", + "Active", + "OFFLINE", + "", + "inactive", + ] + for state in invalid_states: + assert validate_state(state) is False, f"Should be invalid: {state}" + + +class TestHostnameValidation: + + def test_valid_simple_hostnames(self): + valid_hostnames = [ + "a", + "server1", + "web-server", + "my-host-name", + "server123", + "123server", + "a1", + "1a", + ] + for hostname in valid_hostnames: + assert validate_hostname(hostname) is True, f"Should be valid: {hostname}" + + def test_valid_fqdn_hostnames(self): + valid_hostnames = [ + "server.example.com", + "web-01.us-east-1.example.com", + "db.prod.internal", + "a.b.c", + "server1.server2.server3", + ] + for hostname in valid_hostnames: + assert validate_hostname(hostname) is True, f"Should be valid: {hostname}" + + def test_invalid_hostnames_leading_trailing_chars(self): + invalid_hostnames = [ + "-server", + "server-", + ".server", + "server.", + "-", + ".", + ] + for hostname in invalid_hostnames: + assert validate_hostname(hostname) is False, f"Should be invalid: {hostname}" + + def test_invalid_hostnames_consecutive_dots(self): + invalid_hostnames = [ + "server..name", + "a..b", + "server...name", + ] + for hostname in invalid_hostnames: + assert validate_hostname(hostname) is False, f"Should be invalid: {hostname}" + + def test_invalid_hostnames_special_chars(self): + invalid_hostnames = [ + "server name", + "server_name", + "server@name", + "server#name", + "server!", + ] + for hostname in invalid_hostnames: + assert validate_hostname(hostname) is False, f"Should be invalid: {hostname}" + + def test_hostname_length_limits(self): + # Max total length is 255 (must use valid labels of max 63 chars each) + # Create a 255 char hostname: 63 + 1 + 63 + 1 + 63 + 1 + 63 = 255 + long_valid = "a" * 63 + "." + "b" * 63 + "." + "c" * 63 + "." + "d" * 63 + assert len(long_valid) == 255 + assert validate_hostname(long_valid) is True + + # 256 chars should fail + too_long = long_valid + "e" + assert validate_hostname(too_long) is False + + # Empty hostname + assert validate_hostname("") is False + + # Each label max 63 chars + assert validate_hostname("a" * 63) is True + assert validate_hostname("a" * 63 + ".b") is True + assert validate_hostname("a" * 64) is False # Single label > 63 chars + + def test_hostname_label_rules(self): + # Labels must start and end with alphanumeric + assert validate_hostname("a-b") is True + assert validate_hostname("a--b") is True # Consecutive hyphens in middle is OK + assert validate_hostname("-ab") is False + assert validate_hostname("ab-") is False + + # Each label in FQDN follows same rules + assert validate_hostname("ok.ok") is True + assert validate_hostname("-bad.ok") is False + assert validate_hostname("ok.-bad") is False + assert validate_hostname("ok.bad-") is False + + def test_single_character_labels(self): + assert validate_hostname("a") is True + assert validate_hostname("a.b") is True + assert validate_hostname("a.b.c") is True + assert validate_hostname("1") is True + assert validate_hostname("1.2.3") is True