@@ -87,11 +87,12 @@ class NoDOH(dns.exception.DNSException):
8787 available."""
8888
8989
90- def _compute_expiration (timeout ):
90+ def _compute_times (timeout ):
91+ now = time .time ()
9192 if timeout is None :
92- return None
93+ return ( now , None )
9394 else :
94- return time . time () + timeout
95+ return ( now , now + timeout )
9596
9697# This module can use either poll() or select() as the "polling backend".
9798#
@@ -230,6 +231,21 @@ def _destination_and_source(af, where, port, source, source_port,
230231 destination = None
231232 return (af , destination , source )
232233
234+ def _make_socket (af , type , source , ssl_context = None , server_hostname = None ):
235+ s = socket_factory (af , type )
236+ try :
237+ s .setblocking (False )
238+ if source is not None :
239+ s .bind (source )
240+ if ssl_context :
241+ return ssl_context .wrap_socket (s , do_handshake_on_connect = False ,
242+ server_hostname = server_hostname )
243+ else :
244+ return s
245+ except Exception :
246+ s .close ()
247+ raise
248+
233249def https (q , where , timeout = None , port = 443 , source = None , source_port = 0 ,
234250 one_rr_per_rrset = False , ignore_trailing = False ,
235251 session = None , path = '/dns-query' , post = True ,
@@ -424,7 +440,7 @@ def receive_udp(sock, destination, expiration=None,
424440
425441def udp (q , where , timeout = None , port = 53 , source = None , source_port = 0 ,
426442 ignore_unexpected = False , one_rr_per_rrset = False , ignore_trailing = False ,
427- raise_on_truncation = False ):
443+ raise_on_truncation = False , sock = None ):
428444 """Return the response obtained after sending a query via UDP.
429445
430446 *q*, a ``dns.message.Message``, the query to send
@@ -455,30 +471,37 @@ def udp(q, where, timeout=None, port=53, source=None, source_port=0,
455471 *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
456472 the TC bit is set.
457473
474+ *sock*, a ``socket.socket``, or ``None``, the socket to use for the
475+ query. If ``None``, the default, a socket is created. Note that
476+ if a socket is provided, it must be a nonblocking datagram socket,
477+ and the *source* and *source_port* are ignored.
478+
458479 Returns a ``dns.message.Message``.
459480 """
460481
461482 wire = q .to_wire ()
462483 (af , destination , source ) = _destination_and_source (None , where , port ,
463484 source , source_port )
464- with socket_factory (af , socket .SOCK_DGRAM , 0 ) as s :
465- expiration = _compute_expiration (timeout )
466- s .setblocking (0 )
467- if source is not None :
468- s .bind (source )
469- (_ , sent_time ) = send_udp (s , wire , destination , expiration )
485+ (begin_time , expiration ) = _compute_times (timeout )
486+ with contextlib .ExitStack () as stack :
487+ if sock :
488+ s = sock
489+ else :
490+ s = stack .enter_context (_make_socket (af , socket .SOCK_DGRAM , source ))
491+ send_udp (s , wire , destination , expiration )
470492 (r , received_time ) = receive_udp (s , destination , expiration ,
471493 ignore_unexpected , one_rr_per_rrset ,
472494 q .keyring , q .mac , ignore_trailing ,
473495 raise_on_truncation )
474- r .time = received_time - sent_time
496+ r .time = received_time - begin_time
475497 if not q .is_response (r ):
476498 raise BadResponse
477499 return r
478500
479501def udp_with_fallback (q , where , timeout = None , port = 53 , source = None ,
480502 source_port = 0 , ignore_unexpected = False ,
481- one_rr_per_rrset = False , ignore_trailing = False ):
503+ one_rr_per_rrset = False , ignore_trailing = False ,
504+ udp_sock = None , tcp_sock = None ):
482505 """Return the response to the query, trying UDP first and falling back
483506 to TCP if UDP results in a truncated response.
484507
@@ -507,17 +530,28 @@ def udp_with_fallback(q, where, timeout=None, port=53, source=None,
507530 *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
508531 junk at end of the received message.
509532
533+ *udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the
534+ UDP query. If ``None``, the default, a socket is created. Note that
535+ if a socket is provided, it must be a nonblocking datagram socket,
536+ and the *source* and *source_port* are ignored for the UDP query.
537+
538+ *tcp_sock*, a ``socket.socket``, or ``None``, the socket to use for the
539+ TCP query. If ``None``, the default, a socket is created. Note that
540+ if a socket is provided, it must be a nonblocking connected stream
541+ socket, and *where*, *source* and *source_port* are ignored for the TCP
542+ query.
543+
510544 Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True``
511545 if and only if TCP was used.
512546 """
513547 try :
514548 response = udp (q , where , timeout , port , source , source_port ,
515549 ignore_unexpected , one_rr_per_rrset ,
516- ignore_trailing , True )
550+ ignore_trailing , True , udp_sock )
517551 return (response , False )
518552 except dns .message .Truncated :
519553 response = tcp (q , where , timeout , port , source , source_port ,
520- one_rr_per_rrset , ignore_trailing )
554+ one_rr_per_rrset , ignore_trailing , tcp_sock )
521555 return (response , True )
522556
523557def _net_read (sock , count , expiration ):
@@ -634,12 +668,12 @@ def _connect(s, address, expiration):
634668
635669
636670def tcp (q , where , timeout = None , port = 53 , source = None , source_port = 0 ,
637- one_rr_per_rrset = False , ignore_trailing = False ):
671+ one_rr_per_rrset = False , ignore_trailing = False , sock = None ):
638672 """Return the response obtained after sending a query via TCP.
639673
640674 *q*, a ``dns.message.Message``, the query to send
641675
642- *where*, a ``str`` containing an IPv4 or IPv6 address, where
676+ *where*, a ``str`` containing an IPv4 or IPv6 address, where
643677 to send the message.
644678
645679 *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
@@ -659,19 +693,31 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
659693 *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
660694 junk at end of the received message.
661695
696+ *sock*, a ``socket.socket``, or ``None``, the socket to use for the
697+ query. If ``None``, the default, a socket is created. Note that
698+ if a socket is provided, it must be a nonblocking connected stream
699+ socket, and *where*, *source* and *source_port* are ignored.
700+
662701 Returns a ``dns.message.Message``.
663702 """
664703
665704 wire = q .to_wire ()
666- (af , destination , source ) = _destination_and_source (None , where , port ,
667- source , source_port )
668- with socket_factory (af , socket .SOCK_STREAM , 0 ) as s :
669- expiration = _compute_expiration (timeout )
670- s .setblocking (0 )
671- begin_time = time .time ()
672- if source is not None :
673- s .bind (source )
674- _connect (s , destination , expiration )
705+ (begin_time , expiration ) = _compute_times (timeout )
706+ with contextlib .ExitStack () as stack :
707+ if sock :
708+ #
709+ # Verify that the socket is connected, as if it's not connected,
710+ # it's not writable, and the polling in send_tcp() will time out or
711+ # hang forever.
712+ sock .getpeername ()
713+ s = sock
714+ else :
715+ (af , destination , source ) = _destination_and_source (None , where ,
716+ port , source ,
717+ source_port )
718+ s = stack .enter_context (_make_socket (af , socket .SOCK_STREAM ,
719+ source ))
720+ _connect (s , destination , expiration )
675721 send_tcp (s , wire , expiration )
676722 (r , received_time ) = receive_tcp (s , expiration , one_rr_per_rrset ,
677723 q .keyring , q .mac , ignore_trailing )
@@ -693,7 +739,7 @@ def _tls_handshake(s, expiration):
693739
694740
695741def tls (q , where , timeout = None , port = 853 , source = None , source_port = 0 ,
696- one_rr_per_rrset = False , ignore_trailing = False ,
742+ one_rr_per_rrset = False , ignore_trailing = False , sock = None ,
697743 ssl_context = None , server_hostname = None ):
698744 """Return the response obtained after sending a query via TLS.
699745
@@ -719,6 +765,11 @@ def tls(q, where, timeout=None, port=853, source=None, source_port=0,
719765 *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
720766 junk at end of the received message.
721767
768+ *sock*, an ``ssl.SSLSocket``, or ``None``, the socket to use for the
769+ query. If ``None``, the default, a socket is created. Note that
770+ if a socket is provided, it must be a nonblocking connected SSL stream
771+ socket, and *where*, *source*, *source_port*, and *ssl_context* are ignored.
772+
722773 *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
723774 a TLS connection. If ``None``, the default, creates one with the default
724775 configuration.
@@ -730,21 +781,24 @@ def tls(q, where, timeout=None, port=853, source=None, source_port=0,
730781 Returns a ``dns.message.Message``.
731782 """
732783
784+ if sock :
785+ #
786+ # If a socket was provided, there's no special TLS handling needed.
787+ #
788+ return tcp (q , where , timeout , port , source , source_port ,
789+ one_rr_per_rrset , ignore_trailing , sock )
790+
733791 wire = q .to_wire ()
792+ (begin_time , expiration ) = _compute_times (timeout )
734793 (af , destination , source ) = _destination_and_source (None , where , port ,
735- source , source_port )
736- if ssl_context is None :
794+ source , source_port )
795+ if ssl_context is None and not sock :
737796 ssl_context = ssl .create_default_context ()
738797 if server_hostname is None :
739798 ssl_context .check_hostname = False
740- with ssl_context .wrap_socket (socket_factory (af , socket .SOCK_STREAM , 0 ),
741- do_handshake_on_connect = False ,
742- server_hostname = server_hostname ) as s :
743- expiration = _compute_expiration (timeout )
744- s .setblocking (0 )
745- begin_time = time .time ()
746- if source is not None :
747- s .bind (source )
799+
800+ with _make_socket (af , socket .SOCK_STREAM , source , ssl_context = ssl_context ,
801+ server_hostname = server_hostname ) as s :
748802 _connect (s , destination , expiration )
749803 _tls_handshake (s , expiration )
750804 send_tcp (s , wire , expiration )
@@ -828,11 +882,8 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
828882 if use_udp and rdtype != dns .rdatatype .IXFR :
829883 raise ValueError ('cannot do a UDP AXFR' )
830884 sock_type = socket .SOCK_DGRAM if use_udp else socket .SOCK_STREAM
831- with socket_factory (af , sock_type , 0 ) as s :
832- s .setblocking (0 )
833- if source is not None :
834- s .bind (source )
835- expiration = _compute_expiration (lifetime )
885+ with _make_socket (af , sock_type , source ) as s :
886+ (_ , expiration ) = _compute_times (lifetime )
836887 _connect (s , destination , expiration )
837888 l = len (wire )
838889 if use_udp :
@@ -854,7 +905,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
854905 tsig_ctx = None
855906 first = True
856907 while not done :
857- mexpiration = _compute_expiration (timeout )
908+ ( _ , mexpiration ) = _compute_times (timeout )
858909 if mexpiration is None or \
859910 (expiration is not None and mexpiration > expiration ):
860911 mexpiration = expiration
0 commit comments