diff --git a/blackwords.go b/blackwords.go index 01c8502..8954db3 100644 --- a/blackwords.go +++ b/blackwords.go @@ -20,8 +20,7 @@ func BlackWords(words ...string) func(http.Handler) http.Handler { if body != "" { for _, word := range words { if strings.Contains(body, strings.ToLower(word)) { - w.WriteHeader(http.StatusForbidden) - RenderJSON(w, JSON{"error": "one of blacklisted words detected"}) + _ = EncodeJSON(w, http.StatusForbidden, JSON{"error": "one of blacklisted words detected"}) return } } diff --git a/blackwords_test.go b/blackwords_test.go index 67c368f..02c6216 100644 --- a/blackwords_test.go +++ b/blackwords_test.go @@ -82,3 +82,19 @@ func TestBlackwordsFn(t *testing.T) { }) } } + +func TestBlackwordsContentType(t *testing.T) { + bwMiddleware := BlackWords("bad1", "bad2") + ts := httptest.NewServer(bwMiddleware(getTestHandlerBlah())) + defer ts.Close() + + client := http.Client{Timeout: 5 * time.Second} + req, err := http.NewRequest("GET", ts.URL+"/something", bytes.NewBuffer([]byte("contains bad1 word"))) + assert.NoError(t, err) + + resp, err := client.Do(req) + assert.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + assert.Equal(t, "application/json; charset=utf-8", resp.Header.Get("Content-Type")) +} diff --git a/metrics.go b/metrics.go index 1cf0a41..6edc2da 100644 --- a/metrics.go +++ b/metrics.go @@ -13,8 +13,7 @@ func Metrics(onlyIps ...string) func(http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { if r.Method == "GET" && strings.HasSuffix(strings.ToLower(r.URL.Path), "/metrics") { if matched, ip, err := matchSourceIP(r, onlyIps); !matched || err != nil { - w.WriteHeader(http.StatusForbidden) - RenderJSON(w, JSON{"error": fmt.Sprintf("ip %s rejected", ip)}) + _ = EncodeJSON(w, http.StatusForbidden, JSON{"error": fmt.Sprintf("ip %s rejected", ip)}) return } expvar.Handler().ServeHTTP(w, r) diff --git a/metrics_test.go b/metrics_test.go index 96e17aa..bcb2f5b 100644 --- a/metrics_test.go +++ b/metrics_test.go @@ -43,3 +43,18 @@ func TestMetricsRejected(t *testing.T) { defer resp.Body.Close() assert.Equal(t, 403, resp.StatusCode) } + +func TestMetricsContentType(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, err := w.Write([]byte("blah blah")) + require.NoError(t, err) + }) + ts := httptest.NewServer(Metrics("1.1.1.1")(handler)) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/metrics") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + assert.Equal(t, "application/json; charset=utf-8", resp.Header.Get("Content-Type")) +} diff --git a/middleware.go b/middleware.go index 86b3cc1..c04af1f 100644 --- a/middleware.go +++ b/middleware.go @@ -82,12 +82,11 @@ func Health(path string, checkers ...func(ctx context.Context) (name string, err } resp = append(resp, hh) } + status := http.StatusOK if anyError { - w.WriteHeader(http.StatusServiceUnavailable) - } else { - w.WriteHeader(http.StatusOK) + status = http.StatusServiceUnavailable } - RenderJSON(w, resp) + _ = EncodeJSON(w, status, resp) } return http.HandlerFunc(fn) } diff --git a/middleware_test.go b/middleware_test.go index 7f57d7f..fe2da88 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -271,6 +271,41 @@ func TestHealthFailed(t *testing.T) { assert.Equal(t, `[{"name":"check1","status":"ok"},{"name":"check2","status":"failed","error":"some error"}]`+"\n", string(b)) } +func TestHealthContentType(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, err := w.Write([]byte("blah blah")) + require.NoError(t, err) + }) + + t.Run("healthy returns json content-type", func(t *testing.T) { + check := func(context.Context) (string, error) { + return "check1", nil + } + ts := httptest.NewServer(Health("/health", check)(handler)) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/health") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "application/json; charset=utf-8", resp.Header.Get("Content-Type")) + }) + + t.Run("unhealthy returns json content-type", func(t *testing.T) { + check := func(context.Context) (string, error) { + return "check1", errors.New("some error") + } + ts := httptest.NewServer(Health("/health", check)(handler)) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/health") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + assert.Equal(t, "application/json; charset=utf-8", resp.Header.Get("Content-Type")) + }) +} + func TestReject(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, err := w.Write([]byte("blah blah")) diff --git a/onlyfrom.go b/onlyfrom.go index cccd77e..acc8ff2 100644 --- a/onlyfrom.go +++ b/onlyfrom.go @@ -21,8 +21,7 @@ func OnlyFrom(onlyIps ...string) func(http.Handler) http.Handler { } matched, ip, err := matchSourceIP(r, onlyIps) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - RenderJSON(w, JSON{"error": fmt.Sprintf("can't get realip: %s", err)}) + _ = EncodeJSON(w, http.StatusInternalServerError, JSON{"error": fmt.Sprintf("can't get realip: %s", err)}) return } if matched { @@ -31,8 +30,7 @@ func OnlyFrom(onlyIps ...string) func(http.Handler) http.Handler { return } - w.WriteHeader(http.StatusForbidden) - RenderJSON(w, JSON{"error": fmt.Sprintf("ip %q rejected", ip)}) + _ = EncodeJSON(w, http.StatusForbidden, JSON{"error": fmt.Sprintf("ip %q rejected", ip)}) } return http.HandlerFunc(fn) } diff --git a/onlyfrom_test.go b/onlyfrom_test.go index 0d0d6fb..c86bece 100644 --- a/onlyfrom_test.go +++ b/onlyfrom_test.go @@ -150,3 +150,36 @@ func TestOnlyFromErrors(t *testing.T) { }) } } + +func TestOnlyFromContentType(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, err := w.Write([]byte("blah blah")) + require.NoError(t, err) + }) + + t.Run("rejected ip returns json content-type", func(t *testing.T) { + ts := httptest.NewServer(OnlyFrom("1.1.1.1")(handler)) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/blah") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + assert.Equal(t, "application/json; charset=utf-8", resp.Header.Get("Content-Type")) + }) + + t.Run("invalid remote addr returns json content-type", func(t *testing.T) { + outerHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.RemoteAddr = "bad-addr" + OnlyFrom("1.1.1.1")(handler).ServeHTTP(w, r) + }) + ts := httptest.NewServer(outerHandler) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/blah") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.Equal(t, "application/json; charset=utf-8", resp.Header.Get("Content-Type")) + }) +}