From a6f532c1a7fcd5b77fc7ae2cafca53ec32035221 Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Thu, 17 Oct 2024 16:08:31 -0700 Subject: [PATCH] Fix AXFR-style IXFR with multiple messages. The inbound xfr code is conflating the expected rdtype in responses with the incremental/replacement response style. This causes a problem when an AXFR-style IXFR response spans multiple messages, as resetting the style to AXFR (replacement) also changed the expected type in the question section of future responses to AXFR. This change separates out the style from the expected rdtype. --- dns/xfr.py | 25 ++++++++++++++----------- tests/test_xfr.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/dns/xfr.py b/dns/xfr.py index f1b875934..d17dd4848 100644 --- a/dns/xfr.py +++ b/dns/xfr.py @@ -83,8 +83,13 @@ def __init__( if rdtype == dns.rdatatype.IXFR: if serial is None: raise ValueError("a starting serial must be supplied for IXFRs") - elif is_udp: - raise ValueError("is_udp specified for AXFR") + self.incremental = True + elif rdtype == dns.rdatatype.AXFR: + if is_udp: + raise ValueError("is_udp specified for AXFR") + self.incremental = False + else: + raise ValueError("rdtype is not IXFR or AXFR") self.serial = serial self.is_udp = is_udp (_, _, self.origin) = txn_manager.origin_information() @@ -103,8 +108,7 @@ def process_message(self, message: dns.message.Message) -> bool: Returns `True` if the transfer is complete, and `False` otherwise. """ if self.txn is None: - replacement = self.rdtype == dns.rdatatype.AXFR - self.txn = self.txn_manager.writer(replacement) + self.txn = self.txn_manager.writer(not self.incremental) rcode = message.rcode() if rcode != dns.rcode.NOERROR: raise TransferError(rcode) @@ -131,7 +135,7 @@ def process_message(self, message: dns.message.Message) -> bool: raise dns.exception.FormError("first RRset is not an SOA") answer_index = 1 self.soa_rdataset = rdataset.copy() # pyright: ignore - if self.rdtype == dns.rdatatype.IXFR: + if self.incremental: assert self.soa_rdataset is not None soa = cast(dns.rdtypes.ANY.SOA.SOA, self.soa_rdataset[0]) if soa.serial == self.serial: @@ -168,7 +172,7 @@ def process_message(self, message: dns.message.Message) -> bool: # # Every time we see an origin SOA delete_mode inverts # - if self.rdtype == dns.rdatatype.IXFR: + if self.incremental: self.delete_mode = not self.delete_mode # # If this SOA Rdataset is equal to the first we saw @@ -177,8 +181,7 @@ def process_message(self, message: dns.message.Message) -> bool: # part of the response. # if rdataset == self.soa_rdataset and ( - self.rdtype == dns.rdatatype.AXFR - or (self.rdtype == dns.rdatatype.IXFR and self.delete_mode) + (not self.incremental) or self.delete_mode ): # # This is the final SOA @@ -187,7 +190,7 @@ def process_message(self, message: dns.message.Message) -> bool: if self.expecting_SOA: # We got an empty IXFR sequence! raise dns.exception.FormError("empty IXFR sequence") - if self.rdtype == dns.rdatatype.IXFR and self.serial != soa.serial: + if self.incremental and self.serial != soa.serial: raise dns.exception.FormError("unexpected end of IXFR sequence") self.txn.replace(name, rdataset) self.txn.commit() @@ -199,7 +202,7 @@ def process_message(self, message: dns.message.Message) -> bool: # self.expecting_SOA = False soa = cast(dns.rdtypes.ANY.SOA.SOA, rdataset[0]) - if self.rdtype == dns.rdatatype.IXFR: + if self.incremental: if self.delete_mode: # This is the start of an IXFR deletion set if soa.serial != self.serial: @@ -220,7 +223,7 @@ def process_message(self, message: dns.message.Message) -> bool: # SOA RR, but saw something else, so this must be an # AXFR response. # - self.rdtype = dns.rdatatype.AXFR + self.incremental = False self.expecting_SOA = False self.delete_mode = False self.txn.rollback() diff --git a/tests/test_xfr.py b/tests/test_xfr.py index 458cdf9b7..257397c26 100644 --- a/tests/test_xfr.py +++ b/tests/test_xfr.py @@ -263,6 +263,30 @@ class Server(object): @ 3600 IN SOA foo bar 1 2 3 4 5 """ +ixfr_axfr1 = """id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN IXFR +;ANSWER +@ 3600 IN SOA foo bar 1 2 3 4 5 +@ 3600 IN NS ns1 +@ 3600 IN NS ns2 +""" +ixfr_axfr2 = """id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN IXFR +;ANSWER +bar.foo 300 IN MX 0 blaz.foo +ns1 3600 IN A 10.0.0.1 +ns2 3600 IN A 10.0.0.2 +@ 3600 IN SOA foo bar 1 2 3 4 5 +""" + def test_basic_axfr(): z = dns.versioned.Zone("example.") @@ -394,6 +418,19 @@ def test_ixfr_is_axfr(): assert z == ez +def test_ixfr_is_axfr_two_parts(): + z = dns.versioned.Zone("example.") + m1 = dns.message.from_text(ixfr_axfr1, origin=z.origin, one_rr_per_rrset=True) + m2 = dns.message.from_text(ixfr_axfr2, origin=z.origin, one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=0xFFFFFFFF) as xfr: + done = xfr.process_message(m1) + assert not done + done = xfr.process_message(m2) + assert done + ez = dns.zone.from_text(base, "example.") + assert z == ez + + def test_ixfr_requires_serial(): z = dns.zone.from_text(base, "example.", zone_factory=dns.versioned.Zone) with pytest.raises(ValueError):