@@ -3,12 +3,19 @@ package pq
33// This file contains SSL tests
44
55import (
6+ "bytes"
67 _ "crypto/sha256"
8+ "crypto/tls"
79 "crypto/x509"
810 "database/sql"
11+ "fmt"
12+ "io"
13+ "net"
914 "os"
1015 "path/filepath"
16+ "strings"
1117 "testing"
18+ "time"
1219)
1320
1421func maybeSkipSSLTests (t * testing.T ) {
@@ -280,3 +287,135 @@ func TestSSLClientCertificates(t *testing.T) {
280287 }
281288 }
282289}
290+
291+ // Check that clint sends SNI data when `sslsni` is not disabled
292+ func TestSNISupport (t * testing.T ) {
293+ t .Parallel ()
294+ tests := []struct {
295+ name string
296+ conn_param string
297+ hostname string
298+ expected_sni string
299+ }{
300+ {
301+ name : "SNI is set by default" ,
302+ conn_param : "" ,
303+ hostname : "localhost" ,
304+ expected_sni : "localhost" ,
305+ },
306+ {
307+ name : "SNI is passed when asked for" ,
308+ conn_param : "sslsni=1" ,
309+ hostname : "localhost" ,
310+ expected_sni : "localhost" ,
311+ },
312+ {
313+ name : "SNI is not passed when disabled" ,
314+ conn_param : "sslsni=0" ,
315+ hostname : "localhost" ,
316+ expected_sni : "" ,
317+ },
318+ {
319+ name : "SNI is not set for IPv4" ,
320+ conn_param : "" ,
321+ hostname : "127.0.0.1" ,
322+ expected_sni : "" ,
323+ },
324+ }
325+ for _ , tt := range tests {
326+ tt := tt
327+ t .Run (tt .name , func (t * testing.T ) {
328+ t .Parallel ()
329+
330+ // Start mock postgres server on OS-provided port
331+ listener , err := net .Listen ("tcp" , "127.0.0.1:" )
332+ if err != nil {
333+ t .Fatal (err )
334+ }
335+ serverErrChan := make (chan error , 1 )
336+ serverSNINameChan := make (chan string , 1 )
337+ go mockPostgresSSL (listener , serverErrChan , serverSNINameChan )
338+
339+ defer listener .Close ()
340+ defer close (serverErrChan )
341+ defer close (serverSNINameChan )
342+
343+ // Try to establish a connection with the mock server. Connection will error out after TLS
344+ // clientHello, but it is enough to catch SNI data on the server side
345+ port := strings .Split (listener .Addr ().String (), ":" )[1 ]
346+ connStr := fmt .Sprintf ("sslmode=require host=%s port=%s %s" , tt .hostname , port , tt .conn_param )
347+
348+ // We are okay to skip this error as we are polling serverErrChan and we'll get an error
349+ // or timeout from the server side in case of problems here.
350+ db , _ := sql .Open ("postgres" , connStr )
351+ _ , _ = db .Exec ("SELECT 1" )
352+
353+ // Check SNI data
354+ select {
355+ case sniHost := <- serverSNINameChan :
356+ if sniHost != tt .expected_sni {
357+ t .Fatalf ("Expected SNI to be 'localhost', got '%+v' instead" , sniHost )
358+ }
359+ case err = <- serverErrChan :
360+ t .Fatalf ("mock server failed with error: %+v" , err )
361+ case <- time .After (time .Second ):
362+ t .Fatal ("exceeded connection timeout without erroring out" )
363+ }
364+ })
365+ }
366+ }
367+
368+ // Make a postgres mock server to test TLS SNI
369+ //
370+ // Accepts postgres StartupMessage and handles TLS clientHello, then closes a connection.
371+ // While reading clientHello catch passed SNI data and report it to nameChan.
372+ func mockPostgresSSL (listener net.Listener , errChan chan error , nameChan chan string ) {
373+ var sniHost string
374+
375+ conn , err := listener .Accept ()
376+ if err != nil {
377+ errChan <- err
378+ return
379+ }
380+ defer conn .Close ()
381+
382+ err = conn .SetDeadline (time .Now ().Add (time .Second ))
383+ if err != nil {
384+ errChan <- err
385+ return
386+ }
387+
388+ // Receive StartupMessage with SSL Request
389+ startupMessage := make ([]byte , 8 )
390+ if _ , err := io .ReadFull (conn , startupMessage ); err != nil {
391+ errChan <- err
392+ return
393+ }
394+ // StartupMessage: first four bytes -- total len = 8, last four bytes SslRequestNumber
395+ if ! bytes .Equal (startupMessage , []byte {0 , 0 , 0 , 0x8 , 0x4 , 0xd2 , 0x16 , 0x2f }) {
396+ errChan <- fmt .Errorf ("unexpected startup message: %#v" , startupMessage )
397+ return
398+ }
399+
400+ // Respond with SSLOk
401+ _ , err = conn .Write ([]byte ("S" ))
402+ if err != nil {
403+ errChan <- err
404+ return
405+ }
406+
407+ // Set up TLS context to catch clientHello. It will always error out during handshake
408+ // as no certificate is set.
409+ srv := tls .Server (conn , & tls.Config {
410+ GetConfigForClient : func (argHello * tls.ClientHelloInfo ) (* tls.Config , error ) {
411+ sniHost = argHello .ServerName
412+ return nil , nil
413+ },
414+ })
415+ defer srv .Close ()
416+
417+ // Do the TLS handshake ignoring errors
418+ _ = srv .Handshake ()
419+
420+ nameChan <- sniHost
421+ }
0 commit comments