Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tools/lib/api.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
import os
import requests
from requests.adapters import HTTPAdapter, Retry

from openpilot.common.swaglog import cloudlog
API_HOST = os.getenv('API_HOST', 'https://api.commadotai.com')

# TODO: this should be merged into common.api


class LogCallbackRetry(Retry):
def increment(self, method=None, url=None, *args, **kwargs):
if url:
cloudlog.warning(f"[API Failure] Retrying {method} {url}")
return super().increment(method, url, *args, **kwargs)

class CommaApi:
def __init__(self, token=None):
self.session = requests.Session()
self.session.headers['User-agent'] = 'OpenpilotTools'
if token:
self.session.headers['Authorization'] = 'JWT ' + token

retries = LogCallbackRetry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
self.session.mount('https://', HTTPAdapter(max_retries=retries))

def request(self, method, endpoint, **kwargs):
with self.session.request(method, API_HOST + '/' + endpoint, **kwargs) as resp:
resp_json = resp.json()
Expand Down
79 changes: 79 additions & 0 deletions tools/lib/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pytest
from openpilot.tools.lib.api import LogCallbackRetry, CommaApi, UnauthorizedError, APIError

class TestLogCallbackRetry:
def test_increment_logs_warning(self, mocker):
mock_cloudlog = mocker.patch("openpilot.tools.lib.api.cloudlog")
retry = LogCallbackRetry(total=1)

retry.increment(method="GET", url="http://test.com")

mock_cloudlog.warning.assert_called_with("[API Failure] Retrying GET http://test.com")

def test_increment_no_url(self, mocker):
mock_cloudlog = mocker.patch("openpilot.tools.lib.api.cloudlog")
retry = LogCallbackRetry(total=1)

retry.increment(method="GET")

mock_cloudlog.warning.assert_not_called()

def test_retry_configuration(self):
from openpilot.tools.lib.api import LogCallbackRetry
from requests.adapters import HTTPAdapter

api = CommaApi(token="test_token")
adapter = api.session.adapters['https://']
assert isinstance(adapter, HTTPAdapter)
assert isinstance(adapter.max_retries, LogCallbackRetry)
assert adapter.max_retries.total == 5


class TestCommaApi:
@pytest.fixture(autouse=True)
def setup(self):
self.api = CommaApi()

def test_init_token(self):
api = CommaApi(token="test_token")
assert api.session.headers['Authorization'] == 'JWT test_token'

def test_request_success(self, mocker):
mock_resp = mocker.MagicMock()
mock_resp.json.return_value = {"key": "value"}
mock_resp.status_code = 200
mock_request = mocker.patch("openpilot.tools.lib.api.requests.Session.request")
mock_request.return_value.__enter__.return_value = mock_resp

resp = self.api.request("GET", "test_endpoint")
assert resp == {"key": "value"}

def test_request_unauthorized(self, mocker):
mock_resp = mocker.MagicMock()
mock_resp.json.return_value = {"error": "unauthorized"}
mock_resp.status_code = 401
mock_request = mocker.patch("openpilot.tools.lib.api.requests.Session.request")
mock_request.return_value.__enter__.return_value = mock_resp

with pytest.raises(UnauthorizedError):
self.api.request("GET", "test_endpoint")

def test_request_api_error(self, mocker):
mock_resp = mocker.MagicMock()
mock_resp.json.return_value = {"error": "server error", "description": "details"}
mock_resp.status_code = 500
mock_request = mocker.patch("openpilot.tools.lib.api.requests.Session.request")
mock_request.return_value.__enter__.return_value = mock_resp

with pytest.raises(APIError) as cm:
self.api.request("GET", "test_endpoint")
assert cm.value.status_code == 500
assert "details" in str(cm.value)

def test_get_post(self, mocker):
mock_request = mocker.patch("openpilot.tools.lib.api.CommaApi.request")
self.api.get("endpoint", param="1")
mock_request.assert_called_with("GET", "endpoint", param="1")

self.api.post("endpoint", data="2")
mock_request.assert_called_with("POST", "endpoint", data="2")