mirror of
https://github.com/kyverno/kyverno.git
synced 2025-03-15 04:07:46 +00:00
* fix: use http.MaxBytesReader instead of content length for API Calls * feat: add unit tests * feat: added test for chunked transfer --------- Signed-off-by: Vishal Choudhary <vishal.choudhary@nirmata.com> Co-authored-by: Vishal Choudhary <vishal.choudhary@nirmata.com>
This commit is contained in:
parent
98f2162413
commit
92028dfd9b
2 changed files with 78 additions and 43 deletions
|
@ -141,19 +141,14 @@ func (a *apiCall) executeServiceCall(ctx context.Context, apiCall *kyvernov1.API
|
||||||
return nil, fmt.Errorf("failed to execute HTTP request for APICall %s: %w", a.entry.Name, err)
|
return nil, fmt.Errorf("failed to execute HTTP request for APICall %s: %w", a.entry.Name, err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
var w http.ResponseWriter
|
||||||
|
|
||||||
if a.config.maxAPICallResponseLength != 0 {
|
if a.config.maxAPICallResponseLength != 0 {
|
||||||
if resp.ContentLength <= 0 {
|
resp.Body = http.MaxBytesReader(w, resp.Body, a.config.maxAPICallResponseLength)
|
||||||
return nil, fmt.Errorf("content length header must be present.")
|
|
||||||
}
|
|
||||||
if resp.ContentLength > a.config.maxAPICallResponseLength {
|
|
||||||
return nil, fmt.Errorf("content length must be less than max response length of %d.", a.config.maxAPICallResponseLength)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
reader := io.LimitReader(resp.Body, max(a.config.maxAPICallResponseLength, resp.ContentLength))
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
b, err := io.ReadAll(reader)
|
b, err := io.ReadAll(resp.Body)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil, fmt.Errorf("HTTP %s: %s", resp.Status, string(b))
|
return nil, fmt.Errorf("HTTP %s: %s", resp.Status, string(b))
|
||||||
}
|
}
|
||||||
|
@ -161,10 +156,14 @@ func (a *apiCall) executeServiceCall(ctx context.Context, apiCall *kyvernov1.API
|
||||||
return nil, fmt.Errorf("HTTP %s", resp.Status)
|
return nil, fmt.Errorf("HTTP %s", resp.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(reader)
|
body, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if _, ok := err.(*http.MaxBytesError); ok {
|
||||||
|
return nil, fmt.Errorf("response length must be less than max allowed response length of %d.", a.config.maxAPICallResponseLength)
|
||||||
|
} else {
|
||||||
return nil, fmt.Errorf("failed to read data from APICall %s: %w", a.entry.Name, err)
|
return nil, fmt.Errorf("failed to read data from APICall %s: %w", a.entry.Name, err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
a.logger.Info("executed service APICall", "name", a.entry.Name, "len", len(body))
|
a.logger.Info("executed service APICall", "name", a.entry.Name, "len", len(body))
|
||||||
return body, nil
|
return body, nil
|
||||||
|
|
|
@ -2,6 +2,7 @@ package apicall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
@ -21,13 +22,31 @@ var (
|
||||||
apiConfig = APICallConfiguration{
|
apiConfig = APICallConfiguration{
|
||||||
maxAPICallResponseLength: 1 * 1000 * 1000,
|
maxAPICallResponseLength: 1 * 1000 * 1000,
|
||||||
}
|
}
|
||||||
|
apiConfigMaxSizeExceed = APICallConfiguration{
|
||||||
|
maxAPICallResponseLength: 10,
|
||||||
|
}
|
||||||
|
apiConfigWithoutSecurityCheck = APICallConfiguration{
|
||||||
|
maxAPICallResponseLength: 0,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func buildTestServer(responseData []byte) *httptest.Server {
|
func buildTestServer(responseData []byte, useChunked bool) *httptest.Server {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("/resource", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/resource", func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method == "GET" {
|
if r.Method == "GET" {
|
||||||
w.Write(responseData)
|
w.Write(responseData)
|
||||||
|
|
||||||
|
if useChunked {
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
panic("expected http.ResponseWriter to be an http.Flusher")
|
||||||
|
}
|
||||||
|
for i := 1; i <= 10; i++ {
|
||||||
|
fmt.Fprintf(w, "Chunk #%d\n", i)
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,8 +61,9 @@ func buildTestServer(responseData []byte) *httptest.Server {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_serviceGetRequest(t *testing.T) {
|
func Test_serviceGetRequest(t *testing.T) {
|
||||||
|
testfn := func(t *testing.T, useChunked bool) {
|
||||||
serverResponse := []byte(`{ "day": "Sunday" }`)
|
serverResponse := []byte(`{ "day": "Sunday" }`)
|
||||||
s := buildTestServer(serverResponse)
|
s := buildTestServer(serverResponse, false)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
entry := kyvernov1.ContextEntry{}
|
entry := kyvernov1.ContextEntry{}
|
||||||
|
@ -78,11 +98,27 @@ func Test_serviceGetRequest(t *testing.T) {
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
assert.Assert(t, data != nil, "nil data")
|
assert.Assert(t, data != nil, "nil data")
|
||||||
assert.Equal(t, string(serverResponse), string(data))
|
assert.Equal(t, string(serverResponse), string(data))
|
||||||
|
|
||||||
|
call, err = New(logr.Discard(), jp, entry, ctx, nil, apiConfigMaxSizeExceed)
|
||||||
|
assert.NilError(t, err)
|
||||||
|
data, err = call.FetchAndLoad(context.TODO())
|
||||||
|
assert.ErrorContains(t, err, "response length must be less than max allowed response length of 10.")
|
||||||
|
|
||||||
|
call, err = New(logr.Discard(), jp, entry, ctx, nil, apiConfigWithoutSecurityCheck)
|
||||||
|
assert.NilError(t, err)
|
||||||
|
data, err = call.FetchAndLoad(context.TODO())
|
||||||
|
assert.NilError(t, err)
|
||||||
|
assert.Assert(t, data != nil, "nil data")
|
||||||
|
assert.Equal(t, string(serverResponse), string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Standard", func(t *testing.T) { testfn(t, false) })
|
||||||
|
t.Run("Chunked", func(t *testing.T) { testfn(t, true) })
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_servicePostRequest(t *testing.T) {
|
func Test_servicePostRequest(t *testing.T) {
|
||||||
serverResponse := []byte(`{ "day": "Monday" }`)
|
serverResponse := []byte(`{ "day": "Monday" }`)
|
||||||
s := buildTestServer(serverResponse)
|
s := buildTestServer(serverResponse, false)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
entry := kyvernov1.ContextEntry{
|
entry := kyvernov1.ContextEntry{
|
||||||
|
|
Loading…
Add table
Reference in a new issue