Skip to content

Commit 52f46fc

Browse files
committed
Adds sock parameters to query methods.
Allow passing a socket into dns.query.{udp,tcp,tls,udp_with_fallback}, and add tests for this.
1 parent c19e671 commit 52f46fc

File tree

2 files changed

+156
-42
lines changed

2 files changed

+156
-42
lines changed

dns/query.py

Lines changed: 93 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,12 @@ class NoDOH(dns.exception.DNSException):
8787
available."""
8888

8989

90-
def _compute_expiration(timeout):
90+
def _compute_times(timeout):
91+
now = time.time()
9192
if timeout is None:
92-
return None
93+
return (now, None)
9394
else:
94-
return time.time() + timeout
95+
return (now, now + timeout)
9596

9697
# This module can use either poll() or select() as the "polling backend".
9798
#
@@ -230,6 +231,21 @@ def _destination_and_source(af, where, port, source, source_port,
230231
destination = None
231232
return (af, destination, source)
232233

234+
def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
235+
s = socket_factory(af, type)
236+
try:
237+
s.setblocking(False)
238+
if source is not None:
239+
s.bind(source)
240+
if ssl_context:
241+
return ssl_context.wrap_socket(s, do_handshake_on_connect=False,
242+
server_hostname=server_hostname)
243+
else:
244+
return s
245+
except Exception:
246+
s.close()
247+
raise
248+
233249
def https(q, where, timeout=None, port=443, source=None, source_port=0,
234250
one_rr_per_rrset=False, ignore_trailing=False,
235251
session=None, path='/dns-query', post=True,
@@ -424,7 +440,7 @@ def receive_udp(sock, destination, expiration=None,
424440

425441
def udp(q, where, timeout=None, port=53, source=None, source_port=0,
426442
ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
427-
raise_on_truncation=False):
443+
raise_on_truncation=False, sock=None):
428444
"""Return the response obtained after sending a query via UDP.
429445
430446
*q*, a ``dns.message.Message``, the query to send
@@ -455,30 +471,37 @@ def udp(q, where, timeout=None, port=53, source=None, source_port=0,
455471
*raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
456472
the TC bit is set.
457473
474+
*sock*, a ``socket.socket``, or ``None``, the socket to use for the
475+
query. If ``None``, the default, a socket is created. Note that
476+
if a socket is provided, it must be a nonblocking datagram socket,
477+
and the *source* and *source_port* are ignored.
478+
458479
Returns a ``dns.message.Message``.
459480
"""
460481

461482
wire = q.to_wire()
462483
(af, destination, source) = _destination_and_source(None, where, port,
463484
source, source_port)
464-
with socket_factory(af, socket.SOCK_DGRAM, 0) as s:
465-
expiration = _compute_expiration(timeout)
466-
s.setblocking(0)
467-
if source is not None:
468-
s.bind(source)
469-
(_, sent_time) = send_udp(s, wire, destination, expiration)
485+
(begin_time, expiration) = _compute_times(timeout)
486+
with contextlib.ExitStack() as stack:
487+
if sock:
488+
s = sock
489+
else:
490+
s = stack.enter_context(_make_socket(af, socket.SOCK_DGRAM, source))
491+
send_udp(s, wire, destination, expiration)
470492
(r, received_time) = receive_udp(s, destination, expiration,
471493
ignore_unexpected, one_rr_per_rrset,
472494
q.keyring, q.mac, ignore_trailing,
473495
raise_on_truncation)
474-
r.time = received_time - sent_time
496+
r.time = received_time - begin_time
475497
if not q.is_response(r):
476498
raise BadResponse
477499
return r
478500

479501
def udp_with_fallback(q, where, timeout=None, port=53, source=None,
480502
source_port=0, ignore_unexpected=False,
481-
one_rr_per_rrset=False, ignore_trailing=False):
503+
one_rr_per_rrset=False, ignore_trailing=False,
504+
udp_sock=None, tcp_sock=None):
482505
"""Return the response to the query, trying UDP first and falling back
483506
to TCP if UDP results in a truncated response.
484507
@@ -507,17 +530,28 @@ def udp_with_fallback(q, where, timeout=None, port=53, source=None,
507530
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
508531
junk at end of the received message.
509532
533+
*udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the
534+
UDP query. If ``None``, the default, a socket is created. Note that
535+
if a socket is provided, it must be a nonblocking datagram socket,
536+
and the *source* and *source_port* are ignored for the UDP query.
537+
538+
*tcp_sock*, a ``socket.socket``, or ``None``, the socket to use for the
539+
TCP query. If ``None``, the default, a socket is created. Note that
540+
if a socket is provided, it must be a nonblocking connected stream
541+
socket, and *where*, *source* and *source_port* are ignored for the TCP
542+
query.
543+
510544
Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True``
511545
if and only if TCP was used.
512546
"""
513547
try:
514548
response = udp(q, where, timeout, port, source, source_port,
515549
ignore_unexpected, one_rr_per_rrset,
516-
ignore_trailing, True)
550+
ignore_trailing, True, udp_sock)
517551
return (response, False)
518552
except dns.message.Truncated:
519553
response = tcp(q, where, timeout, port, source, source_port,
520-
one_rr_per_rrset, ignore_trailing)
554+
one_rr_per_rrset, ignore_trailing, tcp_sock)
521555
return (response, True)
522556

523557
def _net_read(sock, count, expiration):
@@ -634,12 +668,12 @@ def _connect(s, address, expiration):
634668

635669

636670
def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
637-
one_rr_per_rrset=False, ignore_trailing=False):
671+
one_rr_per_rrset=False, ignore_trailing=False, sock=None):
638672
"""Return the response obtained after sending a query via TCP.
639673
640674
*q*, a ``dns.message.Message``, the query to send
641675
642-
*where*, a ``str`` containing an IPv4 or IPv6 address, where
676+
*where*, a ``str`` containing an IPv4 or IPv6 address, where
643677
to send the message.
644678
645679
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the
@@ -659,19 +693,31 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
659693
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
660694
junk at end of the received message.
661695
696+
*sock*, a ``socket.socket``, or ``None``, the socket to use for the
697+
query. If ``None``, the default, a socket is created. Note that
698+
if a socket is provided, it must be a nonblocking connected stream
699+
socket, and *where*, *source* and *source_port* are ignored.
700+
662701
Returns a ``dns.message.Message``.
663702
"""
664703

665704
wire = q.to_wire()
666-
(af, destination, source) = _destination_and_source(None, where, port,
667-
source, source_port)
668-
with socket_factory(af, socket.SOCK_STREAM, 0) as s:
669-
expiration = _compute_expiration(timeout)
670-
s.setblocking(0)
671-
begin_time = time.time()
672-
if source is not None:
673-
s.bind(source)
674-
_connect(s, destination, expiration)
705+
(begin_time, expiration) = _compute_times(timeout)
706+
with contextlib.ExitStack() as stack:
707+
if sock:
708+
#
709+
# Verify that the socket is connected, as if it's not connected,
710+
# it's not writable, and the polling in send_tcp() will time out or
711+
# hang forever.
712+
sock.getpeername()
713+
s = sock
714+
else:
715+
(af, destination, source) = _destination_and_source(None, where,
716+
port, source,
717+
source_port)
718+
s = stack.enter_context(_make_socket(af, socket.SOCK_STREAM,
719+
source))
720+
_connect(s, destination, expiration)
675721
send_tcp(s, wire, expiration)
676722
(r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
677723
q.keyring, q.mac, ignore_trailing)
@@ -693,7 +739,7 @@ def _tls_handshake(s, expiration):
693739

694740

695741
def tls(q, where, timeout=None, port=853, source=None, source_port=0,
696-
one_rr_per_rrset=False, ignore_trailing=False,
742+
one_rr_per_rrset=False, ignore_trailing=False, sock=None,
697743
ssl_context=None, server_hostname=None):
698744
"""Return the response obtained after sending a query via TLS.
699745
@@ -719,6 +765,11 @@ def tls(q, where, timeout=None, port=853, source=None, source_port=0,
719765
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
720766
junk at end of the received message.
721767
768+
*sock*, an ``ssl.SSLSocket``, or ``None``, the socket to use for the
769+
query. If ``None``, the default, a socket is created. Note that
770+
if a socket is provided, it must be a nonblocking connected SSL stream
771+
socket, and *where*, *source*, *source_port*, and *ssl_context* are ignored.
772+
722773
*ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
723774
a TLS connection. If ``None``, the default, creates one with the default
724775
configuration.
@@ -730,21 +781,24 @@ def tls(q, where, timeout=None, port=853, source=None, source_port=0,
730781
Returns a ``dns.message.Message``.
731782
"""
732783

784+
if sock:
785+
#
786+
# If a socket was provided, there's no special TLS handling needed.
787+
#
788+
return tcp(q, where, timeout, port, source, source_port,
789+
one_rr_per_rrset, ignore_trailing, sock)
790+
733791
wire = q.to_wire()
792+
(begin_time, expiration) = _compute_times(timeout)
734793
(af, destination, source) = _destination_and_source(None, where, port,
735-
source, source_port)
736-
if ssl_context is None:
794+
source, source_port)
795+
if ssl_context is None and not sock:
737796
ssl_context = ssl.create_default_context()
738797
if server_hostname is None:
739798
ssl_context.check_hostname = False
740-
with ssl_context.wrap_socket(socket_factory(af, socket.SOCK_STREAM, 0),
741-
do_handshake_on_connect=False,
742-
server_hostname=server_hostname) as s:
743-
expiration = _compute_expiration(timeout)
744-
s.setblocking(0)
745-
begin_time = time.time()
746-
if source is not None:
747-
s.bind(source)
799+
800+
with _make_socket(af, socket.SOCK_STREAM, source, ssl_context=ssl_context,
801+
server_hostname=server_hostname) as s:
748802
_connect(s, destination, expiration)
749803
_tls_handshake(s, expiration)
750804
send_tcp(s, wire, expiration)
@@ -828,11 +882,8 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
828882
if use_udp and rdtype != dns.rdatatype.IXFR:
829883
raise ValueError('cannot do a UDP AXFR')
830884
sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
831-
with socket_factory(af, sock_type, 0) as s:
832-
s.setblocking(0)
833-
if source is not None:
834-
s.bind(source)
835-
expiration = _compute_expiration(lifetime)
885+
with _make_socket(af, sock_type, source) as s:
886+
(_, expiration) = _compute_times(lifetime)
836887
_connect(s, destination, expiration)
837888
l = len(wire)
838889
if use_udp:
@@ -854,7 +905,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
854905
tsig_ctx = None
855906
first = True
856907
while not done:
857-
mexpiration = _compute_expiration(timeout)
908+
(_, mexpiration) = _compute_times(timeout)
858909
if mexpiration is None or \
859910
(expiration is not None and mexpiration > expiration):
860911
mexpiration = expiration

tests/test_query.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
import socket
1919
import unittest
2020

21+
try:
22+
import ssl
23+
have_ssl = True
24+
except Exception:
25+
have_ssl = False
26+
2127
import dns.message
2228
import dns.name
2329
import dns.rdataclass
@@ -46,6 +52,19 @@ def testQueryUDP(self):
4652
self.assertTrue('8.8.8.8' in seen)
4753
self.assertTrue('8.8.4.4' in seen)
4854

55+
def testQueryUDPWithSocket(self):
56+
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
57+
s.setblocking(0)
58+
qname = dns.name.from_text('dns.google.')
59+
q = dns.message.make_query(qname, dns.rdatatype.A)
60+
response = dns.query.udp(q, '8.8.8.8', sock=s)
61+
rrs = response.get_rrset(response.answer, qname,
62+
dns.rdataclass.IN, dns.rdatatype.A)
63+
self.assertTrue(rrs is not None)
64+
seen = set([rdata.address for rdata in rrs])
65+
self.assertTrue('8.8.8.8' in seen)
66+
self.assertTrue('8.8.4.4' in seen)
67+
4968
def testQueryTCP(self):
5069
qname = dns.name.from_text('dns.google.')
5170
q = dns.message.make_query(qname, dns.rdatatype.A)
@@ -57,6 +76,20 @@ def testQueryTCP(self):
5776
self.assertTrue('8.8.8.8' in seen)
5877
self.assertTrue('8.8.4.4' in seen)
5978

79+
def testQueryTCPWithSocket(self):
80+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
81+
s.connect(('8.8.8.8', 53))
82+
s.setblocking(0)
83+
qname = dns.name.from_text('dns.google.')
84+
q = dns.message.make_query(qname, dns.rdatatype.A)
85+
response = dns.query.tcp(q, None, sock=s)
86+
rrs = response.get_rrset(response.answer, qname,
87+
dns.rdataclass.IN, dns.rdatatype.A)
88+
self.assertTrue(rrs is not None)
89+
seen = set([rdata.address for rdata in rrs])
90+
self.assertTrue('8.8.8.8' in seen)
91+
self.assertTrue('8.8.4.4' in seen)
92+
6093
def testQueryTLS(self):
6194
qname = dns.name.from_text('dns.google.')
6295
q = dns.message.make_query(qname, dns.rdatatype.A)
@@ -68,12 +101,42 @@ def testQueryTLS(self):
68101
self.assertTrue('8.8.8.8' in seen)
69102
self.assertTrue('8.8.4.4' in seen)
70103

104+
@unittest.skipUnless(have_ssl, "No SSL support")
105+
def testQueryTLSWithSocket(self):
106+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
107+
s.connect(('8.8.8.8', 853))
108+
ctx = ssl.create_default_context()
109+
s = ctx.wrap_socket(s, server_hostname='dns.google')
110+
s.setblocking(0)
111+
qname = dns.name.from_text('dns.google.')
112+
q = dns.message.make_query(qname, dns.rdatatype.A)
113+
response = dns.query.tls(q, None, sock=s)
114+
rrs = response.get_rrset(response.answer, qname,
115+
dns.rdataclass.IN, dns.rdatatype.A)
116+
self.assertTrue(rrs is not None)
117+
seen = set([rdata.address for rdata in rrs])
118+
self.assertTrue('8.8.8.8' in seen)
119+
self.assertTrue('8.8.4.4' in seen)
120+
71121
def testQueryUDPFallback(self):
72122
qname = dns.name.from_text('.')
73123
q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
74124
(_, tcp) = dns.query.udp_with_fallback(q, '8.8.8.8')
75125
self.assertTrue(tcp)
76126

127+
def testQueryUDPFallbackWithSocket(self):
128+
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as udp_s:
129+
udp_s.setblocking(0)
130+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_s:
131+
tcp_s.connect(('8.8.8.8', 53))
132+
tcp_s.setblocking(0)
133+
qname = dns.name.from_text('.')
134+
q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
135+
(_, tcp) = dns.query.udp_with_fallback(q, '8.8.8.8',
136+
udp_sock=udp_s,
137+
tcp_sock=tcp_s)
138+
self.assertTrue(tcp)
139+
77140
def testQueryUDPFallbackNoFallback(self):
78141
qname = dns.name.from_text('dns.google.')
79142
q = dns.message.make_query(qname, dns.rdatatype.A)

0 commit comments

Comments
 (0)