@@ -32,6 +32,8 @@ package gax
3232import (
3333 "context"
3434 "errors"
35+ "fmt"
36+ "strconv"
3537 "testing"
3638 "time"
3739
@@ -40,6 +42,7 @@ import (
4042 "github.com/googleapis/gax-go/v2/apierror"
4143 "google.golang.org/genproto/googleapis/rpc/errdetails"
4244 "google.golang.org/grpc/codes"
45+ "google.golang.org/grpc/metadata"
4346 "google.golang.org/grpc/status"
4447)
4548
@@ -264,3 +267,47 @@ func TestInvokeWithTimeout(t *testing.T) {
264267 })
265268 }
266269}
270+
271+ func TestInvokeRetryCount (t * testing.T ) {
272+ for _ , tracingEnabled := range []bool {true , false } {
273+ t .Run (fmt .Sprintf ("tracingEnabled=%v" , tracingEnabled ), func (t * testing.T ) {
274+ TestOnlyResetIsFeatureEnabled ()
275+ defer TestOnlyResetIsFeatureEnabled ()
276+
277+ if tracingEnabled {
278+ t .Setenv ("GOOGLE_SDK_GO_EXPERIMENTAL_TRACING" , "true" )
279+ } else {
280+ t .Setenv ("GOOGLE_SDK_GO_EXPERIMENTAL_TRACING" , "false" )
281+ }
282+
283+ const target = 3
284+ var retryCounts []int
285+ calls := 0
286+ apiCall := func (ctx context.Context , _ CallSettings ) error {
287+ calls ++
288+ md , _ := metadata .FromOutgoingContext (ctx )
289+ if vals := md ["gcp.grpc.resend_count" ]; len (vals ) > 0 {
290+ if count , err := strconv .Atoi (vals [0 ]); err == nil {
291+ retryCounts = append (retryCounts , count )
292+ }
293+ }
294+ if calls < target {
295+ return errors .New ("retry" )
296+ }
297+ return nil
298+ }
299+ var settings CallSettings
300+ WithRetry (func () Retryer { return boolRetryer (true ) }).Resolve (& settings )
301+ var sp recordSleeper
302+ invoke (context .Background (), apiCall , settings , sp .sleep )
303+
304+ var want []int
305+ if tracingEnabled {
306+ want = []int {0 , 1 , 2 }
307+ }
308+ if diff := cmp .Diff (want , retryCounts ); diff != "" {
309+ t .Errorf ("retry count mismatch (-want +got):\n %s" , diff )
310+ }
311+ })
312+ }
313+ }
0 commit comments