@@ -301,77 +301,135 @@ func (t *DTLSTransport) role() DTLSRole {
301301}
302302
303303// Start DTLS transport negotiation with the parameters of the remote DTLS transport.
304- func (t * DTLSTransport ) Start (remoteParameters DTLSParameters ) error { //nolint:gocognit,cyclop
305- // Take lock and prepare connection, we must not hold the lock
306- // when connecting
307- prepareTransport := func () (DTLSRole , * dtls.Config , error ) {
308- t .lock .Lock ()
309- defer t .lock .Unlock ()
304+ func (t * DTLSTransport ) Start (remoteParameters DTLSParameters ) error {
305+ role , certificate , err := t .prepareStart (remoteParameters )
306+ if err != nil {
307+ return err
308+ }
310309
311- if err := t .ensureICEConn (); err != nil {
312- return DTLSRole (0 ), nil , err
313- }
310+ dtlsEndpoint := t .iceTransport .newEndpoint (mux .MatchDTLS )
311+ dtlsEndpoint .SetOnClose (t .internalOnCloseHandler )
314312
315- if t .state != DTLSTransportStateNew {
316- return DTLSRole (0 ), nil , & rtcerr.InvalidStateError {Err : fmt .Errorf ("%w: %s" , errInvalidDTLSStart , t .state )}
317- }
313+ sharedOpts := t .dtlsSharedOptions (certificate )
318314
319- t .srtpEndpoint = t .iceTransport .newEndpoint (mux .MatchSRTP )
320- t .srtcpEndpoint = t .iceTransport .newEndpoint (mux .MatchSRTCP )
321- t .remoteParameters = remoteParameters
315+ dtlsConn , err := t .connectDTLS (dtlsEndpoint , role , sharedOpts )
316+ if err != nil {
317+ dtlsEndpoint .SetOnClose (nil )
318+ _ = dtlsEndpoint .Close ()
322319
323- cert := t .certificates [0 ]
324- t .onStateChange (DTLSTransportStateConnecting )
320+ return t .failStart (err )
321+ }
322+
323+ if err = t .handshakeDTLS (dtlsConn ); err != nil {
324+ dtlsEndpoint .SetOnClose (nil )
325+ _ = dtlsConn .Close ()
326+
327+ return t .failStart (err )
328+ }
329+
330+ if err = t .completeStart (dtlsConn ); err != nil {
331+ dtlsEndpoint .SetOnClose (nil )
332+ _ = dtlsConn .Close ()
325333
326- return t .role (), & dtls.Config {
327- Certificates : []tls.Certificate {
328- {
329- Certificate : [][]byte {cert .x509Cert .Raw },
330- PrivateKey : cert .privateKey ,
331- },
332- },
333- SRTPProtectionProfiles : func () []dtls.SRTPProtectionProfile {
334- if len (t .api .settingEngine .srtpProtectionProfiles ) > 0 {
335- return t .api .settingEngine .srtpProtectionProfiles
336- }
337-
338- return defaultSrtpProtectionProfiles ()
339- }(),
340- ClientAuth : dtls .RequireAnyClientCert ,
341- LoggerFactory : t .api .settingEngine .LoggerFactory ,
342- InsecureSkipVerify : ! t .api .settingEngine .dtls .disableInsecureSkipVerify ,
343- CipherSuites : t .api .settingEngine .dtls .cipherSuites ,
344- CustomCipherSuites : t .api .settingEngine .dtls .customCipherSuites ,
345- }, nil
346- }
347-
348- var dtlsConn * dtls.Conn
349- dtlsEndpoint := t .iceTransport .newEndpoint (mux .MatchDTLS )
350- dtlsEndpoint .SetOnClose (t .internalOnCloseHandler )
351- role , dtlsConfig , err := prepareTransport ()
352- if err != nil {
353334 return err
354335 }
355336
337+ return nil
338+ }
339+
340+ func (t * DTLSTransport ) prepareStart (remoteParameters DTLSParameters ) (DTLSRole , tls.Certificate , error ) {
341+ t .lock .Lock ()
342+ defer t .lock .Unlock ()
343+
344+ if err := t .ensureICEConn (); err != nil {
345+ return DTLSRole (0 ), tls.Certificate {}, err
346+ }
347+
348+ if t .state != DTLSTransportStateNew {
349+ return DTLSRole (0 ), tls.Certificate {}, & rtcerr.InvalidStateError {
350+ Err : fmt .Errorf ("%w: %s" , errInvalidDTLSStart , t .state ),
351+ }
352+ }
353+
354+ t .srtpEndpoint = t .iceTransport .newEndpoint (mux .MatchSRTP )
355+ t .srtcpEndpoint = t .iceTransport .newEndpoint (mux .MatchSRTCP )
356+ t .remoteParameters = remoteParameters
357+
358+ cert := t .certificates [0 ]
359+ t .onStateChange (DTLSTransportStateConnecting )
360+
361+ return t .role (), tls.Certificate {
362+ Certificate : [][]byte {cert .x509Cert .Raw },
363+ PrivateKey : cert .privateKey ,
364+ }, nil
365+ }
366+
367+ func (t * DTLSTransport ) dtlsSharedOptions (certificate tls.Certificate ) []dtls.Option {
368+ sharedOpts := []dtls.Option {
369+ dtls .WithCertificates (certificate ),
370+ dtls .WithSRTPProtectionProfiles (t .srtpProtectionProfiles ()... ),
371+ dtls .WithExtendedMasterSecret (t .api .settingEngine .dtls .extendedMasterSecret ),
372+ dtls .WithInsecureSkipVerify (! t .api .settingEngine .dtls .disableInsecureSkipVerify ),
373+ dtls .WithLoggerFactory (t .api .settingEngine .LoggerFactory ),
374+ dtls .WithVerifyPeerCertificate (t .verifyPeerCertificateFunc ()),
375+ }
376+
377+ if t .api .settingEngine .dtls .customCipherSuites != nil {
378+ sharedOpts = append (
379+ sharedOpts ,
380+ dtls .WithCustomCipherSuites (t .api .settingEngine .dtls .customCipherSuites ),
381+ )
382+ }
383+
384+ if t .api .settingEngine .dtls .retransmissionInterval > 0 {
385+ sharedOpts = append (
386+ sharedOpts ,
387+ dtls .WithFlightInterval (t .api .settingEngine .dtls .retransmissionInterval ),
388+ )
389+ }
390+
356391 if t .api .settingEngine .replayProtection .DTLS != nil {
357- dtlsConfig .ReplayProtectionWindow = int (* t .api .settingEngine .replayProtection .DTLS ) //nolint:gosec // G115
392+ sharedOpts = append (
393+ sharedOpts ,
394+ dtls .WithReplayProtectionWindow (int (* t .api .settingEngine .replayProtection .DTLS )), //nolint:gosec // G115
395+ )
358396 }
359397
360- if t .api .settingEngine .dtls .clientAuth != nil {
361- dtlsConfig .ClientAuth = * t .api .settingEngine .dtls .clientAuth
362- }
363-
364- dtlsConfig .FlightInterval = t .api .settingEngine .dtls .retransmissionInterval
365- dtlsConfig .InsecureSkipVerifyHello = t .api .settingEngine .dtls .insecureSkipHelloVerify
366- dtlsConfig .EllipticCurves = t .api .settingEngine .dtls .ellipticCurves
367- dtlsConfig .ExtendedMasterSecret = t .api .settingEngine .dtls .extendedMasterSecret
368- dtlsConfig .ClientCAs = t .api .settingEngine .dtls .clientCAs
369- dtlsConfig .RootCAs = t .api .settingEngine .dtls .rootCAs
370- dtlsConfig .KeyLogWriter = t .api .settingEngine .dtls .keyLogWriter
371- dtlsConfig .ClientHelloMessageHook = t .api .settingEngine .dtls .clientHelloMessageHook
372- dtlsConfig .ServerHelloMessageHook = t .api .settingEngine .dtls .serverHelloMessageHook
373- dtlsConfig .CertificateRequestMessageHook = t .api .settingEngine .dtls .certificateRequestMessageHook
374- dtlsConfig .VerifyPeerCertificate = func (rawCerts [][]byte , _verifiedChains [][]* x509.Certificate ) error {
398+ if t .api .settingEngine .dtls .cipherSuites != nil {
399+ sharedOpts = append (
400+ sharedOpts ,
401+ dtls .WithCipherSuites (t .api .settingEngine .dtls .cipherSuites ... ),
402+ )
403+ }
404+
405+ if len (t .api .settingEngine .dtls .ellipticCurves ) > 0 {
406+ sharedOpts = append (
407+ sharedOpts ,
408+ dtls .WithEllipticCurves (t .api .settingEngine .dtls .ellipticCurves ... ),
409+ )
410+ }
411+
412+ if t .api .settingEngine .dtls .rootCAs != nil {
413+ sharedOpts = append (sharedOpts , dtls .WithRootCAs (t .api .settingEngine .dtls .rootCAs ))
414+ }
415+
416+ if t .api .settingEngine .dtls .keyLogWriter != nil {
417+ sharedOpts = append (sharedOpts , dtls .WithKeyLogWriter (t .api .settingEngine .dtls .keyLogWriter ))
418+ }
419+
420+ return sharedOpts
421+ }
422+
423+ func (t * DTLSTransport ) srtpProtectionProfiles () []dtls.SRTPProtectionProfile {
424+ if len (t .api .settingEngine .srtpProtectionProfiles ) > 0 {
425+ return t .api .settingEngine .srtpProtectionProfiles
426+ }
427+
428+ return defaultSrtpProtectionProfiles ()
429+ }
430+
431+ func (t * DTLSTransport ) verifyPeerCertificateFunc () func ([][]byte , [][]* x509.Certificate ) error {
432+ return func (rawCerts [][]byte , _ [][]* x509.Certificate ) error {
375433 if len (rawCerts ) == 0 {
376434 return errNoRemoteCertificate
377435 }
@@ -384,32 +442,105 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { //nolint:
384442 return nil
385443 }
386444
387- parsedRemoteCert , parseErr := x509 .ParseCertificate (t .remoteCertificate )
388- if parseErr != nil {
389- return parseErr
445+ parsedRemoteCert , err := x509 .ParseCertificate (t .remoteCertificate )
446+ if err != nil {
447+ return err
390448 }
391449
392450 return t .validateFingerPrint (parsedRemoteCert )
393451 }
452+ }
394453
395- // Connect as DTLS Client/Server, function is blocking and we
396- // must not hold the DTLSTransport lock
454+ func (t * DTLSTransport ) connectDTLS (
455+ dtlsEndpoint * mux.Endpoint ,
456+ role DTLSRole ,
457+ sharedOpts []dtls.Option ,
458+ ) (* dtls.Conn , error ) {
397459 if role == DTLSRoleClient {
398- dtlsConn , err = dtls .Client (dtlsEndpoint , dtlsEndpoint .RemoteAddr (), dtlsConfig )
399- } else {
400- dtlsConn , err = dtls .Server (dtlsEndpoint , dtlsEndpoint .RemoteAddr (), dtlsConfig )
460+ clientOpts := t .toDTLSClientOptions (sharedOpts )
461+
462+ return dtls .ClientWithOptions (
463+ dtlsEndpoint ,
464+ dtlsEndpoint .RemoteAddr (),
465+ clientOpts ... ,
466+ )
401467 }
402468
403- if err == nil {
404- if t .api .settingEngine .dtls .connectContextMaker != nil {
405- handshakeCtx , _ := t .api .settingEngine .dtls .connectContextMaker ()
406- err = dtlsConn .HandshakeContext (handshakeCtx )
407- } else {
408- err = dtlsConn .Handshake ()
409- }
469+ serverOpts := t .toDTLSServerOptions (sharedOpts )
470+
471+ return dtls .ServerWithOptions (
472+ dtlsEndpoint ,
473+ dtlsEndpoint .RemoteAddr (),
474+ serverOpts ... ,
475+ )
476+ }
477+
478+ func (t * DTLSTransport ) toDTLSServerOptions (sharedOpts []dtls.Option ) []dtls.ServerOption {
479+ serverOpts := make ([]dtls.ServerOption , 0 , len (sharedOpts )+ 5 )
480+ for _ , opt := range sharedOpts {
481+ serverOpts = append (serverOpts , opt )
482+ }
483+
484+ clientAuth := dtls .RequireAnyClientCert
485+ if t .api .settingEngine .dtls .clientAuth != nil {
486+ clientAuth = * t .api .settingEngine .dtls .clientAuth
487+ }
488+
489+ serverOpts = append (serverOpts ,
490+ dtls .WithClientAuth (clientAuth ),
491+ dtls .WithClientCAs (t .api .settingEngine .dtls .clientCAs ),
492+ dtls .WithInsecureSkipVerifyHello (t .api .settingEngine .dtls .insecureSkipHelloVerify ),
493+ )
494+
495+ if t .api .settingEngine .dtls .serverHelloMessageHook != nil {
496+ serverOpts = append (
497+ serverOpts ,
498+ dtls .WithServerHelloMessageHook (t .api .settingEngine .dtls .serverHelloMessageHook ),
499+ )
500+ }
501+
502+ if t .api .settingEngine .dtls .certificateRequestMessageHook != nil {
503+ serverOpts = append (
504+ serverOpts ,
505+ dtls .WithCertificateRequestMessageHook (t .api .settingEngine .dtls .certificateRequestMessageHook ),
506+ )
507+ }
508+
509+ return serverOpts
510+ }
511+
512+ func (t * DTLSTransport ) toDTLSClientOptions (sharedOpts []dtls.Option ) []dtls.ClientOption {
513+ clientOpts := make ([]dtls.ClientOption , 0 , len (sharedOpts )+ 1 )
514+ for _ , opt := range sharedOpts {
515+ clientOpts = append (clientOpts , opt )
516+ }
517+
518+ if t .api .settingEngine .dtls .clientHelloMessageHook != nil {
519+ clientOpts = append (
520+ clientOpts ,
521+ dtls .WithClientHelloMessageHook (t .api .settingEngine .dtls .clientHelloMessageHook ),
522+ )
523+ }
524+
525+ return clientOpts
526+ }
527+
528+ func (t * DTLSTransport ) handshakeDTLS (dtlsConn * dtls.Conn ) error {
529+ if t .api .settingEngine .dtls .connectContextMaker == nil {
530+ return dtlsConn .Handshake ()
531+ }
532+
533+ handshakeCtx , cancel := t .api .settingEngine .dtls .connectContextMaker ()
534+ if cancel != nil {
535+ defer cancel ()
410536 }
411537
412- // Re-take the lock, nothing beyond here is blocking
538+ return dtlsConn .HandshakeContext (handshakeCtx )
539+ }
540+
541+ func (t * DTLSTransport ) completeStart (dtlsConn * dtls.Conn ) error {
542+ srtpProtectionProfile , err := srtpProtectionProfileFromDTLSConn (dtlsConn )
543+
413544 t .lock .Lock ()
414545 defer t .lock .Unlock ()
415546
@@ -419,32 +550,43 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { //nolint:
419550 return err
420551 }
421552
553+ t .srtpProtectionProfile = srtpProtectionProfile
554+ t .conn = dtlsConn
555+ t .onStateChange (DTLSTransportStateConnected )
556+
557+ return t .startSRTP ()
558+ }
559+
560+ func (t * DTLSTransport ) failStart (err error ) error {
561+ t .lock .Lock ()
562+ defer t .lock .Unlock ()
563+ t .onStateChange (DTLSTransportStateFailed )
564+
565+ return err
566+ }
567+
568+ func srtpProtectionProfileFromDTLSConn (dtlsConn * dtls.Conn ) (srtp.ProtectionProfile , error ) {
422569 srtpProfile , ok := dtlsConn .SelectedSRTPProtectionProfile ()
423570 if ! ok {
424- t .onStateChange (DTLSTransportStateFailed )
425-
426- return ErrNoSRTPProtectionProfile
571+ return 0 , ErrNoSRTPProtectionProfile
427572 }
428573
574+ return srtpProtectionProfileFromDTLS (srtpProfile )
575+ }
576+
577+ func srtpProtectionProfileFromDTLS (srtpProfile dtls.SRTPProtectionProfile ) (srtp.ProtectionProfile , error ) {
429578 switch srtpProfile {
430579 case dtls .SRTP_AEAD_AES_128_GCM :
431- t . srtpProtectionProfile = srtp .ProtectionProfileAeadAes128Gcm
580+ return srtp .ProtectionProfileAeadAes128Gcm , nil
432581 case dtls .SRTP_AEAD_AES_256_GCM :
433- t . srtpProtectionProfile = srtp .ProtectionProfileAeadAes256Gcm
582+ return srtp .ProtectionProfileAeadAes256Gcm , nil
434583 case dtls .SRTP_AES128_CM_HMAC_SHA1_80 :
435- t . srtpProtectionProfile = srtp .ProtectionProfileAes128CmHmacSha1_80
584+ return srtp .ProtectionProfileAes128CmHmacSha1_80 , nil
436585 case dtls .SRTP_NULL_HMAC_SHA1_80 :
437- t . srtpProtectionProfile = srtp .ProtectionProfileNullHmacSha1_80
586+ return srtp .ProtectionProfileNullHmacSha1_80 , nil
438587 default :
439- t .onStateChange (DTLSTransportStateFailed )
440-
441- return ErrNoSRTPProtectionProfile
588+ return 0 , ErrNoSRTPProtectionProfile
442589 }
443-
444- t .conn = dtlsConn
445- t .onStateChange (DTLSTransportStateConnected )
446-
447- return t .startSRTP ()
448590}
449591
450592// Stop stops and closes the DTLSTransport object.
0 commit comments