@@ -86,20 +86,20 @@ func (rt *testingRoundTripper) RoundTrip(req *http.Request) (*http.Response, err
8686}
8787
8888func TestFullIntegration (t * testing.T ) {
89- upstream , upstreamHandler := httpx .NewChanHandler (0 )
89+ upstream , upstreamHandler := httpx .NewChanHandler (1 )
9090 upstreamServer := httptest .NewTLSServer (upstream )
9191 defer upstreamServer .Close ()
9292
9393 // create the proxy
94- hostMapper := make (chan func (* http.Request ) (* HostConfig , error ))
95- reqMiddleware := make (chan ReqMiddleware )
96- respMiddleware := make (chan RespMiddleware )
94+ hostMapper := make (chan func (* http.Request ) (* HostConfig , error ), 1 )
95+ reqMiddleware := make (chan ReqMiddleware , 1 )
96+ respMiddleware := make (chan RespMiddleware , 1 )
9797
9898 type CustomErrorReq func (* http.Request , error )
9999 type CustomErrorResp func (* http.Response , error ) error
100100
101- onErrorReq := make (chan CustomErrorReq )
102- onErrorResp := make (chan CustomErrorResp )
101+ onErrorReq := make (chan CustomErrorReq , 1 )
102+ onErrorResp := make (chan CustomErrorResp , 1 )
103103
104104 proxy := httptest .NewTLSServer (New (
105105 func (ctx context.Context , r * http.Request ) (context.Context , * HostConfig , error ) {
@@ -122,17 +122,20 @@ func TestFullIntegration(t *testing.T) {
122122 return f (resp , config , body )
123123 }),
124124 WithOnError (func (request * http.Request , err error ) {
125- f := <- onErrorReq
126- if f == nil {
127- return
125+ select {
126+ case f := <- onErrorReq :
127+ f (request , err )
128+ default :
129+ t .Errorf ("unexpected error: %+v" , err )
128130 }
129- f (request , err )
130131 }, func (response * http.Response , err error ) error {
131- f := <- onErrorResp
132- if f == nil {
133- return nil
132+ select {
133+ case f := <- onErrorResp :
134+ return f (response , err )
135+ default :
136+ t .Errorf ("unexpected error: %+v" , err )
137+ return err
134138 }
135- return f (response , err )
136139 })))
137140
138141 cl := proxy .Client ()
@@ -315,8 +318,7 @@ func TestFullIntegration(t *testing.T) {
315318 req .Host = "auth.example.com"
316319 return req
317320 },
318- assertResponse : func (t * testing.T , r * http.Response ) {
319- },
321+ assertResponse : func (t * testing.T , r * http.Response ) {},
320322 respMiddleware : func (resp * http.Response , config * HostConfig , body []byte ) ([]byte , error ) {
321323 return nil , errors .New ("some response middleware error" )
322324 },
@@ -495,37 +497,55 @@ func TestFullIntegration(t *testing.T) {
495497 },
496498 } {
497499 t .Run ("case=" + tc .desc , func (t * testing.T ) {
498- go func () {
499- hostMapper <- func (r * http.Request ) (* HostConfig , error ) {
500- host := r .Host
501- hc , err := tc .hostMapper (host )
502- if err == nil {
503- hc .UpstreamHost = urlx .ParseOrPanic (upstreamServer .URL ).Host
504- hc .UpstreamScheme = urlx .ParseOrPanic (upstreamServer .URL ).Scheme
505- hc .TargetHost = hc .UpstreamHost
506- hc .TargetScheme = hc .UpstreamScheme
507- }
508- return hc , err
500+ hostMapper <- func (r * http.Request ) (* HostConfig , error ) {
501+ host := r .Host
502+ hc , err := tc .hostMapper (host )
503+ if err == nil {
504+ hc .UpstreamHost = urlx .ParseOrPanic (upstreamServer .URL ).Host
505+ hc .UpstreamScheme = urlx .ParseOrPanic (upstreamServer .URL ).Scheme
506+ hc .TargetHost = hc .UpstreamHost
507+ hc .TargetScheme = hc .UpstreamScheme
509508 }
509+ return hc , err
510+ }
511+ if tc .onErrReq != nil {
512+ onErrorReq <- tc .onErrReq
513+ }
514+ if tc .onErrResp != nil {
515+ onErrorResp <- tc .onErrResp
516+ }
517+
518+ if tc .onErrReq == nil {
519+ // we will only send a request if there is no request error
510520 reqMiddleware <- tc .reqMiddleware
521+ respMiddleware <- tc .respMiddleware
511522 upstreamHandler <- func (w http.ResponseWriter , r * http.Request ) {
512523 t := & remoteT {t : t , w : w , r : r }
513524 tc .handler (assert .New (t ), t , r )
514525 }
515- respMiddleware <- tc .respMiddleware
516- }()
517-
518- go func () {
519- onErrorReq <- tc .onErrReq
520- }()
521-
522- go func () {
523- onErrorResp <- tc .onErrResp
524- }()
526+ }
525527
526528 resp , err := cl .Do (tc .request (t ))
527529 require .NoError (t , err )
528530 tc .assertResponse (t , resp )
531+
532+ select {
533+ case <- hostMapper :
534+ t .Fatal ("host mapper not consumed" )
535+ case <- reqMiddleware :
536+ t .Fatal ("req middleware not consumed" )
537+ case <- respMiddleware :
538+ t .Fatal ("resp middleware not consumed" )
539+ case <- onErrorReq :
540+ t .Fatal ("req error not consumed" )
541+ case <- onErrorResp :
542+ t .Fatal ("resp error not consumed" )
543+ default :
544+ if len (upstreamHandler ) != 0 {
545+ t .Fatal ("upstream handler not consumed" )
546+ }
547+ return
548+ }
529549 })
530550 }
531551}
0 commit comments