Skip to content

Commit

Permalink
Fix AXFR-style IXFR with multiple messages.
Browse files Browse the repository at this point in the history
The inbound xfr code is conflating the expected rdtype in responses with
the incremental/replacement response style.  This causes a problem when
an AXFR-style IXFR response spans multiple messages, as resetting the
style to AXFR (replacement) also changed the expected type in the
question section of future responses to AXFR.

This change separates out the style from the expected rdtype.
  • Loading branch information
bwelling committed Oct 17, 2024
1 parent b9e75af commit a6f532c
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 11 deletions.
25 changes: 14 additions & 11 deletions dns/xfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,13 @@ def __init__(
if rdtype == dns.rdatatype.IXFR:
if serial is None:
raise ValueError("a starting serial must be supplied for IXFRs")
elif is_udp:
raise ValueError("is_udp specified for AXFR")
self.incremental = True
elif rdtype == dns.rdatatype.AXFR:
if is_udp:
raise ValueError("is_udp specified for AXFR")
self.incremental = False
else:
raise ValueError("rdtype is not IXFR or AXFR")
self.serial = serial
self.is_udp = is_udp
(_, _, self.origin) = txn_manager.origin_information()
Expand All @@ -103,8 +108,7 @@ def process_message(self, message: dns.message.Message) -> bool:
Returns `True` if the transfer is complete, and `False` otherwise.
"""
if self.txn is None:
replacement = self.rdtype == dns.rdatatype.AXFR
self.txn = self.txn_manager.writer(replacement)
self.txn = self.txn_manager.writer(not self.incremental)
rcode = message.rcode()
if rcode != dns.rcode.NOERROR:
raise TransferError(rcode)
Expand All @@ -131,7 +135,7 @@ def process_message(self, message: dns.message.Message) -> bool:
raise dns.exception.FormError("first RRset is not an SOA")
answer_index = 1
self.soa_rdataset = rdataset.copy() # pyright: ignore
if self.rdtype == dns.rdatatype.IXFR:
if self.incremental:
assert self.soa_rdataset is not None
soa = cast(dns.rdtypes.ANY.SOA.SOA, self.soa_rdataset[0])
if soa.serial == self.serial:
Expand Down Expand Up @@ -168,7 +172,7 @@ def process_message(self, message: dns.message.Message) -> bool:
#
# Every time we see an origin SOA delete_mode inverts
#
if self.rdtype == dns.rdatatype.IXFR:
if self.incremental:
self.delete_mode = not self.delete_mode
#
# If this SOA Rdataset is equal to the first we saw
Expand All @@ -177,8 +181,7 @@ def process_message(self, message: dns.message.Message) -> bool:
# part of the response.
#
if rdataset == self.soa_rdataset and (
self.rdtype == dns.rdatatype.AXFR
or (self.rdtype == dns.rdatatype.IXFR and self.delete_mode)
(not self.incremental) or self.delete_mode
):
#
# This is the final SOA
Expand All @@ -187,7 +190,7 @@ def process_message(self, message: dns.message.Message) -> bool:
if self.expecting_SOA:
# We got an empty IXFR sequence!
raise dns.exception.FormError("empty IXFR sequence")
if self.rdtype == dns.rdatatype.IXFR and self.serial != soa.serial:
if self.incremental and self.serial != soa.serial:
raise dns.exception.FormError("unexpected end of IXFR sequence")
self.txn.replace(name, rdataset)
self.txn.commit()
Expand All @@ -199,7 +202,7 @@ def process_message(self, message: dns.message.Message) -> bool:
#
self.expecting_SOA = False
soa = cast(dns.rdtypes.ANY.SOA.SOA, rdataset[0])
if self.rdtype == dns.rdatatype.IXFR:
if self.incremental:
if self.delete_mode:
# This is the start of an IXFR deletion set
if soa.serial != self.serial:
Expand All @@ -220,7 +223,7 @@ def process_message(self, message: dns.message.Message) -> bool:
# SOA RR, but saw something else, so this must be an
# AXFR response.
#
self.rdtype = dns.rdatatype.AXFR
self.incremental = False
self.expecting_SOA = False
self.delete_mode = False
self.txn.rollback()
Expand Down
37 changes: 37 additions & 0 deletions tests/test_xfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,30 @@ class Server(object):
@ 3600 IN SOA foo bar 1 2 3 4 5
"""

ixfr_axfr1 = """id 1
opcode QUERY
rcode NOERROR
flags AA
;QUESTION
example. IN IXFR
;ANSWER
@ 3600 IN SOA foo bar 1 2 3 4 5
@ 3600 IN NS ns1
@ 3600 IN NS ns2
"""
ixfr_axfr2 = """id 1
opcode QUERY
rcode NOERROR
flags AA
;QUESTION
example. IN IXFR
;ANSWER
bar.foo 300 IN MX 0 blaz.foo
ns1 3600 IN A 10.0.0.1
ns2 3600 IN A 10.0.0.2
@ 3600 IN SOA foo bar 1 2 3 4 5
"""


def test_basic_axfr():
z = dns.versioned.Zone("example.")
Expand Down Expand Up @@ -394,6 +418,19 @@ def test_ixfr_is_axfr():
assert z == ez


def test_ixfr_is_axfr_two_parts():
z = dns.versioned.Zone("example.")
m1 = dns.message.from_text(ixfr_axfr1, origin=z.origin, one_rr_per_rrset=True)
m2 = dns.message.from_text(ixfr_axfr2, origin=z.origin, one_rr_per_rrset=True)
with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=0xFFFFFFFF) as xfr:
done = xfr.process_message(m1)
assert not done
done = xfr.process_message(m2)
assert done
ez = dns.zone.from_text(base, "example.")
assert z == ez


def test_ixfr_requires_serial():
z = dns.zone.from_text(base, "example.", zone_factory=dns.versioned.Zone)
with pytest.raises(ValueError):
Expand Down

0 comments on commit a6f532c

Please sign in to comment.