Skip to content

Commit d4e4eaa

Browse files
committed
Upgrade to dtls options
1 parent 8b9515c commit d4e4eaa

File tree

4 files changed

+640
-94
lines changed

4 files changed

+640
-94
lines changed

dtlstransport.go

Lines changed: 233 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -301,77 +301,135 @@ func (t *DTLSTransport) role() DTLSRole {
301301
}
302302

303303
// Start DTLS transport negotiation with the parameters of the remote DTLS transport.
304-
func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { //nolint:gocognit,cyclop
305-
// Take lock and prepare connection, we must not hold the lock
306-
// when connecting
307-
prepareTransport := func() (DTLSRole, *dtls.Config, error) {
308-
t.lock.Lock()
309-
defer t.lock.Unlock()
304+
func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error {
305+
role, certificate, err := t.prepareStart(remoteParameters)
306+
if err != nil {
307+
return err
308+
}
310309

311-
if err := t.ensureICEConn(); err != nil {
312-
return DTLSRole(0), nil, err
313-
}
310+
dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS)
311+
dtlsEndpoint.SetOnClose(t.internalOnCloseHandler)
314312

315-
if t.state != DTLSTransportStateNew {
316-
return DTLSRole(0), nil, &rtcerr.InvalidStateError{Err: fmt.Errorf("%w: %s", errInvalidDTLSStart, t.state)}
317-
}
313+
sharedOpts := t.dtlsSharedOptions(certificate)
318314

319-
t.srtpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTP)
320-
t.srtcpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTCP)
321-
t.remoteParameters = remoteParameters
315+
dtlsConn, err := t.connectDTLS(dtlsEndpoint, role, sharedOpts)
316+
if err != nil {
317+
dtlsEndpoint.SetOnClose(nil)
318+
_ = dtlsEndpoint.Close()
322319

323-
cert := t.certificates[0]
324-
t.onStateChange(DTLSTransportStateConnecting)
320+
return t.failStart(err)
321+
}
322+
323+
if err = t.handshakeDTLS(dtlsConn); err != nil {
324+
dtlsEndpoint.SetOnClose(nil)
325+
_ = dtlsConn.Close()
326+
327+
return t.failStart(err)
328+
}
329+
330+
if err = t.completeStart(dtlsConn); err != nil {
331+
dtlsEndpoint.SetOnClose(nil)
332+
_ = dtlsConn.Close()
325333

326-
return t.role(), &dtls.Config{
327-
Certificates: []tls.Certificate{
328-
{
329-
Certificate: [][]byte{cert.x509Cert.Raw},
330-
PrivateKey: cert.privateKey,
331-
},
332-
},
333-
SRTPProtectionProfiles: func() []dtls.SRTPProtectionProfile {
334-
if len(t.api.settingEngine.srtpProtectionProfiles) > 0 {
335-
return t.api.settingEngine.srtpProtectionProfiles
336-
}
337-
338-
return defaultSrtpProtectionProfiles()
339-
}(),
340-
ClientAuth: dtls.RequireAnyClientCert,
341-
LoggerFactory: t.api.settingEngine.LoggerFactory,
342-
InsecureSkipVerify: !t.api.settingEngine.dtls.disableInsecureSkipVerify,
343-
CipherSuites: t.api.settingEngine.dtls.cipherSuites,
344-
CustomCipherSuites: t.api.settingEngine.dtls.customCipherSuites,
345-
}, nil
346-
}
347-
348-
var dtlsConn *dtls.Conn
349-
dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS)
350-
dtlsEndpoint.SetOnClose(t.internalOnCloseHandler)
351-
role, dtlsConfig, err := prepareTransport()
352-
if err != nil {
353334
return err
354335
}
355336

337+
return nil
338+
}
339+
340+
func (t *DTLSTransport) prepareStart(remoteParameters DTLSParameters) (DTLSRole, tls.Certificate, error) {
341+
t.lock.Lock()
342+
defer t.lock.Unlock()
343+
344+
if err := t.ensureICEConn(); err != nil {
345+
return DTLSRole(0), tls.Certificate{}, err
346+
}
347+
348+
if t.state != DTLSTransportStateNew {
349+
return DTLSRole(0), tls.Certificate{}, &rtcerr.InvalidStateError{
350+
Err: fmt.Errorf("%w: %s", errInvalidDTLSStart, t.state),
351+
}
352+
}
353+
354+
t.srtpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTP)
355+
t.srtcpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTCP)
356+
t.remoteParameters = remoteParameters
357+
358+
cert := t.certificates[0]
359+
t.onStateChange(DTLSTransportStateConnecting)
360+
361+
return t.role(), tls.Certificate{
362+
Certificate: [][]byte{cert.x509Cert.Raw},
363+
PrivateKey: cert.privateKey,
364+
}, nil
365+
}
366+
367+
func (t *DTLSTransport) dtlsSharedOptions(certificate tls.Certificate) []dtls.Option {
368+
sharedOpts := []dtls.Option{
369+
dtls.WithCertificates(certificate),
370+
dtls.WithSRTPProtectionProfiles(t.srtpProtectionProfiles()...),
371+
dtls.WithExtendedMasterSecret(t.api.settingEngine.dtls.extendedMasterSecret),
372+
dtls.WithInsecureSkipVerify(!t.api.settingEngine.dtls.disableInsecureSkipVerify),
373+
dtls.WithLoggerFactory(t.api.settingEngine.LoggerFactory),
374+
dtls.WithVerifyPeerCertificate(t.verifyPeerCertificateFunc()),
375+
}
376+
377+
if t.api.settingEngine.dtls.customCipherSuites != nil {
378+
sharedOpts = append(
379+
sharedOpts,
380+
dtls.WithCustomCipherSuites(t.api.settingEngine.dtls.customCipherSuites),
381+
)
382+
}
383+
384+
if t.api.settingEngine.dtls.retransmissionInterval > 0 {
385+
sharedOpts = append(
386+
sharedOpts,
387+
dtls.WithFlightInterval(t.api.settingEngine.dtls.retransmissionInterval),
388+
)
389+
}
390+
356391
if t.api.settingEngine.replayProtection.DTLS != nil {
357-
dtlsConfig.ReplayProtectionWindow = int(*t.api.settingEngine.replayProtection.DTLS) //nolint:gosec // G115
392+
sharedOpts = append(
393+
sharedOpts,
394+
dtls.WithReplayProtectionWindow(int(*t.api.settingEngine.replayProtection.DTLS)), //nolint:gosec // G115
395+
)
358396
}
359397

360-
if t.api.settingEngine.dtls.clientAuth != nil {
361-
dtlsConfig.ClientAuth = *t.api.settingEngine.dtls.clientAuth
362-
}
363-
364-
dtlsConfig.FlightInterval = t.api.settingEngine.dtls.retransmissionInterval
365-
dtlsConfig.InsecureSkipVerifyHello = t.api.settingEngine.dtls.insecureSkipHelloVerify
366-
dtlsConfig.EllipticCurves = t.api.settingEngine.dtls.ellipticCurves
367-
dtlsConfig.ExtendedMasterSecret = t.api.settingEngine.dtls.extendedMasterSecret
368-
dtlsConfig.ClientCAs = t.api.settingEngine.dtls.clientCAs
369-
dtlsConfig.RootCAs = t.api.settingEngine.dtls.rootCAs
370-
dtlsConfig.KeyLogWriter = t.api.settingEngine.dtls.keyLogWriter
371-
dtlsConfig.ClientHelloMessageHook = t.api.settingEngine.dtls.clientHelloMessageHook
372-
dtlsConfig.ServerHelloMessageHook = t.api.settingEngine.dtls.serverHelloMessageHook
373-
dtlsConfig.CertificateRequestMessageHook = t.api.settingEngine.dtls.certificateRequestMessageHook
374-
dtlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _verifiedChains [][]*x509.Certificate) error {
398+
if t.api.settingEngine.dtls.cipherSuites != nil {
399+
sharedOpts = append(
400+
sharedOpts,
401+
dtls.WithCipherSuites(t.api.settingEngine.dtls.cipherSuites...),
402+
)
403+
}
404+
405+
if len(t.api.settingEngine.dtls.ellipticCurves) > 0 {
406+
sharedOpts = append(
407+
sharedOpts,
408+
dtls.WithEllipticCurves(t.api.settingEngine.dtls.ellipticCurves...),
409+
)
410+
}
411+
412+
if t.api.settingEngine.dtls.rootCAs != nil {
413+
sharedOpts = append(sharedOpts, dtls.WithRootCAs(t.api.settingEngine.dtls.rootCAs))
414+
}
415+
416+
if t.api.settingEngine.dtls.keyLogWriter != nil {
417+
sharedOpts = append(sharedOpts, dtls.WithKeyLogWriter(t.api.settingEngine.dtls.keyLogWriter))
418+
}
419+
420+
return sharedOpts
421+
}
422+
423+
func (t *DTLSTransport) srtpProtectionProfiles() []dtls.SRTPProtectionProfile {
424+
if len(t.api.settingEngine.srtpProtectionProfiles) > 0 {
425+
return t.api.settingEngine.srtpProtectionProfiles
426+
}
427+
428+
return defaultSrtpProtectionProfiles()
429+
}
430+
431+
func (t *DTLSTransport) verifyPeerCertificateFunc() func([][]byte, [][]*x509.Certificate) error {
432+
return func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
375433
if len(rawCerts) == 0 {
376434
return errNoRemoteCertificate
377435
}
@@ -384,32 +442,105 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { //nolint:
384442
return nil
385443
}
386444

387-
parsedRemoteCert, parseErr := x509.ParseCertificate(t.remoteCertificate)
388-
if parseErr != nil {
389-
return parseErr
445+
parsedRemoteCert, err := x509.ParseCertificate(t.remoteCertificate)
446+
if err != nil {
447+
return err
390448
}
391449

392450
return t.validateFingerPrint(parsedRemoteCert)
393451
}
452+
}
394453

395-
// Connect as DTLS Client/Server, function is blocking and we
396-
// must not hold the DTLSTransport lock
454+
func (t *DTLSTransport) connectDTLS(
455+
dtlsEndpoint *mux.Endpoint,
456+
role DTLSRole,
457+
sharedOpts []dtls.Option,
458+
) (*dtls.Conn, error) {
397459
if role == DTLSRoleClient {
398-
dtlsConn, err = dtls.Client(dtlsEndpoint, dtlsEndpoint.RemoteAddr(), dtlsConfig)
399-
} else {
400-
dtlsConn, err = dtls.Server(dtlsEndpoint, dtlsEndpoint.RemoteAddr(), dtlsConfig)
460+
clientOpts := t.toDTLSClientOptions(sharedOpts)
461+
462+
return dtls.ClientWithOptions(
463+
dtlsEndpoint,
464+
dtlsEndpoint.RemoteAddr(),
465+
clientOpts...,
466+
)
401467
}
402468

403-
if err == nil {
404-
if t.api.settingEngine.dtls.connectContextMaker != nil {
405-
handshakeCtx, _ := t.api.settingEngine.dtls.connectContextMaker()
406-
err = dtlsConn.HandshakeContext(handshakeCtx)
407-
} else {
408-
err = dtlsConn.Handshake()
409-
}
469+
serverOpts := t.toDTLSServerOptions(sharedOpts)
470+
471+
return dtls.ServerWithOptions(
472+
dtlsEndpoint,
473+
dtlsEndpoint.RemoteAddr(),
474+
serverOpts...,
475+
)
476+
}
477+
478+
func (t *DTLSTransport) toDTLSServerOptions(sharedOpts []dtls.Option) []dtls.ServerOption {
479+
serverOpts := make([]dtls.ServerOption, 0, len(sharedOpts)+5)
480+
for _, opt := range sharedOpts {
481+
serverOpts = append(serverOpts, opt)
482+
}
483+
484+
clientAuth := dtls.RequireAnyClientCert
485+
if t.api.settingEngine.dtls.clientAuth != nil {
486+
clientAuth = *t.api.settingEngine.dtls.clientAuth
487+
}
488+
489+
serverOpts = append(serverOpts,
490+
dtls.WithClientAuth(clientAuth),
491+
dtls.WithClientCAs(t.api.settingEngine.dtls.clientCAs),
492+
dtls.WithInsecureSkipVerifyHello(t.api.settingEngine.dtls.insecureSkipHelloVerify),
493+
)
494+
495+
if t.api.settingEngine.dtls.serverHelloMessageHook != nil {
496+
serverOpts = append(
497+
serverOpts,
498+
dtls.WithServerHelloMessageHook(t.api.settingEngine.dtls.serverHelloMessageHook),
499+
)
500+
}
501+
502+
if t.api.settingEngine.dtls.certificateRequestMessageHook != nil {
503+
serverOpts = append(
504+
serverOpts,
505+
dtls.WithCertificateRequestMessageHook(t.api.settingEngine.dtls.certificateRequestMessageHook),
506+
)
507+
}
508+
509+
return serverOpts
510+
}
511+
512+
func (t *DTLSTransport) toDTLSClientOptions(sharedOpts []dtls.Option) []dtls.ClientOption {
513+
clientOpts := make([]dtls.ClientOption, 0, len(sharedOpts)+1)
514+
for _, opt := range sharedOpts {
515+
clientOpts = append(clientOpts, opt)
516+
}
517+
518+
if t.api.settingEngine.dtls.clientHelloMessageHook != nil {
519+
clientOpts = append(
520+
clientOpts,
521+
dtls.WithClientHelloMessageHook(t.api.settingEngine.dtls.clientHelloMessageHook),
522+
)
523+
}
524+
525+
return clientOpts
526+
}
527+
528+
func (t *DTLSTransport) handshakeDTLS(dtlsConn *dtls.Conn) error {
529+
if t.api.settingEngine.dtls.connectContextMaker == nil {
530+
return dtlsConn.Handshake()
531+
}
532+
533+
handshakeCtx, cancel := t.api.settingEngine.dtls.connectContextMaker()
534+
if cancel != nil {
535+
defer cancel()
410536
}
411537

412-
// Re-take the lock, nothing beyond here is blocking
538+
return dtlsConn.HandshakeContext(handshakeCtx)
539+
}
540+
541+
func (t *DTLSTransport) completeStart(dtlsConn *dtls.Conn) error {
542+
srtpProtectionProfile, err := srtpProtectionProfileFromDTLSConn(dtlsConn)
543+
413544
t.lock.Lock()
414545
defer t.lock.Unlock()
415546

@@ -419,32 +550,43 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { //nolint:
419550
return err
420551
}
421552

553+
t.srtpProtectionProfile = srtpProtectionProfile
554+
t.conn = dtlsConn
555+
t.onStateChange(DTLSTransportStateConnected)
556+
557+
return t.startSRTP()
558+
}
559+
560+
func (t *DTLSTransport) failStart(err error) error {
561+
t.lock.Lock()
562+
defer t.lock.Unlock()
563+
t.onStateChange(DTLSTransportStateFailed)
564+
565+
return err
566+
}
567+
568+
func srtpProtectionProfileFromDTLSConn(dtlsConn *dtls.Conn) (srtp.ProtectionProfile, error) {
422569
srtpProfile, ok := dtlsConn.SelectedSRTPProtectionProfile()
423570
if !ok {
424-
t.onStateChange(DTLSTransportStateFailed)
425-
426-
return ErrNoSRTPProtectionProfile
571+
return 0, ErrNoSRTPProtectionProfile
427572
}
428573

574+
return srtpProtectionProfileFromDTLS(srtpProfile)
575+
}
576+
577+
func srtpProtectionProfileFromDTLS(srtpProfile dtls.SRTPProtectionProfile) (srtp.ProtectionProfile, error) {
429578
switch srtpProfile {
430579
case dtls.SRTP_AEAD_AES_128_GCM:
431-
t.srtpProtectionProfile = srtp.ProtectionProfileAeadAes128Gcm
580+
return srtp.ProtectionProfileAeadAes128Gcm, nil
432581
case dtls.SRTP_AEAD_AES_256_GCM:
433-
t.srtpProtectionProfile = srtp.ProtectionProfileAeadAes256Gcm
582+
return srtp.ProtectionProfileAeadAes256Gcm, nil
434583
case dtls.SRTP_AES128_CM_HMAC_SHA1_80:
435-
t.srtpProtectionProfile = srtp.ProtectionProfileAes128CmHmacSha1_80
584+
return srtp.ProtectionProfileAes128CmHmacSha1_80, nil
436585
case dtls.SRTP_NULL_HMAC_SHA1_80:
437-
t.srtpProtectionProfile = srtp.ProtectionProfileNullHmacSha1_80
586+
return srtp.ProtectionProfileNullHmacSha1_80, nil
438587
default:
439-
t.onStateChange(DTLSTransportStateFailed)
440-
441-
return ErrNoSRTPProtectionProfile
588+
return 0, ErrNoSRTPProtectionProfile
442589
}
443-
444-
t.conn = dtlsConn
445-
t.onStateChange(DTLSTransportStateConnected)
446-
447-
return t.startSRTP()
448590
}
449591

450592
// Stop stops and closes the DTLSTransport object.

0 commit comments

Comments
 (0)