Skip to content

Commit 8c63cfd

Browse files
committed
unify chaining code
1 parent 00a22ad commit 8c63cfd

File tree

4 files changed

+194
-68
lines changed

4 files changed

+194
-68
lines changed

dns/message.py

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,18 @@ def message(self):
8080
return self.kwargs['message']
8181

8282

83+
class NotQueryResponse(dns.exception.DNSException):
84+
"""Message is not a response to a query."""
85+
86+
87+
class ChainTooLong(dns.exception.DNSException):
88+
"""The CNAME chain is too long."""
89+
90+
91+
class AnswerForNXDOMAIN(dns.exception.DNSException):
92+
"""The rcode is NXDOMAIN but an answer was found."""
93+
94+
8395
class MessageSection(dns.enum.IntEnum):
8496
"""Message sections"""
8597
QUESTION = 0
@@ -94,6 +106,7 @@ def _maximum(cls):
94106
globals().update(MessageSection.__members__)
95107

96108
DEFAULT_EDNS_PAYLOAD = 1232
109+
MAX_CHAIN = 16
97110

98111
class Message:
99112
"""A DNS message."""
@@ -232,8 +245,10 @@ def is_response(self, other):
232245
dns.opcode.from_flags(self.flags) != \
233246
dns.opcode.from_flags(other.flags):
234247
return False
235-
if dns.rcode.from_flags(other.flags, other.ednsflags) != \
236-
dns.rcode.NOERROR:
248+
if other.rcode() in {dns.rcode.FORMERR, dns.rcode.SERVFAIL,
249+
dns.rcode.NOTIMP, dns.rcode.REFUSED}:
250+
# We don't check the question section in these cases, even
251+
# though they still ought to have the same question.
237252
return True
238253
if dns.opcode.is_update(self.flags):
239254
# This is assuming the "sender doesn't include anything
@@ -696,7 +711,96 @@ def _parse_special_rr_header(self, section, count, position,
696711

697712

698713
class QueryMessage(Message):
699-
pass
714+
def resolve_chaining(self):
715+
"""Follow the CNAME chain in the response to determine the answer
716+
RRset.
717+
718+
Raises NotQueryResponse if the message is not a response.
719+
720+
Raises dns.message.ChainTooLong if the CNAME chain is too long.
721+
722+
Raises AnswerForNXDOMAIN if the rcode is NXDOMAIN but an answer was
723+
found.
724+
725+
Raises dns.exception.FormError if the question count is not 1.
726+
727+
Returns a tuple (dns.name.Name, int, rrset) where the name is the
728+
canonical name, the int is the minimized TTL, and rrset is their
729+
answer RRset, which may be ``None`` if the chain was dangling or
730+
the response is an NXDOMAIN.
731+
"""
732+
if self.flags & dns.flags.QR == 0:
733+
raise NotQueryResponse
734+
if len(self.question) != 1:
735+
raise dns.exception.FormError
736+
question = self.question[0]
737+
qname = question.name
738+
min_ttl = -1
739+
rrset = None
740+
count = 0
741+
while count < MAX_CHAIN:
742+
try:
743+
rrset = self.find_rrset(self.answer, qname, question.rdclass,
744+
question.rdtype)
745+
if min_ttl == -1 or rrset.ttl < min_ttl:
746+
min_ttl = rrset.ttl
747+
break
748+
except KeyError:
749+
if question.rdtype != dns.rdatatype.CNAME:
750+
try:
751+
crrset = self.find_rrset(self.answer, qname,
752+
question.rdclass,
753+
dns.rdatatype.CNAME)
754+
if min_ttl == -1 or crrset.ttl < min_ttl:
755+
min_ttl = crrset.ttl
756+
for rd in crrset:
757+
qname = rd.target
758+
break
759+
count += 1
760+
continue
761+
except KeyError:
762+
# Exit the chaining loop
763+
break
764+
if count >= MAX_CHAIN:
765+
raise ChainTooLong
766+
if self.rcode() == dns.rcode.NXDOMAIN and rrset is not None:
767+
raise AnswerForNXDOMAIN
768+
if rrset is None:
769+
# Further minimize the TTL with NCACHE.
770+
auname = qname
771+
while True:
772+
# Look for an SOA RR whose owner name is a superdomain
773+
# of qname.
774+
try:
775+
srrset = self.find_rrset(self.authority, auname,
776+
question.rdclass,
777+
dns.rdatatype.SOA)
778+
if min_ttl == -1 or srrset.ttl < min_ttl:
779+
min_ttl = srrset.ttl
780+
if srrset[0].minimum < min_ttl:
781+
min_ttl = srrset[0].minimum
782+
break
783+
except KeyError:
784+
try:
785+
auname = auname.parent()
786+
except dns.name.NoParent:
787+
break
788+
return (qname, min_ttl, rrset)
789+
790+
def canonical_name(self):
791+
"""Return the canonical name of the first name in the question
792+
section.
793+
794+
Raises dns.message.NotQueryResponse if the message is not a response.
795+
796+
Raises dns.message.ChainTooLong if the CNAME chain is too long.
797+
798+
Raises AnswerForNXDOMAIN if the rcode is NXDOMAIN but an answer was
799+
found.
800+
801+
Raises dns.exception.FormError if the question count is not 1.
802+
"""
803+
return self.resolve_chaining()[0]
700804

701805

702806
def _maybe_import_update():

dns/resolver.py

Lines changed: 27 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -78,20 +78,16 @@ def canonical_name(self):
7878
"""Return the unresolved canonical name."""
7979
if 'qnames' not in self.kwargs:
8080
raise TypeError("parametrized exception required")
81-
IN = dns.rdataclass.IN
82-
CNAME = dns.rdatatype.CNAME
83-
cname = None
84-
# This code assumes the CNAME chain is in proper order, though
85-
# the Answer code does not make a similar assumption when
86-
# chaining.
8781
for qname in self.kwargs['qnames']:
8882
response = self.kwargs['responses'][qname]
89-
for answer in response.answer:
90-
if answer.rdtype != CNAME or answer.rdclass != IN:
91-
continue
92-
cname = answer[0].target
93-
if cname is not None:
94-
return cname
83+
try:
84+
cname = response.canonical_name()
85+
if cname != qname:
86+
return cname
87+
except Exception:
88+
# We can just eat this exception as it means there was
89+
# something wrong with the response.
90+
pass
9591
return self.kwargs['qnames'][0]
9692

9793
def __add__(self, e_nx):
@@ -209,50 +205,7 @@ def __init__(self, qname, rdtype, rdclass, response, nameserver=None,
209205
self.response = response
210206
self.nameserver = nameserver
211207
self.port = port
212-
min_ttl = -1
213-
rrset = None
214-
for count in range(0, 15):
215-
try:
216-
rrset = response.find_rrset(response.answer, qname,
217-
rdclass, rdtype)
218-
if min_ttl == -1 or rrset.ttl < min_ttl:
219-
min_ttl = rrset.ttl
220-
break
221-
except KeyError:
222-
if rdtype != dns.rdatatype.CNAME:
223-
try:
224-
crrset = response.find_rrset(response.answer,
225-
qname,
226-
rdclass,
227-
dns.rdatatype.CNAME)
228-
if min_ttl == -1 or crrset.ttl < min_ttl:
229-
min_ttl = crrset.ttl
230-
for rd in crrset:
231-
qname = rd.target
232-
break
233-
continue
234-
except KeyError:
235-
# Exit the chaining loop
236-
break
237-
self.canonical_name = qname
238-
self.rrset = rrset
239-
if rrset is None:
240-
while 1:
241-
# Look for a SOA RR whose owner name is a superdomain
242-
# of qname.
243-
try:
244-
srrset = response.find_rrset(response.authority, qname,
245-
rdclass, dns.rdatatype.SOA)
246-
if min_ttl == -1 or srrset.ttl < min_ttl:
247-
min_ttl = srrset.ttl
248-
if srrset[0].minimum < min_ttl:
249-
min_ttl = srrset[0].minimum
250-
break
251-
except KeyError:
252-
try:
253-
qname = qname.parent()
254-
except dns.name.NoParent:
255-
break
208+
(self.canonical_name, min_ttl, self.rrset) = response.resolve_chaining()
256209
self.expiration = time.time() + min_ttl
257210

258211
def __getattr__(self, attr): # pragma: no cover
@@ -698,25 +651,36 @@ def query_result(self, response, ex):
698651
assert response is not None
699652
rcode = response.rcode()
700653
if rcode == dns.rcode.NOERROR:
701-
answer = Answer(self.qname, self.rdtype, self.rdclass, response,
702-
self.nameserver, self.port)
654+
try:
655+
answer = Answer(self.qname, self.rdtype, self.rdclass, response,
656+
self.nameserver, self.port)
657+
except Exception:
658+
# The nameserver is no good, take it out of the mix.
659+
self.nameservers.remove(self.nameserver)
660+
return (None, False)
703661
if self.resolver.cache:
704662
self.resolver.cache.put((self.qname, self.rdtype,
705663
self.rdclass), answer)
706664
if answer.rrset is None and self.raise_on_no_answer:
707665
raise NoAnswer(response=answer.response)
708666
return (answer, True)
709667
elif rcode == dns.rcode.NXDOMAIN:
710-
self.nxdomain_responses[self.qname] = response
711-
# Make next_nameserver() return None, so caller breaks its
712-
# inner loop and calls next_request().
713-
if self.resolver.cache:
668+
# Further validate the response by making an Answer, even
669+
# if we aren't going to cache it.
670+
try:
714671
answer = Answer(self.qname, dns.rdatatype.ANY,
715672
dns.rdataclass.IN, response)
673+
except Exception:
674+
# The nameserver is no good, take it out of the mix.
675+
self.nameservers.remove(self.nameserver)
676+
return (None, False)
677+
self.nxdomain_responses[self.qname] = response
678+
if self.resolver.cache:
716679
self.resolver.cache.put((self.qname,
717680
dns.rdatatype.ANY,
718681
self.rdclass), answer)
719-
682+
# Make next_nameserver() return None, so caller breaks its
683+
# inner loop and calls next_request().
720684
return (None, True)
721685
elif rcode == dns.rcode.YXDOMAIN:
722686
yex = YXDOMAIN()

tests/test_resolution.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,22 @@ def make_negative_response(self, q, nxdomain=False):
8383
r.set_rcode(dns.rcode.NXDOMAIN)
8484
return r
8585

86+
def make_long_chain_response(self, q, count):
87+
r = dns.message.make_response(q)
88+
name = self.qname
89+
for i in range(count):
90+
rrs = r.get_rrset(r.answer, name, dns.rdataclass.IN,
91+
dns.rdatatype.CNAME, create=True)
92+
tname = dns.name.from_text(f'target{i}.')
93+
rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CNAME,
94+
str(tname)), 300)
95+
name = tname
96+
rrs = r.get_rrset(r.answer, name, dns.rdataclass.IN,
97+
dns.rdatatype.A, create=True)
98+
rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A,
99+
'10.0.0.1'), 300)
100+
return r
101+
86102
def test_next_request_cache_hit(self):
87103
self.resolver.cache = dns.resolver.Cache()
88104
q = dns.message.make_query(self.qname, dns.rdatatype.A)
@@ -353,6 +369,36 @@ def test_query_result_nxdomain(self):
353369
self.assertTrue(answer is None)
354370
self.assertTrue(done)
355371

372+
def test_query_result_nxdomain_but_has_answer(self):
373+
q = dns.message.make_query(self.qname, dns.rdatatype.A)
374+
r = self.make_address_response(q)
375+
r.set_rcode(dns.rcode.NXDOMAIN)
376+
(_, _) = self.resn.next_request()
377+
(nameserver, _, _, _) = self.resn.next_nameserver()
378+
(answer, done) = self.resn.query_result(r, None)
379+
self.assertIsNone(answer)
380+
self.assertFalse(done)
381+
self.assertTrue(nameserver not in self.resn.nameservers)
382+
383+
def test_query_result_chain_not_too_long(self):
384+
q = dns.message.make_query(self.qname, dns.rdatatype.A)
385+
r = self.make_long_chain_response(q, 15)
386+
(_, _) = self.resn.next_request()
387+
(_, _, _, _) = self.resn.next_nameserver()
388+
(answer, done) = self.resn.query_result(r, None)
389+
self.assertIsNotNone(answer)
390+
self.assertTrue(done)
391+
392+
def test_query_result_chain_too_long(self):
393+
q = dns.message.make_query(self.qname, dns.rdatatype.A)
394+
r = self.make_long_chain_response(q, 16)
395+
(_, _) = self.resn.next_request()
396+
(nameserver, _, _, _) = self.resn.next_nameserver()
397+
(answer, done) = self.resn.query_result(r, None)
398+
self.assertIsNone(answer)
399+
self.assertFalse(done)
400+
self.assertTrue(nameserver not in self.resn.nameservers)
401+
356402
def test_query_result_nxdomain_cached(self):
357403
self.resolver.cache = dns.resolver.Cache()
358404
q = dns.message.make_query(self.qname, dns.rdatatype.A)

tests/test_resolver.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,18 @@ class Server(object):
104104
;ADDITIONAL
105105
"""
106106

107+
message_text_mx = """id 1234
108+
opcode QUERY
109+
rcode NOERROR
110+
flags QR AA RD
111+
;QUESTION
112+
example. IN MX
113+
;ANSWER
114+
example. 1 IN A 10.0.0.1
115+
;AUTHORITY
116+
;ADDITIONAL
117+
"""
118+
107119
dangling_cname_0_message_text = """id 10000
108120
opcode QUERY
109121
rcode NOERROR
@@ -222,7 +234,7 @@ def testCacheCleaning(self):
222234

223235
def testIndexErrorOnEmptyRRsetAccess(self):
224236
def bad():
225-
message = dns.message.from_text(message_text)
237+
message = dns.message.from_text(message_text_mx)
226238
name = dns.name.from_text('example.')
227239
answer = dns.resolver.Answer(name, dns.rdatatype.MX,
228240
dns.rdataclass.IN, message,
@@ -232,7 +244,7 @@ def bad():
232244

233245
def testIndexErrorOnEmptyRRsetDelete(self):
234246
def bad():
235-
message = dns.message.from_text(message_text)
247+
message = dns.message.from_text(message_text_mx)
236248
name = dns.name.from_text('example.')
237249
answer = dns.resolver.Answer(name, dns.rdatatype.MX,
238250
dns.rdataclass.IN, message,

0 commit comments

Comments
 (0)