Skip to content

Commit a7a3c8a

Browse files
authored
feat: measure external latency (#779)
1 parent f29ab1d commit a7a3c8a

File tree

9 files changed

+252
-43
lines changed

9 files changed

+252
-43
lines changed

.github/workflows/format.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
- uses: actions/checkout@v3
1212
- uses: actions/setup-go@v3
1313
with:
14-
go-version: "1.21"
14+
go-version: "1.22"
1515
- run: make format
1616
- name: Indicate formatting issues
1717
run: git diff HEAD --exit-code --color

.github/workflows/licenses.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
- uses: actions/checkout@v2
1515
- uses: actions/setup-go@v2
1616
with:
17-
go-version: "1.21"
17+
go-version: "1.22"
1818
- uses: actions/setup-node@v2
1919
with:
2020
node-version: "18"

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
- uses: actions/checkout@v2
2020
- uses: actions/setup-go@v2
2121
with:
22-
go-version: "1.21"
22+
go-version: "1.22"
2323
- run: |
2424
go test -tags sqlite -failfast -short -timeout=20m $(go list ./... | grep -v sqlcon | grep -v watcherx | grep -v pkgerx | grep -v configx)
2525
shell: bash
@@ -55,7 +55,7 @@ jobs:
5555
uses: actions/checkout@v2
5656
- uses: actions/setup-go@v2
5757
with:
58-
go-version: "1.21"
58+
go-version: "1.22"
5959
- name: golangci-lint
6060
uses: golangci/golangci-lint-action@v3
6161
with:

go.mod

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
module github.com/ory/x
22

3-
go 1.21
3+
go 1.22
4+
5+
toolchain go1.22.2
46

57
require (
68
code.dny.dev/ssrf v0.2.0

httpx/external_latency.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright © 2024 Ory Corp
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package httpx
5+
6+
import (
7+
"net/http"
8+
"time"
9+
10+
"github.com/ory/x/reqlog"
11+
)
12+
13+
// MeasureExternalLatencyTransport is an http.RoundTripper that measures the latency of all requests as external latency.
14+
type MeasureExternalLatencyTransport struct {
15+
Transport http.RoundTripper
16+
}
17+
18+
var _ http.RoundTripper = (*MeasureExternalLatencyTransport)(nil)
19+
20+
func (m *MeasureExternalLatencyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
21+
upstreamHostPath := req.URL.Scheme + "://" + req.URL.Host + req.URL.Path
22+
defer reqlog.StartMeasureExternalCall(req.Context(), "http_request", upstreamHostPath, time.Now())
23+
24+
t := m.Transport
25+
if t == nil {
26+
t = http.DefaultTransport
27+
}
28+
return t.RoundTrip(req)
29+
}

proxy/proxy_full_test.go

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,20 @@ func (rt *testingRoundTripper) RoundTrip(req *http.Request) (*http.Response, err
8686
}
8787

8888
func TestFullIntegration(t *testing.T) {
89-
upstream, upstreamHandler := httpx.NewChanHandler(0)
89+
upstream, upstreamHandler := httpx.NewChanHandler(1)
9090
upstreamServer := httptest.NewTLSServer(upstream)
9191
defer upstreamServer.Close()
9292

9393
// create the proxy
94-
hostMapper := make(chan func(*http.Request) (*HostConfig, error))
95-
reqMiddleware := make(chan ReqMiddleware)
96-
respMiddleware := make(chan RespMiddleware)
94+
hostMapper := make(chan func(*http.Request) (*HostConfig, error), 1)
95+
reqMiddleware := make(chan ReqMiddleware, 1)
96+
respMiddleware := make(chan RespMiddleware, 1)
9797

9898
type CustomErrorReq func(*http.Request, error)
9999
type CustomErrorResp func(*http.Response, error) error
100100

101-
onErrorReq := make(chan CustomErrorReq)
102-
onErrorResp := make(chan CustomErrorResp)
101+
onErrorReq := make(chan CustomErrorReq, 1)
102+
onErrorResp := make(chan CustomErrorResp, 1)
103103

104104
proxy := httptest.NewTLSServer(New(
105105
func(ctx context.Context, r *http.Request) (context.Context, *HostConfig, error) {
@@ -122,17 +122,20 @@ func TestFullIntegration(t *testing.T) {
122122
return f(resp, config, body)
123123
}),
124124
WithOnError(func(request *http.Request, err error) {
125-
f := <-onErrorReq
126-
if f == nil {
127-
return
125+
select {
126+
case f := <-onErrorReq:
127+
f(request, err)
128+
default:
129+
t.Errorf("unexpected error: %+v", err)
128130
}
129-
f(request, err)
130131
}, func(response *http.Response, err error) error {
131-
f := <-onErrorResp
132-
if f == nil {
133-
return nil
132+
select {
133+
case f := <-onErrorResp:
134+
return f(response, err)
135+
default:
136+
t.Errorf("unexpected error: %+v", err)
137+
return err
134138
}
135-
return f(response, err)
136139
})))
137140

138141
cl := proxy.Client()
@@ -315,8 +318,7 @@ func TestFullIntegration(t *testing.T) {
315318
req.Host = "auth.example.com"
316319
return req
317320
},
318-
assertResponse: func(t *testing.T, r *http.Response) {
319-
},
321+
assertResponse: func(t *testing.T, r *http.Response) {},
320322
respMiddleware: func(resp *http.Response, config *HostConfig, body []byte) ([]byte, error) {
321323
return nil, errors.New("some response middleware error")
322324
},
@@ -495,37 +497,55 @@ func TestFullIntegration(t *testing.T) {
495497
},
496498
} {
497499
t.Run("case="+tc.desc, func(t *testing.T) {
498-
go func() {
499-
hostMapper <- func(r *http.Request) (*HostConfig, error) {
500-
host := r.Host
501-
hc, err := tc.hostMapper(host)
502-
if err == nil {
503-
hc.UpstreamHost = urlx.ParseOrPanic(upstreamServer.URL).Host
504-
hc.UpstreamScheme = urlx.ParseOrPanic(upstreamServer.URL).Scheme
505-
hc.TargetHost = hc.UpstreamHost
506-
hc.TargetScheme = hc.UpstreamScheme
507-
}
508-
return hc, err
500+
hostMapper <- func(r *http.Request) (*HostConfig, error) {
501+
host := r.Host
502+
hc, err := tc.hostMapper(host)
503+
if err == nil {
504+
hc.UpstreamHost = urlx.ParseOrPanic(upstreamServer.URL).Host
505+
hc.UpstreamScheme = urlx.ParseOrPanic(upstreamServer.URL).Scheme
506+
hc.TargetHost = hc.UpstreamHost
507+
hc.TargetScheme = hc.UpstreamScheme
509508
}
509+
return hc, err
510+
}
511+
if tc.onErrReq != nil {
512+
onErrorReq <- tc.onErrReq
513+
}
514+
if tc.onErrResp != nil {
515+
onErrorResp <- tc.onErrResp
516+
}
517+
518+
if tc.onErrReq == nil {
519+
// we will only send a request if there is no request error
510520
reqMiddleware <- tc.reqMiddleware
521+
respMiddleware <- tc.respMiddleware
511522
upstreamHandler <- func(w http.ResponseWriter, r *http.Request) {
512523
t := &remoteT{t: t, w: w, r: r}
513524
tc.handler(assert.New(t), t, r)
514525
}
515-
respMiddleware <- tc.respMiddleware
516-
}()
517-
518-
go func() {
519-
onErrorReq <- tc.onErrReq
520-
}()
521-
522-
go func() {
523-
onErrorResp <- tc.onErrResp
524-
}()
526+
}
525527

526528
resp, err := cl.Do(tc.request(t))
527529
require.NoError(t, err)
528530
tc.assertResponse(t, resp)
531+
532+
select {
533+
case <-hostMapper:
534+
t.Fatal("host mapper not consumed")
535+
case <-reqMiddleware:
536+
t.Fatal("req middleware not consumed")
537+
case <-respMiddleware:
538+
t.Fatal("resp middleware not consumed")
539+
case <-onErrorReq:
540+
t.Fatal("req error not consumed")
541+
case <-onErrorResp:
542+
t.Fatal("resp error not consumed")
543+
default:
544+
if len(upstreamHandler) != 0 {
545+
t.Fatal("upstream handler not consumed")
546+
}
547+
return
548+
}
529549
})
530550
}
531551
}

reqlog/external_latency.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// Copyright © 2024 Ory Corp
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package reqlog
5+
6+
import (
7+
"context"
8+
"sync"
9+
"time"
10+
)
11+
12+
// WithEnableExternalLatencyMeasurement returns a context that measures external latencies.
13+
func WithEnableExternalLatencyMeasurement(ctx context.Context) context.Context {
14+
container := contextContainer{
15+
latencies: make([]externalLatency, 0),
16+
}
17+
return context.WithValue(ctx, externalLatencyKey, &container)
18+
}
19+
20+
// StartMeasureExternalCall starts measuring the duration of an external call.
21+
// The returned function has to be called to record the duration.
22+
func StartMeasureExternalCall(ctx context.Context, cause, detail string, start time.Time) {
23+
container, ok := ctx.Value(externalLatencyKey).(*contextContainer)
24+
if !ok {
25+
return
26+
}
27+
if _, ok := ctx.Value(disableExternalLatencyMeasurement).(bool); ok {
28+
return
29+
}
30+
31+
container.Lock()
32+
defer container.Unlock()
33+
container.latencies = append(container.latencies, externalLatency{
34+
Took: time.Since(start),
35+
Cause: cause,
36+
Detail: detail,
37+
})
38+
}
39+
40+
// totalExternalLatency returns the total duration of all external calls.
41+
func totalExternalLatency(ctx context.Context) (total time.Duration) {
42+
if _, ok := ctx.Value(disableExternalLatencyMeasurement).(bool); ok {
43+
return 0
44+
}
45+
container, ok := ctx.Value(externalLatencyKey).(*contextContainer)
46+
if !ok {
47+
return 0
48+
}
49+
50+
container.Lock()
51+
defer container.Unlock()
52+
for _, l := range container.latencies {
53+
total += l.Took
54+
}
55+
return total
56+
}
57+
58+
// WithDisableExternalLatencyMeasurement returns a context that does not measure external latencies.
59+
// Use this when you want to disable external latency measurements for a specific request.
60+
func WithDisableExternalLatencyMeasurement(ctx context.Context) context.Context {
61+
return context.WithValue(ctx, disableExternalLatencyMeasurement, true)
62+
}
63+
64+
type (
65+
externalLatency = struct {
66+
Took time.Duration
67+
Cause, Detail string
68+
}
69+
contextContainer = struct {
70+
latencies []externalLatency
71+
sync.Mutex
72+
}
73+
contextKey int
74+
)
75+
76+
const (
77+
externalLatencyKey contextKey = 1
78+
disableExternalLatencyMeasurement contextKey = 2
79+
)

reqlog/external_latency_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright © 2024 Ory Corp
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package reqlog
5+
6+
import (
7+
"encoding/json"
8+
"io"
9+
"net/http"
10+
"net/http/httptest"
11+
"sync"
12+
"testing"
13+
"time"
14+
15+
"github.com/stretchr/testify/assert"
16+
"github.com/stretchr/testify/require"
17+
"github.com/tidwall/gjson"
18+
"golang.org/x/sync/errgroup"
19+
)
20+
21+
func TestExternalLatencyMiddleware(t *testing.T) {
22+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
23+
NewMiddleware().ServeHTTP(w, r, func(w http.ResponseWriter, r *http.Request) {
24+
var wg sync.WaitGroup
25+
26+
wg.Add(3)
27+
for i := range 3 {
28+
ctx := r.Context()
29+
if i%3 == 0 {
30+
ctx = WithDisableExternalLatencyMeasurement(ctx)
31+
}
32+
go func() {
33+
defer StartMeasureExternalCall(ctx, "", "", time.Now())
34+
time.Sleep(100 * time.Millisecond)
35+
wg.Done()
36+
}()
37+
}
38+
wg.Wait()
39+
total := totalExternalLatency(r.Context())
40+
_ = json.NewEncoder(w).Encode(map[string]any{
41+
"total": total,
42+
})
43+
})
44+
}))
45+
defer ts.Close()
46+
47+
bodies := make([][]byte, 100)
48+
eg := errgroup.Group{}
49+
for i := range bodies {
50+
eg.Go(func() error {
51+
res, err := http.Get(ts.URL)
52+
if err != nil {
53+
return err
54+
}
55+
defer res.Body.Close()
56+
bodies[i], err = io.ReadAll(res.Body)
57+
if err != nil {
58+
return err
59+
}
60+
return nil
61+
})
62+
}
63+
64+
require.NoError(t, eg.Wait())
65+
66+
for _, body := range bodies {
67+
actualTotal := gjson.GetBytes(body, "total").Int()
68+
assert.GreaterOrEqual(t, actualTotal, int64(200*time.Millisecond), string(body))
69+
assert.Less(t, actualTotal, int64(300*time.Millisecond), string(body))
70+
}
71+
}

reqlog/middleware.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ func (m *Middleware) ServeHTTP(rw http.ResponseWriter, r *http.Request, next htt
139139
nrw = negroni.NewResponseWriter(rw)
140140
}
141141

142+
r = r.WithContext(WithEnableExternalLatencyMeasurement(r.Context()))
142143
next(nrw, r)
143144

144145
latency := m.clock.Since(start)
@@ -161,11 +162,18 @@ func DefaultBefore(entry *logrusx.Logger, req *http.Request, remoteAddr string)
161162

162163
// DefaultAfter is the default func assigned to *Middleware.After
163164
func DefaultAfter(entry *logrusx.Logger, req *http.Request, res negroni.ResponseWriter, latency time.Duration, name string) *logrusx.Logger {
164-
return entry.WithRequest(req).WithField("http_response", map[string]interface{}{
165+
e := entry.WithRequest(req).WithField("http_response", map[string]any{
165166
"status": res.Status(),
166167
"size": res.Size(),
167168
"text_status": http.StatusText(res.Status()),
168169
"took": latency,
169170
"headers": entry.HTTPHeadersRedacted(res.Header()),
170171
})
172+
if el := totalExternalLatency(req.Context()); el > 0 {
173+
e = e.WithFields(map[string]any{
174+
"took_internal": latency - el,
175+
"took_external": el,
176+
})
177+
}
178+
return e
171179
}

0 commit comments

Comments
 (0)