Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions blackwords.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ func BlackWords(words ...string) func(http.Handler) http.Handler {
if body != "" {
for _, word := range words {
if strings.Contains(body, strings.ToLower(word)) {
w.WriteHeader(http.StatusForbidden)
RenderJSON(w, JSON{"error": "one of blacklisted words detected"})
_ = EncodeJSON(w, http.StatusForbidden, JSON{"error": "one of blacklisted words detected"})
return
}
}
Expand Down
16 changes: 16 additions & 0 deletions blackwords_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,19 @@ func TestBlackwordsFn(t *testing.T) {
})
}
}

func TestBlackwordsContentType(t *testing.T) {
bwMiddleware := BlackWords("bad1", "bad2")
ts := httptest.NewServer(bwMiddleware(getTestHandlerBlah()))
defer ts.Close()

client := http.Client{Timeout: 5 * time.Second}
req, err := http.NewRequest("GET", ts.URL+"/something", bytes.NewBuffer([]byte("contains bad1 word")))
assert.NoError(t, err)

resp, err := client.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
assert.Equal(t, "application/json; charset=utf-8", resp.Header.Get("Content-Type"))
}
3 changes: 1 addition & 2 deletions metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ func Metrics(onlyIps ...string) func(http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" && strings.HasSuffix(strings.ToLower(r.URL.Path), "/metrics") {
if matched, ip, err := matchSourceIP(r, onlyIps); !matched || err != nil {
w.WriteHeader(http.StatusForbidden)
RenderJSON(w, JSON{"error": fmt.Sprintf("ip %s rejected", ip)})
_ = EncodeJSON(w, http.StatusForbidden, JSON{"error": fmt.Sprintf("ip %s rejected", ip)})
return
}
expvar.Handler().ServeHTTP(w, r)
Expand Down
15 changes: 15 additions & 0 deletions metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,18 @@ func TestMetricsRejected(t *testing.T) {
defer resp.Body.Close()
assert.Equal(t, 403, resp.StatusCode)
}

func TestMetricsContentType(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, err := w.Write([]byte("blah blah"))
require.NoError(t, err)
})
ts := httptest.NewServer(Metrics("1.1.1.1")(handler))
defer ts.Close()

resp, err := http.Get(ts.URL + "/metrics")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
assert.Equal(t, "application/json; charset=utf-8", resp.Header.Get("Content-Type"))
}
7 changes: 3 additions & 4 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,11 @@ func Health(path string, checkers ...func(ctx context.Context) (name string, err
}
resp = append(resp, hh)
}
status := http.StatusOK
if anyError {
w.WriteHeader(http.StatusServiceUnavailable)
} else {
w.WriteHeader(http.StatusOK)
status = http.StatusServiceUnavailable
}
RenderJSON(w, resp)
_ = EncodeJSON(w, status, resp)
}
return http.HandlerFunc(fn)
}
Expand Down
35 changes: 35 additions & 0 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,41 @@ func TestHealthFailed(t *testing.T) {
assert.Equal(t, `[{"name":"check1","status":"ok"},{"name":"check2","status":"failed","error":"some error"}]`+"\n", string(b))
}

func TestHealthContentType(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, err := w.Write([]byte("blah blah"))
require.NoError(t, err)
})

t.Run("healthy returns json content-type", func(t *testing.T) {
check := func(context.Context) (string, error) {
return "check1", nil
}
ts := httptest.NewServer(Health("/health", check)(handler))
defer ts.Close()

resp, err := http.Get(ts.URL + "/health")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "application/json; charset=utf-8", resp.Header.Get("Content-Type"))
})

t.Run("unhealthy returns json content-type", func(t *testing.T) {
check := func(context.Context) (string, error) {
return "check1", errors.New("some error")
}
ts := httptest.NewServer(Health("/health", check)(handler))
defer ts.Close()

resp, err := http.Get(ts.URL + "/health")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
assert.Equal(t, "application/json; charset=utf-8", resp.Header.Get("Content-Type"))
})
}

func TestReject(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, err := w.Write([]byte("blah blah"))
Expand Down
6 changes: 2 additions & 4 deletions onlyfrom.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ func OnlyFrom(onlyIps ...string) func(http.Handler) http.Handler {
}
matched, ip, err := matchSourceIP(r, onlyIps)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
RenderJSON(w, JSON{"error": fmt.Sprintf("can't get realip: %s", err)})
_ = EncodeJSON(w, http.StatusInternalServerError, JSON{"error": fmt.Sprintf("can't get realip: %s", err)})
return
}
if matched {
Expand All @@ -31,8 +30,7 @@ func OnlyFrom(onlyIps ...string) func(http.Handler) http.Handler {
return
}

w.WriteHeader(http.StatusForbidden)
RenderJSON(w, JSON{"error": fmt.Sprintf("ip %q rejected", ip)})
_ = EncodeJSON(w, http.StatusForbidden, JSON{"error": fmt.Sprintf("ip %q rejected", ip)})
}
return http.HandlerFunc(fn)
}
Expand Down
33 changes: 33 additions & 0 deletions onlyfrom_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,36 @@ func TestOnlyFromErrors(t *testing.T) {
})
}
}

func TestOnlyFromContentType(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, err := w.Write([]byte("blah blah"))
require.NoError(t, err)
})

t.Run("rejected ip returns json content-type", func(t *testing.T) {
ts := httptest.NewServer(OnlyFrom("1.1.1.1")(handler))
defer ts.Close()

resp, err := http.Get(ts.URL + "/blah")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
assert.Equal(t, "application/json; charset=utf-8", resp.Header.Get("Content-Type"))
})

t.Run("invalid remote addr returns json content-type", func(t *testing.T) {
outerHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.RemoteAddr = "bad-addr"
OnlyFrom("1.1.1.1")(handler).ServeHTTP(w, r)
})
ts := httptest.NewServer(outerHandler)
defer ts.Close()

resp, err := http.Get(ts.URL + "/blah")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
assert.Equal(t, "application/json; charset=utf-8", resp.Header.Get("Content-Type"))
})
}
Loading